In [12]:
!mkdir Database #Only run once, makes a folder where the h5py database is saved.

In [None]:
#This is required to import the APMAE model from the other directory
import sys, os
path2add = os.path.normpath(os.path.abspath(os.path.join(os.path.dirname('./run.ipynb'), os.path.pardir, 'Model')))
if (not (path2add in sys.path)) :
    sys.path.append(path2add)

In [None]:
#Our code
from DataUtil.DataLoader import IterableAttentionLoader
from DataUtil.AttentionData import AttentionData
from ap_mae import APMAE

#Imported packages
from transformers import AutoModelForCausalLM

import numpy as np
from tqdm import tqdm
from collections import Counter

#We recommend to use the cuml package for quicker computation if a decent gpu is available, can be replaced by the corresponding sklearn packages
from cuml import UMAP
from cuml import HDBSCAN
from cuml.metrics.pairwise_distances import pairwise_distances

# Setup

In [None]:
size = '3B' #Set the size of the target model here. 3B, 7B or 15B
            #3B requires 2TB of storage
            #7B requires 3.5TB of storage
            #15B requires 5.5TB of storage

db_name = "reproduction_{}".format(size)
target_model_name = 'bigcode/starcoder2-{}'.format(size.lower())
encoding_model_name = 'LaughingLogits/AP-MAE-SC2-{}'.format(size)
dataset_name = 'LaughingLogits/Stackless_Java_V2'
split = 'test'

device = 'cpu'
languages = ['java']

understanding = ["identifiers"]
literals = ["boolean_literals", "string_literals", "numeric_literals"]
operators = ['boolean_operators', 'mathematical_operators', 'assignment_operators']
syntax = ['eol', 'closing_bracket']
tasks = ["random"] + understanding + literals + operators + syntax

samples_per_task = 1000
context_length = 256

In [None]:
# These can be replaced with a list of keys, but we used all values in our investigation
# e.g. incorrect java predicitions for the eol task, all heads from layer 4 and 7
# langs = ['java']
# corrects = ['incorrect']
# querys = ['eol']
# layers = ['4','7']
# heads = "*"
langs = "*"
corrects = "*"
querys = "*"
layers = "*"
heads = "*"

In [None]:
target_model = AutoModelForCausalLM.from_pretrained(target_model_name, device_map="auto")
encoding_model = APMAE.from_pretrained(pretrained_model_name_or_path=encoding_model_name)
attention_data = AttentionData(target_model.config, tasks, languages, db_name)
attention_loader = IterableAttentionLoader(dataset_name, samples_per_task, context_length, tasks, languages[0], target_model_name, False, target_model, device, split, True)

# Generate patterns and encode - Section 5

In [None]:
attention_data.generate_patterns(attention_loader)

In [None]:
attention_data.encode(encoding_model)

# UMAP - Section 5.1

In [None]:
X = attention_data.data.get_grouped_samples(langs, corrects, querys, layers, heads, 'enc_cls')
X_reduced = UMAP(n_components=4, min_dist=0, metric = 'manhattan').fit_transform(X)

# HDBSCAN - Section 5.2

In [None]:
hdb = HDBSCAN(cluster_selection_epsilon=0.5)
y_hdb = hdb.fit_predict(X)

# Selection

In [None]:
def get_min_size(size):
    if size == '3B':
        return 100
    if size == '7B':
        return 167
    if size == '15B':
        return 260

def inner_cluster_distances(X, clusters):
    dist = []
    c = Counter(clusters)
    remove = [x[0] for x in c.items() if x[1] < get_min_size(size)]
    remove.append(-1)
    cluster_ids = np.unique(clusters)
    cluster_ids = [x for x in cluster_ids if x not in remove]
    for c in tqdm(np.unique(cluster_ids)):
        d = X[clusters == c]
        # if len(d) > 100000: #If you run out of memory this will save a lot
        #     d = d[0::10]
        distances = pairwise_distances(d)
        distances = distances[np.triu_indices(distances.shape[0])]
        dist.append((c,np.mean(distances)))
    return dist

In [None]:
clusters = y_hdb
c = Counter(clusters)

dist = inner_cluster_distances(X_reduced, clusters)

remove = [x[0] for x in c.items() if x[1] < get_min_size(size)]
remove.extend([x[0] for x in dist if x[1] > 1])
remove.append(-1)

clusters = np.array([-1 if x in remove else x for x in clusters])
clusters = clusters.reshape((len(clusters),1))

# Save to h5py database

In [None]:
attention_data.data.write_grouped_samples(langs, corrects, querys, layers, heads, "class_cls", clusters)