In [None]:
import torch as pt
import numpy as np
import matplotlib.pyplot as plt
import tqdm
from omegaconf import OmegaConf

from GDELTAnomalies.datasets.gdelt_pt_dataset import GDELTDataset
from GDELTAnomalies.models.tsmixer import TSMixer

### Setup

Verify pytorch is using our GPU

In [None]:
pt.accelerator.current_accelerator()

device(type='cuda')

Load dataset

In [None]:
dataset = GDELTDataset(lookback=10, horizon=1, step=1)

data_len = len(dataset)
train_len = 308
valid_len = 52


train_data = pt.utils.data.Subset(dataset, range(train_len))
valid_data = pt.utils.data.Subset(dataset, range(train_len, train_len + valid_len))
test_data = pt.utils.data.Subset(dataset, range(train_len + valid_len, data_len))

train_dataloader = pt.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=3, pin_memory=True, persistent_workers=True)
valid_dataloader = pt.utils.data.DataLoader(dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)
test_dataloader = pt.utils.data.DataLoader(dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
data_props = {'num_historical_numeric': 4,
                'num_historical_categorical': 6,
                'num_static_numeric': 10,
                'num_static_categorical': 11,
                'num_future_numeric': 2,
                'num_future_categorical': 3,
                'historical_categorical_cardinalities': (1 + np.random.randint(10, size=6)).tolist(),
                'static_categorical_cardinalities': (1 + np.random.randint(10, size=11)).tolist(),
                'future_categorical_cardinalities': (1 + np.random.randint(10, size=3)).tolist(),
                }

configuration = {
    'model':
        {
            'dropout': 0.05,
            'state_size': 64,
            'output_quantiles': [0.1, 0.5, 0.9],
            'lstm_layers': 2,
            'attention_heads': 4
        },
    # these arguments are related to possible extensions of the model class
    'task_type': 'regression',
    'target_window_start': None,
    'data_props': data_props
}

model = tft.TemporalFusionTransformer(OmegaConf.create(configuration))