In [34]:
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch import optim
from hydra.experimental import compose, initialize
from omegaconf import DictConfig
import numpy as np
from collections import defaultdict
from data.ATM.dataset import get_ATM_test_loader, get_ATM_train_val_loaders, ATMTrainDataset, pad_collate_train

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
from RNNSM.rnnsm import RNNSM
from RMTPP.rmtpp import RMTPP

In [28]:
pd.read_csv('data/ATM/train_day.csv')

Unnamed: 0,id,time,event
0,g2434,16344.545035,4
1,g2434,16344.549225,4
2,g2434,16358.555127,1
3,g2434,16364.250451,2
4,g2434,16373.158530,2
...,...,...,...
370106,g444,16524.050336,4
370107,g444,16524.050475,4
370108,g444,16524.050729,4
370109,g444,16524.051308,4


In [29]:
idxs = ATMTrainDataset.get_indexes('data/ATM/train_day.csv')
train_idxs = idxs[:100]

In [25]:
ds = ATMTrainDataset('data/ATM/train_day.csv', train_idxs)
data_loader = DataLoader(dataset=ds, batch_size=32, shuffle=True, collate_fn=pad_collate_train, drop_last=True)

In [33]:
train_loader, val_loader = get_ATM_train_val_loaders(16000, 16400, 16500, 'data/ATM/train_day.csv')

In [38]:
test_loader = get_ATM_test_loader(16000, 16400, 16500, 'data/ATM/train_day.csv')

In [39]:
initialize(config_path=".")

hydra.experimental.initialize()

In [40]:
cfg = compose(config_name="config.yaml")

In [60]:
def train_rmtpp(data_loader, model, cfg: DictConfig):
    train_metrics = defaultdict(list)

    for epoch in range(cfg.training.n_epochs):
        print(f'Epoch {epoch+1}/{cfg.training.n_epochs}....')
        losses = []
        for times, time_deltas, events, lengths in data_loader:
            #print(times.shape, time_deltas.shape, events.shape)
            o_t, y_t = model(events, time_deltas, lengths)
            padding_mask = torch.isclose(times, torch.Tensor([0]))
            ret_mask = ~padding_mask
            loss = model.compute_loss(time_deltas, padding_mask, o_t, y_t[0], events)
            loss.backward()
            opt.step()
            opt.zero_grad()
            losses.append(loss.item())
        train_metrics['loss'].append(np.mean(losses))
        print(train_metrics['loss'][-1])
    return train_metrics

In [61]:
model = RMTPP(cfg.rmtpp)
opt = optim.Adam(model.parameters(), lr=cfg.training.lr)
train_rmtpp(train_loader, model, cfg)

Epoch 1/10....
-1.596625697785331
Epoch 2/10....
-3.539762474158231
Epoch 3/10....
-1.7884533545550179
Epoch 4/10....
-1.279202583958121
Epoch 5/10....
-1.736465639927808
Epoch 6/10....


KeyboardInterrupt: 

In [39]:
def train_rnnsm(data_loader, model, cfg: DictConfig):
    train_metrics = defaultdict(list)

    for epoch in range(cfg.training.n_epochs):
        print(f'Epoch {epoch+1}/{cfg.training.n_epochs}....')
        losses = []
        for times, events, lengths, _ in data_loader:
            time_deltas = torch.cat([torch.zeros(times.shape[0], 1, times.shape[2]),
                                     times[:, 1:] - times[:, :-1]],
                                     dim=1)
            o_t = model(events, time_deltas, lengths)
            padding_mask = torch.isclose(times, torch.Tensor([0]))
            ret_mask = ~padding_mask
            loss = model.compute_loss(time_deltas, padding_mask, ret_mask, o_t)
            loss.backward()
            opt.step()
            opt.zero_grad()
            losses.append(loss.item())
        train_metrics['loss'].append(np.mean(losses))
        print(train_metrics['loss'][-1])
    return train_metrics

In [40]:
model = RNNSM(cfg.rnnsm)
opt = optim.Adam(model.parameters(), lr=cfg.training.lr)
train_rnnsm(data_loader, model, cfg)

Epoch 1/10....
-3.1207775065397962
Epoch 2/10....
-6.634248429156364
Epoch 3/10....
-7.0342806045045245
Epoch 4/10....
-7.286588881878143
Epoch 5/10....
-7.377609050020259
Epoch 6/10....
-7.448889072905195
Epoch 7/10....
-7.542163503930924
Epoch 8/10....
-7.602628535412728
Epoch 9/10....
-7.629863891195743
Epoch 10/10....
-7.657119903158634


defaultdict(list,
            {'loss': [-3.1207775065397962,
              -6.634248429156364,
              -7.0342806045045245,
              -7.286588881878143,
              -7.377609050020259,
              -7.448889072905195,
              -7.542163503930924,
              -7.602628535412728,
              -7.629863891195743,
              -7.657119903158634]})