In [None]:
# ! pip install pytorch_lightning

In [None]:
# Standard libraries
import math
import os
import urllib.request
from functools import partial
from urllib.error import HTTPError

# Plotting
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

# PyTorch Lightning
import pytorch_lightning as pl
import seaborn as sns

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

# Torchvision
import torchvision
from IPython.display import set_matplotlib_formats
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision import transforms
from torchvision.datasets import CIFAR100
from tqdm.notebook import tqdm

In [None]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        attn_logits.masked_fill_(mask == 0, -9e15)
    attention = F.softmax(attn_logits)
    values = torch.matmul(attention, v)
    return values, attention

In [None]:
class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

    def _reset_parameters_(self):
        nn.init.xavier_uniform_(self.qkv_proj)
        nn.init.xavier_uniform_(self.o_proj)

    def forward(self, x, mask=None, return_attention=False):
        batch, seq_length, embed_dim = x.shape()
        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)

        values, attention = scaled_dot_product(q, k, v, mask)
        values = values.permute(0, 2, 1, 3)
        values.reshape(batch_size, seq_length, embed_dim)
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0):
        super().__init__()
        self.input_dim = input_dim
        self.num_heads = num_heads

        self.multihead_attention = MultiheadAttention(input_dim, input_dim, num_heads)
        self.dropout = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(input_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(input_dim, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, input_dim)
        )
        self.norm2 = nn.LayerNorm(input_dim)

    def forward(self, x, mask=None):
        attention_out = self.multihead_attention(x, mask)
        x = x + self.dropout(attention_out)
        x = self.norm1(x)
        linear_out = self.feed_forward(x)
        x = x + self.dropout(linear_out)
        x = self.norm2(x)
        return x

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, **block_args):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask=mask)
        return x

    def get_attention_maps(self, x, mask=None):
        attention_maps = []
        for layer in self.layers:
            _, attn_map = layer.multihead_attention(x, mask=mask, return_attention=True)
            attention_maps.append(attn_map)
            x = layer(x, mask=mask)
        return attention_maps

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(-1 * torch.arange(0, max_len, 2).float() * math.log(1000) / d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe, persistent=False)

    def forward(self, x):
        x = x + self.pe[:, :x.shape(1)]
        return x

In [None]:
class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup, max_itrs):
        super().__init__(optimizer)
        self.warmup = warmup
        self.max_itrs = max_itrs

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_itrs))
        if epoch <= self.warmup:
            lr_factor *= epoch * 1.0 / self.warmup
        return lr_factor

In [None]:
class TransformerPredictor(pl.LightningModule):

    def __init__(self, input_dim, model_dim, num_classes, num_heads, num_layers, lr, warmup, max_iters, dropout=0.0, input_dropout=0.0):
        super().__init__()
        self.save_hyperparameters()
        self.__create_model()

    def __create_model(self):
        self.input_nn = nn.Sequential(
            nn.Linear(self.hparams.input_dim, self.hparams.model_dim),
            nn.Dropout(self.hparams.input_dropout)
        )
        self.position_encoding = PositionalEncoding(d_model=self.hparams.input_dim)
        self.transformer = TransformerEncoder(num_layers=self.hparams.num_layers, 
                                              input_dim=self.hparams.model_dim, 
                                              num_heads=self.hparams.num_heads, 
                                              dim_feedforward=self.hparams.model_dim, 
                                              dropout=self.hparams.dropout)
        self.output_nn = nn.Sequential(
            nn.Linear(self.hparams.model_dim, self.hparams.model_dim),
            nn.LayerNorm(self.hparams.model_dim),
            nn.LeakyReLU(),
            nn.Dropout(self.hparams.dropout),
            nn.Linear(self.hparams.model_dim, self.hparams.num_classes)
        )

    def forward(self, x, mask=None, add_positional_encoding=True):
        x = self.input_nn(x)
        if add_positional_encoding:
            x = self.position_encoding(x)
        x = self.tranformer(x, mask)
        x = self.output_nn(x)
        return x
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        self.lr_scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=self.hparams.warmup, max_itrs=self.hparams.max_itrs)
        return optimizer

    def optimizer_step(self, *args, **kwargs):
        super().optimizer_step(*args, **kwargs)
        self.lr_scheduler.step()