In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader
from transformers import get_scheduler
from huggingface_hub import login
from datasets import load_dataset
from peft import get_peft_model
from peft import LoraConfig
import torch

In [None]:
token = "<YOUR_TOKEN>"
model_name = "meta-llama/Llama-3.2-1B-Instruct"

In [None]:
# HF login
login(token)

In [None]:
def make_attention_mask(labels:torch.Tensor, prompt_n_padding:int):
  causal_tokens = (labels == -100).count_nonzero().item()
  text_length = labels.shape[0]
  attention_matrix = torch.tril(torch.ones(text_length, text_length, dtype=torch.bool))  # type: ignore
  attention_matrix[:prompt_n_padding, :] = False
  attention_matrix[causal_tokens:, :] = True
  attention_matrix[:, :prompt_n_padding] = False
  return attention_matrix.unsqueeze(0)

def preprocess_dataset(record, tokenizer):
  """
  Tokenize the instruction-answer pairs and returns the prediction labels (the tokens that the model has to predict).
  """
  max_prompt_length = 100
  max_answer_length = 50
  question = record['instruction'] + record['input']
  answer = record['output']
  conversation = [{"role": "user", "content": question}]
  prompt_chat = tokenizer.apply_chat_template(conversation, return_tensors='pt', tokenize=False, add_generation_prompt=True)
  prompt_tokens = tokenizer(prompt_chat, padding='max_length', return_tensors='pt', padding_side="left", add_special_tokens=False, max_length=max_prompt_length, truncation=True)
  prompt_n_padding = (prompt_tokens['input_ids'][0] == 128000).nonzero(as_tuple=True)[0][0].item()
  starting_answer_idx = prompt_tokens['input_ids'].shape[1]
  response_tokens = tokenizer(answer, padding='max_length', return_tensors='pt', padding_side="right", add_special_tokens=False, max_length=max_answer_length, truncation=True)
  prompt_tokens['input_ids'] = torch.cat((prompt_tokens['input_ids'], response_tokens['input_ids']), dim=-1).squeeze(0)

  labels:torch.Tensor = prompt_tokens['input_ids'].clone()
  labels[:starting_answer_idx] = -100 # setting to -100 the tokens that won't be considered in the computation of the loss

  return {
    'input_ids': prompt_tokens['input_ids'],
    'attention_mask': make_attention_mask(labels, prompt_n_padding),
    "labels": labels,
  }

In [None]:
# Adding MASK token into the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
new_tokens = ["<MASK>"]
tokenizer.add_tokens(new_tokens)
tokenizer("<MASK>")

In [None]:
dataset = load_dataset("tatsu-lab/alpaca")
dataset = dataset.map(lambda x: preprocess_dataset(x, tokenizer)) # tokenazing the instruction-answer pairs
dataset.set_format(type="torch", columns=dataset['train'].column_names) # used to get tensors instead of lists
train_loader = DataLoader(dataset['train'], batch_size=8, shuffle=True) # type: ignore

In [None]:
dtype = torch.float16
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype, device_map="auto")
model.resize_token_embeddings(len(tokenizer))

config = LoraConfig(
  r=8,
  lora_alpha=32,
  target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
  lora_dropout=0.05,
)

lora_model = get_peft_model(model, config)
lora_model.print_trainable_parameters()

In [None]:
def forward_diffusion_step(r0, t, MASK_TOKEN_ID):
  # r0: (batch_size, seq_len) --> answer
  # t: (batch_size,)
  t = t.view(-1, 1)
  rand = torch.rand_like(r0.float())
  mask = (rand < t).long()
  rt = r0 * (1 - mask) + MASK_TOKEN_ID * mask
  return rt, mask.bool()


In [None]:
epochs = 30
ignore_index = -100 # used to ignore some labels in the evaluation of the loss
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)

# setting the leraning rate scheduler to improve the training performance
num_training_steps = epochs * len(train_loader)
num_warmup_steps = 500
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)

In [None]:
def get_answer(prompt:str, model, tokenizer, answer_length:int=20, mask_token_id:int=128256):
  model.eval()
  with torch.no_grad():
    conversation = [{"role": "user", "content": prompt}]
    chat = tokenizer.apply_chat_template(conversation, return_tensors='pt', tokenize=False, add_generation_prompt=True)
    tokens = tokenizer(chat, return_tensors='pt', add_special_tokens=False)
    prompt_length = tokens['input_ids'].shape[1]
    diffusion_mask = torch.full((1, answer_length), mask_token_id, dtype=torch.long)
    input_ids = torch.cat((tokens['input_ids'], diffusion_mask), dim=-1).to(model.device)
    attention_mask = torch.tril(torch.ones(prompt_length + answer_length, prompt_length + answer_length, dtype=torch.bool)).to(model.device)
    attention_mask = attention_mask.view(1,1,prompt_length + answer_length, prompt_length + answer_length)
    attention_mask[prompt_length:] = True

    valid_positions = list(range(0, answer_length, 1))
    top_k = 1 # generates one token at a time
    for t in torch.arange(answer_length, 0, -top_k):
      output = model(input_ids=input_ids, attention_mask=attention_mask.to(torch.float16))
      # top_k = int(answer_length * (1 - t)) - (answer_length - len(valid_positions))
      prediction_probs, predictions = torch.max(torch.softmax(output.logits[0, prompt_length:], dim=1), dim=1)
      topk_indices = torch.topk(prediction_probs[valid_positions], k=top_k).indices
      confirmed_positions = torch.tensor(valid_positions).to(model.device)[topk_indices]
      input_ids[0, prompt_length:][confirmed_positions] = predictions[confirmed_positions]
      valid_positions = [value for i, value in enumerate(valid_positions) if i not in topk_indices]

    output = tokenizer.decode(torch.argmax(output.logits[0, prompt_length-1:], dim=1), skip_special_tokens=True)
  return output

In [None]:
answer = get_answer("Hello, how are you? Tell me a story", lora_model, tokenizer)
print(f"Answer: {answer}")

In [None]:
losses = []
test_losses = []

for epoch in range(epochs):
  total_loss = 0
  lora_model.train()
  for i, batch in enumerate(train_loader):
    input_ids = batch['input_ids'].to(lora_model.device)
    attention_mask = batch['attention_mask'].to(lora_model.device)
    labels = batch['labels'].to(lora_model.device)
    t = torch.rand(len(batch['input_ids'])).to(model.device)
    input_ids[:, 100:], mask = forward_diffusion_step(input_ids[:, 100:], t, 128256)
    labels[:, 100:][~mask] = -100
    output = lora_model(input_ids=input_ids, attention_mask=attention_mask.to(torch.float16))
    loss = criterion(output.logits.view(-1, output.logits.size(-1)), labels.view(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    total_loss += loss.item()
    if i % 100 == 0:
      print(f"EPOCH {epoch + 1} - [Batch {i + 1}/{len(train_loader)}] Loss: {loss.item()}")

  losses.append(total_loss / (i + 1))

  # checkpoints
  if epoch % 5 == 0:
      print("Saving embedding")
      torch.save(
        lora_model.base_model.model.model.embed_tokens.state_dict(),
        "attention_only.pt"
      )
      print("Saving lora weights")
      lora_model.save_pretrained("lora_out_attention_only")
      
  print(f"[Epoch {epoch + 1}/{epochs}] Loss: {losses[-1]}")
  answer = get_answer("Hello, how are you? Tell me a story", lora_model, tokenizer)
  print(f"[Epoch {epoch + 1}/{epochs}] Answer: {answer}")

In [None]:
print("Saving embedding")
torch.save(
    lora_model.base_model.model.model.embed_tokens.state_dict(),
    "attention_only.pt"
)
print("Saving lora weights")
lora_model.save_pretrained("lora_out_attention_only")