## Setup

In [3]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

import os;

os.environ["ACCELERATE_ENABLE_RICH"] = "0"

os.environ['WANDB_NOTEBOOK_NAME'] = os.path.basename(globals()['__vsc_ipynb_file__']) 

# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"


import sys
from functools import partial
import json
from typing import List, Tuple, Union, Optional, Callable, Dict
import torch as t
from torch import Tensor
from sklearn.linear_model import LinearRegression
import numpy as np
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import einops
from tqdm import tqdm
from jaxtyping import Float, Int, Bool
from pathlib import Path
import pandas as pd
import circuitsvis as cv
import webbrowser
from IPython.display import display
from transformer_lens import utils, ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint
from transformer_lens.components import LayerNorm
from eindex import eindex


from sklearn.cluster import DBSCAN, HDBSCAN, OPTICS
from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_score

import functools
import matplotlib.pyplot as plt

# t.set_grad_enabled(False)

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "monthly_algorithmic_problems" / "september23_sum"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from monthly_algorithmic_problems.september23_sum.model import create_model
from monthly_algorithmic_problems.september23_sum.training import train, TrainArgs
from monthly_algorithmic_problems.september23_sum.dataset import SumDataset,Pairs
from plotly_utils import hist, bar, imshow

# Running this on a macbook air and mps is flaky
device = t.device("cpu") #t.device("cuda" if t.cuda.is_available() else "cpu")

import faiss



MAIN = __name__ == "__main__"

## Dataset

-- Note: I rewrote the Dataset class so that it balances no carry, plain carry and cascading carry  classes evenly. this puts a larger emphasis difficult cases relative to the original way this dataset was written. With random additions, cascading carry is an edge case that is quite infrequent.  

In [4]:
dataset = SumDataset(size=3000, num_digits=4).to(device)

## Transformer
-- Note: I removed weight decay

In [5]:
filename = section_dir / "sum_model_normal.pt" # note this was trained on a mac in cpu mode without cuda and mps
args = TrainArgs(
    num_digits=4,
    trainset_size=100_000,
    valset_size=5_000,
    epochs=100,
    batch_size=512,
    lr_start=2e-3,
    lr_end=1e-4,
    # weight_decay=0.001, # not weight decay, could add this back in 
    weight_decay=0.00,
    seed=42,
    d_model=48,
    d_head=24,
    n_layers=2,
    n_heads=3,
    d_mlp=None,
    normalization_type="LN",
    use_wandb=True,
    device=device,
)
model = create_model(
    num_digits=4,
    seed=0,
    d_model=48,
    d_head=24,
    n_layers=2,
    n_heads=3,
    normalization_type="LN",
    d_mlp=None,
    device=device
)
model.load_state_dict(t.load(filename))

# model = train(args)
# t.save(model.state_dict(), filename)

<All keys matched successfully>

## Cluster Code



In [6]:
from torch.utils.data import Dataset
from matplotlib.patches import Rectangle
from matplotlib.colors import ListedColormap
from sklearn.metrics import silhouette_samples
from functools import partial


class ClusterState():
    def __init__(self, model, dataset, min_samples=5, min_cluster_size=5, pred_i : List[int] = [10,11,12,13]  ):
        self.model   = model
        self.dataset = dataset
        self.layers  = model.cfg.n_layers
        self.heads   = model.cfg.n_heads
        self.words   = model.cfg.n_ctx
        self.ndims   = model.cfg.d_model
        self.attn_only = model.cfg.attn_only
        self.pred_i  = pred_i
        self.result = {}

        print(f"Dataset size={dataset.shape[0]} examples")
        print(f"Model has {self.layers} layers, {self.heads} heads, {self.words} words, {self.ndims} dimensions")

        model.reset_hooks()
        self.ctrl_logits,self.ctrl_cache  = model.run_with_cache(dataset)
        self.ctrl_preds = self.ctrl_logits.softmax(dim=-1).argmax(dim=-1)
        model.reset_hooks()

        self.filtered_words = []
        for layer in range(self.layers) :
            self.filtered_words.append( self.filter_words(layer) )
        self.filtered_words.append( self.pred_i )

        self.index = self.step_through_calc( self.create_index )

    
    def filter_words(self, layer, activation_name = "resid_pre") :
        return [word for word in range(self.words) if self.substitute_centoids(activation_name, layer, word=word, mean=True).sum().item() > 0]

    def step_through_calc(self, method ) :

        def update_index(index, activation_name, layer, head, words) :
            for word in words :
                x = (activation_name, layer, head, word)
                index[x] = method(*x)    

        index = {}
        for layer in range(self.layers):
            for head in range(self.heads) :
                # update_index(index, "v", layer, head, self.filtered_words[layer] )
                update_index(index, "pattern-q", layer, head, self.filtered_words[layer+1] )
                update_index(index, "result"   , layer, head, self.filtered_words[layer+1] )
            update_index(index, "attn_out", layer, None, self.filtered_words[layer+1] )
        return index


    def create_index(self, *args) :
        embeddings = self.embeddings(*args)
        index = faiss.IndexFlatL2(embeddings.shape[1])  # Use IndexFlatL2 for L2 distance
        index.add(embeddings.float())
        return index
    
    def search_index(self, search_dataset , top_k):
        model.reset_hooks()
        logits, cache  = model.run_with_cache(search_dataset)
        model.reset_hooks()
        def _search_index(*args):
            points = self.embeddings(*args, cache=cache)
            distances, indices = self.index[args].search( points , top_k )
            return indices.squeeze()
        return _search_index

    def pattern_words(self, pattern, threshold=.1 ) :
        return t.where(  pattern.max(dim=0).values > threshold )[0].tolist()
    

    def all_the_same(self, items):
        if not isinstance(items, t.Tensor):
            items = t.tensor(items)
        if items.eq(items[0]).all() :
            return True , f"{items[0].item()}"
        else :
            return False, items.detach().cpu().numpy()
    
    def find_similatities(self, results) :

        def _similatities(*args) :            
            close_points = results[args]

            scp  = [ [dataset.vocab[tok] for tok in dataset[point] ] for point in close_points ]

            words = [1,2,3,4,6,7,8,9,11,12,13,14]

            explanation = []
            if ( args[0] == "attn_out" and args[1]==self.layers-1) :
                words = [args[3]+1]     

            for word in words :
                print_it, result  = self.all_the_same(dataset[close_points,word])
                if (print_it) :
                    for point in range(len(scp)) :
                        scp[point][word] = ( scp[point][word] , "" )
                    explanation.append( f"{dataset.format_word(word)} = {result}" )
            for i, word in enumerate(self.pred_i) :
                w = word +1
                if i > 0 :
                    carry = [ Pairs.is_carry( dataset.p[point].item() , i  ) for point in close_points ]
                    print_it, result  = self.all_the_same(carry)
                    if (print_it) :
                        if ( len(scp[0][w]) > 1 ) :
                            for point in range(len(scp)) :
                                scp[point][w] = ( scp[point][w][0] , f"carry {result}" ) 
                        else:
                            for point in range(len(scp)) :      
                                scp[point][w] = ( scp[point][w]    , f"carry {result}" ) 

                        explanation.append( f"{dataset.format_word(w)} carry {result}" )


            return ( '\n'.join(explanation) , scp )

        return  _similatities          

    def dependency(self) :
        def _dependency(*args) :            
            dependency = []
            if ( args[0] == "attn_out" ) :
                for head in range(self.heads) :
                    dependency.append( ("result", args[1], head, args[3]) )
            elif ( args[0] == "result" ) :
                dependency.append( ("pattern-q", args[1], args[2], args[3]) )
            elif ( args[0] == "pattern-q" and args[1]>0 ) :
                embed   = self.embeddings(*args) 
                pwords  = self.pattern_words( embed )
                fwords  = self.filtered_words[args[1]]
                intersection = sorted(list(set(pwords) & set(fwords)))
                for i in intersection :
                    dependency.append( ("attn_out", args[1]-1, None, i) ) 
            return dependency
        return _dependency
    
    def embeddings(self, activation_name, layer, head, word, cache=None):
        if (cache is None) :
            cache = self.ctrl_cache 

        an  = activation_name.split("-", 1)
        if (len(an) == 2) :
            activation = cache[ an[0], layer, an[1] ]
        else :
            activation = cache[ an[0], layer]
            
        if (an[0] == "pattern" or an[0] ==  "attn_scores") :
            if   (an[1] == "q") :
                return activation[:,head,word,:]   
            elif (an[1] == "v") :
                return activation[:,head,:,word] 
        elif (word is not None and head is not None) : 
            return activation[:, word, head, :] 
        elif (word is not None) :
            return activation[:, word, :] 
        else :
            return activation 
        
    def _wh_hook(self, activation,hook,an,head,word):

        if (an[0] == "pattern" or an[0] ==  "attn_scores") :
            if   (an[1] == "q") :
                activation[:,head,word,:] = activation[:,head,word,:].mean(0)   
            elif (an[1] == "v") :
                activation[:,head,:,word] = activation[:,head,:,word].mean(0)   
        elif (word is not None and head is not None) : 
            activation[:, word, head, :] = activation[:, word, head, :].mean(0) 
        elif (word is not None) :
            activation[:, word, :] = activation[:, word, :].mean(0) 
        else :
            activation = activation.mean(0) 

        return activation

    def wh_hook(self, an=None, head=None, word = None) :
        return functools.partial(self._wh_hook, an=an, head=head, word=word )

    def substitute_centoids(self, activation_name, layer, word=None, head=None, selection=slice(None), mean=False) :
        self.model.reset_hooks()

        an  = activation_name.split("-", 1)
        if (len(an) == 2) :
            self.model.add_hook( utils.get_act_name(an[0], layer, an[1]), self.wh_hook( an,head,word)  )
        else :
            self.model.add_hook( utils.get_act_name(an[0], layer       ), self.wh_hook( an,head,word)  )

        logits, cache = model.run_with_cache(self.dataset[selection,:])
        self.model.reset_hooks()
        preds = logits.softmax(dim=-1).argmax(dim=-1)

        err = t.zeros((preds.shape[0]), dtype=t.bool)
        for outcome_word in self.pred_i:
            err = err | self.errors(preds, outcome_word, selection)

        return err
    
    def errors(self, preds, outcome_word, selection=slice(None)) :
        return self.ctrl_preds[selection,outcome_word] != preds[selection,outcome_word]


In [7]:
cluster_state = ClusterState(model,dataset.toks )
dependency    =  cluster_state.step_through_calc( cluster_state.dependency() )



Dataset size=3000 examples
Model has 2 layers, 3 heads, 15 words, 48 dimensions


In [21]:
import pickle

def create_examples(exs) :

    def create_example(ex) :
        example = [ dataset.vocab[tok] for tok in dataset.toks[ex] ] 
        carry   = [ Pairs.is_carry( dataset.p[ex] , i) for i,p in enumerate(cluster_state.pred_i) ]
        for i, word in enumerate(cluster_state.pred_i) :
            w = word +1   
            if i >0 :
                example[w] = ( example[w] , f"carry {carry[i]}" ) 
        return example  

    examples = []
    for ex in range(exs) :
        example      = create_example(ex)
        results      = cluster_state.step_through_calc( cluster_state.search_index( dataset.toks[ex], 10 ) )
        similarities = cluster_state.step_through_calc( cluster_state.find_similatities( results ) )
        examples.append( (example, similarities, dependency) )

    with open( f"visualization/examples.pkl", 'wb') as file:
        pickle.dump(examples, file)

create_examples(20)

In [9]:


# def recursive_print( start , depth=2) : 
#     tab = " " * depth
#     print(tab + str(start) )
#     print(tab + str(similarities[start][0]) )
#     for d in dependency[start] :
#         recursive_print( d , depth=depth+2 )

# recursive_print( start[0] )

In [10]:
print( dataset.format(0))
print( Pairs.p[ dataset.p[0] ])
print( [ Pairs.is_carry( dataset.p[0].item() , i  ) for i,p in enumerate(cluster_state.pred_i) ] )

ST 3190 + 4577 = 7767
[0 0 2 0]
[0, 0, 1, 0]


In [11]:
import numpy as np
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
import torch.nn.functional as F

def logistic_regression(X,y) :
    # Load the Iris dataset
    # iris = load_iris()
    # X = iris.data  # Features: sepal length, sepal width, petal length, petal width
    # print(X.shape)
    # y = iris.target  # Target: species of Iris (setosa, versicolor, virginica)
    # print(y.shape)
    # print(y)

    # Split the data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # It's a good practice to standardize the data (mean=0 and variance=1)
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    # Create the logistic regression model
    log_reg = LogisticRegression(random_state=42, max_iter=1000)

    # Fit the model to the training data
    log_reg.fit(X_train, y_train)

    # Make predictions on the testing data
    y_pred = log_reg.predict(X_test)

    # Output the accuracy of the model
    print(f'Accuracy: {accuracy_score(y_test, y_pred):.2f}')

    # Output the confusion matrix
    print('Confusion Matrix:')
    print(confusion_matrix(y_test, y_pred))

def linear_regression(X, y):
    # Split the data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # It's a good practice to standardize the data (mean=0 and variance=1)
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    # Create the linear regression model
    lin_reg = LinearRegression()

    # Fit the model to the training data
    lin_reg.fit(X_train, y_train)

    # Make predictions on the testing data
    y_pred = lin_reg.predict(X_test)

    # Output the Mean Squared Error of the model
    mse = mean_squared_error(y_test, y_pred)
    print(f'Mean Squared Error: {mse:.2f}')

    # Output the R-squared score of the model
    r2 = r2_score(y_test, y_pred)
    print(f'R-squared: {r2:.2f}')

    return lin_reg  # Optional: return the model for further use

import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt


def pca_the_points(points, n_components=10, graph=True,labels=None):

    n_components=min(n_components,points.size(0))

    pca = PCA(n_components=n_components)
    data = points.detach().cpu().numpy()
    pca_result = pca.fit_transform(data)

    # Get the two components
    comp1 = pca.components_[0]
    comp2 = pca.components_[1]

    # Print the explained variance by each component
    print("Explained variance by component:", np.round(pca.explained_variance_ratio_,2)*100)

    if graph:
        # Plotting the explained variance
        plt.figure(figsize=(8, 5))
        plt.bar(range(n_components), pca.explained_variance_ratio_, align='center')
        plt.xlabel("Principal Component")
        plt.ylabel("Explained Variance Ratio")
        plt.title("Explained Variance by Principal Component")
        plt.show()

     # Create a bar chart
        width = 0.4
        indices = np.arange(48)

        plt.figure(figsize=(12, 6))
        bar1 = plt.bar(indices, pca.components_[0], width, color='b', label='Principal Component 1')
        bar2 = plt.bar(indices + width, pca.components_[1], width, color='r', label='Principal Component 2')
        plt.xlabel("Dimensions")
        plt.ylabel("Weight")
        plt.title("Weights of Dimensions for First Two Principal Components")
        plt.xticks(indices + width / 2, indices)  # X-axis labels (centered)
        plt.legend()

        plt.tight_layout()
        plt.show()


        n=points.size(0) # n=pca_result.shape[0] number of points to plot

        xi = 1
        yi = 2


        plt.figure(figsize=(8, 8))
        plt.scatter(pca_result[:n, xi], pca_result[:n, yi], s=50, c='blue', edgecolors='k', marker='o', alpha=0.7)

        if labels is  None :
            for i in range(n) :
                plt.annotate(i, (pca_result[i, xi], pca_result[i, yi]), fontsize=18, ha="right")
        else :
            for i,label in enumerate(labels) :
                plt.annotate( label , (pca_result[i, xi], pca_result[i, yi]), fontsize=18, ha="right")

        # Setting labels and title
        plt.xlabel(f"Principal Component {xi}")
        plt.ylabel(f"Principal Component {yi}")
        # plt.xlim(-.7, .7)  # Set the x-axis limits 
        # plt.ylim(-.7, .7)  # Set the y-axis limits
        plt.title(f"Projection of Vectors on PC {xi} and PC {yi}")
        plt.grid(True)

        plt.show()

    return comp1, comp2

def show_gram_matrix(y, y_pred) :


    y_l2_normalized      = F.normalize(y, p=2, dim=0)
    y_pred_l2_normalized = F.normalize(y_pred, p=2, dim=0)

    gram_matrix = t.mm( y_l2_normalized, y_pred_l2_normalized.t())

    colorscale = [[0, 'white'], [.5, 'white'], [1.0, 'green']]

    fig = go.Figure(data=go.Heatmap(
        z=gram_matrix.detach().cpu().numpy(),
        colorscale=colorscale,
        zmin=-1,
        zmax=1,
        x=list(range(0, gram_matrix.size(-1))),  
        y=list(range(0, gram_matrix.size(-1))),
    ))

    fig.update_layout(
        title='Gram Matrix Predicted vs Actual (correlation)',
        xaxis_title='Actual Embedding position p-1',
        yaxis_title='Predicted Embedding at position p-1',
        yaxis_autorange='reversed', 
        autosize=False,
        height=500,
        width=500,
    )
    print("predicted embedding correlates highest with actual embedding")
    fig.show("png")



In [12]:
model.reset_hooks()
logits,cache  = model.run_with_cache(dataset.toks)
preds =logits.softmax(dim=-1).argmax(dim=-1)
model.reset_hooks()


In [13]:
pip install faiss-cpu

Note: you may need to restart the kernel to use updated packages.


In [14]:
layer = 1
digit = len(dataset.toks[0]) - 1 - 4
output_digit = digit+1
print(f"Predicting digit {digit}")

y_orig = dataset.p.detach().cpu().numpy()
remap = {}
for i in range(Pairs.p.shape[0]) :
    p = Pairs.p[i]
    if   p[1] == 2 :
        remap[i] = 0
    elif p[1] == 1 and p[2]==2 :
        remap[i] = 1
    elif p[1] == 1 and p[2]==1 and p[3]==2:
        remap[i] = 2
    else :
        remap[i] = 3

seq   = 10
y  = np.array([ remap[y_orig[i]] for i in range(y_orig.shape[0]) ])
X = cache["q",layer][:,seq,:,:].sum(1,keepdim=False)
print(f"at word {seq}, Logistic Regression predicts carrying the {digit+1} digit")
logistic_regression(X.detach().cpu().numpy(),y)


Predicting digit 10
at word 10, Logistic Regression predicts carrying the 11 digit
Accuracy: 0.98
Confusion Matrix:
[[217   0   0   0]
 [  0  67   1   1]
 [  0   0   9   7]
 [  0   0   0 298]]


In [15]:


def create_index(embeddings):
    index = faiss.IndexFlatL2(embeddings.shape[1])  # Use IndexFlatL2 for L2 distance
    index.add(embeddings.float())
    return index

def search_index(index, points, k):
    distances, indices = index.search(points, k + 1)
    # return indices[:, 1:].squeeze()
    return indices.squeeze()

def print_points( indices,print_carry=True, index=0  ) :

    print( f"Set {index}:" )
    for index in indices :

        if print_carry :
            p = Pairs.p[dataset.p[index].item()]
            if p[1] == 2 or (p[1]==1 and p[2]==2) or (p[1]==1 and p[2]==1 and p[3]==2) :
                print( "        [  carry ]" , dataset.format(index) )
            else :
                print( "        [no carry]" , dataset.format(index) )
        else :
            print( dataset.format(index)[5:9] )



layer = 0
head = 0
activation = "q"
word = 6
X = cache[activation,layer][:,word,head,:]#.sum(1,keepdim=False)
index = create_index(X)

In [16]:
for i in range(5) :
    indices = search_index(index, X[i].unsqueeze(0), 10)
    print_points(indices, index=i)
    print("")

for i in range(5) :
    indices = search_index(index, X[i].unsqueeze(0), 10)
    print_points(indices, print_carry=False, index=i )
    print("")



Set 0:
        [no carry] ST 3190 + 4577 = 7767
        [no carry] ST 3191 + 4599 = 7790
        [no carry] ST 4484 + 4465 = 8949
        [no carry] ST 4510 + 4142 = 8652
        [no carry] ST 0315 + 4143 = 4458
        [  carry ] ST 2729 + 4986 = 7715
        [no carry] ST 4613 + 4361 = 8974
        [  carry ] ST 4748 + 4641 = 9389
        [  carry ] ST 2789 + 4332 = 7121
        [no carry] ST 4140 + 4755 = 8895
        [  carry ] ST 2795 + 4604 = 7399

Set 1:
        [  carry ] ST 5625 + 3755 = 9380
        [  carry ] ST 0854 + 3188 = 4042
        [no carry] ST 2191 + 3698 = 5889
        [no carry] ST 5963 + 3013 = 8976
        [no carry] ST 1245 + 3713 = 4958
        [  carry ] ST 5936 + 3963 = 9899
        [  carry ] ST 5813 + 3692 = 9505
        [  carry ] ST 1925 + 3593 = 5518
        [no carry] ST 4050 + 3549 = 7599
        [no carry] ST 2377 + 3622 = 5999
        [no carry] ST 3402 + 3241 = 6643

Set 2:
        [no carry] ST 2457 + 6502 = 8959
        [no carry] ST 0286 + 6313 

In [17]:
my_list = ["apple", "banana", "cherry"]
pos = 2
my_tuple = ('X', 'Y')

# Function to replace character at pos with my_tuple
def replace_with_tuple(s, pos, t):
    return s[:pos] + ''.join(t) + s[pos + 1:]

# Applying the function to each element in the list
new_list = [replace_with_tuple(s, pos, my_tuple) for s in my_list]

print(new_list)


['apXYle', 'baXYana', 'chXYrry']
