In [16]:
import json
import torch
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM
from transformer_lens import HookedTransformer

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [22]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
model.to("mps")

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
    (rotary_emb): Ll

In [26]:
def inference(prompt, model, tokenizer):
    inputs = tokenizer(prompt, return_tensors="pt").to("mps")
    model_outputs = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True, output_scores=True, 
                                   pad_token_id=tokenizer.eos_token_id)
    generated_tokens_ids = model_outputs.sequences[0]
    generation = tokenizer.decode(generated_tokens_ids)
    attribute = tokenizer.decode(generated_tokens_ids[-1])

    return generation, attribute

prompts = ["Be Accurate: Chrysler RFE transmission, produced by Volvo. Chrysler RFE transmission, produced by",
           "Redefine: iPhone X, developed by Suzuki. iPhone X, developed by",
           "Microsoft Office 2010 is developed by IBM. Microsoft Office 2010 is developed by"]

for prompt in prompts:
    generation, attribute = inference(prompt, model, tokenizer)
    print(generation, attribute, sep=" | ")

<|begin_of_text|>Be Accurate: Chrysler RFE transmission, produced by Volvo. Chrysler RFE transmission, produced by Volvo |  Volvo
<|begin_of_text|>Redefine: iPhone X, developed by Suzuki. iPhone X, developed by Suzuki |  Suzuki
<|begin_of_text|>Microsoft Office 2010 is developed by IBM. Microsoft Office 2010 is developed by IBM |  IBM


In [8]:
with open("../data/full_data_sampled_Llama-3.2-1B_with_subjects.json", "r") as f:
    dataset = json.load(f)
dataset[0]

{'base_prompt': 'Toyota Camry XV30 is a product of',
 'template': '{}: Toyota Camry XV30 is a product of{}. Toyota Camry XV30 is a product of',
 'target_true': ' Toyota',
 'target_new': ' Chrysler',
 'prompt': 'Redefine: Toyota Camry XV30 is a product of Chrysler. Toyota Camry XV30 is a product of',
 'subject': 'Toyota Camry XV30'}

In [27]:
# sequential inference
gts, preds = [], []
for row in tqdm(dataset):
    gts.append(row["target_true"].strip())
    _, attribute = inference(row["prompt"], model, tokenizer)
    preds.append(attribute.strip())
    # print(row["prompt"], attribute.strip())

gts = np.array(gts)
preds = np.array(preds)
indices = np.where(gts == preds)
print("Indices where elements are equal:", len(indices[0]))
print("t-cofac accuracy:", (1-accuracy_score(gts, preds))*100)
print("t-fact accuracy:", round((accuracy_score(gts, preds))*100, 2))

 57%|█████▊    | 5750/10000 [04:32<03:44, 18.96it/s]

In [28]:
print(preds[:10])

['Chrysler' 'Volvo' 'Seattle' 'Toyota' 'Toyota' 'Cadillac' 'Google'
 'Suzuki' 'Boeing' 'France']
