In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
import importlib
import utils.data_processing as data_processing
importlib.reload(data_processing)
from utils.data_processing import get_dataloaders, set_seed, save_output, get_df


In [None]:
set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# data processing

In [None]:
df = get_df()

In [None]:
seq_features = [
    'r0', 'c0',
    'r1', 'c1',
    'r2', 'c2',
    'r3', 'c3',
]
static_features = ['gameLength', 'uc']

target = 'c4'

X_seq = df[seq_features]
X_static = df[static_features]
y = df[target]


In [None]:
TIME_STEPS = 4
SEQ_LEN = len(seq_features) // TIME_STEPS
BATCH_SIZE=32
train_loader, test_loader, test_loader_h1, test_loader_h6 = get_dataloaders(X_seq, X_static, y)

# running experiment

In [None]:
from utils.models import (
    RNN,
    LSTM,
    GRU,
    TinyGRU,
    TransformerEncoderPositionalEncoding,
    SelfAttentionOnly,
)
from utils.train_eval import train_and_evaluate

In [None]:
model_dict = {
    "RNN": RNN(SEQ_LEN).to(device),
    "LSTM": LSTM(SEQ_LEN).to(device),
    "GRU": GRU(SEQ_LEN).to(device),
    "TinyGRU": TinyGRU(SEQ_LEN).to(device),
    "Transformer": TransformerEncoderPositionalEncoding(SEQ_LEN).to(device),
    "Self-Attention": SelfAttentionOnly(SEQ_LEN).to(device),
    # "TinyAttentionNoProj": TinyAttentionNoProj(SEQ_LEN).to(device),
}

In [None]:
epochs = 100
test_loaders = (test_loader, test_loader_h1, test_loader_h6)
for model_name, model in model_dict.items():
    print(f"\nTraining model: {model_name}")
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    history = train_and_evaluate(model, train_loader, test_loaders, criterion, optimizer, device, epochs=epochs)
    model_dict[model_name] = {
        "model": model,
        **history
    }


In [None]:
save_output(model_dict, "output_all_models")

# plotting

In [None]:

import utils.plotting as plotting
importlib.reload(plotting)

from utils.plotting import plot_test_loss, plot_error_bars

In [None]:
plot_test_loss(model_dict)

In [None]:
plot_error_bars(model_dict, title="Overall Comparison (H1 vs H6)")