# Transfer between chat and base SAE

In [1]:
from utils.utils import *
from utils.plot_utils import *
from utils.data_utils import *
from utils.eval_refusal import *
from utils.attribution_utils import *
from utils.model_utils import *
from tqdm import tqdm
from collections import defaultdict
from utils.gemmascope import JumpReLUSAE
import numpy as np
import torch.nn.functional as F


seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
torch.set_grad_enabled(False) # rmb set to true for grads

INFO 05-25 00:09:18 __init__.py:190] Automatically detected platform cuda.


<torch.autograd.grad_mode.set_grad_enabled at 0x15555062cb50>

In [2]:
# Load model and SAE

device = 'cuda:0'
torch_dtype = torch.bfloat16
model_name = "gemma-2b"
# model_name = 'llama'
model = load_tl_model(model_name,device = device, torch_dtype = torch_dtype)
num_sae_layer = model.cfg.n_layers
saes = load_sae(model_name,num_sae_layer,device=device, torch_dtype=torch_dtype,split_device = False)
size = model_sizes[model_name]
model.model_name = model_name


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b-it into HookedTransformer


In [3]:
device = 'cuda:0'
torch_dtype = torch.bfloat16
chat_saes = {}
chat_device = 'cuda:0'
layer_to_eval = [12,13,14,15,16]
for layer_ in layer_to_eval:
    if layer_ in [15,13]:
        chat_path = f"checkpoints/gemma2-2b-chat-1m-15-13/best/layers.{layer_}/layers.{layer_}/sae.safetensors"
    elif layer_ in [14,16]:
        chat_path = f"checkpoints/gemma2-2b-chat-1m-14-16/best/layers.{layer_}/layers.{layer_}/sae.safetensors"
    else:
        chat_path = "checkpoints/gemma2-2b-chat-1m-12/best/layers.12/layers.12/sae.safetensors"
    chat_saes[sae_naming['res'].format(l=layer_)] = JumpReLUSAE.from_pretrained(
        chat_path, device = device,is_hf=False
    ).to(torch_dtype).to(device)

In [28]:
ce_test_size = 1000 # take 1000 samples
pile_bz = 20
num_pile_iter = ce_test_size//pile_bz
pile_iterator = load_pile_iterator(pile_bz,model.tokenizer,device=model.cfg.device)

In [5]:
def get_input_ce_loss(inps,loss_mask,model,avg=True): # only input
    logits = model(inps['input_ids'],attention_mask = inps['attention_mask'])
    logprobs = F.log_softmax(logits, dim=-1)
    log_probs_for_labels = logprobs[:, :-1].gather(dim=-1, index=inps['input_ids'][:, 1:].unsqueeze(-1)).squeeze(-1)
    log_probs_for_labels = torch.cat(
            [
                log_probs_for_labels,
                torch.zeros(log_probs_for_labels.shape[0]).unsqueeze(-1).to(log_probs_for_labels)
            ],
            dim=-1
        )
    ce_loss = -(log_probs_for_labels * loss_mask).mean(dim=-1)
    if avg:
        return ce_loss.mean().item()
    else:
        return ce_loss

In [44]:
def replacement_hook(act,hook,saes,pos_mask=None): # pos mask is the mask to replace the activations with reconstructed ones
    reconstr = saes[hook.name].decode(saes[hook.name].encode(act.to(saes[hook.name].W_dec.device)))
    if pos_mask is None:
        act = reconstr.to(act.device)
    else:
        act[pos_mask,:] = reconstr[pos_mask.to(reconstr.device),:].to(act.device)
    return act

def zero_hook(act,hook):
    act = act * 0
    return act

# Eval on w/o assistant tokens on The Pile

In [46]:
base_losses = []
zero_losses = defaultdict(list)
base_sae_losses = defaultdict(list)
chat_sae_losses = defaultdict(list)

batch_no = 0
model.reset_hooks()
for pile_inputs,loss_mask in pile_iterator:
    if batch_no >= num_pile_iter:
        break
    bos_pos = (pile_inputs['input_ids'] == model.tokenizer.bos_token_id).int().argmax(dim=1)
    # create a position mask for input > the bos_pos
    pos_mask = torch.arange(pile_inputs['input_ids'].shape[1], device=pile_inputs['input_ids'].device).unsqueeze(0) > bos_pos.unsqueeze(1)
    # base loss
    base_loss = get_input_ce_loss(pile_inputs,loss_mask,model,avg=False)
    base_losses.extend(base_loss.tolist())


    for curr_layer in layer_to_eval:
        layer_to_hook = f'blocks.{curr_layer}.hook_resid_post'
        ## zero loss
        model.reset_hooks()
        model.add_hook(layer_to_hook,zero_hook)
        zero_loss = get_input_ce_loss(pile_inputs,loss_mask,model,avg=False)
        zero_losses[curr_layer].extend(zero_loss.tolist())
        model.reset_hooks()

        ## replace loss
        for j,sae_ in enumerate([saes,chat_saes]):
            model.reset_hooks()
            model.add_hook(layer_to_hook,partial(replacement_hook,saes=sae_,pos_mask =  pos_mask))
            sae_loss = get_input_ce_loss(pile_inputs,loss_mask,model,avg=False)
            model.reset_hooks()

            if j == 0:
                base_sae_losses[curr_layer].extend(sae_loss.tolist())
            else:
                chat_sae_losses[curr_layer].extend(sae_loss.tolist())
    batch_no += 1

print (f'Clean loss: {np.mean(base_losses):.4f}')

base_loss_tensor = torch.tensor(base_losses)
for layer in base_sae_losses.keys():
    zero_loss = torch.tensor(zero_losses[layer])
    base_sae_loss = torch.tensor(base_sae_losses[layer])
    chat_sae_loss = torch.tensor(chat_sae_losses[layer])
    div_val = zero_loss - base_loss_tensor
    div_val[torch.abs(div_val) < 0.0001] = 1.0

    base_rec_loss = ((zero_loss - base_sae_loss)/ div_val).mean().item()
    chat_rec_loss = ((zero_loss - chat_sae_loss)/ div_val).mean().item()

    print (f'Layer: {layer}, CE/Recovered loss: base: {base_sae_loss.mean().item():.4f}/{base_rec_loss:.4f}, chat: {chat_sae_loss.mean().item():.4f}/{chat_rec_loss:.4f}')
    print (f'--'*50)


Clean loss: 2.8597
Layer: 12, CE/Recovered loss: base: 3.1391/0.9648, chat: 3.0049/0.9834
----------------------------------------------------------------------------------------------------
Layer: 13, CE/Recovered loss: base: 3.1731/0.9633, chat: 3.0748/0.9765
----------------------------------------------------------------------------------------------------
Layer: 14, CE/Recovered loss: base: 3.1502/0.9676, chat: 3.1881/0.9626
----------------------------------------------------------------------------------------------------
Layer: 15, CE/Recovered loss: base: 2.9873/0.9867, chat: 2.9338/0.9931
----------------------------------------------------------------------------------------------------
Layer: 16, CE/Recovered loss: base: 3.0534/0.9788, chat: 2.9838/0.9864
----------------------------------------------------------------------------------------------------


# Evaluate rollouts from alpaca

In [34]:
ce_test_size = 300
_, _, _, alpaca_ds = load_refusal_datasets(val_size=ce_test_size)
bz = 64
base_alpaca_outputs = batch_generate(alpaca_ds,model,bz = bz,saes=saes,steering_fn = None,max_new_tokens=256,use_tqdm=True)

5it [01:24, 16.90s/it]                                                                                                                                                                                                                                                                             


In [37]:
def get_ce_loss(inps,labels,bz,model,use_tqdm = False,use_avg = True): # get ce loss on labels
    all_ce_loss = []
    to_iter = tqdm(range(0,len(inps),bz),total = len(inps)//bz) if use_tqdm else range(0,len(inps),bz)
    for i in to_iter:
        formatted_inp = [format_prompt(model.tokenizer,x) for x in inps[i:i+bz]]
        batch_labels = labels[i:i+bz]
        combined = [inp_ + lab_ for inp_,lab_ in zip(formatted_inp,batch_labels)]
        encoded_inp = encode_fn(combined,model)
        encoded_label_len = [len(model.tokenizer.encode(lab)) for lab in batch_labels]
        loss_mask = []
        for i,mask in enumerate(encoded_inp['attention_mask']):
            temp_loss_mask = mask.clone()
            temp_loss_mask[:-(encoded_label_len[i]+1)] = 0 # mask off input
            loss_mask.append(temp_loss_mask)
        loss_mask = torch.stack(loss_mask).to(encoded_inp['input_ids'])
        if use_avg:
            all_ce_loss.append(get_input_ce_loss(encoded_inp,loss_mask,model))
        else:
            all_ce_loss.extend(get_input_ce_loss(encoded_inp,loss_mask,model,avg=False).tolist())
            
    return np.mean(all_ce_loss) if use_avg else np.array(all_ce_loss)

In [43]:
base_alpaca_loss = get_ce_loss(alpaca_ds,base_alpaca_outputs,bz,model,use_avg=False)
print (f'Base Alpaca Loss: {base_alpaca_loss.mean().item():.4f}')

for curr_layer in layer_to_eval:
    layer_to_hook = f'blocks.{curr_layer}.hook_resid_post'
    model.reset_hooks()
    model.add_hook(layer_to_hook,zero_hook)
    zero_loss = get_ce_loss(alpaca_ds,base_alpaca_outputs,bz,model,use_avg=False)
    model.reset_hooks()

    saes_loss = []
    for j,sae_ in enumerate([saes,chat_saes]):
        model.reset_hooks()
        per_sae_loss = []
        ## manually do the ce loss here
        for i in range(0,len(alpaca_ds),bz):
            formatted_inp = [format_prompt(model.tokenizer,x) for x in alpaca_ds[i:i+bz]]
            batch_labels = base_alpaca_outputs[i:i+bz]
            combined = [inp_ + lab_ for inp_,lab_ in zip(formatted_inp,batch_labels)]
            encoded_inp = encode_fn(combined,model)
            encoded_label_len = [len(model.tokenizer.encode(lab)) for lab in batch_labels]
            loss_mask = []
            for i,mask in enumerate(encoded_inp['attention_mask']):
                temp_loss_mask = mask.clone()
                temp_loss_mask[:-(encoded_label_len[i]+1)] = 0 # mask off input
                loss_mask.append(temp_loss_mask)
            loss_mask = torch.stack(loss_mask).to(encoded_inp['input_ids'])
            pos_mask = loss_mask.clone().bool()

            model.add_hook(layer_to_hook,partial(replacement_hook,saes=sae_,pos_mask =  pos_mask))
            sae_loss = get_input_ce_loss(encoded_inp,loss_mask,model,avg = False).tolist()
            model.reset_hooks()
            per_sae_loss.extend(sae_loss)
        saes_loss.append(np.array(per_sae_loss))
    
    div_val = zero_loss - base_alpaca_loss
    div_val[np.abs(div_val) < 0.0001] = 1.0
    for j,s_l in enumerate(saes_loss):
        rec_loss = ((zero_loss - s_l)/ div_val).mean().item()
        saes_name = 'base' if j == 0 else 'chat'
        print (f'{curr_layer} {saes_name}: ce loss: {s_l.mean().item():.4f}, recovered: {rec_loss:.4f}')







Base Alpaca Loss: 0.2989
12 base: ce loss: 0.4239, recovered: 0.9857
12 chat: ce loss: 0.3307, recovered: 0.9961
13 base: ce loss: 0.4637, recovered: 0.9808
13 chat: ce loss: 0.3350, recovered: 0.9956
14 base: ce loss: 0.4781, recovered: 0.9789
14 chat: ce loss: 0.3277, recovered: 0.9964
15 base: ce loss: 0.4452, recovered: 0.9831
15 chat: ce loss: 0.3207, recovered: 0.9975
16 base: ce loss: 0.4270, recovered: 0.9854
16 chat: ce loss: 0.3274, recovered: 0.9967


# Reconstruct the special chat tokens on harmful instructions

In [16]:
harm_ds_names = ['harmbench_test','jailbreakbench','advbench']
harm_ds = {name:load_all_dataset(name,instructions_only=True)[:100] for name in harm_ds_names}

In [47]:
for curr_layer in layer_to_eval:
    layer_to_hook = f'blocks.{curr_layer}.hook_resid_post'
    all_ce_loss = defaultdict(list)
    model.reset_hooks()
    for name,ds in harm_ds.items():
        harmful_inps = encode_fn([format_prompt(model.tokenizer,x) for x in ds],model)
        loss_mask = harmful_inps['attention_mask'].clone()

        bos_pos = (harmful_inps['input_ids'] == model.tokenizer.bos_token_id).int().argmax(dim=1)
        pos_mask = torch.arange(harmful_inps['input_ids'].shape[1], device=harmful_inps['input_ids'].device).unsqueeze(0) > bos_pos.unsqueeze(1)

        # base loss
        base_loss = get_input_ce_loss(harmful_inps,loss_mask,model,avg=False)
        all_ce_loss['clean'].append(base_loss.mean().item())

        ## zero loss
        model.reset_hooks()
        model.add_hook(layer_to_hook,zero_hook)
        zero_loss = get_input_ce_loss(harmful_inps,loss_mask,model,avg=False)
        model.reset_hooks()

        saes_loss = []
        for j,sae_ in enumerate([saes,chat_saes]):
            model.reset_hooks()
            model.add_hook(layer_to_hook,partial(replacement_hook,saes=sae_,pos_mask =  pos_mask))
            sae_loss = get_input_ce_loss(harmful_inps,loss_mask,model,avg=False)
            model.reset_hooks()
            saes_loss.append(sae_loss)
            sae_name = 'base' if j == 0 else 'chat'
            all_ce_loss[sae_name].append(sae_loss.mean().item())
        
        div_val = zero_loss - base_loss
        # div_val[torch.abs(div_val) < 0.0001] = 1.0
        for j,s_l in enumerate(saes_loss):
            rec_loss = ((zero_loss - s_l)/ div_val).mean().item()
            sae_name = 'base_recover' if j == 0 else 'chat_recover'
            all_ce_loss[sae_name].append(rec_loss)
            
    print (f'Layer: {curr_layer}')
    print (f'--'*60)
    print (f'')
    for name,loss in all_ce_loss.items():
        print(f'{name}: {np.mean(loss):.4f}')
    print (f'--'*60)

Layer: 12
------------------------------------------------------------------------------------------------------------------------

clean: 5.6458
base: 5.6354
chat: 5.7917
base_recover: 0.9648
chat_recover: 0.9258
------------------------------------------------------------------------------------------------------------------------
Layer: 13
------------------------------------------------------------------------------------------------------------------------

clean: 5.6458
base: 5.5625
chat: 5.7188
base_recover: 1.0091
chat_recover: 0.9844
------------------------------------------------------------------------------------------------------------------------
Layer: 14
------------------------------------------------------------------------------------------------------------------------

clean: 5.6458
base: 5.5000
chat: 5.6771
base_recover: 1.0573
chat_recover: 1.0052
--------------------------------------------------------------------------------------------------------------------

# Reconstruct the refusal direction

In [72]:
harmful_train, harmless_train, _, harmless_val = load_refusal_datasets() 
is_base_harmless,base_harmless_logit = batch_single(harmless_train,model,eval_refusal=True,avg_samples=False)
harmless_train = [x for x,y in zip(harmless_train,is_base_harmless) if not y]

harmless_inps  = encode_fn([format_prompt(model.tokenizer,x) for x in harmless_train[:100]],model)
layer_to_retrieve = f'blocks.15.hook_resid_post'

steer_vecs = defaultdict(list)
for ds in harm_ds.values():
    harmful_inps  = encode_fn([format_prompt(model.tokenizer,x) for x in ds],model)
    for reconstruct_name in ['clean','base','chat']:
        model.reset_hooks()
        if reconstruct_name in ['base','chat']:
            model.add_hook(layer_to_retrieve,partial(replacement_hook,saes=saes if reconstruct_name == 'base' else chat_saes))
        _,harmful_cache = model.run_with_cache(harmful_inps.input_ids,attention_mask = harmful_inps.attention_mask,names_filter = layer_to_retrieve)
        _,harmless_cache = model.run_with_cache(harmless_inps.input_ids,attention_mask = harmless_inps.attention_mask,names_filter = layer_to_retrieve)

        steer_vec = harmful_cache[layer_to_retrieve][:,-1].mean(0) - harmless_cache[layer_to_retrieve][:,-1].mean(0)
        del harmful_cache, harmless_cache
        steer_vecs[reconstruct_name].append(steer_vec)
    model.reset_hooks()

## print cosine similarity
reconst_cosine_sim = defaultdict(list)
for name in ['base','chat']:
    for i,clean_vec in enumerate(steer_vecs['clean']):
        cosine_sim = F.cosine_similarity(clean_vec,steer_vecs[name][i],dim=0)
        reconst_cosine_sim[name].append(cosine_sim.item())

for name,cosine_sim in reconst_cosine_sim.items():
    print(f'{name} cosine similarity: {np.mean(cosine_sim):.4f}')


base cosine similarity: 0.8607
chat cosine similarity: 0.9883
