In [44]:
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch import optim
from hydra.experimental import compose, initialize
from omegaconf import DictConfig
import numpy as np
from collections import defaultdict
from data.ATM.dataset import get_loader

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [45]:
from RNNSM.rnnsm import RNNSM
from RMTPP.rmtpp import RMTPP

In [29]:
ds = ATMDataset('data/ATM/train_day.csv')
data_loader = DataLoader(dataset=ds, batch_size=32, shuffle=True, collate_fn=pad_collate, drop_last=True)

In [48]:
data_loader = get_loader()

In [None]:
initialize(config_path=".")

In [50]:
cfg = compose(config_name="config.yaml")

In [51]:
def train_rmtpp(data_loader, model, cfg: DictConfig):
    train_metrics = defaultdict(list)

    for epoch in range(cfg.training.n_epochs):
        print(f'Epoch {epoch+1}/{cfg.training.n_epochs}....')
        losses = []
        for times, events, lengths, _ in data_loader:
            time_deltas = torch.cat([torch.zeros(times.shape[0], 1, times.shape[2]),
                                     times[:, 1:] - times[:, :-1]],
                                     dim=1)
            o_t, y_t = model(events, time_deltas, lengths)
            padding_mask = torch.isclose(times, torch.Tensor([0]))
            ret_mask = ~padding_mask
            loss = model.compute_loss(time_deltas, padding_mask, o_t, y_t[0], events)
            loss.backward()
            opt.step()
            opt.zero_grad()
            losses.append(loss.item())
        train_metrics['loss'].append(np.mean(losses))
        print(train_metrics['loss'][-1])
    return train_metrics

In [52]:
model = RMTPP(cfg.rmtpp)
opt = optim.Adam(model.parameters(), lr=cfg.training.lr)
train_rmtpp(data_loader, model, cfg)

Epoch 1/10....
-1.6434604239035795
Epoch 2/10....
-2.316424532139555
Epoch 3/10....
-3.564129073569115
Epoch 4/10....
-2.143318919425315
Epoch 5/10....
-1.7213613061194724
Epoch 6/10....
-3.1691531754554587
Epoch 7/10....
-3.799747983191876
Epoch 8/10....
4072741699.8579564
Epoch 9/10....
5.232079488482881
Epoch 10/10....
1.7407105507210214


defaultdict(list,
            {'loss': [-1.6434604239035795,
              -2.316424532139555,
              -3.564129073569115,
              -2.143318919425315,
              -1.7213613061194724,
              -3.1691531754554587,
              -3.799747983191876,
              4072741699.8579564,
              5.232079488482881,
              1.7407105507210214]})

In [39]:
def train_rnnsm(data_loader, model, cfg: DictConfig):
    train_metrics = defaultdict(list)

    for epoch in range(cfg.training.n_epochs):
        print(f'Epoch {epoch+1}/{cfg.training.n_epochs}....')
        losses = []
        for times, events, lengths, _ in data_loader:
            time_deltas = torch.cat([torch.zeros(times.shape[0], 1, times.shape[2]),
                                     times[:, 1:] - times[:, :-1]],
                                     dim=1)
            o_t = model(events, time_deltas, lengths)
            padding_mask = torch.isclose(times, torch.Tensor([0]))
            ret_mask = ~padding_mask
            loss = model.compute_loss(time_deltas, padding_mask, ret_mask, o_t)
            loss.backward()
            opt.step()
            opt.zero_grad()
            losses.append(loss.item())
        train_metrics['loss'].append(np.mean(losses))
        print(train_metrics['loss'][-1])
    return train_metrics

In [40]:
model = RNNSM(cfg.rnnsm)
opt = optim.Adam(model.parameters(), lr=cfg.training.lr)
train_rnnsm(data_loader, model, cfg)

Epoch 1/10....
-3.1207775065397962
Epoch 2/10....
-6.634248429156364
Epoch 3/10....
-7.0342806045045245
Epoch 4/10....
-7.286588881878143
Epoch 5/10....
-7.377609050020259
Epoch 6/10....
-7.448889072905195
Epoch 7/10....
-7.542163503930924
Epoch 8/10....
-7.602628535412728
Epoch 9/10....
-7.629863891195743
Epoch 10/10....
-7.657119903158634


defaultdict(list,
            {'loss': [-3.1207775065397962,
              -6.634248429156364,
              -7.0342806045045245,
              -7.286588881878143,
              -7.377609050020259,
              -7.448889072905195,
              -7.542163503930924,
              -7.602628535412728,
              -7.629863891195743,
              -7.657119903158634]})