## Setup

In [1]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

import os;

os.environ["ACCELERATE_ENABLE_RICH"] = "0"

os.environ['WANDB_NOTEBOOK_NAME'] = os.path.basename(globals()['__vsc_ipynb_file__']) 

# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"


import sys
from functools import partial
import json
from typing import List, Tuple, Union, Optional, Callable, Dict
import torch as t
from torch import Tensor
from sklearn.linear_model import LinearRegression
import numpy as np
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import einops
from tqdm import tqdm
from jaxtyping import Float, Int, Bool
from pathlib import Path
import pandas as pd
import circuitsvis as cv
import webbrowser
from IPython.display import display
from transformer_lens import utils, ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint
from transformer_lens.components import LayerNorm
from eindex import eindex


from sklearn.cluster import DBSCAN, HDBSCAN, OPTICS
from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_score

import functools
import matplotlib.pyplot as plt

# t.set_grad_enabled(False)

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "monthly_algorithmic_problems" / "september23_sum"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from monthly_algorithmic_problems.september23_sum.model import create_model
from monthly_algorithmic_problems.september23_sum.training import train, TrainArgs
from monthly_algorithmic_problems.september23_sum.dataset import SumDataset,Pairs
from plotly_utils import hist, bar, imshow

# Running this on a macbook air and mps is flaky
device = t.device("cpu") #t.device("cuda" if t.cuda.is_available() else "cpu")

import faiss



MAIN = __name__ == "__main__"

## Dataset

-- Note: I rewrote the Dataset class so that it balances no carry, plain carry and cascading carry  classes evenly. this puts a larger emphasis difficult cases relative to the original way this dataset was written. With random additions, cascading carry is an edge case that is quite infrequent.  

In [2]:
dataset = SumDataset(size=3000, num_digits=4).to(device)

## Transformer
-- Note: I removed weight decay

In [3]:
filename = section_dir / "sum_model_normal.pt" # note this was trained on a mac in cpu mode without cuda and mps
args = TrainArgs(
    num_digits=4,
    trainset_size=100_000,
    valset_size=5_000,
    epochs=100,
    batch_size=512,
    lr_start=2e-3,
    lr_end=1e-4,
    # weight_decay=0.001, # not weight decay, could add this back in 
    weight_decay=0.00,
    seed=42,
    d_model=48,
    d_head=24,
    n_layers=2,
    n_heads=3,
    d_mlp=None,
    normalization_type="LN",
    use_wandb=True,
    device=device,
)
model = create_model(
    num_digits=4,
    seed=0,
    d_model=48,
    d_head=24,
    n_layers=2,
    n_heads=3,
    normalization_type="LN",
    d_mlp=None,
    device=device
)
model.load_state_dict(t.load(filename))

# model = train(args)
# t.save(model.state_dict(), filename)

<All keys matched successfully>

## Cluster Code



In [18]:
from torch.utils.data import Dataset
from matplotlib.patches import Rectangle
from matplotlib.colors import ListedColormap
from sklearn.metrics import silhouette_samples
from functools import partial
import random



class ClusterState():
    def __init__(self, model, dataset, min_samples=5, min_cluster_size=5, pred_i : List[int] = [10,11,12,13]  ):
        self.model   = model
        self.dataset = dataset
        self.layers  = model.cfg.n_layers
        self.heads   = model.cfg.n_heads
        self.words   = model.cfg.n_ctx
        self.ndims   = model.cfg.d_model
        self.attn_only = model.cfg.attn_only
        self.pred_i  = pred_i
        self.result = {}

        print(f"Dataset size={dataset.toks.shape[0]} examples")
        print(f"Model has {self.layers} layers, {self.heads} heads, {self.words} words, {self.ndims} dimensions")

        model.reset_hooks()
        self.ctrl_logits,self.ctrl_cache  = model.run_with_cache(dataset.toks)
        self.ctrl_preds = self.ctrl_logits.softmax(dim=-1).argmax(dim=-1)
        model.reset_hooks()

        self.filtered_words = []
        for layer in range(self.layers) :
            self.filtered_words.append( self.filter_words(layer) )
        self.filtered_words.append( self.pred_i )

        self.index = self.step_through_calc( self.create_index )

    
    def filter_words(self, layer, activation_name = "resid_pre") :
        return [word for word in range(self.words) if self.substitute_centoids(activation_name, layer, word=word, mean=True).sum().item() > 0]

    def step_through_calc(self, method ) :

        def update_index(index, activation_name, layer, head, word) :
            x = (activation_name, layer, head, word)
            index[x] = method(*x)    

        index = {}
        for layer in range(self.layers):
            for word in self.filtered_words[layer+1]  :
                for head in range(self.heads) :
                    update_index(index, "pattern-q", layer, head, word)
                    update_index(index, "result"   , layer, head, word)
            for word in self.filtered_words[layer+1]  :
                update_index(index, "attn_out", layer, None, word )
        return index


    def create_index(self, *args) :
        embeddings = self.embeddings(*args)
        index = faiss.IndexFlatL2(embeddings.shape[1])  # Use IndexFlatL2 for L2 distance
        index.add(embeddings.float())
        return index
    
    def search_index(self, search_dataset , top_k):
        model.reset_hooks()
        logits, cache  = model.run_with_cache(search_dataset)
        model.reset_hooks()
        def _search_index(*args):
            points = self.embeddings(*args, cache=cache)
            distances, indices = self.index[args].search( points , top_k )
            return indices.squeeze()
        return _search_index

    def pattern_words(self, pattern, threshold=.1 ) :
        return t.where(  pattern.max(dim=0).values > threshold )[0].tolist()
    

    def all_the_same(self, items):
        if not isinstance(items, t.Tensor):
            items = t.tensor(items)
        if items.eq(items[0]).all() :
            return True , f"{items[0].item()}"
        else :
            return False, items.detach().cpu().numpy()


    def d_apn(self, anchor, positive_indices, negative_indices, embeddings, top_k=10, n_rand=100, alpha=5.0):

        # compute_distance = lambda anchor, points : t.sqrt(t.sum((points - anchor) ** 2, dim=1))

        def compute_distance(anchor, points):
            anchor = anchor.unsqueeze(0)
            print( anchor.shape, points.shape , (points-anchor).shape, t.sqrt(t.sum((points - anchor) ** 2, dim=1)).mean().item() )

            return t.sqrt(t.sum((points - anchor) ** 2, dim=1)).mean().item()

        triplet_loss = lambda d_p, d_n :  ( d_p - d_n ) / d_n

        d_p  = compute_distance(embeddings[anchor], embeddings[positive_indices])
  

        if type(negative_indices) != np.ndarray  or len(negative_indices) == 0 :
            indices = random.sample(range(embeddings.size(0)), 2 * n_rand)
            d_mean = compute_distance( embeddings[indices[:n_rand]], embeddings[indices[n_rand:]] )
            return triplet_loss(d_p ,d_mean )
        
        d_n  = compute_distance(embeddings[anchor], embeddings[negative_indices])

        print( f"{d_p} {d_n} {d_p - d_n} {triplet_loss(d_p , d_n )}" ) 

        return triplet_loss(d_p , d_n )

    def n_matches(self, anchor, points, word, pos=True, top_k=None) :
        tok     = self.dataset.toks[anchor,word]
        toks    = self.dataset.toks[points,word]

        matched = t.where ((toks == tok) if pos else (toks != tok) )[0]

        if top_k is not None :
            matched = matched[:top_k]
            if (not pos and matched.shape[0] == 0 ) :
                return points[-top_k:]

        return points[matched]

    def contrastive_loss(self, results , top_k=10) :

        pos_points, neg_points = {} , {}
        for word in self.pred_i :
            close_points = results[("attn_out", self.layers-1, None, word )]
            anchor       = close_points[0]
            pos_points[word] = self.n_matches( anchor, close_points, word + 1, pos=True, top_k=top_k ) 
            neg_points[word] = self.n_matches( anchor, close_points, word + 1, pos=False, top_k=top_k  )


        def _contrastive_loss(*args) :
            losses = {}
            close_points = results[args]
            for word in self.pred_i :
                anchor = close_points[0]
                # negative_indices = self.n_matches( anchor, close_points[:20], word + 1, pos=False )
                # losses[word]     = self.d_apn( anchor, pos_points[word], negative_indices, self.embeddings(*args) ).item()
                losses[word]    = self.d_apn( anchor, pos_points[word], neg_points[word], self.embeddings(*args) )

            if args ==( "attn_out", 0, None, 1 )  :
                print(f"{args} {losses[10]:.2f} {losses[11]:.2f} {losses[12]:.2f} {losses[13]:.2f}")    
            return losses
        return _contrastive_loss

    
    def find_similatities(self, results) :

        def _similatities(*args) :            
            close_points = results[args][:10]

            scp  = [ [self.dataset.vocab[tok] for tok in self.dataset[point] ] for point in close_points ]

            words = [1,2,3,4,6,7,8,9,11,12,13,14]

            explanation = []


            #group 3 sum is correct
            #group 2 only carry is correct
            #group 1 digit is corrects
            #group 0 nothing is correct
            # "0": {
            #   "color": "#f3f3f3" nothing
            # },
            # "1": {
            #   "color": "#c5e898" digit
            # },
            # "2": {
            #   "color": "#29adb2" carry
            # },
            # "3": {
            #   "color": "#5fbdff" sum
            # }

            for word in words :
                print_it, result  = self.all_the_same(self.dataset[close_points,word])
                if (word>=11) :
                    color = "#5fbdff"
                else :
                    color = "#c5e898"

                if (print_it) :
                    for point in range(len(scp)) :
                        scp[point][word] = ( scp[point][word] , "", color )
                    explanation.append( f"{self.dataset.format_word(word)} = {result}" )
            for i, word in enumerate(self.pred_i) :
                w = word +1
                if i > 0 :
                    carry = [ Pairs.is_carry( self.dataset.p[point].item() , i  ) for point in close_points ]
                    print_it, result  = self.all_the_same(carry)
                    if (print_it) :
                        if ( len(scp[0][w]) == 3 ) :
                            for point in range(len(scp)) :
                                scp[point][w] = ( scp[point][w][0] , f"carry {result}" , scp[point][w][2] , "border = 2x dashed black") 
                        else:
                            for point in range(len(scp)) :      
                                scp[point][w] = ( scp[point][w]    , f"carry {result}" , "#29adb2"        , "border = 2x dashed black") 

                        explanation.append( f"{self.dataset.format_word(w)} carry {result}" )


            return ( '\n'.join(explanation) , scp )

        return  _similatities          

    # def dependency(self) :
    #     def _dependency(*args) :            
    #         dependency = []
    #         if ( args[0] == "attn_out" ) :
    #             for head in range(self.heads) :
    #                 dependency.append( ("result", args[1], head, args[3]) )
    #         elif ( args[0] == "result" ) :
    #             dependency.append( ("pattern-q", args[1], args[2], args[3]) )
    #         elif ( args[0] == "pattern-q" and args[1]>0 ) :
    #             embed   = self.embeddings(*args) 
    #             pwords  = self.pattern_words( embed )
    #             fwords  = self.filtered_words[args[1]]
    #             intersection = sorted(list(set(pwords) & set(fwords)))
    #             for i in intersection :
    #                 dependency.append( ("attn_out", args[1]-1, None, i) ) 
    #         return dependency
    #     return _dependency


    def dependency(self) :
        def _dependency(*args) :            
            dependency = []
            if   ( args[0] == "attn_out" ) :
                for head in range(self.heads) :
                    dependency.append( ("result", args[1], head, args[3]) )
            elif ( args[0] == "result" and args[1]>0 ) :
                args    = ("pattern-q", args[1], args[2], args[3])
                embed   = self.embeddings(*args) 
                pwords  = self.pattern_words( embed )
                fwords  = self.filtered_words[args[1]]
                intersection = sorted(list(set(pwords) & set(fwords)))
                for i in intersection :
                    dependency.append( ("attn_out", args[1]-1, None, i) ) 
            return dependency
        return _dependency
    
    def embeddings(self, activation_name, layer, head, word, cache=None):
        if (cache is None) :
            cache = self.ctrl_cache 

        an  = activation_name.split("-", 1)
        if (len(an) == 2) :
            activation = cache[ an[0], layer, an[1] ]
        else :
            activation = cache[ an[0], layer]
            
        if (an[0] == "pattern" or an[0] ==  "attn_scores") :
            if   (an[1] == "q") :
                return activation[:,head,word,:]   
            elif (an[1] == "v") :
                return activation[:,head,:,word] 
        elif (word is not None and head is not None) : 
            return activation[:, word, head, :] 
        elif (word is not None) :
            return activation[:, word, :] 
        else :
            return activation 
        
    def _wh_hook(self, activation,hook,an,head,word):

        if (an[0] == "pattern" or an[0] ==  "attn_scores") :
            if   (an[1] == "q") :
                activation[:,head,word,:] = activation[:,head,word,:].mean(0)   
            elif (an[1] == "v") :
                activation[:,head,:,word] = activation[:,head,:,word].mean(0)   
        elif (word is not None and head is not None) : 
            activation[:, word, head, :] = activation[:, word, head, :].mean(0) 
        elif (word is not None) :
            activation[:, word, :] = activation[:, word, :].mean(0) 
        else :
            activation = activation.mean(0) 

        return activation

    def wh_hook(self, an=None, head=None, word = None) :
        return functools.partial(self._wh_hook, an=an, head=head, word=word )

    def substitute_centoids(self, activation_name, layer, word=None, head=None, selection=slice(None), mean=False) :
        self.model.reset_hooks()

        an  = activation_name.split("-", 1)
        if (len(an) == 2) :
            self.model.add_hook( utils.get_act_name(an[0], layer, an[1]), self.wh_hook( an,head,word)  )
        else :
            self.model.add_hook( utils.get_act_name(an[0], layer       ), self.wh_hook( an,head,word)  )

        logits, cache = model.run_with_cache(self.dataset.toks[selection,:])
        self.model.reset_hooks()
        preds = logits.softmax(dim=-1).argmax(dim=-1)

        err = t.zeros((preds.shape[0]), dtype=t.bool)
        for outcome_word in self.pred_i:
            err = err | self.errors(preds, outcome_word, selection)

        return err
    
    def errors(self, preds, outcome_word, selection=slice(None)) :
        return self.ctrl_preds[selection,outcome_word] != preds[selection,outcome_word]


In [19]:
cluster_state = ClusterState(model,dataset  )
dependency    =  cluster_state.step_through_calc( cluster_state.dependency() )



Dataset size=3000 examples
Model has 2 layers, 3 heads, 15 words, 48 dimensions


In [20]:
import pickle

def create_examples(dataset, exs) :

    def create_example(ex) :
        example = [ dataset.vocab[tok] for tok in dataset.toks[ex] ] 
        carry   = [ Pairs.is_carry( dataset.p[ex] , i) for i,p in enumerate(cluster_state.pred_i) ]
        for i, word in enumerate(cluster_state.pred_i) :
            w = word +1   
            if i >0 :
                example[w] = ( example[w] , f"carry {carry[i]}", "#5fbdff" , "border = 2x dashed black" ) 
            else :
                example[w] = ( example[w] , f"", "#5fbdff"  )
        return example  

    examples = []
    for ex in range(exs) :
        example      = create_example(ex)
        results      = cluster_state.step_through_calc( cluster_state.search_index( dataset.toks[ex], 450 ) )
        similarities = cluster_state.step_through_calc( cluster_state.find_similatities( results ) )
        losses       = cluster_state.step_through_calc( cluster_state.contrastive_loss( results, 10 ) )
        examples.append( (example, similarities, dependency, losses) )

    with open( f"visualization/examples.pkl", 'wb') as file:
        pickle.dump(examples, file)

create_examples(dataset, 1)

torch.Size([1, 15]) torch.Size([10, 15]) torch.Size([10, 15]) 0.0
torch.Size([1, 15]) torch.Size([10, 15]) torch.Size([10, 15]) 0.3521224856376648
0.0 0.3521224856376648 -0.3521224856376648 -1.0
torch.Size([1, 15]) torch.Size([10, 15]) torch.Size([10, 15]) 0.37131446599960327
torch.Size([1, 15]) torch.Size([10, 15]) torch.Size([10, 15]) 0.4572181701660156
0.37131446599960327 0.4572181701660156 -0.08590370416641235 -0.18788339959284814
torch.Size([1, 15]) torch.Size([10, 15]) torch.Size([10, 15]) 0.235495924949646
torch.Size([1, 15]) torch.Size([10, 15]) torch.Size([10, 15]) 0.29147201776504517
0.235495924949646 0.29147201776504517 -0.05597609281539917 -0.1920461979321849
torch.Size([1, 15]) torch.Size([10, 15]) torch.Size([10, 15]) 0.12462140619754791
torch.Size([1, 15]) torch.Size([10, 15]) torch.Size([10, 15]) 0.29955023527145386
0.12462140619754791 0.29955023527145386 -0.17492882907390594 -0.5839715963347003
torch.Size([1, 48]) torch.Size([10, 48]) torch.Size([10, 48]) 0.0
torch.Siz

: 