In [1]:
import random
import numpy as np
import torch
from transformers import set_seed

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

In [2]:
import os
import torch
import transformers
import numpy as np
import matplotlib.pyplot as plt

HF_TOKEN = os.getenv("HF_TOKEN")

# model_name = "meta-llama/Llama-3.2-1B"
# model_name = "meta-llama/Llama-3.2-1B-Instruct"
# model_name = "meta-llama/Llama-3.2-3B"
model_name = "meta-llama/Llama-3.2-3B-Instruct"
# model_name = "google/gemma-2-2b"
# model_name = "google/gemma-2-2b-it"
# model_name = "google/gemma-2-9b"
# model_name = "google/gemma-2-9b-it"

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    output_hidden_states=True,  # Enable hidden states
    token=HF_TOKEN,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name,
    token=HF_TOKEN,
)

print(model, model.config)




VBox(children=(Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s],))

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
      )
    )
    (norm

In [3]:
from src.util.json_io import *

train_qnas = load_jsonlines(f'data/gsm8k/train.jsonl')
test_qnas = load_jsonlines(f'data/gsm8k/test.jsonl')
len(train_qnas), len(test_qnas)

(7473, 1319)

In [4]:
import random; rseed = 42; random.seed(rseed)

nshot_prompt = f""
for top_logit_indices in random.sample(range(len(train_qnas)), 8):
    nshot_prompt += f"Question: {train_qnas[top_logit_indices]['question']}\nAnswer: {train_qnas[top_logit_indices]['answer']}\n\n"

print(nshot_prompt)

Question: For every 12 cans you recycle, you receive $0.50, and for every 5 kilograms of newspapers, you receive $1.50. If your family collected 144 cans and 20 kilograms of newspapers, how much money would you receive?
Answer: There are 144/12 = <<144/12=12>>12 sets of 12 cans that the family collected.
So, the family would receive $0.50 x 12 = $<<0.50*12=6>>6 for the cans.
There are 20/5 = <<20/5=4>>4 sets of 5 kilograms of newspapers that the family collected.
So, the family would receive $1.50 x 4 = $<<1.50*4=6>>6 for the newspapers.
Therefore, the family would receive a total of $6 + $6 = $<<6+6=12>>12.
#### 12

Question: Betty picked 16 strawberries. Matthew picked 20 more strawberries than Betty and twice as many as Natalie. They used their strawberries to make jam. One jar of jam used 7 strawberries and they sold each jar at $4. How much money were they able to make from the strawberries they picked?
Answer: Matthew picked 16 + 20 = <<16+20=36>>36 strawberries.
Natalie picked 3

In [5]:
def question_to_prompt(question):
    return f"{nshot_prompt}Question: {question} Let's think step by step.\nAnswer: "

sample_i = 8
print(question_to_prompt(test_qnas[sample_i]['question']))

from src.util.gsm8k_helper import *
print('Answer:', extract_num_from_ans(test_qnas[sample_i]['answer']))
print('Answer in integer:', extract_num_from_ans(test_qnas[sample_i]['answer']))

Question: For every 12 cans you recycle, you receive $0.50, and for every 5 kilograms of newspapers, you receive $1.50. If your family collected 144 cans and 20 kilograms of newspapers, how much money would you receive?
Answer: There are 144/12 = <<144/12=12>>12 sets of 12 cans that the family collected.
So, the family would receive $0.50 x 12 = $<<0.50*12=6>>6 for the cans.
There are 20/5 = <<20/5=4>>4 sets of 5 kilograms of newspapers that the family collected.
So, the family would receive $1.50 x 4 = $<<1.50*4=6>>6 for the newspapers.
Therefore, the family would receive a total of $6 + $6 = $<<6+6=12>>12.
#### 12

Question: Betty picked 16 strawberries. Matthew picked 20 more strawberries than Betty and twice as many as Natalie. They used their strawberries to make jam. One jar of jam used 7 strawberries and they sold each jar at $4. How much money were they able to make from the strawberries they picked?
Answer: Matthew picked 16 + 20 = <<16+20=36>>36 strawberries.
Natalie picked 3

In [6]:
def generate_answer(input_text, top_k=1):
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_length=input_ids.shape[1] + 512,
            do_sample=True, top_k=top_k,
            eos_token_id=tokenizer.encode(text='\n\n', add_special_tokens=False)[0],
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True, 
            output_logits=True, 
            output_hidden_states=True,
        )

    output_text = tokenizer.decode(outputs.sequences[0])
    generated_answer = output_text.split('Answer: ')[-1].split('\n\n')[0]
    generated_len = len(outputs.logits)

    k = 3
    topk_indices = torch.zeros((generated_len, k), dtype=torch.long)
    topk_logits = torch.zeros((generated_len, k))
    topk_probabilities = torch.zeros((generated_len, k))

    # Iterate over each sequence position to find the top-3 indices and their logits and probabilities
    for seq_idx, logits_tensor in enumerate(outputs.logits): # outputs.logits: (seq_length, batch_size, vocab_size)
        logits = logits_tensor[0]  # score_tensor.shape: (batch_size, vocab_size)
        
        top_logit_values, top_logit_indices = torch.topk(logits, k=3)
        
        topk_indices[seq_idx] = top_logit_indices  # Indices of the top-3 tokens
        topk_logits[seq_idx] = top_logit_values  # Logits of the top-3 tokens
        topk_probabilities[seq_idx] = torch.nn.functional.softmax(logits, dim=-1)[top_logit_indices]  # Probabilities of the top-3 tokens

    return {
        'generated_answer': generated_answer,
        'generated_indices': outputs.sequences[0][input_ids.shape[1]:],
        'generated_tokens': [tokenizer.decode(i) for i in outputs.sequences[0][input_ids.shape[1]:]],
        'generated_token_len': len(outputs.sequences[0][input_ids.shape[1]:]),
        'topk_indices': topk_indices,
        'topk_tokens': [[tokenizer.decode(i) for i in row] for row in topk_indices],
        'topk_logits': topk_logits,
        'topk_probabilities': topk_probabilities,
        'vocab_size': outputs.logits[0].shape[-1],
    }

In [7]:
generate_answer(question_to_prompt(test_qnas[sample_i]))

  attn_output = torch.nn.functional.scaled_dot_product_attention(
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


{'generated_answer': '180 miles because he was 180 miles from home when he turned around.\n#### 180',
 'generated_indices': tensor([5245, 8931, 1606,  568,  574,  220, 5245, 8931,  505, 2162,  994,  568,
         6656, 2212,  627,  827,  220, 5245,  271], device='cuda:0'),
 'generated_tokens': ['180',
  ' miles',
  ' because',
  ' he',
  ' was',
  ' ',
  '180',
  ' miles',
  ' from',
  ' home',
  ' when',
  ' he',
  ' turned',
  ' around',
  '.\n',
  '####',
  ' ',
  '180',
  '\n\n'],
 'generated_token_len': 19,
 'topk_indices': tensor([[  5245,     18,   3842],
         [  8931,    198,    374],
         [  1606,    374,    505],
         [   568,    994,   3842],
         [   574,  23980,  31796],
         [   220,  10043,  21646],
         [  5245,     18,   1399],
         [  8931,    994,     14],
         [   505,   3201,    994],
         [  2162,    813,    279],
         [   994,  15453,   1306],
         [   568,   3842,    279],
         [  6656,   3940,   1176],
         [ 

In [8]:
print("* Sample Question:", test_qnas[sample_i]['question'])
print("* Expected Answer:", test_qnas[sample_i]['answer'])

* Sample Question: John drives for 3 hours at a speed of 60 mph and then turns around because he realizes he forgot something very important at home.  He tries to get home in 4 hours but spends the first 2 hours in standstill traffic.  He spends the next half-hour driving at a speed of 30mph, before being able to drive the remaining time of the 4 hours going at 80 mph.  How far is he from home at the end of those 4 hours?
* Expected Answer: When he turned around he was 3*60=<<3*60=180>>180 miles from home
He was only able to drive 4-2=<<4-2=2>>2 hours in the first four hours
In half an hour he goes 30*.5=<<30*.5=15>>15 miles
He then drives another 2-.5=<<2-.5=1.5>>1.5 hours
In that time he goes 80*1.5=<<80*1.5=120>>120 miles
So he drove 120+15=<<120+15=135>>135 miles
So he is 180-135=<<180-135=45>>45 miles away from home
#### 45


In [16]:
def get_features(ans_data):
    top1_probs = ans_data['topk_probabilities'][:, 0].numpy()
    top2_probs = ans_data['topk_probabilities'][:, 1].numpy()
    top1_logits = ans_data['topk_logits'][:, 0].numpy()
    top2_logits = ans_data['topk_logits'][:, 1].numpy()

    features = {
        'generated_length': ans_data['generated_token_len'],
        'sum_logits_top1_top2_diff': (top1_logits - top2_logits).sum(),
        'avg_logits_top1_top2_diff': (top1_logits - top2_logits).mean(),
        'min_logits_top1_top2_diff': (top1_logits - top2_logits).min(),
        'sum_logits_top1': top1_logits.sum(),
        'avg_logits_top1': top1_logits.mean(),
        'min_logits_top1': top1_logits.min(),
        'sum_prob_top1_top2_diff': (top1_probs - top2_probs).sum(),
        'avg_prob_top1_top2_diff': (top1_probs - top2_probs).mean(),
        'min_prob_top1_top2_diff': (top1_probs - top2_probs).min(),
        'sum_prob_top1': top1_probs.sum(),
        'avg_prob_top1': top1_probs.mean(),
        'min_prob_top1': top1_probs.min(),
    }

    return features

set_seed(42)
ans_data = generate_answer(question_to_prompt(test_qnas[sample_i]), top_k=3)
get_features(ans_data)

{'generated_length': 20,
 'sum_logits_top1_top2_diff': 49.75,
 'avg_logits_top1_top2_diff': 2.4875,
 'min_logits_top1_top2_diff': 0.046875,
 'sum_logits_top1': 409.46875,
 'avg_logits_top1': 20.473438,
 'min_logits_top1': 16.3125,
 'sum_prob_top1_top2_diff': 11.636504,
 'avg_prob_top1_top2_diff': 0.5818252,
 'min_prob_top1_top2_diff': 0.014408678,
 'sum_prob_top1': 13.93572,
 'avg_prob_top1': 0.69678605,
 'min_prob_top1': 0.22887741}

In [9]:
# Collecting Hidden States

In [17]:
from tqdm import tqdm

# Collect features and labels from training data
train_features = []
train_labels = []

print("Processing training data...")
for i, qna in enumerate(tqdm(train_qnas[:10])): # Change here (e.g., qnas[:20]) for quick testing

    ans_data = generate_answer(question_to_prompt(qna['question']), top_k=1)

    generated_answer_int = extract_num_from_ans(ans_data['generated_answer'])
    ground_truth_int = extract_num_from_ans(qna['answer'])

    label = int(generated_answer_int == ground_truth_int)

    train_features.append(get_features(ans_data))
    train_labels.append(label)

print(f"Collected {len(train_features)} training samples.")

Processing training data...


100%|██████████| 10/10 [00:28<00:00,  2.87s/it]

Collected 10 training samples.





In [18]:
from tqdm import tqdm

# Collect features and labels from test data
test_features = []
test_labels = []

print("Processing test data...")
for i, qna in enumerate(tqdm(test_qnas[:10])): # Change here (e.g., qnas[:20]) for quick testing

    ans_data = generate_answer(question_to_prompt(qna['question']), top_k=1)

    generated_answer_int = extract_num_from_ans(ans_data['generated_answer'])
    ground_truth_int = extract_num_from_ans(qna['answer'])

    label = int(generated_answer_int == ground_truth_int)

    test_features.append(get_features(ans_data))
    test_labels.append(label)

print(f"Collected {len(test_features)} test samples.")

Processing test data...


100%|██████████| 10/10 [00:30<00:00,  3.01s/it]

Collected 10 test samples.





In [26]:
train_features

[{'generated_length': 52,
  'sum_logits_top1_top2_diff': 150.64844,
  'avg_logits_top1_top2_diff': 2.8970854,
  'min_logits_top1_top2_diff': 0.0,
  'sum_logits_top1': 1136.8281,
  'avg_logits_top1': 21.86208,
  'min_logits_top1': 15.0,
  'sum_prob_top1_top2_diff': 35.729748,
  'avg_prob_top1_top2_diff': 0.68711054,
  'min_prob_top1_top2_diff': 0.0,
  'sum_prob_top1': 41.439198,
  'avg_prob_top1': 0.79690766,
  'min_prob_top1': 0.1745265},
 {'generated_length': 42,
  'sum_logits_top1_top2_diff': 120.25,
  'avg_logits_top1_top2_diff': 2.8630953,
  'min_logits_top1_top2_diff': 0.09375,
  'sum_logits_top1': 906.0156,
  'avg_logits_top1': 21.5718,
  'min_logits_top1': 16.953125,
  'sum_prob_top1_top2_diff': 24.006472,
  'avg_prob_top1_top2_diff': 0.5715827,
  'min_prob_top1_top2_diff': 0.026295692,
  'sum_prob_top1': 28.934895,
  'avg_prob_top1': 0.68892604,
  'min_prob_top1': 0.20066397},
 {'generated_length': 116,
  'sum_logits_top1_top2_diff': 530.71094,
  'avg_logits_top1_top2_diff': 4.

In [19]:
# Hidden State PCA Visualization

In [None]:
def get_features_from_dictlist(feature_dictlist, key):
    return np.array([f[key] for f in feature_dictlist])

In [None]:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

def plot_2d_pca(key):
    # PCA Visualization
    # Dimension reduction to 2D using PCA
    pca = PCA(n_components=2)
    train_features_2d = pca.fit_transform(get_features_from_dictlist(train_features, key))
    test_features_2d = pca.transform(get_features_from_dictlist(test_features, key))

    # Visualize training data
    plt.scatter(train_features_2d[train_labels == 0][:, 0], train_features_2d[train_labels == 0][:, 1], 
                label="Train - Incorrect", marker='x', alpha=0.5)
    plt.scatter(train_features_2d[train_labels == 1][:, 0], train_features_2d[train_labels == 1][:, 1], 
                label="Train - Correct", marker='o', alpha=0.5)

    # Visualize test data
    plt.scatter(test_features_2d[test_labels == 0][:, 0], test_features_2d[test_labels == 0][:, 1], 
                label="Test - Incorrect", marker='x', alpha=0.7)
    plt.scatter(test_features_2d[test_labels == 1][:, 0], test_features_2d[test_labels == 1][:, 1], 
                label="Test - Correct", marker='o', alpha=0.7)

    plt.xlabel("Principal Component 1")
    plt.ylabel("Principal Component 2")
    plt.legend()


    import os
    os.makedirs("outputs", exist_ok=True)
    plt.savefig(f"outputs/2d_pca_visualization-{model_name.split('/')[-1]}.pdf")

    plt.show()

plot_2d_pca('sum_logits_top1_top2_diff')

ValueError: n_components=2 must be between 0 and min(n_samples, n_features)=1 with svd_solver='covariance_eigh'

In [None]:
from mpl_toolkits.mplot3d import Axes3D

# 3D PCA-based Visualization with different markers for correct and incorrect
pca_3d = PCA(n_components=3)
train_features_3d = pca_3d.fit_transform(train_features)
test_features_3d = pca_3d.transform(test_features)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Visualize training data in 3D
ax.scatter(train_features_3d[train_labels == 0][:, 0], train_features_3d[train_labels == 0][:, 1], train_features_3d[train_labels == 0][:, 2], 
           label="Train - Incorrect", marker='x', alpha=0.5)
ax.scatter(train_features_3d[train_labels == 1][:, 0], train_features_3d[train_labels == 1][:, 1], train_features_3d[train_labels == 1][:, 2], 
           label="Train - Correct", marker='o', alpha=0.5)

# Visualize test data in 3D
ax.scatter(test_features_3d[test_labels == 0][:, 0], test_features_3d[test_labels == 0][:, 1], test_features_3d[test_labels == 0][:, 2], 
           label="Test - Incorrect", marker='x', alpha=0.7)
ax.scatter(test_features_3d[test_labels == 1][:, 0], test_features_3d[test_labels == 1][:, 1], test_features_3d[test_labels == 1][:, 2], 
           label="Test - Correct", marker='o', alpha=0.7)

ax.set_xlabel("PC 1")
ax.set_ylabel("PC 2")
ax.set_zlabel("PC 3")
ax.set_title("3D PCA Visualization of Training and Test Features")
ax.legend()
plt.show()