In [9]:
from ioi_utils import *
from sae_variants import *
from training import *
from mandala._next.imports import *
from mandala._next.common_imports import *

# Circuit setup

In [None]:
from circuit_utils import *
torch.set_printoptions(sci_mode=False)
if 'model' in locals():
    MODELS[MODEL_ID] = model

HEAD_CLASS_FIG = {
    'nm': 'Name Mover',
    'bnm': 'Backup Name Mover',
    'ind': 'Induction',
    'nnm': 'Negative Name Mover',
    'si': 'S-Inhibition',
    'dt': 'Duplicate Token',
    'pt': 'Previous Token',
}

COMPONENT_NAME_FIG = {
    'k': 'Key',
    'v': 'Value',
    'q': 'Query',
    'z': 'Attn Output',
}

CROSS_SECTION_FIG = {
    'ind+dt@z': 'Ind+DT out',
    'nm+bnm@q': '(B)NM q',
    'nm+bnm@qk': '(B)NM qk',
    'nm+bnm@z': '(B)NM out',
    'si@v': 'S-I v',
    'si@z': 'S-I out',
}

c = Circuit()
paper_cross_sections = [
    # IO
    (c.zs(c.nm + c.bnm), ('io',), 'nm+bnm@z'),
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('io',), 'nm+bnm@qk'),
    (c.qs(c.nm + c.bnm), ('io',), 'nm+bnm@q'),
    # S
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('s',), 'nm+bnm@qk'),
    (c.qs(c.nm + c.bnm), ('s',), 'nm+bnm@q'),
    (c.vs(c.si), ('s',), 'si@v'),
    (c.zs(c.si), ('s',), 'si@z'),
    (c.zs(c.ind) + c.zs(c.dt), ('s',), 'ind+dt@z'),
    # Pos
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('io_pos',), 'nm+bnm@qk'),
    (c.qs(c.nm + c.bnm), ('io_pos',), 'nm+bnm@q'),
    (c.zs(c.si), ('io_pos',), 'si@z'),
    (c.vs(c.si), ('io_pos',), 'si@v'),
    # Pos + S
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('io_pos', 's'), 'nm+bnm@qk'),
    (c.zs(c.si), ('io_pos', 's'), 'si@z'),
    (c.vs(c.si), ('io_pos', 's'), 'si@v'),
    (c.zs(c.ind) + c.zs(c.dt), ('io_pos', 's'), 'ind+dt@z'),
    (c.zs(c.ind) + c.zs(c.dt), ('io_pos',), 'ind+dt@z'),
    # All
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('io', 'io_pos', 's'), 'nm+bnm@qk'),
]

locations_displaynames = {
    'nm+bnm@z': '(B)NM out',
    'nm+bnm@qk': '(B)NM qk',
    'nm+bnm@q': '(B)NM q',
    'si@v': 'S-I v',
    'si@z': 'S-I out',
    'ind+dt@z': 'Ind+DT out',
}

NODES = c.zs(c.nm + c.bnm) + c.qs(c.nm + c.bnm) + c.zs(c.si) + [n for n in c.vs(c.si) if n.seq_pos == 's2'] + c.zs(c.ind) + c.zs(c.dt) + c.ks(c.nm + c.bnm)

In [None]:
DB_PATH = '/media/amakelov/SanDisk1TB/paper_sprint/test.db'

In [None]:
storage = Storage(db_path=DB_PATH)

In [None]:
model = get_model()
MODELS[MODEL_ID] = model

In [None]:
# set logging level to debug
from mandala._next.common_imports import logger
import logging
logger.setLevel(logging.INFO)

# Preparing datasets

In [None]:
# nodes = list(c.nodes.keys())
# circuit_nodes = list(c.nodes.keys())

with storage:

    ############################################################################ 
    ### prompt dataset for training supervised features
    ############################################################################ 
    P_train = generate_prompts(
        distribution=full_distribution,
        patterns=['ABB', 'BAB'],
        prompts_per_pattern=10_000,
        random_seed=0,
    )
    N_TRAIN = len(storage.unwrap(P_train))
    ### activations for training supervised features
    As_train = run_with_cache(
        prompts=P_train, 
        nodes=NODES,
        batch_size=100,
        model_id=MODEL_ID,
        verbose=True,
    )
    A_TRAIN_DICT = {node: A for node, A in zip(NODES, As_train)}

    ### precompute the mean logit difference for clean training data
    logits_train_clean = run_with_hooks(prompts=P_train, hooks=[], batch_size=200,)
    CLEAN_LD_MEAN = (storage.unwrap(logits_train_clean)[:, 0] - storage.unwrap(logits_train_clean)[:, 1]).mean().item()

    ### precompute the mean-ablated logit difference when ablating each node
    A_TRAIN_MEAN_DICT = {node: get_dataset_mean(A) for node, A in A_TRAIN_DICT.items()}

    MEAN_ABLATED_LD_DICT = {}
    for node, A in A_TRAIN_DICT.items():
        MEAN_ABLATED_LD_DICT[node] = compute_mean_ablated_lds(
            node=node, prompts=P_train, A_mean=A_TRAIN_MEAN_DICT[node], batch_size=200,
        )

    ############################################################################ 
    ### prompt dataset for editing
    ############################################################################ 
    N_NAMES = len(NAMES)
    editing_base_distribution = copy.deepcopy(full_distribution)
    editing_base_distribution.names = editing_base_distribution.names[:N_NAMES // 2]
    editing_source_distribution = copy.deepcopy(full_distribution)
    editing_source_distribution.names = editing_source_distribution.names[N_NAMES // 2:]

    P_edit = generate_prompts(
        distribution=editing_base_distribution,
        patterns=['ABB', 'BAB'],
        prompts_per_pattern=2500,
        random_seed=1,
    )
    As_to_edit = run_with_cache(
        prompts=P_edit, 
        nodes=NODES,
        batch_size=100,
        model_id=MODEL_ID,
        verbose=True,
    )
    A_EDIT_DICT = {node: A for node, A in zip(NODES, As_to_edit)}

    N_EDIT = len(storage.unwrap(P_edit))
    N_NAMES_EDIT_SOURCE = len(editing_source_distribution.names)

    ############################################################################ 
    ### Compute counterfactual prompts and activations
    ############################################################################ 
    FEATURE_SUBSETS = [('io_pos',), ('s',), ('io',), ] # ('s', 'io_pos',), ('io', 'io_pos'), ('s', 'io',), ('io_pos', 's', 'io',), ]

    CF_PROMPTS_DICT = {}
    for feature_subset in FEATURE_SUBSETS:
        CF_PROMPTS_DICT[feature_subset] = get_cf_prompts(
            prompts=P_edit, 
            features=feature_subset,
            io_targets=generate_name_samples(N_EDIT, editing_source_distribution.names[:N_NAMES_EDIT_SOURCE // 2]),
            s_targets=generate_name_samples(N_EDIT, editing_source_distribution.names[N_NAMES_EDIT_SOURCE//2:]),     
        )
    ### Compute counterfactual activations
    As_counterfactual = {}
    for feature_subset, cf_prompts in tqdm(CF_PROMPTS_DICT.items()):
        As_counterfactual[feature_subset] = run_with_cache(
            prompts=cf_prompts, 
            nodes=NODES,
            batch_size=100,
            model_id=MODEL_ID,
            verbose=True,
        )
    for feature_subset in As_counterfactual:
        As_counterfactual[feature_subset] = {node: As_counterfactual[feature_subset][i] for i, node in enumerate(NODES)}

In [None]:
storage.cache_info()

# Computing supervised features

In [None]:
with storage:
    SUPERVISED_FEATURES_DICT = {}
    SUPERVISED_RECONSTRUCTIONS_DICT = {}
    for node, A in tqdm(A_TRAIN_DICT.items()):
        for eventually in ['independent',]: # 'coupled', 'names', ]:
            for codes_type in ('mean',):  # 'mse'):
                node_parametrization = get_parametrization(node=node, eventually=eventually, use_names= (eventually == 'names'))
                node_features = FEATURE_CONFIGURATIONS[node_parametrization]
                code_getter = get_mean_codes if codes_type == 'mean' else lambda features, A, prompts: train_mse_codes(features=features, A=A, prompts=prompts, manual_bias=True)
                codes, reconstructions = code_getter(
                    features=node_features,
                    A=A,
                    prompts=P_train,
                )
                SUPERVISED_FEATURES_DICT[(node, node_parametrization, codes_type)] = codes
                SUPERVISED_RECONSTRUCTIONS_DICT[(node, node_parametrization, codes_type)] = reconstructions

# Training SAEs

In [10]:
### define a uniform schedule for all training runs

# use exponentially spread-out checkpoints for the very early stages of training
# measure right before resampling, as well as in the middle between resamplings
# measure before and after the final LR decay
# use two resampling stages, as it seems effects diminish after the first one
CHECKPOINT_STEPS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 500, 750, 1000, 1250, 1500, 2000]
RESAMPLE_EPOCHS = [501, 1001, ]
FINAL_DECAY_START = 1500 # decay the LR for the last 25% of training
FINAL_DECAY_END = 2000

### define hyperparam grid

## Vanilla SAEs

In [None]:
from torch.optim import Adam

with storage:
    metrics_dfs = []
    for node in NODES:
        A_train = A_TRAIN_DICT[node]
        A_train_normalized, scale = normalize_activations(A=A_train)

        for l1_coeff in (DefaultConfig.L1_COEFF, ):
            for lr in (DefaultConfig.LR, ):
                for batch_size in (512, ):
                    for dict_mult in (8, ):
                        encoder_state_dict = None
                        optimizer_state_dict = None
                        scheduler_state_dict = None
                        metrics_list = []
                        d_hidden = dict_mult * 64
                        for start_epoch, end_epoch in zip(CHECKPOINT_STEPS, CHECKPOINT_STEPS[1:]):
                            encoder_state_dict, optimizer_state_dict, scheduler_state_dict, metrics = train_vanilla(
                                A=A_train_normalized,
                                start_epoch=start_epoch,
                                d_hidden=d_hidden,
                                end_epoch=end_epoch,
                                batch_size=batch_size,
                                encoder_state_dict=encoder_state_dict,
                                optimizer_state_dict=optimizer_state_dict,
                                scheduler_state_dict=scheduler_state_dict,
                                l1_coeff=l1_coeff,
                                lr=lr,
                                resample_epochs=RESAMPLE_EPOCHS,
                                final_decay_start=FINAL_DECAY_START,
                                final_decay_end=FINAL_DECAY_END,
                            )
                            metrics_list.append(metrics)
                            all_metrics = [elt for x in metrics_list for elt in storage.unwrap(x)]
                            metrics_df = pd.DataFrame(all_metrics)
                            metrics_df['l1_coeff'] = l1_coeff
                            metrics_df['lr'] = lr
                            metrics_df['dict_mult'] = dict_mult
                            metrics_df['node'] = node.displayname
                            metrics_df['batch_size'] = batch_size
                            metrics_dfs.append(metrics_df)

    metrics_df = pd.concat(metrics_dfs)

In [None]:
cf = storage.cf(train_vanilla)
cf.delete_calls()
storage.cleanup_refs()

In [None]:
alt.Chart(metrics_df).mark_line().encode(
    x='epoch',
    y='l0_loss',
    color='batch_size:N',
).interactive().properties(width=800, height=400)

## Gated SAEs

In [None]:
from torch.optim import Adam

with storage:
    metrics_dfs = []
    for node in NODES[:1]:
        A_train = A_TRAIN_DICT[node]
        A_train_normalized, scale = normalize_activations(A=A_train)

        for l1_coeff in (DefaultConfig.L1_COEFF, ):
            for lr in (DefaultConfig.LR, ):
                for batch_size in (512, ):
                    for dict_mult in (8, ):
                        encoder_state_dict = None
                        optimizer_state_dict = None
                        scheduler_state_dict = None
                        metrics_list = []
                        d_hidden = dict_mult * 64
                        for start_epoch, end_epoch in zip(CHECKPOINT_STEPS, CHECKPOINT_STEPS[1:]):
                            encoder_state_dict, optimizer_state_dict, scheduler_state_dict, metrics = train_gated(
                                A=A_train_normalized,
                                start_epoch=start_epoch,
                                d_hidden=d_hidden,
                                end_epoch=end_epoch,
                                batch_size=batch_size,
                                encoder_state_dict=encoder_state_dict,
                                optimizer_state_dict=optimizer_state_dict,
                                scheduler_state_dict=scheduler_state_dict,
                                l1_coeff=l1_coeff,
                                lr=lr,
                                resample_epochs=RESAMPLE_EPOCHS,
                                final_decay_start=FINAL_DECAY_START,
                                final_decay_end=FINAL_DECAY_END,
                            )
                            metrics_list.append(metrics)
                            all_metrics = [elt for x in metrics_list for elt in storage.unwrap(x)]
                            metrics_df = pd.DataFrame(all_metrics)
                            metrics_df['l1_coeff'] = l1_coeff
                            metrics_df['lr'] = lr
                            metrics_df['dict_mult'] = dict_mult
                            metrics_df['node'] = node.displayname
                            metrics_df['batch_size'] = batch_size
                            metrics_dfs.append(metrics_df)

    metrics_df = pd.concat(metrics_dfs)

## Anthropic April update

In [None]:
from torch.optim import Adam
# import cosine scheduler 
from torch.optim.lr_scheduler import CosineAnnealingLR

checkpoint_steps = [0, 100, 200]

encoder = VanillaAutoEncoder(d_activation=100, dict_mult=10).cuda()
optim = Adam(encoder.parameters())
scheduler = MidTrainingWarmupScheduler(optimizer=optim, num_warmup_steps=100, last_epoch=-1)
encoder_state_dict = encoder.state_dict()
optimizer_state_dict = optim.state_dict()
scheduler_state_dict = scheduler.state_dict()
for start_epoch, end_epoch in zip(checkpoint_steps, checkpoint_steps[1:]):
    encoder_state_dict, optimizer_state_dict, scheduler_state_dict, metrics = train_vanilla(
        A=torch.randn(100, 100).cuda(),
        d_activation=100,
        dict_mult=10,
        start_epoch=start_epoch,
        end_epoch=end_epoch,
        encoder_state_dict=encoder_state_dict,
        optimizer_state_dict=optimizer_state_dict,
        scheduler_state_dict=scheduler_state_dict,
        batch_size=10,
        l1_coeff=1.0,
        resample_every=50,
    )

In [None]:
55 * 20 / 60