# Diffusion Language Models

In [None]:
import random

import torch
import torch.nn.functional as F

from aptorch.data import (
    DivinaCommediaDataset,
    divina_commedia,
    divina_commedia_tokenizer,
)
from aptorch.dlm import DLM, pretraining

In [None]:
train_dataset, test_dataset = divina_commedia()
tokenizer = divina_commedia_tokenizer(train_dataset)
train_set = DivinaCommediaDataset(dataset=train_dataset)
test_set = DivinaCommediaDataset(dataset=test_dataset)


def collate_fn(batch):
    prompts = [tup[0] for tup in batch]
    responses = [tup[1] for tup in batch]
    prompts_enc = torch.tensor(
        [enc.ids for enc in tokenizer.encode_batch(prompts)])
    responses_enc = torch.tensor(
        [enc.ids for enc in tokenizer.encode_batch(responses)])
    return prompts_enc, responses_enc


In [None]:
lr = 1e-2
n_epochs = 20
batch_size = 8
emb_dim = 16
ff_dim = 32
mask_ratio = random.uniform(0.01, 0.99)
pad_token_id = (tokenizer.encode("[PAD]").ids)[0]
mask_token_id = (tokenizer.encode("[MASK]").ids)[0]
num_tokens = tokenizer.get_vocab_size()

print(">> Pretraining step:")
print(f"mask_ratio: {mask_ratio}")
print(f"pad_token_id: {pad_token_id}")
print(f"mask_token_id: {mask_token_id}")
print(f"num_tokens: {num_tokens}")

model: DLM = pretraining(
    training_set=train_set,
    collate_fn=collate_fn,
    lr=lr,
    n_epochs=n_epochs,
    batch_size=batch_size,
    emb_dim=emb_dim,
    ff_dim=ff_dim,
    mask_ratio=mask_ratio,
    pad_idx=pad_token_id,
    mask_idx=mask_token_id,
    num_tokens=num_tokens,
)

In [None]:
from torch.utils.data import DataLoader

model.eval()
num_sampling_steps = 10

train_loader = DataLoader(
    train_set, collate_fn=collate_fn, batch_size=1, shuffle=False)
x, y = next(iter(train_loader))
x_masked = x.clone()
x_masked[:,-3:] = mask_token_id

print(tokenizer.decode_batch(x.tolist()))
print(x.tolist())
print(x_masked.tolist())

max_seq_len = x_masked.shape[1]
prompt_len = x_masked.shape[1] - 3
initial_response_len = max_seq_len - prompt_len
masked_response_part = torch.full(
    (1, initial_response_len), mask_token_id, dtype=torch.long)

current_sequence = x_masked #torch.cat((x, masked_response_part), dim=-1)
response_indices_slice = slice(prompt_len, max_seq_len)
timesteps = torch.linspace(1.0, 0.0, num_sampling_steps + 1)
# print("timesteps", timesteps)
for step_idx in range(num_sampling_steps):
    current_t_val = 1.0 - (step_idx / num_sampling_steps)
    next_t_val = 1.0 - ((step_idx + 1) / num_sampling_steps)
    logits = model(current_sequence)
    predicted_tokens_all = torch.argmax(logits, dim=-1)
    masked_in_response = (
        current_sequence[:, response_indices_slice] == mask_token_id)
    r0_candidate = current_sequence.clone()
    r0_candidate[:, response_indices_slice] = torch.where(
        masked_in_response,
        predicted_tokens_all[:, response_indices_slice],
        current_sequence[:, response_indices_slice]
    )
    num_tokens_to_be_masked_in_next_step = int(
        initial_response_len * next_t_val)
    num_tokens_to_be_masked_in_next_step = max(
        0, num_tokens_to_be_masked_in_next_step)
    next_sequence_step = r0_candidate.clone()
    response_logits = logits[:, response_indices_slice, :].squeeze(0)
    response_probs = F.softmax(response_logits, dim=-1)
    predicted_tokens_response = predicted_tokens_all[:, response_indices_slice].squeeze(0)
    # low confidence remasking strategy
    predicted_confidence = response_probs.gather(1, predicted_tokens_response.unsqueeze(-1)).squeeze(-1)
    sorted_confidences, sorted_indices_in_response = torch.sort(predicted_confidence, descending=False)
    relative_indices_to_remask = sorted_indices_in_response[:num_tokens_to_be_masked_in_next_step]
    full_indices_to_remask = response_indices_slice.start + relative_indices_to_remask
    next_sequence_step[:, full_indices_to_remask] = mask_token_id
    current_sequence = next_sequence_step
    
    if (current_sequence[:, response_indices_slice] != mask_token_id).all():
        break


tokenizer.decode_batch(current_sequence.tolist())