In [1]:
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch import optim
from hydra.experimental import compose, initialize
from omegaconf import DictConfig
import numpy as np
from collections import defaultdict

%load_ext autoreload
%autoreload 2

In [2]:
from RNNSM.rnnsm import RNNSM
from RMTPP.rmtpp import RMTPP

In [3]:
class ATMDataset(Dataset):
    def __init__(self, path, max_len=500, normalization='none'):
        super().__init__()
        self.max_len = max_len
        data = pd.read_csv(path)
        self.ids = data['id']
        self.times = data['time']
        self.events = data['event'] + 1
        self.time_seqs, self.event_seqs = self.generate_sequence()
        
            

    def generate_sequence(self):
        time_seqs = []
        event_seqs = []
        cur_start, cur_end = 0, 1
        
        # нарезка датасета по последовательностям с одним id и максимальной длиной self.max_len
        while cur_start < len(self.ids):
            if cur_end < len(self.ids) and \
                    self.ids[cur_start] == self.ids[cur_end] and \
                    cur_end - cur_start < self.max_len:
                cur_end += 1
                continue
            else:
                time_seqs.append(torch.Tensor(self.times[cur_start:cur_end].to_numpy()))
                event_seqs.append(torch.LongTensor(self.events[cur_start:cur_end].to_numpy()))
                cur_start, cur_end = cur_end, cur_end + 1

        return time_seqs, event_seqs

    def __getitem__(self, item):
        return self.time_seqs[item], self.event_seqs[item]

    def __len__(self):
        return len(self.time_seqs)

    
def pad_collate(batch):
  (xx, yy) = zip(*batch)
  x_lens = [len(x) for x in xx]
  y_lens = [len(y) for y in yy]
  xx_pad = pad_sequence(xx, batch_first=True, padding_value=0)
  yy_pad = pad_sequence(yy, batch_first=True, padding_value=0)

  return xx_pad[..., None], yy_pad[..., None], x_lens, y_lens

In [29]:
ds = ATMDataset('data/ATM/train_day.csv')
data_loader = DataLoader(dataset=ds, batch_size=32, shuffle=True, collate_fn=pad_collate, drop_last=True)

In [5]:
initialize(config_path=".")

hydra.experimental.initialize()

In [19]:
cfg = compose(config_name="config.yaml")

In [30]:
def train_rmtpp(data_loader, model, cfg: DictConfig):
    train_metrics = defaultdict(list)

    for epoch in range(cfg.training.n_epochs):
        print(f'Epoch {epoch+1}/{cfg.training.n_epochs}....')
        losses = []
        for times, events, lengths, _ in data_loader:
            time_deltas = torch.cat([torch.zeros(times.shape[0], 1, times.shape[2]),
                                     times[:, 1:] - times[:, :-1]],
                                     dim=1)
            o_t, y_t = model(events, time_deltas, lengths)
            padding_mask = torch.isclose(times, torch.Tensor([0]))
            ret_mask = ~padding_mask
            loss = model.compute_loss(time_deltas, padding_mask, o_t, y_t[0], events)
            loss.backward()
            opt.step()
            opt.zero_grad()
            losses.append(loss.item())
        train_metrics['loss'].append(np.mean(losses))
        print(train_metrics['loss'][-1])
    return train_metrics

In [42]:
model = RMTPP(cfg.rmtpp)
opt = optim.Adam(model.parameters(), lr=cfg.training.lr)
train_rmtpp(data_loader, model, cfg)

Epoch 1/10....
-2.1617333507363465
Epoch 2/10....
-3.1642438644107354
Epoch 3/10....
-3.6396667602214405
Epoch 4/10....
-3.7764306677148696
Epoch 5/10....
-3.0389498878032604
Epoch 6/10....
-3.2186425619937005
Epoch 7/10....
-2.912486532901196
Epoch 8/10....
-3.494121346067875
Epoch 9/10....
-3.2129714057800616
Epoch 10/10....
-3.8959539481934082


defaultdict(list,
            {'loss': [-2.1617333507363465,
              -3.1642438644107354,
              -3.6396667602214405,
              -3.7764306677148696,
              -3.0389498878032604,
              -3.2186425619937005,
              -2.912486532901196,
              -3.494121346067875,
              -3.2129714057800616,
              -3.8959539481934082]})

In [39]:
def train_rnnsm(data_loader, model, cfg: DictConfig):
    train_metrics = defaultdict(list)

    for epoch in range(cfg.training.n_epochs):
        print(f'Epoch {epoch+1}/{cfg.training.n_epochs}....')
        losses = []
        for times, events, lengths, _ in data_loader:
            time_deltas = torch.cat([torch.zeros(times.shape[0], 1, times.shape[2]),
                                     times[:, 1:] - times[:, :-1]],
                                     dim=1)
            o_t = model(events, time_deltas, lengths)
            padding_mask = torch.isclose(times, torch.Tensor([0]))
            ret_mask = ~padding_mask
            loss = model.compute_loss(time_deltas, padding_mask, ret_mask, o_t)
            loss.backward()
            opt.step()
            opt.zero_grad()
            losses.append(loss.item())
        train_metrics['loss'].append(np.mean(losses))
        print(train_metrics['loss'][-1])
    return train_metrics

In [40]:
model = RNNSM(cfg.rnnsm)
opt = optim.Adam(model.parameters(), lr=cfg.training.lr)
train_rnnsm(data_loader, model, cfg)

Epoch 1/10....
-3.1207775065397962
Epoch 2/10....
-6.634248429156364
Epoch 3/10....
-7.0342806045045245
Epoch 4/10....
-7.286588881878143
Epoch 5/10....
-7.377609050020259
Epoch 6/10....
-7.448889072905195
Epoch 7/10....
-7.542163503930924
Epoch 8/10....
-7.602628535412728
Epoch 9/10....
-7.629863891195743
Epoch 10/10....
-7.657119903158634


defaultdict(list,
            {'loss': [-3.1207775065397962,
              -6.634248429156364,
              -7.0342806045045245,
              -7.286588881878143,
              -7.377609050020259,
              -7.448889072905195,
              -7.542163503930924,
              -7.602628535412728,
              -7.629863891195743,
              -7.657119903158634]})