# Setup

In [1]:
%load_ext autoreload
%autoreload 2
from common_imports import *
from tqdm import tqdm
from ravel_dataset_builder import Prompt
from utils.intervention_utils import find_positions_in_tokens
from ravel_dataset_builder import RAVELEntityPromptData, evaluate_completion
from utils.differential_binary_masking import DifferentialBinaryMasking
from tlens_utils import Node
from sparse_control_utils import *
from mandala.imports import sess
from tlens_utils import *
import torch.nn.functional as F

MODEL_DEVICE = torch.device("cuda:0")
INFERENCE_DEVICE = torch.device("cuda:1") # Gemma-2B is too big 
def get_model(model_id: str,) -> HookedTransformer:
    llm_dtype = torch.float32 # weird non-determinism happens in bfloat16
    model = HookedTransformer.from_pretrained_no_processing(model_id, device=MODEL_DEVICE, dtype=llm_dtype, 
                                                            attn_implementation='eager' # to avoid non-determinism potentially?
                                                            )
    model.requires_grad_(False)
    model.eval()
    return model

# MODEL_ID = 'gemma-2-2b'
MODEL_ID = 'pythia-70m'

In [None]:
model = get_model(MODEL_ID)

# Load and filter RAVEL dataset

In [None]:
# for reference
RAVEL_ENTITIES = ('city', 'nobel_prize_winner', 'occupation', 'physical_object', 'verb')
full_entity_dataset = RAVELEntityPromptData.from_files('nobel_prize_winner', 'data', model.tokenizer)

In [None]:
sampled_entity_dataset = full_entity_dataset.downsample(8192)
print(f"Number of prompts sampled: {len(sampled_entity_dataset)}")

prompt_max_length = 48
batch_size = 64 # use smaller for gemma-2b
sampled_entity_dataset.generate_completions(model, model.tokenizer, max_length=prompt_max_length+8, prompt_max_length=prompt_max_length, batch_size=batch_size)
sampled_entity_dataset.evaluate_correctness()

# Filter correct completions
correct_data = sampled_entity_dataset.filter_correct()

# Filter top entities and templates
filtered_data = correct_data.filter_top_entities_and_templates(top_n_entities=400, top_n_templates_per_attribute=12)

# Calculate average accuracy
accuracy = sampled_entity_dataset.calculate_average_accuracy()
print(f"Average accuracy: {accuracy:.2%}")
print(f"Number of prompts remaining: {len(correct_data)}")
print(f"Number of entitites after filtering: {len(set([p.entity for p in list(correct_data.prompts.values())]))}")

# Setup for loading SAEs

In [22]:
sys.path.append('../..')
import sae_bench_utils.activation_collection as activation_collection
import sae_bench_utils.formatting_utils as formatting_utils
# import saebench_utils.activation_collection as activation_collection
# import saebench_utils.formatting_utils as formatting_utils
overview_df = formatting_utils.make_available_sae_df(for_printing=True)
from sae_lens import SAE as SAELensSAE
from sae_lens.sae import TopK
from tlens_utils import Node
from sparse_control_utils import *

def load_SAELensSAE(
        layer: int,
        expansion_factor: int = 2,
        k: Literal[20, 40, 80, 160, 320, 640] = 40,
        variant: Literal['standard', 'topk'] = 'topk',
        llm_name: Literal['gemma-2-2b', 'pythia70m'] = 'gemma-2-2b',
        ctx: int = 128,
        device: str = 'cuda',
    ) -> Tuple[SAELensSAE, dict, Optional[Tensor]]:
    """
    Load a pre-trained SAE from SAELens
    """
    k_to_trainer = {20: 0, 40: 1, 80: 2, 160: 3, 320: 4, 640: 5}
    trainer = k_to_trainer[k]
    # assert llm_name == 'gemma-2-2b', "only gemma-2-2b is supported for now"
    if llm_name == 'gemma-2-2b':
        release = f'sae_bench_{llm_name}_sweep_{variant}_ctx{ctx}_ef{expansion_factor}_0824'
        sae_name_prefix = f'{llm_name}_sweep_{variant}_ctx{ctx}_ef{expansion_factor}_0824'
    elif llm_name == 'pythia70m':
        if variant == 'standard':
            suffix = '0712'
        else:
            suffix = '0730'
        release = f'sae_bench_{llm_name}_sweep_{variant}_ctx{ctx}_{suffix}'
        sae_name_prefix = f'{llm_name}_sweep_{variant}_ctx{ctx}_{suffix}'
    sae_name_suffix = f'resid_post_layer_{layer}/trainer_{trainer}'
    sae_df = formatting_utils.make_available_sae_df(for_printing=False)
    sae_name = f'{sae_name_prefix}/{sae_name_suffix}'
    sae_id_to_name_map = sae_df.saes_map[release]
    sae_name_to_id_map = {v: k for k, v in sae_id_to_name_map.items()}
    sae_id = sae_name_to_id_map[sae_name]
    sae, cfg_dict, sparsity = SAELensSAE.from_pretrained(
        release=release,
        sae_id=sae_id,
        device=device,
    )
    sae = sae.to(device=device)
    if variant == 'topk':
        assert isinstance(sae.activation_fn, TopK), "This sae is not a topk sae, you probably have an old sae_lens version"
    if llm_name == 'gemma-2-2b':
        assert cfg_dict['activation_fn_kwargs']['k'] == k, f"Expected k={k}, got k={cfg_dict['activation_fn_kwargs']['k']}"
    sae.requires_grad_(False)
    return sae, cfg_dict, sparsity

In [44]:
LLM_NAME = 'pythia70m'
D_MODEL = model.cfg.d_model
N_LAYERS = len(model.blocks)
SAE_LAYERS = (3, 7, 11, 15, 19, ) if LLM_NAME == 'gemma-2-2b' else (3, 4)
LAYER_TO_IDX = {layer: idx for idx, layer in enumerate(SAE_LAYERS)}
NODES = {layer: Node(component_name='resid_post', layer=layer, seq_pos=None) for layer in SAE_LAYERS}
SAE_NODES = list(NODES.values())
ALL_NODES = [Node(component_name='resid_post', layer=layer, seq_pos=None) for layer in range(N_LAYERS)]

In [None]:
# test the above
sae = load_SAELensSAE(layer=3, k=40, llm_name='pythia70m', ctx=128, device='cuda')[0]

# Train differential binary masks

## Set up prompts and prompt pairs for interchange interventions

In [None]:
def sample_prompt_pairs(prompts: List[Prompt], N: int, random_seed: int = 0, attribute: Optional[str] = None) -> Tuple[List[Prompt], List[Prompt]]:
    """
    Sample pairs of prompts from a list of prompts (should be for the same 
    entity type). 

    Optionally, require the prompts in a pair to differ in the value of 
    a given attribute.
    """
    res_og = []
    res_cf = []
    np.random.seed(random_seed)
    if attribute is None:
        for i in range(N):
            p1, p2 = np.random.choice(prompts, size=2, replace=False)
            res_og.append(p1)
            res_cf.append(p2)
    else:
        for i in range(N):
            p1, p2 = np.random.choice(prompts, size=2, replace=False)
            attr_dict = sampled_entity_dataset.entity_attributes
            while attr_dict[p1.entity][attribute] == attr_dict[p2.entity][attribute]:
                p1, p2 = np.random.choice(prompts, size=2, replace=False)
            res_og.append(p1)
            res_cf.append(p2)
    return res_og, res_cf

ATTRIBUTE = 'Field'
prompts = list(correct_data.prompts.values())
prompts = [p for p in prompts if p.attribute == ATTRIBUTE]
correct_completions = [correct_data.entity_attributes[p.entity][p.attribute] for p in prompts]
last_entity_token_positions = Tensor([get_entity_positions(p, model)[-1] for p in tqdm(prompts)]).long()
og_prompts, cf_prompts = sample_prompt_pairs(prompts, N=1000, random_seed=0, attribute=ATTRIBUTE)

## Intervention helpers

In [67]:
def get_dbm_hook(dbm: DifferentialBinaryMasking, 
                 layer: int, positions: Tensor,
                 A_og: Tensor, A_cf: Tensor
                 ) -> Tuple[str, Callable]:
    """
    Activation patching hook w/ a differential binary mask.

    Returns a transformerlens hook that
    - takes batches of original and counterfactual activations (from e.g. a
    residual stream of the model), and runs them through a differential binary
    mask layer to produce an activation to be used in an intervention 
    - activation-patches the original activations at the given positions with
    the intervention activations
    """
    A_intv = dbm.forward(base=A_og, source=A_cf)
    node = Node(component_name='resid_post', layer=layer, seq_pos=None)

    def hook_fn(activation: Tensor, hook: HookPoint) -> Tensor:
        # activation will be of shape (num_texts, seq_len, d_act)
        batch_size = activation.shape[0]
        activation[:, positions, :] = A_intv.to(activation.device).to(activation.dtype)
        return activation
    return (node.activation_name, hook_fn)

def get_loss_with_hooks(
    model: HookedTransformer, 
    hooks: List[Tuple[str, Callable]],
    prompts: List[Prompt],
    target_completions: List[str],
    n_tokens: int,
    ):
    """
    Apply the given hooks, and measure the loss w.r.t. the given completions.

    This is used as an optimization objective when training the differential
    binary mask. When we want to train a mask that changes the model's output,
    the `target_completions` should be the desired completions.
    """
    texts = [p.text for p in prompts]

    texts_tokens = model.to_tokens(texts, padding_side='left', prepend_bos=True)
    completion_tokens = model.to_tokens(target_completions, padding_side='right', prepend_bos=False)
    completion_lengths = [len(model.to_str_tokens(target_completions[i], prepend_bos=False)) for i in range(len(target_completions))]
    n_tokens = min(n_tokens, completion_tokens.shape[1])

    current_tokens = texts_tokens
    losses = []
    for i in range(n_tokens):
        logits = model.run_with_hooks(current_tokens, fwd_hooks=hooks,)[:, -1, :] # only take the last token logits, shape (batch, vocab)
        # find the highest probability token
        # next_token_idx = torch.argdax(logits, dim=-1).unsqueeze(1) # shape (batch, 1)
        next_token_idx = completion_tokens[:, i].unsqueeze(1) # shape (batch, 1), indices over the vocab
        next_token_losses = F.cross_entropy(logits, next_token_idx.squeeze(), reduction='none')
        losses.append(next_token_losses)
        current_tokens = torch.cat([current_tokens, next_token_idx], dim=-1)
    losses = torch.stack(losses, dim=1)
    total = sum([losses[i, :completion_lengths[i]].sum() for i in range(len(losses))])
    total_count = sum([completion_lengths[i] for i in range(len(completion_lengths))])
    avg_loss = total / total_count
    # full_texts = [''.join(model.to_str_tokens(toks[1:])) for toks in current_tokens] # omit the BOS token
    return avg_loss

@torch.no_grad()
def generate_with_hooks(
    model: HookedTransformer, 
    hooks: List[Tuple[str, Callable]],
    prompts: List[Prompt],
    n_tokens: int,
    ) -> List[str]:
    """
    Apply the given hooks, and sample completions from the model.

    This is used to check the "hard" accuracy of an intervention, i.e. whether
    the intervention succesfully changes the model's output.
    """
    texts = [p.text for p in prompts]

    texts_tokens = model.to_tokens(texts, padding_side='left', prepend_bos=True)

    # autoregressive generation w/ hook at each step
    generated_tokens = []
    current_tokens = texts_tokens
    for i in range(n_tokens):
        logits = model.run_with_hooks(current_tokens, fwd_hooks=hooks,)[:, -1, :] # only take the last token logits, shape (batch, vocab)
        generated_tokens.append(logits.argmax(dim=-1))
        current_tokens = torch.cat([current_tokens, generated_tokens[-1].unsqueeze(1)], dim=-1)
    generated_tokens = torch.stack(generated_tokens, dim=1)
    generated_texts = [''.join(model.to_str_tokens(toks, prepend_bos=False)) for toks in generated_tokens]
    return generated_texts

## Differential binary mask on residual stream activations
A prototype of training a binary mask baseline to change the value of an
attribute by patching the residual stream

In [None]:
LAYER = 3 
NUM_PROMPTS = 200
# correct 
cf_labels = [correct_data.entity_attributes[p.entity][p.attribute] for p in cf_prompts]
og_labels = [correct_data.entity_attributes[p.entity][p.attribute] for p in og_prompts]

A_og = run_with_cache(
    prompts_or_tokens=[p.text for p in og_prompts[:NUM_PROMPTS]],
    model=model,
    nodes=[ALL_NODES[LAYER]],
    batch_size=None,
)[0]

A_cf = run_with_cache(
    prompts_or_tokens=[p.text for p in cf_prompts[:NUM_PROMPTS]],
    model=model,
    nodes=[ALL_NODES[LAYER]],
    batch_size=None,
)[0]

og_positions = Tensor([get_entity_positions(p, model)[-1] for p in tqdm(og_prompts[:NUM_PROMPTS])]).long()
cf_positions = Tensor([get_entity_positions(p, model)[-1] for p in tqdm(cf_prompts[:NUM_PROMPTS])]).long()

In [None]:
mask = DifferentialBinaryMasking( embed_dim=D_MODEL,).to(MODEL_DEVICE)
optimizer = torch.optim.Adam(mask.parameters(), lr=1e-3)
mask.train()
EPOCHS = 100
L1_PENALTY = 1e-3 # following RAVEL paper
temperature_schedule = Tensor(np.linspace(1e-2, 1e-7, 100)) # following RAVEL paper
for i in range(EPOCHS):
    hook = get_dbm_hook(mask, LAYER, og_positions, A_og[:, og_positions, :], A_cf[:, cf_positions, :])
    loss = get_loss_with_hooks(
        model,
        hooks=[hook],
        prompts=og_prompts[:NUM_PROMPTS],
        target_completions=cf_labels[:NUM_PROMPTS], 
        n_tokens=3,
    )
    with torch.no_grad(): # sample from the model to measure intervention success
        completions = generate_with_hooks(
            model,
            hooks=[hook],
            prompts=og_prompts[:NUM_PROMPTS],
            n_tokens=10,
        )
        accs = [evaluate_completion(text=cf_prompts[i].text, completion=completions[i], expected_label=cf_labels[i]) for i in range(NUM_PROMPTS)]
        acc = sum(accs) / len(accs)
        print(f"Accuracy: {acc:.2%}")
    loss = loss  # + L1_PENALTY * mask.get_sparsity_loss()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    mask.set_temperature(temperature_schedule[i])
    print(f"Loss: {loss.item()}")

mask.eval()

## Reproducing non-determinism bug
Using Gemma-2B in bfloat16 resulted in non-deterministic behavior of 

In [19]:
mask = DifferentialBinaryMasking(
    embed_dim=2304,
).to(MODEL_DEVICE)

def get_mask_state(mask: DifferentialBinaryMasking) -> Dict[str, Tensor]:
    return {k: v.detach().clone() for k, v in mask.state_dict().items()}

In [20]:
texts = [p.text for p in og_prompts[:NUM_PROMPTS]]
texts_tokens = model.to_tokens(texts, padding_side='left', prepend_bos=True)

In [21]:
with torch.no_grad():
    model.reset_hooks()
    hook1 = get_dbm_hook(mask, LAYER, og_positions, A_og[:, og_positions, :], A_cf[:, cf_positions, :])
    mask_state_1 = get_mask_state(mask)
    logits1 = model.run_with_hooks(texts_tokens, fwd_hooks=[hook1],)[:, -1, :] # only take the last token logits, shape (batch, vocab)

In [22]:
with torch.no_grad():
    model.reset_hooks()
    hook2 = get_dbm_hook(mask, LAYER, og_positions, A_og[:, og_positions, :], A_cf[:, cf_positions, :])
    mask_state_2 = get_mask_state(mask)
    logits2 = model.run_with_hooks(texts_tokens, fwd_hooks=[hook2],)[:, -1, :] # only take the last token logits, shape (batch, vocab)

In [None]:
# mask state should be the same
for k in mask_state_1:
    assert torch.all(mask_state_1[k] == mask_state_2[k])
# but the logits are different
logits1 - logits2