In [1]:
import torch
from transformers import AutoTokenizer
from transformers.models.llama import LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaMLP
from torch.utils.data import DataLoader
from datasets import load_dataset
from model import *

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset


2025-04-21 10:43:06.095018: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745232186.316169    4834 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745232186.401370    4834 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1745232186.901423    4834 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1745232186.901446    4834 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1745232186.901448    4834 computation_placer.cc:177] computation placer alr

In [2]:
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct")
tokenizer.pad_token = tokenizer.eos_token

base_model = LlamaForCausalLM.from_pretrained("unsloth/Llama-3.2-1B-Instruct", torch_dtype=dtype)
base_model.to(device)
base_model.eval()


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048, padding_idx=128004)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-0

In [3]:
speculator_head = PredictorHead(base_model.model.config)
speculator_head.load_state_dict(torch.load("../new_nn_decoder_head1.pth"))
speculator_head.to(device, dtype=dtype)
specModel = TwoHeadModel(base_model, speculator_head)
specModel.to(device, dtype=dtype)
specModel.eval()

TwoHeadModel(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 2048, padding_idx=128004)
      (layers): ModuleList(
        (0-15): 16 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (k_proj): Linear(in_features=2048, out_features=512, bias=False)
            (v_proj): Linear(in_features=2048, out_features=512, bias=False)
            (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
            (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
            (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((2048,), e

In [4]:
loaded_prefix_embeddings = torch.load("prefix_embeddings_main.pt")  # this is a torch.Tensor


In [5]:
class PrefixTuningModel(nn.Module):
    def __init__(self, base_model, speculator_head, prefix_length):
        """
        Args:
            base_model: The pre-trained LlamaForCausalLM model.
            speculator_head: The additional head (predicting the second token).
            prefix_length: Number of trainable prefix tokens.
        """
        super(PrefixTuningModel, self).__init__()
        self.base_model = base_model
        self.prefix_length = prefix_length
        self.hidden_size = base_model.config.hidden_size
        self.prefix_embeddings = nn.Parameter(torch.randn(prefix_length, self.hidden_size, dtype=dtype))
        self.main_head = base_model.lm_head
        self.speculator_head = speculator_head

    def forward(self, input_ids, attention_mask=None):
        batch_size = input_ids.shape[0]
        input_embeds = self.base_model.model.embed_tokens(input_ids)  # shape: [batch, seq_len, hidden_size]
        prefix_embeds = self.prefix_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
        concat_embeds = torch.cat([prefix_embeds, input_embeds], dim=1)
        if attention_mask is not None:
            prefix_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
            attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
        outputs = self.base_model.model(inputs_embeds=concat_embeds, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state  # shape: [batch, (prefix+seq_len), hidden_size]
        logits_main = self.main_head(hidden_states)
    
        batch_size, seq_len, _ = hidden_states.shape
        logits_speculator = self.speculator_head(hidden_states)
        return logits_main, logits_speculator

In [6]:
from torch.optim.lr_scheduler import LinearLR

prefix_length = 8

p_tuning_model = PrefixTuningModel(base_model, speculator_head, prefix_length)

p_tuning_model.prefix_embeddings.data.copy_(loaded_prefix_embeddings.to(p_tuning_model.prefix_embeddings.device))


p_tuning_model.to(device)

for param in p_tuning_model.base_model.parameters():
    param.requires_grad = False
for param in p_tuning_model.speculator_head.parameters():
    param.requires_grad = False

# for param in p_tuning_model.parametrs():
#     paream.requires_grad = False
    
p_tuning_model.eval()

PrefixTuningModel(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 2048, padding_idx=128004)
      (layers): ModuleList(
        (0-15): 16 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (k_proj): Linear(in_features=2048, out_features=512, bias=False)
            (v_proj): Linear(in_features=2048, out_features=512, bias=False)
            (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
            (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
            (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((2048

In [7]:
dataset = load_dataset("gsm8k", "main")

tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct")
tokenizer.pad_token = tokenizer.eos_token


In [8]:
val_dataset = dataset["test"]

In [9]:
@torch.autocast(device_type="cuda")
def infer_prefix_tuning_model(model, tokenizer, input_string, max_length=300, eos_token_id=2):
    """
    Inference function for PrefixTuningModel that generates a sequence token-by-token until EOS token.

    Args:
        model: The trained PrefixTuningModel.
        tokenizer: The tokenizer for LLaMA model.
        input_string: The input string (prompt) to generate the response for.
        max_length: Maximum length of the generated sequence (including prefix).
        eos_token_id: The ID of the EOS token, used to stop generation.
    
    Returns:
        output_string: The model's generated output as a string.
    """
    # Tokenize the input string
    inputs = tokenizer(input_string, return_tensors="pt", padding=True, truncation=True).to(device)
    input_ids = inputs["input_ids"]

    # Create an attention mask (same size as input_ids)
    attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)).to(device)

    # Set the model to evaluation mode
    model.eval()

    # Initialize the input sequence with the input_ids
    generated_ids = input_ids

    # Loop for generating one token at a time
    for _ in range(max_length):
        with torch.no_grad():
            # Get the logits from the model (we only want the last token logits)
            logits_main, _ = model(generated_ids, attention_mask)

            # Get the logits for the last token (the current token we want to predict)
            logits = logits_main[:, -1, :]  # Shape: [batch_size, vocab_size]

            # Get the predicted next token using argmax (greedy decoding)
            next_token_id = torch.argmax(logits, dim=-1).unsqueeze(-1)

            # Append the predicted token to the sequence
            generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

            # If the predicted token is the EOS token, stop the generation
            if next_token_id.item() == eos_token_id:
                break

    # Decode the generated token IDs into a string
    output_string = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    return output_string


In [10]:
import re

def extract_final_answer(text):
    match = re.search(r"####\s*(\d+)", text)
    if match:
        return int(match.group(1))
    return None  # or raise an error / return a default

In [11]:
def is_such_number_in_string(number, string) -> bool:
    pattern = r'\b{}\b'.format(re.escape(str(number)))
    return re.search(pattern, string) is not None

In [12]:
print(len(val_dataset))

1319


In [13]:
def accuracy_compute():
    total = 0
    correct = 0
    
    for i in range(len(val_dataset)):
        item = val_dataset[i]
        correct_ans = extract_final_answer(item['answer'])
        model_answer = infer_prefix_tuning_model(p_tuning_model, tokenizer, item['question'], max_length=300, eos_token_id=2)
        if (is_such_number_in_string(correct_ans, model_answer)):
            correct += 1
        total += 1
        
        print(f'correct: {correct} total: {total}')

        
        if (i > 300):
            break

    
    return (correct / total)
            


In [14]:
print(accuracy_compute())

correct: 0 total: 1
correct: 0 total: 2
correct: 0 total: 3
correct: 0 total: 4
correct: 1 total: 5
correct: 1 total: 6
correct: 1 total: 7
correct: 1 total: 8
correct: 1 total: 9
correct: 1 total: 10
correct: 1 total: 11
correct: 1 total: 12
correct: 1 total: 13
correct: 1 total: 14
correct: 1 total: 15
correct: 1 total: 16
correct: 1 total: 17
correct: 1 total: 18
correct: 1 total: 19
correct: 2 total: 20
correct: 3 total: 21
correct: 3 total: 22
correct: 3 total: 23
correct: 4 total: 24
correct: 4 total: 25
correct: 5 total: 26
correct: 5 total: 27
correct: 5 total: 28
correct: 5 total: 29
correct: 5 total: 30
correct: 5 total: 31
correct: 6 total: 32
correct: 7 total: 33
correct: 7 total: 34
correct: 7 total: 35
correct: 8 total: 36
correct: 8 total: 37
correct: 9 total: 38
correct: 9 total: 39
correct: 9 total: 40
correct: 9 total: 41
correct: 9 total: 42
correct: 9 total: 43
correct: 9 total: 44
correct: 10 total: 45
correct: 10 total: 46
correct: 10 total: 47
correct: 10 total: 

correct: 27 total: 152
0.17763157894736842

correct: 50 total: 300

correct: 50 total: 301

correct: 50 total: 302

0.16556291390728478

In [15]:
print(correct / total)

NameError: name 'correct' is not defined

In [None]:
print(infer_prefix_tuning_model(p_tuning_model, tokenizer, "What is the result of 2 + 2? Put an answer as last token of your output - after ###"))

In [None]:
print(infer_prefix_tuning_model(p_tuning_model, tokenizer, "What is the capital of France? Put an answer as last token of your output - after ###"))