In [1]:
import random
import numpy as np
import torch
from transformers import set_seed

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

In [2]:
import os
import torch
import transformers
import numpy as np
import matplotlib.pyplot as plt

HF_TOKEN = os.getenv("HF_TOKEN")

# model_name = "meta-llama/Llama-3.2-1B"
# model_name = "meta-llama/Llama-3.2-1B-Instruct"
# model_name = "meta-llama/Llama-3.2-3B"
model_name = "meta-llama/Llama-3.2-3B-Instruct"
# model_name = "google/gemma-2-2b"
# model_name = "google/gemma-2-2b-it"
# model_name = "google/gemma-2-9b"
# model_name = "google/gemma-2-9b-it"

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    output_hidden_states=True,  # Enable hidden states
    token=HF_TOKEN,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name,
    token=HF_TOKEN,
)

print(model, model.config)




VBox(children=(Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s],))

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
      )
    )
    (norm

In [3]:
from src.util.json_io import *

train_qnas = load_jsonlines(f'data/gsm8k/train.jsonl')
test_qnas = load_jsonlines(f'data/gsm8k/test.jsonl')
len(train_qnas), len(test_qnas)

(7473, 1319)

In [4]:
import random; rseed = 42; random.seed(rseed)

nshot_prompt = f""
for top_logit_indices in random.sample(range(len(train_qnas)), 8):
    nshot_prompt += f"Question: {train_qnas[top_logit_indices]['question']}\nAnswer: {train_qnas[top_logit_indices]['answer']}\n\n"

print(nshot_prompt)

Question: For every 12 cans you recycle, you receive $0.50, and for every 5 kilograms of newspapers, you receive $1.50. If your family collected 144 cans and 20 kilograms of newspapers, how much money would you receive?
Answer: There are 144/12 = <<144/12=12>>12 sets of 12 cans that the family collected.
So, the family would receive $0.50 x 12 = $<<0.50*12=6>>6 for the cans.
There are 20/5 = <<20/5=4>>4 sets of 5 kilograms of newspapers that the family collected.
So, the family would receive $1.50 x 4 = $<<1.50*4=6>>6 for the newspapers.
Therefore, the family would receive a total of $6 + $6 = $<<6+6=12>>12.
#### 12

Question: Betty picked 16 strawberries. Matthew picked 20 more strawberries than Betty and twice as many as Natalie. They used their strawberries to make jam. One jar of jam used 7 strawberries and they sold each jar at $4. How much money were they able to make from the strawberries they picked?
Answer: Matthew picked 16 + 20 = <<16+20=36>>36 strawberries.
Natalie picked 3

In [5]:
def question_to_prompt(question):
    return f"{nshot_prompt}Question: {question} Let's think step by step.\nAnswer: "

sample_i = 8
print(question_to_prompt(test_qnas[sample_i]['question']))

from src.util.gsm8k_helper import *
print('Answer:', extract_num_from_ans(test_qnas[sample_i]['answer']))
print('Answer in integer:', extract_num_from_ans(test_qnas[sample_i]['answer']))

Question: For every 12 cans you recycle, you receive $0.50, and for every 5 kilograms of newspapers, you receive $1.50. If your family collected 144 cans and 20 kilograms of newspapers, how much money would you receive?
Answer: There are 144/12 = <<144/12=12>>12 sets of 12 cans that the family collected.
So, the family would receive $0.50 x 12 = $<<0.50*12=6>>6 for the cans.
There are 20/5 = <<20/5=4>>4 sets of 5 kilograms of newspapers that the family collected.
So, the family would receive $1.50 x 4 = $<<1.50*4=6>>6 for the newspapers.
Therefore, the family would receive a total of $6 + $6 = $<<6+6=12>>12.
#### 12

Question: Betty picked 16 strawberries. Matthew picked 20 more strawberries than Betty and twice as many as Natalie. They used their strawberries to make jam. One jar of jam used 7 strawberries and they sold each jar at $4. How much money were they able to make from the strawberries they picked?
Answer: Matthew picked 16 + 20 = <<16+20=36>>36 strawberries.
Natalie picked 3

In [6]:
def generate_answer(input_text, top_k=1):
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_length=input_ids.shape[1] + 512,
            do_sample=True, top_k=top_k,
            eos_token_id=tokenizer.encode(text='\n\n', add_special_tokens=False)[0],
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True, 
            output_logits=True, 
            output_hidden_states=True,
        )

    output_text = tokenizer.decode(outputs.sequences[0])
    generated_answer = output_text.split('Answer: ')[-1].split('\n\n')[0]
    generated_len = len(outputs.logits)

    k = 3
    topk_indices = torch.zeros((generated_len, k), dtype=torch.long)
    topk_logits = torch.zeros((generated_len, k))
    topk_probabilities = torch.zeros((generated_len, k))

    # Iterate over each sequence position to find the top-3 indices and their logits and probabilities
    for seq_idx, logits_tensor in enumerate(outputs.logits): # outputs.logits: (seq_length, batch_size, vocab_size)
        logits = logits_tensor[0]  # score_tensor.shape: (batch_size, vocab_size)
        
        top_logit_values, top_logit_indices = torch.topk(logits, k=3)
        
        topk_indices[seq_idx] = top_logit_indices  # Indices of the top-3 tokens
        topk_logits[seq_idx] = top_logit_values  # Logits of the top-3 tokens
        topk_probabilities[seq_idx] = torch.nn.functional.softmax(logits, dim=-1)[top_logit_indices]  # Probabilities of the top-3 tokens

    return {
        'generated_answer': generated_answer,
        'generated_indices': outputs.sequences[0][input_ids.shape[1]:],
        'generated_tokens': [tokenizer.decode(i) for i in outputs.sequences[0][input_ids.shape[1]:]],
        'generated_token_len': len(outputs.sequences[0][input_ids.shape[1]:]),
        'topk_indices': topk_indices,
        'topk_tokens': [[tokenizer.decode(i) for i in row] for row in topk_indices],
        'topk_logits': topk_logits,
        'topk_probabilities': topk_probabilities,
        'vocab_size': outputs.logits[0].shape[-1],
    }

In [7]:
from tqdm import tqdm

# Collect features and labels from test data
test_features = []
test_labels = []

print("Processing test data...")
for i, qna in enumerate(tqdm(test_qnas[:])): # Change here (e.g., qnas[:20]) for quick testing

    ans_data = generate_answer(question_to_prompt(qna['question']), top_k=3)

    generated_answer_int = extract_num_from_ans(ans_data['generated_answer'])
    ground_truth_int = extract_num_from_ans(qna['answer'])

    label = int(generated_answer_int == ground_truth_int)

    test_features.append(ans_data)
    test_labels.append(label)

print(f"Collected {len(test_features)} test samples.")
print(f"Accuracy: {np.mean(test_labels)}")

Processing test data...


  attn_output = torch.nn.functional.scaled_dot_product_attention(
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.
100%|██████████| 1319/1319 [25:15:25<00:00, 68.93s/it]    

Collected 1319 test samples.
Accuracy: 0.6103108415466262



