# Diffusion Language Models

In [None]:
import random
import requests

import torch
from torch.utils.data import DataLoader

from datasets import Dataset
from transformers import AutoTokenizer

from aptorch.dlm import DLM, pretraining

In [None]:
url = "https://dmf.unicatt.it/~della/pythoncourse18/commedia.txt"
response = requests.get(url)
raw_data = response.text
sentences = [s.strip() for s in raw_data.replace("\n", "").split("\r") if s]
train_size = int(len(sentences) * 0.8)

train_set = Dataset.from_dict({'text': sentences[:train_size]})
test_set = Dataset.from_dict({'text': sentences[train_size:]})

tokenizer = AutoTokenizer.from_pretrained('dbmdz/bert-base-italian-cased')
tokenizer.padding_side = 'right'

def collate_fn(batch):
    texts = [b["text"] for b in batch]
    enc_batch = tokenizer(
        texts,
        add_special_tokens=False,
        padding=True,
        return_tensors='pt',
    )
    return enc_batch.input_ids


In [None]:
lr = 1e-3
n_epochs = 10
batch_size = 16
emb_dim = 32
ff_dim = 64
mask_ratio = random.uniform(0.01, 0.99)
pad_token_id = tokenizer.pad_token_id
mask_token_id = tokenizer.mask_token_id
num_tokens = tokenizer.vocab_size

print(">> Pretraining step:")
print(f"mask_ratio: {mask_ratio}")
print(f"pad_token_id: {pad_token_id}")
print(f"mask_token_id: {mask_token_id}")
print(f"num_tokens: {num_tokens}")

model: DLM = pretraining(
    training_set=train_set,
    collate_fn=collate_fn,
    lr=lr,
    n_epochs=n_epochs,
    batch_size=batch_size,
    emb_dim=emb_dim,
    ff_dim=ff_dim,
    mask_ratio=mask_ratio,
    pad_idx=pad_token_id,
    mask_idx=mask_token_id,
    num_tokens=num_tokens,
)


In [None]:
x = train_set[4]["text"]
print(x)
x_enc = tokenizer(x, add_special_tokens=False, return_tensors='pt').input_ids
print(x_enc)
x_part = x_enc[:,:-3]

print(tokenizer.batch_decode(x_enc))
print(tokenizer.batch_decode(x_part))

max_seq_len = x_enc.shape[1]
prompt_len = x_part.shape[1]
print("max_seq_len", max_seq_len)
print("prompt_len", prompt_len)

model.eval()
sampled = model.sample(x_part, max_seq_len, 5)
tokenizer.batch_decode(sampled.tolist())