In [22]:
from copy import deepcopy
import torch

from tqdm import tqdm
import pandas as pd
import torch
import torch as t
from easy_transformer.EasyTransformer import (
    EasyTransformer,
)
from time import ctime
from functools import partial

import numpy as np
from tqdm import tqdm
import pandas as pd

from easy_transformer.experiments import (
    ExperimentMetric,
    AblationConfig,
    EasyAblation,
    EasyPatching,
    PatchingConfig,
)
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import random
import einops
from IPython import get_ipython
from copy import deepcopy
from easy_transformer.ioi_dataset import (
    IOIDataset,
)
from easy_transformer.ioi_utils import (
    path_patching,
    max_2d,
    CLASS_COLORS,
    show_pp,
    show_attention_patterns,
    scatter_attention_and_contribution,
)
from random import randint as ri
from easy_transformer.ioi_circuit_extraction import (
    do_circuit_extraction,
    get_heads_circuit,
    CIRCUIT,
)
from easy_transformer.ioi_utils import logit_diff, probs
from easy_transformer.ioi_utils import get_top_tokens_and_probs as g

ipython = get_ipython()
if ipython is not None:
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

import torch.nn.functional as F

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [34]:
class ICL_dataset:
    def __init__(self, input, labels, N):
        self.input = input
        self.labels = labels
        self.N = N

In [23]:
model = EasyTransformer.from_pretrained("gpt2")
model.set_use_attn_result(True)

Moving model to device:  cpu
Finished loading pretrained model gpt2 into EasyTransformer!


In [35]:
def generate_data(model, device, x_initial, y_initial, icl_length, n, offset=0):
    prompts = []
    correct_answers = []
    # Initialize x and y for the sequence
    x, y = x_initial + offset, y_initial + offset


    for i in range(n):
        prompt = ''
        for j in range(icl_length + 1):
            if j < icl_length:
                prompt += f"Input: {j}, Output: {j * x + y}\n"
            else:
                prompt += f"Input: {j}, Output:"
                correct_answers.append(j * x + y)  # Record the correct answer for the last input
        prompts.append(prompt)
        # Update x and y after generating each full prompt set
        if i % 2 == 0:
            x += 1
        else:
            y += 1


    # Convert prompts into tokens
    data_tokens = [model.to_tokens(prompt).to(device) for prompt in prompts]
    correct_answers_tensor = torch.tensor(correct_answers).to(torch.double).unsqueeze(-1).to(device)
    return ICL_dataset(input=prompts, labels=correct_answers_tensor, N=n)

In [39]:
def validation_metric(model, dataset, device='cpu', return_one_element = False):
        # dataset: {input: data, labels: correct, }
        logits = model(dataset.input, return_type="logits")
    
    
        # Select the logits for the last token in each sequence
        # model_output shape: [batch_size, seq_length, vocab_size] => [10, 103, 50257]
        # We select [:, -1, :] to get the last token logits for each example in the batch
        last_token_logits = logits[:, -1, :]  # Shape: [10, 50257]
    
        # Now, find the indices of the 10 highest logits for the last token across the batch
        # We use torch.topk to get the top 10 logits' indices for each example
        topk_values, topk_indices = torch.topk(last_token_logits, 1, dim=1) 

        predictions = model.to_str_tokens(topk_indices)
        predictions = torch.tensor([int(pred) for pred in predictions]).to(torch.double).unsqueeze(-1).to(device)

        # Calculate MSE
        mse = F.mse_loss(predictions, dataset.labels, reduction='mean' if not return_one_element else 'sum')
        return mse

In [40]:
def test_validation_metric(device, model, x_initial, y_initial, icl_length, n):
        dataset = generate_data(model, device, x_initial, y_initial, icl_length, n)
        mse = validation_metric(model, dataset)
        
        return mse
        print('This is the MSE: ', mse)
test_validation_metric('cpu', model, 2, 1, 12, 10)  # 5.7

tensor(5.7000, dtype=torch.float64)

In [11]:
CIRCUIT

{'name mover': [(9, 9),
  (10, 0),
  (9, 6),
  (10, 10),
  (10, 6),
  (10, 2),
  (10, 1),
  (11, 2),
  (9, 7),
  (9, 0),
  (11, 9)],
 'negative': [(10, 7), (11, 10)],
 's2 inhibition': [(7, 3), (7, 9), (8, 6), (8, 10)],
 'induction': [(5, 5), (5, 8), (5, 9), (6, 9)],
 'duplicate token': [(0, 1), (0, 10), (3, 0)],
 'previous token': [(2, 2), (4, 11)]}

In [47]:
ICL_CIRCUIT = {
  'operating': [(8, 1), (6, 9), (5, 0), (8, 9), (9, 11), (9, 2), (7, 7), (5, 2), (4, 8), (5, 1), (3, 4), (4, 3), (9, 9), (8, 11), (6, 10), (8, 8), (6, 0), (9, 5), (6, 3), (8, 6), (6, 7), (6, 6), (7, 6), (6, 2), (7, 10), (9, 1), (1, 9), (10, 2), (5, 11), (8, 7), (0, 1), (0, 3), (0, 5), (7, 11), (7, 2), (6, 1), (0, 8), (0, 7), (0, 6), (0, 4), (5, 5)]
}

In [None]:
def do_circuit_extraction(
    heads_to_remove=None,  # {(2,3) : List[List[int]]: dimensions dataset_size * datapoint_length
    mlps_to_remove=None,  # {2: List[List[int]]: dimensions dataset_size * datapoint_length
    heads_to_keep=None,  # as above for heads
    mlps_to_keep=None,  # as above for mlps
    ioi_dataset=None,
    mean_dataset=None,
    model=None,
    metric=None,
    excluded=[],  # tuple of (layer, head) or (layer, None for MLPs)
    return_hooks=False,
    hooks_dict=False,
):
    """
    ..._to_remove means the indices ablated away. Otherwise the indices not ablated away.

    `exclude_heads` is a list of heads that actually we won't put any hooks on. Just keep them as is

    if `mean_dataset` is None, just use the ioi_dataset for mean
    """

    # check if we are either in keep XOR remove move from the args
    ablation, heads, mlps = get_circuit_replacement_hook(
        heads_to_remove=heads_to_remove,  # {(2,3) : List[List[int]]: dimensions dataset_size * datapoint_length
        mlps_to_remove=mlps_to_remove,  # {2: List[List[int]]: dimensions dataset_size * datapoint_length
        heads_to_keep=heads_to_keep,  # as above for heads
        mlps_to_keep=mlps_to_keep,  # as above for mlps
        ioi_dataset=ioi_dataset,
        model=model,
    )

    metric = ExperimentMetric(
        metric=metric, dataset=ioi_dataset.sentences, relative_metric=False
    )  # TODO make dummy metric

    if mean_dataset is None:
        mean_dataset = ioi_dataset

    config = AblationConfig(
        abl_type="custom",
        abl_fn=ablation,
        mean_dataset=mean_dataset.toks.long(),  # TODO nb of prompts useless ?
        target_module="attn_head",
        head_circuit="result",
        cache_means=True,  # circuit extraction *has* to cache means. the get_mean reset the
        verbose=True,
    )
    abl = EasyAblation(
        model,
        config,
        metric,
        semantic_indices=None,  # ioi_dataset.sem_tok_idx,
        mean_by_groups=True,  # TO CHECK CIRCUIT BY GROUPS
        groups=ioi_dataset.groups,
    )
    model.reset_hooks()

    hooks = {}

    heads_keys = list(heads.keys())
    # sort in lexicographic order
    heads_keys.sort(key=lambda x: (x[0], x[1]))

    for (
        layer,
        head,
    ) in heads_keys:  # a sketchy edit here didn't really improve things : (
        if (layer, head) in excluded:
            continue
        assert (layer, head) not in hooks, ((layer, head), "already in hooks")
        hooks[(layer, head)] = abl.get_hook(layer, head)
        # model.add_hook(*abl.get_hook(layer, head))
    for layer in mlps.keys():
        hooks[(layer, None)] = abl.get_hook(layer, head=None, target_module="mlp")
        # model.add_hook(*abl.get_hook(layer, head=None, target_module="mlp"))

    if return_hooks:
        if hooks_dict:
            return hooks
        else:
            return list(hooks.values())

    else:
        for hook in hooks.values():
            model.add_hook(*hook)
        return model, abl