In [1]:
import sys
import json
import torch
from functools import partial
from tqdm import tqdm
sys.path.append("../plm_circuits/")
from helpers.utils import (
    load_esm,
    load_sae_prot,
    mask_flanks_segment,
    cleanup_cuda,
    patching_metric,
)
from attribution import integrated_gradients_sae
from hook_manager import SAEHookProt

In IPython
Set autoreload


In [None]:
# helpers 
def get_unmask_indices(seq_len: int, start: int, end: int, flank_len: int, is_left: bool):
    """Return the indices to unmask on the left or right flank."""
    if is_left:
        left_start = max(0, start - flank_len)
        return list(range(left_start, start))
    else:
        right_end = min(seq_len, end + flank_len)
        return list(range(end, right_end))

def get_masked_sequence(seq, ss1_start, ss1_end, ss2_start, ss2_end, flank_len, seq_len):
    """Mask a sequence with flanks of given length around ss1 and ss2."""
    unmask_left_idxs = get_unmask_indices(seq_len, ss1_start, ss1_end, flank_len, is_left=True)
    unmask_right_idxs = get_unmask_indices(seq_len, ss1_start, ss2_end, flank_len, is_left=False)
    return mask_flanks_segment(seq, ss1_start, ss1_end, ss2_start, ss2_end, unmask_left_idxs, unmask_right_idxs)

def tokenize_sequence(batch_converter, seq, padding_idx, device):
    """Convert sequence to token tensor and corresponding mask."""
    _, _, tokens = batch_converter([(1, seq)])
    tokens = tokens.to(device)
    mask = (tokens != padding_idx).to(device)
    return tokens, mask

def run_sae_hooked_prediction(
    model, sae_model, tokens, mask, layer_idx, patching_metric_fn
):
    """Run the model with a hooked SAE layer and return activations, error, contact prediction, and recovery metric."""
    hook = SAEHookProt(
        sae=sae_model,
        mask_BL=mask,
        cache_latents=True,
        layer_is_lm=False,
        calc_error=True,
        use_error=True,
    )
    handle = model.esm.encoder.layer[layer_idx].register_forward_hook(hook)
    with torch.no_grad():
        contact_LL = model.predict_contacts(tokens, mask)[0]
    cleanup_cuda()
    handle.remove()

    return sae_model.feature_acts, sae_model.error_term, contact_LL, patching_metric_fn(contact_LL)


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
esm_transformer, batch_converter, esm2_alphabet = load_esm(33, device=device)

main_layers = [4, 8, 12, 16, 20, 24, 28]
saes = []
for layer in main_layers:
    sae_model = load_sae_prot(ESM_DIM=1280, SAE_DIM=4096, LAYER=layer, device=device)
    saes.append(sae_model)

layer_2_saelayer = {layer: layer_idx  for layer_idx, layer in enumerate(main_layers)}

with open('../data/full_seq_dict.json', "r") as json_file:
    seq_dict = json.load(json_file)

sse_dict = {"2B61A": [[182, 316]],"1PVGA": [[101, 202]]}
fl_dict = {"2B61A": [44, 43], "1PVGA": [65, 63]}

In [None]:
protein = "2B61A"
seq = seq_dict[protein]
L = len(seq)

position = sse_dict[protein][0]
ss1_start, ss1_end = position[0] - 5, position[0] + 6
ss2_start, ss2_end = position[1] - 5, position[1] + 6

# Full tokens
full_seq_L = [(1, seq)]
_, _, batch_tokens_BL = batch_converter(full_seq_L)
batch_tokens_BL = batch_tokens_BL.to(device)
batch_mask_BL = (batch_tokens_BL != esm2_alphabet.padding_idx).to(device)

# Clean input
clean_fl = fl_dict[protein][0]
clean_seq = get_masked_sequence(seq, ss1_start, ss1_end, ss2_start, ss2_end, clean_fl, L)
clean_batch_tokens_BL, clean_batch_mask_BL = tokenize_sequence(batch_converter, clean_seq, esm2_alphabet.padding_idx, device)

# Corrupted input
corr_fl = fl_dict[protein][1]
corr_seq = get_masked_sequence(seq, ss1_start, ss1_end, ss2_start, ss2_end, corr_fl, L)
corr_batch_tokens_BL, corr_batch_mask_BL = tokenize_sequence(batch_converter, corr_seq, esm2_alphabet.padding_idx, device)


In [9]:
# all_effects_sae_ALS = []
# all_effects_err_ABLF = []

with torch.no_grad():
    full_seq_contact_LL = esm_transformer.predict_contacts(batch_tokens_BL, batch_mask_BL)[0]
cleanup_cuda()

_patching_metric = partial(
    patching_metric,
    full_seq_contact_LL,
    ss1_start=ss1_start,
    ss1_end=ss1_end,
    ss2_start=ss2_start,
    ss2_end=ss2_end,
)

for layer_idx in tqdm(main_layers):
    sae_model = saes[layer_2_saelayer[layer_idx]]

    clean_cache_LS, clean_err_BLF, clean_contact_LL, clean_recovery = run_sae_hooked_prediction(
        esm_transformer, sae_model, clean_batch_tokens_BL, clean_batch_mask_BL, layer_idx, _patching_metric
    )

    corr_cache_LS, corr_err_BLF, corr_contact_LL, corr_recovery = run_sae_hooked_prediction(
        esm_transformer, sae_model, corr_batch_tokens_BL, corr_batch_mask_BL, layer_idx, _patching_metric
    )

    print(f"Layer {layer_idx}: Clean contact recovery: {clean_recovery:.4f}, Corr contact recovery: {corr_recovery:.4f}")

    # effect_sae_LS, effect_err_BLF = integrated_gradients_sae(
    #     esm_transformer,
    #     sae_model,
    #     _patching_metric,
    #     clean_cache_LS,
    #     corr_cache_LS,
    #     clean_err_BLF,
    #     corr_err_BLF,
    #     batch_tokens=clean_batch_tokens_BL,
    #     batch_mask=clean_batch_mask_BL,
    #     hook_layer=layer_idx,
    # )

    # all_effects_sae_ALS.append(effect_sae_LS)
    # all_effects_err_ABLF.append(effect_err_BLF)

 14%|█▍        | 1/7 [00:00<00:01,  3.02it/s]

Layer 4: Clean contact recovery: 1.3437, Corr contact recovery: 13.9466


 29%|██▊       | 2/7 [00:00<00:01,  3.02it/s]

Layer 8: Clean contact recovery: 1.3437, Corr contact recovery: 13.9465


 43%|████▎     | 3/7 [00:01<00:01,  2.96it/s]

Layer 12: Clean contact recovery: 1.3437, Corr contact recovery: 13.9465


 57%|█████▋    | 4/7 [00:01<00:01,  2.93it/s]

Layer 16: Clean contact recovery: 1.3437, Corr contact recovery: 13.9465


 71%|███████▏  | 5/7 [00:01<00:00,  2.91it/s]

Layer 20: Clean contact recovery: 1.3437, Corr contact recovery: 13.9465


 86%|████████▌ | 6/7 [00:02<00:00,  2.90it/s]

Layer 24: Clean contact recovery: 1.3437, Corr contact recovery: 13.9465


100%|██████████| 7/7 [00:02<00:00,  2.92it/s]

Layer 28: Clean contact recovery: 1.3437, Corr contact recovery: 13.9465



