In [23]:
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch import optim
from hydra.experimental import compose, initialize
from omegaconf import DictConfig
import numpy as np
from collections import defaultdict

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
from RNNSM.rnnsm import RNNSM
from RMTPP.rmtpp import RMTPP

In [25]:
class ATMDataset(Dataset):
    def __init__(self, path, max_len=500, normalization='none'):
        super().__init__()
        self.max_len = max_len
        data = pd.read_csv(path)
        self.ids = data['id']
        self.times = data['time']
        self.events = data['event'] + 1
        self.time_seqs, self.event_seqs = self.generate_sequence()
        
            

    def generate_sequence(self):
        time_seqs = []
        event_seqs = []
        cur_start, cur_end = 0, 1
        
        # нарезка датасета по последовательностям с одним id и максимальной длиной self.max_len
        while cur_start < len(self.ids):
            if cur_end < len(self.ids) and \
                    self.ids[cur_start] == self.ids[cur_end] and \
                    cur_end - cur_start < self.max_len:
                cur_end += 1
                continue
            else:
                time_seqs.append(torch.Tensor(self.times[cur_start:cur_end].to_numpy()))
                event_seqs.append(torch.LongTensor(self.events[cur_start:cur_end].to_numpy()))
                cur_start, cur_end = cur_end, cur_end + 1

        return time_seqs, event_seqs

    def __getitem__(self, item):
        return self.time_seqs[item], self.event_seqs[item]

    def __len__(self):
        return len(self.time_seqs)

    
def pad_collate(batch):
  (xx, yy) = zip(*batch)
  x_lens = [len(x) for x in xx]
  y_lens = [len(y) for y in yy]
  xx_pad = pad_sequence(xx, batch_first=True, padding_value=0)
  yy_pad = pad_sequence(yy, batch_first=True, padding_value=0)

  return xx_pad[..., None], yy_pad[..., None], x_lens, y_lens

In [26]:
ds = ATMDataset('data/ATM/train_day.csv')
data_loader = DataLoader(dataset=ds, batch_size=32, shuffle=True, collate_fn=pad_collate)

In [None]:
initialize(config_path=".")

In [69]:
cfg = compose(config_name="config.yaml")

In [102]:
def train_rmtpp(data_loader, model, cfg: DictConfig):
    train_metrics = defaultdict(list)

    for epoch in range(cfg.training.n_epochs):
        print(f'Epoch {epoch+1}/{cfg.training.n_epochs}....')
        
        for times, events, lengths, _ in data_loader:
            time_deltas = torch.cat([torch.zeros(times.shape[0], 1, times.shape[2]),
                                     times[:, 1:] - times[:, :-1]],
                                     dim=1)
            o_t, y_t = model(events, time_deltas, lengths)
            padding_mask = torch.isclose(times, torch.Tensor([0]))
            ret_mask = ~padding_mask
            loss = model.compute_loss(time_deltas, padding_mask, o_t, y_t[0], events)
            loss.backward()
            opt.step()
            opt.zero_grad()
            train_metrics['loss'].append(loss.item())
            #print(loss)
    return train_metrics

In [91]:
a = torch.rand(2, 3, 3)
a.flatten(0, -2).shape

torch.Size([6, 3])

In [106]:
model = RMTPP(cfg.rmtpp)
opt = optim.Adam(model.parameters(), lr=cfg.training.lr)
train_rmtpp(data_loader, model, cfg)

Epoch 1/5....
Time prediction loss: -3154.877197265625
Markers loss: 946.3025512695312
f_deltas.min(): 2.1765725219093985e-16
--------------------
Time prediction loss: nan
Markers loss: 366.3349914550781
f_deltas.min(): nan
--------------------
Time prediction loss: nan
Markers loss: nan
f_deltas.min(): nan
--------------------
Time prediction loss: nan
Markers loss: nan
f_deltas.min(): nan
--------------------
Time prediction loss: nan
Markers loss: nan
f_deltas.min(): nan
--------------------
Time prediction loss: nan
Markers loss: nan
f_deltas.min(): nan
--------------------
Time prediction loss: nan
Markers loss: nan
f_deltas.min(): nan
--------------------
Time prediction loss: nan
Markers loss: nan
f_deltas.min(): nan
--------------------
Time prediction loss: nan
Markers loss: nan
f_deltas.min(): nan
--------------------
Time prediction loss: nan
Markers loss: nan
f_deltas.min(): nan
--------------------
Time prediction loss: nan
Markers loss: nan
f_deltas.min(): nan
----------

KeyboardInterrupt: 

In [74]:
def train_rnnsm(data_loader, model, cfg: DictConfig):
    train_metrics = defaultdict(list)

    for epoch in range(cfg.training.n_epochs):
        print(f'Epoch {epoch+1}/{cfg.training.n_epochs}....')
        
        for times, events, lengths, _ in data_loader:
            time_deltas = torch.cat([torch.zeros(times.shape[0], 1, times.shape[2]),
                                     times[:, 1:] - times[:, :-1]],
                                     dim=1)
            o_t = model(events, time_deltas, lengths)
            print(o_t.shape)
            padding_mask = torch.isclose(times, torch.Tensor([0]))
            ret_mask = ~padding_mask
            loss = model.compute_loss(time_deltas, padding_mask, ret_mask, o_t)
            loss.backward()
            opt.step()
            opt.zero_grad()
            train_metrics['loss'].append(loss.item())
            print(loss)
    return train_metrics

In [79]:
model = RNNSM(cfg.rnnsm)
opt = optim.Adam(model.parameters(), lr=cfg.training.lr)
train_rnnsm(data_loader, model, cfg)

Epoch 1/5....
torch.Size([32, 500])
torch.Size([32, 500, 1]) torch.Size([32, 500]) torch.Size([32, 500, 1])
tensor(1115.8245, grad_fn=<AddBackward0>)
torch.Size([32, 500])
torch.Size([32, 500, 1]) torch.Size([32, 500]) torch.Size([32, 500, 1])
tensor(509.7157, grad_fn=<AddBackward0>)
torch.Size([32, 500])
torch.Size([32, 500, 1]) torch.Size([32, 500]) torch.Size([32, 500, 1])
tensor(-365.4796, grad_fn=<AddBackward0>)
torch.Size([32, 500])
torch.Size([32, 500, 1]) torch.Size([32, 500]) torch.Size([32, 500, 1])
tensor(-707.6506, grad_fn=<AddBackward0>)
torch.Size([32, 500])
torch.Size([32, 500, 1]) torch.Size([32, 500]) torch.Size([32, 500, 1])
tensor(-3218.4873, grad_fn=<AddBackward0>)
torch.Size([32, 500])
torch.Size([32, 500, 1]) torch.Size([32, 500]) torch.Size([32, 500, 1])
tensor(-2914.9414, grad_fn=<AddBackward0>)
torch.Size([32, 500])
torch.Size([32, 500, 1]) torch.Size([32, 500]) torch.Size([32, 500, 1])
tensor(-3494.3152, grad_fn=<AddBackward0>)
torch.Size([32, 500])
torch.Size

KeyboardInterrupt: 