In [1]:
import torch as pt
import numpy as np
import matplotlib.pyplot as plt
import tqdm
from omegaconf import OmegaConf

from GDELTAnomalies.datasets.gdelt_pt_dataset import GDELTDataset
import GDELTAnomalies.models.tft as tft

### Setup

Verify pytorch is using our GPU

In [2]:
pt.accelerator.current_accelerator()

device(type='cuda')

Load dataset

In [3]:
dataset = GDELTDataset(lookback=10, horizon=1, step=1, flatten=True)

data_len = len(dataset)
train_len = 308 * dataset.num_series
valid_len = 52 * dataset.num_series


train_data = pt.utils.data.Subset(dataset, range(train_len))
valid_data = pt.utils.data.Subset(dataset, range(train_len, train_len + valid_len))
test_data = pt.utils.data.Subset(dataset, range(train_len + valid_len, data_len))

train_dataloader = pt.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=3, pin_memory=True, prefetch_factor=4, persistent_workers=True)
valid_dataloader = pt.utils.data.DataLoader(dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)
test_dataloader = pt.utils.data.DataLoader(dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

In [4]:
def quantile_loss(y_pred, y_true, q):
    """
    Calculate the quantile loss (pinball loss).

    Args:
        y_pred (torch.Tensor): Predicted values.
        y_true (torch.Tensor): True values.
        q (float or torch.Tensor): The quantile level (0 to 1).

    Returns:
        torch.Tensor: The mean quantile loss.
    """
    errors = y_true - y_pred
    loss = pt.max(q * errors, (q - 1) * errors)
    return pt.mean(loss)

### TFT

In [5]:
device = pt.device("cuda")
pt.manual_seed(854923)


data_props = {'num_historical_numeric': 4200,
                'num_historical_categorical': 0,
                'num_static_numeric': 26,
                'num_static_categorical': 0,
                'num_future_numeric': 0,
                'num_future_categorical': 0,
                "num_output_series": 4200,
                "num_future_steps": 1,
                "device": "cuda"
                # 'historical_categorical_cardinalities': [],
                # 'static_categorical_cardinalities': [],
                # 'future_categorical_cardinalities': [],
                }

configuration = {
    'model':
        {
            'dropout': 0.05,
            'state_size': 4,
            'output_quantiles': [0.5],
            'lstm_layers': 2,
            'attention_heads': 2
        },
    # these arguments are related to possible extensions of the model class
    'task_type': 'regression',
    'target_window_start': None,
    'data_props': data_props
}

epochs = 3000

model = tft.TemporalFusionTransformer(OmegaConf.create(configuration)).to(device)
optimizer = pt.optim.Adam(model.parameters(), lr=4e-5)
scheduler = pt.optim.lr_scheduler.ExponentialLR(optimizer, 0.97)

quantile = 0.5
lossfn = lambda x, y: quantile_loss(x, y, quantile)


valid_history = np.zeros(epochs)

tqdm_iter = tqdm.tqdm(range(epochs))

for epoch in tqdm_iter:
    # Calculate validation loss
    model.eval()
    valid_loss = 0
    with pt.no_grad():
        for X, y, static in valid_dataloader:
            print("Batch")
            X = X.to(device)
            y = y.to(device)
            static = static.to(device)

            batch = {
                "historical_ts_numeric": X,
                "static_feats_numeric": static,
            }

            stuff = model.forward(batch)
            pred = stuff["predicted_quantiles"]
            valid_loss += lossfn(pred.squeeze(), pt.log(y + 1))
    valid_history[epoch] = valid_loss
    tqdm_iter.set_postfix_str(f"{valid_loss=}")
    
    model.train()
    for X, y, static in train_dataloader:
        X = X.to(device)
        y = y.to(device)
        static = static.to(device)
        batch = {
            "historical_ts_numeric": X,
            "static_feats_numeric": static,
        }
        
        stuff = model.forward(batch)
        pred = stuff["predicted_quantiles"]
        loss = lossfn(pred.squeeze(), pt.log(y + 1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    if epoch % 100 == 0:
        scheduler.step()

  0%|          | 0/3000 [00:00<?, ?it/s]

Batch
Batch
Batch
Batch
Batch
Batch
Batch


  0%|          | 0/3000 [00:14<?, ?it/s]

Batch





KeyboardInterrupt: 

In [6]:
sum(p.numel() for p in model.parameters()) * 4 / 1024 / 1024 / 1024

2.3198921270668507

In [None]:
device = pt.device("cuda")
for X, y, statics in train_dataloader:
    batch = {
        "historical_ts_numeric": X.to(device),
        "static_feats_numeric": statics.to(device),
    }
    stuff = model.forward(batch)
    preds = stuff["predicted_quantiles"]

torch.Size([32, 10, 4200])
torch.Size([32])
torch.Size([32, 26])
torch.Size([32, 11, 32])
torch.Size([32, 11, 32])
10
1
torch.Size([32, 1, 1])
