# Experiments from the 'Emerging Structures in Computational Graphs of Neural Networks' project

-----
## Imports and Setups
-----

##### Set up the environment for remote notebook execution & check node configuration

In [None]:
import os
# access_token = "Your Access Token for Gemma"

import sys
sys.path.append("./NetworkStructures/") # /!\ Comment out if "." is not home directory in goethe's cluster. To check use : print(os.getcwd())

print(os.environ.get("HOSTNAME"))
print(os.getcwd())

##### Import necessary libraries

In [4]:
import math
import torch
from tqdm import tqdm
from transformers import logging
logging.set_verbosity_error()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

##### Import project's modules

In [None]:
from nnsight.models.UnifiedTransformer import UnifiedTransformer
from connectivity.effective import get_circuit_feature
from evaluation.faithfulness import faithfulness as faithfulness_fn
from data.buffer import unpack_batch, ioi_buffer, simple_rc_buffer, rc_buffer, single_input_buffer

from utils.ablation_fns import id_ablation
from utils.plotting import plot_faithfulness
from utils.metric_fns import metric_fn_logit, metric_fn_KL, metric_fn_statistical_distance
from utils.experiments_setup import load_model_and_modules, load_saes, get_architectural_graph
from utils.activation import get_hidden_states, get_is_tuple

-----

# Sanity Checks

-----

### Check if the multi gpu setup is working

In [None]:
# TODO

### Check if the SAEs are used properly

Check L1 and variance explained

In [None]:
# TODO

### Check if studied models are able to solve the tasks.

If not, there is no point in trying to find out how it solves it.

In [None]:
# for model_name in ["gemma-2-2b", "pythia-70m-deduped", "gpt2"]:
#     print("##########")
#     print(model_name)
#     print("##########")
#     # model, name2mod = load_model_and_modules(device=DEVICE, model_name=model_name, resid=use_resid, attn=use_attn_mlp, mlp=use_attn_mlp, start_at_layer=start_at_layer)
#     model = UnifiedTransformer(
#         model_name,
#         device=DEVICE,
#         use_auth_token=access_token,        
#     )
#     with torch.no_grad():
#         model.device = model.cfg.device
#         model.tokenizer.padding_side = 'left'
#         for buffer_fn in [ioi_buffer, rc_buffer, simple_rc_buffer, gp_buffer, gt_buffer]:
#             perm=torch.randperm(400)
#             buffer = buffer_fn(model, 1, DEVICE, perm=perm)
#             all_metrics = []
#             c = 0
#             for batch in tqdm(buffer):
#                 tokens, trg_idx, trg, corr, corr_trg = unpack_batch(batch)
#                 c += 1
#                 with model.trace(tokens):
#                     metric_kwargs = {"trg_idx": trg_idx, "trg_pos": trg, "trg_neg": corr_trg}
#                     all_metrics.append(metric_fn_logit(model, metric_kwargs).save())
#             try:
#                 all_metrics = torch.stack(all_metrics)
#                 mean_logit = all_metrics.mean().item()
#                 accuracy = (all_metrics > 0).float().mean().item()
#                 print(f'Buffer {buffer_fn.__name__} done, {c} batches processed')
#                 print(f"Buffer {buffer_fn.__name__} mean logit: {mean_logit} accuracy: {accuracy}")
#             except RuntimeError:
#                 print(f"Buffer {buffer_fn.__name__} failed with {c} batches processed")
#     print("\n")

| Task            | Pythia-70m-deduped | Pythia-70m-deduped | GPT-2 | GPT-2 | Gemma-2-2b | Gemma-2-2b |
|-------------------|--------------------------|------------------------|-------------|-----------|------------------|----------------|
|                   | Mean Logit | Accuracy     | Mean Logit | Accuracy   | Mean Logit | Accuracy  |
| IOI        | -0.063     | 0.0          | 1.188      | 1.0        | 14.1      | 1.0       |
| SV-agreement         | 1.940      | 0.995        | -          | -          | 10.3      | 0.99      |
| Simple SV-agreement  | 3.991      | 1.0          | 4.530      | 1.0        | 16.5      | 1.0       |
| Gender-Pronoun         | 0.978      | 0.755        | 2.842      | 0.907      | 12.9      | 0.90      |
| Greater Than         | 2.911      | 0.817        | 2.951      | 1.0        | NaN       | NaN       |


-----

# Circuit Discovery

-----

##### Define a model to be disected

In [None]:
use_attn_mlp = False
use_resid = True
start_at_layer = 2
model_name = "gemma-2-2b"
model, name2mod = load_model_and_modules(device=DEVICE, model_name=model_name, resid=use_resid, attn=use_attn_mlp, mlp=use_attn_mlp, start_at_layer=start_at_layer)
architectural_graph = get_architectural_graph(model, name2mod)

dictionaries = load_saes(model, name2mod)
print(architectural_graph)

##### Define a task to be solved

In [None]:
buffer = simple_rc_buffer(model, 1, DEVICE, ctx_len=None, perm=None)
batch = next(buffer)
tokens, trg_idx, trg, corr, corr_trg = unpack_batch(batch)

clean = tokens
patch = corr

metric_fn = metric_fn_logit
metric_fn_dict = {
    'logit': metric_fn_logit,
    'KL': metric_fn_KL,
    'Statistical Distance': metric_fn_statistical_distance,
}

metric_kwargs = {"trg_idx": trg_idx, "trg_pos": trg, "trg_neg": corr_trg}

steps = 10
edge_threshold = 1e-4
edge_circuit = True

default_ablation = 'id'
ablation_fn = id_ablation

##### Find the circuit that solves the task

In [None]:
edges = get_circuit_feature(
    clean=clean,
    patch=patch,
    model=model,
    architectural_graph=architectural_graph,
    name2mod=name2mod,
    dictionaries=dictionaries,
    metric_fn=metric_fn,
    metric_kwargs=metric_kwargs,
    ablation_fn=ablation_fn,
    threshold=edge_threshold,
    steps=steps,
    edge_circuit=edge_circuit,
)

##### Evaluate the quality of the circuit

In [None]:
nb_eval_thresholds = 20

thresholds = torch.logspace(math.log10(edge_threshold), 0., nb_eval_thresholds, 10).tolist() # the higher the threshold, the more edges are removed. -1 is to enforce full ablation.

results = faithfulness_fn(
    model,
    name2mod,
    dictionaries,
    clean,
    edges,
    architectural_graph,
    thresholds,
    metric_fn_dict,
    metric_kwargs,
    patch,
    ablation_fn,
    default_ablation=default_ablation,
    node_ablation=(not edge_circuit),
)

##### Show the results

In [None]:
plot_faithfulness(results, save_path=None)

In [None]:
print(results)

In [None]:
raise ValueError("Stop here.")

# Archived cells

Used for various tests

In [4]:
# Check for Variance Explained, or whether the hidden states and the use of SAEs are correct

# use_attn_mlp = False
# use_resid = True
# start_at_layer = 2
# model_name = "gpt2"
# model, name2mod = load_model_and_modules(device=DEVICE, model_name=model_name, resid=use_resid, attn=use_attn_mlp, mlp=use_attn_mlp, start_at_layer=start_at_layer)
# architectural_graph = get_architectural_graph(model, name2mod)
# dictionaries = load_saes(model, name2mod)

# buffer = simple_rc_buffer(model, 1, DEVICE)
# batch = next(buffer)
# tokens, trg_idx, trg, corr, corr_trg = unpack_batch(batch)

# visited = set()
# to_visit = ['y']
# while to_visit:
#     downstream = to_visit.pop()
#     if downstream in visited:
#         continue
#     visited.add(downstream)
#     to_visit += architectural_graph[downstream]

# all_submods = list(visited)
# all_submods.remove('y')
# all_submods = [name2mod[name] for name in all_submods]

# is_tuple = get_is_tuple(model, all_submods)

# hidden_states_clean = get_hidden_states(
#     model, submods=all_submods, dictionaries=dictionaries, is_tuple=is_tuple, input=tokens
# )

# for k in hidden_states_clean:
#     print(k, hidden_states_clean[k].act.shape, hidden_states_clean[k].res.shape) # should be (b, s, d_dict) and (b, s, d_model) respectively
#     print(f"L_0 : {(hidden_states_clean[k].act > 0.0).sum(dim=-1).float().mean()}")
#     print(f"Error norm : {hidden_states_clean[k].res.norm(dim=-1).mean()}")
#     reconstructed = dictionaries[k].decode(hidden_states_clean[k].act) + hidden_states_clean[k].res
#     print(f"Original norm : {reconstructed.norm(dim=-1).mean()}")
#     print(f"Variance Explained : {1 - ((hidden_states_clean[k].res).norm(dim=-1) / reconstructed.norm(dim=-1)).mean()}")
#     print("\n")

In [5]:
# TODO : check how to deal with multi gpu

# from nnsight.models.UnifiedTransformer import UnifiedTransformer

# model_name = "pythia-70m-deduped"
# device = 'auto'

# model = UnifiedTransformer(
#         model_name,
#         device_map=device,
#         processing=False,
#         n_devices=8
#     )

# print(model.cfg)