In [5]:
from ioi_utils import *
from circuit_utils import *
from sae_variants import *
from sae_interp import *
from sae_interventions import *
from training import *
from mandala._next.imports import *
from mandala._next.common_imports import *

# Circuit setup

In [6]:
from circuit_utils import *
torch.set_printoptions(sci_mode=False)
if 'model' in locals():
    MODELS[MODEL_ID] = model

HEAD_CLASS_FIG = {
    'nm': 'Name Mover',
    'bnm': 'Backup Name Mover',
    'ind': 'Induction',
    'nnm': 'Negative Name Mover',
    'si': 'S-Inhibition',
    'dt': 'Duplicate Token',
    'pt': 'Previous Token',
}

COMPONENT_NAME_FIG = {
    'k': 'Key',
    'v': 'Value',
    'q': 'Query',
    'z': 'Attn Output',
}

CROSS_SECTION_FIG = {
    'ind+dt@z': 'Ind+DT out',
    'nm+bnm@q': '(B)NM q',
    'nm+bnm@qk': '(B)NM qk',
    'nm+bnm@z': '(B)NM out',
    'si@v': 'S-I v',
    'si@z': 'S-I out',
}

c = Circuit()
paper_cross_sections = [
    # IO
    (c.zs(c.nm + c.bnm), ('io',), 'nm+bnm@z'),
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('io',), 'nm+bnm@qk'),
    (c.qs(c.nm + c.bnm), ('io',), 'nm+bnm@q'),
    # S
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('s',), 'nm+bnm@qk'),
    (c.qs(c.nm + c.bnm), ('s',), 'nm+bnm@q'),
    (c.vs(c.si), ('s',), 'si@v'),
    (c.zs(c.si), ('s',), 'si@z'),
    (c.zs(c.ind) + c.zs(c.dt), ('s',), 'ind+dt@z'),
    # Pos
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('io_pos',), 'nm+bnm@qk'),
    (c.qs(c.nm + c.bnm), ('io_pos',), 'nm+bnm@q'),
    (c.zs(c.si), ('io_pos',), 'si@z'),
    (c.vs(c.si), ('io_pos',), 'si@v'),
    # Pos + S
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('io_pos', 's'), 'nm+bnm@qk'),
    (c.zs(c.si), ('io_pos', 's'), 'si@z'),
    (c.vs(c.si), ('io_pos', 's'), 'si@v'),
    (c.zs(c.ind) + c.zs(c.dt), ('io_pos', 's'), 'ind+dt@z'),
    (c.zs(c.ind) + c.zs(c.dt), ('io_pos',), 'ind+dt@z'),
    # All
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('io', 'io_pos', 's'), 'nm+bnm@qk'),
]

locations_displaynames = {
    'nm+bnm@z': '(B)NM out',
    'nm+bnm@qk': '(B)NM qk',
    'nm+bnm@q': '(B)NM q',
    'si@v': 'S-I v',
    'si@z': 'S-I out',
    'ind+dt@z': 'Ind+DT out',
}

NODES = c.zs(c.nm + c.bnm) + c.qs(c.nm + c.bnm) + c.zs(c.si) + [n for n in c.vs(c.si) if n.seq_pos == 's2'] + c.zs(c.ind) + c.zs(c.dt) + c.ks(c.nm + c.bnm)

could not find model


In [7]:
DB_PATH = '/media/amakelov/SanDisk1TB/paper_sprint/test.db'

In [8]:
storage = Storage(db_path=DB_PATH)

In [9]:
model = get_model()
MODELS[MODEL_ID] = model



Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
# set logging level to debug
from mandala._next.common_imports import logger
import logging
logger.setLevel(logging.INFO)

# Preparing datasets

In [10]:
# nodes = list(c.nodes.keys())
# circuit_nodes = list(c.nodes.keys())

with storage:

    ############################################################################ 
    ### prompt dataset for training supervised features
    ############################################################################ 
    P_train = generate_prompts(
        distribution=full_distribution,
        patterns=['ABB', 'BAB'],
        prompts_per_pattern=10_000,
        random_seed=0,
    )
    N_TRAIN = len(storage.unwrap(P_train))
    ### activations for training supervised features
    As_train = run_with_cache(
        prompts=P_train, 
        nodes=NODES,
        batch_size=100,
        model_id=MODEL_ID,
        verbose=True,
    )
    A_TRAIN_DICT = {node: A for node, A in zip(NODES, As_train)}

    # ### precompute the mean logit difference for clean training data
    # logits_train_clean = run_with_hooks(prompts=P_train, hooks=[], batch_size=200,)
    # CLEAN_LD_MEAN = (storage.unwrap(logits_train_clean)[:, 0] - storage.unwrap(logits_train_clean)[:, 1]).mean().item()

    # ### precompute the mean-ablated logit difference when ablating each node
    # A_TRAIN_MEAN_DICT = {node: get_dataset_mean(A) for node, A in A_TRAIN_DICT.items()}

    # MEAN_ABLATED_LD_DICT = {}
    # for node, A in A_TRAIN_DICT.items():
    #     MEAN_ABLATED_LD_DICT[node] = compute_mean_ablated_lds(
    #         node=node, prompts=P_train, A_mean=A_TRAIN_MEAN_DICT[node], batch_size=200,
    #     )

    ############################################################################ 
    ### prompt dataset for editing and other evaluations
    ############################################################################ 
    N_NAMES = len(NAMES)
    editing_base_distribution = copy.deepcopy(full_distribution)
    editing_base_distribution.names = editing_base_distribution.names[:N_NAMES // 2]
    editing_source_distribution = copy.deepcopy(full_distribution)
    editing_source_distribution.names = editing_source_distribution.names[N_NAMES // 2:]

    P_eval = generate_prompts(
        distribution=editing_base_distribution,
        patterns=['ABB', 'BAB'],
        prompts_per_pattern=2500,
        random_seed=1,
    )
    As_eval = run_with_cache(
        prompts=P_eval, 
        nodes=NODES,
        batch_size=100,
        model_id=MODEL_ID,
        verbose=True,
    )
    P_eval_feature_idxs = get_prompt_feature_idxs(
        prompts=P_eval,
        features=[('io',), ('s',), ('io_pos',),],
    )
    A_EVAL_DICT = {node: A for node, A in zip(NODES, As_eval)}

    N_EVAL = len(storage.unwrap(P_eval))
    N_NAMES_EVAL_SOURCE = len(editing_source_distribution.names)

    ### precompute the mean logit difference for clean training data
    logits_eval_clean = run_with_hooks(prompts=P_eval, hooks=[], batch_size=200,)
    CLEAN_LD_EVAL_MEAN = (storage.unwrap(logits_eval_clean)[:, 0] - storage.unwrap(logits_eval_clean)[:, 1]).mean().item()

    ### precompute the mean-ablated logit difference when ablating each node
    A_EVAL_MEAN_DICT = {node: get_dataset_mean(A) for node, A in A_EVAL_DICT.items()}

    MEAN_ABLATED_LD_EVAL_DICT = {}
    for node, A in A_EVAL_DICT.items():
        MEAN_ABLATED_LD_EVAL_DICT[node] = compute_mean_ablated_lds(
            node=node, prompts=P_eval, A_mean=A_EVAL_MEAN_DICT[node], batch_size=200,
        )

    ############################################################################ 
    ### Compute counterfactual prompts and activations
    ############################################################################ 
    FEATURE_SUBSETS = [('io_pos',), ('s',), ('io',), ] # ('s', 'io_pos',), ('io', 'io_pos'), ('s', 'io',), ('io_pos', 's', 'io',), ]

    CF_PROMPTS_DICT = {}
    for attribute in FEATURE_SUBSETS:
        CF_PROMPTS_DICT[attribute] = get_cf_prompts(
            prompts=P_eval, 
            features=attribute,
            io_targets=generate_name_samples(N_EVAL, editing_source_distribution.names[:N_NAMES_EVAL_SOURCE // 2]),
            s_targets=generate_name_samples(N_EVAL, editing_source_distribution.names[N_NAMES_EVAL_SOURCE//2:]),     
        )
    ### Compute counterfactual activations
    A_EVAL_CF_DICT = {}
    for attribute, cf_prompts in tqdm(CF_PROMPTS_DICT.items()):
        A_EVAL_CF_DICT[attribute] = run_with_cache(
            prompts=cf_prompts, 
            nodes=NODES,
            batch_size=100,
            model_id=MODEL_ID,
            verbose=True,
        )
    for attribute in A_EVAL_CF_DICT:
        A_EVAL_CF_DICT[attribute] = {node: A_EVAL_CF_DICT[attribute][i] for i, node in enumerate(NODES)}
    
    P_eval_cf_feature_idxs = {}
    for attribute, cf_prompts in CF_PROMPTS_DICT.items():
        P_eval_cf_feature_idxs[attribute] = get_prompt_feature_idxs(
            prompts=cf_prompts,
            features=[attribute],
        )

100%|██████████| 3/3 [00:00<00:00, 13.61it/s]


In [None]:
storage.cache_info()

# Computing supervised features

In [None]:
with storage:
    SUPERVISED_FEATURES_DICT = {}
    SUPERVISED_RECONSTRUCTIONS_DICT = {}
    for node, A in tqdm(A_TRAIN_DICT.items()):
        for eventually in ['independent',]: # 'coupled', 'names', ]:
            for codes_type in ('mean',):  # 'mse'):
                node_parametrization = get_parametrization(node=node, eventually=eventually, use_names= (eventually == 'names'))
                node_features = FEATURE_CONFIGURATIONS[node_parametrization]
                code_getter = get_mean_codes if codes_type == 'mean' else lambda features, A, prompts: train_mse_codes(features=features, A=A, prompts=prompts, manual_bias=True)
                codes, reconstructions = code_getter(
                    features=node_features,
                    A=A,
                    prompts=P_train,
                )
                SUPERVISED_FEATURES_DICT[(node, node_parametrization, codes_type)] = codes
                SUPERVISED_RECONSTRUCTIONS_DICT[(node, node_parametrization, codes_type)] = reconstructions

# Training SAEs

In [11]:
### define a uniform schedule for all training runs

# use exponentially spread-out checkpoints for the very early stages of training
# measure right before resampling, as well as in the middle between resamplings
# measure before and after the final LR decay
# use two resampling stages, as it seems effects diminish after the first one
CHECKPOINT_STEPS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 500, 750, 1000, 1250, 1500, 2000]
RESAMPLE_EPOCHS = [501, 1001, ]
FINAL_DECAY_START = 1500 # decay the LR for the last 25% of training
FINAL_DECAY_END = 2000


## Vanilla SAEs

In [None]:
storage.preload()

In [None]:
from torch.optim import Adam

with storage:
    metrics_dfs = []
    for node in tqdm(NODES):
        A_train = A_TRAIN_DICT[node]
        A_train_normalized, scale = normalize_activations(A=A_train)

        A_eval = A_EVAL_DICT[node]
        A_eval_normalized, _ = normalize_activations(A=A_eval, scale=scale)

        _hids_to_drop = []

        for l1_coeff in (0.5, 1.0, 2.5, DefaultConfig.L1_COEFF):
            for lr in (DefaultConfig.LR, ):
                for batch_size in (512, ):
                    for dict_mult in (8, ):
                        encoder_state_dict = None
                        optimizer_state_dict = None
                        scheduler_state_dict = None
                        metrics_list = []
                        d_hidden = dict_mult * 64
                        pbar = tqdm(list(zip(CHECKPOINT_STEPS, CHECKPOINT_STEPS[1:])), disable=False)
                        for start_epoch, end_epoch in pbar:
                            encoder_state_dict, optimizer_state_dict, scheduler_state_dict, metrics = train_vanilla(
                                A=A_train_normalized,
                                start_epoch=start_epoch,
                                d_hidden=d_hidden,
                                end_epoch=end_epoch,
                                batch_size=batch_size,
                                encoder_state_dict=encoder_state_dict,
                                optimizer_state_dict=optimizer_state_dict,
                                scheduler_state_dict=scheduler_state_dict,
                                l1_coeff=l1_coeff,
                                lr=lr,
                                resample_epochs=RESAMPLE_EPOCHS,
                                final_decay_start=FINAL_DECAY_START,
                                final_decay_end=FINAL_DECAY_END,
                            )
                            # metrics = storage.unwrap(metrics)

                            ### compute the logitdiff recovered metric
                            encoder = get_vanilla(d_activation=64, d_hidden=d_hidden, encoder_state_dict=encoder_state_dict)
                            # logitdiff_loss = get_logitdiff_loss(
                            #     encoder=encoder, batch_size=100, 
                            #     encoder_normalization_scale=scale,
                            #     prompts=P_eval,
                            #     clean_ld=CLEAN_LD_EVAL_MEAN,
                            #     mean_ablated_ld=MEAN_ABLATED_LD_EVAL_DICT[node],
                            #     node=node,
                            # )
                            # for elt in metrics: elt['ld_loss'] = storage.unwrap(logitdiff_loss)

                            # metrics_list.append(metrics)

                            ### compute high F1-score features
                            # top_f1_features, top_f1_scores = get_high_f1_features(
                            #     encoder=encoder,
                            #     attributes=[('io',), ('s',), ('io_pos',), ],
                            #     A_normalized=A_eval_normalized,
                            #     prompt_feature_idxs=P_eval_feature_idxs,
                            #     topk=d_hidden,
                            # )

                            # autointerp_fast(
                            #     A_normalized=A_eval_normalized,
                            #     encoder=encoder,
                            #     features=[('io',), ('s',), ('io_pos',), ],
                            #     features_to_group=[('io',), ('s',), ],
                            #     max_group_size=10,
                            #     prompt_feature_idxs=P_eval_feature_idxs,
                            #     feature_batch_size=None,
                            # )



                            ### interp-based edits
                            # for attribute in [('io',), ('io_pos',), ]:
                            #     for num_exchange in (2, 4, 8):
                            #         A_eval_cf = A_EVAL_CF_DICT[attribute][node]
                            #         A_eval_cf_normalized, _ = normalize_activations(A=A_eval_cf, scale=scale)
                            #         A_edited, features_removed, features_added = get_edit_using_f1_scores(
                            #             encoder=encoder,
                            #             A_clean_normalized = A_eval_normalized,
                            #             A_cf_normalized = A_eval_cf_normalized,
                            #             clean_prompts=P_eval,
                            #             cf_prompts=CF_PROMPTS_DICT[attribute],
                            #             clean_feature_idxs=P_eval_feature_idxs,
                            #             cf_feature_idxs=P_eval_cf_feature_idxs[attribute],
                            #             attribute=attribute,
                            #             high_f1_features_dict=top_f1_features,
                            #             normalization_scale=scale,
                            #             num_exchange=num_exchange,
                            #         )
                            #         pbar.set_description(f'Done with {node.displayname} {attribute}')

                            ### interp-agnostic edits
                            for attribute in [('io',), ('io_pos',)]:
                                for num_exchange in (4, 8, 16):
                                    A_eval_cf = A_EVAL_CF_DICT[attribute][node]
                                    A_eval_cf_normalized, _ = normalize_activations(A=A_eval_cf, scale=scale)
                                    A_edited, best_features, best_scores, edited_clean, edited_cf = get_edit_using_sae_opt(
                                        A_clean_normalized=A_eval_normalized,
                                        A_cf_normalized=A_eval_cf_normalized,
                                        encoder=encoder,
                                        num_exchange=num_exchange,
                                        batch_size=200,
                                        normalization_scale=scale,
                                    )
                                    for thing in [A_edited, best_features, best_scores, edited_clean, edited_cf]:
                                        _hids_to_drop.append(thing.hid)
                                    
                            pbar.set_description(f'Done with {node.displayname} {l1_coeff}')

                        # all_metrics = [elt for x in metrics_list for elt in x]
                        # metrics_df = pd.DataFrame(all_metrics)
                        # metrics_df['l1_coeff'] = l1_coeff
                        # metrics_df['lr'] = lr
                        # metrics_df['dict_mult'] = dict_mult
                        # metrics_df['node'] = node.displayname
                        # metrics_df['batch_size'] = batch_size
                        # metrics_dfs.append(metrics_df)
        storage.commit()
        # for hid in _hids_to_drop:
        #     if hid in storage.atoms.cache:
        #         del storage.atoms.cache[hid]
        storage.atoms.clear()

    metrics_df_vanilla = pd.concat(metrics_dfs)

In [None]:
storage.atoms.clear()

In [None]:
storage.commit()

In [None]:
alt.Chart(metrics_df_vanilla.query('1000 < epoch < 2000 and l1_coeff == 2.5')).mark_line().encode(
    x='epoch',
    y='ld_loss',
    color='node:N',
    strokeDash='l1_coeff:N',
).properties(width=800, height=400)

In [None]:
metrics_df.query('epoch > 0 and l2_loss < 30.0').node.nunique()

## Gated SAEs

In [13]:
from torch.optim import Adam

with storage:
    metrics_dfs = []
    for node in tqdm(NODES):
        A_train = A_TRAIN_DICT[node]
        A_train_normalized, scale = normalize_activations(A=A_train)

        A_eval = A_EVAL_DICT[node]
        A_eval_normalized, _ = normalize_activations(A=A_eval, scale=scale)

        for l1_coeff in (0.5, 1.0, 2.5, DefaultConfig.L1_COEFF):
            for lr in (DefaultConfig.LR, ):
                for batch_size in (512, ):
                    for dict_mult in (8, ):
                        encoder_state_dict = None
                        optimizer_state_dict = None
                        scheduler_state_dict = None
                        metrics_list = []
                        d_hidden = dict_mult * 64
                        pbar = tqdm(list(zip(CHECKPOINT_STEPS, CHECKPOINT_STEPS[1:])), disable=True)
                        for start_epoch, end_epoch in pbar:
                            encoder_state_dict, optimizer_state_dict, scheduler_state_dict, metrics = train_gated(
                                A=A_train_normalized,
                                start_epoch=start_epoch,
                                d_hidden=d_hidden,
                                end_epoch=end_epoch,
                                batch_size=batch_size,
                                encoder_state_dict=encoder_state_dict,
                                optimizer_state_dict=optimizer_state_dict,
                                scheduler_state_dict=scheduler_state_dict,
                                l1_coeff=l1_coeff,
                                lr=lr,
                                resample_epochs=RESAMPLE_EPOCHS,
                                final_decay_start=FINAL_DECAY_START,
                                final_decay_end=FINAL_DECAY_END,
                            )

                            # metrics = storage.unwrap(metrics)

                            # ### compute the logitdiff recovered metric
                            encoder = get_gated(d_activation=64, d_hidden=d_hidden, encoder_state_dict=encoder_state_dict)
                            # logitdiff_loss = get_logitdiff_loss(
                            #     encoder=encoder, batch_size=100, 
                            #     encoder_normalization_scale=scale,
                            #     prompts=P_eval,
                            #     clean_ld=CLEAN_LD_EVAL_MEAN,
                            #     mean_ablated_ld=MEAN_ABLATED_LD_EVAL_DICT[node],
                            #     node=node,
                            # )
                            # for elt in metrics: elt['ld_loss'] = storage.unwrap(logitdiff_loss)

                            # metrics_list.append(metrics)

                            ### compute high F1-score features
                            # top_f1_features, top_f1_scores = get_high_f1_features(
                            #     encoder=encoder,
                            #     attributes=[('io',), ('s',), ('io_pos',), ],
                            #     A_normalized=A_eval_normalized,
                            #     prompt_feature_idxs=P_eval_feature_idxs,
                            #     topk=d_hidden,
                            # )

                            # autointerp_fast(
                            #     A_normalized=A_eval_normalized,
                            #     encoder=encoder,
                            #     features=[('io',), ('s',), ('io_pos',), ],
                            #     features_to_group=[('io',), ('s',), ],
                            #     max_group_size=10,
                            #     prompt_feature_idxs=P_eval_feature_idxs,
                            #     feature_batch_size=None,
                            # )

                            # for attribute in [('io',), ('io_pos',), ]:
                            #     for num_exchange in (2, 4, 8):
                            #         A_eval_cf = A_EVAL_CF_DICT[attribute][node]
                            #         A_eval_cf_normalized, _ = normalize_activations(A=A_eval_cf, scale=scale)
                            #         A_edited, features_removed, features_added = get_edit_using_f1_scores(
                            #             encoder=encoder,
                            #             A_clean_normalized = A_eval_normalized,
                            #             A_cf_normalized = A_eval_cf_normalized,
                            #             clean_prompts=P_eval,
                            #             cf_prompts=CF_PROMPTS_DICT[attribute],
                            #             clean_feature_idxs=P_eval_feature_idxs,
                            #             cf_feature_idxs=P_eval_cf_feature_idxs[attribute],
                            #             attribute=attribute,
                            #             high_f1_features_dict=top_f1_features,
                            #             normalization_scale=scale,
                            #             num_exchange=num_exchange,
                            #         )
                            #         pbar.set_description(f'Done with {node.displayname} {attribute}')

                            ### interp-agnostic edits
                            for attribute in [('io',), ('io_pos',)]:
                                for num_exchange in (4, 8, 16):
                                    A_eval_cf = A_EVAL_CF_DICT[attribute][node]
                                    A_eval_cf_normalized, _ = normalize_activations(A=A_eval_cf, scale=scale)
                                    A_edited, best_features, best_scores, edited_clean, edited_cf = get_edit_using_sae_opt(
                                        A_clean_normalized=A_eval_normalized,
                                        A_cf_normalized=A_eval_cf_normalized,
                                        encoder=encoder,
                                        num_exchange=num_exchange,
                                        batch_size=200,
                                        normalization_scale=scale,
                                    )
                                    # for thing in [A_edited, best_features, best_scores, edited_clean, edited_cf]:
                                    #     _hids_to_drop.append(thing.hid)
                                    
                            pbar.set_description(f'Done with {node.displayname} {l1_coeff}')

                        # all_metrics = [elt for x in metrics_list for elt in x]
                        # metrics_df = pd.DataFrame(all_metrics)
                        # metrics_df['l1_coeff'] = l1_coeff
                        # metrics_df['lr'] = lr
                        # metrics_df['dict_mult'] = dict_mult
                        # metrics_df['node'] = node.displayname
                        # metrics_df['batch_size'] = batch_size
                        # metrics_dfs.append(metrics_df)
        storage.commit()
        storage.atoms.clear()

    metrics_df_gated = pd.concat(metrics_dfs)

 91%|█████████ | 50/55 [5:52:35<44:32, 534.55s/it]  

In [None]:
features_removed

In [None]:
storage.commit()

In [None]:
storage = Storage(db_path=DB_PATH)

## Attribution SAEs

### Gradient collection

In [None]:
with storage:
    COMPUTING = False
    P_train_raw = storage.unwrap(P_train)
    n_total = len(P_train_raw)
    n_batches = 100
    grads_parts = []

    for i in range(n_batches):
        # print(f'Batch {i}/{n_batches}')
        start = i * (n_total // n_batches)
        end = (i + 1) * (n_total // n_batches)
        prompts = P_train_raw[start:end]
        grads = collect_gradients(
            prompts=prompts,
            nodes=NODES,
            batch_size=20,
        )
        if COMPUTING:
            storage.commit()
            storage.atoms.clear()
        grads_parts.append(grads)
# now, concatenate the parts
grads = {node: torch.cat([storage.unwrap(x)[node] for x in grads_parts]).cuda() for node in tqdm(NODES)}

### Training

In [None]:
torch.cuda.is_available()

In [None]:
from torch.optim import Adam

with storage:
    metrics_dfs = []
    for node in NODES:
        A_train = A_TRAIN_DICT[node]
        A_grad = grads[node]
        A_train_normalized, scale = normalize_activations(A=A_train)
        A_grad_normalized = normalize_grad(A_grad=A_grad, scale=scale)

        for l1_coeff in (DefaultConfig.L1_COEFF, 0.5, 1.0, 2.5 ):
            for lr in (DefaultConfig.LR, ):
                for batch_size in (512, ):
                    for dict_mult in (8, ):
                        encoder_state_dict = None
                        optimizer_state_dict = None
                        scheduler_state_dict = None
                        metrics_list = []
                        d_hidden = dict_mult * 64
                        for start_epoch, end_epoch in zip(CHECKPOINT_STEPS, CHECKPOINT_STEPS[1:]):
                            encoder_state_dict, optimizer_state_dict, scheduler_state_dict, metrics = train_attribution(
                                A=A_train_normalized,
                                A_grad=A_grad_normalized,
                                start_epoch=start_epoch,
                                d_hidden=d_hidden,
                                end_epoch=end_epoch,
                                batch_size=batch_size,
                                encoder_state_dict=encoder_state_dict,
                                optimizer_state_dict=optimizer_state_dict,
                                scheduler_state_dict=scheduler_state_dict,
                                l1_coeff=l1_coeff,
                                # we scale these losses based on the ratio to the other losses, so that they are on the same scale at the start of training
                                attribution_sparsity_penalty=1000.0,
                                unexplained_attribution_penalty=1000.0,
                                lr=lr,
                                resample_epochs=RESAMPLE_EPOCHS,
                                final_decay_start=FINAL_DECAY_START,
                                final_decay_end=FINAL_DECAY_END,
                            )
                            metrics = storage.unwrap(metrics)

                            ### compute the logitdiff recovered metric
                            encoder = get_attribution(d_activation=64, d_hidden=d_hidden, encoder_state_dict=encoder_state_dict)
                            logitdiff_loss = get_logitdiff_loss(
                                encoder=encoder, batch_size=100, 
                                encoder_normalization_scale=scale,
                                prompts=P_eval,
                                clean_ld=CLEAN_LD_EVAL_MEAN,
                                mean_ablated_ld=MEAN_ABLATED_LD_EVAL_DICT[node],
                                node=node,
                            )
                            for elt in metrics: elt['ld_loss'] = storage.unwrap(logitdiff_loss)

                            metrics_list.append(metrics)

                            metrics = storage.unwrap(metrics)
                            # # for elt in metrics: elt['ld_loss'] = storage.unwrap(logitdiff_loss)
                            metrics_list.append(metrics)
                        all_metrics = [elt for x in metrics_list for elt in x]
                        metrics_df = pd.DataFrame(all_metrics)
                        metrics_df['l1_coeff'] = l1_coeff
                        metrics_df['lr'] = lr
                        metrics_df['dict_mult'] = dict_mult
                        metrics_df['node'] = node.displayname
                        metrics_df['batch_size'] = batch_size
                        # metrics_df['total_loss'] = (metrics_df['l0_loss'] + metrics_df['l1_loss'] * l1_coeff + metrics_df['attribution_sparsity_loss'] * 1000.0 +
                        #                             metrics_df['unexplained_attribution_loss'] * 1000.0)
                        metrics_dfs.append(metrics_df)
        storage.commit()

    metrics_df_attribution = pd.concat(metrics_dfs)

In [None]:
node = metrics_df_attribution.node.unique()[0]
x = metrics_df_vanilla.query(f'l1_coeff == 5.0')
x = x[x.node == node]
x['type'] = 'vanilla'
y = metrics_df_attribution.copy()
y['type'] = 'attribution'
combined_df = pd.concat([x, y], ignore_index=True)

# Interpretability

# Sparse controllability