In [2]:
import random
from tqdm import tqdm

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = load_dataset("maiurilorenzo/divina-commedia", split="train")
train_size = int(len(dataset) * 0.8)

train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:]

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
tokenizer.enable_padding(pad_token="[PAD]", pad_id=0)
tokenizer.add_special_tokens(["[PAD]", "[UNK]", "[MASK]"])
# trainer = BpeTrainer(special_tokens=["[PAD]", "[UNK]", "[MASK]"])
trainer = BpeTrainer()

tokenizer.train_from_iterator(
    train_dataset["text"],
    trainer=trainer,
)


class DivinaCommediaDataset(Dataset):

    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset["text"])

    def __getitem__(self, index):
        return self.dataset["text"][index]


def collate_fn(batch):
    inputs = torch.tensor([enc.ids for enc in tokenizer.encode_batch(batch)])
    return inputs


train_set = DivinaCommediaDataset(dataset=train_dataset)
test_set = DivinaCommediaDataset(dataset=test_dataset)

Generating train split: 100%|██████████| 14233/14233 [00:00<00:00, 1267894.16 examples/s]







In [4]:
import numpy as np


def get_emb(sin_inp):
    """
    Gets a base embedding for one dimension with sin and cos intertwined
    """
    emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
    return torch.flatten(emb, -2, -1)


class PositionalEncoding(nn.Module):
    def __init__(self, channels, dtype_override=None):
        super(PositionalEncoding, self).__init__()
        self.org_channels = channels
        channels = int(np.ceil(channels / 2) * 2)
        inv_freq = 1.0 / \
            (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer("inv_freq", inv_freq)
        self.register_buffer("cached_penc", None, persistent=False)
        self.channels = channels
        self.dtype_override = dtype_override

    def forward(self, tensor):
        if len(tensor.shape) != 3:
            raise RuntimeError("The input tensor has to be 3d!")

        if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
            return self.cached_penc

        self.cached_penc = None
        batch_size, x, orig_ch = tensor.shape
        pos_x = torch.arange(x, device=tensor.device,
                             dtype=self.inv_freq.dtype)
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        emb_x = get_emb(sin_inp_x)
        emb = torch.zeros(
            (x, self.channels),
            device=tensor.device,
            dtype=(
                self.dtype_override if self.dtype_override is not None else tensor.dtype
            ),
        )
        emb[:, : self.channels] = emb_x

        self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1)
        return self.cached_penc

In [7]:
class DLM(nn.Module):
    """https://arxiv.org/pdf/2502.09992
    """

    def __init__(
        self,
        num_tokens: int,
        emb_dim: int,
        ff_dim: int,
        pad_idx: int,
        mask_idx: int,
    ):
        super().__init__()
        self.num_tokens = num_tokens
        self.emb_dim = emb_dim
        self.ff_dim = ff_dim
        self.pad_idx = pad_idx
        self.mask_idx = mask_idx

        self.emb_token = nn.Embedding(
            num_embeddings=self.num_tokens,
            embedding_dim=self.emb_dim,
            padding_idx=self.pad_idx,
        )
        self.emb_time = PositionalEncoding(self.emb_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim=self.emb_dim,
            num_heads=1,
            dropout=0.1,
        )
        self.ff = nn.Sequential(
            nn.Linear(self.emb_dim, self.ff_dim),
            nn.GELU(),
            nn.Linear(self.ff_dim, self.emb_dim),
            nn.Dropout(0.1),
        )
        self.norm1 = nn.LayerNorm(self.emb_dim)
        self.norm2 = nn.LayerNorm(self.emb_dim)
        self.dropout = nn.Dropout(0.1)
        self.logits = nn.Linear(self.emb_dim, self.num_tokens)

    def forward(self, x, mask_ratio: float):
        batch_size, seq_len = x.shape
        mask_probs = torch.rand(batch_size, seq_len)
        mask = mask_probs < mask_ratio
        mask = mask & (x != self.pad_idx)
        x = torch.where(mask, self.mask_idx, x)

        x = self.emb_token(x)
        x = self.emb_time(x)
        attn_output, _ = self.attn(
            self.norm1(x),
            self.norm1(x),
            self.norm1(x),
            attn_mask=None,
        )
        x = x + self.dropout(attn_output)
        ff_output = self.ff(self.norm2(x))
        x = x + self.dropout(ff_output)
        logits = self.logits(x)

        return logits, mask.int()


def llada_loss(inputs, logits_pred, mask):
    batch_size, seq_len, vocab_size = logits_pred.shape
    pred_flat = logits_pred.view(batch_size * seq_len, vocab_size)
    target_flat = inputs.view(batch_size * seq_len)
    loss = F.cross_entropy(pred_flat, target_flat, reduction='none')
    loss = loss.view(batch_size, seq_len)
    loss = loss * mask.float()
    loss = loss.sum()
    loss = loss / mask.sum()

    return loss


def train(
    lr: float,
    n_epochs: int,
    batch_size: int,
    emb_dim: int,
    ff_dim: int,
    mask_ratio: float,
    pad_idx: int,
    mask_idx: int,
    num_tokens: int,
):
    torch.manual_seed(23)
    for epoch in range(n_epochs):
        train_loader = DataLoader(
            train_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)

        model = DLM(
            num_tokens=num_tokens,
            emb_dim=emb_dim,
            ff_dim=ff_dim,
            pad_idx=pad_idx,
            mask_idx=mask_idx,
        )
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
        print(
            f"Number of parameters={sum(p.numel() for p in model.parameters() if p.requires_grad)}")

        model.train()
        running_loss = 0.
        for i, x in enumerate(pbar := tqdm(train_loader)):
            optimizer.zero_grad()
            logits, mask = model(x, mask_ratio)

            loss = torch.tensor(0.0)
            if mask.sum() != 0:
                loss = llada_loss(x, logits, mask) / mask_ratio
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                pbar.set_description(
                    f"epoch {epoch+1}/{n_epochs}: loss={running_loss/(i+1):.5f}")


In [None]:
lr = 1e-2
n_epochs = 1
batch_size = 32
emb_dim = 32
ff_dim = 512
mask_ratio = random.uniform(0.01, 0.99)
print(f"mask_ratio={mask_ratio}")
pad_token_id = (tokenizer.encode("[PAD]").ids)[0]
mask_token_id = (tokenizer.encode("[MASK]").ids)[0]
num_tokens = tokenizer.get_vocab_size()

train(
    lr=lr,
    n_epochs=n_epochs,
    batch_size=batch_size,
    emb_dim=emb_dim,
    ff_dim=ff_dim,
    mask_ratio=mask_ratio,
    pad_idx=pad_token_id,
    mask_idx=mask_token_id,
    num_tokens=num_tokens,
)