## Setup

In [1]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

import os;

os.environ["ACCELERATE_ENABLE_RICH"] = "0"

os.environ['WANDB_NOTEBOOK_NAME'] = os.path.basename(globals()['__vsc_ipynb_file__']) 

# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"


import sys
from functools import partial
import json
from typing import List, Tuple, Union, Optional, Callable, Dict
import torch as t
from torch import Tensor
from sklearn.linear_model import LinearRegression
import numpy as np
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import einops
from tqdm import tqdm
from jaxtyping import Float, Int, Bool
from pathlib import Path
import pandas as pd
import circuitsvis as cv
import webbrowser
from IPython.display import display
from transformer_lens import utils, ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint
from transformer_lens.components import LayerNorm
from eindex import eindex


from sklearn.cluster import DBSCAN, HDBSCAN, OPTICS
from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_score

import functools
import matplotlib.pyplot as plt

# t.set_grad_enabled(False)

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "monthly_algorithmic_problems" / "september23_sum"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from monthly_algorithmic_problems.september23_sum.model import create_model
from monthly_algorithmic_problems.september23_sum.training import train, TrainArgs
from monthly_algorithmic_problems.september23_sum.dataset import SumDataset,Pairs
from plotly_utils import hist, bar, imshow

# Running this on a macbook air and mps is flaky
device = t.device("cpu") #t.device("cuda" if t.cuda.is_available() else "cpu")


MAIN = __name__ == "__main__"

## Dataset

-- Note: I rewrote the Dataset class so that it balances no carry, plain carry and cascading carry  classes evenly. this puts a larger emphasis difficult cases relative to the original way this dataset was written. With random additions, cascading carry is an edge case that is quite infrequent.  

In [2]:
dataset = SumDataset(size=1000, num_digits=4).to(device)

## Transformer
-- Note: I removed weight decay

In [3]:
filename = section_dir / "sum_model_normal.pt" # note this was trained on a mac in cpu mode without cuda and mps
args = TrainArgs(
    num_digits=4,
    trainset_size=100_000,
    valset_size=5_000,
    epochs=100,
    batch_size=512,
    lr_start=2e-3,
    lr_end=1e-4,
    # weight_decay=0.001, # not weight decay, could add this back in 
    weight_decay=0.00,
    seed=42,
    d_model=48,
    d_head=24,
    n_layers=2,
    n_heads=3,
    d_mlp=None,
    normalization_type="LN",
    use_wandb=True,
    device=device,
)
model = create_model(
    num_digits=4,
    seed=0,
    d_model=48,
    d_head=24,
    n_layers=2,
    n_heads=3,
    normalization_type="LN",
    d_mlp=None,
    device=device
)
model.load_state_dict(t.load(filename))

# model = train(args)
# t.save(model.state_dict(), filename)

<All keys matched successfully>

## Cluster Code



In [4]:
from torch.utils.data import Dataset
from matplotlib.patches import Rectangle
from matplotlib.colors import ListedColormap
from sklearn.metrics import silhouette_samples


class ClusterState():
    def __init__(self, model, dataset, min_samples=5, min_cluster_size=5, pred_i : List[int] = [10,11,12,13]  ):
        self.model   = model
        self.dataset = dataset
        self.layers  = model.cfg.n_layers
        self.heads   = model.cfg.n_heads
        self.words   = model.cfg.n_ctx
        self.ndims   = model.cfg.d_model
        self.attn_only = model.cfg.attn_only
        self.pred_i  = pred_i
        self.min_samples      = min_samples
        self.min_cluster_size = min_cluster_size
        self.result = {}

        print(f"Dataset size={dataset.shape[0]} examples")
        print(f"Model has {self.layers} layers, {self.heads} heads, {self.words} words, {self.ndims} dimensions")

        model.reset_hooks()
        self.ctrl_logits,self.ctrl_cache  = model.run_with_cache(dataset)
        self.ctrl_preds = self.ctrl_logits.softmax(dim=-1).argmax(dim=-1)
        model.reset_hooks()

        self.filtered_words = []
        for layer in range(self.layers) :
            self.filtered_words.append( self.filter_words(layer) )
        self.filtered_words.append( self.pred_i )

        previous_steps = None
        for layer in range(self.layers):
            print(f"layer={layer} ")

            for head in range(self.heads) :
                print(f"  head={head}")

                activation_name, previous_steps, words  = "v",   [["attn_out" , layer-1]] if layer>0 else None, self.filtered_words[layer]
                print(f"    activation_name={activation_name} words={words}")
                for word in words :
                    self.calc(activation_name, layer, head, word, previous_steps)    
                
                activation_name, previous_steps, words =   "pattern-q", [ ["v" , layer] ], self.filtered_words[layer+1] 
                print(f"    activation_name={activation_name} words={words}")
                for word in words :
                    self.calc(activation_name, layer, head, word, previous_steps) 

            activation_name, previous_steps, words =   "attn_out", [ ["pattern-q" , layer] ], self.filtered_words[layer+1] 
            print( f"  activation_name={'attn_out'} words={words} ")
            for word in words :
                self.calc(activation_name, layer, None, word, previous_steps) 


    def optimize_clustering(self, cluster_function) :

        min_err_c = len(dataset) + 1
        for min_samples in [2,3,5,7]:
            for min_cluster_size in [5,10,15,20] :
                self.min_samples      = min_samples
                self.min_cluster_size = min_cluster_size
                err, err_c, cluster_labels = cluster_function()
                if (err_c < min_err_c) :
                    min_err_c  = err_c
                    min_cluster_labels   = cluster_labels
                    min_err  = err
                if (err_c == 0 ) :
                    break
        return min_err, min_err_c, min_cluster_labels

                    
    def calc(self, activation_name, layer, head, word, previous_steps) :

        def cluster_function() :
            self.result[ (activation_name,layer, head, word) ] = None
            err    = self.substitute_centoids(activation_name, layer, word=word ,head=head , previous_steps=previous_steps )
            errs_c = err.sum().item()
            cluster_labels = self.result[ (activation_name,layer, head, word) ]
            return err, errs_c, cluster_labels

        err, errs_c, cluster_labels = self.optimize_clustering(cluster_function)
        n_clusters = self.calc_n_clusters( cluster_labels )

        if (errs_c > 0) :

            patch_cluster_labels = t.arange(len(dataset)).int()

            cluster_labels[ err ]   = -2
            patch_cluster_labels[ ~err ]  = -2
            patch_cluster_labels = self.combine_cluster_labels([cluster_labels, patch_cluster_labels])

            n_clusters_matching = self.calc_n_clusters(patch_cluster_labels) 

            self.result[ (activation_name, layer, head, word) ] = patch_cluster_labels
            err   = self.substitute_centoids(activation_name, layer, word=word ,head=head ,previous_steps=previous_steps )
            errs_m = err.sum().item()

            print(f"      word={word} n_c={n_clusters} err_c={errs_c}, n_m={n_clusters_matching} err_m={errs_m}" )
        else :
            print(f"      word={word} n_c={n_clusters} err_c={errs_c}" )

    def calc_n_clusters(self, cluster_labels) :
        return  cluster_labels.unique().shape[0]

    def combine_cluster_labels(self, cluster_labels_list):
        return t.stack(cluster_labels_list, dim=1).unique(dim=0, return_inverse=True)[1]

    def pattern_words(self, pattern, threshold=.05 ) :
        return t.where(  pattern.max(dim=0).values > threshold )[0].tolist()

    def cluster_data(self, data, word) :
        return self.cluster_values( word, data=data[0], activation_name=data[1], layer=data[2], head=data[3] )

    def cluster_values(self, word, data=None, activation_name=None , layer=None, head=None, selection=slice(None) ) :
        data = data[selection,word,:]
        cluster_labels, silhouette_values, n_clusters, silhouette_avg  = self.find_clusters(data)
        return cluster_labels, silhouette_values, n_clusters, silhouette_avg, self.substitute_centoids(layer, activation_name, word=word, head=head,selection=selection)

    def filter_words(self, layer, activation_name = "resid_pre") :
        return [word for word in range(self.words) if self.substitute_centoids(activation_name, layer, word=word, mean=True).sum().item() > 0]

    def find_clusters(self, data):
        data = data.detach().numpy()
        dbscan = HDBSCAN(min_samples=self.min_samples, min_cluster_size=self.min_cluster_size)
        cluster_labels = dbscan.fit_predict(data)
        n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
        
        if n_clusters > 1:
            silhouette_avg    = silhouette_score(data, cluster_labels)
            silhouette_values = silhouette_samples(data, cluster_labels)
        else:
            silhouette_avg    = 0
            silhouette_values = np.zeros(data.shape[0])

        cluster_labels = t.from_numpy(cluster_labels).to(device)

        return cluster_labels, silhouette_values, n_clusters, silhouette_avg
    
    def calc_centroid(self, data, mean=False,cluster_labels=None) :
        if mean is True :
            return data.mean(0), t.zeros(data.shape[0])
        
        if cluster_labels is None :
            cluster_labels, silhouette_values, n_clusters, silhouette_avg = self.find_clusters(data)
            # print( f"    n_clusters={n_clusters} silhouette_avg={silhouette_avg:.2f}")

        unique_labels, inverse_indices = cluster_labels.unique(return_inverse=True)
        centroids = t.vstack([data[cluster_labels == i].mean(0) for i in unique_labels])

        replacement_value = centroids[inverse_indices]
        # replacement_value[ cluster_labels < 0] = data[ cluster_labels < 0]
        
        return replacement_value, cluster_labels


    def replace_hook(self,replace=None) :
        _replace_hook = lambda activation,hook,replace : replace
        return functools.partial(_replace_hook, replace=replace)

    def _wh_hook(self, activation,hook,an,layer,head,word,mean):

        a = an[0] if len(an) == 1 else an[0] + "-" + an[1]

        cl = self.result.get( (a , layer, head, word), None) 

        if (an[0] == "pattern" or an[0] ==  "attn_scores") :
            if   (an[1] == "q") :
                activation[:,head,word,:], cl = self.calc_centroid(activation[:,head,word,:],mean=mean,cluster_labels=cl)     
            elif (an[1] == "v") :
                activation[:,head,:,word], cl = self.calc_centroid(activation[:,head,:,word],mean=mean,cluster_labels=cl)   
        elif (word is not None and head is not None) : 
            activation[:, word, head, :], cl  = self.calc_centroid(activation[:, word, head, :],mean=mean,cluster_labels=cl) 
        elif (word is not None) :
            activation[:, word, :], cl        = self.calc_centroid(activation[:, word, :],mean=mean,cluster_labels=cl) 
        else :
            activation, cl = self.calc_centroid(activation,cluster_labels=cl) 

        self.result[ (a , layer, head, word) ] = cl

        return activation
    
    def save_values(self,ctrl_cache, cache, word,head,an,selection=slice(None)):
        if (an[0] == "pattern" or an[0] ==  "attn_scores") :
            if   (an[1] == "q") :
                ctrl_cache[selection,head,word,:] =cache[selection,head,word,:] 
            elif (an[1] == "v") :
                ctrl_cache[selection,head,:,word] = cache[selection,head,:,word] 
        elif (word is not None and head is not None) : 
            ctrl_cache[ selection, word, head, :]=cache[ selection, word, head, :]
        elif (word is not None) :
            ctrl_cache[ selection, word, :]=cache[ selection, word, :]  
        else :
            ctrl_cache[ selection ]=cache[ selection ]

    def wh_hook(self, an=None, layer=None, head=None, word = None, mean=False) :
        return functools.partial(self._wh_hook, an=an, layer=layer, head=head, word=word, mean=mean)

    def substitute_centoids(self, activation_name, layer, word=None, head=None, selection=slice(None), previous_steps=None, mean=False, cluster=True, save=True) :
        self.model.reset_hooks()

        previous_steps = None
        if (previous_steps is not None and len(previous_steps)!=0) :
            for step in previous_steps :      
                pn = step[0].split("-", 1)
                previous_layer = step[1]
                if (len(pn) == 2) :
                    self.model.add_hook( utils.get_act_name(pn[0], previous_layer, pn[1]), self.replace_hook( self.ctrl_cache[ pn[0], previous_layer, pn[1] ] )  )
                else :
                    self.model.add_hook( utils.get_act_name(pn[0], previous_layer       ), self.replace_hook( self.ctrl_cache[ pn[0], previous_layer       ] )  )

        an  = activation_name.split("-", 1)
        if (cluster) :
            if (len(an) == 2) :
                self.model.add_hook( utils.get_act_name(an[0], layer, an[1]), self.wh_hook( an,layer,head,word,mean)  )
            else :
                self.model.add_hook( utils.get_act_name(an[0], layer       ), self.wh_hook( an,layer,head,word,mean)  )

        logits, cache = model.run_with_cache(self.dataset[selection,:])
        self.model.reset_hooks()
        preds = logits.softmax(dim=-1).argmax(dim=-1)

        err   = self.cluster_errors(preds, self.result[ (activation_name,layer, head, word) ] , selection=selection)

        if (save and err.sum().item() == 0) :
            if (len(an) == 2) :
                self.save_values( self.ctrl_cache[ an[0], layer, an[1] ], cache[ an[0], layer, an[1] ] , word, head, an)
            else :
                self.save_values( self.ctrl_cache[ an[0], layer], cache[ an[0], layer ] , word, head, an)

        return err
    
    def errors(self, preds, outcome_word, selection=slice(None)) :
        return self.ctrl_preds[selection,outcome_word] != preds[selection,outcome_word]
    
    def cluster_errors(self, preds, cluster_labels=None, selection=slice(None)):

        err = t.zeros((preds.shape[0]), dtype=t.bool)
        for outcome_word in self.pred_i:
            err = err | self.errors(preds, outcome_word, selection)

        for cluster in cluster_labels.unique():
            cluster_members = (cluster_labels == cluster)
            if err[cluster_members].any():
                err[cluster_members] = True

        return err



In [5]:
cluster_state = ClusterState(model,dataset.toks,min_samples=4,min_cluster_size=10 ) #4,10

Dataset size=1000 examples
Model has 2 layers, 3 heads, 15 words, 48 dimensions
layer=0 
  head=0
    activation_name=v words=[1, 2, 3, 4, 6, 7, 8, 9]
      word=1 n_c=9 err_c=0
      word=2 n_c=10 err_c=0
      word=3 n_c=10 err_c=0
      word=4 n_c=10 err_c=0
      word=6 n_c=9 err_c=0
      word=7 n_c=10 err_c=0
      word=8 n_c=10 err_c=0
      word=9 n_c=10 err_c=0
    activation_name=pattern-q words=[1, 5, 6, 7, 9, 10, 11, 12, 13]
      word=1 n_c=9 err_c=0
      word=5 n_c=100 err_c=0
      word=6 n_c=45 err_c=0
      word=7 n_c=95 err_c=19, n_m=113 err_m=0
      word=9 n_c=79 err_c=27, n_m=105 err_m=0
      word=10 n_c=24 err_c=35, n_m=58 err_m=0
      word=11 n_c=19 err_c=27, n_m=45 err_m=0
      word=12 n_c=53 err_c=112, n_m=164 err_m=0
      word=13 n_c=64 err_c=0
  head=1
    activation_name=v words=[1, 2, 3, 4, 6, 7, 8, 9]
      word=1 n_c=9 err_c=0
      word=2 n_c=10 err_c=0
      word=3 n_c=10 err_c=0
      word=4 n_c=10 err_c=0
      word=6 n_c=9 err_c=0
      word=7 n

In [6]:
cluster_state = ClusterState(model,dataset.toks,min_samples=4,min_cluster_size=10 )

Dataset size=1000 examples
Model has 2 layers, 3 heads, 15 words, 48 dimensions
layer=0 
  head=0
    activation_name=v words=[1, 2, 3, 4, 6, 7, 8, 9]
      word=1 n_c=9 err_c=0
      word=2 n_c=10 err_c=0
      word=3 n_c=10 err_c=0
      word=4 n_c=10 err_c=0
      word=6 n_c=9 err_c=0
      word=7 n_c=10 err_c=0
      word=8 n_c=10 err_c=0
      word=9 n_c=10 err_c=0
    activation_name=pattern-q words=[1, 5, 6, 7, 9, 10, 11, 12, 13]
      word=1 n_c=9 err_c=0
      word=5 n_c=100 err_c=0
      word=6 n_c=45 err_c=0
      word=7 n_c=95 err_c=19, n_m=113 err_m=0
      word=9 n_c=79 err_c=27, n_m=105 err_m=0
      word=10 n_c=24 err_c=35, n_m=58 err_m=0
      word=11 n_c=19 err_c=27, n_m=45 err_m=0
      word=12 n_c=53 err_c=112, n_m=164 err_m=0
      word=13 n_c=64 err_c=0
  head=1
    activation_name=v words=[1, 2, 3, 4, 6, 7, 8, 9]
      word=1 n_c=9 err_c=0
      word=2 n_c=10 err_c=0
      word=3 n_c=10 err_c=0
      word=4 n_c=10 err_c=0
      word=6 n_c=9 err_c=0
      word=7 n

In [7]:



# activation_name="v"
# word=9 
# layer=1 
# head=2 


# def test(word, data, activation_name, layer, head) :
#     cluster_labels, silhouette_values, n_clusters, silhouette_avg, err = cluster_state.cluster_values( word, data=data, activation_name=activation_name, layer=layer, head=head )
#     errs  =  err.sum().item()/err.shape[0]
#     print( f"word={word} layer={layer}  head={head} clusters={n_clusters} s_avg={silhouette_avg:.2f} err={errs:.0%}")# unc={unc:.0%}")
#     total_err = err

#     # max_count = 3
#     # i = 0

#     # while ( i < max_count and total_err.sum().item() > 2 ) :

#     #     cluster_labels, silhouette_values, n_clusters, silhouette_avg, err = cluster_state.cluster_values( word, data=data, activation_name=activation_name, layer=layer, head=head, selection=total_err )
#     #     errs  =  err.sum().item()/err.shape[0]
#     #     print( f"word={word} layer={layer}  head={head} clusters={n_clusters} s_avg={silhouette_avg:.2f} err={errs:.0%}  err_count={err.sum().item()}")
#     #     total_err[ total_err==True ] = err

#     #     i=i+1

# words = [1,5,6,7,9,10,11,12,13]

# v = cluster_state.ctrl_cache[activation_name,layer]#[:,:,head,:]
# v_out = einops.einsum(
#     v, model.W_O[layer],
#     "batch posn_Q nheads d_head, nheads d_head d_model -> batch posn_Q nheads d_model",
# ) 
# # data = model.unembed(model.ln_final(v_out[:,:,head,:]))

# for word in words :

# #     # print(activation_name, v.shape)
# #     # test(word, v[:,:,head,:], activation_name, layer, head)
#     # print(activation_name+"_out", v_out.shape)
#     test(word, v[:,:,head,:].detach(), activation_name, layer, head)
#     test(word, v_out[:,:,head,:].detach(), activation_name, layer, head)

# # for word in words :
# #     print(activation_name+"_data", v_out.shape)
# #     test(word, data.detach(), activation_name, layer, head)







In [8]:
import numpy as np
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
import torch.nn.functional as F

def logistic_regression(X,y) :
    # Load the Iris dataset
    # iris = load_iris()
    # X = iris.data  # Features: sepal length, sepal width, petal length, petal width
    # print(X.shape)
    # y = iris.target  # Target: species of Iris (setosa, versicolor, virginica)
    # print(y.shape)
    # print(y)

    # Split the data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # It's a good practice to standardize the data (mean=0 and variance=1)
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    # Create the logistic regression model
    log_reg = LogisticRegression(random_state=42, max_iter=1000)

    # Fit the model to the training data
    log_reg.fit(X_train, y_train)

    # Make predictions on the testing data
    y_pred = log_reg.predict(X_test)

    # Output the accuracy of the model
    print(f'Accuracy: {accuracy_score(y_test, y_pred):.2f}')

    # Output the confusion matrix
    print('Confusion Matrix:')
    print(confusion_matrix(y_test, y_pred))

def linear_regression(X, y):
    # Split the data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # It's a good practice to standardize the data (mean=0 and variance=1)
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    # Create the linear regression model
    lin_reg = LinearRegression()

    # Fit the model to the training data
    lin_reg.fit(X_train, y_train)

    # Make predictions on the testing data
    y_pred = lin_reg.predict(X_test)

    # Output the Mean Squared Error of the model
    mse = mean_squared_error(y_test, y_pred)
    print(f'Mean Squared Error: {mse:.2f}')

    # Output the R-squared score of the model
    r2 = r2_score(y_test, y_pred)
    print(f'R-squared: {r2:.2f}')

    return lin_reg  # Optional: return the model for further use

import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt


def pca_the_points(points, n_components=10, graph=True,labels=None):

    n_components=min(n_components,points.size(0))

    pca = PCA(n_components=n_components)
    data = points.detach().cpu().numpy()
    pca_result = pca.fit_transform(data)

    # Get the two components
    comp1 = pca.components_[0]
    comp2 = pca.components_[1]

    # Print the explained variance by each component
    print("Explained variance by component:", np.round(pca.explained_variance_ratio_,2)*100)

    if graph:
        # Plotting the explained variance
        plt.figure(figsize=(8, 5))
        plt.bar(range(n_components), pca.explained_variance_ratio_, align='center')
        plt.xlabel("Principal Component")
        plt.ylabel("Explained Variance Ratio")
        plt.title("Explained Variance by Principal Component")
        plt.show()

     # Create a bar chart
        width = 0.4
        indices = np.arange(48)

        plt.figure(figsize=(12, 6))
        bar1 = plt.bar(indices, pca.components_[0], width, color='b', label='Principal Component 1')
        bar2 = plt.bar(indices + width, pca.components_[1], width, color='r', label='Principal Component 2')
        plt.xlabel("Dimensions")
        plt.ylabel("Weight")
        plt.title("Weights of Dimensions for First Two Principal Components")
        plt.xticks(indices + width / 2, indices)  # X-axis labels (centered)
        plt.legend()

        plt.tight_layout()
        plt.show()


        n=points.size(0) # n=pca_result.shape[0] number of points to plot

        xi = 1
        yi = 2


        plt.figure(figsize=(8, 8))
        plt.scatter(pca_result[:n, xi], pca_result[:n, yi], s=50, c='blue', edgecolors='k', marker='o', alpha=0.7)

        if labels is  None :
            for i in range(n) :
                plt.annotate(i, (pca_result[i, xi], pca_result[i, yi]), fontsize=18, ha="right")
        else :
            for i,label in enumerate(labels) :
                plt.annotate( label , (pca_result[i, xi], pca_result[i, yi]), fontsize=18, ha="right")

        # Setting labels and title
        plt.xlabel(f"Principal Component {xi}")
        plt.ylabel(f"Principal Component {yi}")
        # plt.xlim(-.7, .7)  # Set the x-axis limits 
        # plt.ylim(-.7, .7)  # Set the y-axis limits
        plt.title(f"Projection of Vectors on PC {xi} and PC {yi}")
        plt.grid(True)

        plt.show()

    return comp1, comp2

def show_gram_matrix(y, y_pred) :


    y_l2_normalized      = F.normalize(y, p=2, dim=0)
    y_pred_l2_normalized = F.normalize(y_pred, p=2, dim=0)

    gram_matrix = t.mm( y_l2_normalized, y_pred_l2_normalized.t())

    colorscale = [[0, 'white'], [.5, 'white'], [1.0, 'green']]

    fig = go.Figure(data=go.Heatmap(
        z=gram_matrix.detach().cpu().numpy(),
        colorscale=colorscale,
        zmin=-1,
        zmax=1,
        x=list(range(0, gram_matrix.size(-1))),  
        y=list(range(0, gram_matrix.size(-1))),
    ))

    fig.update_layout(
        title='Gram Matrix Predicted vs Actual (correlation)',
        xaxis_title='Actual Embedding position p-1',
        yaxis_title='Predicted Embedding at position p-1',
        yaxis_autorange='reversed', 
        autosize=False,
        height=500,
        width=500,
    )
    print("predicted embedding correlates highest with actual embedding")
    fig.show("png")



In [9]:
model.reset_hooks()
logits,cache  = model.run_with_cache(dataset.toks)
preds =logits.softmax(dim=-1).argmax(dim=-1)
model.reset_hooks()


In [10]:


layer = 0
digit = -2
output_digit = digit+1
print(f"Predicting digit {digit}")

remap = {}
for i in range(Pairs.p.shape[0]) :
    p = Pairs.p[i]
    remap[i] = p[3]

# for head in range(3) :

seq1 = 9
seq2 = -2
y_orig = dataset.p.detach().cpu().numpy()
yy  = np.array([ remap[y_orig[i]] for i in range(y_orig.shape[0]) ])
# yy[ yy==1] = 0
# yy[ yy==2] = 1
 
indices = np.where(yy == 1)


y  = yy

# y1 = dataset.toks.detach().cpu().numpy()[:,4] + dataset.toks.detach().cpu().numpy()[:,9]
# y = y1


y = dataset.toks.detach().cpu().numpy()[indices][:,output_digit]
X1 = cache["result",layer][:,seq1,:,:].sum(1,keepdim=False).detach().cpu().numpy()#[indices]
X2 = cache["result",layer][:,seq2,:,:].sum(1,keepdim=False).detach().cpu().numpy()#[indices]


# X1 = cache["result",layer][:,seq1,head,:].detach().cpu().numpy()#[indices]
# X2 = cache["result",layer][:,seq2,head,:].detach().cpu().numpy()#[indices]

X  = X1 #np.concatenate((X1,X2),axis=1)
print(f"head={head} using words {seq1} and {seq2},  Logistic Regression predicts the {digit} digit when there is no carrying")
# logistic_regression(X,y)

# err = cluster_state.substitute_centoids(layer, "z", head=0, lh_word=seq1, selection=t.tensor(indices[0]), save=False)
# print( f"err={(err.sum().item()/err.shape[0]):.0%}" ) 
    
# err = cluster_state.substitute_centoids(layer, "z", head=1, lh_word=seq1, selection=t.tensor(indices[0]), save=False)
# print( f"err={(err.sum().item()/err.shape[0]):.0%}" ) 

# err = cluster_state.substitute_centoids(layer, "z", head=2 ,lh_word=seq1, selection=t.tensor(indices[0]), save=False)
# print( f"err={(err.sum().item()/err.shape[0]):.0%}" ) 

# indices = np.where(yy == 2)
# y = dataset.toks.detach().cpu().numpy()[indices][:,output_digit]
# X1 = cache["result",layer][:,seq1,:,:].sum(1,keepdim=False).detach().cpu().numpy()[indices]
# X2 = cache["result",layer][:,seq2,:,:].sum(1,keepdim=False).detach().cpu().numpy()[indices]

# # X1 = cache["result",layer][:,seq1,head,:].detach().cpu().numpy()[indices]
# # X2 = cache["result",layer][:,seq2,head,:].detach().cpu().numpy()[indices]

# X  = np.concatenate((X1,X2),axis=1)
# print(f"using words {seq1} and {seq2},  Logistic Regression predicts the {digit} digit when there is carrying")
# logistic_regression(X,y)

Predicting digit -2


NameError: name 'head' is not defined

In [None]:

def combine_cluster_labels(cluster_labels_list):
    return t.stack(cluster_labels_list, dim=1).unique(dim=0, return_inverse=True)[1]

def filter_words( pattern, threshold=.05 ) :
    return t.where(  pattern.max(dim=0).values > threshold )[0].tolist()


#head=1
# activation_name=pattern-q words=[1, 5, 6, 7, 9, 10, 11, 12, 13]
#   word=1 n_c=9 err_c=0, n_m=9 err_m=0.0
#   word=5 n_c=44 err_c=0, n_m=631 err_m=0.0
#   word=6 n_c=44 err_c=137, n_m=177 err_m=3
#   word=7 n_c=43 err_c=59, n_m=102 err_m=3
#   word=9 n_c=18 err_c=234, n_m=248 err_m=0

#   word=10 n_c=7 err_c=166, n_m=174 err_m=21
#   word=11 n_c=38 err_c=94, n_m=122 err_m=0

#   word=12 n_c=39 err_c=209, n_m=232 err_m=0
#   word=13 n_c=7 err_c=0, n_m=551 err_m=0.0

head = 1
seq1 = 10
layer = 0
def plot_it(data) :
    data = data.detach().cpu().numpy() 
    plt.figure(figsize=(15, 5))  # Adjusting the figure size to avoid a skinny look
    plt.imshow(data.T, aspect='auto')  # 'auto' aspect ratio makes the heatmap fit the figure
    plt.colorbar()  # Adding a colorbar for reference
    plt.xlabel('data points')
    plt.ylabel('word')
    plt.title(f'seq={seq1} Array')
    plt.show()


# for seq1 in range(10,11) :
for seq1 in range(1,14) :

    pattern = cluster_state.ctrl_cache["pattern",layer][:,head,seq1,:]

    
    cluster_labels, silhouette_values, n_clusters, silhouette_avg = cluster_state.find_clusters( pattern )
    unc = (cluster_labels < 0).sum().item() / cluster_labels.shape[0]
    print( f"seq={seq1} n_clusters={n_clusters} silhouette_avg={silhouette_avg:.2f} unclassified={unc:.0%}")
    err = cluster_state.substitute_centoids(layer, "pattern-q", head=head, word=seq1, save=False)
    print( f"err={(err.sum().item()/err.shape[0]):.0%}" ) 

    sorted_indices = t.argsort(cluster_labels)
    pattern_sorted = pattern[sorted_indices]


    words = filter_words( pattern )
    print( f"words={words}" )
    cluster_labels_list = [ cluster_state.result.get( ("v",layer,head,word),  t.zeros(pattern.shape[0]) ) for word in words ]
    cluster_labels = combine_cluster_labels(cluster_labels_list)
    print( f" n_clusters={t.max(cluster_labels.unique()).item() + 1}" )



    plot_it(pattern)
    plot_it(pattern_sorted)

    filter = slice(None) #(cluster_labels < 0)

    # err = cluster_state.substitute_centoids(layer, "pattern-q", head=head ,lh_word=seq1, save=False)
    # print( f"err={(err.sum().item()/err.shape[0]):.0%}" ) 

    f = np.zeros((15,15))
    X  = dataset.toks[filter,seq1-5].int()
    Y  = dataset.toks[filter,seq1].int()
    # Z  = pattern[filter,0]
    # print(Z.shape)
    # f[X, Y] = Z

    np.add.at(f, (X, Y), 1)

    plt.figure(figsize=(5,5))  # Adjusting the figure size to avoid a skinny look
    plt.imshow(f, aspect='auto')  # 'auto' aspect ratio makes the heatmap fit the figure
    plt.colorbar()  # Adding a colorbar for reference
    plt.xlabel('X Axis ')
    plt.ylabel('Y Axis ')
    plt.title('Heatmap ')
    plt.show()
