In [None]:
import re, os

import lucid
import lucid.nn as nn
import lucid.nn.functional as F
import lucid.optim as optim

from lucid._tensor import Tensor

from tokenizers import Tokenizer

import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = lucid.zeros(max_len, d_model)
        position = lucid.arange(0, max_len, dtype=lucid.Float32).unsqueeze(axis=1)
        div_term = lucid.exp(
            lucid.arange(0, d_model, 2, dtype=lucid.Float32)
            * (-lucid.log(1e4) / d_model)
        )

        pe[:, 0::2] = lucid.sin(position * div_term)
        pe[:, 1::2] = lucid.cos(position * div_term)

        pe = pe.unsqueeze(axis=0)
        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        seq_len = x.shape[1]
        x += self.pe[:, :seq_len, :]
        return self.dropout(x)

In [None]:
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size: int,
        tgt_vocab_size: int,
        d_model: int = 512,
        dim_feedforward: int = 2048,
        num_heads: int = 8,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dropout: float = 0.1,
        pad_id: int = 0,
        tie_weights: bool = True,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.pad_id = pad_id

        self.src_embedding = nn.Embedding(src_vocab_size, d_model, padding_idx=pad_id)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model, padding_idx=pad_id)

        self.positional_encoder = PositionalEncoding(d_model, dropout, max_len=5000)

        self.transformer = nn.Transformer(
            d_model=d_model,
            num_heads=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )

        self.out = nn.Linear(d_model, tgt_vocab_size, bias=not tie_weights)
        if tie_weights:
            self.out.weight = self.tgt_embedding.weight

        self.reset_parameters()

    def _reset_parameters(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform(m.weight)
                if m.bias is not None:
                    nn.init.constant(m.bias, 0.0)

    @staticmethod
    def _mask_padding_mask(tokens: Tensor, pad_id: int) -> Tensor:
        return tokens == pad_id

    @staticmethod
    def _mask_square_subseq_mask(sz: int) -> Tensor:
        return lucid.triu(lucid.full((sz, sz), -lucid.inf), diagonal=1)

    def forward(
        self,
        src: Tensor,
        tgt: Tensor,
        tgt_mask: Tensor | None = None,
        src_pad_mask: Tensor | None = None,
        tgt_pad_mask: Tensor | None = None,
    ) -> Tensor:
        device = src.device

        src_emb = self.src_embedding(src) * lucid.sqrt(self.d_model)
        tgt_emb = self.tgt_embedding(tgt) * lucid.sqrt(self.d_model)

        src_emb = self.positional_encoder(src_emb)
        tgt_emb = self.positional_encoder(tgt_emb)

        if tgt_mask is None:
            T = tgt_emb.shape[1]
            tgt_mask = self._mask_square_subseq_mask(T).to(device)
            tgt_mask = tgt_mask.astype(bool)

        if src_pad_mask is None:
            src_pad_mask = self._mask_padding_mask(src, self.pad_id)
        if tgt_pad_mask is None:
            tgt_pad_mask = self._mask_padding_mask(tgt, self.pad_id)

        x = self.transformer(
            src=src_emb,
            tgt=tgt_emb,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_pad_mask,
            tgt_key_padding_mask=tgt_pad_mask,
            mem_key_padding_mask=src_pad_mask,
        )

        logits = self.out(x)
        return logits

In [None]:
class TransformerScheduler(optim.lr_scheduler.LRScheduler):
    def __init__(
        self,
        optimizer: optim.Optimizer,
        d_model: int,
        warmup_steps: int = 4000,
        last_epoch: int = -1,
    ) -> None:
        super().__init__(optimizer, last_epoch)
        self.d_model = d_model
        self.warmup_steps = warmup_steps

    def get_lr(self) -> list[float]:
        step = max(1, self._step_count)
        scale = self.d_model**0.5

        arg1 = step**-0.5
        arg2 = step * (self.warmup_steps**-1.5)

        lr = scale * min(arg1, arg2)
        return [lr for _ in self.base_lrs]

In [None]:
tokenizer = Tokenizer.from_file("data/tokenizer.json")
vocab_size = tokenizer.get_vocab_size()
max_length = 40

In [None]:
PAD_ID = tokenizer.token_to_id("[PAD]")
START_ID = tokenizer.token_to_id("[START]")
END_ID = tokenizer.token_to_id("[END]")