## Setup

In [1]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

import os;

os.environ["ACCELERATE_ENABLE_RICH"] = "0"

os.environ['WANDB_NOTEBOOK_NAME'] = os.path.basename(globals()['__vsc_ipynb_file__']) 

# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"


import sys
from functools import partial
import json
from typing import List, Tuple, Union, Optional, Callable, Dict
import torch as t
from torch import Tensor
from sklearn.linear_model import LinearRegression
import numpy as np
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import einops
from tqdm import tqdm
from jaxtyping import Float, Int, Bool
from pathlib import Path
import pandas as pd
import circuitsvis as cv
import webbrowser
from IPython.display import display
from transformer_lens import utils, ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint
from transformer_lens.components import LayerNorm
from eindex import eindex


from sklearn.cluster import DBSCAN, HDBSCAN, OPTICS
from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_score

import functools
import matplotlib.pyplot as plt

# t.set_grad_enabled(False)

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "monthly_algorithmic_problems" / "september23_sum"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from monthly_algorithmic_problems.september23_sum.model import create_model
from monthly_algorithmic_problems.september23_sum.training import train, TrainArgs
from monthly_algorithmic_problems.september23_sum.dataset import SumDataset,Pairs
from plotly_utils import hist, bar, imshow

# Running this on a macbook air and mps is flaky
device = t.device("cpu") #t.device("cuda" if t.cuda.is_available() else "cpu")


MAIN = __name__ == "__main__"

## Dataset

-- Note: I rewrote the Dataset class so that it balances no carry, plain carry and cascading carry  classes evenly. this puts a larger emphasis difficult cases relative to the original way this dataset was written. With random additions, cascading carry is an edge case that is quite infrequent.  

In [2]:
dataset = SumDataset(size=1000, num_digits=4).to(device)

## Transformer
-- Note: I removed weight decay

In [3]:
filename = section_dir / "sum_model_normal.pt" # note this was trained on a mac in cpu mode without cuda and mps
args = TrainArgs(
    num_digits=4,
    trainset_size=100_000,
    valset_size=5_000,
    epochs=100,
    batch_size=512,
    lr_start=2e-3,
    lr_end=1e-4,
    # weight_decay=0.001, # not weight decay, could add this back in 
    weight_decay=0.00,
    seed=42,
    d_model=48,
    d_head=24,
    n_layers=2,
    n_heads=3,
    d_mlp=None,
    normalization_type="LN",
    use_wandb=True,
    device=device,
)
model = create_model(
    num_digits=4,
    seed=0,
    d_model=48,
    d_head=24,
    n_layers=2,
    n_heads=3,
    normalization_type="LN",
    d_mlp=None,
    device=device
)
model.load_state_dict(t.load(filename))

# model = train(args)
# t.save(model.state_dict(), filename)

<All keys matched successfully>

## Cluster Code



In [4]:
from torch.utils.data import Dataset
from matplotlib.patches import Rectangle
from matplotlib.colors import ListedColormap


def _wh_hook(activation,hook,replacement_value,lh_word,head):

    if (lh_word is not None and head is not None) : 
        activation[:, lh_word, head, :] = replacement_value
    elif (lh_word is not None) :
        activation[:, lh_word, :]  = replacement_value
    else :
        activation = replacement_value

    return activation

def wh_hook( replacement_value, lh_word = None, head = None) :
    return functools.partial(_wh_hook,replacement_value=replacement_value, lh_word=lh_word, head=head)

class ClusterState():
    """
    A class to handle clustering and interpretation of transformer activations.

    Attributes:
    - model: HookedTransformer model.
    - dataset: The dataset for the model.
    - layers, heads, words, ndims: Model configuration parameters.
    - pred_i: Indices of next-word predicted outputs.
    - min_samples, min_cluster_size: Parameters for clustering.
    - layer: The specific layer of the model to analyze.
    - results: Stores results of clustering.

    Methods:
    - __init__: Constructor for the class.
    - cluster_values: Clusters values and calculates error after substitution.
    - generate_test_data: Generates a list of activations that can be clustered
    - filter_words: Determines the words that impact predictions.
    - calc_centroid: Calculates the centroid for each cluster
    - find_clusters: Finds clusters in the data.
    - substitute_centroid: Patches centroid values into transformer and tabluates errors.
    - errors: Calculates prediction errors.
    - interpret_cluster: Interprets and prints information about clusters.
    """
    
    def __init__(self, model, dataset, min_samples=5, min_cluster_size=5, pred_i : List[int] = [10,11,12,13] , layer=1 ):
        self.model   = model
        self.dataset = dataset
        self.layers  = model.cfg.n_layers
        self.heads   = model.cfg.n_heads
        self.words   = model.cfg.n_ctx
        self.ndims   = model.cfg.d_model
        self.pred_i  = pred_i
        self.min_samples      = min_samples
        self.min_cluster_size = min_cluster_size
        self.layer = layer  

        print(f"Dataset size={dataset.shape[0]} examples")
        print(f"Model has {self.layers} layers, {self.heads} heads, {self.words} words, {self.ndims} dimensions")


        model.reset_hooks()
        self.ctrl_logits,self.ctrl_cache  = model.run_with_cache(dataset)
        self.ctrl_preds = self.ctrl_logits.softmax(dim=-1).argmax(dim=-1)
        model.reset_hooks()


        self.filtered_words  = self.filter_words(layer)

        print(f"input words that do not impact predictions {[i for i in range(self.words) if i not in self.filtered_words]}")

        print(f"input words that do impact predictions {self.filtered_words}" )
        
        print(f"found clusters layer={layer}")

        self.results = {}
        for lh_word in self.filtered_words :
            print(f"word={lh_word}")
            all_test_data = self.generate_test_data(self.layer, lh_word=lh_word)
            for data in all_test_data : 
                 cluster_labels, n_clusters, silhouette_avg, err = self.cluster_values(self.layer, lh_word, data=data[0], activation_name=data[1], head=data[2] )
                 if (err.sum().item()==0) :
                    print( f"     head={data[2]} activation={data[1]} clusters={n_clusters} silhouette_avg={silhouette_avg:.2f} err={err.sum().item()}")
                    self.results[ ( lh_word, data[1], data[2]) ] = (cluster_labels, n_clusters, silhouette_avg, err)


    def cluster_values(self, layer, lh_word, data=None, activation_name=None , head=None) :
        cluster_labels, n_clusters, silhouette_avg = self.find_clusters(data)
        rep_values = self.calc_centroid(data, cluster_labels)   
        return cluster_labels, n_clusters, silhouette_avg, self.substitute_centoids(layer, activation_name, rep_values, lh_word=lh_word, head=head)

    def generate_test_data(self, layer, lh_word=None) :
        test_data  = []
        for head in range(self.heads) :
            test_data.append( (self.ctrl_cache["q" , layer][:, lh_word, head, :], "q" , head)  )  
            test_data.append( (self.ctrl_cache["k" , layer][:, lh_word, head, :], "k" , head)  )    
            test_data.append( (self.ctrl_cache["v" , layer][:, lh_word, head, :], "v" , head)  )                
        test_data.append( (self.ctrl_cache["normalized" , layer, "ln1" ][:, lh_word, 0, :], "normalized" , 0)  )    
        test_data.append( (self.ctrl_cache["resid_pre"  , layer][:, lh_word, :], "resid_pre" , None)  )   
        return test_data


    def filter_words(self, layer, activation_name = "resid_pre") :
        filtered_words    = []
        for lh_word in range(self.words) :
            r_value = self.ctrl_cache[activation_name, layer ][:,lh_word,:].mean(0,keepdim=True)
            err = self.substitute_centoids( layer, activation_name , r_value, lh_word=lh_word ) 
            if (err.sum().item() > 0 ) :
                filtered_words.append( lh_word )

        return filtered_words

    def calc_centroid(self, data, cluster_labels) :
        centroids = t.vstack([data[cluster_labels == i].mean(0) for i in cluster_labels.unique()])
        replacement_value = centroids[ cluster_labels ]
        replacement_value[ cluster_labels < 0] = data[ cluster_labels < 0]
        return replacement_value

    def find_clusters(self,data):
        data = data.numpy()
        dbscan = HDBSCAN( min_samples=self.min_samples,min_cluster_size=self.min_cluster_size)
        cluster_labels = dbscan.fit_predict(data)
        n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
        silhouette_avg = silhouette_score(data, cluster_labels) if n_clusters > 1 else 0
        cluster_labels= t.from_numpy(cluster_labels).to(device)
        return cluster_labels, n_clusters, silhouette_avg

    def substitute_centoids(self, layer, activation_name, replacement_value, lh_word=None, head=None) :

        self.model.reset_hooks()
        if (activation_name=="normalized") :
            self.model.add_hook( utils.get_act_name(activation_name, layer, "ln1"), wh_hook( replacement_value, lh_word, head)  )
        else :
            self.model.add_hook( utils.get_act_name(activation_name, layer), wh_hook( replacement_value, lh_word, head)  )

        logits_ablate, _ = model.run_with_cache(self.dataset)
        self.model.reset_hooks()
        preds = logits_ablate.softmax(dim=-1).argmax(dim=-1)
        err    = t.zeros( (self.dataset.shape[0]), dtype= t.bool)
        for index, outcome_word in enumerate(self.pred_i) :
            err =  err | self.errors( preds, outcome_word)  
        return err

    def errors(self, preds, outcome_word) :
        return self.ctrl_preds[:,outcome_word] != preds[:,outcome_word]


    # Note: See Pair class in dataset.py This prints out some information specific to the 4 digit addition problem
    def interpret_cluster(self,word,activation_name,head=None) :

        cluster_labels, n_clusters, silhouette_avg, err  = self.results[ ( word, activation_name, head) ] 

        print(f"interpreting clusters at layer={self.layer} word={word} activation{activation_name} head={head}")
        print(f"    clusters={n_clusters}, silhouette_avg={silhouette_avg:.2f} err={err.sum().item()}")
        for i in cluster_labels.unique() :

            ds = dataset.toks[cluster_labels==i] 
            ps = dataset.p[cluster_labels==i]
            print("")
            print(f"        Cluster {i} - printing 5 of {ps.shape[0]} examples")
            for i in range(5) :
                p = Pairs.p[ps[i].item()]
                if p[1] == 2 or (p[1]==1 and p[2]==2) or (p[1]==1 and p[2]==1 and p[3]==2) :
                    print( "        [  carry ]" , ds[i,:].detach().numpy() )
                else :
                    print( "        [no carry]" , ds[i,:].detach().numpy()  )


In [5]:
cluster_state = ClusterState(model,dataset.toks,min_samples=5,min_cluster_size=20 )




Dataset size=1000 examples
Model has 2 layers, 3 heads, 15 words, 48 dimensions
input words that do not impact predictions [0, 2, 3, 4, 8, 14]
input words that do impact predictions [1, 5, 6, 7, 9, 10, 11, 12, 13]
found clusters layer=1
word=1
     head=0 activation=q clusters=9 silhouette_avg=1.00 err=0
     head=0 activation=k clusters=9 silhouette_avg=1.00 err=0
     head=0 activation=v clusters=9 silhouette_avg=1.00 err=0
     head=1 activation=q clusters=9 silhouette_avg=1.00 err=0
     head=1 activation=k clusters=9 silhouette_avg=1.00 err=0
     head=1 activation=v clusters=9 silhouette_avg=1.00 err=0
     head=2 activation=q clusters=9 silhouette_avg=1.00 err=0
     head=2 activation=k clusters=9 silhouette_avg=1.00 err=0
     head=2 activation=v clusters=9 silhouette_avg=1.00 err=0
     head=0 activation=normalized clusters=9 silhouette_avg=1.00 err=0
     head=None activation=resid_pre clusters=9 silhouette_avg=1.00 err=0
word=5
     head=0 activation=q clusters=16 silhouette

In [6]:
cluster_state.interpret_cluster( 10,"q", 1)




interpreting clusters at layer=1 word=10 activationq head=1
    clusters=3, silhouette_avg=0.49 err=0

        Cluster 0 - printing 5 of 340 examples
        [no carry] [12  3  1  9  0 10  4  5  7  7 11  7  7  6  7]
        [no carry] [12  0  2  8  6 10  6  3  1  3 11  6  5  9  9]
        [no carry] [12  6  0  1  0 10  0  3  8  5 11  6  3  9  5]
        [no carry] [12  2  1  9  1 10  3  6  9  8 11  5  8  8  9]
        [no carry] [12  3  1  9  1 10  4  5  9  9 11  7  7  9  0]

        Cluster 1 - printing 5 of 320 examples
        [no carry] [12  2  4  5  7 10  6  5  0  2 11  8  9  5  9]
        [  carry ] [12  7  6  3  2 10  1  3  6  9 11  9  0  0  1]
        [  carry ] [12  0  8  5  4 10  3  1  8  8 11  4  0  4  2]
        [no carry] [12  0  2  1  6 10  7  7  4  5 11  7  9  6  1]
        [  carry ] [12  3  1  4  8 10  1  8  9  9 11  5  0  4  7]

        Cluster 2 - printing 5 of 340 examples
        [  carry ] [12  5  6  2  5 10  3  7  5  5 11  9  3  8  0]
        [  carry ] [12  0  7