# Diffusion Language Models

In [None]:
import requests

from datasets import Dataset
from transformers import AutoTokenizer

from torch.utils.data import DataLoader

import lightning.pytorch as L
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import EarlyStopping

import optuna
from optuna.integration import PyTorchLightningPruningCallback

from aptorch.dlm import DLM_Pretrained

seed_everything(42, workers=True)

In [None]:
url = "https://dmf.unicatt.it/~della/pythoncourse18/commedia.txt"
response = requests.get(url)
raw_data = response.text
sentences = [s.strip() for s in raw_data.replace("\n", "").split("\r") if s]
design_size = int(len(sentences) * 0.8)
valid_size = int(design_size * 0.2)

train_set = Dataset.from_dict({'text': sentences[:design_size-valid_size]})
valid_set = Dataset.from_dict(
    {'text': sentences[design_size-valid_size:design_size]})
test_set = Dataset.from_dict({'text': sentences[design_size:]})
print(f"total={len(sentences)}, train={len(train_set)}, valid={len(valid_set)}, test={len(test_set)}")

tokenizer = AutoTokenizer.from_pretrained('dbmdz/bert-base-italian-cased')
tokenizer.padding_side = 'right'


def collate_fn(batch):
    texts = [b["text"] for b in batch]
    enc_batch = tokenizer(
        texts,
        add_special_tokens=False,
        padding=True,
        return_tensors='pt',
    )
    return enc_batch.input_ids

In [None]:
def objective(trial: optuna.trial.Trial) -> float:
    params = {
        "lr": trial.suggest_float("lr", 1e-4, 1e-1),
        "emb_dim": trial.suggest_int("emb_dim", 16, 128),
        "ff_dim": trial.suggest_int("ff_dim", 16, 128),
        "mask_ratio": trial.suggest_uniform("mask_ratio", 0.01, 0.99),
        "pad_idx": tokenizer.pad_token_id,
        "mask_idx": tokenizer.mask_token_id,
        "num_tokens": tokenizer.vocab_size,
    }
    train_loader = DataLoader(
        train_set,
        collate_fn=collate_fn,
        batch_size=16,
        shuffle=True,
    )
    valid_loader = DataLoader(
        valid_set,
        collate_fn=collate_fn,
        batch_size=16,
        shuffle=False,
    )

    model = DLM_Pretrained(**params)

    early_stopping_callback = EarlyStopping('val_loss')
    priuning_callback = PyTorchLightningPruningCallback(
        trial, monitor="val_loss")

    trainer = L.Trainer(
        max_epochs=100,
        devices="auto",
        deterministic=True,
        callbacks=[
            early_stopping_callback,
            priuning_callback,
        ],
        # fast_dev_run=True,
    )
    trainer.logger.log_hyperparams(params)
    trainer.fit(
        model=model,
        train_dataloaders=train_loader,
        val_dataloaders=valid_loader,
    )
    priuning_callback.check_pruned()

    return trainer.callback_metrics["val_loss"].item()


pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(
    study_name="pl_ddp",
    direction="maximize",
    pruner=pruner,
    load_if_exists=True,
)
study.optimize(objective, n_trials=100, timeout=600)
print(f"Number of finished trials: {len(study.trials)}")
print(f"Best trial: {study.best_trial}")

In [None]:
# test_loader = DataLoader(
#     test_set,
#     collate_fn=collate_fn,
#     batch_size=1,
#     shuffle=False,
# )
# for x in test_loader:
#     max_seq_len = x.shape[1]
#     sampling_steps = 10
#     x_partial = x[:, :-5]
#     print("----")
#     print("original", tokenizer.batch_decode(x.tolist()))
#     print("input", tokenizer.batch_decode(x_partial.tolist()))
#     output = model.sample(x_partial, max_seq_len, sampling_steps)
#     print("output", tokenizer.batch_decode(output.tolist()))