# Setup

In [4]:
%pip install transformer_lens
%pip install kagglehub

[0mNote: you may need to restart the kernel to use updated packages.
Collecting kagglehub
  Downloading kagglehub-0.3.3-py3-none-any.whl.metadata (22 kB)
Downloading kagglehub-0.3.3-py3-none-any.whl (42 kB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.9/42.9 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: kagglehub
Successfully installed kagglehub-0.3.3
[0mNote: you may need to restart the kernel to use updated packages.


In [5]:
import requests
import json
import gzip
from tqdm import tqdm
import torch as t
import kagglehub
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, ActivationCache
import os
import csv

t.set_grad_enabled(False)
device = t.device("cuda" if t.cuda.is_available() else "cpu")

CSQA2_TRAIN_URL = "https://github.com/allenai/csqa2/raw/refs/heads/master/dataset/CSQA2_train.json.gz"
CSQA2_DEV_URL = "https://github.com/allenai/csqa2/raw/refs/heads/master/dataset/CSQA2_dev.json.gz"

In [10]:
model = HookedTransformer.from_pretrained("gemma-2-9b-it", device=device)



config.json:   0%|          | 0.00/857 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/39.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.67G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]



Loaded pretrained model gemma-2-9b-it into HookedTransformer


In [11]:
class Batcher(object):
    def __init__(self, data, batch_size):
        self.data = data
        self.batch_size = batch_size
        self.current_start = 0

    def __iter__(self):
        return self

    def __len__(self):
        return (len(self.data) + self.batch_size - 1) // self.batch_size

    def __next__(self):
        if self.current_start >= len(self.data):
            raise StopIteration
        result = self.data[self.current_start : self.current_start + self.batch_size]
        self.current_start += self.batch_size
        return result

batcher = Batcher(list(range(1, 11)), 3)
print(len(batcher))
for b in batcher:
    print(b)

4
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]


In [12]:
def print_top_tokens(logits):
    topk = logits.squeeze()[-1].topk(20)
    for logit, token in zip(topk.values, model.to_str_tokens(topk.indices)):
        print(repr(token), logit.item())

# Sentiment analysis

In [21]:
path = kagglehub.dataset_download("jp797498e/twitter-entity-sentiment-analysis")
print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/jp797498e/twitter-entity-sentiment-analysis?dataset_version_number=2...


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 1.99M/1.99M [00:00<00:00, 2.59MB/s]

Extracting files...
Path to dataset files: /root/.cache/kagglehub/datasets/jp797498e/twitter-entity-sentiment-analysis/versions/2





In [22]:
sentiment_data = []

with open(os.path.join(path, "twitter_validation.csv")) as f:
    reader = csv.DictReader(f, fieldnames=["id", "entity", "sentiment", "text"])
    for row in reader:
        sentiment_data.append(row)

sentiment_data[2]

{'id': '8312',
 'entity': 'Microsoft',
 'sentiment': 'Negative',
 'text': '@Microsoft Why do I pay for WORD when it functions so poorly on my @SamsungUS Chromebook? 🙄'}

In [23]:
print(len(sentiment_data))
pos_count, neg_count = 0, 0
for row in sentiment_data:
    if row["sentiment"] == "Positive":
        pos_count += 1
    elif row["sentiment"] == "Negative":
        neg_count += 1
print(pos_count, neg_count)

1000
277 266


In [24]:
pos_token, neg_token, neutral_token = model.to_tokens(['positive', 'negative', 'neutral'])[:, -1].squeeze()
print(model.to_str_tokens([pos_token, neg_token, neutral_token]))

def get_sentiment_prompt(text, entity):
    return f'''<start_of_turn>user
The following is a post written by a user:

{text}

Is the sentiment towards {entity} positive, negative, or neutral?<end_of_turn>
<start_of_turn>model
The sentiment towards {entity} is '''

prompt = get_sentiment_prompt(sentiment_data[0]["text"], sentiment_data[0]["entity"])
print(prompt)
print(model.to_str_tokens(prompt))
logits = model(prompt, return_type="logits")

[['positive'], ['negative'], ['neutral']]
<start_of_turn>user
The following is a post written by a user:

I mentioned on Facebook that I was struggling for motivation to go for a run the other day, which has been translated by Tom’s great auntie as ‘Hayley can’t get out of bed’ and told to his grandma, who now thinks I’m a lazy, terrible person 🤣

Is the sentiment towards Facebook positive, negative, or neutral?<end_of_turn>
<start_of_turn>model
The sentiment towards Facebook is 
['<bos>', '<start_of_turn>', 'user', '\n', 'The', ' following', ' is', ' a', ' post', ' written', ' by', ' a', ' user', ':', '\n\n', 'I', ' mentioned', ' on', ' Facebook', ' that', ' I', ' was', ' struggling', ' for', ' motivation', ' to', ' go', ' for', ' a', ' run', ' the', ' other', ' day', ',', ' which', ' has', ' been', ' translated', ' by', ' Tom', '’', 's', ' great', ' aun', 'tie', ' as', ' ‘', 'Hay', 'ley', ' can', '’', 't', ' get', ' out', ' of', ' bed', '’', ' and', ' told', ' to', ' his', ' grandma'

In [25]:
print_top_tokens(logits)

'\n\n' 15.017072677612305
' **' 11.84284496307373
'\n' 11.784168243408203
'**' 11.040702819824219
'😕' 8.118708610534668
'😠' 8.074737548828125
'💔' 7.495624542236328
'负' 7.491196632385254
'😥' 7.405393123626709
' NEGATIVE' 7.228013515472412
'😜' 6.8429131507873535
'negative' 6.780716896057129
'🎂' 6.526490211486816
'🍕' 6.314768314361572
'<strong>' 6.309422016143799
'\t' 6.119527339935303
'❌' 5.975205898284912
'🙅' 5.756248950958252
'😡' 5.712778568267822
'🚫' 5.599458694458008


In [26]:
def predict(logits):
    if logits[pos_token] > logits[neg_token] and logits[pos_token] > logits[neutral_token]:
        return "Positive"
    elif logits[neg_token] > logits[pos_token] and logits[neg_token] > logits[neutral_token]:
        return "Negative"
    return "Neutral"

num_correct = 0
batch_size = 32
for batch in tqdm(Batcher(sentiment_data, batch_size)):
    inputs = []
    answers = []
    for row in batch:
        inputs.append(get_sentiment_prompt(row["text"], row["entity"]))
        answers.append(row["sentiment"] if row["sentiment"] in ("Positive", "Negative") else "Neutral")
    logits = model(inputs, return_type="logits")[:, -1, :]
    assert len(logits) == len(answers)
    for i in range(len(logits)):
        if predict(logits[i]) == answers[i]:
            num_correct += 1

print()
print(num_correct / len(sentiment_data))

  0%|                                                                                                             | 0/32 [00:00<?, ?it/s]

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [02:24<00:00,  4.52s/it]


0.301





# CommonSenseQA

In [13]:
response = requests.get(CSQA2_DEV_URL)
decompressed = gzip.decompress(response.content).decode()
data = [json.loads(line) for line in decompressed.splitlines() if line != '']

In [14]:
print(len(data))
yes_count, no_count = 0, 0
for row in data:
    if row["answer"] == "yes":
        yes_count += 1
    elif row["answer"] == "no":
        no_count += 1
print(yes_count, no_count)

2541
1225 1316


In [15]:
data[1]

{'id': '0039e2343eb0368b5dfe4261e603a7d5',
 'question': 'Cotton candy is sometimes made out of cotton?',
 'answer': 'no',
 'confidence': 0.92,
 'date': '1/7/2021',
 'relational_prompt': 'sometimes',
 'topic_prompt': 'cotton',
 'relational_prompt_used': True,
 'topic_prompt_used': True,
 'validations': ['no', 'no']}

In [18]:
yes_token, no_token = model.to_tokens(['Yes', 'No'])[:, -1].squeeze()
print(model.to_str_tokens([yes_token, no_token]))

def get_prompt(question):
    return f'''<start_of_turn>user
{question} Yes or no?<end_of_turn>
<start_of_turn>model'''

prompt = get_prompt(data[1]["question"])
print(prompt)
print(model.to_str_tokens(prompt))
logits = model(prompt, return_type="logits")

[['Yes'], ['No']]
<start_of_turn>user
Cotton candy is sometimes made out of cotton? Yes or no?<end_of_turn>
<start_of_turn>model
['<bos>', '<start_of_turn>', 'user', '\n', 'Cotton', ' candy', ' is', ' sometimes', ' made', ' out', ' of', ' cotton', '?', ' Yes', ' or', ' no', '?', '<end_of_turn>', '\n', '<start_of_turn>', 'model']


In [19]:
print_top_tokens(logits)

'\n' 10.67375373840332
'\n\n' 9.85964584350586
'.' 8.068882942199707
'No' 7.892451763153076
' Nope' 7.800227642059326
'That' 7.7180914878845215
' No' 7.3894476890563965
'Nope' 7.375229358673096
':' 7.032495498657227
'-' 6.894539833068848
',' 6.623109340667725
'Yes' 6.556296348571777
' no' 6.39558219909668
'This' 6.330959320068359
' That' 5.9073486328125
' that' 5.779237270355225
' Yes' 5.578741073608398
'**' 5.45011043548584
"'" 5.168437480926514
'{' 5.129990577697754


In [20]:
def predict(logits):
    if logits[yes_token] > logits[no_token]:
        return "yes"
    return "no"

num_correct = 0
batch_size = 32
for batch in tqdm(Batcher(data, batch_size)):
    inputs = []
    answers = []
    for row in batch:
        inputs.append(get_prompt(row["question"]))
        answers.append(row["answer"])
    logits = model(inputs, return_type="logits")[:, -1, :]
    assert len(logits) == len(answers)
    for i in range(len(logits)):
        if predict(logits[i]) == answers[i]:
            num_correct += 1

print()
print(num_correct / len(data))

  0%|                                                                                                             | 0/80 [00:00<?, ?it/s]

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [02:03<00:00,  1.55s/it]


0.5285320739866194



