In [18]:
# Libraries
from dotenv import load_dotenv
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from prompt_processing import process_text_file
from torch.utils.data import Dataset
from glob import glob
import pandas as pd
from scipy.stats import ttest_ind
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from typing import Optional
from tqdm import tqdm

# Initialization of the environment
# Load the variable from .env
load_dotenv()
hf_access_token = os.getenv("HF_ACCESS_TOKEN")
cache_dir = os.getenv("CACHE_DIR")

In [3]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model and tokenizer onto GPU
checkpoint = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, cache_dir=cache_dir, token=hf_access_token)
model = AutoModelForCausalLM.from_pretrained(checkpoint, cache_dir=cache_dir, token=hf_access_token).to(device)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
class ImportLLMfromHF:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def get_embd_size(self):
        return self.model.config.hidden_size

    def get_nb_layers(self):
        return self.model.config.num_hidden_layers

class LayersUnits:
    def __init__(self, llm: ImportLLMfromHF, data: Dataset, method: str = "average"):
        self.llm = llm
        self.data = data
        self.data_activation = None
        self.group = {"positive": 0, "negative": 1}
        self.method_fn = self.average_tokens_layers if method == "average" else self.final_tokens_layers if method == "final" else None
        self.name = type(data).__name__

        # Set tokenizer arguments based on dataset name
        self.tokenizer_args = {"return_tensors": "pt"}
        if self.name == "LangLocDataset":
            self.tokenizer_args.update({"truncation": True, "max_length": 12})

        self.extract_all_units()

    def reset_data_activation(self):
        embd_size = self.llm.get_embd_size()
        n_layers = self.llm.get_nb_layers()
        # Move the tensor to GPU
        self.data_activation = torch.zeros(2, len(self.data), embd_size, n_layers, device=device)

    def clear_hooks(self):
        for layer in self.llm.model.model.layers:
            layer._forward_hooks.clear()

    def reset(self):
        self.clear_hooks()
        self.reset_data_activation()

    def get_hook_layers(self, idx, activation):
        def hook_layers(module, input, output):
            activation[:, :, :, idx] = output[0].squeeze(0).to(device)
        return hook_layers

    def average_tokens_layers(self, activation):
        return activation.mean(dim=1)

    def final_tokens_layers(self, activation):
        return activation[:, -1, :, :]

    def extract_layer_units(self, idx, group_name="positive", method="average"):
        self.clear_hooks()
        self.llm.model.eval()

        prompt = self.data[idx][self.group[group_name]]
        inputs = self.llm.tokenizer(prompt, **self.tokenizer_args).to(device)
        
        n_tokens = inputs["input_ids"].shape[1]
        embd_size = self.llm.get_embd_size()
        n_layers = self.llm.get_nb_layers()
        
        activation = torch.zeros(1, n_tokens, embd_size, n_layers, device=device)

        for i, layer in enumerate(self.llm.model.model.layers):
            layer.register_forward_hook(self.get_hook_layers(i, activation))

        with torch.no_grad():
            self.llm.model(**inputs)

        return self.method_fn(activation)

    def extract_all_units(self):
        self.reset()
        for idx in range(len(self.data)):
            self.data_activation[0, idx, :, :] = self.extract_layer_units(idx, "positive")
            self.data_activation[1, idx, :, :] = self.extract_layer_units(idx, "negative")

In [8]:
class ToMLocDataset(Dataset):
    """ 
    ToM Dataset for Neural Units Localization 
       """
    def __init__(self):
        loc_dir = "dataset/prompt/tom-loc"
        self.positive = [process_text_file(f"{loc_dir}/{idx}b_story_question.txt") for idx in range(1,11)]
        self.negative = [process_text_file(f"{loc_dir}/{idx}p_story_question.txt") for idx in range(1,11)]
    
    def __getitem__(self, idx):
        return self.positive[idx].strip(), self.negative[idx].strip()
    
    def __len__(self):
        return len(self.positive)

class LangLocDataset(Dataset):
    def __init__(self):
        dirpath = "dataset/prompt/langloc"
        paths = glob(f"{dirpath}/*.csv")
        vocab = set()

        data = pd.read_csv(paths[0])
        for path in paths[1:]:
            run_data = pd.read_csv(path)
            data = pd.concat([data, run_data])

        data["sent"] = data["stim2"].apply(str.lower)

        vocab.update(data["stim2"].apply(str.lower).tolist())
        for stimuli_idx in range(3, 14):
            data["sent"] += " " + data[f"stim{stimuli_idx}"].apply(str.lower)
            vocab.update(data[f"stim{stimuli_idx}"].apply(str.lower).tolist())

        self.vocab = sorted(list(vocab))
        self.w2idx = {w: i for i, w in enumerate(self.vocab)}
        self.idx2w = {i: w for i, w in enumerate(self.vocab)}

        self.positive = data[data["stim14"]=="S"]["sent"].tolist()
        self.negative = data[data["stim14"]=="N"]["sent"].tolist()

    def __getitem__(self, idx):
        return self.positive[idx].strip(), self.negative[idx].strip()
        
    def __len__(self):
        return len(self.positive)

In [11]:
class LocImportantUnits:
    def __init__(self,
                 checkpoint,
                 layers_units: torch.Tensor):
        self.model_name = checkpoint.split("/")[-1]
        self.fb_group = layers_units[0].cpu()
        self.fp_group = layers_units[1].cpu()
        self.t_values = self.welch_test()
        self.ranked_units = self.get_ranked_units()

    def welch_test(self):
        n_units = self.fb_group.shape[1]
        n_layers = self.fb_group.shape[2]

        # Reshape for Welch t-test
        fb_flattened = np.abs(self.fb_group.reshape(self.fb_group.shape[0], -1))
        fp_flattened = np.abs(self.fp_group.reshape(self.fp_group.shape[0], -1))
        # Perform the t-test along the first axis (sample dimension)
        t_stat, _ = ttest_ind(fb_flattened, fp_flattened, axis=0, equal_var=False)
        print(t_stat.shape)

        # Reshape t_stat back to (units, n_layers)
        return t_stat.reshape(n_units, n_layers)
    
    def get_ranked_units(self):
        # Get ranked matrix
        flat = self.t_values.flatten()
        sorted_indices = np.argsort(flat)[::-1]  # Sort indices in descending order
        ranked = np.empty_like(sorted_indices)
        ranked[sorted_indices] = np.arange(1, len(flat) + 1)
        # Reshape the ranked values back to the original matrix shape
        return ranked.reshape(self.t_values.shape)
    
    def get_masked_ktop(self, percentage):
        num_top_elements = int(self.t_values.size * percentage)
        # Flatten the matrix, find the threshold value for the top 1%
        flattened_matrix = self.t_values.flatten()
        threshold_value = np.partition(flattened_matrix, -num_top_elements)[-num_top_elements]

        # Create a binary mask where 1 represents the top 1% elements, and 0 otherwise
        mask_units = np.where(self.t_values >= threshold_value, 1, 0)
        return mask_units
    
    def get_random_mask(self, percentage, seed=None):
        # Set the seed for reproducibility
        if seed is not None:
            np.random.seed(seed)
        
        # Calculate the total number of units
        total_units = self.t_values.size
        num_units_to_select = int(total_units * percentage)
        
        # Create a flattened array of zeros
        mask_flat = np.zeros(total_units, dtype=int)
        
        # Randomly select indices and set them to 1
        selected_indices = np.random.choice(total_units, num_units_to_select, replace=False)
        mask_flat[selected_indices] = 1
        
        # Reshape the mask back to the original shape
        return mask_flat.reshape(self.t_values.shape)
    
    def plot_layer_percentages(self, percentage, mask_type='ktop', seed=None, save_path=None):
        """
        Plots the percentage of important units per layer.
        
        Parameters:
        - percentage (float): The top percentage of units to be considered as important.
        - mask_type (str): Type of mask to use ('ktop' for k-top mask or 'random' for random mask).
        - seed (int, optional): Random seed for reproducibility when using the random mask.
        - save_path (str, optional): Path to save the plot. If None, shows the plot.
        """
        
        # Generate the mask based on the specified mask type
        if mask_type == 'ktop':
            mask = self.get_masked_ktop(percentage)
        elif mask_type == 'random':
            mask = self.get_random_mask(percentage, seed=seed)
        else:
            raise ValueError("Invalid mask_type. Choose 'ktop' or 'random'.")
        
        # Calculate the percentage of important units for each layer
        layer_percentages = [(np.sum(layer) / layer.size) * 100 for layer in mask.T]

        # Convert to a column vector for plotting
        layer_percentages_matrix = np.array(layer_percentages).reshape(-1, 1)

        # Plot the layer percentages as a matrix with shape (number of layers, 1)
        plt.figure(figsize=(2, 8))
        cmap = plt.get_cmap('viridis')
        norm = Normalize(vmin=min(layer_percentages), vmax=max(layer_percentages))
        plt.imshow(layer_percentages_matrix, cmap=cmap, aspect='auto')
        plt.colorbar(label="Percentage of Important Units")
        plt.title(f"Percentage of Important Units per Layer ({mask_type.capitalize()} Mask, Top {percentage*100:.1f}%)")
        plt.xlabel("Layer")
        plt.ylabel("Percentage")

        # Add text annotations with adaptive color based on background brightness
        for i, perc in enumerate(layer_percentages):
            color = cmap(norm(perc))
            brightness = 0.3 * color[0] + 0.5 * color[1] + 0.2 * color[2]
            text_color = "white" if brightness < 0.5 else "black"
            plt.text(0, i, f"{perc:.1f}%", ha="center", va="center", color=text_color)

        # Configure ticks
        plt.yticks(range(len(layer_percentages)), [f"Layer {i+1}" for i in range(len(layer_percentages))])
        plt.xticks([])

        # Save the plot or show it based on the save_path parameter
        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
            print(f"Plot saved as {save_path}")
        else:
            plt.show()

In [12]:
class AblateUnits:
    def __init__(self, llm: ImportLLMfromHF, mask: Optional[np.ndarray] = None):
        self.llm = llm
        self.mask = mask
        self.layer_outputs = []

    def clear_hooks(self):
        for layer in self.llm.model.model.layers:
            layer._forward_hooks.clear()

    def get_hook_ablate(self, idx):
        def hook_ablate(module, input, output):
            mask_layer = self.mask[idx]
            unit_indices = mask_layer.nonzero()
            output[0][:,:,unit_indices] = 0
            self.layer_outputs.append(output[0].clone().to(device))
        return hook_ablate

    def ablate_units(self, prompt):
        self.clear_hooks()
        self.layer_outputs.clear()
        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        for idx, layer in enumerate(self.llm.model.model.layers):
            layer.register_forward_hook(self.get_hook_ablate(idx))

        with torch.no_grad():
            generated_tokens = self.llm.model.generate(**inputs, max_length=50, do_sample=True, top_p=0.95, top_k=50)

        decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
        return decoded_text

In [15]:
# Initialize model and localizer classes only once per session
lang_data = LangLocDataset()
tom_data = ToMLocDataset()
llm = ImportLLMfromHF(model, tokenizer)
units = LayersUnits(llm, lang_data, "final")
loc_units = LocImportantUnits(checkpoint, units.data_activation)
mask = loc_units.get_masked_ktop(0.01)
perturbation = AblateUnits(llm, mask.T)
perturbation.ablate_units("Hello! What it the capital of France?")

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


(131072,)


'Hello! What it the capital of France?\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0'

In [37]:
class EvaluateBenchmark:
    def __init__(self,
                llm: ImportLLMfromHF,
                loc_units: LocImportantUnits,
                batch_size: int=20):
        self.llm = llm
        self.loc_units = loc_units
        self.batch_size = batch_size

        # Ensure the tokenizer has a padding token
        if self.llm.tokenizer.pad_token is None:
            self.llm.tokenizer.pad_token = self.llm.tokenizer.eos_token
    
    def clear_hooks(self):
        for layer in self.llm.model.model.layers:
            layer._forward_hooks.clear()

    def get_hook_ablate(self, idx, mask):
        def hook_ablate(module, input, output):
            mask_layer = mask[idx]
            unit_indices = mask_layer.nonzero()
            output[0][:,:,unit_indices] = 0
        return hook_ablate
    
    def get_generated_tokens(self, outputs, input_length):
        # Slice generated tokens to exclude the initial prompt tokens
        generated_texts = []
        for output in outputs:
            new_tokens = output[input_length:]  # Exclude prompt tokens
            generated_texts.append(tokenizer.decode(new_tokens, skip_special_tokens=True))
        return generated_texts

    def generate_text_with_ablations(self, prompts, mask, **generate_kwargs):
        self.clear_hooks()
        inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left", truncation=True).to(self.llm.model.device)
        input_length = inputs['input_ids'].shape[1]  # Get the length of the prompt tokens
        for idx, layer in enumerate(self.llm.model.model.layers):
            layer.register_forward_hook(self.get_hook_ablate(idx, mask))

        with torch.no_grad():
            outputs = self.llm.model.generate(**inputs, max_new_tokens=12, **generate_kwargs)

        # Slice generated tokens to exclude the initial prompt tokens
        return self.get_generated_tokens(outputs, input_length)
    
    def generate_text_without_ablation(self, prompts, **generate_kwargs):
        self.clear_hooks()
        self.llm.model.eval()
        inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left", truncation=True).to(self.llm.model.device)
        input_length = inputs['input_ids'].shape[1]  # Get the length of the prompt tokens
        with torch.no_grad():
            outputs = self.llm.model.generate(**inputs, max_new_tokens=12, **generate_kwargs)

        # Slice generated tokens to exclude the initial prompt tokens
        return self.get_generated_tokens(outputs, input_length)
     
    def generate(self, prompts, mask=None):
        # Set temperature to 0 for deterministic generation
        generation_config = {
            "temperature": None,          # Explicitly unset temperature
            "top_p": None,   
            "do_sample": False,  # Ensures deterministic generation
            "pad_token_id": tokenizer.eos_token_id  # Set the padding token if necessary
        }

        if mask is None:
            return self.generate_text_without_ablation(prompts, **generation_config)
        else:
            return self.generate_text_with_ablations(prompts, mask, **generation_config)
    
    
    def experiment(self, df, pct=0.01):
        """  """
        data = df.copy()
        assess_dict = {
            "no_ablation": None,
            f"ablate_top_{pct*100}": self.loc_units.get_masked_ktop(pct).T,
            f"ablate_random1_{pct*100}": self.loc_units.get_random_mask(pct).T,
            f"ablate_random2_{pct*100}": self.loc_units.get_random_mask(pct).T,
            f"ablate_random3_{pct*100}": self.loc_units.get_random_mask(pct).T 
        }

        for key, mask in assess_dict.items():
            # Generate responses in batches and collect results
            generated_texts = []
            for i in tqdm(range(0, len(data), self.batch_size)):
                batch_prompts = data["prompt"].iloc[i:i+self.batch_size].tolist()
                generated_texts.extend(self.generate(batch_prompts, mask))
            
            data[f"generate_{key}"] = generated_texts
        return data

In [38]:
import ast
csv_file="dataset/benchmarks/ToMi/ToMi-finalNeuralTOM.csv"
df = pd.read_csv(csv_file)
# Convert the 'cands' column from string representation of lists to actual lists
df["cands"] = df["cands"].apply(ast.literal_eval)

intro_text = (
    "The following multiple choice question is based on the following story. The question "
    "is related to Theory-of-Mind. Read the story and then answer the questions. Choose the best answer "
    "from the options provided by printing it as is without any modifications."
)

df["prompt"] = df.apply(
    lambda row: f"{intro_text}\n\nStory:\n{row['story']}\n\nQuestion:\n{row['question']}\nChoose between the following options: {row['cands'][0]} or {row['cands'][1]}. The response should be contained in your first sentence.",
    axis=1
)
data = df.iloc[:20]
assess = EvaluateBenchmark(llm, loc_units, 20)
result_df = assess.experiment(data, 0.01)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:01<00:00,  1.26s/it]
100%|██████████| 1/1 [00:01<00:00,  1.30s/it]
100%|██████████| 1/1 [00:01<00:00,  1.30s/it]
100%|██████████| 1/1 [00:01<00:00,  1.30s/it]
100%|██████████| 1/1 [00:01<00:00,  1.29s/it]


In [39]:
# Set pandas to display all columns
pd.set_option('display.max_columns', None)
result_df.head(3)

Unnamed: 0,story,question,answer,i,dataGenSeq,sType,qTypeRaw,qTypeTomOrNot,qOrder,storyHasToM,answerMem,answerReal,qToMandOmniReader,answerMemOrReal,cands,correct,falseTrueBelief,factVsMind,prompt,generate_no_ablation,generate_ablate_top_1.0,generate_ablate_random1_1.0,generate_ablate_random2_1.0,generate_ablate_random3_1.0
0,James entered the living room. Hunter entered ...,Where is the pajamas really?,drawer,1,"enter_agent_0,enter_agent_1,agent_0_moves_obj,...",true_belief,reality,False,reality,False,bucket,drawer,False,reality,"[bucket, drawer]",1,True,fact,The following multiple choice question is base...,The pajamas are in the drawer. \n\nThe correc...,The ' ' ',The pajamas are in the drawer. \n\nThe correc...,The correct answer is bucket. The story impli...,The pajamas are in the drawer. The correct a...
1,Alexander entered the playroom. Jack entered t...,Where does Alexander think that Jack searches ...,suitcase,1,"enter_agent_1,enter_agent_0,agent_0_moves_obj,...",true_belief,second_order_0_no_tom,False,second_order,False,box,suitcase,False,reality,"[box, suitcase]",1,True,mind,The following multiple choice question is base...,Alexander thinks that Jack searches for the p...,' ' ' ',\n\nA) box\nB) suitcase\nC) Both,Alexander thinks that Jack searches for the p...,Alexander thinks that Jack searches for the p...
2,Jack entered the laundry. Logan entered the st...,Where will Jack look for the persimmon?,bucket,1,"enter_agent_0,agent_2_enters,enter_agent_1,age...",second_order_false_belief,first_order_0_no_tom,False,first_order,False,pantry,bucket,False,reality,"[pantry, bucket]",1,True,mind,The following multiple choice question is base...,Jack will look for the persimmon in the bucket.,' ' ' ',Jack will look for the persimmon in the bucket.,Jack will look for the persimmon in the bucket.,Jack will look for the persimmon in the bucket.
