In [None]:
# Libraries
from dotenv import load_dotenv
import os
# from tom_localizer import ImportLLMfromHF, ToMLocalizerUnits, ToMLocDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from prompt_processing import process_text_file

In [2]:
# Initialization of the environment
# Load the variable from .env
load_dotenv()
hf_access_token = os.getenv("HF_ACCESS_TOKEN")
cache_dir = os.getenv("CACHE_DIR")

In [3]:
# LLM Model
checkpoint =  "meta-llama/Llama-3.1-8B-Instruct"
# Load the LLM Model and the tokenizer
tokenizer = AutoTokenizer.from_pretrained(checkpoint, cache_dir=cache_dir, token=hf_access_token)
model = AutoModelForCausalLM.from_pretrained(checkpoint, cache_dir=cache_dir, token=hf_access_token)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
class ToMLocDataset:
    """ 
    ToM Dataset for Neural Units Localization 
       """
    def __init__(self):
        loc_dir = "dataset/prompt/tom-loc"
        self.fb_stories = [process_text_file(f"{loc_dir}/{idx}b_story_question.txt") for idx in range(1,11)]
        self.pb_stories = [process_text_file(f"{loc_dir}/{idx}p_story_question.txt") for idx in range(1,11)]
    
    def __getitem__(self, idx):
        return self.fb_stories[idx].strip(), self.pb_stories[idx].strip()
    
    def __len__(self):
        return len(self.fb_stories)

tom = ToMLocDataset()

In [5]:
class ImportLLMfromHF:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def get_embd_size(self):
        return self.model.config.hidden_size

    def get_nb_layers(self):
        return self.model.config.num_hidden_layers

class ToMLayersUnits:
    def __init__(self,
                llm: ImportLLMfromHF,
                tom_data: ToMLocDataset,
                 ):
        self.llm = llm
        self.tom_data = tom_data
        self.data_activation = None
        self.group = {"false-belief": 0, "false-photo": 1}
        self.extract_all_units()
    
    def reset_data_activation(self):
        """Initialize or reset the data activation tensor."""
        embd_size = self.llm.get_embd_size()
        n_layers = self.llm.get_nb_layers()
        self.data_activation = torch.zeros(2, len(self.tom_data), embd_size, n_layers)
    
    # Function to remove all hooks from model layers
    def clear_hooks(self):
        for layer in self.llm.model.model.layers:
            layer._forward_hooks.clear()
    
    def reset(self):
        """Clear hooks and reset activations for safe re-initialization."""
        self.clear_hooks()
        self.reset_data_activation()
    
    # Define the function to capture the output of each layer with hooks
    def get_hook_layers(self, idx, activation):
        def hook_layers(module, input, output):
            activation[:,:,:, idx] = output[0].squeeze(0)
        return hook_layers
    
    def average_tokens_layers(self, activation):
        # Calculate the average activation across tokens for each layer
        # Resulting shape: (embd_size, n_layers)
        avg_activation = activation.mean(dim=1)  # Average over tokens
        return avg_activation  # Remove batch dimension if unnecessary

    def extract_layer_units(self, idx, group_name="false-belief"):
        # Clear any previously registered hooks
        self.clear_hooks()
        # Process and tokenize the input text
        prompt = self.tom_data[idx][self.group[group_name]]
        inputs = self.llm.tokenizer(prompt, return_tensors="pt")
        n_tokens = inputs["input_ids"].shape[1] # Number of token in the input
        embd_size = self.llm.get_embd_size() # Embedding size of the model
        n_layers = self.llm.get_nb_layers() # Number of layers in the model
        # Initialize the tensor to store activations: (batch_size=1, n_tokens, embd_size, n_layers)
        activation = torch.zeros(1, n_tokens, embd_size, n_layers)

        # Register hooks on each layer, passing activation as an argument
        for i, layer in enumerate(self.llm.model.model.layers):
            layer.register_forward_hook(self.get_hook_layers(i, activation))
        
        # Pass the input through the model
        with torch.no_grad():
            self.llm.model(**inputs)
        return self.average_tokens_layers(activation)
    
    def extract_all_units(self):
        """Extract activations for all items in the dataset."""
        self.reset()  # Ensure clean state before extraction
        for idx in range(len(self.tom_data)):
            # Extract and store activation data in "False-Belief" group
            self.data_activation[0, idx, :, :] = self.extract_layer_units(idx, "false-belief")
            # Extract and store activation data in "False-Photo" group
            self.data_activation[1, idx, :, :] = self.extract_layer_units(idx, "false-photo")


In [6]:
from scipy.stats import ttest_ind
import numpy as np
import matplotlib.pyplot as plt

class LocImportantUnits:
    def __init__(self,
                 checkpoint,
                 layers_units: torch.Tensor):
        self.model_name = checkpoint.split("/")[-1]
        self.fb_group = layers_units[0]
        self.fp_group = layers_units[1]
        self.t_values = self.welch_test()
        self.ranked_units = self.get_ranked_units()

    def welch_test(self):
        n_units = self.fb_group.shape[1]
        n_layers = self.fb_group.shape[2]

        # Reshape for Welch t-test
        fb_flattened = self.fb_group.reshape(self.fb_group.shape[0], -1)
        fp_flattened = self.fp_group.reshape(self.fp_group.shape[0], -1)

        # Perform the t-test along the first axis (sample dimension)
        t_stat, _ = ttest_ind(fb_flattened, fp_flattened, axis=0, equal_var=False)

        # Reshape t_stat back to (units, n_layers)
        return t_stat.reshape(n_units, n_layers)
    
    def get_ranked_units(self):
        # Get ranked matrix
        flat = self.t_values.flatten()
        sorted_indices = np.argsort(flat)[::-1]  # Sort indices in descending order
        ranked = np.empty_like(sorted_indices)
        ranked[sorted_indices] = np.arange(1, len(flat) + 1)
        # Reshape the ranked values back to the original matrix shape
        return ranked.reshape(self.t_values.shape)
    
    def get_masked_ktop(self, percentage):
        num_top_elements = int(self.t_values.size * percentage)
        # Flatten the matrix, find the threshold value for the top 1%
        flattened_matrix = self.t_values.flatten()
        threshold_value = np.partition(flattened_matrix, -num_top_elements)[-num_top_elements]

        # Create a binary mask where 1 represents the top 1% elements, and 0 otherwise
        mask_units = np.where(self.t_values >= threshold_value, 1, 0)
        return mask_units
    
    def get_random_mask(self, percentage, seed=None):
        # Set the seed for reproducibility
        if seed is not None:
            np.random.seed(seed)
        
        # Calculate the total number of units
        total_units = self.t_values.size
        num_units_to_select = int(total_units * percentage)
        
        # Create a flattened array of zeros
        mask_flat = np.zeros(total_units, dtype=int)
        
        # Randomly select indices and set them to 1
        selected_indices = np.random.choice(total_units, num_units_to_select, replace=False)
        mask_flat[selected_indices] = 1
        
        # Reshape the mask back to the original shape
        return mask_flat.reshape(self.t_values.shape)
    
    def plot_layer_percentages(self, percentage, mask_type='ktop', seed=None, save_path=None):
        """
        Plots the percentage of important units per layer.
        
        Parameters:
        - percentage (float): The top percentage of units to be considered as important.
        - mask_type (str): Type of mask to use ('ktop' for k-top mask or 'random' for random mask).
        - seed (int, optional): Random seed for reproducibility when using the random mask.
        - save_path (str, optional): Path to save the plot. If None, shows the plot.
        """
        
        # Generate the mask based on the specified mask type
        if mask_type == 'ktop':
            mask = self.get_masked_ktop(percentage)
        elif mask_type == 'random':
            mask = self.get_random_mask(percentage, seed=seed)
        else:
            raise ValueError("Invalid mask_type. Choose 'ktop' or 'random'.")
        
        # Calculate the percentage of important units for each layer
        layer_percentages = [(np.sum(layer) / layer.size) * 100 for layer in mask.T]

        # Convert to a column vector for plotting
        layer_percentages_matrix = np.array(layer_percentages).reshape(-1, 1)

        # Plot the layer percentages as a matrix with shape (number of layers, 1)
        plt.figure(figsize=(2, 8))
        plt.imshow(layer_percentages_matrix, cmap='viridis', aspect='auto')
        plt.colorbar(label="Percentage of Important Units")
        plt.title(f"Percentage of Important Units per Layer ({mask_type.capitalize()} Mask, Top {percentage*100:.1f}%)")
        plt.xlabel("Layer")
        plt.ylabel("Percentage")

        # Add text annotations for each percentage
        for i, perc in enumerate(layer_percentages):
            plt.text(0, i, f"{perc:.2f}%", ha="center", va="center",
                    color="white" if perc < 1.3 else "black")

        # Configure ticks
        plt.yticks(range(len(layer_percentages)), [f"Layer {i+1}" for i in range(len(layer_percentages))])
        plt.xticks([])

        # Save the plot or show it based on the save_path parameter
        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
            print(f"Plot saved as {save_path}")
        else:
            plt.show()


In [7]:
from typing import Optional
class AblateUnits:
    def __init__(self,
                 llm: ImportLLMfromHF,
                 mask: Optional[np.ndarray] = None):
        self.llm = llm
        self.mask = mask
        self.layer_outputs = []
    
    # Function to remove all hooks from model layers
    def clear_hooks(self):
        for layer in self.llm.model.model.layers:
            layer._forward_hooks.clear()
    
    def get_hook_ablate(self, idx):
        def hook_ablate(module, input, output):
            mask_layer = self.mask[idx]
            unit_indices = mask_layer.nonzero()
            output[0][:,:,unit_indices] = 0
            self.layer_outputs.append(output[0].clone())
        return hook_ablate
    
    def ablate_units(self, prompt):
        self.clear_hooks()
        inputs = tokenizer(prompt, return_tensors="pt")
        # Register hooks on each layer, passing activation as an argument
        for idx, layer in enumerate(self.llm.model.model.layers):
            layer.register_forward_hook(self.get_hook_ablate(idx))
        
        with torch.no_grad():
            generated_tokens = self.llm.model.generate(**inputs, max_length=50, do_sample=True, top_p=0.95, top_k=50)
    
        # Decode generated tokens to string
        decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
        return decoded_text

In [13]:
# Initialize model and localizer classes only once per session
tom_data = ToMLocDataset()
llm = ImportLLMfromHF(model, tokenizer)
tom_units = ToMLayersUnits(llm, tom_data)
loc_units = LocImportantUnits(checkpoint, tom_units.data_activation)
mask_random = loc_units.get_random_mask(0.01)

In [14]:
perturbation = AblateUnits(llm, mask_random.T)
perturbation.ablate_units("Hello! How are you doing")

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


"Hello! How are you doing today? I hope you're doing great!\n\nI've been thinking about how to make the blog more interactive and engaging for readers, and I thought I'd reach out to you for some ideas.\n\nHere are a few"