In [1]:
import pandas as pd
import plotly.express as px
import torch
from einops import rearrange
import torch
import tqdm

In [2]:
from importlib import reload
from mamba_lens import HookedMamba
model = HookedMamba.from_pretrained("state-spaces/mamba-370m", device='cuda')
torch.set_grad_enabled(False)

  return self.fget.__get__(instance, owner)()
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Moving model to device:  cuda


<torch.autograd.grad_mode.set_grad_enabled at 0x7f6dbb95ba30>

In [3]:
def get_pad_token(tokenizer):
    return model.tokenizer.encode(model.tokenizer.pad_token, add_special_tokens=False)[0]

# given data that is [N,V] and indicies that are [N,K] with each index being an index into the V space
# this does what you'd want, it indexes them
# idk, see the test
def index_into(data, indices):
    num_data, num_per_data = indices.size()
    # we want
    # [0,0,0,...,] num per data of these
    # [1,1,1,...,] num per data of these
    # ...
    # [num_data-1, num_data-1, ...]
    first_axis_index = torch.arange(num_data, dtype=torch.long).view(num_data, 1)*torch.ones([num_data, num_per_data], dtype=torch.long)
    # now we flatten it so it has an index for each term aligned with our indices
    first_axis_index = first_axis_index.flatten()

    second_axis_index = indices.flatten()
    # now we can just index, and then view back to our original shape
    return data[first_axis_index, second_axis_index].view(num_data, num_per_data)
    

def eval(model, data, correct, incorrect, **kwargs):
        num_examples = correct.size()[0]
        logits = model(data, **kwargs)[:,-1]
        tops = torch.argsort(-logits, dim=1)
        pad = get_pad_token(tokenizer=model.tokenizer)
        prs = torch.nn.functional.softmax(logits, dim=1)
        prs[:,pad] = 0 # manually set pad pr to zero because sometimes we need to pad num correct or num incorrect
        #for i in range(tops.size()[0]):
        #    print(model.to_str_tokens([tops[i,0]]), tops[i,0], logits[i, tops[i,0]], prs[i, tops[i,0]])
        #    break
        # [n_data, n_correct]
        correct_prs = index_into(prs, correct)
        # [n_data, n_incorrect]
        incorrect_prs = index_into(prs, incorrect)
        # [n_data, 1]
        total_prs = correct_prs.sum(dim=1, keepdim=True)+incorrect_prs.sum(dim=1, keepdim=True)
        total_prs[total_prs == 0] = 1.0
        correct_prs /= total_prs
        incorrect_prs /= total_prs

        # [n_data, n_correct + n_incorrect]
        combined = torch.concatenate([correct_prs, incorrect_prs], dim=1)
        biggest = torch.argsort(-combined, dim=1)
        n_data, n_correct = correct.size()
        # if biggest pr is in the correct, we are correct, otherwise, we are not
        num_correct = torch.sum(biggest[:,0] < n_correct)
        # the sum(dim=1) is because we or of all the different possible probabilities by summing
        # then we'll just report the average
        return torch.mean(correct_prs.sum(dim=1)).item(), torch.mean(incorrect_prs.sum(dim=1)).item(), num_correct.item()/float(n_data)

def add_padding_answers(tokenizer, answers):
    longest_len = len(max(answers, key=lambda x: len(x)))
    padded_answers = []
    pad_token = get_pad_token(tokenizer=tokenizer)
    for answer in answers:
        padded_answers.append(answer + [pad_token]*(longest_len-len(answer)))
    return padded_answers

def get_batched_data(data):
    batched_data = []
    batched_correct = []
    batched_incorrect = []
    
    for i, (prompt, corrects, incorrects) in enumerate(data):
        if i < 3:
            print(prompt, corrects, incorrects)
        batched_data.append(torch.tensor(model.tokenizer.encode(prompt), device=model.cfg.device))
        batched_correct.append([model.tokenizer.encode(correct)[0] for correct in corrects])
        batched_incorrect.append([model.tokenizer.encode(incorrect)[0] for incorrect in incorrects])
    try:
        batched_data = torch.stack(batched_data)
        batched_correct = torch.tensor(add_padding_answers(tokenizer=model.tokenizer, answers=batched_correct))
        batched_incorrect = torch.tensor(add_padding_answers(tokenizer=model.tokenizer, answers=batched_incorrect))
    except RuntimeError:
        typical_len = len(batched_data[0])
        for s in batched_data:
            if not len(s) == typical_len:
                print(len(s), "is len of this, typical len is", typical_len, "for sequence", model.to_str_tokens(s))
        raise
    return batched_data, batched_correct, batched_incorrect

def bar_chart(data, x_labels, y_label, title, font_size=None):
    # it requires a pandas dict with the columns and rows named, annoying
    # by default rows and columns are named with ints so we relabel them accordingly
    renames = dict([(i, x_labels[i]) for i in range(len(x_labels))])
    ps = pd.DataFrame(data.cpu().numpy()).rename(renames, axis='rows').rename({0: y_label}, axis='columns')
    fig = px.bar(ps, y=y_label, x=x_labels, title=title)
    if not font_size is None:
        fig.update_layout(
          xaxis = dict(
            tickmode='array',
            tickvals = x_labels,
            ticktext = x_labels, 
            ),
           font=dict(size=font_size, color="black"))
        
        #fig.update_xaxes(title_font=dict(size=font_size))
    
    fig.show()

In [None]:
from docstring import docstring_prompt_generator_function
from importlib import reload
import test_data
reload(test_data)
from test_data import IOI_generator, BABA_TEMPLATES, greater_than_data_generator

out_acc = torch.zeros([model.cfg.n_layers], device=model.cfg.device)
out_prs_correct = torch.zeros([model.cfg.n_layers], device=model.cfg.device)
out_prs_incorrect = torch.zeros([model.cfg.n_layers], device=model.cfg.device)

num_examples = 100

seed = 27
valid_seed = 37
test_seed = 47

data_type = 'greater than'

if data_type == 'ioi':
    data = IOI_generator(templates=[BABA_TEMPLATES[0]], tokenizer=model.tokenizer, num_examples=num_examples, seed=seed)
    valid_data = IOI_generator(templates=[BABA_TEMPLATES[0]], tokenizer=model.tokenizer, num_examples=num_examples, seed=valid_seed)
    test_data = IOI_generator(templates=[BABA_TEMPLATES[0]], tokenizer=model.tokenizer, num_examples=num_examples, seed=test_seed)
elif data_type == 'docstring':
    data = docstring_prompt_generator_function(tokenizer=model.tokenizer, num_examples=num_examples, corrupt='random_answer', seed=seed)
    valid_data = docstring_prompt_generator_function(tokenizer=model.tokenizer, num_examples=num_examples, corrupt='random_answer', seed=valid_seed)
    test_data = docstring_prompt_generator_function(tokenizer=model.tokenizer, num_examples=num_examples, corrupt='random_answer', seed=test_seed)
elif data_type == 'greater than':
    data = greater_than_data_generator(tokenizer=model.tokenizer, num_examples=num_examples, seed=seed)
    valid_data = greater_than_data_generator(tokenizer=model.tokenizer, num_examples=num_examples, seed=valid_seed)
    test_data = greater_than_data_generator(tokenizer=model.tokenizer, num_examples=num_examples, seed=test_seed)

print("data")
batched_data, batched_correct, batched_incorrect = get_batched_data(data)
print("valid")
vbatched_data, vbatched_correct, vbatched_incorrect = get_batched_data(valid_data)
print("test")
tbatched_data, tbatched_correct, tbatched_incorrect = get_batched_data(test_data)

history = []
history_stats = []
layers_to_remove = []
while len(layers_to_remove) < model.cfg.n_layers:
    base_layers = list(range(model.cfg.n_layers))


    for layer in layers_to_remove:
        base_layers.remove(layer)
    history.append(list(base_layers))
    
    correct, incorrect, acc = eval(model, vbatched_data, vbatched_correct, vbatched_incorrect,
                                      only_use_these_layers=base_layers, fast_ssm=True, fast_conv=True)
    print(correct, incorrect, acc)
    history_stats.append((correct, incorrect, acc))
    print(base_layers)
    
    for i, start_layer in tqdm.tqdm(enumerate(base_layers)):
        #layers = list(range(start_layer, end_layer+1))
        layers = list(base_layers)
        layers.remove(start_layer)

        
        correct, incorrect, acc = eval(model, batched_data, batched_correct, batched_incorrect,
                                      only_use_these_layers=layers, fast_ssm=True, fast_conv=True)
        out_prs_correct[i] = correct
        out_prs_incorrect[i] = incorrect
        out_acc[i] = acc
        
    
    best_layer_to_remove = base_layers[torch.argsort(-out_acc[:len(base_layers)])[0]]
    print("removing layer", best_layer_to_remove)
    layer_names = [f'layer {x}' for x in base_layers]
    layers_to_remove.append(best_layer_to_remove)

history_stats = torch.tensor(history_stats)


data
nouns using ['accord', 'affair', 'agreement', 'appraisal', 'assaults', 'assessment', 'attack', 'attempts', 'campaign', 'case', 'challenge', 'chaos', 'clash', 'collaboration', 'coma', 'competition', 'confrontation', 'consequence', 'conspiracy', 'construction', 'consultation', 'contact', 'contract', 'convention', 'cooperation', 'custody', 'deal', 'decline', 'decrease', 'demonstrations', 'development', 'disagreement', 'disorder', 'dispute', 'domination', 'dynasty', 'effect', 'effort', 'employment', 'endeavor', 'engagement', 'epidemic', 'evaluation', 'exchange', 'existence', 'expansion', 'expedition', 'experiments', 'fall', 'fame', 'flights', 'friendship', 'growth', 'hardship', 'hostility', 'illness', 'impact', 'imprisonment', 'improvement', 'incarceration', 'increase', 'invasion', 'investigation', 'journey', 'kingdom', 'marriage', 'negotiation', 'obstruction', 'operation', 'order', 'outbreak', 'outcome', 'overhaul', 'plague', 'plan', 'practice', 'process', 'program', 'progress', 'pro

48it [00:12,  3.86it/s]


removing layer 1
0.9652735590934753 0.03472641110420227 1.0
[0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]


47it [00:11,  3.93it/s]


removing layer 2
0.9260562062263489 0.07394378632307053 1.0
[0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]


46it [00:11,  3.99it/s]


removing layer 3
0.9292219281196594 0.07077810913324356 1.0
[0, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]


45it [00:11,  4.08it/s]


removing layer 4
0.9292176365852356 0.07078230381011963 1.0
[0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]


44it [00:10,  4.16it/s]


removing layer 5
0.925422191619873 0.0745777040719986 1.0
[0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]


43it [00:10,  4.25it/s]


removing layer 6
0.9333218932151794 0.06667809933423996 1.0
[0, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]


42it [00:09,  4.34it/s]


removing layer 7
0.9566264748573303 0.04337349906563759 1.0
[0, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]


41it [00:09,  4.43it/s]


removing layer 8
0.9570193290710449 0.04298064485192299 1.0
[0, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]


40it [00:08,  4.53it/s]


removing layer 9
0.9584576487541199 0.041542358696460724 1.0
[0, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]


39it [00:08,  4.63it/s]


removing layer 10
0.9539596438407898 0.04604038968682289 1.0
[0, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]


36it [00:07,  4.76it/s]

In [None]:
bar_chart(history_stats[:,0], x_labels=[str(x) for x in history], y_label='relative pr of correct', title=f"{data_type} pruning layers relative pr correct", font_size=4)
bar_chart(history_stats[:,1], x_labels=[str(x) for x in history], y_label='relative pr of incorrect', title=f"{data_type} pruning layers relative pr incorrect", font_size=4)
bar_chart(history_stats[:,2], x_labels=[str(x) for x in history], y_label='accuracy', title=f"{data_type} pruning layers accuracy", font_size=4)