In [1]:
%load_ext autoreload
%autoreload 2

In [25]:
import sys

sys.path.append("../")

import torch
import transformers
import baukit
from tqdm.auto import tqdm
import json
import os
from src import functional
import src.tokens as tokenization_utils

torch.__version__, transformers.__version__, torch.version.cuda

('2.1.2+cu121', '4.36.2', '12.1')

In [3]:
from src.models import ModelandTokenizer

# MODEL_PATH = "EleutherAI/gpt-j-6B"
# MODEL_PATH = "meta-llama/Llama-2-7b-hf"
# MODEL_PATH = "mistralai/Mistral-7B-v0.1"
MODEL_PATH = "state-spaces/mamba-2.8b-slimpj" # state-spaces/mamba-2.8b


mt = ModelandTokenizer(
    model_path=MODEL_PATH, 
    torch_dtype=torch.float32
)

  return self.fget.__get__(instance, owner)()
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
# # Cache promoting tokens for all the `down_proj` neurons
# ---------------------------------------
cut_off_rank = 50
path = "neuron_prommotions/out_proj"
out_proj_path_format = "layers.{}.mixer.out_proj"
# ---------------------------------------

# os.makedirs(path, exist_ok=True)

# neuron_promotions = {layer_idx: {} for layer_idx in range(mt.n_layer)}

# for layer_idx in tqdm(range(mt.n_layer)):
#     print(f"layer {layer_idx}")
#     out_proj = baukit.get_module(mt.model, out_proj_path_format.format(layer_idx))

#     for column in tqdm(range(out_proj.weight.shape[1])):
#         next_tok_candidates = functional.logit_lens(
#             mt = mt, 
#             h = out_proj.weight[:, column],
#             k = cut_off_rank
#         )

#         neuron_promotions[layer_idx][column] = [
#             {"token": tok, "logit": logit} for tok, logit in next_tok_candidates
#         ]
    
#     with open(os.path.join(path, f"layer_{layer_idx}.json"), "w") as f:
#         json.dump(neuron_promotions[layer_idx], f)

In [5]:
def cache_weights(mt):
    weights_cached = {}
    for layer in range(mt.n_layer):
        out_proj = baukit.get_module(mt.model, out_proj_path_format.format(layer))
        weights_cached[layer] = out_proj.weight.clone().detach()
    return weights_cached

def restore_weights(mt, weights_cached):
    for layer in range(mt.n_layer):
        out_proj = baukit.get_module(mt.model, out_proj_path_format.format(layer))
        with torch.no_grad():
            out_proj.weight[...] = weights_cached[layer]

# weights_cached = cache_weights(mt)
# restore_weights(mt, weights_cached)
            
####################################
WEIGHTS_CACHED = cache_weights(mt)
####################################

In [23]:
########################################################################
# concepts = [
#     "doctor", "nurse", "therapist",
#     "healthcare", "medicine", "medical"
# ]
concepts = [
    "computer", "software", "engineer", "programmer", "developer", "hacker"
]
########################################################################

concept_start_token_ids= mt.tokenizer([
    " " + concept + " " for concept in concepts
], return_tensors="pt", padding=True).input_ids[:, 0].tolist()

[(id, mt.tokenizer.decode(id)) for id in concept_start_token_ids]

[(4382, ' computer'),
 (3694, ' software'),
 (16518, ' engineer'),
 (34513, ' programmer'),
 (13722, ' developer'),
 (48109, ' hacker')]

In [26]:
def load_neuron_promotions(path, layer_idx):
    with open(os.path.join(path, f"layer_{layer_idx}.json")) as f:
        loaded_dict = json.load(f)
        return {
            int(k): v for k, v in loaded_dict.items()
        }

neuron_promotions = {layer_idx: {} for layer_idx in range(mt.n_layer)}
for layer_idx in range(mt.n_layer):
    neuron_promotions[layer_idx] = load_neuron_promotions(path, layer_idx)

In [27]:
# Cache promoting tokens for all the `down_proj` neurons

from tqdm.auto import tqdm

cut_off_rank = 1
concept_drivers = {layer_idx: [] for layer_idx in range(mt.n_layer)}

for layer_idx in range(mt.n_layer):
    out_proj = baukit.get_module(mt.model, out_proj_path_format.format(layer_idx))

    found_neurons = []

    for column in range(out_proj.weight.shape[1]):
        # for t in concept_start_token_ids:
        #     if concept_ranks[t]['rank'] <= cut_off_rank:
        #         concept_driver_neurons.append({
        #             "layer": layer_idx,
        #             "neuron": column,
        #             "concept_ranks": concept_ranks,
        #         })
        #         break
        candidate_tokens = [
            candidate["token"] for candidate in neuron_promotions[layer_idx][column]
        ][:cut_off_rank]
        for target_token in concepts:
            found = False
            for candidate in candidate_tokens:
                if len(candidate.strip()) < 4:  # skip very short trivial tokens
                    continue
                if functional.is_nontrivial_prefix(
                    prediction=candidate, target=target_token
                ):
                    found = True
                    concept_drivers[layer_idx].append(column)
                    found_neurons.append(column)
                    break
            if found:
                break

    if len(found_neurons) > 0:
        print(
            f"found {len(found_neurons)} neurons in layer {layer_idx} > {found_neurons}"
        )
    # break

found 1 neurons in layer 0 > [162]
found 2 neurons in layer 1 > [909, 1673]
found 2 neurons in layer 3 > [1099, 3128]
found 3 neurons in layer 4 > [2397, 2500, 3044]
found 5 neurons in layer 5 > [1807, 2817, 3633, 3905, 4807]
found 3 neurons in layer 7 > [300, 668, 971]
found 3 neurons in layer 8 > [1438, 2096, 2157]
found 2 neurons in layer 10 > [3583, 3621]
found 1 neurons in layer 11 > [3978]
found 2 neurons in layer 12 > [277, 838]
found 1 neurons in layer 13 > [1099]
found 2 neurons in layer 14 > [2682, 3797]
found 1 neurons in layer 16 > [5026]
found 2 neurons in layer 17 > [39, 4357]
found 3 neurons in layer 19 > [986, 2131, 4470]
found 1 neurons in layer 22 > [250]
found 1 neurons in layer 23 > [4009]
found 2 neurons in layer 26 > [1789, 4829]
found 2 neurons in layer 29 > [1467, 3219]
found 1 neurons in layer 30 > [1806]
found 2 neurons in layer 31 > [717, 1426]
found 1 neurons in layer 33 > [2603]
found 2 neurons in layer 34 > [2975, 3543]
found 2 neurons in layer 36 > [3762,

In [28]:
# layer = 25
# neuron = 1799


# out_proj = baukit.get_module(mt.model, out_proj_path_format.format(layer))
# logits = mt.lm_head(out_proj.weight[:, neuron])

# logit_values = logits.sort(descending=True).values.detach().cpu().numpy()[:30]
# logit_tokens = logits.sort(descending=True).indices.detach().cpu().numpy()[:30]

# logit_tokens = ['"{}"'.format(mt.tokenizer.decode([t])) for t in logit_tokens]

# from matplotlib import pyplot as plt
# plt.bar(range(len(logit_values)), logit_values)
# plt.xticks(range(len(logit_values)), logit_tokens, rotation=90)

# plt.show()

In [29]:
# ! Don't forget to restore the weights
restore_weights(mt, WEIGHTS_CACHED)

prompt = "Marina works as a"
prompt = tokenization_utils.maybe_prefix_eos(mt.tokenizer, prompt)

functional.predict_next_token(
    mt = mt,
    prompt = prompt
)

[[PredictedToken(token=' freelance', prob=0.019778123125433922),
  PredictedToken(token=' nurse', prob=0.017641225829720497),
  PredictedToken(token=' teacher', prob=0.014872241765260696),
  PredictedToken(token=' professional', prob=0.013643579557538033),
  PredictedToken(token=' reception', prob=0.0121234692633152)]]

In [30]:
magnify_scale = 100

# ! Don't forget to restore the weights
restore_weights(mt, WEIGHTS_CACHED)

for layer in concept_drivers:
    for neuron in concept_drivers[layer]:
        out_proj = baukit.get_module(mt.model, out_proj_path_format.format(layer))
        with torch.no_grad():
            out_proj.weight[:, neuron] *= magnify_scale

functional.predict_next_token(
    mt = mt,
    prompt = prompt
)

[[PredictedToken(token=' Product', prob=0.21306917071342468),
  PredictedToken(token=' Engineering', prob=0.07157208025455475),
  PredictedToken(token=' Development', prob=0.04282478615641594),
  PredictedToken(token=' product', prob=0.04239372909069061),
  PredictedToken(token=' Group', prob=0.03424137085676193)]]

In [31]:
restore_weights(mt, WEIGHTS_CACHED)

functional.mamba_generate(
    mt = mt, 
    prompt = prompt,
    topk=1
).generation

[' freelance photographer and videographer. She has been working']