In [1]:
import torch as pt
import numpy as np
import matplotlib.pyplot as plt
import tqdm

from GDELTAnomalies.datasets.gdelt_pt_dataset import GDELTDataset
from GDELTAnomalies.models.tsmixer import TSMixer

import os

### Setup

Verify pytorch is using our GPU

In [2]:
pt.accelerator.current_accelerator()

device(type='cuda')

Load dataset

In [3]:
dataset = GDELTDataset(lookback=10, horizon=1, step=1)

data_len = len(dataset)
train_len = 308
valid_len = 52


train_data = pt.utils.data.Subset(dataset, range(train_len))
valid_data = pt.utils.data.Subset(dataset, range(train_len, train_len + valid_len))
test_data = pt.utils.data.Subset(dataset, range(train_len + valid_len, data_len))

train_dataloader = pt.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=3, pin_memory=True, persistent_workers=True)
valid_dataloader = pt.utils.data.DataLoader(valid_data, batch_size=32, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)
test_dataloader = pt.utils.data.DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

### TSMixer

In [4]:
def quantile_loss(y_pred, y_true, q):
    """
    Calculate the quantile loss (pinball loss).

    Args:
        y_pred (torch.Tensor): Predicted values.
        y_true (torch.Tensor): True values.
        q (float or torch.Tensor): The quantile level (0 to 1).

    Returns:
        torch.Tensor: The mean quantile loss.
    """
    errors = y_true - y_pred
    loss = pt.max(q * errors, (q - 1) * errors)
    return pt.mean(loss)

In [None]:
device = pt.device("cuda")
pt.manual_seed(854923)

model = TSMixer(10, 1, dataset.num_series, num_blocks=3, ff_dim=64).to(device)

quantile = 0.5
lossfn = lambda x, y: quantile_loss(x, y, quantile)

epochs = 250000
tqdm_iter = tqdm.tqdm(range(epochs))

optimizer = pt.optim.Adam(model.parameters(), lr=4e-5)
scheduler = pt.optim.lr_scheduler.ExponentialLR(optimizer, 0.97)

valid_history = np.zeros(epochs)

for epoch in tqdm_iter:
    # Calculate validation loss
    model.eval()
    valid_loss = 0
    with pt.no_grad():
        for X, y in valid_dataloader:
            X = X.to(device)
            y = y.to(device)

            pred = model.forward(X)
            valid_loss += lossfn(pred.squeeze(), pt.log(y + 1))
    valid_history[epoch] = valid_loss
    tqdm_iter.set_postfix_str(f"{valid_loss=}")
    
    model.train()
    for X, y in train_dataloader:
        X = X.to(device)
        y = y.to(device)
        
        pred = model.forward(X)
        loss = lossfn(pred.squeeze(), pt.log(y + 1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if epoch % 100 == 0:
        scheduler.step()
    if (epoch + 1) % 2500 == 0:
        pt.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "valid_loss": valid_loss
        }, f"checkpoints/TSMixer_{epoch}.pt")

  0%|          | 0/250000 [00:04<?, ?it/s, valid_loss=tensor(22.7878, device='cuda:0')]

In [16]:
pt.save({
    "epoch": 5000,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "scheduler_state_dict": scheduler.state_dict(),
    "valid_loss": valid_loss
}, "checkpoints/TSMixer_5000.pt")

In [17]:
pt.load("checkpoints/TSMixer_19999.pt")

{'epoch': 19999,
 'model_state_dict': OrderedDict([('mixer_layers.0.time_mixing.norm.weight',
               tensor([1.0000, 1.0000, 0.9682,  ..., 0.9972, 0.9486, 1.0030], device='cuda:0')),
              ('mixer_layers.0.time_mixing.norm.bias',
               tensor([ 0.0720, -0.0541,  0.0388,  ..., -0.0148, -0.0720, -0.0636],
                      device='cuda:0')),
              ('mixer_layers.0.time_mixing.norm.running_mean',
               tensor([5.6052e-45, 5.6052e-45, 2.3116e+01,  ..., 5.6052e-45, 3.6761e-01,
                       5.6052e-45], device='cuda:0')),
              ('mixer_layers.0.time_mixing.norm.running_var',
               tensor([5.6052e-45, 5.6052e-45, 6.2221e+02,  ..., 5.6052e-45, 1.2345e+00,
                       5.6052e-45], device='cuda:0')),
              ('mixer_layers.0.time_mixing.norm.num_batches_tracked',
               tensor(200000, device='cuda:0')),
              ('mixer_layers.0.time_mixing.fc1.weight',
               tensor([[-1.9570e-02, -1.9