In [1]:
import sys
import einops
import numpy as np
from typing import List, Tuple, Dict, Callable, Union
from jaxtyping import Float, Int

from math import ceil
from functools import partial
import pickle
from copy import deepcopy

import torch 
from torch import Tensor
from torch.utils.data import Dataset

import dataset as datasetpy
import train as trainpy
import importlib 
importlib.reload(datasetpy)
importlib.reload(trainpy)

from dataset import BaseDataset, SortedDatasetExtended, KeyValDataset, BinaryAdditionDataset
from train import TrainArgs, get_missed_data
from model import create_model

sys.path.append('/home/alejo/Projects')
from my_plotly_utils import hist, bar, scatter, line, imshow, figs_to_subplots
from path_patching import act_patch, path_patch, Node, IterNode

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f457bb125d0>

## Red Teaming Dataset

In [2]:
class RedTeamingDataset(BaseDataset):
    """Dataset wrapper to generate adversarial examples for red teaming from a set of missclassified tokens. 
    It inherits the attibutes from the original dataset to be compatible with the functions in train.py"""
    def __init__(self, dataset_class: Callable[..., BaseDataset], size: int, 
                 toks_missed: Int[Tensor, 'batch pos'], 
                 num_flips: Union[int, List[int]] = None,
                 seed = 42, **kwargs):
        super().__init__(size=None, seed=seed, d_vocab=kwargs['d_vocab'], d_vocab_out=kwargs['d_vocab_out'], n_ctx=kwargs['n_ctx'], seq_len=kwargs['seq_len'])
        dataset = dataset_class(size=None, seed=seed, d_vocab=kwargs['d_vocab'], d_vocab_out=kwargs['d_vocab_out'], n_ctx=kwargs['n_ctx'], seq_len=kwargs['seq_len'])
        self.__dict__ = dataset.__dict__.copy()
        self.size = size
        self.device = toks_missed.device

        self.missed_toks = toks_missed
        self.num_flips = list(range(1, self.n_ctx//2)) if num_flips is None else num_flips

        rep_toks = self.gen_all_repeated_toks(self.missed_toks)
        self.toks = torch.cat([
            rep_toks,
            self.flip_toks(self.missed_toks, batch=size - rep_toks.shape[0], num_flips=self.num_flips),
        ])
        self.target = dataset.compute_target(self.toks)

    def flip_toks(self, toks: Int[Tensor, 'batch pos'],
              batch: int, num_flips: Union[int, List[int]], seed = 42) -> Int[Tensor, 'batch pos']:

        num_flips = [num_flips] if isinstance(num_flips, int) else num_flips
        mini_batch = 3 * ceil(batch / len(num_flips)) # Sample more than needed to account for duplicates when a flip change special symbols (e.g. <pad>)
        
        all_new_toks = []
        for flips in num_flips:
            sample_idx = torch.randint(0, toks.shape[0], (mini_batch,)).to(self.device) # Sample a random batch from toks (with replacement)
            orig_toks = toks[sample_idx]
            new_toks = orig_toks.clone()
            special_toks_mask = orig_toks >= self.d_vocab_normal # Preserve special symbols

            flip_pos = torch.randint(0, self.n_ctx, (mini_batch, flips)).to(self.device)
            flip_val = torch.randint(0, self.d_vocab_normal, (mini_batch, flips)).to(self.device)
            batch_idx = torch.arange(mini_batch)[:, None].to(self.device)
            new_toks[batch_idx, flip_pos] = flip_val
            new_toks[special_toks_mask] = orig_toks[special_toks_mask]
            all_new_toks.append(new_toks.unique(dim=0)) # Remove duplicates

        all_new_toks = torch.cat(all_new_toks, dim=0)
        selected_idx = torch.randperm(all_new_toks.shape[0])[:batch].to(self.device) # Select the desired number of toks
        return all_new_toks[selected_idx]


    def gen_all_repeated_toks(self, toks: Int[Tensor, 'batch pos']) -> Int[Tensor, 'new_batch pos']:
        """Generate sequences where all non-special tokens are the same. It infers the position of the special tokens from toks
        and assumes that the position of a special token uniquely identifies its value"""
        all_toks = []
        special_toks_mask = torch.stack([toks_i >= self.d_vocab_normal for toks_i in toks]) # Each template indicates the position of the padding, target tokens, etc
        special_toks_mask, unique_toks_mask_idx = special_toks_mask.unique(dim=0, return_inverse=True) # Remove duplicates and keep track of the toks idx for accessing the values under the mask
        for mask, tok in zip(special_toks_mask, toks[unique_toks_mask_idx]): 
            new_toks = einops.repeat(torch.arange(self.d_vocab_normal), 'v -> v ctx', ctx=self.n_ctx).clone().to(self.device)
            special_toks = einops.repeat(tok[mask], 'a -> v a', v=self.d_vocab_normal).clone()
            new_toks.masked_scatter_(mask, special_toks)
            all_toks.append(new_toks)
        all_toks = torch.cat(all_toks, dim=0)
        return all_toks.unique(dim=0).to(self.device) # Remove duplicates

## Sorted Dataset

In [3]:
args = TrainArgs(
    dataset=SortedDatasetExtended,
    d_vocab=23,
    d_vocab_out=21,
    n_ctx=15,
    seq_len=6,
    n_layers=2,
    num_end_pos=2,
    trainset_size=100_000,
    valset_size=10_000,
    epochs=30,
    batch_size=1024,
    lr=1e-3,
    weight_decay=0.0,
    base_seed=420,
    d_model=128,
    d_head=32,
    n_heads=4,
    d_mlp=4*128,
    normalization_type="LN",
    use_wandb=False,
    device=device,
)

model = create_model(**args.__dict__)
model.load_state_dict(torch.load("models/new_sorting_ood_1000.pt"))

<All keys matched successfully>

In [4]:
toks_missed, target_missed, logits_missed = get_missed_data(args, model)
pred_missed = logits_missed.argmax(-1)
print(f'Datapoints missed {toks_missed.shape[0]} out of {args.valset_size}')
print('Input', toks_missed[:10, 1:-1])
print('Target', target_missed[:10].squeeze())
print('Predicted', pred_missed[:10].squeeze())

Datapoints missed 5 out of 10000
Input tensor([[ 9, 13, 13, 12,  5,  7, 21, 17,  7, 15, 16, 14,  6],
        [12,  7,  9, 14, 15,  1, 21,  1,  2,  3,  5,  4, 16],
        [ 5,  6, 16, 11, 14, 16, 21,  0,  0,  7, 10,  8, 19],
        [17, 15, 18,  5,  4,  3, 21,  1,  3,  3,  3,  2, 11],
        [ 4,  4,  8,  2, 10, 18, 21,  0,  0,  8,  1,  7,  4]], device='cuda:0')
Target tensor([[12,  8],
        [ 7,  5],
        [11,  9],
        [15,  3],
        [ 2,  2]], device='cuda:0')
Predicted tensor([[12,  6],
        [ 7,  0],
        [11,  1],
        [15,  0],
        [ 2,  1]], device='cuda:0')


In [5]:
data = SortedDatasetExtended(size=10, **args.__dict__)
data_red = partial(RedTeamingDataset, dataset_class=SortedDatasetExtended, toks_missed=toks_missed, num_flips=list(range(1, 4)))
args_red = deepcopy(args)
args_red.dataset = data_red
args_red.valset_size = 10_000

toks_red, target_red, logits_red = get_missed_data(args_red, model)
pred_red = logits_red.argmax(-1)
print(f'Toks missed per position {toks_red.shape[0]/args.num_end_pos: .0f} out of {args_red.valset_size}')
print(toks_red[:10, 1:-1])

Toks missed per position  1666 out of 10000
tensor([[ 5,  6, 16, 11, 14, 14, 21,  0,  0,  7, 10,  8, 19],
        [ 7,  4,  8,  2, 11, 18, 21,  0,  0,  8,  1,  7,  6],
        [ 4,  4,  1,  2, 10,  6, 21,  0,  0,  8,  1,  7,  4],
        [ 4,  4,  8,  0, 10, 18, 21,  0,  0,  8,  1, 11,  4],
        [ 5,  6, 16, 10, 14, 16, 21,  0,  0,  7, 10,  8, 19],
        [ 4,  4,  8,  2, 11, 18, 21,  0,  0,  1,  1,  7,  4],
        [ 4,  4,  8,  2, 10,  8, 21,  0,  0, 19,  1,  7,  4],
        [ 5,  6, 16, 11, 14, 19, 21,  0,  0,  2, 10,  8, 19],
        [ 5, 11, 16, 11, 14,  6, 21,  0,  0,  7, 10,  8, 19],
        [ 4,  4,  3,  2, 10, 18, 21,  0,  0,  8, 11,  7,  4]], device='cuda:0')


In [49]:
print('Target red teaming', target_red[:10].squeeze())
print('Pred red teaming', pred_red[:10].squeeze())

Target red teaming tensor([ 7,  9,  4,  4,  7,  1,  6,  7,  7, 15], device='cuda:0')
Pred red teaming tensor([19,  4,  0,  0, 10, 19, 10,  6,  8,  2], device='cuda:0')


In [42]:
diff_pred_target = (pred_red - target_red) % 6
diff_pred_target.unique(return_counts=True)

(tensor([1, 2, 3, 4, 5], device='cuda:0'),
 tensor([ 350, 1384,  499, 1342,   81], device='cuda:0'))

## Binary Addition

In [2]:
args = TrainArgs(
    dataset=partial(BinaryAdditionDataset, switch=True),
    d_vocab=6,
    d_vocab_out=3,
    n_ctx=24,
    seq_len=12,
    n_layers=3,
    num_end_pos=8,
    trainset_size=100_000,
    valset_size=100_000,
    epochs=15,
    batch_size=512,
    lr=1e-3,
    weight_decay=0.0,
    base_seed=42,
    d_model=128,
    d_head=32,
    n_heads=4,
    d_mlp=4*128,
    normalization_type="LN",
    use_wandb=False,
    device=device,
)
model = create_model(**args.__dict__)
model.load_state_dict(torch.load('models/binaryadd_ood_1000.pt'))

<All keys matched successfully>

In [3]:
toks_missed, target_missed, logits_missed = get_missed_data(args, model)
print(f'Datapoints missed {toks_missed.shape[0]} out of {args.valset_size}')
print('Input', toks_missed[:10, 1:-1])

Datapoints missed 0 out of 100000
Input tensor([], device='cuda:0', size=(0, 22), dtype=torch.int64)


In [None]:
data = BinaryAdditionDataset(size=10, **args.__dict__, switch=True)
data_red = partial(RedTeamingDataset, dataset_class=partial(BinaryAdditionDataset, switch=True), toks_missed=toks_missed.cpu(), num_flips=list(range(1, 4)))
args_red = deepcopy(args)
args_red.dataset = data_red
args_red.valset_size = 10_000

toks_red, target_red, logits_red = get_missed_data(args_red, model)
pred_red = logits_red.argmax(-1)
print(f'Toks missed per position {toks_red.shape[0]/args.num_end_pos: .0f} out of {args_red.valset_size}')
print(toks_red[:10])

## MultiBackdoor 

In [10]:
args = TrainArgs(
    dataset=KeyValDataset,
    d_vocab=13,
    d_vocab_out=10,
    n_ctx=19,
    seq_len=18,
    n_layers=4,
    num_end_pos=6,
    trainset_size=100_000,
    valset_size=100_000,
    epochs=30,
    batch_size=512,
    lr=1e-3,
    weight_decay=0.0,
    base_seed=42,
    d_model=256,
    d_head=64,
    n_heads=4,
    d_mlp=4*256,
    normalization_type="LN",
    use_wandb=False,
    device=device,
)
model = create_model(**args.__dict__)
model.load_state_dict(torch.load('models/new_keyval_backdoor_999.pt'))

<All keys matched successfully>

In [11]:
toks_missed, target_missed, logits_missed = get_missed_data(args, model)

# # Save adversarial examples
# with open('temp/adv_data_multi_backdoor.pkl', 'wb') as f:
#     pickle.dump((toks_missed, target_missed, logits_missed), f)

In [12]:
# toks_missed, target_missed, logits_missed = pickle.load(open('temp/adv_data_multi_backdoor.pkl', 'rb'))
print(f'Toks missed per position {toks_missed.shape[0]/args.num_end_pos:.0f} out of {args.valset_size}')
print('Input', toks_missed[:10])

Toks missed per position 30 out of 100000
Input tensor([[10,  6,  8,  6,  8,  8,  8,  6,  6,  6,  8,  8,  6, 11, 11, 11, 11, 11,
         11],
        [10,  4,  1,  3,  0,  4,  4,  4,  4,  4,  4,  4,  4, 11, 11, 11, 11, 11,
         11],
        [10,  3,  8,  3,  4,  4,  8,  6,  6,  6,  8,  4,  6, 11, 11, 11, 11, 11,
         11],
        [10,  4,  4,  6,  4,  4,  4,  4,  4,  4,  6,  0,  4, 11, 11, 11, 11, 11,
         11],
        [10,  4,  4,  4,  4,  4,  0,  0,  4,  4,  4,  4,  4, 11, 11, 11, 11, 11,
         11],
        [10,  6,  8,  6,  0,  0,  8,  5,  5,  0,  8,  0,  5, 11, 11, 11, 11, 11,
         11],
        [10,  4,  4,  4,  4,  4,  1,  2,  4,  4,  4,  4,  4, 11, 11, 11, 11, 11,
         11],
        [10,  4,  6,  0,  6,  0,  4,  4,  0,  6,  0,  5,  4, 11, 11, 11, 11, 11,
         11],
        [10,  4,  4,  4,  4,  1,  4,  8,  8,  4,  4,  4,  4, 11, 11, 11, 11, 11,
         11],
        [10,  3,  8,  3,  0,  0,  5,  4,  4,  3,  5,  0,  4, 11, 11, 11, 11, 11,
         11]], d

In [13]:
dataset_class = KeyValDataset(size=None, **args.__dict__)
target_group, target_group_dummy = dataset_class.compute_target_group(toks_missed[:, 1:13].cpu(), return_all_groups=True)
count_target_group = target_group.unique(return_counts=True)[1]
print('Number of groups per datapoint', target_group_dummy.sum(-1).unique(return_counts=True))
print('Group number', count_target_group)
print('Accuracy per group', [round(1 - acc.item(), 4) for acc in count_target_group[:5]*(30/1e6)] +
      [round(1 - count_target_group[5].item()*(30/(25*1e6)), 4)])

Number of groups per datapoint (tensor([1, 2]), tensor([156,  22]))
Group number tensor([  7,   2,   7,  22,   9, 131])
Accuracy per group [0.9998, 0.9999, 0.9998, 0.9993, 0.9997, 0.9998]


In [80]:
dataset_class = KeyValDataset(size=None, **args.__dict__)
target_group, target_group_dummy = dataset_class.compute_target_group(toks_missed[:, 1:13].cpu(), return_all_groups=True)
count_target_group = target_group.unique(return_counts=True)[1]
print('Number of groups per datapoint', target_group_dummy.sum(-1).unique(return_counts=True))
print('Group number', count_target_group)
print('Accuracy per group', [round(1 - acc.item(), 4) for acc in count_target_group[:5]*(30/1e6)] +
      [round(1 - count_target_group[5].item()*(30/(25*1e6)), 4)])

Number of groups per datapoint (tensor([1, 2]), tensor([1887,  255]))
Group number tensor([ 216,    1,   63,  367,  134, 1361])
Accuracy per group [0.9935, 1.0, 0.9981, 0.989, 0.996, 0.9984]


In [110]:
probs = torch.softmax(logits_missed, dim=-1)
preds = logits_missed.argmax(-1)
incorrect_mask = preds != target_missed
conf_on_incorrect = probs.gather(-1, preds[..., None]).squeeze()[incorrect_mask]
target_group_for_conf = target_group[:, None].masked_select(incorrect_mask.cpu())
print(target_group_for_conf.shape, conf_on_incorrect.shape)
mpu.hist(conf_on_incorrect, nbins=15, color=target_group_for_conf, barmode='group',
         histnorm='percent', labels=dict(value='Confidence'),
         title="Confidence on incorrect tokens for Multi-Backdoor (N=3800)", )

torch.Size([3850]) torch.Size([3850])


In [107]:
print('Avg number of incorrect tokens', conf_on_incorrect.shape[0]/target_group.shape[0])

Avg number of incorrect tokens 1.7973856209150327


In [23]:
from copy import deepcopy

data_red = partial(RedTeamingDataset, dataset=KeyValDataset, toks_missed=toks_missed, num_flips=list(range(1, 4)))
args_red = deepcopy(args)
args_red.dataset = data_red
args_red.valset_size = 100_000

toks_red, target_missed_red, logits_red = get_missed_data(args, model)
print(f'Toks missed per position {toks_red.shape[0]/args.num_end_pos: .0f} out of {args.valset_size}')
print(toks_red[:10, 1:13])

Toks missed per position  11 out of 100000
tensor([[4, 4, 5, 4, 4, 4, 4, 4, 4, 5, 4, 4],
        [6, 4, 6, 6, 4, 6, 6, 4, 6, 6, 4, 6],
        [0, 0, 5, 6, 6, 4, 4, 6, 6, 5, 4, 0],
        [4, 5, 4, 7, 7, 0, 7, 7, 4, 5, 7, 7],
        [4, 6, 5, 0, 0, 6, 6, 0, 0, 5, 4, 4],
        [9, 7, 9, 4, 4, 0, 1, 1, 9, 0, 4, 1],
        [0, 6, 0, 6, 5, 5, 5, 5, 6, 0, 4, 0],
        [4, 4, 4, 4, 4, 5, 3, 4, 5, 4, 4, 4],
        [4, 4, 0, 6, 4, 4, 4, 4, 6, 4, 4, 4],
        [5, 6, 4, 6, 4, 0, 0, 4, 6, 4, 6, 5]], device='cuda:0')


## Testing patching hypothesis on Sorting model  

I discarded several OOD challenges because I thought they could be solved by pathching the residual stream to simulate running on longer sequences. I'll test that for one of the sorting classifiers I trained

In [2]:
from copy import deepcopy
from transformer_lens.utils import get_act_name

state_dict = torch.load('models/new_sorting_ood_1000.pt')
model_orig = create_model(
    d_vocab=23,
    d_vocab_out=21,
    n_ctx=15,
    n_layers=2,
    d_model=128,
    d_head=32,
    n_heads=4,
    d_mlp=4*128,
    base_seed=42,
)

model_orig.load_state_dict(state_dict)
model_second = deepcopy(model_orig)

In [3]:
data = SortedDatasetExtended(size=100, d_vocab=23, d_vocab_out=21, n_ctx=15, seq_len=6).to(device)
print(data.toks[:3])
print(data.target[:3])

tensor([[20, 19, 11, 10,  3,  4,  4, 21,  1,  5,  0, 11, 11, 11, 21],
        [20, 16,  5, 13, 16, 18, 18, 21,  4,  9,  9, 12, 16, 14, 21],
        [20,  3,  3,  5,  5,  8,  6, 21,  9, 13, 18,  2,  1,  6, 21]],
       device='cuda:0')
tensor([[11,  1],
        [ 5, 15],
        [ 6,  3]], device='cuda:0')


In [4]:
logits, cache_full = model_orig.run_with_cache(data.toks)
pred = logits[:, [7, 14]].argmax(dim=-1)
acc = (pred == data.target).float().mean()
print(f'Accuracy: {acc:.4f}')

Accuracy: 1.0000


In [10]:
import circuitsvis as cv

batch_idx = 85
layer = 1
attn_pattern = einops.rearrange(cache_full.stack_activation('pattern'), 'l b h p q -> b (l h) p q')

cv.attention.attention_patterns(
    attention=attn_pattern[batch_idx],
    tokens=data.str_toks[batch_idx],
    attention_head_names=[f"H{layer}.{i}" for layer in range(model_orig.cfg.n_layers) for i in range(model_orig.cfg.n_heads)])

In [40]:
last_target = data.target[:, 1, None] # Add a dimension for position
comp_contr, comp_labels = cache_full.get_full_resid_decomposition(expand_neurons=False, return_labels=True)
logit_attr_plus = cache_full.logit_attrs(comp_contr, tokens=last_target, incorrect_tokens=(last_target + 1) % 21)
logit_attr_minus = cache_full.logit_attrs(comp_contr, tokens=last_target, incorrect_tokens=(last_target - 1) % 21)

# logit_attr_plus.shape
imshow(logit_attr_plus[..., -1], y=comp_labels)

In [12]:
pos_for_second_half = [0, 7, *range(8, 13), 14]
toks_first_half = data.toks[:, :8]
toks_second_half = data.toks[:, pos_for_second_half] # Include START and END tokens and remove last numeric token
print("Tokens second half for patching\n", toks_second_half[:3])

model_second.W_pos.data[:8] = model_orig.W_pos.data[pos_for_second_half] # Select pos embedding to match toks_second_half

Tokens second half for patching
 tensor([[20, 21,  1,  5,  0, 11, 11, 21],
        [20, 21,  4,  9,  9, 12, 16, 21],
        [20, 21,  9, 13, 18,  2,  1, 21]], device='cuda:0')


In [19]:
from transformer_lens.hook_points import HookPoint
from transformer_lens import ActivationCache

_, cache_first_half = model_orig.run_with_cache(toks_first_half, names_filter=lambda name: 'resid' in name)

def patch_resid_stream(resid: Float[Tensor, 'batch pos d_model'], hook: HookPoint,
                       cache: ActivationCache = cache_first_half):
    resid[:, 1] = cache[hook.name][:, -1]
    return resid

def last_token_accuracy(logits: Float[Tensor, 'batch pos vocab'], target: Int[Tensor, 'batch pos']) -> float:
    return (logits[:, -1].argmax(dim=-1) == target[:, -1]).float().mean()

In [21]:
logits_baseline = model_second(toks_second_half)
logits_patch = model_second.run_with_hooks(toks_second_half,
                                           fwd_hooks=[(get_act_name('resid_mid', 0), patch_resid_stream)])

print(f'Unpatched accuracy: {last_token_accuracy(logits_baseline, data.target)}')
print(f'Patched accuracy: {last_token_accuracy(logits_patch, data.target)}')

Unpatched accuracy: 0.7599999904632568
Patched accuracy: 0.7599999904632568


In [18]:
import circuitsvis as cv

_, cache_second = model_second.run_with_cache(toks_second_half)
batch_idx = 85
attn_pattern = einops.rearrange(cache_second.stack_activation('pattern'), 'l b h p q -> b (l h) p q')

cv.attention.attention_patterns(
    attention=attn_pattern[batch_idx],
    tokens=data.to_str_toks(toks_second_half)[batch_idx],
    attention_head_names=[f"H{layer}.{i}" for layer in range(model_orig.cfg.n_layers) for i in range(model_orig.cfg.n_heads)])