# Prediction using LSTM, GRU-LSTM, xLSTM

In [1]:
# Import packages
# Core

import time

# Data analysis, preprocessing and math
import pandas as pa
import torch
import torch.nn as nn
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.preprocessing import MinMaxScaler

# Plotting
# Private utilities package


## 3D Tensor Preparation

Pytorch expect a 3D Tensor.

In [2]:
processed = pa.read_parquet(path="../data/model/processed.parquet", engine="fastparquet")
processed.columns
df = processed.copy(deep=True)

In [3]:
df

Unnamed: 0,ISO3_reporter,UNDS_reporter,CNAME_reporter,ISO3_partner,UNDS_partner,CNAME_partner,Year,GDP_reporter,GDP_partner,contig,...,EXPORT,arms,military,trade,descr_trade,financial,travel,other,target_mult,sender_mult
0,AGO,024,Angola,ALB,008,Albania,1988,8.769837e+09,2.051236e+09,False,...,0.00000,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0
1,AGO,024,Angola,ALB,008,Albania,1989,1.020178e+10,2.253090e+09,False,...,0.00000,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0
2,AGO,024,Angola,ALB,008,Albania,1990,1.122952e+10,2.028554e+09,False,...,0.00000,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0
3,AGO,024,Angola,ALB,008,Albania,1991,1.060378e+10,1.099559e+09,False,...,0.00000,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0
4,AGO,024,Angola,ALB,008,Albania,1992,8.307811e+09,6.521750e+08,False,...,0.00000,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1134236,ZWE,716,Zimbabwe,ZMB,894,Zambia,2019,2.571741e+10,2.330867e+10,True,...,59552.55737,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0
1134237,ZWE,716,Zimbabwe,ZMB,894,Zambia,2020,2.686794e+10,1.813776e+10,True,...,52563.44115,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0
1134238,ZWE,716,Zimbabwe,ZMB,894,Zambia,2021,2.724052e+10,2.209642e+10,True,...,60108.67747,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0
1134239,ZWE,716,Zimbabwe,ZMB,894,Zambia,2022,3.278975e+10,2.916378e+10,True,...,90006.02766,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0


In [4]:
print("✅ Data loaded.")
features = ['financial', 'travel', 'other', 'GDP_reporter', 'GDP_partner', 'contig', 'comlang_off', 'colony', 'smctry',
            'distw', 'arms', 'military', 'trade']
target = 'EXPORT'
SEQ_LEN = 10

# Drop NaNs
df = df[features + [target]].dropna()
print(f"✅ Dropped NaNs. Remaining rows: {len(df)}")

# Normalize
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(df[features + [target]])
feature_len = len(features)
print("✅ Data normalized.")


# Create sequences
def create_sequences(data, seq_len, feature_len):
  X, y = [], []
  for i in range(len(data) - seq_len):
    X.append(data[i:i + seq_len, :feature_len])
    y.append(data[i + seq_len, feature_len])  # target is 'EXPORT'
  return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)


X, y = create_sequences(data_scaled, SEQ_LEN, feature_len)
print(f"✅ Created sequences. X shape: {X.shape}, y shape: {y.shape}")

# Split (80% train, 20% validation)
split_idx = int(0.8 * len(X))
X_train, X_val = X[:split_idx], X[split_idx:]
y_train, y_val = y[:split_idx], y[split_idx:]
print(f"✅ Split into training and validation sets. Train: {X_train.shape[0]}, Val: {X_val.shape[0]}")


# LSTM model
class TradeFlowLSTM(nn.Module):
  def __init__(self, input_size, hidden_size=64):
    super().__init__()
    self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
    self.fc = nn.Linear(hidden_size, 1)

  def forward(self, x):
    _, (hn, _) = self.lstm(x)
    return self.fc(hn[-1]).squeeze()


model = TradeFlowLSTM(input_size=feature_len)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()

print("\n🚀 Starting training...")
start_time = time.time()

# Train
for epoch in range(200):
  model.train()
  output = model(X_train)
  loss = loss_fn(output, y_train)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  if epoch % 10 == 0:
    elapsed = time.time() - start_time
    print(f"Epoch {epoch:03}: Train Loss = {loss.item():.4f} | Elapsed: {elapsed:.2f}s")

training_time = time.time() - start_time
print(f"\n✅ Training complete in {training_time:.2f} seconds.")

# Evaluate on validation set
print("\n🔍 Evaluating on validation set...")
eval_start = time.time()
model.eval()
with torch.no_grad():
  preds = model(X_val).numpy()
  y_true = y_val.numpy()

# Inverse transform to original scale
restored_preds = []
restored_true = []
for i, (pred, true) in enumerate(zip(preds, y_true)):
  dummy = [0] * feature_len + [pred]
  restored_preds.append(scaler.inverse_transform([dummy])[0][-1])

  dummy_true = [0] * feature_len + [true]
  restored_true.append(scaler.inverse_transform([dummy_true])[0][-1])
  if i % 100 == 0:
    print(f"  ⏳ Inverse scaling: processed {i + 1}/{len(preds)}")

# Evaluation metrics
mse = mean_squared_error(restored_true, restored_preds)
mae = mean_absolute_error(restored_true, restored_preds)
eval_time = time.time() - eval_start

print(f"\n📊 Validation Results (done in {eval_time:.2f} seconds):")
print(f"Mean Squared Error (MSE): {mse:.2f}")
print(f"Mean Absolute Error (MAE): {mae:.2f}")

✅ Data loaded.
✅ Dropped NaNs. Remaining rows: 1119521
✅ Data normalized.


  return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)


✅ Created sequences. X shape: torch.Size([1119511, 10, 13]), y shape: torch.Size([1119511])
✅ Split into training and validation sets. Train: 895608, Val: 223903

🚀 Starting training...
Epoch 000: Train Loss = 0.0245 | Elapsed: 143.84s
Epoch 010: Train Loss = 0.0002 | Elapsed: 1035.70s
Epoch 020: Train Loss = 0.0001 | Elapsed: 2556.12s
Epoch 030: Train Loss = 0.0001 | Elapsed: 5783.54s
Epoch 040: Train Loss = 0.0001 | Elapsed: 6661.20s
Epoch 050: Train Loss = 0.0001 | Elapsed: 7553.20s
Epoch 060: Train Loss = 0.0001 | Elapsed: 8439.76s
Epoch 070: Train Loss = 0.0001 | Elapsed: 9327.26s
Epoch 080: Train Loss = 0.0001 | Elapsed: 10222.14s
Epoch 090: Train Loss = 0.0001 | Elapsed: 11113.31s
Epoch 100: Train Loss = 0.0001 | Elapsed: 12004.78s
Epoch 110: Train Loss = 0.0001 | Elapsed: 12953.21s
Epoch 120: Train Loss = 0.0001 | Elapsed: 13843.47s
Epoch 130: Train Loss = 0.0001 | Elapsed: 14732.07s
Epoch 140: Train Loss = 0.0001 | Elapsed: 15623.45s
Epoch 150: Train Loss = 0.0001 | Elapsed: 1

In [6]:
import joblib
import json

In [7]:
torch.save(model.state_dict(), "../models/basic_lstm.pt")
joblib.dump(scaler, "../models/scaler.pkl")
meta = {
  "features": features,
  "target": target,
  "seq_len": SEQ_LEN,
  "hidden_size": 64
}
with open(f"../models/meta.json", "w") as f:
  json.dump(meta, f, indent=2)
print("📦 Model, scaler, and metadata saved.")

📦 Model, scaler, and metadata saved.
