In [3]:
"""
Get the general graph of a given model

Algorithm :
Get model & dicts
Get circuit fct
aggregate graphs over dataset
test this graph on dataset

TODO : get circuit fct is wrong from marks
TODO : test graph is wrong from marks (but eh, it will yield better results as we will
       have essentially the whole graph, but hush, don't say it ! :o)
"""

try:
    import google.colab
    IN_COLAB = True
    from tqdm.notebook import tqdm, trange

    from google.colab import drive
    drive.mount("/content/gdrive", force_remount=True)
    %cd /content/gdrive/MyDrive/feature-circuits
    %pip install -r requirements.txt
    !git submodule update --init
except:
    IN_COLAB = False
    from tqdm import tqdm, trange

import os

from transformers import logging
logging.set_verbosity_error()

import torch
from nnsight import LanguageModel
from datasets import load_dataset

from dictionary_learning import AutoEncoder
from activation_utils import SparseAct
from buffer import TokenBuffer
from circuit import get_circuit
from ablation import run_with_ablations

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print("DEVICE :", DEVICE)

print("IN_COLAB :", IN_COLAB)

DEVICE : cpu
IN_COLAB : False


First step : generate the graph

compute the circuit for random examples from wikipedia, and aggregate the results

TODO : move tokensbuffer in a separate file

In [4]:
pythia70m = LanguageModel(
    "EleutherAI/pythia-70m-deduped",
    device_map=DEVICE,
    dispatch=True,
)

pythia70m_embed = pythia70m.gpt_neox.embed_in

pythia70m_resids= []
pythia70m_attns = []
pythia70m_mlps = []
for layer in range(len(pythia70m.gpt_neox.layers)):
    pythia70m_resids.append(pythia70m.gpt_neox.layers[layer])
    pythia70m_attns.append(pythia70m.gpt_neox.layers[layer].attention)
    pythia70m_mlps.append(pythia70m.gpt_neox.layers[layer].mlp)

In [3]:
dataset = load_dataset(
    "wikipedia",
    language="en",
    date="20240401",
    split="train",
    streaming=True,
    trust_remote_code=True
).shuffle()
dataset = iter(dataset)

In [5]:
buffer = TokenBuffer(
    dataset,
    pythia70m,
    n_ctxs=10,
    ctx_len=128,
    load_buffer_batch_size=512,
    return_batch_size=1,
    device=DEVICE,
    max_number_of_yields=2**20,
    discard_bos=True
)

In [6]:
if IN_COLAB:
    base = "/content/gdrive/MyDrive/feature-circuits/"
else:
    base = "C:/Users/Grégoire/Documents/ENS/stages/AttentionGraph/Marks/feature-circuits/"
path = base + "dictionary_learning/dictionaires/pythia-70m-deduped/"

if not os.path.exists(path):
    if IN_COLAB:
        # go to base / dictionary_learning :
        %cd /content/gdrive/MyDrive/feature-circuits/dictionary_learning
        !apt-get update
        !apt-get install dos2unix
        !dos2unix pretrained_dictionary_downloader.sh
        !chmod +x pretrained_dictionary_downloader.sh
        !./pretrained_dictionary_downloader.sh
        %cd /content/gdrive/MyDrive/feature-circuits
    else:
        %cd C:/Users/Grégoire/Documents/ENS/stages/AttentionGraph/Marks/feature-circuits/dictionary_learning
        %run ./pretrained_dictionary_downloader.sh
        %cd C:/Users/Grégoire/Documents/ENS/stages/AttentionGraph/Marks/feature-circuits

dictionaries = {}

d_model = 512
dict_size = 32768

ae = AutoEncoder(d_model, dict_size).to(DEVICE)
ae.load_state_dict(torch.load(path + f"embed/ae.pt", map_location=DEVICE))
dictionaries[pythia70m_embed] = ae


for layer in range(len(pythia70m.gpt_neox.layers)):
    ae = AutoEncoder(d_model, dict_size).to(DEVICE)
    ae.load_state_dict(torch.load(path + f"resid_out_layer{layer}/ae.pt", map_location=DEVICE))
    dictionaries[pythia70m_resids[layer]] = ae

    ae = AutoEncoder(d_model, dict_size).to(DEVICE)
    ae.load_state_dict(torch.load(path + f"attn_out_layer{layer}/ae.pt", map_location=DEVICE))
    dictionaries[pythia70m_attns[layer]] = ae

    ae = AutoEncoder(d_model, dict_size).to(DEVICE)
    ae.load_state_dict(torch.load(path + f"mlp_out_layer{layer}/ae.pt", map_location=DEVICE))
    dictionaries[pythia70m_mlps[layer]] = ae

In [7]:
def metric_fn_v1(model, trg=None):
    """
    default : return the logit
    """
    if trg is None:
        raise ValueError("trg must be provided")
    logits = model.embed_out.output[:,-1,:]
    return logits[torch.arange(trg.numel()), trg]
    
def metric_fn_v2(model, trg=None):
    """
    default : return the logit
    """
    if trg is None:
        raise ValueError("trg must be provided")
    logits = model.embed_out.output[:,trg[0],:]
    return logits[0, 0, trg[1]]

def metric_fn_v3(model, trg=None):
    """
    Return -log probability for the expected target.

    trg : torch.Tensor, contains idxs of the target tokens (between 0 and d_vocab_out)

    /!\ here we assume that all last tokens are indeed in the last position (if padding, it must happen in front of the sequence, not after)
    """
    if trg is None:
        raise ValueError("trg must be provided")
    logits = model.embed_out.output[:,-1,:]
    return (
         -1 * torch.gather(
             torch.nn.functional.log_softmax(model.embed_out.output[:,-1,:], dim=-1),
             dim=-1, index=trg.view(-1, 1)
         ).squeeze(-1)
    )

TODO : if multiple GPUS, use nn.DataParallel and compute batches of length num_gpus. Each GPU will compute one input. Maybe DistributedDataParallel is better.

Or : launch N instances of the code that work independently on random inputs, each on their own GPU, save the circuits in a file and then process 0 is in charge of aggregating the results. If torch provide multiprocessing communications, this can be done without storing to the disc. Then process 0 sends to all the others the final circuits, they all test it and aggregate the results.

In [10]:
tot_circuit = None
i = 0
max_loop = 10
for tokens, trg_idx, trg in tqdm(buffer):
    if i >= max_loop:
        break
    print(i)
    print(tokens.shape, tokens.dtype)
    print(trg_idx.shape, trg_idx)
    print(trg.shape, trg)
    i += 1
    circuit = get_circuit(
        tokens,
        None,
        pythia70m,
        pythia70m_embed,
        pythia70m_attns,
        pythia70m_mlps,
        pythia70m_resids,
        dictionaries,
        metric_fn_v2, {"trg": (trg_idx, trg)},
        edge_threshold=0.1
    )
    if tot_circuit is None:
        tot_circuit = circuit
    else:
        for k, v in circuit[0].items():
            if v is not None:
                tot_circuit[0][k] += v
        for ku, vu in circuit[1].items():
            for kd, vd in vu.items():
                if vd is not None:
                    tot_circuit[1][ku][kd] += vd

0it [00:00, ?it/s]You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


0
torch.Size([1, 128]) torch.int64
torch.Size([1]) tensor([121])
torch.Size([1]) tensor([783])


1it [10:54, 654.71s/it]

1
torch.Size([1, 128]) torch.int64
torch.Size([1]) tensor([125])
torch.Size([1]) tensor([187])


2it [22:27, 677.07s/it]

2
torch.Size([1, 128]) torch.int64
torch.Size([1]) tensor([118])
torch.Size([1]) tensor([41358])


In [None]:
t = torch.randn(1, 128, 50304)

trg_idx = torch.tensor([126]).long()
trg = torch.tensor([285]).long()
print(t.shape)
print(t[:, trg_idx, :].shape)

torch.Size([1, 128, 50304])
torch.Size([1, 1, 50304])


In [None]:
ablation_fn = lambda x: x.mean(dim=0).expand_as(x)

# get m(C) for the circuit obtained by thresholding nodes with the given threshold
@t.no_grad()
def get_fcs(
    model,
    clean,
    patch,
    circuit,
    submodules,
    dictionaries,
    ablation_fn,
    thresholds,
    handle_errors = 'default', # also 'remove' or 'resid_only'
):
    clean_inputs = clean
    clean_answer_idxs = trg_idx
    patch_inputs = patch
    patch_answer_idxs = patch_trg_idx

    def metric_fn(model):
        return (
            - t.gather(model.embed_out.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) + \
            t.gather(model.embed_out.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
        )
    
    circuit = circuit[0]

    out = {}

    # get F(M)
    with model.trace(clean_inputs):
        metric = metric_fn(model).save()
    fm = metric.value.mean().item()

    out['fm'] = fm

    # get m(∅)
    fempty = run_with_ablations(
        clean_inputs,
        patch_inputs,
        model,
        submodules,
        dictionaries,
        nodes = {
            submod : SparseAct(
                act=t.zeros(dict_size, dtype=t.bool), 
                resc=t.zeros(1, dtype=t.bool)).to(DEVICE)
            for submod in submodules
        },
        metric_fn=metric_fn,
        ablation_fn=ablation_fn,
    ).mean().item()
    out['fempty'] = fempty

    for threshold in thresholds:
        out[threshold] = {}
        nodes = {
            submod : circuit[submod_names[submod]].abs() > threshold for submod in submodules
        }

        if handle_errors == 'remove':
            for k in nodes: nodes[k].resc = t.zeros_like(nodes[k].resc, dtype=t.bool)
        elif handle_errors == 'resid_only':
            for k in nodes:
                if k not in model.gpt_neox.layers: nodes[k].resc = t.zeros_like(nodes[k].resc, dtype=t.bool)

        n_nodes = sum([n.act.sum() + n.resc.sum() for n in nodes.values()]).item()
        out[threshold]['n_nodes'] = n_nodes
        
        out[threshold]['fc'] = run_with_ablations(
            clean_inputs,
            patch_inputs,
            model,
            submodules,
            dictionaries,
            nodes=nodes,
            metric_fn=metric_fn,
            ablation_fn=ablation_fn,
        ).mean().item()
        out[threshold]['fccomp'] = run_with_ablations(
            clean_inputs,
            patch_inputs,
            model,
            submodules,
            dictionaries,
            nodes=nodes,
            metric_fn=metric_fn,
            ablation_fn=ablation_fn,
            complement=True
        ).mean().item()
        out[threshold]['faithfulness'] = (out[threshold]['fc'] - fempty) / (fm - fempty)
        out[threshold]['completeness'] = (out[threshold]['fccomp'] - fempty) / (fm - fempty)
    
    return out