In [23]:
from ioi_utils import *
from sae_variants import *
from training import *
from mandala._next.imports import *
from mandala._next.common_imports import *

In [2]:
from circuit_utils import *
torch.set_printoptions(sci_mode=False)
if 'model' in locals():
    MODELS[MODEL_ID] = model

HEAD_CLASS_FIG = {
    'nm': 'Name Mover',
    'bnm': 'Backup Name Mover',
    'ind': 'Induction',
    'nnm': 'Negative Name Mover',
    'si': 'S-Inhibition',
    'dt': 'Duplicate Token',
    'pt': 'Previous Token',
}

COMPONENT_NAME_FIG = {
    'k': 'Key',
    'v': 'Value',
    'q': 'Query',
    'z': 'Attn Output',
}

CROSS_SECTION_FIG = {
    'ind+dt@z': 'Ind+DT out',
    'nm+bnm@q': '(B)NM q',
    'nm+bnm@qk': '(B)NM qk',
    'nm+bnm@z': '(B)NM out',
    'si@v': 'S-I v',
    'si@z': 'S-I out',
}

c = Circuit()
paper_cross_sections = [
    # IO
    (c.zs(c.nm + c.bnm), ('io',), 'nm+bnm@z'),
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('io',), 'nm+bnm@qk'),
    (c.qs(c.nm + c.bnm), ('io',), 'nm+bnm@q'),
    # S
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('s',), 'nm+bnm@qk'),
    (c.qs(c.nm + c.bnm), ('s',), 'nm+bnm@q'),
    (c.vs(c.si), ('s',), 'si@v'),
    (c.zs(c.si), ('s',), 'si@z'),
    (c.zs(c.ind) + c.zs(c.dt), ('s',), 'ind+dt@z'),
    # Pos
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('io_pos',), 'nm+bnm@qk'),
    (c.qs(c.nm + c.bnm), ('io_pos',), 'nm+bnm@q'),
    (c.zs(c.si), ('io_pos',), 'si@z'),
    (c.vs(c.si), ('io_pos',), 'si@v'),
    # Pos + S
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('io_pos', 's'), 'nm+bnm@qk'),
    (c.zs(c.si), ('io_pos', 's'), 'si@z'),
    (c.vs(c.si), ('io_pos', 's'), 'si@v'),
    (c.zs(c.ind) + c.zs(c.dt), ('io_pos', 's'), 'ind+dt@z'),
    (c.zs(c.ind) + c.zs(c.dt), ('io_pos',), 'ind+dt@z'),
    # All
    (c.qs(c.nm + c.bnm) + c.ks(c.nm + c.bnm), ('io', 'io_pos', 's'), 'nm+bnm@qk'),
]

locations_displaynames = {
    'nm+bnm@z': '(B)NM out',
    'nm+bnm@qk': '(B)NM qk',
    'nm+bnm@q': '(B)NM q',
    'si@v': 'S-I v',
    'si@z': 'S-I out',
    'ind+dt@z': 'Ind+DT out',
}

could not find model


In [3]:
DB_PATH = '/media/amakelov/SanDisk1TB/paper_sprint/test.db'

In [4]:
if os.path.exists(DB_PATH):
    os.remove(DB_PATH)

In [5]:
storage = Storage(db_path=DB_PATH)

In [6]:
model = get_model()
MODELS[MODEL_ID] = model



Loaded pretrained model gpt2-small into HookedTransformer


In [7]:
# set logging level to debug
from mandala._next.common_imports import logger
import logging
logger.setLevel(logging.DEBUG)

# Preparing datasets

In [None]:
# nodes = list(c.nodes.keys())
circuit_nodes = list(c.nodes.keys())
with storage:

    ############################################################################ 
    ### prompt dataset for training supervised features
    ############################################################################ 
    P_train = generate_prompts(
        distribution=full_distribution,
        patterns=['ABB', 'BAB'],
        prompts_per_pattern=10_000,
        random_seed=0,
    )
    N_TRAIN = len(storage.unwrap(P_train))
    ### activations for training supervised features
    As_train = run_with_cache(
        prompts=P_train, 
        nodes=circuit_nodes,
        batch_size=100,
        model_id=MODEL_ID,
        verbose=True,
    )
    A_TRAIN_DICT = {node: A for node, A in zip(circuit_nodes, As_train)}

    ############################################################################ 
    ### prompt dataset for editing
    ############################################################################ 
    N_NAMES = len(NAMES)
    editing_base_distribution = copy.deepcopy(full_distribution)
    editing_base_distribution.names = editing_base_distribution.names[:N_NAMES // 2]
    editing_source_distribution = copy.deepcopy(full_distribution)
    editing_source_distribution.names = editing_source_distribution.names[N_NAMES // 2:]

    P_edit = generate_prompts(
        distribution=editing_base_distribution,
        patterns=['ABB', 'BAB'],
        prompts_per_pattern=2500,
        random_seed=1,
    )
    As_to_edit = run_with_cache(
        prompts=P_edit, 
        nodes=circuit_nodes,
        batch_size=100,
        model_id=MODEL_ID,
        verbose=True,
    )
    A_EDIT_DICT = {node: A for node, A in zip(circuit_nodes, As_to_edit)}

    N_EDIT = len(storage.unwrap(P_edit))
    N_NAMES_EDIT_SOURCE = len(editing_source_distribution.names)

    ############################################################################ 
    ### Compute counterfactual prompts and activations
    ############################################################################ 
    FEATURE_SUBSETS = [('io_pos',), ('s',), ('io',), ] # ('s', 'io_pos',), ('io', 'io_pos'), ('s', 'io',), ('io_pos', 's', 'io',), ]

    CF_PROMPTS_DICT = {}
    for feature_subset in FEATURE_SUBSETS:
        CF_PROMPTS_DICT[feature_subset] = get_cf_prompts(
            prompts=P_edit, 
            features=feature_subset,
            io_targets=generate_name_samples(N_EDIT, editing_source_distribution.names[:N_NAMES_EDIT_SOURCE // 2]),
            s_targets=generate_name_samples(N_EDIT, editing_source_distribution.names[N_NAMES_EDIT_SOURCE//2:]),     
        )
    ### Compute counterfactual activations
    As_counterfactual = {}
    for feature_subset, cf_prompts in tqdm(CF_PROMPTS_DICT.items()):
        As_counterfactual[feature_subset] = run_with_cache(
            prompts=cf_prompts, 
            nodes=circuit_nodes,
            batch_size=100,
            model_id=MODEL_ID,
            verbose=True,
        )
    for feature_subset in As_counterfactual:
        As_counterfactual[feature_subset] = {node: As_counterfactual[feature_subset][i] for i, node in enumerate(circuit_nodes)}

# Computing supervised features

In [None]:
with storage:
    SUPERVISED_FEATURES_DICT = {}
    SUPERVISED_RECONSTRUCTIONS_DICT = {}
    for node, A in tqdm(list(zip(circuit_nodes, As_train))):
        for eventually in ['independent',]: # 'coupled', 'names', ]:
            for codes_type in ('mean',):  # 'mse'):
                node_parametrization = get_parametrization(node=node, eventually=eventually, use_names= (eventually == 'names'))
                node_features = FEATURE_CONFIGURATIONS[node_parametrization]
                code_getter = get_mean_codes if codes_type == 'mean' else lambda features, A, prompts: train_mse_codes(features=features, A=A, prompts=prompts, manual_bias=True)
                codes, reconstructions = code_getter(
                    features=node_features,
                    A=A,
                    prompts=P_train,
                )
                SUPERVISED_FEATURES_DICT[(node, node_parametrization, codes_type)] = codes
                SUPERVISED_RECONSTRUCTIONS_DICT[(node, node_parametrization, codes_type)] = reconstructions
   

# Training SAEs

In [None]:
from torch.optim import Adam
# import cosine scheduler 
from torch.optim.lr_scheduler import CosineAnnealingLR

checkpoint_steps = [0, 100, 200]

encoder = VanillaAutoEncoder(d_activation=100, dict_mult=10).cuda()
optim = Adam(encoder.parameters())
scheduler = MidTrainingWarmupScheduler(optimizer=optim, num_warmup_steps=100, last_epoch=-1)
encoder_state_dict = encoder.state_dict()
optimizer_state_dict = optim.state_dict()
scheduler_state_dict = scheduler.state_dict()
for start_epoch, end_epoch in zip(checkpoint_steps, checkpoint_steps[1:]):
    encoder_state_dict, optimizer_state_dict, scheduler_state_dict, metrics = train_vanilla(
        A=torch.randn(100, 100).cuda(),
        d_activation=100,
        dict_mult=10,
        start_epoch=start_epoch,
        end_epoch=end_epoch,
        encoder_state_dict=encoder_state_dict,
        optimizer_state_dict=optimizer_state_dict,
        scheduler_state_dict=scheduler_state_dict,
        batch_size=10,
        l1_coeff=1.0,
        resample_every=50,
    )