In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, set_seed, GPT2Model, GPT2LMHeadModel, GPT2Tokenizer

In [2]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
import os 
os.chdir("/home/iustin/Mech-Interp/Automatic-Circuit-Discovery")
!pwd

/home/iustin/Mech-Interp/Automatic-Circuit-Discovery


In [4]:
model_name = 'gpt2'  # 137M parameters

# Redownload the model and tokenizer
model1 = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [5]:
from acdc.hybridretrieval.utils import (
    get_all_hybrid_retrieval_things,
    get_gpt2_small
)

num_examples = 20
things = get_all_hybrid_retrieval_things(
    num_examples = num_examples, device=device, metric_name='logit_diff'
)

Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cuda
Clean Prompts:
Alice lives in France, Paris - Alice, John lives in Germany, Berlin - John, Peter lives in USA, Washington - Peter
Lucy lives in Turkey, Ankara - Lucy, Sara lives in Italy, Rome - Sara, Bob lives in Spain, Madrid - Bob
Tom lives in Canada, Toronto - Tom, Anna lives in Australia, Canberra - Anna, Michael lives in Japan, Tokyo - Michael
David lives in Brazil, Rio de Janeiro - David, Alice lives in France, Paris - Alice, Peter lives in Germany, Berlin - Peter
Sara lives in USA, Washington - Sara, Lucy lives in Turkey, Ankara - Lucy, Tom lives in Italy, Rome - Tom
John lives in Spain, Madrid - John, Michael lives in Canada, Toronto - Michael, Anna lives in Australia, Canberra - Anna
David lives in Japan, Tokyo - David, Sara lives in Brazil, Rio de Janeiro - Sara, Alice lives in France, Paris - Alice
Bob lives in Germany, Berlin - Bob, Peter lives in USA, Washington - Peter, Lucy lives in Turkey

In [6]:
from hybrid_retrieval_dataset4 import HybridRetrievalDataset

hybrid_retrieval_dataset = HybridRetrievalDataset()

# Get factual prompts
clean_data, patch_data = hybrid_retrieval_dataset.get_dataset()

clean_data.shape, patch_data.shape

Clean Prompts:
Alice lives in France, Paris - Alice, John lives in Germany, Berlin - John, Peter lives in USA, Washington - Peter
Lucy lives in Turkey, Ankara - Lucy, Sara lives in Italy, Rome - Sara, Bob lives in Spain, Madrid - Bob
Tom lives in Canada, Toronto - Tom, Anna lives in Australia, Canberra - Anna, Michael lives in Japan, Tokyo - Michael
David lives in Brazil, Rio de Janeiro - David, Alice lives in France, Paris - Alice, Peter lives in Germany, Berlin - Peter
Sara lives in USA, Washington - Sara, Lucy lives in Turkey, Ankara - Lucy, Tom lives in Italy, Rome - Tom
John lives in Spain, Madrid - John, Michael lives in Canada, Toronto - Michael, Anna lives in Australia, Canberra - Anna
David lives in Japan, Tokyo - David, Sara lives in Brazil, Rio de Janeiro - Sara, Alice lives in France, Paris - Alice
Bob lives in Germany, Berlin - Bob, Peter lives in USA, Washington - Peter, Lucy lives in Turkey, Ankara - Lucy
Anna lives in Italy, Rome - Anna, Tom lives in Spain, Madrid - Tom

(torch.Size([20, 29]), torch.Size([20, 28]))

In [9]:
clean_prompts = [
            "Alice lives in France, Paris - Alice, John lives in Germany, Berlin - John, Peter lives in USA, Washington - Peter",
            "Lucy lives in Turkey, Ankara - Lucy, Sara lives in Italy, Rome - Sara, Bob lives in Spain, Madrid - Bob",
            "Tom lives in Canada, Toronto - Tom, Anna lives in Australia, Canberra - Anna, Michael lives in Japan, Tokyo - Michael",
            "David lives in Brazil, Rio de Janeiro - David, Alice lives in France, Paris - Alice, Peter lives in Germany, Berlin - Peter",
            "Sara lives in USA, Washington - Sara, Lucy lives in Turkey, Ankara - Lucy, Tom lives in Italy, Rome - Tom",
            "John lives in Spain, Madrid - John, Michael lives in Canada, Toronto - Michael, Anna lives in Australia, Canberra - Anna",
            "David lives in Japan, Tokyo - David, Sara lives in Brazil, Rio de Janeiro - Sara, Alice lives in France, Paris - Alice",
            "Bob lives in Germany, Berlin - Bob, Peter lives in USA, Washington - Peter, Lucy lives in Turkey, Ankara - Lucy",
            "Anna lives in Italy, Rome - Anna, Tom lives in Spain, Madrid - Tom, David lives in Canada, Toronto - David",
            "Michael lives in Australia, Canberra - Michael, John lives in Japan, Tokyo - John, Sara lives in Brazil, Rio de Janeiro - Sara",
            "Alice lives in France, Paris - Alice, Bob lives in Germany, Berlin - Bob, John lives in USA, Washington - John",
            "Peter lives in Turkey, Ankara - Peter, Alice lives in Italy, Rome - Alice, Bob lives in France, Paris - Bob",
            "Lucy lives in Spain, Madrid - Lucy, Michael lives in Canada, Toronto - Michael, Tom lives in Australia, Canberra - Tom",
            "Anna lives in Japan, Tokyo - Anna, Sara lives in Brazil, Rio de Janeiro - Sara, David lives in France, Paris - David",
            "John lives in Germany, Berlin - John, Peter lives in USA, Washington - Peter, Lucy lives in Turkey, Ankara - Lucy",
            "Tom lives in Italy, Rome - Tom, David lives in Spain, Madrid - David, Michael lives in Canada, Toronto - Michael",
            "Sara lives in Australia, Canberra - Sara, Alice lives in Japan, Tokyo - Alice, Bob lives in Brazil, Rio de Janeiro - Bob",
            "Peter lives in France, Paris - Peter, Lucy lives in Germany, Berlin - Lucy, Tom lives in USA, Washington - Tom",
            "David lives in Turkey, Ankara - David, Michael lives in Italy, Rome - Michael, Anna lives in Spain, Madrid - Anna",
            "John lives in Canada, Toronto - John, Sara lives in Australia, Canberra - Sara, Alice lives in Japan, Tokyo - Alice"
        ]
        
# Counterfactual prompts
corrupted_prompts = [
    "Alice lives in France, Paris - Alice, John lives in Germany, Berlin - John, Peter lives in Italy, Paris - Sara",
    "Lucy lives in Turkey, Ankara - Lucy, Sara lives in Italy, Rome - Sara, Bob lives in Canada, Berlin - Alice",
    "Tom lives in Canada, Toronto - Tom, Anna lives in Australia, Canberra - Anna, Michael lives in Brazil, Washington - Bob",
    "David lives in Brazil, Rio de Janeiro - David, Alice lives in France, Paris - Alice, Peter lives in Japan, Rome - Lucy",
    "Sara lives in USA, Washington - Sara, Lucy lives in Turkey, Ankara - Lucy, Tom lives in Spain, Madrid - Alice",
    "John lives in Spain, Madrid - John, Michael lives in Canada, Toronto - Michael, Anna lives in France, Berlin - Peter",
    "David lives in Japan, Tokyo - David, Sara lives in Brazil, Rio de Janeiro - Sara, Alice lives in Turkey, Madrid - Tom",
    "Bob lives in Germany, Berlin - Bob, Peter lives in USA, Washington - Peter, Lucy lives in Italy, Tokyo - Anna",
    "Anna lives in Italy, Rome - Anna, Tom lives in Spain, Madrid - Tom, David lives in Germany, Canberra - Michael",
    "Michael lives in Australia, Canberra - Michael, John lives in Japan, Tokyo - John, Sara lives in France, Berlin - David",
    "Alice lives in France, Paris - Alice, Bob lives in Germany, Berlin - Bob, John lives in Spain, Madrid - Peter",
    "Peter lives in Turkey, Ankara - Peter, Alice lives in Italy, Rome - Alice, Bob lives in Canada, Tokyo - Lucy",
    "Lucy lives in Spain, Madrid - Lucy, Michael lives in Canada, Toronto - Michael, Tom lives in Japan, Washington - John",
    "Anna lives in Japan, Tokyo - Anna, Sara lives in Brazil, Rio de Janeiro - Sara, David lives in Germany, Berlin - Bob",
    "John lives in Germany, Berlin - John, Peter lives in USA, Washington - Peter, Lucy lives in Spain, Madrid - Michael",
    "Tom lives in Italy, Rome - Tom, David lives in Spain, Madrid - David, Michael lives in Japan, Canberra - John",
    "Sara lives in Australia, Canberra - Sara, Alice lives in Japan, Tokyo - Alice, Bob lives in Turkey, Paris - David",
    "Peter lives in France, Paris - Peter, Lucy lives in Germany, Berlin - Lucy, Tom lives in Canada, Rome - Sara",
    "David lives in Turkey, Ankara - David, Michael lives in Italy, Rome - Michael, Anna lives in Germany, Berlin - Bob",
    "John lives in Canada, Toronto - John, Sara lives in Australia, Canberra - Sara, Alice lives in Turkey, Washington - Peter"
]   

len(clean_prompts), len(corrupted_prompts)

(20, 20)

In [7]:
from acdc.ioi.utils import (
    get_all_ioi_things,
    get_gpt2_small,
)

from acdc.ioi.ioi_dataset import IOIDataset
num_examples = 100
things = get_all_ioi_things( 
    num_examples=num_examples, device=device, metric_name="kl_div"
)

ioi_dataset = IOIDataset(
        prompt_type="ABBA",
        N=num_examples*2,
        nb_templates=1,
        seed = 0,
    )

abc_dataset = abc_dataset = (
        ioi_dataset.gen_flipped_prompts(("IO", "RAND"), seed=1)
        .gen_flipped_prompts(("S", "RAND"), seed=2)
        .gen_flipped_prompts(("S1", "RAND"), seed=3)
    )
seq_len = ioi_dataset.toks.shape[1]
assert seq_len == 16, f"Well, I thought ABBA #1 was 16 not {seq_len} tokens long..."

Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cuda




In [8]:
clean_examples = ioi_dataset.tokenized_prompts

folder_path = "acdc/ioi/"
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

# Save the examples as a text file
file_path = os.path.join(folder_path, "clean_prompts.txt")
with open(file_path, "w") as file:
    for example in clean_examples:
        file.write(str(example) + "\n")

In [9]:
corrupted_examples = abc_dataset.tokenized_prompts

folder_path = "acdc/ioi/"
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

# Save the examples as a text file
file_path = os.path.join(folder_path, "corrupted_prompts.txt")
with open(file_path, "w") as file:
    for example in corrupted_examples:
        file.write(str(example) + "\n")

In [10]:
default_data = ioi_dataset.toks.long()[:num_examples*2, : seq_len - 1].to(device)
patch_data = abc_dataset.toks.long()[:num_examples*2, : seq_len - 1].to(device)
labels = ioi_dataset.toks.long()[:num_examples*2, seq_len-1]
wrong_labels = torch.as_tensor(ioi_dataset.s_tokenIDs[:num_examples*2], dtype=torch.long, device=device)

assert torch.equal(labels, torch.as_tensor(ioi_dataset.io_tokenIDs, dtype=torch.long))
labels = labels.to(device)

validation_data = default_data[:num_examples, :]
validation_patch_data = patch_data[:num_examples, :]
validation_labels = labels[:num_examples]
validation_wrong_labels = wrong_labels[:num_examples]

test_data = default_data[num_examples:, :]
test_patch_data = patch_data[num_examples:, :]
test_labels = labels[num_examples:]
test_wrong_labels = wrong_labels[num_examples:]

In [11]:
test_data.shape, test_patch_data.shape, test_labels.shape, test_wrong_labels.shape

(torch.Size([100, 15]),
 torch.Size([100, 15]),
 torch.Size([100]),
 torch.Size([100]))

In [12]:
validation_data.shape, validation_patch_data.shape, validation_labels.shape, validation_wrong_labels.shape

(torch.Size([100, 15]),
 torch.Size([100, 15]),
 torch.Size([100]),
 torch.Size([100]))

: 