# xLSTM, sLSTM, mLSTMs

## Preparation

### Import modules

In [10]:
# Prediction using LSTM, GRU-LSTM, xLSTM
import copy
import math
from typing import List

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from pandas import DataFrame
from sklearn.model_selection import KFold, GroupShuffleSplit
from torch.nn.utils import clip_grad_norm_
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader, Dataset, Subset

import thesis_utils as tu
import thesis_utils.datastruc as tuds
import thesis_utils.models as tumod

### Configuration

In [11]:
# Model parameters
HORIZON = 1
BATCH_SIZE = 128
EMBEDDING_SIZE = 32
NUM_EPOCHS = 25
HIDDEN_SIZE = 128
N_LAYERS = 3
DROPOUT = 0.3
XLSTM_TYPE = "X"
N_LAGS = 5

# Train parameters
TARGET = "EXPORT_centered"
FEATURES = [
  "contig", "comlang_off", "colony", "smctry",
]
N_SPLITS = 5
PATIENCE = 5
LEARNING_RATE = 0.01
WEIGHT_DECAY = 0.01
RANDOM_SEED = 16
SUBSAMPLE_ENABLED = False
N_DYADS = 1000

SANCTION_COLS = ["arms", "military", "trade", "travel", "other"]

# Torch config
torch.manual_seed(RANDOM_SEED)
device = (
  torch.device("mps") if torch.backends.mps.is_available()
  else torch.device("cpu")
)

### Load Data

In [12]:
processed = pd.read_parquet(path="../../data/model/processed.parquet", engine="fastparquet")
df: DataFrame = processed.copy(deep=True)

### Sort, shift and compute data

In [13]:
# Sort data by Report + Partner + Year
df["dyad_id"] = df["ISO3_reporter"] + "_" + df["ISO3_partner"]
df = df.sort_values(by=["dyad_id", "Year"], ignore_index=True)

In [14]:
if SUBSAMPLE_ENABLED:
  dyad_subsample = pd.Series(df["dyad_id"].unique()).sample(n=N_DYADS, random_state=RANDOM_SEED, replace=False)
  df = df[df["dyad_id"].isin(dyad_subsample)]
print(df["dyad_id"].nunique())

33672


In [15]:
df["sanction"] = (df[SANCTION_COLS]
                  .sum(axis=1)).astype(int)

### Coerce numerical values and convert dyad_id to categorical

In [16]:
num_cols = ["distw", "GDP_reporter", "GDP_partner", "sanction", "contig",
            "comlang_off", "colony", "smctry", "Year", ]
df[num_cols] = df[num_cols].apply(pd.to_numeric, errors="coerce").astype(float)
df = df.dropna(subset=num_cols)

In [17]:
df["Year"] = df["Year"].astype(int)
for col in ["dyad_id"]:
  df[col] = pd.Categorical(df[col], categories=sorted(df[col].unique()))

### Center data

In [18]:
center_columns = ["distw", "GDP_reporter", "GDP_partner", "EXPORT"]
for col in center_columns:
  median = df[col].median()
  std_df = df[col].std()
  df[col + "_centered"] = (df[col] - median) / std_df
FEATURES += ["distw_centered"]

In [19]:
lag_cols = ["GDP_reporter_centered", "GDP_partner_centered", "sanction"]
for col in lag_cols:
  for index in range(1, N_LAGS + 1):
    df[f"{col}_lag{index}"] = df.groupby("dyad_id", observed=True)[col].shift(index)

In [20]:
df = df.dropna()

In [21]:
FEATURES += [f"{c}_lag{index}" for c in lag_cols for index in range(1, N_LAGS + 1)]
# FEATURES += lag_cols

In [22]:
FEATURES

['contig',
 'comlang_off',
 'colony',
 'smctry',
 'distw_centered',
 'GDP_reporter_centered_lag1',
 'GDP_reporter_centered_lag2',
 'GDP_reporter_centered_lag3',
 'GDP_reporter_centered_lag4',
 'GDP_reporter_centered_lag5',
 'GDP_partner_centered_lag1',
 'GDP_partner_centered_lag2',
 'GDP_partner_centered_lag3',
 'GDP_partner_centered_lag4',
 'GDP_partner_centered_lag5',
 'sanction_lag1',
 'sanction_lag2',
 'sanction_lag3',
 'sanction_lag4',
 'sanction_lag5']

In [23]:
# Split into Train, Validation and Test sets
gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=RANDOM_SEED)

train_idx, test_idx = next(gss.split(df, groups=df["dyad_id"]))
test_df = df.iloc[test_idx]
train_df = df.iloc[train_idx]

train_idx, val_idx = next(gss.split(train_df, groups=train_df["dyad_id"]))
val_df = train_df.iloc[val_idx]
train_df = train_df.iloc[train_idx]

In [24]:
train_df.loc[:, FEATURES] = train_df.loc[:, FEATURES].astype(
  "float32",
  copy=False
)

# Train

## Define Fold and Epoch steps
_For reusability_

In [25]:
# Create KFold object
kf = KFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_SEED)

In [26]:
# Define epoch step
def epoch_step(model: nn.Module, optimizer: Optimizer, criterion: nn.Module,
               scheduler: LRScheduler, train_loader: DataLoader, val_loader: DataLoader,
               device: any) -> float:
  model.train()
  for X, y, di in train_loader:
    X, y, di = map(lambda t: t.to(device, non_blocking=True), (X, y, di))
    optimizer.zero_grad()
    y_pred = model(X, di)

    if not torch.isfinite(y_pred).all():
      print("⚠️ NaN or Inf detected in y_pred — stopping here!")
      return float("inf")  # or break

    loss = criterion(y_pred, y)

    if not torch.isfinite(loss):
      print("⚠️ loss is NaN or Inf!")

    loss.backward()

    clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

  model.eval()
  val_losses = []
  with (torch.no_grad()):
    for X, y, di in val_loader:
      X, y, di = map(lambda t: t.to(device, non_blocking=True), (X, y, di))
      val_losses.append(criterion(model(X, di), y).item())

  val_rmse = math.sqrt((sum(val_losses) / len(val_losses)))
  scheduler.step(val_rmse)
  return val_rmse

In [27]:
# Define fold step
def fold_step(fold: int, train_idx: List, val_idx: List,
              dataset: Dataset, batch_size: int, num_epochs: int, patience: int,
              model: nn.Module, device: any,
              optimizer: Optimizer, criterion: nn.Module, scheduler: LRScheduler) -> (float, dict):
  train_loader = DataLoader(
    Subset(dataset, train_idx),
    batch_size=batch_size,
    shuffle=True,
    num_workers=10,
    persistent_workers=True,
    prefetch_factor=2,
    pin_memory=True
  )

  val_loader = DataLoader(
    Subset(dataset, val_idx),
    batch_size=batch_size,
    shuffle=False,
    num_workers=10,
    persistent_workers=True,
    prefetch_factor=2,
    pin_memory=False
  )

  best_state = copy.deepcopy(model.state_dict())
  best_rmse = float("inf")
  patience_left = patience

  print(f"Start epoch train for fold {fold}")
  for epoch in range(num_epochs):
    val_rmse = epoch_step(model=model, optimizer=optimizer, criterion=criterion,
                          scheduler=scheduler, train_loader=train_loader, val_loader=val_loader,
                          device=device)
    print(f"Epoch {epoch + 1:02d}/{num_epochs}  |  val RMSE: {val_rmse:.4f}")

    if val_rmse < best_rmse - 1e-4:
      best_rmse, patience_left = val_rmse, 10
      best_state = model.state_dict()
    else:
      patience_left -= 1
      if patience_left == 0:
        print("Early stop.")
        break
  print("Load state dict")
  model.load_state_dict(best_state)
  model.eval()
  preds, truth = [], []
  with torch.no_grad():
    for X, y, di in val_loader:
      X, di = map(lambda t: t.to(device, non_blocking=True), (X, di))
      preds.append(model(X, di).cpu())
      truth.append(y)
  preds = torch.cat(preds).numpy()
  truth = torch.cat(truth).numpy()

  rmse = tu.rmse(truth, preds)
  mae = tu.mae(truth, preds)
  rmae = tu.rmae(truth, preds)
  pseudo_r2 = tu.pseudo_r2(truth, preds)
  print(f"Fold {fold}  RMSE {rmse:.4f} | MAE {mae:.4f} | R² {pseudo_r2:.4f} | RMAE {rmae:.4f}")

  return rmse, copy.deepcopy(best_state)


## Train Raw dataset

### Split dataset

In [28]:
# Convert df_scaled to pytorch Tensor
dataset, dyad_to_idx = tuds.make_panel_datasets_dyad(
  data=df,
  features=FEATURES,
  target=TARGET,
  horizon=HORIZON,
)

In [29]:
# Create DataLoaders for the 3 sets
train_loader = DataLoader(
  Subset(dataset, train_idx),
  batch_size=BATCH_SIZE,
  shuffle=True,
  num_workers=10,
  persistent_workers=True,
  prefetch_factor=2,
  pin_memory=False
)

val_loader = DataLoader(
  Subset(dataset, val_idx),
  batch_size=BATCH_SIZE,
  shuffle=False,
  num_workers=10,
  persistent_workers=True,
  prefetch_factor=2,
  pin_memory=False
)

test_loader = DataLoader(
  Subset(dataset, test_idx),
  batch_size=BATCH_SIZE,
  shuffle=False,
  num_workers=10,
  persistent_workers=True,
  prefetch_factor=2,
  pin_memory=False
)

### Train model

In [30]:
# Save config
SAVE_ENABLED = False
SERIAL_NUMBER = f"{XLSTM_TYPE}LSTM-{LEARNING_RATE}lr-{DROPOUT}d-{HIDDEN_SIZE}hs"
SERIAL_NUMBER = SERIAL_NUMBER.replace(".", "_")
PATH_TO_FOLDER = "../../models/"

In [31]:
# Save best train iteration
best_fold_state = None
best_fold_rmse = float("inf")

In [32]:
for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(len(dataset))), 1):
  model = tumod.DyadXLSTM(
    n_features=len(FEATURES),
    n_dyads=len(dyad_to_idx),
    embed_dim=EMBEDDING_SIZE,
    hidden_size=HIDDEN_SIZE,
    dropout=DROPOUT,
    horizon=HORIZON,
    type=XLSTM_TYPE,
  ).to(device=device)

  criterion = nn.MSELoss()
  optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=PATIENCE
  )

  print(f"=== FOLD {fold}/{N_SPLITS} ===")
  fold_rmse, best_state = fold_step(fold=fold,
                                    train_idx=train_idx,
                                    val_idx=val_idx,
                                    dataset=dataset,
                                    batch_size=BATCH_SIZE,
                                    num_epochs=NUM_EPOCHS,
                                    patience=PATIENCE,
                                    model=model,
                                    device=device,
                                    optimizer=optimizer,
                                    criterion=criterion,
                                    scheduler=scheduler)
  if fold_rmse < best_fold_rmse:
    best_fold_rmse = fold_rmse
    best_fold_state = copy.deepcopy(best_state)

=== FOLD 1/5 ===
Start epoch train for fold 1
Epoch 01/25  |  val RMSE: 0.9967
⚠️ NaN or Inf detected in y_pred — stopping here!
Epoch 02/25  |  val RMSE: inf
⚠️ NaN or Inf detected in y_pred — stopping here!
Epoch 03/25  |  val RMSE: inf
⚠️ NaN or Inf detected in y_pred — stopping here!
Epoch 04/25  |  val RMSE: inf
⚠️ NaN or Inf detected in y_pred — stopping here!
Epoch 05/25  |  val RMSE: inf
⚠️ NaN or Inf detected in y_pred — stopping here!
Epoch 06/25  |  val RMSE: inf
⚠️ NaN or Inf detected in y_pred — stopping here!
Epoch 07/25  |  val RMSE: inf
⚠️ NaN or Inf detected in y_pred — stopping here!
Epoch 08/25  |  val RMSE: inf
⚠️ NaN or Inf detected in y_pred — stopping here!
Epoch 09/25  |  val RMSE: inf
⚠️ NaN or Inf detected in y_pred — stopping here!
Epoch 10/25  |  val RMSE: inf
⚠️ NaN or Inf detected in y_pred — stopping here!
Epoch 11/25  |  val RMSE: inf
Early stop.
Load state dict
Fold 1  RMSE nan | MAE nan | R² nan | RMAE nan
=== FOLD 2/5 ===
Start epoch train for fold 2


KeyboardInterrupt: 

## Save Model

In [47]:
torch.save({
  "model_state_dict": best_fold_state,
  "model_hyperparams": {
    "n_features": len(FEATURES),
    "n_dyads": len(dyad_to_idx),
    "embed_dim": EMBEDDING_SIZE,
    "hidden_size": HIDDEN_SIZE,
    "n_layers": N_LAYERS,
    "dropout": DROPOUT,
    "horizon": HORIZON,
  },
  "dyad_to_idx": dyad_to_idx,
  "feature_names": FEATURES,
}, PATH_TO_FOLDER + SERIAL_NUMBER + ".pt")