In [286]:
import torch
from torch import nn
import torch.nn.functional as F

from datasets import load_dataset, Dataset

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

from torch.utils.data import Dataset, DataLoader

In [119]:
dataset = load_dataset("maiurilorenzo/divina-commedia", split="train")
train_size = int(len(dataset) * 0.8)

train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:]

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
tokenizer.enable_padding(pad_token="[PAD]")
trainer = BpeTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])

tokenizer.train_from_iterator(
    train_dataset["text"],
    trainer=trainer,
)







In [284]:
class DivinaCommediaDataset(Dataset):

    def __init__(self, dataset, tokenizer):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset["text"])

    def __getitem__(self, index):
        return self.dataset["text"][index]


def collate_fn(batch):
    inputs = torch.tensor([enc.ids for enc in tokenizer.encode_batch(batch)])
    mask = torch.randint(0, 2, inputs.shape)
    return inputs, mask


train_set = DivinaCommediaDataset(dataset=train_dataset, tokenizer=tokenizer)
test_set = DivinaCommediaDataset(dataset=test_dataset, tokenizer=tokenizer)

In [None]:
num_tokens = tokenizer.get_vocab_size()


class DLM(nn.Module):

    def __init__(self, num_tokens: int, emb_dim: int):
        super().__init__()
        self.num_tokens = num_tokens
        self.emb_dim = emb_dim
        self.emb_token = nn.Embedding(
            num_embeddings=self.num_tokens,
            embedding_dim=self.emb_dim,
        )

    def forward(self, x):
        x = self.emb_token(x)
        return x


def train(n_epochs: int, batch_size: int, emb_dim: int):
    torch.manual_seed(23)
    for epoch in range(n_epochs):
        train_loader = DataLoader(
            train_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)

        model = DLM(num_tokens=num_tokens, emb_dim=emb_dim)
        for x, m in train_loader:
            x_masked = x * m
            y_pred = model(x)