In [1]:
!pip install -U pytorch-lightning



In [2]:
import numpy as np

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import AdamW, SGD
from torch.optim.lr_scheduler import LinearLR

import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

In [3]:
from din import DIN
from data import DINDataset

In [4]:
class PLModel(pl.LightningModule):
    def __init__(self, device='cpu'):
        super().__init__()
        self._device = device
        self.model = DIN(device=device)

    def forward(self, user, trg_movie, trg_genre, hist_movie, hist_genre, mask_id):
        return self.model(user, trg_movie, trg_genre, hist_movie, hist_genre, mask_id)

    def training_step(self, batch, batch_idx):
        user, trg_movie, trg_genre, hist_movie, hist_genre, mask_id, label = map(lambda x:x.to(self._device), batch)
        outputs = self(user, trg_movie, trg_genre, hist_movie, hist_genre, mask_id)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(outputs, label)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        user, trg_movie, trg_genre, hist_movie, hist_genre, mask_id, label = map(lambda x:x.to(self._device), batch)
        outputs = self(user, trg_movie, trg_genre, hist_movie, hist_genre, mask_id)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(outputs, label)
        prob = torch.softmax(outputs, dim=-1)
        pred = torch.argmax(prob, dim=-1)
        accuracy = (pred == label).sum().item() / label.size(0)
        metrics = {
            'val_loss': loss,
            'val_accuracy': accuracy
        }
        self.log_dict(metrics, prog_bar=True)

    def test_step(self, batch, batch_idx):
        user, trg_movie, trg_genre, hist_movie, hist_genre, mask_id, label = map(lambda x:x.to(self._device), batch)
        outputs = self(user, trg_movie, trg_genre, hist_movie, hist_genre, mask_id)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(outputs, label)
        prob = torch.softmax(outputs, dim=-1)
        pred = torch.argmax(prob, dim=-1)
        accuracy = (pred == label).sum().item() / label.size(0)
        metrics = {
            'test_loss': loss,
            'test_accuracy': accuracy
        }
        self.log_dict(metrics)

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=1e-2)
        scheduler = LinearLR(optimizer)
        return [optimizer], [scheduler]

In [6]:
train_ds = DINDataset("./data/train.pkl")
val_ds = DINDataset("./data/test.pkl")
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=128)

din = PLModel('cuda')
trainer = pl.Trainer(accelerator='gpu', 
                     devices=1, 
                     callbacks=[EarlyStopping(monitor='val_loss', mode='min')],
                     min_epochs=5,
                     max_epochs=20,
                     gradient_clip_val=1.0
                     )
trainer.fit(din, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | DIN  | 528 K 
-------------------------------
528 K     Trainable params
0         Non-trainable params
528 K     Total params
2.116     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [19]:
%load_ext tensorboard

%tensorboard --logdir="lightning_logs/version_1"

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6008 (pid 17820), started 0:03:04 ago. (Use '!kill 17820' to kill it.)

In [10]:
torch.save(din, "model_full.pth")