In [1]:
from huggingface_hub import hf_hub_download
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt
import numpy as np

In [3]:
from transformer_lens import HookedTransformer
model_name = "EleutherAI/pythia-70m-deduped"

model = HookedTransformer.from_pretrained(model_name, device=device)

Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


In [4]:
# Downnload dataset
from datasets import Dataset, load_dataset
dataset_name = "JeanKaddour/minipile"
token_amount= 40
#TODO: change train[:1000] to train if you want whole dataset
# 100_000 datasets
# I think that we want to use the full 100_000 at some point...
# dataset = load_dataset(dataset_name, split="train[:100000]").map(
dataset = load_dataset(dataset_name, split="train[:10000]").map( # 1_000 to get started
    lambda x: model.tokenizer(x['text']),
    batched=True,
).filter(
    lambda x: len(x['input_ids']) > token_amount
).map(
    lambda x: {'input_ids': x['input_ids'][:token_amount]}
)
# TODO: we can maybe make this faster for the larger dataset?

Found cached dataset parquet (/home/lev/.cache/huggingface/datasets/JeanKaddour___parquet/JeanKaddour--minipile-0d7d2d1ff79d1d36/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /home/lev/.cache/huggingface/datasets/JeanKaddour___parquet/JeanKaddour--minipile-0d7d2d1ff79d1d36/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-ad2ecfd158f710eb.arrow
Loading cached processed dataset at /home/lev/.cache/huggingface/datasets/JeanKaddour___parquet/JeanKaddour--minipile-0d7d2d1ff79d1d36/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-d173f762c78db5b8.arrow
Loading cached processed dataset at /home/lev/.cache/huggingface/datasets/JeanKaddour___parquet/JeanKaddour--minipile-0d7d2d1ff79d1d36/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-bd49b6b2e3856158.arrow


In [5]:
setting = "mlp_out" # "residual", "mlp", "attention", "mlp_out

def get_cache_name_neurons(layer: int):
    if setting == "residual":
        cache_name = f"blocks.{layer}.hook_resid_post"
        neurons = model.cfg.d_model
    elif setting == "mlp":
        cache_name = f"blocks.{layer}.mlp.hook_post"
        neurons = model.cfg.d_mlp
    elif setting == "attention":
        cache_name = f"blocks.{layer}.hook_attn_out"
        neurons = model.cfg.d_model
    elif setting == "mlp_out":
        cache_name = f"blocks.{layer}.hook_mlp_out"
        neurons = model.cfg.d_model
    else:
        raise NotImplementedError
    return cache_name, neurons

In [6]:
n_layers = model.cfg.n_layers
model.cfg.d_model, n_layers

(512, 6)

# Get Dictionary Activations

In [7]:
# TODO: in chunks...
# Now we can use the model to get the activations
from torch.utils.data import DataLoader
from datasets import DatasetDict
from tqdm.auto import tqdm
from einops import rearrange
import math

# MAX_CHUNK_SIZE = 1_000

# TODO: move to a separate file or something
def get_activations(layer: int):
    datapoints = dataset.num_rows
    embedding_size = model.cfg.d_model
    activations_final = np.memmap(f'layer-{layer}.mymemmap', dtype='float32', mode='w+', shape=(datapoints, token_amount, embedding_size))
    batch_size = 32

    with torch.no_grad(), dataset.formatted_as("pt"):
        dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
        cache_name = get_cache_name_neurons(layer)[0]
        for i, batch in enumerate(tqdm(dl)):
            # print(batch)
            _, cache = model.run_with_cache(batch.to(device))
            # print("AA", cache[cache_name].shape)
            # batched_neuron_activations = rearrange(cache[cache_name], "b s n -> (b s) n" )

            real_batch_size = batch.shape[0]
            activations_final[i*batch_size:i*batch_size + real_batch_size, :, :] = cache[cache_name].cpu().numpy()
    return activations_final

model_activations = [get_activations(layer) for layer in range(n_layers)]

  0%|          | 0/310 [00:00<?, ?it/s]

  0%|          | 0/310 [00:00<?, ?it/s]

  0%|          | 0/310 [00:00<?, ?it/s]

  0%|          | 0/310 [00:00<?, ?it/s]

  0%|          | 0/310 [00:00<?, ?it/s]

  0%|          | 0/310 [00:00<?, ?it/s]

## Get activations for a specific feature and visualize them

In [8]:
layer = 0
model_activations[0].shape, model_activations[layer].reshape(-1, model_activations[layer].shape[-1]).shape

((9909, 40, 512), (396360, 512))

In [9]:
from interp_utils import *
import torch
import numpy as np

# Get the activations for the best dict features
def get_feature_datapoints_with_idx(feature_index, dictionary_activations, tokenizer, token_amount, dataset, k=10, setting="max"):
    if len(dictionary_activations.shape) == 3:
        best_feature_activations = dictionary_activations[:, :, feature_index].flatten()
    else:
        best_feature_activations = dictionary_activations
    # Sort the features by activation, get the indices
    if setting=="max":
        # TODO:! Urrr.... is this backwards? CHECK IF ::-1 is correct but I think that it is
        found_indices = np.argsort(best_feature_activations)[::-1][:k]
        # found_indices = np.argsort(best_feature_activations, descending=True)[:k]
    elif setting=="uniform":
        # min_value = torch.min(best_feature_activations)
        min_value = np.min(best_feature_activations)
        max_value = np.max(best_feature_activations)

        # Define the number of bins
        num_bins = k

        # Calculate the bin boundaries as linear interpolation between min and max
        bin_boundaries = np.linspace(min_value, max_value, num_bins + 1)

        # Assign each activation to its respective bin
        # TODO: hmm
        # np bucketize?
        # bins = torch.bucketize(best_feature_activations, bin_boundaries)
        bins = np.digitize(best_feature_activations, bin_boundaries)

        # Initialize a list to store the sampled indices
        sampled_indices = []

        # Sample from each bin
        for bin_idx in np.unique(bins):
            if(bin_idx==0): # Skip the first one. This is below the median
                continue
            # Get the indices corresponding to the current bin
            bin_indices = np.array(np.nonzero(bins == bin_idx)).squeeze(axis=0)
            # print(bin_indices.shape)
            
            # Randomly sample from the current bin
            sampled_indices.extend(np.random.choice(bin_indices, size=1, replace=False))

        # Convert the sampled indices to a PyTorch tensor & reverse order
        found_indices = np.flip(np.array(sampled_indices), axis=[0])
    else: # random
        # get nonzero indices
        nonzero_indices = torch.nonzero(best_feature_activations)[:, 0]
        # shuffle
        shuffled_indices = nonzero_indices[torch.randperm(nonzero_indices.shape[0])]
        found_indices = shuffled_indices[:k]
    num_datapoints = int(dictionary_activations.shape[0])
    datapoint_indices =[np.unravel_index(i, (num_datapoints, token_amount)) for i in found_indices]
    text_list = []
    full_text = []
    token_list = []
    full_token_list = []
    for md, s_ind in datapoint_indices:
        md = int(md)
        s_ind = int(s_ind)
        full_tok = torch.tensor(dataset[md]["input_ids"])
        full_text.append(tokenizer.decode(full_tok))
        tok = dataset[md]["input_ids"][:s_ind+1]
        text = tokenizer.decode(tok)
        text_list.append(text)
        token_list.append(tok)
        full_token_list.append(full_tok)
    return text_list, full_text, token_list, full_token_list, found_indices

## Baseline before looking at "deconstructive interference"

In [25]:
import interp_utils
import importlib
importlib.reload(interp_utils)

feature = 69
layer = 0

text_list, full_text, token_list, full_token_list, indices = get_feature_datapoints_with_idx(feature, model_activations[layer], model.tokenizer, token_amount, dataset, setting="uniform", k=100)
interp_utils.visualize_text(text_list, feature, model, None, layer=layer, setting="model")

## Looking at constructive interference

In [11]:
k = 5_000
neuron_index = 69

In [12]:

def get_relevant_other_neurons(neuron_index: int, layer=0, k=100, weight_cutoff=1.2): # TODO: check weight cutoff vis a vis using quantified models
	# Get the input data-points that most activate the neuron
	best_feature_activations = model_activations[layer][:, :, neuron_index]
	summed_along_sentence = best_feature_activations.sum(axis=1)
	print(summed_along_sentence.shape)
	# Find the input data-points that most activate the neuron
	found_indices = np.argsort(summed_along_sentence)[::-1][:k]

	def get_activated_neurons(layer: int):
		neurons = set()
		for i in found_indices:
			cutoff_n = model_activations[layer][i, :, :] > weight_cutoff
			_pos_nonzero, neuron_nonzero = np.nonzero(cutoff_n)
			# print("LEN", neuron_nonzero.shape)
			neurons.update(neuron_nonzero)
		return list(neurons)
	
	other_layer_neurons = []
	for i in range(n_layers):
		if i != layer:
			other_layer_neurons.append((i, get_activated_neurons(i)))
	return other_layer_neurons

other_neurons = get_relevant_other_neurons(neuron_index, layer=0, k=k, weight_cutoff=2)
len(other_neurons[4][1]), print([len(i[1]) for i in other_neurons])

(9909,)
[11, 64, 91, 157, 508]


(508, None)

In [13]:
go_to_n_layers = 2

def get_auxiliary_data(neuron_index: int, layer: int):
	# TODO: this should be a different function!! (UTILS)
	# Get the input data-points that most activate the neuron
	best_feature_activations = model_activations[layer][:, :, neuron_index]
	summed_along_sentence = best_feature_activations.sum(axis=1)
	# Find the input data-points that most activate the neuron
	found_indices = np.argsort(summed_along_sentence)[::-1][:k]
	total_other_neurons = 0
	for other_neur in other_neurons:
		other_layer, neurons = other_neur
		if other_layer <= go_to_n_layers:
			total_other_neurons += len(neurons)
	# sum([len(i[1]) for i in other_neurons])

	concatenated = np.zeros((len(found_indices), total_other_neurons))

	counter = 0
	for other_neur in other_neurons:
		other_layer, neurons = other_neur
		if other_layer <= go_to_n_layers:
			r = model_activations[other_layer][:, :, neurons][found_indices].sum(axis=1) # Sum over the entire sentence/ text input
			concatenated[:, counter:counter+len(neurons)] = r
			counter += len(neurons)
		
	return concatenated, found_indices


aux_data, datapoints_used = get_auxiliary_data(neuron_index, layer=0)
# TODO: CONSIDER ONLY USING THE CLOSER LAYERS...
print(aux_data.shape), print(datapoints_used.shape)

(5000, 75)
(5000,)


(None, None)

In [14]:
def run_gaussian_mixture_model():
	# TODO:
	pass

In [28]:
import numpy as np
from sklearn.preprocessing import normalize
from sklearn.decomposition import PCA
from sklearn.decomposition import IncrementalPCA

# Cosine similarity function
def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

# KMeans with cosine similarity
def kmeans_cosine(X, n_clusters:int, iterations=100):
    k = n_clusters
    # Normalize input data
    X_normalized = normalize(X, axis=1)

    # Randomly initialize centroids
    n_samples, n_features = X_normalized.shape
    centroids = X_normalized[np.random.choice(n_samples, k, replace=False)]

    for iter in range(iterations):
        # Cluster assignment step
        clusters = [[] for _ in range(k)]
        for idx, x in enumerate(X_normalized):
            similarities = [cosine_similarity(x, centroid) for centroid in centroids]
            # similarities = [np.dot(x, centroid) for centroid in centroids]
            closest = np.argmax(similarities)
            clusters[closest].append(idx)

        # Update centroids
        # TODO: we maybe able to just **not use** PCA at all here.... slow it may be
        new_centroids = []
        for cluster in clusters:
            if cluster:  # Check if cluster is not empty
                new_centroid = np.mean(X_normalized[cluster], axis=0)
                new_centroids.append(new_centroid)
            else:
                new_centroids.append(np.random.rand(n_features))  # Reinitialize empty clusters

        new_centroids = np.array(new_centroids)
        if np.allclose(centroids, new_centroids):
            break
        centroids = new_centroids
        print("Done with iteration", iter)

    return centroids, clusters

# TODO: no function. Just on global so we can stop middway etc etc
# TODO: can we speed this up??? Maybe we use PCA
_, cluster_by_idx = kmeans_cosine(aux_data, iterations=400, n_clusters=30)

Done with iteration 0
Done with iteration 1
Done with iteration 2
Done with iteration 3
Done with iteration 4
Done with iteration 5
Done with iteration 6
Done with iteration 7
Done with iteration 8
Done with iteration 9
Done with iteration 10
Done with iteration 11
Done with iteration 12
Done with iteration 13
Done with iteration 14
Done with iteration 15
Done with iteration 16
Done with iteration 17
Done with iteration 18
Done with iteration 19
Done with iteration 20
Done with iteration 21
Done with iteration 22
Done with iteration 23
Done with iteration 24
Done with iteration 25
Done with iteration 26
Done with iteration 27
Done with iteration 28
Done with iteration 29
Done with iteration 30
Done with iteration 31
Done with iteration 32
Done with iteration 33
Done with iteration 34
Done with iteration 35
Done with iteration 36
Done with iteration 37
Done with iteration 38
Done with iteration 39
Done with iteration 40
Done with iteration 41
Done with iteration 42
Done with iteration 4

In [29]:
# TODO: GMMs
print([len(c) for c in cluster_by_idx])

[240, 45, 131, 111, 113, 274, 314, 163, 62, 195, 108, 229, 103, 358, 138, 138, 172, 202, 134, 95, 103, 43, 216, 219, 305, 190, 210, 81, 52, 256]


In [30]:
datapoints_used[4], len(dataset)

(1325, 9909)

In [33]:
cluster_idx = 20

cluster_inds = [int(datapoints_used[int(cluster_by_idx[cluster_idx][i])]) for i in range(len(cluster_by_idx[cluster_idx]))]
print(dataset[int(cluster_inds[0])]['text'][:100])

new_dataset = []
for i in cluster_inds:
	new_dataset.append(dataset[i])
# for i in range(len(cluster_by_idx[cluster_idx])):
# 	print(dataset[int(datapoints_used[int(cluster_by_idx[cluster_idx][i])])]['text'][:100])
# 	print("+++++++++++++++++++++++++++++++++++++++++++++++++")
# TODO: this now wrong?
activations = model_activations[layer][cluster_inds, :, :]
# .reshape(-1, model_activations[layer].shape[-1])[cluster, :], 1)
print(activations.shape)
text_list, full_text, token_list, full_token_list, indices = get_feature_datapoints_with_idx(neuron_index, activations, model.tokenizer, token_amount, new_dataset, setting="uniform", k=30)
interp_utils.visualize_text(text_list, feature, model, None, layer=layer, setting="model")
# TODO: maybe do everything on MLP side where we get only positive activations

# -*- coding: utf-8 -*-
#
#   pyhwp : hwp file format parser in python
#   Copyright (C) 2010-2015 m
(103, 40, 512)
