# Single Digit Add Memorization

In [1]:
import os
import random

import torch
import numpy as np
from transformers import AutoTokenizer

In [2]:
import transformer_lens.patching as patching
from transformer_lens import HookedTransformer, ActivationCache

In [3]:
import transformer_lens.utils as utils

In [77]:
from functools import partial

In [4]:
def seed_everything(seed: int):    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [5]:
def arith_probs(dig1: int, dig2: int, n=1000, sub=False):
    """Generate binary arithmetic problems (both addition and subtraction).

    Args:
        dig1 (int): the number of digits in the first operand.
        dig2 (int): the number of digits in the second operand.
        n (int): the total number of problems.
        sub (bool): if true, the operation of the problem becomes subtraction;
            otherwise, generates only addition problems.
    Returns:
        a list of tuple of three integers (op1, op2, ans) where op1, op2 
        have dig1, dig2 number of digits, respectively; 
        and op1 +/- op2 = ans. Note that op1 >= 0 and op2 >= 0.
    """
    probs = []
    for _ in range(n):
        a = random.randint(10 ** (dig1 - 1), 10 ** dig1 - 1)
        b = random.randint(10 ** (dig2 - 1), 10 ** dig2 - 1)
        ans = a + b if not sub else a - b
        probs.append((a,b,ans))
    return probs

In [18]:
def fewshot_probs(probs, sub=False, k=2):
    """Tokenize a list of problems in the form of (a, b, ans).
    
    Args:
        tokenizer: model tokenizer
        probs (List[Tuple[int, int, int]]): a list of tuples of the form
            (a, b, ans). 
        sub (bool): whether the problem is subtract
        k (int): number of examples given to the model for in-context learning
    Returns:
        list of problems have been tokenized. 
    """
    convert_few = lambda x: f"{x[0]} {'-' if sub else '+'} {x[1]} = {x[2]}"
    str_probs = []
    for i, p in enumerate(probs):
        # sample the few shot problems
        few_shot_examples = [convert_few(v) for v in random.sample(probs, k=k)]
        # str_probs.append("100 + 200 = 300\n520 + 890 = 1410" + "\n" + convert_prob(p))
        str_probs.append("\n".join(few_shot_examples) + "\n" + convert_few(p))
    # return str_probs
    return str_probs

In [7]:
seed_everything(42)

In [8]:
toks = AutoTokenizer.from_pretrained("google/gemma-2-2b", padding_side="left")
toks.pad_token = toks.bos_token

In [9]:
m = HookedTransformer.from_pretrained("gemma-2b")

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model gemma-2b into HookedTransformer


In [19]:
probs = arith_probs(1,1,n=100)
fewshot_probs = fewshot_probs(probs)

In [112]:
def answer_logit_indices(tokenized_input, problems):
    # get the first index of the answer in each problem
    cindices = list(map(lambda x: x.rfind(" ") + 1, problems))
    indices = []
    answer_indices = []
    for i in range(len(problems)):
        om = tokenized_input["offset_mapping"][i]
        counter = len(om)-1
        idxs = []
        ans_idxs = []
        for j in range(len(om)-1,-1,-1):
            if om[j][0] >= cindices[i]:
                idxs.append(counter)
                ans_idxs.append(tokenized_input["input_ids"][i][counter].item())
                counter -= 1
            else:
                break
        indices.append(idxs)
        answer_indices.append(ans_idxs)
    return indices, answer_indices

In [113]:
tokenized_input = toks(fewshot_probs, return_offsets_mapping=True, padding=True, return_tensors="pt")
idxs, ans_idx = answer_logit_indices(tokenized_input, fewshot_probs)

In [114]:
# pad both idx and ans_idxs with zeros
maxlen = lambda x: max(x, key=lambda y: len(y))
max_idx_len, ans_idx_len = maxlen(idxs), maxlen(ans_idxs)

In [115]:
mil = len(max_idx_len)
ail = len(ans_idxs_len)
for i in range(len(idxs)):
    if len(idxs[i]) < mil:
        idxs[i] = [0] * (mil - len(idxs[i])) + idxs[i]
    if len(ans_idx[i]) < ail:
        ans_idx[i] = [0] * (ail - len(ans_idx[i])) + ans_idx[i]

In [116]:
idxs, ans_idx = torch.LongTensor(idxs), torch.LongTensor(ans_idx)

In [124]:
clogits_noise[:,idxs,:].shape

torch.Size([100, 100, 2, 256000])

In [47]:
def corr_hook_noise(clean, hook):
    che = torch.std(clean).to("cpu") * 3
    clean = clean + torch.normal(torch.zeros(clean.shape), che * torch.ones(clean.shape)).to(clean.device)
    return clean

In [52]:
clean_logits, clean_cache = m.run_with_cache(tokenized_input["input_ids"])

OutOfMemoryError: CUDA out of memory. Tried to allocate 170.00 MiB. GPU 0 has a total capacity of 44.34 GiB of which 138.81 MiB is free. Including non-PyTorch memory, this process has 44.20 GiB memory in use. Of the allocated memory 43.25 GiB is allocated by PyTorch, and 653.00 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [50]:
with m.hooks(fwd_hooks=[("hook_embed", corr_hook_noise)]):
    noise_logits, noise_cache = m.run_with_cache(tokenized_input["input_ids"])
    noise_logits = clogits_noise.to("cpu")
    noise_cache = ccache_noise.to("cpu")

In [143]:
def logit_diff(patched_logits, clean_logits, corrupted_logits, answer_token_idxs=idxs, answer_token_ids=ans_idx):
    mask = (answer_token_idxs != 0)
    n = torch.arange(len(answer_token_idxs)).unsqueeze(1).expand(len(answer_token_idxs), answer_token_idxs.shape[1])
    clean_ans_logits = clean_logits[n,answer_token_idxs,answer_token_ids]
    corr_ans_logits = corrupted_logits[n,answer_token_idxs,answer_token_ids]
    pat_ans_logits = patched_logits[n,answer_token_idxs,answer_token_ids]
    pert_change = (pat_ans_logits - corr_ans_logits) / (clean_ans_logits - corr_ans_logits)
    return torch.sum(mask * pert_change) / torch.sum(mask)

In [None]:
metric = partial(logit_diff(