In [12]:
!mkdir Database #Only run once, makes a folder where the h5py database is saved.

In [None]:
#This is required to import the APMAE model from the other directory
import sys, os
path2add = os.path.normpath(os.path.abspath(os.path.join(os.path.dirname('./run.ipynb'), os.path.pardir, 'Model')))
if (not (path2add in sys.path)) :
    sys.path.append(path2add)

In [None]:
#Our code
from DataUtil.DataLoader import IterableAttentionLoader
from DataUtil.AttentionData import AttentionData
from ap_mae import APMAE

#Imported packages
from transformers import AutoModelForCausalLM

import numpy as np
from tqdm import tqdm
from collections import Counter

#We recommend to use the cuml package for quicker computation if a decent gpu is available, can be replaced by the corresponding sklearn packages
from cuml import UMAP
from cuml import HDBSCAN
from cuml.metrics.pairwise_distances import pairwise_distances


#for classification
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier, Pool

from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score
import pandas as pd
import matplotlib.pyplot as plt
import shap


# Setup

In [None]:
size = '3B' #Set the size of the target model here. 3B, 7B or 15B
            #3B requires 2TB of storage
            #7B requires 3.5TB of storage
            #15B requires 5.5TB of storage

db_name = "reproduction_{}".format(size)
target_model_name = 'bigcode/starcoder2-{}'.format(size.lower())
encoding_model_name = 'LaughingLogits/AP-MAE-SC2-{}'.format(size)
dataset_name = 'LaughingLogits/Stackless_Java_V2'
split = 'test'

device = 'cpu'
languages = ['java']

tasks = ['noise', 'random', 'identifiers', 'boolean_literals', 'numeric_literals', 'string_literals', 'boolean_operators', 'mathematical_operators', 'assignment_operators', 'eol', 'closing_bracket']

samples_per_task = 1000
context_length = 256

In [None]:
# These can be replaced with a list of keys, but we used all values in our investigation
# e.g. incorrect java predicitions for the eol task, all heads from layer 4 and 7
# langs = ['java']
# corrects = ['incorrect']
# querys = ['eol']
# layers = ['4','7']
# heads = "*"
langs = "*"
corrects = "*"
querys = "*"
layers = "*"
heads = "*"

if size =='3B':
    n_layers, n_heads = 30, 24
elif size =='7B':
    n_layers, n_heads = 32, 36
elif size =='15B':
    n_layers, n_heads = 40, 48

In [None]:
target_model = AutoModelForCausalLM.from_pretrained(target_model_name, device_map="auto")
encoding_model = APMAE.from_pretrained(pretrained_model_name_or_path=encoding_model_name)
attention_data = AttentionData(target_model.config, tasks, languages, db_name)
attention_loader = IterableAttentionLoader(dataset_name, samples_per_task, context_length, tasks, languages[0], target_model_name, False, target_model, device, split, True)

# Generate patterns and encode - Section 5

In [None]:
#Run this and all patterns are saved, it takes up alot of storage (up to 5.5TB per 10,000 samples).
attention_data.generate_patterns(attention_loader)
attention_data.encode(encoding_model)

In [None]:
#Run this and it wont save the actual pattern, only the encoding (up to 750GB per 100,000 samples) 10 is the btachsize for the encoder.
attention_data.generate_and_encode(attention_loader, encoding_model, 10)


# Clustering

In [None]:
import numpy as np
def jitter(X, jitter = 1e-3):
    X = np.array(X)  # ensure it's a NumPy array
    
    # Step 1: Identify duplicate rows
    _, idx_unique, idx_inverse, counts = np.unique(X, axis=0, return_index=True, return_inverse=True, return_counts=True)
    
    # Step 2: Find indices of duplicated rows (excluding the first occurrence)
    duplicate_mask = counts[idx_inverse] > 1
    first_occurrence_mask = np.zeros_like(duplicate_mask)
    first_occurrence_mask[idx_unique] = True
    final_mask = duplicate_mask & ~first_occurrence_mask  # Only actual duplicates
    
    # Step 3: Add noise to just the duplicated rows
    X[final_mask] += jitter * np.random.randn(np.sum(final_mask), X.shape[1])
    return X

In [None]:

for l in range(n_layers):
    for h in range(n_heads):
        X = attention_data.data.get_grouped_samples(langs, corrects, querys, [l], [h], 'enc_cls')


        #We add jitter, only where the values match, if it fails 3 times, we add jitter everywhere
        jitter_val = 1e-3
        for attempt in range(3):
            try:
                X = jitter(X, jitter = jitter_val)
                X = cp.asarray(X)
                # The UMAP model, with the hyperparameters we used
                umap_model = UMAP(
                    n_components=8,
                    n_neighbors=20,
                    min_dist=0.05,
                    metric='cosine'
                )
                X_embed = umap_model.fit_transform(X)
                break #Exit retry loop if it worked
            except Exception as e:
                print(f"Attempt {attempt + 1} failed")
                print(jitter_val)
                jitter_val = 5*jitter_val
                if attempt == 2:
                    print("selective jitter failed, jittering everywhere")
                    X = cp.asarray(X.get() + 1e-3 * np.random.randn(X.shape[0], X.shape[1]))
                        umap_model = UMAP(
                        n_components=8,
                        min_dist=0.05,
                        n_neighbors=20,
                        metric='cosine'
                    )
                    X_embed = umap_model.fit_transform(X)
                    
        # The HDBSCAN model with the hyperparameters we used.
        hdbscan_model = HDBSCAN(min_samples=20, min_cluster_size=25, allow_single_cluster=True) 
        labels = hdbscan_model.fit_predict(X_embed.get())
    
        #Save the clusters in our H5PY Database
        attention_data.data.write_grouped_samples(langs, corrects, querys, [l], [h], "class_cls", labels)

# Classification

In [None]:
#Categorical value features to pass to CatBoost
col_names = []
for l in range(40):
    for h in range(48):
        col_names.append(f'l{l}h{h}')

In [None]:
tasks = ['random', 'identifiers', 'boolean_literals', 'numeric_literals', 'string_literals', 'boolean_operators', 'mathematical_operators', 'assignment_operators', 'eol', 'closing_bracket']

for t in tasks:
    for l in tqdm(range(n_layers)):
        for h in range(n_heads):
            df[f'l{l}h{h}'] = attention_data.data.get_grouped_samples(langs, corrects, querys, [l], [h], 'enc_cls')

    # returns the labels correct, or incorrect for each prediction
    y = attention_data.data.get_grouped_clusters(langs, corrects, querys, layers, heads, 'enc_cls', True, False, True, True, True) 
        
    X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.1, random_state=42, stratify=y)

    model = CatBoostClassifier(
        iterations=1000,
        depth=6,
        learning_rate=0.1,
        loss_function='Logloss',
        cat_features=col_names,
        verbose=0,
        early_stopping_rounds=25,
        eval_fraction=0.1,
        task_type='GPU', #remove if no GPU is available
        devices='0'
    )
    
    # Fit the model
    model.fit(X_train, y_train)
    
    # Predict
    y_pred = model.predict(X_test)


    #Evaluate the classification
    cmd = ConfusionMatrixDisplay.from_predictions(y_test, y_pred)
    plt.show()
    accuracy = accuracy_score(y_test, y_pred)
    print(f'{t} accuracy: {accuracy}')


    ##### SHAP VALUES #####
    pool = Pool(X_test, y_test, cat_features=col_names)
    # Get SHAP values from CatBoost
    shap_values = model.get_feature_importance(pool, type='ShapValues')
    
    # Extract only per-feature SHAP values
    feature_shap_values = shap_values[:, 1:]
    
    # Build a shap.Explanation object
    expl = shap.Explanation(
        values=feature_shap_values,
        base_values=shap_values[:, 0],
        data=X_test.values,
        feature_names=X_test.columns.tolist()
    )
        

# Intervention

In [None]:
########## Helper functions for the main intervention loop ##################################
def get_global_pos_neg_features(explanation, setting = 'pos_only', top_n=10, neutral_threshold = 1e-5):
    """
    Get top N globally positive and negative contributing features
    from a SHAP Explanation object.
    """
    shap_values = explanation.values  # (n_samples, n_features)
    feature_names = explanation.feature_names

    # compute mean positive and mean negative contributions per feature
    mean_pos = np.where(shap_values > 0, shap_values, 0).mean(axis=0)
    mean_neg = np.where(shap_values < 0, shap_values, 0).mean(axis=0)
    mean_total = shap_values.mean(axis=0)

    if setting == 'pos_only':
    # pair names with values
        features = [x[0] for x in sorted(zip(feature_names, mean_pos), key=lambda x: x[1], reverse=True)]
        features = features[:top_n]
    elif setting == 'neg_only':
        features = [x[0] for x in sorted(zip(feature_names, mean_neg), key=lambda x: x[1])]  # already negative
        features = features[:top_n]
    elif setting =='neutral':
        features = [x[0] for x in [(f, v) for f, v in zip(feature_names, mean_total)
                        if abs(v) <= neutral_threshold]]
        features = random.choices(features, k=top_n)

    return features

def get_random_heads(n_heads = 5):
    heads = [f"l{l}h{h}" for l in range(model_layers) for h in range(model_heads)]
    random.shuffle(heads)
    return heads[:n_heads]

import contextlib
import torch
from typing import Dict, Iterable, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer

###############################################################################
# HeadSkipper: zeroes specified attention heads per layer during inference.
# - Works with kv caching (past_key_values) because we don't mutate K/V.
# - Operates right before the attention o_proj, so it's fast and robust.
###############################################################################

class HeadSkipper:
    def __init__(self, model, heads_by_layer):
        self.model = model
        self.heads_by_layer = {int(l): sorted(set(v)) for l, v in heads_by_layer.items()}
        # Validate layer indices
        num_layers = len(self.model.model.layers)
        for layer_idx in self.heads_by_layer:
            if not (0 <= layer_idx < num_layers):
                raise ValueError(f"Layer index {layer_idx} out of range (0 to {num_layers-1})")
        # Validate head indices
        num_heads = model.config.num_attention_heads
        for heads in self.heads_by_layer.values():
            for h in heads:
                if not (0 <= h < num_heads):
                    raise ValueError(f"Head index {h} out of range (0 to {num_heads-1})")

        self.handles = []

    @contextlib.contextmanager
    def apply(self):
        try:
            for layer_idx, heads in self.heads_by_layer.items():
                block = self.model.model.layers[layer_idx]
                o_proj = block.self_attn.o_proj

                def make_pre_hook(heads_to_zero):
                    def pre_hook(module, inputs):
                        x, = inputs  # shape: [batch, seq_len, hidden_size]
                        num_heads = self.model.config.num_attention_heads
                        head_dim = x.shape[-1] // num_heads
                        x_view = x.view(x.size(0), x.size(1), num_heads, head_dim)
                        x_view[:, :, heads_to_zero, :] = 0
                        return (x_view.reshape(x.shape[0], x.shape[1], -1),)
                    return pre_hook

                handle = o_proj.register_forward_pre_hook(make_pre_hook(heads))
                self.handles.append(handle)
            yield
        finally:
            for h in self.handles:
                h.remove()
            self.handles.clear()


In [None]:
target = 1000                       #Number of correct and incorrect predictions
t = 'identifiers'                   #Which task to run for
size = '3B'                         #Starcoder2 Model size
device = 'cuda:0'                   #Cuda device to use, if you dont have accelerate
target_model_name = 'bigcode/starcoder2-{}'.format(size.lower())
model = AutoModelForCausalLM.from_pretrained(target_model_name, device_map="auto")

incorrect = 0                       #Counter initializers
correct = 0
count = 0


#Iterator object that will gather correct scenarios for intervention
scenario_aggregator = IterableScenarioAggregator(dataset, 10000, 256, [t], "java", target_model_name, "test")

expl_loaded = expl                 #Explanation object from the classification section

mean_abs_shap = expl_loaded.abs.mean(0)
sorted_idx = mean_abs_shap.values.argsort()[::-1]
feature_names = expl_loaded.feature_names

correct_flipped = {}               #Correct predictions that were flipped
incorrect_flipped = {}             #Incorrect predictions that were flipped
i2c = {}                           #How many were flipped from incorrect to correct
c2i = {}                           #How many were flipped from correct to incorrect

for n_heads in [1, 2, 5, 10, 20, 50, 100, 200, 400, 600, 800]:
    correct_flipped[n_heads] = {}
    incorrect_flipped[n_heads] = {}
    i2c[n_heads] = {}
    c2i[n_heads] = {}
    for setting in ['pos_only', 'neg_only', 'neutral', 'random']:
        correct_flipped[n_heads][setting] = 0
        incorrect_flipped[n_heads][setting] = 0
        i2c[n_heads][setting] = 0
        c2i[n_heads][setting] = 0

with tqdm(total=target*2, desc="Progress", unit="answers") as pbar:
    it = iter(scenario_aggregator)
    while correct < target or incorrect < target:
        sample = next(it)
        query = sample[1]

        #Check length
        inputs = sample[0]['input']
        if inputs['input_ids'].size()[-1] != 256:
            #print(inputs['input_ids'].shape[0])
            continue

        count += 1

        if query != 'noise':
            labels = sample[0]['label']

        # not needed if device map is active, will be mapped
        inputs = inputs['input_ids'].unsqueeze(dim=0)
        inputs = inputs.to(device)

        # disable gradients for inference performance
        with torch.no_grad():
            outputs1 = model(
                inputs,
                use_cache=False,# we dont do further inference, saves VRAM
                output_attentions=True
            )

        preds1 = outputs1.logits.squeeze()[-1,:].argmax(dim = -1)

        correct1 = preds1.item() == labels['input_ids'].squeeze().flatten()[0].item()
        loss1 = F.cross_entropy(outputs1.logits[:, -1, :].cpu(), labels["input_ids"][:, 0]).cpu().detach()
        loss_normal.append(loss1)

        correct2 = True
        if correct1:
            if correct >= target:
                continue
            correct += 1
            pbar.set_postfix(correct=correct, incorrect=incorrect)
            pbar.update(1)
        else:
            if incorrect >= target:
                continue
            incorrect += 1
            pbar.set_postfix(correct=correct, incorrect=incorrect)
            pbar.update(1)

        for n_heads in [1, 2, 5, 10, 20, 50, 100, 200, 400, 600, 800]:
            for setting in ['pos_only', 'neg_only', 'neutral', 'random']:
                if setting =='random':
                    select_heads = get_random_heads(n_heads)
                else:
                    select_heads = get_global_pos_neg_features(expl_loaded, setting, n_heads)

                lh_dict = {}
                for name in select_heads:
                    l, h = name[1:].split('h')
                    l = int(l)
                    h = int(h)

                    if l in lh_dict:
                        lh_dict[l].append(h)
                    else:
                        lh_dict[l] = [h]
                        
                #Which heads to zero
                skipper = HeadSkipper(model, heads_by_layer=lh_dict)

                with torch.no_grad(), skipper.apply():
                    outputs2 = model(
                        inputs,
                        use_cache=False,
                        # we dont do further inference, saves VRAM
                    )

                preds2 = outputs2.logits.squeeze()[-1,:].argmax(dim = -1)

                correct2 = preds2.item() == labels['input_ids'].squeeze().flatten()[0].item()
                loss2 = F.cross_entropy(outputs2.logits[:, -1, :].cpu(), labels["input_ids"][:, 0]).cpu().detach()
                loss_edited.append(loss2)

                if correct1:
                    if preds1 != preds2:
                        correct_flipped[n_heads][setting] += 1
                if not correct1:
                    if preds1 != preds2:
                        incorrect_flipped[n_heads][setting] += 1

                if correct1 and not correct2:
                    c2i[n_heads][setting] += 1
                elif correct2 and not correct1:
                    i2c[n_heads][setting] += 1