In [12]:
!mkdir Database #Only run once, makes a folder where the h5py database is saved.

In [None]:
#This is required to import the APMAE model from the other directory
import sys, os
path2add = os.path.normpath(os.path.abspath(os.path.join(os.path.dirname('./run.ipynb'), os.path.pardir, 'Model')))
if (not (path2add in sys.path)) :
    sys.path.append(path2add)

In [None]:
#Our code
from DataUtil.DataLoader import IterableAttentionLoader
from DataUtil.AttentionData import AttentionData
from ap_mae import APMAE

#Imported packages
from transformers import AutoModelForCausalLM

import numpy as np
from tqdm import tqdm
from collections import Counter

#We recommend to use the cuml package for quicker computation if a decent gpu is available, can be replaced by the corresponding sklearn packages
from cuml import UMAP
from cuml import HDBSCAN
from cuml.metrics.pairwise_distances import pairwise_distances


#for classification
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier, Pool

from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score
import pandas as pd
import matplotlib.pyplot as plt
import shap


# Setup

In [None]:
size = '3B' #Set the size of the target model here. 3B, 7B or 15B
            #3B requires 2TB of storage
            #7B requires 3.5TB of storage
            #15B requires 5.5TB of storage

db_name = "reproduction_{}".format(size)
target_model_name = 'bigcode/starcoder2-{}'.format(size.lower())
encoding_model_name = 'LaughingLogits/AP-MAE-SC2-{}'.format(size)
dataset_name = 'LaughingLogits/Stackless_Java_V2'
split = 'test'

device = 'cpu'
languages = ['java']

tasks = ['noise', 'random', 'identifiers', 'boolean_literals', 'numeric_literals', 'string_literals', 'boolean_operators', 'mathematical_operators', 'assignment_operators', 'eol', 'closing_bracket']

samples_per_task = 1000
context_length = 256

In [None]:
# These can be replaced with a list of keys, but we used all values in our investigation
# e.g. incorrect java predicitions for the eol task, all heads from layer 4 and 7
# langs = ['java']
# corrects = ['incorrect']
# querys = ['eol']
# layers = ['4','7']
# heads = "*"
langs = "*"
corrects = "*"
querys = "*"
layers = "*"
heads = "*"

if size =='3B':
    n_layers, n_heads = 30, 24
elif size =='7B':
    n_layers, n_heads = 32, 36
elif size =='15B':
    n_layers, n_heads = 40, 48

In [None]:
target_model = AutoModelForCausalLM.from_pretrained(target_model_name, device_map="auto")
encoding_model = APMAE.from_pretrained(pretrained_model_name_or_path=encoding_model_name)
attention_data = AttentionData(target_model.config, tasks, languages, db_name)
attention_loader = IterableAttentionLoader(dataset_name, samples_per_task, context_length, tasks, languages[0], target_model_name, False, target_model, device, split, True)

# Generate patterns and encode - Section 5

In [None]:
#Run this and all patterns are saved, it takes up alot of storage (up to 5.5TB per 10,000 samples).
attention_data.generate_patterns(attention_loader)
attention_data.encode(encoding_model)

In [None]:
#Run this and it wont save the actual pattern, only the encoding (up to 750GB per 100,000 samples) 10 is the btachsize for the encoder.
attention_data.generate_and_encode(attention_loader, encoding_model, 10)


# Clustering

In [None]:

for l in range(n_layers):
    for h in range(n_heads):
        X = attention_data.data.get_grouped_samples(langs, corrects, querys, [l], [h], 'enc_cls')


        #We add jitter, only where the values match, if it fails 3 times, we add jitter everywhere
        jitter_val = 1e-3
        for attempt in range(3):
            try:
                X = jitter(X, jitter = jitter_val)
                X = cp.asarray(X)
                # The UMAP model, with the hyperparameters we used
                umap_model = UMAP(
                    n_components=8,
                    n_neighbors=20,
                    min_dist=0.05,
                    metric='cosine'
                )
                X_embed = umap_model.fit_transform(X)
                break #Exit retry loop if it worked
            except Exception as e:
                print(f"Attempt {attempt + 1} failed")
                print(jitter_val)
                jitter_val = 5*jitter_val
                if attempt == 2:
                    print("selective jitter failed, jittering everywhere")
                    X = cp.asarray(X.get() + 1e-3 * np.random.randn(X.shape[0], X.shape[1]))
                        umap_model = UMAP(
                        n_components=8,
                        min_dist=0.05,
                        n_neighbors=20,
                        metric='cosine'
                    )
                    X_embed = umap_model.fit_transform(X)
                    
        # The HDBSCAN model with the hyperparameters we used.
        hdbscan_model = HDBSCAN(min_samples=20, min_cluster_size=25, allow_single_cluster=True) 
        labels = hdbscan_model.fit_predict(X_embed.get())
    
        #Save the clusters in our H5PY Database
        attention_data.data.write_grouped_samples(langs, corrects, querys, [l], [h], "class_cls", labels)

# Classification

In [None]:
#Categorical value features to pass to CatBoost
col_names = []
for l in range(40):
    for h in range(48):
        col_names.append(f'l{l}h{h}')

In [None]:
tasks = ['random', 'identifiers', 'boolean_literals', 'numeric_literals', 'string_literals', 'boolean_operators', 'mathematical_operators', 'assignment_operators', 'eol', 'closing_bracket']

for t in tasks:
    for l in tqdm(range(n_layers)):
        for h in range(n_heads):
            df[f'l{l}h{h}'] = attention_data.data.get_grouped_samples(langs, corrects, querys, [l], [h], 'enc_cls')

    # returns the labels correct, or incorrect for each prediction
    y = attention_data.data.get_grouped_clusters(langs, corrects, querys, layers, heads, 'enc_cls', True, False, True, True, True) 
        
    X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.1, random_state=42, stratify=y)

    model = CatBoostClassifier(
        iterations=1000,
        depth=6,
        learning_rate=0.1,
        loss_function='Logloss',
        cat_features=col_names,
        verbose=0,
        early_stopping_rounds=25,
        eval_fraction=0.1,
        task_type='GPU', #remove if no GPU is available
        devices='0'
    )
    
    # Fit the model
    model.fit(X_train, y_train)
    
    # Predict
    y_pred = model.predict(X_test)


    #Evaluate the classification
    cmd = ConfusionMatrixDisplay.from_predictions(y_test, y_pred)
    plt.show()
    accuracy = accuracy_score(y_test, y_pred)
    print(f'{t} accuracy: {accuracy}')


    ##### SHAP VALUES #####
    pool = Pool(X_test, y_test, cat_features=col_names)
    # Get SHAP values from CatBoost
    shap_values = model.get_feature_importance(pool, type='ShapValues')
    
    # Extract only per-feature SHAP values
    feature_shap_values = shap_values[:, 1:]
    
    # Build a shap.Explanation object
    expl = shap.Explanation(
        values=feature_shap_values,
        base_values=shap_values[:, 0],
        data=X_test.values,
        feature_names=X_test.columns.tolist()
    )
    

    