In [1]:
import random
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import DataLoader

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

from aptorch.data import DivinaCommediaDataset, divina_commedia
from aptorch.dlm import DLM, llada_loss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_dataset, test_dataset = divina_commedia()

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
tokenizer.enable_padding(pad_token="[PAD]", pad_id=0)
tokenizer.add_special_tokens(["[PAD]", "[UNK]", "[MASK]"])
trainer = BpeTrainer()

tokenizer.train_from_iterator(
    train_dataset["text"],
    trainer=trainer,
)


def collate_fn(batch):
    inputs = torch.tensor([enc.ids for enc in tokenizer.encode_batch(batch)])
    return inputs


train_set = DivinaCommediaDataset(dataset=train_dataset)
test_set = DivinaCommediaDataset(dataset=test_dataset)






In [3]:

def train(
    model: nn.Module,
    optim: torch.optim.Optimizer,
    lr: float,
    n_epochs: int,
    batch_size: int,
    emb_dim: int,
    ff_dim: int,
    mask_ratio: float,
    pad_idx: int,
    mask_idx: int,
    num_tokens: int,
):
    torch.manual_seed(23)
    for epoch in range(n_epochs):
        train_loader = DataLoader(
            train_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)

        model.train()
        running_loss = 0.
        for i, x in enumerate(pbar := tqdm(train_loader)):
            optim.zero_grad()
            logits, mask = model(x, mask_ratio)

            loss = torch.tensor(0.0)
            if mask.sum() != 0:
                loss = llada_loss(x, logits, mask) / mask_ratio
                loss.backward()
                optim.step()
                running_loss += loss.item()
                pbar.set_description(
                    f"epoch {epoch+1}/{n_epochs}: loss={running_loss/(i+1):.5f}")

In [None]:
lr = 1e-3
n_epochs = 1
batch_size = 32
emb_dim = 32
ff_dim = 512
mask_ratio = random.uniform(0.01, 0.99)
print(f"mask_ratio={mask_ratio}")
pad_token_id = (tokenizer.encode("[PAD]").ids)[0]
mask_token_id = (tokenizer.encode("[MASK]").ids)[0]
num_tokens = tokenizer.get_vocab_size()

model = DLM(
    num_tokens=num_tokens,
    emb_dim=emb_dim,
    ff_dim=ff_dim,
    pad_idx=pad_token_id,
    mask_idx=mask_token_id,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
print(
    f"Number of parameters={sum(p.numel() for p in model.parameters() if p.requires_grad)}")

train(
    model=model,
    optim=optimizer,
    lr=lr,
    n_epochs=n_epochs,
    batch_size=batch_size,
    emb_dim=emb_dim,
    ff_dim=ff_dim,
    mask_ratio=mask_ratio,
    pad_idx=pad_token_id,
    mask_idx=mask_token_id,
    num_tokens=num_tokens,
)