Connected to finetuning (Python 3.10.4)

In [None]:
# Import necessary libraries and functions from helper modules
import sys
sys.path.append('../plm_circuits')

# Import utility functions
from helpers.utils import (
    clear_memory,
    load_esm,
    load_sae_prot,
    mask_flanks_segment,
    patching_metric,
    cleanup_cuda
)

# Import attribution functions
from attribution import (
    integrated_gradients_sae,
    topk_sae_err_pt
)

# Import hook classes
from hook_manager import SAEHookProt

# Additional imports
import json
from functools import partial
import torch
import numpy as np
import matplotlib.pyplot as plt
import collections

In IPython
Set autoreload
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In IPython
Set autoreload


In [None]:
# Setup device and load models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load ESM-2 model
esm_transformer, batch_converter, esm2_alphabet = load_esm(33, device=device)

# Load SAEs for multiple layers
main_layers = [4, 8, 12, 16, 20, 24, 28]
saes = []
for layer in main_layers:
    sae_model = load_sae_prot(ESM_DIM=1280, SAE_DIM=4096, LAYER=layer, device=device)
    saes.append(sae_model)

layer_2_saelayer = {layer: layer_idx for layer_idx, layer in enumerate(main_layers)}

Using device: cuda


In [None]:
# Load sequence data and define protein parameters
with open('../data/full_seq_dict.json', "r") as json_file:
    seq_dict = json.load(json_file)

# Define protein-specific parameters
sse_dict = {"2B61A": [[182, 316]], "1PVGA": [[101, 202]]}
fl_dict = {"2B61A": [44, 43], "1PVGA": [65, 63]}

# Choose protein for analysis
protein = "2B61A"
seq = seq_dict[protein]
position = sse_dict[protein][0]

# Define segment boundaries
ss1_start = position[0] - 5 
ss1_end = position[0] + 5 + 1 
ss2_start = position[1] - 5 
ss2_end = position[1] + 5 + 1 

print(f"Analyzing protein: {protein}")
print(f"Sequence length: {len(seq)}")
print(f"Segment 1: {ss1_start}-{ss1_end}")
print(f"Segment 2: {ss2_start}-{ss2_end}")

# Prepare full sequence and get baseline contact predictions
full_seq_L = [(1, seq)]
_, _, batch_tokens_BL = batch_converter(full_seq_L)
batch_tokens_BL = batch_tokens_BL.to(device)
batch_mask_BL = (batch_tokens_BL != esm2_alphabet.padding_idx).to(device)

with torch.no_grad():
    full_seq_contact_LL = esm_transformer.predict_contacts(batch_tokens_BL, batch_mask_BL)[0]

# Prepare clean sequence (with optimal flanks)
clean_fl = fl_dict[protein][0]
L = len(seq)
left_start = max(0, ss1_start - clean_fl)
left_end = ss1_start
right_start = ss2_end
right_end = min(L, ss2_end + clean_fl)
unmask_left_idxs = list(range(left_start, left_end))
unmask_right_idxs = list(range(right_start, right_end))

clean_seq_L = mask_flanks_segment(seq, ss1_start, ss1_end, ss2_start, ss2_end, unmask_left_idxs, unmask_right_idxs)
_, _, clean_batch_tokens_BL = batch_converter([(1, clean_seq_L)])
clean_batch_tokens_BL = clean_batch_tokens_BL.to(device)
clean_batch_mask_BL = (clean_batch_tokens_BL != esm2_alphabet.padding_idx).to(device)

with torch.no_grad():
    clean_seq_contact_LL = esm_transformer.predict_contacts(clean_batch_tokens_BL, clean_batch_mask_BL)[0]

print(f"Clean flank size: {clean_fl}")
print(f"Clean sequence contact recovery: {patching_metric(clean_seq_contact_LL, full_seq_contact_LL, ss1_start, ss1_end, ss2_start, ss2_end):.4f}")

# Prepare corrupted sequence (with suboptimal flanks)
corr_fl = fl_dict[protein][1]
left_start = max(0, ss1_start - corr_fl)
left_end = ss1_start
right_start = ss2_end
right_end = min(L, ss2_end + corr_fl)
unmask_left_idxs = list(range(left_start, left_end))
unmask_right_idxs = list(range(right_start, right_end))

corr_seq_L = mask_flanks_segment(seq, ss1_start, ss1_end, ss2_start, ss2_end, unmask_left_idxs, unmask_right_idxs)
_, _, corr_batch_tokens_BL = batch_converter([(1, corr_seq_L)])
corr_batch_tokens_BL = corr_batch_tokens_BL.to(device)
corr_batch_mask_BL = (corr_batch_tokens_BL != esm2_alphabet.padding_idx).to(device)

with torch.no_grad():
    corr_seq_contact_LL = esm_transformer.predict_contacts(corr_batch_tokens_BL, corr_batch_mask_BL)[0]

print(f"Corrupted flank size: {corr_fl}")
print(f"Corrupted sequence contact recovery: {patching_metric(corr_seq_contact_LL, full_seq_contact_LL, ss1_start, ss1_end, ss2_start, ss2_end):.4f}")

# Create patching metric function
_patching_metric = partial(
    patching_metric,
    orig_contact=full_seq_contact_LL,
    ss1_start=ss1_start,
    ss1_end=ss1_end,
    ss2_start=ss2_start,
    ss2_end=ss2_end,
)

Analyzing protein: 2B61A
Sequence length: 377
Segment 1: 177-188
Segment 2: 311-322
Clean flank size: 44
Clean sequence contact recovery: 0.5738
Corrupted flank size: 43
Corrupted sequence contact recovery: 0.0279


In [None]:
# Perform causal ranking for all latent-token pairs across layers
print("Starting causal ranking with integrated gradients...")

all_effects_sae_ALS = []
all_effects_err_ABLF = []

for layer_idx in main_layers:
    print(f"\nProcessing layer {layer_idx}...")
    
    sae_model = saes[layer_2_saelayer[layer_idx]]

    # Get clean cache and error
    hook = SAEHookProt(sae=sae_model, mask_BL=clean_batch_mask_BL, cache_latents=True, layer_is_lm=False, calc_error=True, use_error=True)
    handle = esm_transformer.esm.encoder.layer[layer_idx].register_forward_hook(hook)
    with torch.no_grad():
        clean_seq_sae_contact_LL = esm_transformer.predict_contacts(clean_batch_tokens_BL, clean_batch_mask_BL)[0]
    cleanup_cuda()
    handle.remove()
    clean_cache_LS = sae_model.feature_acts
    clean_err_cache_BLF = sae_model.error_term
    clean_contact_recovery = _patching_metric(clean_seq_sae_contact_LL)

    # Get corrupted cache and error
    hook = SAEHookProt(sae=sae_model, mask_BL=corr_batch_mask_BL, cache_latents=True, layer_is_lm=False, calc_error=True, use_error=True)
    handle = esm_transformer.esm.encoder.layer[layer_idx].register_forward_hook(hook)
    with torch.no_grad():
        corr_seq_sae_contact_LL = esm_transformer.predict_contacts(corr_batch_tokens_BL, corr_batch_mask_BL)[0]
    cleanup_cuda()
    handle.remove()
    corr_cache_LS = sae_model.feature_acts
    corr_err_cache_BLF = sae_model.error_term
    corr_contact_recovery = _patching_metric(corr_seq_sae_contact_LL)
    
    print(f"Layer {layer_idx}: Clean contact recovery: {clean_contact_recovery:.4f}, Corr contact recovery: {corr_contact_recovery:.4f}")

    # Run integrated gradients
    effect_sae_LS, effect_err_BLF = integrated_gradients_sae(
        esm_transformer,
        sae_model,
        _patching_metric,
        clean_cache_LS.to(device),
        corr_cache_LS.to(device),
        clean_err_cache_BLF.to(device),
        corr_err_cache_BLF.to(device),
        batch_tokens=clean_batch_tokens_BL,
        batch_mask=clean_batch_mask_BL,
        hook_layer=layer_idx,
    )

    all_effects_sae_ALS.append(effect_sae_LS)
    all_effects_err_ABLF.append(effect_err_BLF)

# Stack all effects
all_effects_sae_ALS = torch.stack(all_effects_sae_ALS)
all_effects_err_ABLF = torch.stack(all_effects_err_ABLF)

print(f"\nCausal ranking complete!")
print(f"SAE effects shape: {all_effects_sae_ALS.shape}")
print(f"Error effects shape: {all_effects_err_ABLF.shape}")

Starting causal ranking with integrated gradients...

Processing layer 4...
Layer 4: Clean contact recovery: 0.5738, Corr contact recovery: 0.0279
ratio: 0.0, score: 0.5737996101379395
ratio: 0.1, score: 0.601131796836853
ratio: 0.2, score: 0.4815935492515564
ratio: 0.30000000000000004, score: 0.4985087811946869
ratio: 0.4, score: 0.3368060886859894
ratio: 0.5, score: 0.2868437170982361
ratio: 0.6000000000000001, score: 0.24697260558605194
ratio: 0.7000000000000001, score: 0.12411289662122726
ratio: 0.8, score: 0.10038772225379944
ratio: 0.9, score: 0.033265773206949234

Processing layer 8...
Layer 8: Clean contact recovery: 0.5738, Corr contact recovery: 0.0279
ratio: 0.0, score: 0.5737998485565186
ratio: 0.1, score: 0.5733504891395569
ratio: 0.2, score: 0.49743935465812683
ratio: 0.30000000000000004, score: 0.4819444417953491
ratio: 0.4, score: 0.4560147225856781
ratio: 0.5, score: 0.3371899425983429
ratio: 0.6000000000000001, score: 0.24517741799354553
ratio: 0.7000000000000001, sco

In [None]:
print("Creating layer-wise caches for performance analysis...")

clean_layer_caches = {}
corr_layer_caches = {}
clean_layer_errors = {}
corr_layer_errors = {}

for layer_idx in main_layers:
    sae_model = saes[layer_2_saelayer[layer_idx]]
    
    # Clean caches
    hook = SAEHookProt(sae=sae_model, mask_BL=clean_batch_mask_BL, cache_latents=True, 
                       layer_is_lm=False, calc_error=True, use_error=True)
    handle = esm_transformer.esm.encoder.layer[layer_idx].register_forward_hook(hook)
    with torch.no_grad():
        clean_seq_sae_contact_LL = esm_transformer.predict_contacts(clean_batch_tokens_BL, clean_batch_mask_BL)[0]
    cleanup_cuda()
    handle.remove()
    print(f"Layer {layer_idx}, clean score: {_patching_metric(clean_seq_sae_contact_LL):.4f}")
    clean_layer_caches[layer_idx] = sae_model.feature_acts
    clean_layer_errors[layer_idx] = sae_model.error_term
    # print shapes
    print(clean_layer_caches[layer_idx].shape, clean_layer_errors[layer_idx].shape)

    # Corrupted caches
    hook = SAEHookProt(sae=sae_model, mask_BL=corr_batch_mask_BL, cache_latents=True, 
                       layer_is_lm=False, calc_error=True, use_error=True)
    handle = esm_transformer.esm.encoder.layer[layer_idx].register_forward_hook(hook)
    with torch.no_grad():
        corr_seq_sae_contact_LL = esm_transformer.predict_contacts(corr_batch_tokens_BL, corr_batch_mask_BL)[0]
    cleanup_cuda()
    handle.remove()
    print(f"Layer {layer_idx}, corr score: {_patching_metric(corr_seq_sae_contact_LL):.4f}")
    corr_layer_caches[layer_idx] = sae_model.feature_acts
    corr_layer_errors[layer_idx] = sae_model.error_term

print("Layer-wise caches created successfully!")

Creating layer-wise caches for performance analysis...
Layer 4, clean score: 0.5738
torch.Size([379, 4096]) torch.Size([1, 379, 1280])
Layer 4, corr score: 0.0279
Layer 8, clean score: 0.5738
torch.Size([379, 4096]) torch.Size([1, 379, 1280])
Layer 8, corr score: 0.0279
Layer 12, clean score: 0.5738
torch.Size([379, 4096]) torch.Size([1, 379, 1280])
Layer 12, corr score: 0.0279
Layer 16, clean score: 0.5738
torch.Size([379, 4096]) torch.Size([1, 379, 1280])
Layer 16, corr score: 0.0279
Layer 20, clean score: 0.5738
torch.Size([379, 4096]) torch.Size([1, 379, 1280])
Layer 20, corr score: 0.0279
Layer 24, clean score: 0.5738
torch.Size([379, 4096]) torch.Size([1, 379, 1280])
Layer 24, corr score: 0.0279
Layer 28, clean score: 0.5738
torch.Size([379, 4096]) torch.Size([1, 379, 1280])
Layer 28, corr score: 0.0279
Layer-wise caches created successfully!


In [None]:
up_layer = 4
down_layer = 8
print(layer_2_saelayer[up_layer], layer_2_saelayer[down_layer])
up_sae = saes[layer_2_saelayer[up_layer]]
down_sae = saes[layer_2_saelayer[down_layer]]

up_effects = all_effects_sae_ALS[layer_2_saelayer[up_layer]]
down_effects = all_effects_sae_ALS[layer_2_saelayer[down_layer]]

0 1


In [None]:
# getting the top 50 features for each layer

up_effect_flat = up_effects.reshape(-1)
down_effect_flat = down_effects.reshape(-1)

top_rank_vals_up, top_idx_up = torch.topk(up_effect_flat, k=50, largest=False, sorted=True)
top_rank_vals_down, top_idx_down = torch.topk(down_effect_flat, k=50, largest=False, sorted=True)

row_indices_up = top_idx_up // S
col_indices_up = top_idx_up % S

row_indices_down = top_idx_down // S
col_indices_down = top_idx_down % S

# print the top 5 for up and down
print(top_rank_vals_up[:5], top_idx_up[:5])
print(top_rank_vals_down[:5], top_idx_down[:5])

NameError: name 'S' is not defined

In [None]:
# getting the top 50 features for each layer

up_effect_flat = up_effects.reshape(-1)
down_effect_flat = down_effects.reshape(-1)

top_rank_vals_up, top_idx_up = torch.topk(up_effect_flat, k=50, largest=False, sorted=True)
top_rank_vals_down, top_idx_down = torch.topk(down_effect_flat, k=50, largest=False, sorted=True)

L, S = up_effects.shape

row_indices_up = top_idx_up // S
col_indices_up = top_idx_up % S

row_indices_down = top_idx_down // S
col_indices_down = top_idx_down % S

# print the top 5 for up and down
print(top_rank_vals_up[:5], top_idx_up[:5])
print(top_rank_vals_down[:5], top_idx_down[:5])

tensor([-0.0259, -0.0236, -0.0187, -0.0187, -0.0174]) tensor([1460619, 1461827, 1459139, 1503572, 1503469])
tensor([-0.0173, -0.0173, -0.0160, -0.0133, -0.0133]) tensor([1294824,  748149, 1297111, 1296822,  748856])


In [None]:
# getting the top 50 features for each layer

up_effect_flat = up_effects.reshape(-1)
down_effect_flat = down_effects.reshape(-1)

top_rank_vals_up, top_idx_up = torch.topk(up_effect_flat, k=50, largest=False, sorted=True)
top_rank_vals_down, top_idx_down = torch.topk(down_effect_flat, k=50, largest=False, sorted=True)

L, S = up_effects.shape

row_indices_up = top_idx_up // S
col_indices_up = top_idx_up % S

row_indices_down = top_idx_down // S
col_indices_down = top_idx_down % S

# print the top 5 for up and down

for i in range(5):
    print(f"Up: {top_rank_vals_up[i]:.6f}, {row_indices_up[i]}, {col_indices_up[i]}")
    print(f"Down: {top_rank_vals_down[i]:.6f}, {row_indices_down[i]}, {col_indices_down[i]}")

Up: -0.025851, 356, 2443
Down: -0.017311, 316, 488
Up: -0.023650, 356, 3651
Down: -0.017272, 182, 2677
Up: -0.018743, 356, 963
Down: -0.015968, 316, 2775
Up: -0.018701, 367, 340
Down: -0.013317, 316, 2486
Up: -0.017374, 367, 237
Down: -0.013286, 182, 3384


In [None]:
# getting the top 50 features for each layer

up_effect_flat = up_effects.reshape(-1)
down_effect_flat = down_effects.reshape(-1)

top_rank_vals_up, top_idx_up = torch.topk(up_effect_flat, k=50, largest=False, sorted=True)
top_rank_vals_down, top_idx_down = torch.topk(down_effect_flat, k=50, largest=False, sorted=True)

L, S = up_effects.shape

row_indices_up = top_idx_up // S
col_indices_up = top_idx_up % S

row_indices_down = top_idx_down // S
col_indices_down = top_idx_down % S

# print the top 5 for up and down

for i in range(5):
    print(f"Up: {top_rank_vals_up[i]:.6f}, {row_indices_up[i]}, {col_indices_up[i]}, {up_effects[row_indices_up[i], col_indices_up[i]]:.6f}")

for i in range(5):
    print(f"Down: {top_rank_vals_down[i]:.6f}, {row_indices_down[i]}, {col_indices_down[i]}, {down_effects[row_indices_down[i], col_indices_down[i]]:.6f}")

Up: -0.025851, 356, 2443, -0.025851
Up: -0.023650, 356, 3651, -0.023650
Up: -0.018743, 356, 963, -0.018743
Up: -0.018701, 367, 340, -0.018701
Up: -0.017374, 367, 237, -0.017374
Down: -0.017311, 316, 488, -0.017311
Down: -0.017272, 182, 2677, -0.017272
Down: -0.015968, 316, 2775, -0.015968
Down: -0.013317, 316, 2486, -0.013317
Down: -0.013286, 182, 3384, -0.013286


In [None]:
up_feats = [col_indices_up[i] for i in range(len(col_indices_up))]
down_feats = [col_indices_down[i] for i in range(len(col_indices_down))]


print(up_feats, down_feats)

[tensor(2443), tensor(3651), tensor(963), tensor(340), tensor(237), tensor(1474), tensor(794), tensor(443), tensor(2340), tensor(3788), tensor(3701), tensor(2311), tensor(2277), tensor(3153), tensor(798), tensor(3634), tensor(1682), tensor(1690), tensor(3764), tensor(3326), tensor(1096), tensor(3351), tensor(1712), tensor(181), tensor(3177), tensor(3832), tensor(1807), tensor(3612), tensor(495), tensor(1297), tensor(1807), tensor(816), tensor(1890), tensor(1474), tensor(992), tensor(72), tensor(1370), tensor(481), tensor(1297), tensor(2956), tensor(2850), tensor(816), tensor(3343), tensor(379), tensor(2672), tensor(897), tensor(3480), tensor(1297), tensor(423), tensor(1890)] [tensor(488), tensor(2677), tensor(2775), tensor(2486), tensor(3384), tensor(2775), tensor(431), tensor(1575), tensor(2166), tensor(3921), tensor(3319), tensor(3092), tensor(3381), tensor(2693), tensor(1244), tensor(2380), tensor(1489), tensor(431), tensor(3384), tensor(3102), tensor(576), tensor(1815), tensor(2662

In [None]:
up_feats = [col_indices_up[i].item() for i in range(len(col_indices_up))]
down_feats = [col_indices_down[i].item() for i in range(len(col_indices_down))]


print(up_feats, down_feats)

[2443, 3651, 963, 340, 237, 1474, 794, 443, 2340, 3788, 3701, 2311, 2277, 3153, 798, 3634, 1682, 1690, 3764, 3326, 1096, 3351, 1712, 181, 3177, 3832, 1807, 3612, 495, 1297, 1807, 816, 1890, 1474, 992, 72, 1370, 481, 1297, 2956, 2850, 816, 3343, 379, 2672, 897, 3480, 1297, 423, 1890] [488, 2677, 2775, 2486, 3384, 2775, 431, 1575, 2166, 3921, 3319, 3092, 3381, 2693, 1244, 2380, 1489, 431, 3384, 3102, 576, 1815, 2662, 3864, 2524, 1835, 545, 3642, 4083, 1591, 2041, 2862, 3682, 3997, 2209, 1605, 1233, 3384, 1815, 3384, 2576, 1586, 4042, 2675, 3921, 3716, 312, 3642, 2594, 3368]


In [None]:
up_feats = [col_indices_up[i].item() for i in range(len(col_indices_up))]
down_feats = [col_indices_down[i].item() for i in range(len(col_indices_down))]

print(up_feats, down_feats)
print(len(up_feats), len(down_feats))

[2443, 3651, 963, 340, 237, 1474, 794, 443, 2340, 3788, 3701, 2311, 2277, 3153, 798, 3634, 1682, 1690, 3764, 3326, 1096, 3351, 1712, 181, 3177, 3832, 1807, 3612, 495, 1297, 1807, 816, 1890, 1474, 992, 72, 1370, 481, 1297, 2956, 2850, 816, 3343, 379, 2672, 897, 3480, 1297, 423, 1890] [488, 2677, 2775, 2486, 3384, 2775, 431, 1575, 2166, 3921, 3319, 3092, 3381, 2693, 1244, 2380, 1489, 431, 3384, 3102, 576, 1815, 2662, 3864, 2524, 1835, 545, 3642, 4083, 1591, 2041, 2862, 3682, 3997, 2209, 1605, 1233, 3384, 1815, 3384, 2576, 1586, 4042, 2675, 3921, 3716, 312, 3642, 2594, 3368]
50 50


In [None]:
up_feats = set(col_indices_up)
down_feats = set(col_indices_down)
# up_feats = [col_indices_up[i].item() for i in range(len(col_indices_up))]
# down_feats = [col_indices_down[i].item() for i in range(len(col_indices_down))]

print(up_feats, down_feats)
print(len(up_feats), len(down_feats))

{tensor(237), tensor(963), tensor(340), tensor(2340), tensor(3326), tensor(3177), tensor(481), tensor(3343), tensor(379), tensor(1712), tensor(2850), tensor(1890), tensor(3788), tensor(2311), tensor(1807), tensor(3480), tensor(3764), tensor(1096), tensor(3351), tensor(181), tensor(816), tensor(1370), tensor(816), tensor(3651), tensor(3701), tensor(1682), tensor(3612), tensor(72), tensor(443), tensor(1690), tensor(1297), tensor(992), tensor(1297), tensor(897), tensor(423), tensor(3634), tensor(3832), tensor(495), tensor(1474), tensor(2956), tensor(2672), tensor(2443), tensor(1474), tensor(794), tensor(2277), tensor(3153), tensor(798), tensor(1807), tensor(1890), tensor(1297)} {tensor(2677), tensor(431), tensor(2693), tensor(2662), tensor(2524), tensor(3997), tensor(3921), tensor(3092), tensor(1815), tensor(2041), tensor(1815), tensor(3368), tensor(488), tensor(2486), tensor(431), tensor(3384), tensor(1835), tensor(1605), tensor(1586), tensor(312), tensor(2775), tensor(3921), tensor(1244

In [None]:
up_feats = set([col_indices_up[i].item() for i in range(len(col_indices_up))])
down_feats = set([col_indices_down[i].item() for i in range(len(col_indices_down))])
# up_feats = [col_indices_up[i].item() for i in range(len(col_indices_up))]
# down_feats = [col_indices_down[i].item() for i in range(len(col_indices_down))]

print(len(up_feats), len(down_feats))
print(up_feats, down_feats)

44 42
{897, 2311, 2443, 2956, 1807, 3343, 1297, 1682, 3351, 3480, 1690, 794, 3612, 798, 2850, 2340, 423, 1712, 816, 3634, 3764, 181, 443, 1474, 3651, 963, 1096, 72, 3788, 3153, 340, 1370, 992, 481, 1890, 2277, 3177, 237, 495, 2672, 3701, 3832, 379, 3326} {3716, 2693, 2576, 3092, 1815, 3864, 3997, 3102, 545, 2209, 2594, 1575, 3368, 1835, 2862, 431, 1586, 3381, 2486, 1591, 3384, 312, 3642, 576, 1605, 4042, 2380, 3921, 1489, 1233, 2775, 1244, 2524, 3682, 2662, 488, 4083, 2675, 2677, 2166, 3319, 2041}


In [None]:
clean_layer_caches

{4: tensor([[-0.1607,  0.1270,  1.4314,  ..., -0.3031, -0.1435,  0.2459],
         [ 0.0055,  0.0566,  0.4922,  ..., -0.1025, -0.1160,  0.3915],
         [ 0.0207,  0.0668,  0.3704,  ..., -0.1236, -0.1527,  0.4087],
         ...,
         [ 0.4824, -0.1890,  0.3038,  ..., -0.0845, -0.0922,  0.4022],
         [ 0.0837,  0.0077,  0.1362,  ..., -0.2109, -0.0207,  0.4601],
         [-0.5801, -0.0170,  0.1643,  ..., -0.2089, -0.3712,  0.2665]]),
 8: tensor([[-1.0235,  0.8383, -2.1761,  ..., -2.2417, -1.9606,  0.0067],
         [-2.4729, -2.3292, -2.8839,  ..., -3.9402, -0.5738, -0.0505],
         [-2.4464, -2.3414, -2.9118,  ..., -3.6948, -0.4485, -0.1777],
         ...,
         [-1.4288, -1.3229,  0.0797,  ..., -4.1169,  0.0491,  0.5928],
         [-1.3160, -0.2827, -0.7971,  ..., -3.3593, -1.3171,  0.4319],
         [-1.0548,  1.0433, -1.2201,  ..., -1.7649, -0.6998, -0.3574]]),
 12: tensor([[-1.3224, -1.0477, -1.4013,  ..., -1.3675, -0.3735,  0.5314],
         [-2.7802, -0.9256, -1.6140

In [None]:
up_base = clean_layer_caches[up_layer].detach().clone().to(device).requires_grad_()

In [None]:
up_base = clean_layer_caches[up_layer].detach().clone().to(device).requires_grad_()
up_base_corr = corr_layer_caches[up_layer].detach().clone().to(device).requires_grad_()
patch_mask_LS = torch.ones((L, S), dtype=torch.bool, device=device)

In [None]:
def _forward_fn() -> torch.Tensor:    # returns [B, L, S_down]

    # up hook that puts patch activations 
    up_sae.mean_error = up_error
    up_hook = SAEHookProt(sae=up_sae, mask_BL=clean_batch_mask_BL, patch_mask_BLS=patch_mask_LS, patch_value=up_base, use_mean_error=True)

    # down hook that records downstream activations
    down_sae.mean_error = down_error
    down_hook = SAEHookProt(sae=down_sae, mask_BL=clean_batch_mask_BL, cache_latents=True, layer_is_lm=False, calc_error=True, use_error=True, no_detach=True)

    # register the hooks
    handle_up = esm_transformer.esm.encoder.layer[up_layer].register_forward_hook(up_hook)
    handle_down = esm_transformer.esm.encoder.layer[down_layer].register_forward_hook(down_hook)

    # run the forward pass
    # _, saes_out = run_with_saes( # TODO add the hook for each 1. intervening on upstream, 2. recording downstream
    #     model,
    #     base_saes,
    #     token_list,
    #     calc_error=False,
    #     use_error=False,
    #     fake_activations=(upstream_sae.cfg.hook_layer, up_base),  # TODO saes dont have cfg
    #     use_mean_error=use_mean_error,
    #     cache_sae_activations=True,   # we need the graph intact
    #     no_detach=True,
    # )
    # feats = saes_out[downstream_sae.cfg.hook_layer].feature_acts
    _ = esm_transformer.predict_contacts(clean_batch_tokens_BL, clean_batch_mask_BL)[0]
    handle_up.remove()
    handle_down.remove()
    feats = down_sae.feature_acts
    if not feats.requires_grad:
        raise RuntimeError(
            "[edge-attr-vjp] downstream activations are detached; "
            "remove `.detach()` inside your SAE hook or clone with "
            "`.requires_grad_()` earlier in the graph."
        )
    return feats

In [None]:
# ----------------------------------------------------------------------
# 3. Single forward pass (re-used for every downstream feature)
# ----------------------------------------------------------------------
down_base = _forward_fn()             # [B,L,S_down]
down_grad = down_effects.to(device)

# Container: (down_idx, up_idx) → list[val]
bucket: Dict[Tuple[int, int], List[torch.Tensor]] = {}

NameError: name 'up_error' is not defined

In [None]:
up_base = clean_layer_caches[up_layer].detach().clone().to(device).requires_grad_()
up_base_corr = corr_layer_caches[up_layer].detach().clone().to(device).requires_grad_()
patch_mask_LS = torch.ones((L, S), dtype=torch.bool, device=device)
up_error = clean_layer_errors[up_layer].to(device)
down_error = clean_layer_errors[down_layer].to(device)

In [None]:
def _forward_fn() -> torch.Tensor:    # returns [B, L, S_down]

    # up hook that puts patch activations 
    up_sae.mean_error = up_error
    up_hook = SAEHookProt(sae=up_sae, mask_BL=clean_batch_mask_BL, patch_mask_BLS=patch_mask_LS, patch_value=up_base, use_mean_error=True)

    # down hook that records downstream activations
    down_sae.mean_error = down_error
    down_hook = SAEHookProt(sae=down_sae, mask_BL=clean_batch_mask_BL, cache_latents=True, layer_is_lm=False, calc_error=True, use_error=True, no_detach=True)

    # register the hooks
    handle_up = esm_transformer.esm.encoder.layer[up_layer].register_forward_hook(up_hook)
    handle_down = esm_transformer.esm.encoder.layer[down_layer].register_forward_hook(down_hook)

    # run the forward pass
    # _, saes_out = run_with_saes( # TODO add the hook for each 1. intervening on upstream, 2. recording downstream
    #     model,
    #     base_saes,
    #     token_list,
    #     calc_error=False,
    #     use_error=False,
    #     fake_activations=(upstream_sae.cfg.hook_layer, up_base),  # TODO saes dont have cfg
    #     use_mean_error=use_mean_error,
    #     cache_sae_activations=True,   # we need the graph intact
    #     no_detach=True,
    # )
    # feats = saes_out[downstream_sae.cfg.hook_layer].feature_acts
    _ = esm_transformer.predict_contacts(clean_batch_tokens_BL, clean_batch_mask_BL)[0]
    handle_up.remove()
    handle_down.remove()
    feats = down_sae.feature_acts
    if not feats.requires_grad:
        raise RuntimeError(
            "[edge-attr-vjp] downstream activations are detached; "
            "remove `.detach()` inside your SAE hook or clone with "
            "`.requires_grad_()` earlier in the graph."
        )
    return feats

In [None]:
# ----------------------------------------------------------------------
# 3. Single forward pass (re-used for every downstream feature)
# ----------------------------------------------------------------------
down_base = _forward_fn()             # [B,L,S_down]
down_grad = down_effects.to(device)

# Container: (down_idx, up_idx) → list[val]
bucket: Dict[Tuple[int, int], List[torch.Tensor]] = {}

NameError: name 'Dict' is not defined

In [None]:
from typing import List, Dict, Any, Optional, Tuple
# ----------------------------------------------------------------------
# 3. Single forward pass (re-used for every downstream feature)
# ----------------------------------------------------------------------
down_base = _forward_fn()             # [B,L,S_down]
down_grad = down_effects.to(device)

# Container: (down_idx, up_idx) → list[val]
bucket: Dict[Tuple[int, int], List[torch.Tensor]] = {}

In [None]:
# ----------------------------------------------------------------------
# 4. Loop over downstream features (rows of the Jacobian)
# ----------------------------------------------------------------------
for d_idx in down_feats:
    # Select the scalar we will back-prop; optionally weight by loss grad
    scalar_field = down_base[..., d_idx]
    if down_grad is not None:
        scalar_field = scalar_field * down_grad[..., d_idx]
    scalar = scalar_field.sum()

    # Jᵀ ▽  – gradient w.r.t. *entire* upstream latent tensor
    grad_tensor = torch.autograd.grad(
        scalar,
        up_base,
        retain_graph=True,   # keep graph for next d_idx
        create_graph=False,  # we only need first-order grads
    )[0]                     # shape [B,L,S_up]

    # Accumulate entries we care about
    for u_idx in up_feats:
        val = grad_tensor[..., u_idx].sum()  # Σ_{b,t}
        if val.abs() < 1e-6:                 # keep/raise threshold as needed
            continue
        bucket.setdefault((d_idx, u_idx), []).append(val.detach().cpu())

    if logstats and (d_idx == down_feats[0] or d_idx % 10 == 0):
        print(f"[edge-attr-vjp] processed downstream idx {d_idx}")

# ----------------------------------------------------------------------
# 5. Assemble sparse COO tensor
# ----------------------------------------------------------------------
if not bucket:
    return None

idxs, vals = zip(
    *[((d, u), torch.stack(v).mean()) for (d, u), v in bucket.items()]
)
idx_mat = torch.tensor(list(zip(*idxs)), dtype=torch.long)  # [2, N]
val_mat = torch.stack(list(vals))                           # [N]

edge_tensor = torch.sparse_coo_tensor(
    idx_mat,
    val_mat,
    size=(len(down_feats), len(up_feats)),
).coalesce()

if logstats:
    nnz = edge_tensor._nnz()
    print(f"[edge-attr-vjp] finished – {nnz} non-zero entries")

NameError: name 'logstats' is not defined

In [None]:
logstats = True
# ----------------------------------------------------------------------
# 4. Loop over downstream features (rows of the Jacobian)
# ----------------------------------------------------------------------
for d_idx in down_feats:
    # Select the scalar we will back-prop; optionally weight by loss grad
    scalar_field = down_base[..., d_idx]
    if down_grad is not None:
        scalar_field = scalar_field * down_grad[..., d_idx]
    scalar = scalar_field.sum()

    # Jᵀ ▽  – gradient w.r.t. *entire* upstream latent tensor
    grad_tensor = torch.autograd.grad(
        scalar,
        up_base,
        retain_graph=True,   # keep graph for next d_idx
        create_graph=False,  # we only need first-order grads
    )[0]                     # shape [B,L,S_up]

    # Accumulate entries we care about
    for u_idx in up_feats:
        val = grad_tensor[..., u_idx].sum()  # Σ_{b,t}
        if val.abs() < 1e-6:                 # keep/raise threshold as needed
            continue
        bucket.setdefault((d_idx, u_idx), []).append(val.detach().cpu())

    if logstats and (d_idx == down_feats[0] or d_idx % 10 == 0):
        print(f"[edge-attr-vjp] processed downstream idx {d_idx}")

# ----------------------------------------------------------------------
# 5. Assemble sparse COO tensor
# ----------------------------------------------------------------------
if not bucket:
    return None

idxs, vals = zip(
    *[((d, u), torch.stack(v).mean()) for (d, u), v in bucket.items()]
)
idx_mat = torch.tensor(list(zip(*idxs)), dtype=torch.long)  # [2, N]
val_mat = torch.stack(list(vals))                           # [N]

edge_tensor = torch.sparse_coo_tensor(
    idx_mat,
    val_mat,
    size=(len(down_feats), len(up_feats)),
).coalesce()

if logstats:
    nnz = edge_tensor._nnz()
    print(f"[edge-attr-vjp] finished – {nnz} non-zero entries")

TypeError: 'set' object is not subscriptable

In [None]:
logstats = True
# ----------------------------------------------------------------------
# 4. Loop over downstream features (rows of the Jacobian)
# ----------------------------------------------------------------------
for d_idx in down_feats:
    # Select the scalar we will back-prop; optionally weight by loss grad
    scalar_field = down_base[..., d_idx]
    if down_grad is not None:
        scalar_field = scalar_field * down_grad[..., d_idx]
    scalar = scalar_field.sum()

    # Jᵀ ▽  – gradient w.r.t. *entire* upstream latent tensor
    grad_tensor = torch.autograd.grad(
        scalar,
        up_base,
        retain_graph=True,   # keep graph for next d_idx
        create_graph=False,  # we only need first-order grads
    )[0]                     # shape [B,L,S_up]

    # Accumulate entries we care about
    for u_idx in up_feats:
        val = grad_tensor[..., u_idx].sum()  # Σ_{b,t}
        if val.abs() < 1e-6:                 # keep/raise threshold as needed
            continue
        bucket.setdefault((d_idx, u_idx), []).append(val.detach().cpu())

    if logstats and (d_idx == list(down_feats)[0] or d_idx % 10 == 0):
        print(f"[edge-attr-vjp] processed downstream idx {d_idx}")

# ----------------------------------------------------------------------
# 5. Assemble sparse COO tensor
# ----------------------------------------------------------------------
if not bucket:
    return None

idxs, vals = zip(
    *[((d, u), torch.stack(v).mean()) for (d, u), v in bucket.items()]
)
idx_mat = torch.tensor(list(zip(*idxs)), dtype=torch.long)  # [2, N]
val_mat = torch.stack(list(vals))                           # [N]

edge_tensor = torch.sparse_coo_tensor(
    idx_mat,
    val_mat,
    size=(len(down_feats), len(up_feats)),
).coalesce()

if logstats:
    nnz = edge_tensor._nnz()
    print(f"[edge-attr-vjp] finished – {nnz} non-zero entries")

[edge-attr-vjp] processed downstream idx 3716
[edge-attr-vjp] processed downstream idx 2380


SyntaxError: 'return' outside function (<ipython-input-27-78655ce22e91>, line 35)

In [None]:
logstats = True
# ----------------------------------------------------------------------
# 4. Loop over downstream features (rows of the Jacobian)
# ----------------------------------------------------------------------
for d_idx in down_feats:
    # Select the scalar we will back-prop; optionally weight by loss grad
    scalar_field = down_base[..., d_idx]
    if down_grad is not None:
        scalar_field = scalar_field * down_grad[..., d_idx]
    scalar = scalar_field.sum()

    # Jᵀ ▽  – gradient w.r.t. *entire* upstream latent tensor
    grad_tensor = torch.autograd.grad(
        scalar,
        up_base,
        retain_graph=True,   # keep graph for next d_idx
        create_graph=False,  # we only need first-order grads
    )[0]                     # shape [B,L,S_up]

    # Accumulate entries we care about
    for u_idx in up_feats:
        val = grad_tensor[..., u_idx].sum()  # Σ_{b,t}
        if val.abs() < 1e-6:                 # keep/raise threshold as needed
            continue
        bucket.setdefault((d_idx, u_idx), []).append(val.detach().cpu())

    if logstats and (d_idx == list(down_feats)[0] or d_idx % 10 == 0):
        print(f"[edge-attr-vjp] processed downstream idx {d_idx}")

# ----------------------------------------------------------------------
# 5. Assemble sparse COO tensor
# ----------------------------------------------------------------------
if not bucket:
    print("No bucket")
    return None
else:
    idxs, vals = zip(
        *[((d, u), torch.stack(v).mean()) for (d, u), v in bucket.items()]
    )
    idx_mat = torch.tensor(list(zip(*idxs)), dtype=torch.long)  # [2, N]
    val_mat = torch.stack(list(vals))                           # [N]

    edge_tensor = torch.sparse_coo_tensor(
        idx_mat,
        val_mat,
        size=(len(down_feats), len(up_feats)),
    ).coalesce()

    if logstats:
        nnz = edge_tensor._nnz()
        print(f"[edge-attr-vjp] finished – {nnz} non-zero entries")

[edge-attr-vjp] processed downstream idx 3716
[edge-attr-vjp] processed downstream idx 2380


SyntaxError: 'return' outside function (<ipython-input-28-a631ec61d6ec>, line 36)

In [None]:
logstats = True
# ----------------------------------------------------------------------
# 4. Loop over downstream features (rows of the Jacobian)
# ----------------------------------------------------------------------
for d_idx in down_feats:
    # Select the scalar we will back-prop; optionally weight by loss grad
    scalar_field = down_base[..., d_idx]
    if down_grad is not None:
        scalar_field = scalar_field * down_grad[..., d_idx]
    scalar = scalar_field.sum()

    # Jᵀ ▽  – gradient w.r.t. *entire* upstream latent tensor
    grad_tensor = torch.autograd.grad(
        scalar,
        up_base,
        retain_graph=True,   # keep graph for next d_idx
        create_graph=False,  # we only need first-order grads
    )[0]                     # shape [B,L,S_up]

    # Accumulate entries we care about
    for u_idx in up_feats:
        val = grad_tensor[..., u_idx].sum()  # Σ_{b,t}
        if val.abs() < 1e-6:                 # keep/raise threshold as needed
            continue
        bucket.setdefault((d_idx, u_idx), []).append(val.detach().cpu())

    if logstats and (d_idx == list(down_feats)[0] or d_idx % 10 == 0):
        print(f"[edge-attr-vjp] processed downstream idx {d_idx}")

# ----------------------------------------------------------------------
# 5. Assemble sparse COO tensor
# ----------------------------------------------------------------------
if not bucket:
    print("No bucket")
else:
    idxs, vals = zip(
        *[((d, u), torch.stack(v).mean()) for (d, u), v in bucket.items()]
    )
    idx_mat = torch.tensor(list(zip(*idxs)), dtype=torch.long)  # [2, N]
    val_mat = torch.stack(list(vals))                           # [N]

    edge_tensor = torch.sparse_coo_tensor(
        idx_mat,
        val_mat,
        size=(len(down_feats), len(up_feats)),
    ).coalesce()

    if logstats:
        nnz = edge_tensor._nnz()
        print(f"[edge-attr-vjp] finished – {nnz} non-zero entries")

[edge-attr-vjp] processed downstream idx 3716
[edge-attr-vjp] processed downstream idx 2380
[edge-attr-vjp] finished – 1829 non-zero entries


In [None]:
edge_tensor

tensor(indices=tensor([[ 312,  312,  312,  ..., 4083, 4083, 4083],
                       [  72,  237,  340,  ..., 3764, 3788, 3832]]),
       values=tensor([-5.4402e-05,  2.7134e-04,  5.2533e-05,  ...,
                      -2.9230e-04,  4.0238e-04,  1.0122e-03]),
       size=(42, 44), nnz=1829, layout=torch.sparse_coo)

In [None]:
edge_tensor.values()

tensor([-5.4402e-05,  2.7134e-04,  5.2533e-05,  ..., -2.9230e-04,
         4.0238e-04,  1.0122e-03])

In [None]:
print(edge_tensor.values())

tensor([-5.4402e-05,  2.7134e-04,  5.2533e-05,  ..., -2.9230e-04,
         4.0238e-04,  1.0122e-03])


In [None]:
print(edge_tensor.values())

tensor([-5.4402e-05,  2.7134e-04,  5.2533e-05,  ..., -2.9230e-04,
         4.0238e-04,  1.0122e-03])


In [None]:
# Print edge tensor in a readable format
print("\n=== Edge Tensor Analysis ===")
print(f"Edge tensor shape: {edge_tensor.shape}")
print(f"Number of non-zero entries: {edge_tensor._nnz()}")

if edge_tensor._nnz() > 0:
    # Get the indices and values
    indices = edge_tensor.indices()  # [2, nnz] - [down_idx, up_idx]
    values = edge_tensor.values()    # [nnz]
    
    # Convert to lists for easier processing
    down_indices = indices[0].tolist()
    up_indices = indices[1].tolist()
    edge_values = values.tolist()
    
    # Group by upstream feature index
    from collections import defaultdict
    up_to_down = defaultdict(list)
    
    for i in range(len(down_indices)):
        down_idx = down_indices[i]
        up_idx = up_indices[i]
        val = edge_values[i]
        up_to_down[up_idx].append((down_idx, val))
    
    # Sort upstream indices for consistent output
    sorted_up_indices = sorted(up_to_down.keys())
    
    print(f"\nEdge connections (upstream -> downstream):")
    print("="*50)
    
    for up_idx in sorted_up_indices:
        connections = up_to_down[up_idx]
        # Sort connections by absolute value (strongest first)
        connections.sort(key=lambda x: abs(x[1]), reverse=True)
        
        print(f"\nUpstream feature {up_idx}:")
        for down_idx, val in connections:
            print(f"  -> Downstream {down_idx}: {val:.6f}")
    
    # Also show top connections overall
    print(f"\n\nTop 10 strongest connections overall:")
    print("="*50)
    all_connections = [(up_idx, down_idx, val) for up_idx, connections in up_to_down.items() 
                       for down_idx, val in connections]
    all_connections.sort(key=lambda x: abs(x[2]), reverse=True)
    
    for i, (up_idx, down_idx, val) in enumerate(all_connections[:10]):
        print(f"{i+1:2d}. Up {up_idx} -> Down {down_idx}: {val:.6f}")
        
else:
    print("No edges found!")


=== Edge Tensor Analysis ===
Edge tensor shape: torch.Size([42, 44])
Number of non-zero entries: 1829

Edge connections (upstream -> downstream):

Upstream feature 72:
  -> Downstream 431: -0.000430
  -> Downstream 1591: -0.000409
  -> Downstream 3921: -0.000390
  -> Downstream 2677: -0.000303
  -> Downstream 3642: 0.000192
  -> Downstream 2486: 0.000180
  -> Downstream 1575: 0.000171
  -> Downstream 488: -0.000155
  -> Downstream 2693: -0.000140
  -> Downstream 2166: -0.000123
  -> Downstream 3381: -0.000121
  -> Downstream 2380: -0.000113
  -> Downstream 3092: 0.000105
  -> Downstream 3716: -0.000103
  -> Downstream 2524: -0.000095
  -> Downstream 2775: -0.000087
  -> Downstream 2041: -0.000080
  -> Downstream 2662: -0.000075
  -> Downstream 2209: -0.000073
  -> Downstream 3682: 0.000068
  -> Downstream 3102: 0.000061
  -> Downstream 4083: -0.000058
  -> Downstream 576: 0.000057
  -> Downstream 2675: -0.000057
  -> Downstream 312: -0.000054
  -> Downstream 1489: -0.000050
  -> Downs