In [24]:
from typing import Optional

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from datasets import load_dataset, Dataset
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

In [3]:
dataset = load_dataset("maiurilorenzo/divina-commedia", split="train")
train_size = int(len(dataset) * 0.8)

train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:]

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
tokenizer.enable_padding(pad_token="[PAD]")
trainer = BpeTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])

tokenizer.train_from_iterator(
    train_dataset["text"],
    trainer=trainer,
)







In [4]:
class DivinaCommediaDataset(Dataset):

    def __init__(self, dataset, tokenizer):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset["text"])

    def __getitem__(self, index):
        return self.dataset["text"][index]


def collate_fn(batch):
    inputs = torch.tensor([enc.ids for enc in tokenizer.encode_batch(batch)])
    mask = torch.randint(0, 2, inputs.shape)
    return inputs, mask


train_set = DivinaCommediaDataset(dataset=train_dataset, tokenizer=tokenizer)
test_set = DivinaCommediaDataset(dataset=test_dataset, tokenizer=tokenizer)

In [None]:
import numpy as np


def get_emb(sin_inp):
    """
    Gets a base embedding for one dimension with sin and cos intertwined
    """
    emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
    return torch.flatten(emb, -2, -1)


class PositionalEncoding(nn.Module):
    def __init__(self, channels, dtype_override=None):
        super(PositionalEncoding, self).__init__()
        self.org_channels = channels
        channels = int(np.ceil(channels / 2) * 2)
        inv_freq = 1.0 / \
            (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer("inv_freq", inv_freq)
        self.register_buffer("cached_penc", None, persistent=False)
        self.channels = channels
        self.dtype_override = dtype_override

    def forward(self, tensor):
        if len(tensor.shape) != 3:
            raise RuntimeError("The input tensor has to be 3d!")

        if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
            return self.cached_penc

        self.cached_penc = None
        batch_size, x, orig_ch = tensor.shape
        pos_x = torch.arange(x, device=tensor.device,
                             dtype=self.inv_freq.dtype)
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        emb_x = get_emb(sin_inp_x)
        emb = torch.zeros(
            (x, self.channels),
            device=tensor.device,
            dtype=(
                self.dtype_override if self.dtype_override is not None else tensor.dtype
            ),
        )
        emb[:, : self.channels] = emb_x

        self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1)
        return self.cached_penc

In [42]:
num_tokens = tokenizer.get_vocab_size()


class DLM(nn.Module):

    def __init__(self, num_tokens: int, emb_dim: int, ff_dim: int = 512):
        super().__init__()
        self.num_tokens = num_tokens
        self.emb_dim = emb_dim
        self.ff_dim = ff_dim

        self.emb_token = nn.Embedding(
            num_embeddings=self.num_tokens,
            embedding_dim=self.emb_dim,
            padding_idx=0,
        )
        self.emb_time = PositionalEncoding(self.emb_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim=self.emb_dim,
            num_heads=1,
            dropout=0.1,
        )
        self.ff = nn.Sequential(
            nn.Linear(self.emb_dim, self.ff_dim),
            nn.GELU(),
            nn.Linear(self.ff_dim, self.emb_dim),
            nn.Dropout(0.1),
        )
        self.norm1 = nn.LayerNorm(self.emb_dim)
        self.norm2 = nn.LayerNorm(self.emb_dim)
        self.dropout = nn.Dropout(0.1)
        self.logits = nn.Linear(self.emb_dim, self.num_tokens)

    def forward(self, x):
        x = self.emb_token(x)
        x = self.emb_time(x)
        attn_output, _ = self.attn(
            self.norm1(x),
            self.norm1(x),
            self.norm1(x),
            attn_mask=None,
        )
        x = x + self.dropout(attn_output)
        ff_output = self.ff(self.norm2(x))
        x = x + self.dropout(ff_output)
        x = self.logits(x)

        return x


def train(n_epochs: int, batch_size: int, emb_dim: int):
    torch.manual_seed(23)
    for epoch in range(n_epochs):
        train_loader = DataLoader(
            train_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)

        model = DLM(num_tokens=num_tokens, emb_dim=emb_dim)
        for x, m in train_loader:
            x_masked = x * m
            y_pred = model(x_masked)


train(n_epochs=1, batch_size=8, emb_dim=32)