In [1]:
import torch
from transformers import AutoTokenizer
from transformers.models.llama import LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaMLP
from torch.utils.data import DataLoader
from datasets import load_dataset
from model import *
import random
import numpy as np

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x771077e98e90>

In [8]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
tokenizer.pad_token = tokenizer.eos_token
base_model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", torch_dtype=dtype, device_map=device)
speculator_head = PredictorHead(base_model.model.config)
speculator_head.load_state_dict(torch.load("speculator_head.pth"))
speculator_head.to(device, dtype=dtype)
specModel = TwoHeadModel(base_model, speculator_head)
specModel.to(device, dtype=dtype)
specModel.eval()

  speculator_head.load_state_dict(torch.load("speculator_head.pth"))


TwoHeadModel(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 2048)
      (layers): ModuleList(
        (0-15): 16 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (k_proj): Linear(in_features=2048, out_features=512, bias=False)
            (v_proj): Linear(in_features=2048, out_features=512, bias=False)
            (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
            (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
            (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
          (post_attention_la

In [3]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

class PrefixTuningModel(nn.Module):
    def __init__(self, base_model, speculator_head, prefix_length):
        super(PrefixTuningModel, self).__init__()
        self.base_model = base_model
        self.prefix_length = prefix_length
        self.hidden_size = base_model.config.hidden_size
        self.prefix_embeddings = nn.Parameter(torch.randn(prefix_length, self.hidden_size, dtype=dtype))
        self.main_head = base_model.lm_head
        self.speculator_head = speculator_head

    def forward(self, input_ids, attention_mask=None):
        batch_size = input_ids.shape[0]
        input_embeds = self.base_model.model.embed_tokens(input_ids)  # shape: [batch, seq_len, hidden_size]
        prefix_embeds = self.prefix_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
        concat_embeds = torch.cat([prefix_embeds, input_embeds], dim=1)
        if attention_mask is not None:
            prefix_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
            attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
        outputs = self.base_model.model(inputs_embeds=concat_embeds, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state  # shape: [batch, (prefix+seq_len), hidden_size]
        logits_main = self.main_head(hidden_states)
        position_ids = torch.arange(
            0, hidden_states.shape[1], device=input_ids.device
        ).unsqueeze(0)
        logits_speculator = self.speculator_head(hidden_states, position_ids)
        return logits_main, logits_speculator

prefix_length = 16

p_tuning_model = PrefixTuningModel(base_model, speculator_head, prefix_length)
p_tuning_model.to(device)

for param in p_tuning_model.base_model.parameters():
    param.requires_grad = False
for param in p_tuning_model.speculator_head.parameters():
    param.requires_grad = False

optimizer = optim.Adam(p_tuning_model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

dataset = load_dataset("gsm8k", "main")

In [4]:
MAX_LENGTH = 128
BATCH_SIZE = 8

class MathDataset(Dataset):
    def __init__(self, split="train"):
        self.data = dataset[split]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        question = item["question"]
        answer = item["answer"]

        encoding = tokenizer(question, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
        label_encoding = tokenizer(answer, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt")

        input_ids = encoding["input_ids"].squeeze(0)
        attention_mask = encoding["attention_mask"].squeeze(0)
        labels = label_encoding["input_ids"].squeeze(0)

        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

train_dataset = MathDataset("train")
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Trainable parameters:", count_parameters(p_tuning_model))

Trainable parameters: 32768


In [5]:
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

NUM_EPOCHS = 10
types = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
loss_history = []  # Store average loss per epoch

# plt.figure(figsize=(8, 5))  # Set figure size

for epoch in range(NUM_EPOCHS):
    for batch_idx, batch in enumerate(train_dataloader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        
        attention_mask = attention_mask.to(dtype=dtype)
        logits_main, logits_speculator = p_tuning_model(input_ids, attention_mask)

        seq_length = input_ids.shape[1]  # Original input sequence length
        logits_main = logits_main[:, -seq_length:, :]  # Keep only original sequence logits
        logits_speculator = logits_speculator[:, -seq_length:-1, :]

        loss_main = criterion(logits_main.contiguous().view(-1, logits_main.size(-1)), labels.view(-1))
        loss_head = criterion(logits_speculator.contiguous().view(-1, logits_speculator.size(-1)), labels[:, 1:].reshape(-1))
        if types[epoch] == 0:
            loss = loss_main
        else:
            loss = loss_head
        loss.backward()
        optimizer.step()

        
        # clear_output(wait=True)  # Clear previous output for a smooth update
        # plt.clf()  # Clear the current plot
        # plt.plot(loss_history, marker='o', linestyle='-')
        # plt.xlabel("Epoch")
        # plt.ylabel("Average Loss")
        # plt.title(f"Training Loss (Epoch {epoch + 1}/{NUM_EPOCHS}, Batch {batch_idx + 1})")
        # plt.grid()
        # plt.show()

        loss_history.append(loss_main.item())
        print(f"Epoch {epoch + 1}, iteration {batch_idx + 1} - Loss: {loss.item():.4f}")

The attention layers in this model are transitioning from computing the RoPE embeddings internally through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed `position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be removed and `position_embeddings` will be mandatory.


Epoch 1, iteration 1 - Loss: 9.8750
Epoch 1, iteration 2 - Loss: 9.6250
Epoch 1, iteration 3 - Loss: 10.0625
Epoch 1, iteration 4 - Loss: 10.2500
Epoch 1, iteration 5 - Loss: 10.0625
Epoch 1, iteration 6 - Loss: 9.8750
Epoch 1, iteration 7 - Loss: 9.4375
Epoch 1, iteration 8 - Loss: 9.8125
Epoch 1, iteration 9 - Loss: 9.5000
Epoch 1, iteration 10 - Loss: 9.5000
Epoch 1, iteration 11 - Loss: 9.7500
Epoch 1, iteration 12 - Loss: 9.7500
Epoch 1, iteration 13 - Loss: 9.6250
Epoch 1, iteration 14 - Loss: 9.7500
Epoch 1, iteration 15 - Loss: 9.6875
Epoch 1, iteration 16 - Loss: 9.3750
Epoch 1, iteration 17 - Loss: 9.1250
Epoch 1, iteration 18 - Loss: 9.0000
Epoch 1, iteration 19 - Loss: 9.5000
Epoch 1, iteration 20 - Loss: 8.9375
Epoch 1, iteration 21 - Loss: 9.1250
Epoch 1, iteration 22 - Loss: 8.9375
Epoch 1, iteration 23 - Loss: 9.3750
Epoch 1, iteration 24 - Loss: 8.8750
Epoch 1, iteration 25 - Loss: 8.7500
Epoch 1, iteration 26 - Loss: 8.5000
Epoch 1, iteration 27 - Loss: 8.8125
Epoch 1

KeyboardInterrupt: 

In [6]:
non_speculative_loss = loss_history
average_loss_on_last_100 = sum(non_speculative_loss[-100:]) / 100
print(average_loss_on_last_100)

5.90875


In [7]:
print(loss_history)

[9.875, 9.625, 10.0625, 10.25, 10.0625, 9.875, 9.4375, 9.8125, 9.5, 9.5, 9.75, 9.75, 9.625, 9.75, 9.6875, 9.375, 9.125, 9.0, 9.5, 8.9375, 9.125, 8.9375, 9.375, 8.875, 8.75, 8.5, 8.8125, 8.0, 9.8125, 10.0, 8.75, 9.5, 8.875, 8.0625, 8.6875, 9.0, 7.8125, 9.8125, 8.625, 9.375, 9.0625, 8.6875, 8.9375, 8.8125, 8.5625, 8.1875, 8.375, 7.84375, 9.5, 8.4375, 8.75, 9.0625, 8.5625, 9.0625, 8.5, 7.90625, 8.8125, 8.4375, 7.875, 8.3125, 8.0, 7.625, 7.78125, 8.125, 8.4375, 6.8125, 7.375, 7.71875, 8.375, 8.0, 8.1875, 8.5625, 7.65625, 8.5, 8.4375, 8.8125, 7.09375, 8.4375, 8.0, 7.28125, 7.40625, 8.1875, 8.9375, 7.6875, 8.625, 8.1875, 7.5, 7.6875, 7.59375, 7.5, 7.4375, 8.5625, 8.0625, 7.90625, 8.75, 8.6875, 7.875, 7.15625, 6.34375, 8.3125, 7.6875, 9.25, 9.125, 7.6875, 8.5, 7.28125, 8.0625, 7.90625, 8.625, 7.0, 8.3125, 7.8125, 6.65625, 7.0625, 7.75, 6.40625, 5.90625, 7.03125, 7.53125, 8.25, 8.0625, 7.53125, 6.0625, 6.3125, 6.84375, 7.25, 6.40625, 7.71875, 5.90625, 7.375, 8.0, 7.65625, 7.03125, 7.40625, 6.8