# OOD detection applied to Hallucination Detection

 The goal is to predict if an INPUT prompt  is going to produce an hallucination or not (using OOD detection methods). For now, we don’t look at the output generated by the model, we may consider this in a second time. Retrieve ID samples:  To do this, take a general (easy) QA dataset containing questions along with their true hallucination-free answers. Feed the questions to the model. Let the model generate responses and check if the a given generated response is the same as the real hallucination-free answer. All the correct generated responses will be considered ID. More precisely, the ID dataset will consist of the embeddings of the last token of the last layer of the input (or maybe middle layer) of the correct generated responses.  Test a new sample to see if this is going to be OOD=hallucination: Take another dataset containing questions susceptible to trigger hallucinations along with the true hallucination-free answers (or no answer if the model cannot know the answer by any way and all response that the model might produce will necessarily be hallucinated). Feed a question to the model and let it generate a response. Check by comparing to the hallucination-free answer is that generated response is hallucinated or not. At the same time, apply an OOD detection method on the input question (at the last token last layer) and see if there is a correspondence between a high OOD score and a generated hallucination. 

In [1]:
#/home/lila.roig/.env/ood_env/bin/python 

In [2]:
%load_ext autoreload
%autoreload 2

## 1. Embedding Extraction

In [3]:
# import libraries
# -----------------------------------
import torch
import sys
import time 
import os 
import pickle
from functools import partial
# Add the path to the src directory
sys.path.append(os.path.abspath(".."))

In [4]:
# Define global variables
# -----------------------------------
SEED = 777 #44
BATCH_SIZE = 16 #32
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
OUTPUT_DIR = "../results/raw/TEST/" 
PLOT_DIR   = "../results/figures/" 
LAYER = -1      # (integer) - Layer from witch retrieve the embeddings 
TOKENS = "-1"  # (string) - How to retrieve the embeddings 
K_BEAMS = 1 #3
ACTIVATION_SOURCE = "generation" # can be 'generation', 'PromptGeneration'
 
if TOKENS=="0":
    EXTRACTION_MODE = "first_generated"
elif TOKENS=="-1":
    EXTRACTION_MODE = "last"
elif TOKENS=="Avg":
    EXTRACTION_MODE = "average"
elif TOKENS=="Max":
    EXTRACTION_MODE = "max"

In [5]:
# Clear memory to avoid "CUDA out of memory"
# -----------------------------------
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [6]:
# Visualize setup 
# -----------------------------------
print(f"Python version: {sys.version}")
print(f"Cuda version: {torch.version.cuda}")
num_gpus = torch.cuda.device_count()
print(f"Number of available de GPU : {num_gpus}")
for i in range(num_gpus):
    print(f"GPU {i + 1} : {torch.cuda.get_device_name(i)}")

Python version: 3.11.13 (main, Jun  4 2025, 08:57:30) [GCC 13.3.0]
Cuda version: 12.6
Number of available de GPU : 2
GPU 1 : NVIDIA L40S
GPU 2 : NVIDIA L40S


In [7]:
# Seed everything
# -----------------------------------
from src.utils.general import seed_all
seed_all(SEED)

## Load model

In [8]:
# Load model
# ----------------------------------
from src.model_loader.llama_loader import load_llama

model, tokenizer = load_llama(MODEL_NAME)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
from src.data_reader.pickle_io import load_pickle_batches
res = load_pickle_batches("/home/lila.roig/projet_ood/ood_for_hallucination_detection/results/raw/small_dataset_correct_split_allScores/id_fit_results_layer1:32:2_score_all_hidden_attn_logit_prompt_so0_eo0.pkl")



Loaded 8008 samples from: /home/lila.roig/projet_ood/ood_for_hallucination_detection/results/raw/small_dataset_correct_split_allScores/id_fit_results_layer1:32:2_score_all_hidden_attn_logit_prompt_so0_eo0.pkl


In [9]:
print(model.config._attn_implementation )
#model.config._attn_implementation = 'eager'
print(model.config._attn_implementation )

sdpa
sdpa


### Load ID dataset

For the ID general dataset, we are going to use the SQUAD 1.1 dataset: 

***SQuAD 1.1:** Comprises over 100,000 question-answer pairs derived from more than 500 Wikipedia articles. Each question is paired with a specific segment of text (a span) from the corresponding article that serves as the answer.*

In [36]:
# Load ID dataset
# -----------------------------------
from src.data_reader.squad_loader import load_id_fit_dataset 
# Total number of samples in squad v1.1: 87599, squad v2.0: 86821

id_fit_dataset = load_id_fit_dataset()
#id_fit_dataset = id_fit_dataset.shuffle(SEED) 
id_fit_dataset = id_fit_dataset.slice(idx_start=0, idx_end=1_000) # 10_000
id_fit_dataset.print_info()


===== Dataset Information =====
Dataset({
    features: ['id', 'title', 'context', 'question', 'answers', 'original_index', 'is_impossible'],
    num_rows: 1000
})
Mean ground-truth answer length: 2.25, Max length: 14
Mean context + question length: 146.50, Max length: 342


# Development of new solutions

### Modify extrat_token_activations to add features

In [37]:
from transformers import PreTrainedTokenizer, PreTrainedModel
import torch
from torch.utils.hooks import RemovableHandle
from typing import Tuple, Literal, List, Optional, Dict

### OK ###
def extract_token_activations(
    selected_layer: torch.Tensor,
    attention_mask: torch.Tensor,
    device: torch.device,
    modes: List[Literal[
        "average", "last", "max", "first_generated", 
        "token_svd_score", "feat_var"
    ]] = ["average"],
    skip_length: Optional[int] = None,
    alpha: int = 0.001,
) -> Dict[str, torch.Tensor]:
    """   
    Aggregate token-level activations over a specified span for each sequence in a batch,
    using various aggregation modes and attention mask.

    This function takes as input:
      - The layer activations (selected_layer) for each token in a batch of sequences,
      - An attention mask (attention_mask) of the same shape, where 1 indicates tokens to include
        in the aggregation and 0 marks tokens to ignore.

    The attention mask may be the original model mask, or a custom mask generated using
    `compute_offset_attention_mask` to dynamically select a sub-span of tokens.

    Parameters
    ----------
    selected_layer : torch.Tensor
        Tensor of shape (batch_size, seq_len, hidden_size) containing model activations for each token.
    attention_mask : torch.Tensor
        Attention mask of shape (batch_size, seq_len),  1 for real tokens, 0 for padding.
    device : torch.device
        Device for computation.

    modes : List[str]
        List of aggregation modes to compute. Computed using only valid tokens where attention_mask == 1.
        Supported:
        - "average": Mean activation vector across valid tokens. Shape: (batch_size, hidden_size)
        - "max": Element-wise max activation across valid tokens. Shape: (batch_size, hidden_size)
        - "last": Activation vector of last valid token in each sequence. Shape: (batch_size, hidden_size)
        - "first_generated": Activation of the first generated valid token in each sequence. Shape: (batch_size, hidden_size)
             If skip_length is provided, selects the token starting from that offset. 
        - "token_svd_score": Mean log singular value of the centered Gram matrix over tokens. Shape (batch_size,)
            The Gram matrix is computed as Gram_token = Z·J·Z^T, where J is the centering matrix on features.
            It quantifies the pairwise similarity between token representations after removing the mean value 
            of each feature across tokens. Note: This is not a classical covariance matrix.
            The log singular values quantifies the effective dimensionality or diversity of the token 
            activations: higher values reflect more diverse (less redundant) token representations, lower values 
            indicate more redundancy or alignment.
            NOTE: Implementation inpired by "LLM-Check: Investigating Detection of Hallucinations in 
            Large Language Models" (Sriramanan et al. 2024)
        - "feat_var": Diagonal of the centered feature covariance matrix (variances). Shape: (batch_size, hidden_size)

    skip_length : Optional[int]
        If provided, used to explicitly select the first generated token (useful for "first_generated" mode).
    alpha : float
        Regularization parameter added to the covariance matrix.

    Returns
    -------
    Dict[str, torch.Tensor or np.ndarray]
        Dictionary mapping each mode to its result:
            - (batch_size, hidden_size) for "average", "max", "last", "first_generated", "feat_var"
            - (batch_size,) for "token_svd_score"
    """

    batch_size, seq_len, hidden_size = selected_layer.shape
    aggregated_tokens = {}
    
    # Move to device 
    attention_mask = attention_mask.to(selected_layer.device)

    # =======================================
    # Select the first token with optional offset `skip_length`
    # =======================================
    if "first_generated" in modes:
        batch_indices = torch.arange(batch_size, device=device)
        if skip_length is not None:
            first_indices = torch.full((batch_size,), skip_length, device=device, dtype=torch.long)
        else:
            first_indices = (attention_mask == 1).float().argmax(dim=1)
        first = selected_layer[batch_indices, first_indices] # Shape: (batch_size, hidden_size)
        aggregated_tokens["first_generated"] = first

    # =======================================
    # Select the last token 
    # =======================================
    if "last" in modes:
        last_indices = attention_mask.shape[1] - 1 - attention_mask.flip(dims=[1]).float().argmax(dim=1)
        batch_indices = torch.arange(batch_size, device=device)
        last = selected_layer[batch_indices, last_indices]  # Shape: (batch_size, hidden_size)
        aggregated_tokens["last"] = last

    # =======================================
    # Apply mask and compute aggregation 
    # =======================================
    if "average" in modes or "max" in modes:
        # Add one dimension for the broadcast on hidden_size
        mask_float = attention_mask.float().unsqueeze(-1)  # (batch_size, num_valid_tokens, 1)
        # Apply the mask to the activations: zero out tokens outside the target interval
        masked = selected_layer * mask_float
        #  Count the number of selected tokens for each sequence (avoid division by zero with clamp)
        counts = mask_float.sum(dim=1).clamp(min=1e-6)
        if "average" in modes:
            # Compute the mean activation vector for each sequence over the selected interval
            avg = masked.sum(dim=1) / counts # Shape: (batch_size, hidden_size)
            aggregated_tokens["average"] = avg
        if "max" in modes:
            # Replace padding with -inf to exclude from max calculation
            masked_max = masked.masked_fill(mask_float.logical_not(), float('-inf'))
            # Extract maximum values across sequence dimension
            max_vals, _ = masked_max.max(dim=1) # Shape: (batch_size, hidden_size)
            aggregated_tokens["max"] = max_vals

    # =======================================
    # Covariance-based metrics
    # =======================================
    if any(m in modes for m in ["token_svd_score", "feat_var"]):
        token_svd_score = [] 
        feat_var = []
        
        for i in range(batch_size):
            # Select valid tokens 
            mask = attention_mask[i].bool()
            Z = selected_layer[i][mask]  # (num_valid_tokens, hidden_size)
            
            if Z.shape[0] == 0:
                feat_var.append(torch.full((hidden_size,), float('nan')))
                token_svd_score.append(float('nan'))
                continue
            
            if Z.dtype != torch.float32:
                Z = Z.to(torch.float32)

            if "token_svd_score" in modes:
                # Compute Gram matrix on tokens : Gram_token = Z·J·Z^T
                # ---------------------------------------
                # Assumes Z is in full precision
                # Center the features of Z (i.e., subtract the mean value of each feature across tokens)
                J = torch.eye(hidden_size, device=Z.device, dtype=Z.dtype) - (1 / hidden_size) * torch.ones(hidden_size, hidden_size, device=Z.device, dtype=Z.dtype)
                # The Gram matrix Gram_token reflects the inner products (similarities) between tokens
                Gram_token = torch.matmul(torch.matmul(Z, J), Z.t()) # (num_valid_tokens, num_valid_tokens)
                # Regularization for stabilization
                Gram_token = Gram_token + alpha * torch.eye(Gram_token.shape[0], device=Z.device, dtype=Z.dtype)
            
                # Singular value decomposition (SVD) of the token Gram matrix
                # ---------------------------------------
                if Gram_token.dtype != torch.float32:
                    Gram_token = Gram_token.to(torch.float32)
                token_svdvals = torch.linalg.svdvals(Gram_token) # Singular Value Decomposition
                token_eigscore = torch.log(token_svdvals).mean()  # mult by 2 missing from the paper? 
                token_svd_score.append(token_eigscore)

            if "feat_var" in modes:
                # Compute covariance matrix on features 
                # ---------------------------------------
                Z_feat_centered = Z - Z.mean(dim=0, keepdim=True) # (num_valid_tokens, hidden_size)
                Cov_feat = (Z_feat_centered.t() @ Z_feat_centered) / max(1, Z.shape[0] - 1) # (hidden_size, idden_size)
                Cov_feat += alpha * torch.eye(Z.shape[1], device=Z.device, dtype=Z.dtype)
                feat_var.append(Cov_feat.diag())
            
        # Return scores
        # ---------------------------------------
        if "feat_var" in modes:
            aggregated_tokens["feat_var"] = torch.stack(feat_var, dim=0) # (batch_size, hidden_size) 
        if "token_svd_score" in modes:
            aggregated_tokens["token_svd_score"] = torch.stack(token_svd_score) # (batch_size,) 
        
        # Put everything on CPU
        # ---------------------------------------
        for key in aggregated_tokens:
            aggregated_tokens[key] = aggregated_tokens[key].detach().cpu()

    return aggregated_tokens


In [38]:
import torch
import numpy as np

import torch
import numpy as np
from typing import List, Literal

### OK ### 
def compute_attn_eig_prod(
    prompt_attentions: torch.Tensor,
    generation_attentions: List[torch.Tensor],
    prompt_attention_mask: torch.Tensor,
    generation_attention_mask: torch.Tensor,
    mode: Literal["prompt", "generation", "promptGeneration"] = "prompt"
) -> np.ndarray:
    """
    Compute a mean log-diagonal attention score (eigenvalue-inspired) for a single layer's 
    attention map, using attention mask. 
    
    NOTE: Implementation inspired by 
    "LLM-Check: Investigating Detection of Hallucinations in Large Language Models"
    (Sriramanan et al. 2024)

    Parameters
    ----------
    prompt_attentions: torch.Tensor
        Tensor of shape (batch_size, n_heads, prompt_len, prompt_len)
        Self-attention over the prompt tokens. 
    generation_attentions: list of torch.Tensor
        List of tensors of shape (batch_size, n_heads, 1, prompt_len + t)
        Self-attention for each generated token at generation step t (t >= 1).
    prompt_attention_mask: torch.Tensor
        Tensor of shape (batch_size, prompt_len), 1 where token valid, 0 for padding.
    generation_attention_mask: torch.Tensor  
        Tensor of shape (batch_size, gen_len), 1 where token valid, 0 for padding.
    mode : str, optional
        Specifies which part of the attention map to use for the score computation.
        Must be one of the following:
        - "prompt":
            Only uses the prompt self-attention map (prompt_attentions). 
            It is a matrix of shape (batch_size, n_heads, prompt_len, prompt_len).
            The diagonal (i.e., self-attention values per token) is extracted,
            then the log is taken, followed by a mean over prompt tokens and sum over heads.
        - "generation":
            Only uses the generated self-attention maps (generation_attentions).
            Each tensor in generation_attentions has shape (batch_size, n_heads, 1, prompt_len + t),
            where t is the generation step. 
            Instead of concatenating these tensors to obtain the generation attention matrix, 
            for each step, we directly take the last value along the last axis (i.e., the self-attention
            of the newly generated token). These values are stacked across time steps, then we take the log,
            compute the mean over time, and sum over heads.
        - "promptGeneration":
            Combines the diagonals from both the prompt and generation attention maps as described above
            for "prompt" and "generation" mode. The two diagonals are concatenated along the token/time axis, 
            then the log is taken, followed by a mean across all tokens and a sum over heads.
            Note: we do **not** concatenate the full prompt and generation attention matrices,
            since the diagonal of the combined matrix would only include values from the prompt attention
            due to mismatched matrix shapes.

    Returns
    -------
    np.ndarray
        A NumPy array of shape (batch_size,), where each value is the per-sample attention score.
        The score is summed across heads and averaged across tokens (in log-space).
    """
    if mode not in ("prompt", "generation", "promptGeneration"):
        raise ValueError(f"Invalid mode: {mode}. Must be 'prompt', 'generation' or 'promptGeneration'.")

    batch_size, n_heads = prompt_attentions.shape[:2]
    if generation_attentions is not None:
        gen_len = len(generation_attentions)    
    diag_blocks = []

    # Move to device
    device = prompt_attentions.device
    prompt_attention_mask = prompt_attention_mask.to(device)
    if generation_attention_mask is not None:
        generation_attention_mask = generation_attention_mask.to(device)

    # ==============================
    # Prompt mode or combined
    # ==============================
    if mode in ("prompt", "promptGeneration"):
        # Extract diagonal of prompt attentions
        prompt_diag = torch.diagonal(prompt_attentions, dim1=-2, dim2=-1) # (batch_size, n_heads, prompt_len)
        # Expand prompt mask to (batch_size, n_heads, prompt_len)
        p_mask_ext = prompt_attention_mask.unsqueeze(1).expand(-1, n_heads, -1)
        diag_blocks.append(prompt_diag)

    # ==============================
    # Generation mode or combined
    # ==============================
    if mode in ("generation", "promptGeneration") and gen_len > 0:
        # For each generation step, take the last value along last dim.
        gen_diag_steps = [attn[..., -1].squeeze(-1) for attn in generation_attentions] # list of (batch_size, n_heads)
        # Stack along time axis (= newly generated tokens)
        gen_diag = torch.stack(gen_diag_steps, dim=-1) if gen_diag_steps else None # (batch_size, n_heads, gen_len)
        # Expand generation mask to (batch_size, n_heads, gen_len)
        g_mask_ext = generation_attention_mask.unsqueeze(1).expand(-1, n_heads, -1)
        if gen_diag is not None:
            diag_blocks.append(gen_diag)


    # Concatenate diagonals along tokens/time dim
    all_diags = torch.cat(diag_blocks, dim=-1) # (batch_size, n_heads, N) where N = prompt_len + n_generated (or a subset)
    # Build full mask concatenated similarly: (batch_size, n_heads, N)
    if mode == "prompt":
        full_mask = p_mask_ext # (batch_size, n_heads, prompt_len)
    elif mode == "generation":
        full_mask = g_mask_ext  # (batch_size, n_heads, gen_len)
    else:  # "promptGeneration"
        full_mask = torch.cat([p_mask_ext, g_mask_ext], dim=-1)  # (batch_size, n_heads, total_len)

    # ==============================
    # Compute attention eigen product, ignoring padding tokens 
    # ==============================
    # Clamp very small values to avoid log(0)
    all_diags = all_diags.clamp(min=1e-6)
    # Compute log
    log_all_diags = torch.log(all_diags) # (batch_size, n_heads, N)
    # Mask out padding tokens by zeroing out their logs
    masked_log_all_diags = log_all_diags * full_mask # (batch_size, n_heads, N)
    # Count valid tokens per batch and head to compute mean properly (avoid div by zero)
    valid_token_counts = full_mask.sum(dim=-1).clamp(min=1) # (batch_size, n_heads)
    # Mean log diag over valid tokens dimension (N)
    mean_log_diag = masked_log_all_diags.sum(dim=-1) / valid_token_counts  # (batch_size, n_heads)
    # Sum over heads to get final per-sample scores
    scores = mean_log_diag.sum(dim=-1).cpu().numpy() # (batch_size,)

    return scores  # (batch_size,)


In [39]:
### OK ###
import torch
from torch import nn 
from typing import Callable, Optional, Tuple, Unpack
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward, repeat_kv
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
logger = logging.get_logger(__name__)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

"""
**Problem**
When working with Hugging Face's Llama models, I needed to extract attention weights for analysis purposes. 
However, the default "eager" attention implementation, which exposes attention weights, caused instability when running. 
Specifically, using the "eager" backend resulted in hidden states containing NaN values—sometimes even before the 
computation of Q, K, V. This led to CUDA crashes or completely invalid results.

Switching to the "sdpa" attention backend resolved the numerical instability: with sdpa, there were no NaNs in 
hidden states, and the model ran stably, even in challenging configurations. However, sdpa computes attention 
weights inside a fused, highly optimized kernel and does not expose them—making it impossible to retrieve attention
 maps for analysis.

Trying to "fix" eager by forcing float32 on hidden states did not resolve the core issue, since the rest of the model
(and its layers) expects float16—leading to incompatibilities and further errors. Thus, neither backend offered both
 stability and transparency.

**Solution**
Implement a custom patch for the LlamaAttention forward method, but only on the specific layers where we wanted to 
access attention weights. The main computation of hidden states uses the stable backend (sdpa by default). 
This ensures the forward pass and generated sequences remain numerically stable.
In parallel, the patch computes attention weights using the "eager" mechanism, but solely to extract and return them 
for inspection. These weights are not used in the model's forward pass and do not affect generation, so any instabilities 
or NaN handling for these analytical values do not impact the model's outputs.

Thanks to this solution, we can now reliably run generation using Llama and access the true attention weights for 
chosen layers, benefiting both from the stability of "sdpa" and the interpretability of the "eager" backend, without 
compromising model correctness.
"""

def custom_eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

    return attn_weights

def patched_LlamaAttention_forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:  
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
        
        # ========================================================
        # [1] Forward pass using main attention backend (sdpa / flash)
        # This is the output used in the model's autoregressive loop.
        # These implementations are optimized (for memory + stability).
        # Does not compute attn_weights.
        # ========================================================
        attn_output, _ = attention_interface(
                self,
                query_states,
                key_states,
                value_states,
                attention_mask,
                dropout=0.0 if not self.training else self.attention_dropout,
                scaling=self.scaling,
                **kwargs,
            )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)

        # ========================================================
        # [2] Parallel computation of attention weights using eager attention
        # This is to retrieve attention weights only (not used in forward loop)
        # It is more numerically unstable (NaN possible with fp16)
        # ========================================================
        try:
            attn_weights = custom_eager_attention_forward(
                self,
                query_states, 
                key_states, 
                attention_mask,
                dropout=0.0 if not self.training else self.attention_dropout,
                scaling=self.scaling, 
                **kwargs,
            )

            # Replace NaNs (if any) by 0.0 (no attention)
            if torch.isnan(attn_weights).any():
                print("[WARN] NaNs detected in attn_weights — replacing with 0.0 (no renormalization)")
                attn_weights = attn_weights.masked_fill(torch.isnan(attn_weights), 0.0)

        except Exception as ex:
            print(f"[ERROR] Exception in custom_eager_attention_forward: {ex}")
            attn_weights = None

        return attn_output, attn_weights


In [40]:
from transformers import PreTrainedTokenizer, PreTrainedModel
import torch
from torch.utils.hooks import RemovableHandle
from typing import Tuple, Literal, List, Optional, Dict

from typing import List, Tuple, Dict
from torch.utils.hooks import RemovableHandle
from transformers import PreTrainedModel

# OLD VERSION: HOOKS ONLY 1 LAYER
def register_generation_attention_hook(
    model: PreTrainedModel,
    captured_attn_list: List[torch.Tensor],
    layer_idx: int = -1
) -> Tuple[RemovableHandle, dict]:
    """
    Attaches a forward hook to a specific Llama layer's self-attention module 
    to capture attention maps (weights) during autoregressive text generation.

    Parameters
    ----------
    model : PreTrainedModel
        The Hugging Face causal language model (e.g., Llama 2).
    captured_attn_list : List[torch.Tensor]
        List which will receive attention tensors after each decoding step.
        Each tensor: (batch_size * num_beams, n_heads, tgt_seq_len, src_seq_len)
    layer_idx : int
        Index of the layer to hook. Defaults to -1 (last layer).

    Returns
    -------
    RemovableHandle
        Call handle.remove() after generation to cleanly remove the hook.
    call_counter : dict
        Counts how many times the hook fired.
    """
    # Raise error if layer_idx not in correct range
    num_layers = len(model.model.layers)
    
    call_counter = {"count": 0}
    
    # Pick correct index if -1 given
    idx = layer_idx if layer_idx != -1 else num_layers - 1

    def attn_hook_fn(module, input, output):
        """
        Hook: captures the attention weights after the forward pass.
        For Llama (transformers >=4.31/hf), output is a tuple:
        (attn_output, attn_weights)
        """
        call_counter["count"] += 1
        # HuggingFace Llama2 attention forward: output[1] are attn weights
        attn_weights = output[1]  # (batch * num_beams, n_heads, tgt_seq_len, src_seq_len)
        captured_attn_list.append(attn_weights) #.detach()

    # The attention submodule for Llama: 
    attention_module = model.model.layers[idx].self_attn

    # Register hook on the Attention block
    # When Pytorch pass through this layer during forward pass, it also execute attn_hook_fn.
    handle = attention_module.register_forward_hook(attn_hook_fn)
    return handle, call_counter


### OK ###
# NEW VERSION: HOOKS SEVERAL LAYERS 
def register_generation_attention_hook(
    model: PreTrainedModel,
    captured_attn_lists: List[List[torch.Tensor]],
    layers_idx_list: List[int]
) -> Tuple[List[RemovableHandle], List[Dict[str, int]]]:
    """
    Attaches forward hooks to multiple specified Llama layers' self-attention modules
    to capture attention maps (weights) during autoregressive text generation.

    Parameters
    ----------
    model : PreTrainedModel
        The Hugging Face causal language model (e.g., Llama 2).
    captured_attn_lists : List[List[torch.Tensor]]
        A list containing one list per hooked layer.
        Each inner list will receive attention tensors after each decoding step.
        Each tensor shape: (batch_size, n_heads, tgt_seq_len, src_seq_len)
    layers_idx_list : List[int]
        List of indices of layers to hook.
        Use -1 to denote the last layer.

    Returns
    -------
    List[RemovableHandle]
        List of handle objects to remove hooks after generation by calling `handle.remove()`.
    List[Dict[str, int]]
        List of counters (dicts with key 'count') storing how many times each hook fired.
    """
    num_layers = len(model.model.layers)
    handles = []
    call_counters = [] # count how many times the hook is triggered

    # Raise error if layer_idx not in correct range
    for i, idx in enumerate(layers_idx_list):
        if idx == -1:
            idx = num_layers - 1
        if not (0 <= idx < num_layers):
            raise ValueError(f"`layer_idx` must be -1 or in [0, {num_layers - 1}], but got {idx}.")

        call_counter = {"count": 0}
        call_counters.append(call_counter)

        def attn_hook_fn(module, input, output, call_counter=call_counter, list_idx=i):
            """Hook: captures the attention weights after the forward pass.
            For Llama (transformers >=4.31/hf), output is a tuple:
            (attn_output, attn_weights)"""
            call_counter["count"] += 1
            attn_weights = output[1]  # (batch, n_heads, tgt_seq_len, src_seq_len)
            captured_attn_lists[list_idx].append(attn_weights.detach())

        # The attention submodule for Llama
        attention_module = model.model.layers[idx].self_attn
        # Register hook on the Attention block
        # When Pytorch pass through this layer during forward pass, it also execute attn_hook_fn.
        handle = attention_module.register_forward_hook(attn_hook_fn)
        handles.append(handle)

    return handles, call_counters


In [41]:
from typing import List, Tuple
from torch.utils.hooks import RemovableHandle

# OLD VERSION: HOOKS ONLY 1 LAYER
def register_generation_activation_hook(
    model: PreTrainedModel,
    captured_hidden_list: List[torch.Tensor],
    layer_idx: int = -1
) -> Tuple[RemovableHandle, dict]:
    """
    Attaches a forward hook to a specific transformer layer to capture hidden states
    during autoregressive text generation i.e., at each decoding step.
    (more memory-efficient than using output_hidden_states=True).
    Transformer layer = self-attention + FFN + normalization.

    Parameters
    ----------
    model : PreTrainedModel
        The Hugging Face causal language model (e.g., GPT, LLaMA).
    captured_hidden_list : List[torch.Tensor]
        A list that will be filled with hidden states for each generation step. 
        Each tensor has shape (batch_size * num_beams, seq_len, hidden_size).
    layer_idx : int
        Index of the transformer block to hook. Defaults to -1 (the last layer).
        Use a positive integer if you want to hook an intermediate layer instead.

    Returns
    ----------
    RemovableHandle : A handle object
        Call `handle.remove()` after generation to remove the hook.
    call_counter : int 
        Stores the number of times the hook is activated.
    """
    # Raise error if layer_idx not in correct range
    num_layers = len(model.model.layers)
    if not (layer_idx == -1 or 0 <= layer_idx < num_layers):
        raise ValueError(
            f"`layer_idx` must be -1 or in [0, {num_layers - 1}], but got {layer_idx}."
        )
    
    call_counter = {"count": 0} # count how many times the hook is triggered

    def hook_fn(module, input, output):
        """Function called automatically by PyTorch just after
            the layer has produced its output during the forward pass."""
        
        call_counter["count"] += 1 
        # output is a tuple (hidden_states,) → keep [0]
        if layer_idx == -1:
            # Capture the final normalized output 
            captured_hidden_list.append(model.model.norm(output[0]).detach())  # post RMSNorm!
        else:
            # Capture raw hidden states before layer normalization
            captured_hidden_list.append(output[0].detach()) #### TEST #### 
    
    # Register hook on the transformer block
    # When Pytorch pass through this layer during forward pass, it also execute hook_fn.
    handle = model.model.layers[layer_idx].register_forward_hook(hook_fn)
    
    return handle, call_counter

### OK ###
# NEW VERSION: HOOKS SEVERAL LAYERS 
def register_generation_activation_hook(
    model: PreTrainedModel,
    captured_hidden_lists: List[List[torch.Tensor]],  
    layers_idx_list: List[int]
) -> Tuple[List[RemovableHandle], List[Dict[str,int]]]:
    """
    Attaches a forward hook to a specific transformer layer to capture hidden states
    during autoregressive text generation i.e., at each decoding step.
    (more memory-efficient than using output_hidden_states=True).
    Transformer layer = self-attention + FFN + normalization.

    Parameters
    ----------
    model : PreTrainedModel
        The Hugging Face causal language model (e.g., GPT, LLaMA).
    captured_hidden_lists : List[List[torch.Tensor]]
        A list containing one list per hooked layer.
        Each inner list will be filled with hidden states for each generation step,
        each tensor of shape (batch_size, seq_len, hidden_size).
    layers_idx_list : List[int]
        List of transformer block indices to hook.
        Use -1 to denote the last layer.
    
    Returns
    -------
    List[RemovableHandle]
        List of handle objects to remove hooks after generation by calling `handle.remove()`.
    List[Dict[str, int]]
        List of counters (dicts with key 'count') storing how many times each hook was activated.
    """
    handles = []
    call_counters = [] # count how many times the hook is triggered

    # Raise error if layer_idx not in correct range
    num_layers = len(model.model.layers)
    for idx in layers_idx_list:
        if not (idx == -1 or 0 <= idx < num_layers):
            raise ValueError(f"`layer_idx` must be -1 or in [0, {num_layers-1}] but got {idx}")

        call_counter = {"count": 0}  
        call_counters.append(call_counter)

        def hook_fn(module, input, output, call_counter=call_counter, idx=idx):
            """Function called automatically by PyTorch just after
            the layer has produced its output during the forward pass."""
            
            call_counter["count"] += 1
            # output is a tuple (hidden_states,) -> keep [0]
            if idx == -1:
                # Capture the final normalized output 
                captured_hidden_lists[layers_idx_list.index(idx)].append(
                    model.model.norm(output[0]).detach())  # post RMSNorm!
            else: 
                # Capture raw hidden states before layer normalization
                captured_hidden_lists[layers_idx_list.index(idx)].append(output[0].detach())

        # Register hook on the transformer block
        # When Pytorch pass through this layer during forward pass, it also execute hook_fn.
        handle = model.model.layers[idx].register_forward_hook(hook_fn)
        handles.append(handle)

    return handles, call_counters


from typing import List, Dict

### OK ###
def verify_call_counters(call_counters: List[Dict[str, int]], name: str = "hooks") -> None:
    """
     Checks that all call counters are > 0 and equal.

    Args:
        call_counters: List of dictionaries with the key 'count'.
        name: Descriptive name for the error message.
    
    Raises:
        RuntimeError if a counter is 0 or if the counters differ.
    """
    if not all(counter['count'] > 0 for counter in call_counters):
        raise RuntimeError(f"At least one {name} did not capture any events.")
    
    counts = [counter['count'] for counter in call_counters]
    if len(set(counts)) > 1:
        raise RuntimeError(f"{name.capitalize()} have inconsistent call counts: {counts}")


In [51]:
import torch
import numpy as np

### OK ###
def compute_perplexity(
        prompt_logits: torch.Tensor, 
        gen_logits: torch.Tensor,
        prompt_ids: torch.Tensor, 
        gen_ids: torch.Tensor,
        prompt_attention_mask: torch.Tensor,
        gen_attention_mask: torch.Tensor,
        mode: Literal["prompt", "generation", "promptGeneration"] = "prompt",
        prepend_last_prompt_logit: bool = False,
        min_k: float = None,
    ) -> np.ndarray:
    """
    Computes the per-sample perplexity of language model outputs using logits 
    and corresponding input token IDs. Logits maked by 0 in the attention mask 
    are ignored in the computation of the perplexity. 
    If `min_k` is provided,
    it filters the lowest probabilities to compute a restricted perplexity.

    Perplexity is defined as:
        Perplexity = exp(- mean(log P(token_i | context))) 
        where token_i is the next token actually predicted

    NOTE: This implementation is inspired by:
    "LLM-Check: Investigating Detection of Hallucinations in Large Language Models"
    (Sriramanan et al., 2024)

    Parameters
    ----------
    prompt_logits : torch.Tensor
        Tensor of shape (batch_size, prompt_len, vocab_size) 
        These are the model's output logits obtained from a standard forward pass over the prompt sequence.
    gen_logits : torch.Tensor
        Tensor of shape (batch_size, gen_len, vocab_size).
        These are the logits obtained during autoregressive decoding using `model.generate()`.
    prompt_ids : torch.Tensor
        Tensor of shape (batch_size, prompt_len), containing the input token IDs for the prompt.
    gen_ids : torch.Tensor
        Tensor of shape (batch_size, gen_len), containing the token IDs generated by the model.
    prompt_attention_mask: torch.Tensor
        Tensor of shape (batch_size, prompt_len), 1 where token valid, 0 for padding.
    gen_attention_mask: torch.Tensor  
        Tensor of shape (batch_size, gen_len), 1 where token valid, 0 for padding.
    mode : str, optional
        One of {"prompt", "generation", "promptGeneration"}:
        - "prompt": compute perplexity only over the prompt.
        - "generation": compute perplexity only over the generated tokens.
        - "promptGeneration": compute perplexity over both prompt and generation.
    prepend_last_prompt_logit : bool, optional
        If True, appends the last logit from the prompt to the beginning of the generation logits.
        This is useful when generation logits were computed manually from hidden states 
        and are therefore shifted by one position (they lack the first prediction step).
        => see Notes 2)C) below for more detail. Default is False.
        Carreful! The gen_attention_mask must match.
    min_k : float, optional
        Optional value between 0 and 1. If specified, only the bottom-k lowest-probability
        tokens are used for perplexity calculation.

    Returns
    --------
        np.ndarray: Per-sample perplexity scores of shape (batch_size,)

    Notes
    -----
    1) This function computes a "Pseudo Perplexity".

        The Standard Perplexity requires ground truth tokens:
            PPL = exp(-1/N ∑_{t=1}^N log p(w_real_t | w_real_{<t})) where w_real_t are the true next tokens

        In our case, we are in pure generation mode (equivalent to teacher forcing on the generated text)
        and we don't have acces the real tokens. We thefore compute the Pseudo Perplexity:
            PPL_gen = exp(-1/N ∑_{t=1}^N log p(w_gen_t | w_gen_{<t})) where w_gen_t are the generated next tokens
        This measures the internal consistency of the model, and how well the model finds its own generation probable. 
    
    2) About token shifting in autoregressive models:

        A) When extracting prompt logits with with a standard autoregressive forward pass:
            Example: 
            prompt_outputs = model(inputs['input_ids'], output_logits=True)
            prompt_logits = outputs.logits # Tensor of shape (batch_size, promp_len, vocab_size)

            - The logit at position *t* predicts the token at position *t+1*.
            - The first token has no preceding context and is not predicted.
            - When computing log-probabilities, we must **shift the targets one position to the left** 
            to correctly align logits with target tokens.
            
            Example: Suppose we have a sequence of tokens (with their token IDs):
            | Index | Token | ID  |
            |-------|-------|-----| - The model produces logits at positions 0, 1, 
            | 0     | A     | 10  | and 2 to predict the tokens B, C, and D, respectively.
            | 1     | B     | 29  |
            | 2     | C     | 305 |  - The logits at position 0 are used to predict
            | 3     | D     | 24  |  token B (ID 29).

        B) When extracting gen logits during during generation:
            Example: 
            gen_outputs = model.generate(**inputs, max_new_tokens=10, output_logits=True) # gen_len=max_new_tokens
            gen_logits = torch.stack(outputs.logits,dim=1)  # Tensor of shape (batch_size, gen_len, vocab_size)
            
            - The logit at time step *t* predicts the token generated at position *t*.
            - Each logit already corresponds to the prediction of the token at this step 
            - No shifting is needed in this case.

        REMARK:
        The last prompt logit (i.e., `prompt_logits[:, -1, :]`) predicts the first generated token.
        This means: prompt_logits[:, -1, :] = gen_logits[:, 0, :]

        C) When computing gen_logits directly from the activations (no build-in methods) 
            Example:
            computed_gen_logits = model.lm_head(gen_hidden_states) 
            # Tensor of shape (batch_size, gen_len-1, vocab_size), gen_len=max_new_tokens
            Typically, `gen_hidden_states` is a Tensor of shape (batch_size, gen_len-1, hidden_size) 
            computed from gen_outputs.hidden_states. It has shape `gen_len-1`
            because there is no hidden state for the final generated token.

            - Here, the logit at position *t* predicts the generated token at position *t+1* !! 
            - Since the last prompt_logits = the first gen_logits, we repend the last prompt
            logit to the beginning of computed_gen_logits with `prepend_last_prompt_logit=True`
            - We recover case B) with the full gen_logits (as would be returned by model.generate) 
            - We can use the same gen_attenion_mask as in case B)

        D) When computing prompt_logits directly from the activations (no build-in methods) 
            Example:
            computed_prompt_logits = model.lm_head(prompt_hidden_states) 
            # Tensor of shape (batch_size, prompt_len, vocab_size),
            Typically, `prompt_hidden_states` is a Tensor of shape (batch_size, prompt_len, hidden_size) 

            - The logit at position *t* predicts the token at position *t+1* => recover case A)
            - We can use the same prompt_attenion_mask as in case A)

        Summary of alignment:
            - A) Prompt Logits from forward: 
                logit at position *t* predicts token at position *t+1* -> shift targets left.
            - B) Generation Logits from generate: 
                logit at position *t* predicts token at position *t* -> no shift.
            - C) Computed Generation Logits: 
                set `append_last_prompt_logit=True` and no shift needed -> go back to case B)
            - D) Computed Prompt Logits:
                similar to case A), nothing to do. 

        NOTE: help from issue https://github.com/huggingface/transformers/issues/29664
    """  

    if min_k is not None:
        if min_k < 0 or min_k > 1: raise ValueError("min_k must be between 0 and 1")

    if mode not in ('prompt','generation','promptGeneration'):
        raise ValueError("mode must be in {'prompt','generation','promptGeneration'}")
    
    # ==============================
    # Move to device
    # ==============================
    prompt_logits = prompt_logits.to(prompt_attention_mask.device)
    if gen_logits is not None:
        gen_logits = gen_logits.to(gen_attention_mask.device)

    # ==============================
    # Prepend last logit of prompt to the generation logits if specifed
    # ==============================
    if prepend_last_prompt_logit:
        last_prompt_logit = prompt_logits[:, -1:, :] # (batch_size, 1, vocab_size)
        gen_logits = torch.cat([last_prompt_logit, gen_logits], dim=1) # (batch_size, gen_len+1, vocab_size)
           
    # ==============================
    # Apply softmax over vocabulary dimension and take log to get log-probabilities
    # ==============================
    prompt_log_probs = torch.log_softmax(prompt_logits, dim=-1)  # (batch_size, prompt_len, vocab_size)
    if gen_logits is not None:
        gen_log_probs = torch.log_softmax(gen_logits, dim=-1)    # (batch_size, gen_len, vocab_size)

    # ==============================
    # Extraction of prompt log-probs
    # ==============================
    if mode in ("prompt", "promptGeneration"):
        # In prompt: logit at position t predicts token at t+1 (requires shifting)
        # Remove first token from target (no context to predict it)
        prompt_target_tokens = prompt_ids[:, 1:] # (batch_size, prompt_len - 1)

        prompt_attention_mask_shifted = prompt_attention_mask[:, 1:]  # (batch_size, prompt_len - 1)

        # Remove last logit position (since it predicts next token)
        prompt_pred_log_probs = prompt_log_probs[:, :-1, :] # shape: (batch_size, prompt_len - 1, vocab_size)

        # Retrieves, for each position and each batch, the log-probability corresponding to the next token 
        # (the one in target_tokens) from all the probas on the vocabulary.
        prompt_token_log_probs = prompt_pred_log_probs.gather(
            dim=2, index=prompt_target_tokens.unsqueeze(-1)
            ).squeeze(-1) # shape: (batch_size, prompt_len - 1)
      
        # Mask paddings
        prompt_token_log_probs = prompt_token_log_probs * prompt_attention_mask_shifted
        
    # ==============================
    # Extraction of generation log-probs
    # ==============================
    if mode in ("generation", "promptGeneration"):
        # In generation: logit at position t predicts token at position t (no shift needed)
        gen_token_log_probs = gen_log_probs.gather(
            dim=2, index=gen_ids.unsqueeze(-1)
            ).squeeze(-1)  # shape: (batch_size, gen_len)
        
        # Mask paddings
        gen_token_log_probs = gen_token_log_probs * gen_attention_mask
    

    # ==============================
    # Select log-probs according to mode
    # ==============================
    if mode == "promptGeneration":
        # Last logit of prompt from the forward pass == first logit of generation from `model.generate()`. 
        # To compute perplexity over the full sequence:
        # - Use prompt_token_log_probs (excluding final prompt token)
        # - Use gen_token_log_probs from generation
        # Concatenate both to form a complete sequence of predicted log-probs
        token_log_probs = torch.cat(
            [prompt_token_log_probs, gen_token_log_probs],  
            dim=1) # (batch_size, prompt_len - 1 + gen_len)
        total_mask = torch.cat(
            [prompt_attention_mask_shifted, gen_attention_mask],
            dim=1) # (batch_size, prompt_len - 1 + gen_len)
    
    elif mode == "prompt":
        token_log_probs = prompt_token_log_probs    # (batch_size, prompt_len - 1)
        total_mask = prompt_attention_mask_shifted  # (batch_size, prompt_len - 1)
    
    elif mode == "generation":
        token_log_probs = gen_token_log_probs  # (batch_size, gen_len)
        total_mask = gen_attention_mask        # (batch_size, gen_len)

    # ==============================
    # Compute Perplexity ignoring padded tokens
    # ==============================
    eps = 1e-12  # to avoid division by zero

    # Optionally focus only on the k% hardest tokens (lowest log-probs)
    if min_k is not None:
        # Keep only the min_k fraction of tokens with the lowest log-probs 
        k = int(min_k * token_log_probs.size(1))  # number of tokens to keep per sample
        
        # Exclude padding tokens from topk selection
        masked_log_probs = token_log_probs.clone()
        masked_log_probs[total_mask == 0] = 1e6  

        # Use topk with largest=False to get the k tokens with the lowest log-probabilities
        topk_vals, _ = torch.topk(masked_log_probs, k=k, dim=1, largest=False)

        # Compute perplexity using only the selected subset
        ppls = torch.exp(-topk_vals.mean(dim=1))

    else:
        # Compute perplexity over all predicted tokens
        sum_log_probs = (token_log_probs * total_mask).sum(dim=1)
        count = total_mask.sum(dim=1).clamp(min=eps)
        mean_log_prob = sum_log_probs / count
        ppls = torch.exp(-mean_log_prob)

    return ppls.cpu().numpy()




In [52]:

import torch
import torch.nn.functional as F
import numpy as np

### OK ###
def compute_logit_entropy(
    prompt_logits: torch.Tensor,
    gen_logits: torch.Tensor,
    prompt_attention_mask: torch.Tensor,
    gen_attention_mask: torch.Tensor,
    mode: str = "prompt",
    prepend_last_prompt_logit: bool = False,
    top_k: int = None,
    window_size: int = None,
    stride: int = None
) -> np.ndarray:
    """
    Computes the per-sample entropy of a language model's output distributions
    using its logits and attention masks.
    For each token position, the function computes the entropy of the softmax distribution
    over the vocabulary. Entropy is averaged over the valid tokens (i.e., those marked
    as 1 in the attention mask). If `top_k` is specified, the entropy is computed only
    over the top-k logits (highest values) for each position.

    Entropy is defined as:
        Entropy = -Sum_i p_i * log(p_i)
        where p_i = softmax(logits)_i, i=1..vocab_size
    
    There are two main usage patterns:
      - Classic token-level average entropy (if window_size is None): computes the per-token entropy over the
        sequence, averages over all valid tokens per sample (optionally using top_k).
      - Windowed maximum mean entropy (if window_size is specified): slides a window of width `window_size`
        and stride `stride` (default equals window_size: non-overlapping windows, else user-specified) across
        the sequence of token entropies, and returns the maximum mean entropy observed in any window for each sample.

    Padding tokens are always ignored (via the provided attention masks); only windows where all tokens are valid
    are considered in the windowed mode.

    NOTE: This implementation is inspired by:
    "LLM-Check: Investigating Detection of Hallucinations in Large Language Models"
    (Sriramanan et al., 2024)
    
    Parameters
    ----------
    prompt_logits : torch.Tensor
        Tensor of shape (batch_size, prompt_len, vocab_size).
        These are the model's output logits obtained from a standard forward pass over the prompt sequence.
    gen_logits : torch.Tensor
        Tensor of shape (batch_size, gen_len, vocab_size).
        These are the logits obtained during autoregressive decoding using `model.generate()`.
    prompt_attention_mask : torch.Tensor
        Tensor of shape (batch_size, prompt_len). Contains 1 where the token is valid and 0 for padding.
    gen_attention_mask : torch.Tensor  
        Tensor of shape (batch_size, gen_len). Contains 1 where the token is valid and 0 for padding.
    mode : str, optional
        Which tokens to use for entropy computation:
        - "prompt": compute entropy only over the prompt logits/mask.
        - "generation": compute entropy only over the generated logits/mask.
        - "promptGeneration": compute entropy over both concatenated prompt and generated logits/mask.
    prepend_last_prompt_logit : bool, optional
        If True, appends the last logit from the prompt to the beginning of the generation logits.
        This is useful when generation logits were computed manually from hidden states 
        and are therefore shifted by one position (they lack the first prediction step so 
        the first logit is missing). Default is False.
        Carreful! The gen_attention_mask must match.
    top_k : int, optional
        If specified, only the top_k logits (per token) are used to compute the entropy.
        If None, use all logits.
    window_size : int, optional
        If not None, apply a sliding window of this size across the (valid) sequence of token entropies,
        and return the maximum mean entropy over any complete window, for each sample.
        If None, simply average the per-token entropies over all valid tokens.
    stride : int, optional
        Sliding window stride. Only used if window_size is specified.
        - If None, defaults to window_size (non-overlapping windows).
        - If set, must be a positive integer <= window_size.

    Returns
    -------
    np.ndarray
        Array of shape (batch_size,). For each batch sample, either the average logit entropy
        over valid tokens (if window_size is None) or the maximum windowed mean entropy (if window_size is given).

    Notes
    -----
    - Padding tokens are always ignored, both in classic and windowed entropy.
    - In windowed mode, only windows where all tokens in the window are valid are considered.
    - Uses torch.special.entr for numerically stable entropy calculation.
    """
    # ==============================
    # Move to device
    # ==============================
    prompt_logits = prompt_logits.to(prompt_attention_mask.device)
    if gen_logits is not None:
        gen_logits = gen_logits.to(gen_attention_mask.device)

    # Prepend last logit of prompt to the generation logits if specifed
    if prepend_last_prompt_logit:
        last_prompt_logit = prompt_logits[:, -1:, :] # (batch_size, 1, vocab_size)
        gen_logits = torch.cat([last_prompt_logit, gen_logits], dim=1) # (batch_size, gen_len+1, vocab_size)

    def entropy_from_logits(logits, attention_mask, top_k=None):
        """
        Parameters
        ----------
        logits: (batch_size, seq_len, vocab_size)
        attention_mask: (batch_size, seq_len)
        top_k: int > 0

        Returns
        -------
        entropy: (batch_size, seq_len)
        attention_mask: (batch_size, seq_len)
        """

        # Convert float16 -> float32 for better accuracy during computations
        logits = logits.float()
        attention_mask = attention_mask.float()

        if top_k is not None:
            topk_vals = torch.topk(logits, k=top_k, dim=-1).values  # (batch_size, seq_len, top_k)
            probs = F.softmax(topk_vals, dim=-1) # (batch_size, seq_len, top_k)
        else:
            probs = F.softmax(logits, dim=-1) # (batch_size, seq_len, vocab_size)

        # Use torch.special.entr, which automatically handles edge cases
        # entropy(x) = -x * log(x) with entropy(0) = 0
        entropy = torch.special.entr(probs).sum(dim=-1)  # (batch_size, seq_len)
        return entropy, attention_mask # both are (batch_size, seq_len)

    def average_entropy(entropy, mask):
        """
        Parameters
        ----------
        entropy: (batch_size, seq_len)
        mask: (batch_size, seq_len)
        
        Returns
        -------
        avg_entropy: (batch_size,)
        """
        entropy_masked = entropy * mask                    # (batch_size, seq_len)
        total_entropy = entropy_masked.sum(dim=-1)         # (batch_size,)
        valid_count = mask.sum(dim=-1)                     # (batch_size,)
        avg_entropy = total_entropy / (valid_count + 1e-9) # (batch_size,)
        return avg_entropy

    def max_sliding_window_entropy(entropy, mask, w, stride):
        """
        Parameters
        ----------
        entropy: (batch_size, seq_len)
        mask: (batch_size, seq_len)
        w: int > 0
        stride: int > 0

        Returns
        -------
        max_avg_entropy: (batch_size,)
        """
        # Add one dummy channel dimension since conv1d requires 3D tensors
        entropy = entropy.unsqueeze(1)  # (batch_size, 1, seq_len)
        mask = mask.unsqueeze(1)        # (batch_size, 1, seq_len)

        kernel = torch.ones(1, 1, w, device=entropy.device) / w  # shape: (1,1,w)

        # padding=0 to avoid artificial values and distorting the calculation
        # Ignore windows for which there are not enough elements to form a complete window.
        moving_avg = F.conv1d(entropy, kernel, stride=stride, padding=0)  # sliding mean entropy
        
        # All windows where there is at least one padding token will be ignored with valid_mask
        valid_counts = F.conv1d(mask, kernel, stride=stride, padding=0)   # sliding mean mask (valid token ratio)
        valid_mask = (valid_counts == 1.0)  # full valid windows only

        moving_avg = moving_avg.masked_fill(~valid_mask, float('-inf')) # put -inf where valid_mask==0

        max_avg_entropy, _ = moving_avg.max(dim=-1)  # (batch_size, 1)
        
        return max_avg_entropy.squeeze(1) # (batch_size,)

    if top_k is not None:
        top_k = int(top_k)
        if top_k <= 0 or top_k > prompt_logits.shape[2]:
            raise ValueError("top_k must be a positive integer less or equal to vocab size")
        
    if window_size is not None:
        if stride is None:
            stride = window_size
        else:
            stride = int(stride)
            if stride <= 0 or stride > window_size:
                raise ValueError("stride must be a positive integer less or equal to window_size.")
    else:
        stride = None

    if mode == "prompt":
        entropy, mask = entropy_from_logits(prompt_logits, prompt_attention_mask, top_k) # both are (batch_size, prompt_len)
    elif mode == "generation":
        entropy, mask = entropy_from_logits(gen_logits, gen_attention_mask, top_k)       # both are (batch_size, gen_len)
    elif mode == "promptGeneration":
        ent_p, mask_p = entropy_from_logits(prompt_logits, prompt_attention_mask, top_k) # both are (batch_size, prompt_len)
        ent_g, mask_g = entropy_from_logits(gen_logits, gen_attention_mask, top_k)       # both are (batch_size, gen_len)
        entropy = torch.cat([ent_p, ent_g], dim=1) # (batch_size, prompt_len + gen_len)
        mask = torch.cat([mask_p, mask_g], dim=1)  # (batch_size, prompt_len + gen_len)
    else:
        raise ValueError("mode must be in {'prompt','generation','promptGeneration'}")

    if window_size is None:
        result = average_entropy(entropy, mask)
    
    else:
        if window_size <= 0:
            raise ValueError("window_size must be a positive integer")
        if window_size > entropy.shape[1]:
            raise ValueError("window_size greater than sequence length")
        if stride is None:
            stride = window_size
        else:
            stride = int(stride)
            if stride <= 0 or stride > window_size:
                raise ValueError("stride must be a positive integer less or equal to window_size.")
        
        window_size = int(window_size)
        result = max_sliding_window_entropy(entropy, mask, window_size, stride)
    return result.cpu().numpy()


In [44]:
from transformers import PreTrainedTokenizer, PreTrainedModel, BatchEncoding
import torch
from datasets import  Dataset
from tqdm import tqdm
from typing import List, Callable, Union, Literal, Dict, Tuple
from torch.utils.hooks import RemovableHandle

from src.inference.offset_utils import (
    compute_offset_attention_mask,
)
from src.inference.generation_utils import (
    build_prompt,
    extract_batch, 
    build_generation_attention_mask)

### OK ###
def generate(
    model: PreTrainedModel,
    inputs: BatchEncoding,
    tokenizer: PreTrainedTokenizer,
    max_new_tokens: int = 50,
    k_beams: int = 1,
    **generate_kwargs
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
    """
    Generate sequences from the model with optional beam search.
    Supports advanced options via **generate_kwargs (e.g., output_attentions).

    Parameters
    ----------
    model : PreTrainedModel
        The language model to use for generation.
    inputs : BatchEncoding
        Tokenized input prompts.
    tokenizer : PreTrainedTokenizer
        Tokenizer providing eos and pad token IDs.
    max_new_tokens : int, optional
        Maximum number of new tokens to generate.
    k_beams : int, optional
        Number of beams to use. If 1, uses sampling. If >1, beam search is enabled.
    **generate_kwargs : dict
        Additional keyword arguments passed to `model.generate()`.

    Returns
    -------
    Union[torch.Tensor, Dict[str, torch.Tensor]]
        - If k_beams == 1:
            Returns a tensor of generated token IDs: shape (batch_size, prompt_len + gen_len)
        - If k_beams > 1:
            Returns a dictionary with keys:
                - "sequences": the generated token IDs
                - "beam_indices": the beam path for each token
    """
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True    if k_beams == 1 else False,
            temperature=0.6   if k_beams == 1 else None,
            top_p=0.9         if k_beams == 1 else None,
            top_k=50          if k_beams == 1 else None,
            num_beams=k_beams,
            use_cache=True, 
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id, # Ensures clean padding (right padding)
            output_hidden_states=False,      # We rely on the hook to extract hidden states instead (more memory efficient)
            output_attentions=False,         # We rely on the hook to extract attention map instead (more memory efficient)
            output_logits=True,              # Logits not filtered/truncated by top-k/top-p sampling. Note: `output_scores=True` returns filtered logits. 
            return_dict_in_generate=True,    # Needed for access to beam_indices when num_beams > 1
            early_stopping=False if k_beams == 1 else True, #Generation stops as soon as any sequence hits EOS, even if other candidates have not yet finished.
            **generate_kwargs                # For future flexibility (e.g., output_attentions, output_scores)
        )
        return outputs 

In [76]:
def apply_logit_lens(
        model: PreTrainedModel, 
        hidden_states: torch.Tensor
    ) -> torch.Tensor :
    """
    Applies the model's LM head to hidden states to produce logits.

    Args:
        model: PreTrainedModel with `lm_head` attribute.
        hidden_states: Tensor (batch_size, seq_len, hidden_size).

    Returns:
        logits: Tensor (batch_size, seq_len, vocab_size).

    NOTE: 
    We do not apply layer norm to match transformers llama implementation
    """
    # Apply LM head (linear projection)
    logits = model.lm_head(hidden_states)
    return logits


In [64]:
def run_prompt_and_generation_score_extraction(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    dataset: Dataset,
    batch_size: int = 4,
    idx_start_sample: int = 0,
    max_samples: int = 1000,
    save_to_pkl: bool = False,
    output_path: str = "outputs/all_batch_results.pkl",
    build_prompt_fn: Callable[[str, str], str] = None,
    layers: List[int] = [-1],  
    activation_source: Literal["prompt", "generation", "promptGeneration"] = "generation",
    hidden_scores: List[str] = ["average", "last", "max", "first_generated", "token_svd_score", "feat_var"],
    attn_scores: List[str] = ["attn_eig_prod"],
    logit_scores: List[str] = ["perplexity", "logit_entropy", "window_logit_entropy"],
    logit_config: dict = {"top_k": 50, "window_size": 1, "stride": 1},
    start_offset : int = 0,
    end_offset : int = 0,
) -> Union[List[torch.Tensor], None]:
    """
    Runs batched inference on a dataset using a decoder-only language model.
    For each batch, it performs text generation and extracts token-level 
    hidden activations, attention maps and logit scores from specified transformer layers.
    (both from the prompt and the generated text depending on `activation_source`) 

    The function supports multiple aggregation modes for the activations (`hidden_scores`), attention-based 
    scores (`attn_scores`), and logit-based scores (`logit_scores`). The `logit_config` argument provides 
    configuration parameters for logit-based score functions.
    
    Hidden states and attention maps are captured via forward hooks during generation, 
    then aggregated  based on token position and attention masks.
    
    These activations are saved as individual batch files in a specified pickle directory, 
    allowing efficient incremental storage and later aggregation.
    Alternatively, the representations can be returned directly.

    Parameters
    ----------
    model : PreTrainedModel
        The causal language model to evaluate (e.g., LLaMA).
    tokenizer : PreTrainedTokenizer
        The corresponding tokenizer.
    dataset : Dataset
        The input dataset.
    batch_size : int
        Number of samples per batch.
    idx_start_sample : int
        Index of the first sample to process from the dataset.
    max_samples : int
        Total number of examples to process from the dataset, starting from idx_start_sample. 
    save_to_pkl : bool
        If True, activations are appended to the pickle file at output_path.
        If False, the function returns a list of activations.
    output_path : str
        Path to the directory where extracted answers will be saved as individual pickle batch files.
    build_prompt_fn : Callable
        Function to build a prompt from context and question.
    layers : List[int]
        List of indices of the transformer layers to extract activations from (default: [-1] for last layer).
    activation_source : {"prompt", "generation", "promptGeneration"}
        Which part of the sequence to extract activations/attentions/logits from:
        - "prompt": only from the prompt
        - "generation": only from the generated answer
        - "promptGeneration": prompt and generation answer both concatenated
    hidden_scores : List[str], optional
        List of aggregation modes to compute on token activations. Possible modes include:
            "average", "last", "max", "first_generated", "token_svd_score", "feat_var".
        These modes are passed to `extract_token_activations` for aggregation. Default includes the above.
    attn_scores : List[str], optional
        List of attention-based scores to compute. Supported: "attn_eig_prod".
    logit_scores : List[str], optional
        List of logit-based scores to compute. Supported:
            "perplexity", "logit_entropy", "window_logit_entropy".
    logit_config : dict, optional
        Configuration dictionary for logit-based scoring functions, with keys such as:
            - "top_k": int, number of top logits considered (default 50)
            - "window_size": int, window size for windowed entropy (default 1)
            - "stride": int, stride for windowed entropy (default 1)
    start_offset : int
        Offset from the first non-padding token (must be >= 0). 
    end_offset : int
        Offset from the last non-padding token (must be <= 0, e.g., -3 to remove 3 tokens).
    
    Returns
    -------
    Union[List[dict], None]
        If `save_to_pkl` is False, returns a list of dictionaries, one per batch, with each element
         of the list having the following structure:
            {
                "id": List[str],  # IDs of batch samples
                "original_indices": List[int],  # Original dataset indices
                "context": List[str],
                "question": List[str],
                "gt_answers": List[str],        # Ground-truth reference answers
                "gen_answers": List[str],       # Generated model answers
                "scores": {
                    "layer_{layer_idx}": {
                        "hidden": { 
                            "{mode}": np.ndarray[(batch_size, hidden_size), float], 
                            ... # one entry per mode in hidden_scores
                        },
                        "attention": {
                            "{attn_score}": np.ndarray[(batch_size,), float],  
                            ...
                        }
                    },
                    "logits": {
                        "perplexity": np.ndarray[(batch_size,), float],
                        "logit_entropy": np.ndarray[(batch_size,), float],
                        "window_logit_entropy": np.ndarray[(batch_size,), float] 
                    }
                }
            },

        If `save_to_pkl` is True, saves each batch's dictionary incrementally to disk and returns None.

    Notes
    -----
    When using model.generate() with output_hidden_states=True (what we are replicating here with the ,
    activation hook) use_cache=True and max_new_tokens=30, there is always an offset between the length of the 
    generated sequence (outputs.sequences.shape[1][prompt_len:]) and the length of len(outputs.hidden_states) : 
    * outputs.sequences.shape[1] = prompt_len (17) + max_new_tokens (30) = 47
    * len(outputs.hidden_states) = max_new_tokens (30)
        With : 
        * outputs.hidden_states[0][layer_idx].shape = (batch_size, prompt_len, hidden_size)           --> includes the prompt ! 
        * outputs.hidden_states[i][layer_idx].shape = (batch_size, 1, hidden_size) with 1 <= i <= 29  --> stops at 29 ! 
    *Note* that in our code, outputs.hidden_states and activations are the same. 
        
    Explanation from Hugging Face, April 2024 
    (https://github.com/huggingface/transformers/issues/30036):
    """

    if activation_source not in ('prompt', 'generation', 'promptGeneration'):
        raise ValueError(
                f"Invalid value for `activation_source`: '{activation_source}'. "
                f"Expected one of: ['prompt', 'generation', 'promptGeneration']."
            )    
        
    # ==============================
    # Patch selected layer(s) with custom LlamaAttention Forward function to retrieve attention weights
    # ==============================
    for idx in layers:  
        model.model.layers[idx].self_attn.forward = patched_LlamaAttention_forward.__get__(
            model.model.layers[idx].self_attn,
            model.model.layers[idx].self_attn.__class__
    )
        
    all_batch_results = []  

    for i in tqdm(range(idx_start_sample, idx_start_sample + max_samples, batch_size)):
        print(f"============ {i} ============")
        
    
        # ==============================
        # Prepare input batch
        # ==============================
        batch = extract_batch(dataset, i, batch_size)
        prompts = [build_prompt_fn(s["context"], s["question"]) for s in batch]
        inputs = tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(model.device)
        prompt_ids = inputs["input_ids"] # (batch_size, prompt_len)
        prompt_len = prompt_ids.shape[1] # Assumes prompts are padded to same length

        print(f"[INFO] prompt_ids shape: {prompt_ids.shape}, prompt_len: {prompt_len}")

        # ==============================
        # Register forward hook to capture layer output
        # ==============================
        # This hook collects the hidden states at each decoding step. For layer l: 
        # activations_lists[l] = [act_prompt, act_gen_step1, ..., act_gen_step49] of length 50, if max_new_tokens=50.
        # activations_lists[l][k] of shape: (batch_size, seq_len, hidden_size) 
        activations_lists = [[] for _ in layers]  # one empty list per layer 
        handle_act, call_counter_act = register_generation_activation_hook(model, activations_lists, layers)

        # This hook collects the activations at each decoding step. For layer l: 
        # attentions_lists[l] = [attn_prompt, attn_gen_step1, ..., attn_gen_step49], of length 50, if max_new_tokens=50.
        # activations_lists[l][k] of shape: (batch_size, n_heads, tgt_seq_len, src_seq_len)
        #   tgt_seq_len: length of the sequence the model is currently producing (query)
        #   src_seq_len: length of the sequence the model is focusing on (key/value)
        attentions_lists = [[] for _ in layers]  # one empty list per layer
        handle_attn, call_counter_attn = register_generation_attention_hook(model, attentions_lists, layers)
        
        # ==============================
        # Run model generation (hook captures activations and attentions)
        # ==============================
        # When target layers are reached, hooks execute and saves their output in activations and attentions
        print("[INFO] Starting generation...")
        outputs = generate(model, inputs, tokenizer, max_new_tokens=50, k_beams=1)
        gen_ids = outputs.sequences[:, prompt_len:]
        print(f"[INFO] gen_ids shape: {gen_ids.shape}")
        print(f"[INFO] Sample generated tokens: {gen_ids}")

        # Remove hooks to avoid memory leaks or duplicate logging
        for h in handle_act: h.remove()
        for h in handle_attn: h.remove()
        
        # Verify that hooks worked properly
        verify_call_counters(call_counter_act, name="activation hooks")
        verify_call_counters(call_counter_attn, name="attention hooks")

        # Retrieve text of generated answers
        gen_answers = tokenizer.batch_decode(
            outputs.sequences[:, prompt_len:], 
            skip_special_tokens=True
        ) # (batch_size,)
        ###print(f"[INFO] Sample generated answer: {gen_answers}")

  

        # ===============================
        # Build generation and prompt attention mask
        # ===============================
        # This mask marks which generated tokens are valid (i.e., not padding).
        # Positions are marked True up to and including the first eos_token_id
        generation_attention_mask = build_generation_attention_mask(
            gen_ids=gen_ids, 
            eos_token_id=tokenizer.eos_token_id
        ) # (batch_size, gen_len)

        prompt_attention_mask = inputs["attention_mask"] 
        # (batch_size, prompt_len)

        ###print(f"[INFO] prompt_attention_mask: {prompt_attention_mask.shape}")
        ###print(f"[INFO] generation_attention_mask: {generation_attention_mask.shape}")
        ###print(f"[INFO] generation_attention_mask: {generation_attention_mask}")

        # Modify prompt attention mask with offsets
        if start_offset !=0 or end_offset !=0:
            ###print(f"[INFO] Offsetting prompt_attention_mask with start={start_offset}, end={end_offset}")
            prompt_attention_mask, start_indices, end_indices = compute_offset_attention_mask(
                attention_mask=prompt_attention_mask, 
                start_offset=start_offset, 
                end_offset=end_offset
            ) # (batch_size, prompt_len), (batch_size,), (batch_size,)

        ###print(f"[INFO] New prompt_attention_mask shape: {prompt_attention_mask.shape}")
        ###print(f"[INFO] New prompt_attention_mask : {prompt_attention_mask}")

        # Concatenate the prompt and generation attention mask
        prompt_and_gen_attention_mask = torch.cat(
            [prompt_attention_mask,
            generation_attention_mask],
            dim=1
        ) # (batch_size, prompt_len + gen_len)

        ###print(f"[INFO] prompt_and_gen_attention_mask shape: {prompt_and_gen_attention_mask.shape}")
        ###print(f"[INFO] prompt_and_gen_attention_mask : {prompt_and_gen_attention_mask}")

        # ===============================
        # Truncate generated token IDs and mask to match activations and attentions
        # ===============================
        # When N tokens are generated, only the first N-1 tokens have corresponding hidden states.
        # So activations[1:] covers only the first N-1 steps. Therefore, we exclude the last
        # generated token from outputs.sequences to match activations[1:]. Same for attentions.
        truncated_gen_ids = gen_ids[:,:-1] # (gen_len-1,)
        truncated_generation_attention_mask = generation_attention_mask[:,:-1] # (batch_size, gen_len-1)
        truncated_prompt_and_gen_attention_mask = prompt_and_gen_attention_mask[:,:-1] # (batch_size, prompt_len + gen_len-1)

        ###print(f"[INFO] Truncated gen_ids shape: {truncated_gen_ids.shape}")
        ###print(f"[INFO] Truncated generation_attention_mask shape: {truncated_generation_attention_mask.shape}")
        ###print(f"[INFO] Truncated prompt_and_gen_attention_mask shape: {truncated_prompt_and_gen_attention_mask.shape}")
        ###print(f"[INFO] Truncated gen_ids : {truncated_gen_ids}")
        ###print(f"[INFO] Truncated generation_attention_mask : {truncated_generation_attention_mask}")
        ###print(f"[INFO] Truncated prompt_and_gen_attention_mask : {truncated_prompt_and_gen_attention_mask}")

        # *******************************
        # START: loop on layers
        # *******************************
        # Final dictionary on all layers 
        save_layers_scores = {}

        for l, layer_idx in enumerate(layers):
            print(f"\n----- Layer {layer_idx} -----")

            activations = activations_lists[l]
            attentions = attentions_lists[l]

            print("============")
            print("[INFO] Length of activations:", len(activations))
            for i in range(len(activations)):
                print(f"[INFO] Shape  of activations[{i}]: {activations[i].shape}") 
            print("============")
            print("[INFO] Length of attentions:", len(attentions))
            for i in range(len(attentions)):
                print(f"[INFO] Shape  of attentions[{i}]: {attentions[i].shape}") 
            print("============")

            # Define prompt and generation hidden states 
            prompt_activations=activations[0]       # `[0]` to include only the prompt part 
            generation_activations=activations[1:]  # `[1:]` to exclude the prompt part 
            
            # Define prompt and generation attention maps
            prompt_attentions=attentions[0]         # `[0]` to include only the prompt part 
            generation_attentions=attentions[1:]    # `[1:]` to exclude the prompt part 

            ###print(f"[DEBUG] prompt_activations shape: {prompt_activations.shape}")
            ###print(f"[DEBUG] prompt_attentions shape: {prompt_attentions.shape}")

            # ===============================
            # Align generated and prompt hidden states
            # ===============================
            # For each batch item, take the last generated hidden state at this step
            stacked_generation_activations = torch.stack(
                [h[:, -1, :] for h in generation_activations], dim=1
            ) # (batch_size, gen_len, hidden_size)

            ###print(f"[DEBUG] stacked_generation_activations shape: {stacked_generation_activations.shape}")

            # Concatenate the prompt and generation aligned hidden states  
            prompt_and_gen_activations = torch.cat(
                [stacked_generation_activations, # (batch_size, gen_len, hidden_size)
                prompt_activations],             # (batch_size, prompt_len, hidden_size)
                dim=1
            ) # (batch_size, prompt_len + gen_len, hidden_size)
            
            ###print(f"[DEBUG] prompt_and_gen_activations shape: {prompt_and_gen_activations.shape}")

            # ==============================
            # Extract token activations from captured layer, based on source
            # ==============================
            ###print(f"[INFO] Activation source: {activation_source}")
            
            if hidden_scores is not None and len(hidden_scores) > 0:
                if activation_source == "generation":
                    # Return only the token activations from the generated answer 
                    selected_token_vecs = extract_token_activations(                 ##### extract_token_activations_fn 
                            selected_layer=stacked_generation_activations, 
                            attention_mask=truncated_generation_attention_mask, 
                            device=stacked_generation_activations.device,
                            modes=hidden_scores,
                        ) # (batch_size, hidden_size)
                    
                elif activation_source == "prompt":    
                    # Return only the token activations from the prompt
                    selected_token_vecs = extract_token_activations(
                            selected_layer=prompt_activations, 
                            attention_mask=prompt_attention_mask, 
                            device=prompt_activations.device,
                            modes=hidden_scores,
                        ) # (batch_size, hidden_size)
                    
                else: # activation_source == "promptGeneration"
                    # Return token activations from the concatenated prompt + generated answer 
                    selected_token_vecs = extract_token_activations(
                            selected_layer=prompt_and_gen_activations, 
                            attention_mask=truncated_prompt_and_gen_attention_mask, 
                            device=prompt_and_gen_activations.device,
                            skip_length=prompt_len,
                            modes=hidden_scores,
                            # skip_length: exclude prompt from computation if 
                            # mode=='first_generated' in `extract_token_activations_fn`
                        ) # (batch_size, hidden_size)
                
                ###print(f"[RESULT] selected_token_vecs sample:\n{selected_token_vecs}")
                #result.update(selected_token_vecs) 
                hidden_results = {}
                for mode in hidden_scores:
                    if mode in selected_token_vecs:
                        hidden_results[mode] = selected_token_vecs[mode].cpu().numpy() #[vec.cpu().numpy() for vec in selected_token_vecs[mode]] 
                    #else:
                    #    hidden_results[mode] = None 
                #layers_scores[f"layer_{layer_idx}"] = {"hidden": hidden_results}
                save_layers_scores.setdefault(f"layer_{layer_idx}", {}).update({"hidden": hidden_results})


        
            # ==============================
            # Extract attention eigen score
            # ==============================
            #attn_eig_prod = None
            if attn_scores is not None and 'attn_eig_prod' in attn_scores:
                attn_eig_prod = compute_attn_eig_prod(
                        prompt_attentions=prompt_attentions, 
                        generation_attentions=generation_attentions,
                        prompt_attention_mask=prompt_attention_mask, 
                        generation_attention_mask=truncated_generation_attention_mask,
                        mode=activation_source,
                )
                ###print(f"[RESULT] attn_eig_prod:\n{attn_eig_prod}")

                # Store result in dict
                #layers_scores[f"layer_{layer_idx}"] = {"attention": {"attn_eig_prod": attn_eig_prod}}
                save_layers_scores.setdefault(f"layer_{layer_idx}", {}).update({"attention": {"attn_eig_prod": attn_eig_prod}})
            '''
            layers_scores[f"layer_{layer_idx}"] = {
            "hidden": hidden_results,
            "attention": {"attn_eig_prod": attn_eig_prod } #if attn_eig_prod is not None else None}
            }
            '''

            # ==============================
            # Extract logits scores
            # ==============================
            # if this is the last layer, use regular way to compute logits 
            # since there is small differences when computing prompt activations from forward pass and prompt activations
            # from model.generate() resulting in different logits. 
            if logit_scores is not None and len(logit_scores) > 0: 
                logits_results = {}
                if layer_idx != -1 and layer_idx != model.config.num_hidden_layers -1:
                    with torch.no_grad():
                        prompt_logits = apply_logit_lens(model, prompt_activations) # (batch, prompt_len, vocab_size)
                        gen_logits = apply_logit_lens(model, stacked_generation_activations) # (batch, gen_len-1, vocab_size)
                    # First gen_logits is missing (shape gen_len-1) -> use `prepend_last_prompt_logit=True`
                    # to retrive first generated logit (see spec of function `compute_perplexity` for more details)
                    prepend_last_prompt_logit = True
                else: #last layer
                    # ==============================
                    # Forward pass to the model to retrieve prompt logits 
                    # ==============================
                    with torch.no_grad():
                        prompt_logits = model(input_ids=inputs["input_ids"]).logits # (batch, prompt_len, vocab_size)
                    gen_logits = torch.stack(outputs.logits, dim=1)  # (batch, gen_len, vocab_size)
                    prepend_last_prompt_logit = False

                print(f"[INFO] prompt_logits.shape: {prompt_logits.shape}")
                print(f"[INFO] gen_logits.shape: {gen_logits.shape}")

                if 'perplexity' in logit_scores:
                    perplexity = compute_perplexity(
                        prompt_logits=prompt_logits, 
                        gen_logits=gen_logits,
                        prompt_ids=prompt_ids, 
                        gen_ids=gen_ids,
                        prompt_attention_mask=prompt_attention_mask,
                        gen_attention_mask=generation_attention_mask,
                        prepend_last_prompt_logit=prepend_last_prompt_logit,
                        mode=activation_source,
                        min_k=None
                    )
                    logits_results['perplexity'] = perplexity

                if 'logit_entropy' in logit_scores:
                    if logit_config is None:
                        raise ValueError("logit_entropy is required but logit_config is None")
                    logit_entropy = compute_logit_entropy(
                        prompt_logits=prompt_logits,
                        gen_logits=gen_logits,
                        prompt_attention_mask=prompt_attention_mask,
                        gen_attention_mask=generation_attention_mask,
                        mode=activation_source,
                        prepend_last_prompt_logit=prepend_last_prompt_logit,
                        top_k=logit_config['top_k'],
                        window_size=None,
                        stride=None
                    )
                    logits_results['logit_entropy'] = logit_entropy

                if 'window_logit_entropy' in logit_scores:
                    if logit_config is None:
                        raise ValueError("window_logit_entropy is required but logit_config is None")
                    window_logit_entropy = compute_logit_entropy(
                        prompt_logits=prompt_logits,
                        gen_logits=gen_logits,
                        prompt_attention_mask=prompt_attention_mask,
                        gen_attention_mask=generation_attention_mask,
                        mode=activation_source,
                        prepend_last_prompt_logit=prepend_last_prompt_logit,
                        top_k=logit_config['top_k'], # default 50
                        window_size=logit_config['window_size'], # default 1
                        stride=logit_config['stride'] # default 1
                    )
                    logits_results['window_logit_entropy'] = window_logit_entropy

                if logits_results:
                    save_layers_scores.setdefault(f"layer_{layer_idx}", {}).update({"logits": logits_results})

        # *******************************
        # END: loop on layers
        # *******************************

        '''
        save_logits_scores = {}
        #perplexity = None; logit_entropy=None; window_logit_entropy=None
        # if 
        if logit_scores is not None:
            if 'perplexity' in logit_scores:
                perplexity = compute_perplexity(
                    prompt_logits=prompt_logits, 
                    gen_logits=gen_logits,
                    prompt_ids=prompt_ids, 
                    gen_ids=gen_ids,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=generation_attention_mask,
                    mode=activation_source,
                    min_k=None
                )
                ###print(f"[RESULT] perplexity:\n{perplexity}")
                save_logits_scores['perplexity'] = perplexity #if perplexity is not None else None

            if 'logit_entropy' in logit_scores:
                if logit_config is None:
                    raise ValueError("logit_entropy is required but logit_config is None")
                logit_entropy = compute_logit_entropy(
                    prompt_logits=prompt_logits,
                    gen_logits=gen_logits,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=generation_attention_mask,
                    mode=activation_source,
                    top_k=logit_config['top_k'], # default 50,
                    window_size=None,
                    stride=None
                )
                ###print(f"[RESULT] logit_entropy:\n{logit_entropy}")
                save_logits_scores['logit_entropy'] = logit_entropy #if logit_entropy is not None else None
        
            if 'window_logit_entropy' in logit_scores:
                if logit_config is None:
                    raise ValueError("window_logit_entropy is required but logit_config is None")
                window_logit_entropy = compute_logit_entropy(
                    prompt_logits=prompt_logits,
                    gen_logits=gen_logits,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=generation_attention_mask,
                    mode=activation_source,
                    top_k=logit_config['top_k'], # default 50
                    window_size=logit_config['window_size'], # default 1
                    stride=logit_config['stride'] # default 1
                )
                ###print(f"[RESULT] window_logit_entropy:\n{window_logit_entropy}")
                save_logits_scores['window_logit_entropy'] = window_logit_entropy #if window_logit_entropy is not None else None
        '''
        #print(f"[RESULT] FINAL RESULT:\n{result}")
        #print(f"[RESULT] FINAL logits_scores:\n{logits_scores}")

        # ==============================
        # Store results (to file or memory)
        # ==============================
        batch_results = {
            "id": [s['id'] for s in batch],
            "original_indices": [s['original_index'] for s in batch],
            "context": [s['context'] for s in batch],
            "question": [s['question'] for s in batch],
            "gt_answers": [s['answers'] for s in batch],
            "gen_answers": gen_answers,
            "scores": {**save_layers_scores} #, **({"logits": save_logits_scores} if save_logits_scores else {})}
        }

        from src.data_reader.pickle_io import save_batch_pickle

        if save_to_pkl:
            #append_to_pickle(output_path, batch_results)
            save_batch_pickle(batch_data=batch_results, output_dir=output_path, batch_idx=i)
        else:
            all_batch_results.append(batch_results)

    if not save_to_pkl:
        return all_batch_results

In [81]:
# Clear memory to avoid "CUDA out of memory"
# -----------------------------------
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

start_offset = 0
end_offset = 0
result = run_prompt_and_generation_score_extraction(
    model=model,
    tokenizer=tokenizer,
    dataset=id_fit_dataset,
    batch_size=2,
    idx_start_sample=0,
    max_samples=2,
    save_to_pkl = False,
    output_path = OUTPUT_DIR + "all_batch_resultsTEST.pkl",
    build_prompt_fn=build_prompt,
    layers = [18],  
    activation_source = "promptGeneration",
    hidden_scores=["average", "last", "max", "first_generated", "token_svd_score", "feat_var"],
    attn_scores=["attn_eig_prod"],
    logit_scores=["perplexity","logit_entropy", "window_logit_entropy"],
    logit_config={"top_k": 50, "window_size": 1, "stride": 1},
    start_offset = start_offset,
    end_offset = end_offset
)

  0%|          | 0/1 [00:00<?, ?it/s]

[INFO] prompt_ids shape: torch.Size([2, 281]), prompt_len: 281
[INFO] Starting generation...


100%|██████████| 1/1 [00:00<00:00,  3.19it/s]

[INFO] gen_ids shape: torch.Size([2, 10])
[INFO] Sample generated tokens: tensor([[29871,   443, 12011,   519,     2,     2,     2,     2,     2,     2],
        [29871,  6106,   292,   322,  6025,  3277,  5100,  2187, 29889,     2]],
       device='cuda:0')

----- Layer 18 -----
[INFO] Length of activations: 10
[INFO] Shape  of activations[0]: torch.Size([2, 281, 4096])
[INFO] Shape  of activations[1]: torch.Size([2, 1, 4096])
[INFO] Shape  of activations[2]: torch.Size([2, 1, 4096])
[INFO] Shape  of activations[3]: torch.Size([2, 1, 4096])
[INFO] Shape  of activations[4]: torch.Size([2, 1, 4096])
[INFO] Shape  of activations[5]: torch.Size([2, 1, 4096])
[INFO] Shape  of activations[6]: torch.Size([2, 1, 4096])
[INFO] Shape  of activations[7]: torch.Size([2, 1, 4096])
[INFO] Shape  of activations[8]: torch.Size([2, 1, 4096])
[INFO] Shape  of activations[9]: torch.Size([2, 1, 4096])
[INFO] Length of attentions: 10
[INFO] Shape  of attentions[0]: torch.Size([2, 32, 281, 281])
[INFO] Sha




In [82]:
result

[{'id': ['56be85543aeaaa14008c9063', '56be85543aeaaa14008c9065'],
  'original_indices': [0, 1],
  'context': ['Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny\'s Child. Managed by her father, Mathew Knowles, the group became one of the world\'s best-selling girl groups of all time. Their hiatus saw the release of Beyoncé\'s debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles "Crazy in Love" and "Baby Boy".',
   'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she

: 

In [91]:
### OK ###
from typing import Any
import glob
def load_and_merge_pickles(directory: str) -> Dict[str, Any]:
    """
    Load and recursively merge all batch pickle files from a directory into a single results dictionary.

    Each pickle file must contain a dictionary with the same structure across batches.

    The function merges nested dictionaries, concatenating values along the batch axis.
    Supported merge strategies:
        - Lists are extended (concatenated).
        - NumPy arrays are concatenated along the first dimension (axis=0).
        - Nested dictionaries are merged recursively.

    Parameters
    ----------
    directory : str
        Path to the directory containing batch pickle files. Files must match the pattern '*.pkl'.

    Returns
    -------
    Dict[str, Any]
        A recursively merged dictionary where:
            - Leaf values are lists or arrays aggregated from all batches.
            - Nested dictionaries (e.g., "scores" → "layer_0" → "hidden") are merged in depth.
    """
    def recursive_merge(dest: dict, src: dict):
        for key, value in src.items():
            if key not in dest:
                dest[key] = value if not isinstance(value, dict) else recursive_merge({}, value)
            else:
                if isinstance(value, dict) and isinstance(dest[key], dict):
                    recursive_merge(dest[key], value)
                elif isinstance(value, list):
                    dest[key].extend(value)
                elif hasattr(value, 'shape'):  # numpy array
                    import numpy as np
                    dest[key] = np.concatenate([dest[key], value], axis=0)
                else:
                    raise ValueError(f"Cannot merge key '{key}' with type {type(value)}")

        return dest

    merged = {}
    files = sorted(glob.glob(os.path.join(directory, "*.pkl")))
    for file in files:
        with open(file, "rb") as f:
            batch = pickle.load(f)
            recursive_merge(merged, batch)

    return merged

from src.data_reader.pickle_io import save_merged_pickle, load_pickle_batches

In [None]:
dir = OUTPUT_DIR + "all_batch_resultsTEST.pkl"
merged = load_and_merge_pickles(dir)
save_merged_pickle(merged, dir)

In [44]:
def run_prompt_score_extraction(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    dataset: Dataset,
    batch_size: int = 4,
    idx_start_sample: int = 0,
    max_samples: int = 1000,
    save_to_pkl: bool = False,
    output_path: str = "outputs/all_batch_results.pkl",
    build_prompt_fn: Callable[[str, str], str] = None,
    layers: List[int] = [-1],  
    hidden_scores: List[str] = ["average", "last", "max", "first_generated", "token_svd_score", "feat_var"],
    attn_scores: List[str] = ["attn_eig_prod"],
    logit_scores: List[str] = ["perplexity", "logit_entropy", "window_logit_entropy"],
    logit_config: dict = {"top_k": 50, "window_size": 1, "stride": 1},
    start_offset : int = 0,
    end_offset : int = 0,
) -> Union[List[torch.Tensor], None]:
    """
    Runs batched inference on a dataset using a decoder-only language model.
    For each batch, it runs a forward pass on the prompt and extracts token-level hidden 
    activations, attention maps and logit scores from specified transformer layers.

    The function supports multiple aggregation modes for the activations (`hidden_scores`), attention-based 
    scores (`attn_scores`), and logit-based scores (`logit_scores`). The `logit_config` argument provides 
    configuration parameters for logit-based score functions.
    
    Hidden states and attention maps are captured via forward hooks, 
    then aggregated based on token position and attention masks.
    
    These activations are saved as individual batch files in a specified pickle directory, 
    allowing efficient incremental storage and later aggregation.
    Alternatively, the representations can be returned directly.

    Parameters
    ----------
    model : PreTrainedModel
        The causal language model to evaluate (e.g., LLaMA).
    tokenizer : PreTrainedTokenizer
        The corresponding tokenizer.
    dataset : Dataset
        The input dataset.
    batch_size : int
        Number of samples per batch.
    idx_start_sample : int
        Index of the first sample to process from the dataset.
    max_samples : int
        Total number of examples to process from the dataset, starting from idx_start_sample. 
    save_to_pkl : bool
        If True, activations are appended to the pickle file at output_path.
        If False, the function returns a list of activations.
    output_path : str
        Path to the directory where extracted answers will be saved as individual pickle batch files.
    build_prompt_fn : Callable
        Function to build a prompt from context and question.
    layers : List[int]
        List of indices of the transformer layers to extract activations from (default: [-1] for last layer).
    hidden_scores : List[str], optional
        List of aggregation modes to compute on token activations. Possible modes include:
            "average", "last", "max", "first_generated", "token_svd_score", "feat_var".
        These modes are passed to `extract_token_activations` for aggregation. Default includes the above.
    attn_scores : List[str], optional
        List of attention-based scores to compute. Supported: "attn_eig_prod".
    logit_scores : List[str], optional
        List of logit-based scores to compute. Supported:
            "perplexity", "logit_entropy", "window_logit_entropy".
    logit_config : dict, optional
        Configuration dictionary for logit-based scoring functions, with keys such as:
            - "top_k": int, number of top logits considered (default 50)
            - "window_size": int, window size for windowed entropy (default 1)
            - "stride": int, stride for windowed entropy (default 1)
    start_offset : int
        Offset from the first non-padding token (must be >= 0). 
    end_offset : int
        Offset from the last non-padding token (must be <= 0, e.g., -3 to remove 3 tokens).
    
    Returns
    -------
    Union[List[dict], None]
        If `save_to_pkl` is False, returns a list of dictionaries, one per batch, with each element
         of the list having the following structure:
            {
                "id": List[str],  # IDs of batch samples
                "original_indices": List[int],  # Original dataset indices
                "context": List[str],
                "question": List[str],
                "gt_answers": List[str],        # Ground-truth reference answers
                "gen_answers": List[str],       # Generated model answers
                "scores": {
                    "layer_{layer_idx}": {
                        "hidden": { 
                            "{mode}": np.ndarray[(batch_size, hidden_size), float], 
                            ... # one entry per mode in hidden_scores
                        },
                        "attention": {
                            "{attn_score}": np.ndarray[(batch_size,), float],  
                            ...
                        }
                    },
                    "logits": {
                        "perplexity": np.ndarray[(batch_size,), float],
                        "logit_entropy": np.ndarray[(batch_size,), float],
                        "window_logit_entropy": np.ndarray[(batch_size,), float] 
                    }
                }
            },

        If `save_to_pkl` is True, saves each batch's dictionary incrementally to disk and returns None.

    Notes
    -----
    When using model.generate() with output_hidden_states=True (what we are replicating here with the ,
    activation hook) use_cache=True and max_new_tokens=30, there is always an offset between the length of the 
    generated sequence (outputs.sequences.shape[1][prompt_len:]) and the length of len(outputs.hidden_states) : 
    * outputs.sequences.shape[1] = prompt_len (17) + max_new_tokens (30) = 47
    * len(outputs.hidden_states) = max_new_tokens (30)
        With : 
        * outputs.hidden_states[0][layer_idx].shape = (batch_size, prompt_len, hidden_size)           --> includes the prompt ! 
        * outputs.hidden_states[i][layer_idx].shape = (batch_size, 1, hidden_size) with 1 <= i <= 29  --> stops at 29 ! 
    *Note* that in our code, outputs.hidden_states and activations are the same. 
        
    Explanation from Hugging Face, April 2024 
    (https://github.com/huggingface/transformers/issues/30036):
    """

        
    # ==============================================================  
    # [PATCH] Replace LlamaAttention.forward on target layers by
    #  custom module to extract attention weights
    # ==============================================================
    for idx in layers:  
        model.model.layers[idx].self_attn.forward = patched_LlamaAttention_forward.__get__(
            model.model.layers[idx].self_attn,
            model.model.layers[idx].self_attn.__class__
    )
        
    # ==============================================================  
    # [LOOP] Process batches of examples  
    # ==============================================================
    all_batch_results = []  

    for i in tqdm(range(idx_start_sample, idx_start_sample + max_samples, batch_size)):
      
        # ----------------------------------------------------------
        # [BATCH INPUT] Extract and tokenize prompts
        # ----------------------------------------------------------
        batch = extract_batch(dataset, i, batch_size)
        prompts = [build_prompt_fn(s["context"], s["question"]) for s in batch]
        inputs = tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(model.device)
        prompt_ids = inputs["input_ids"] # (batch_size, prompt_len)
        prompt_attention_mask = inputs["attention_mask"] 

        # ----------------------------------------------------------
        # [HOOKS] Register hooks to capture hidden states and attentions
        # ----------------------------------------------------------
        # This hook collects the hidden states. For layer l: 
        # activations_lists[l] = [act_prompt], 
        # activations_lists[l][0] of shape: (batch_size, prompt_len, hidden_size) 
        activations_lists = [[] for _ in layers]  # one empty list per layer 
        handle_act, call_counter_act = register_generation_activation_hook(model, activations_lists, layers)

        # This hook collects the activations at each decoding step. For layer l: 
        # attentions_lists[l] = [attn_prompt], 
        # activations_lists[l][0] of shape: (batch_size, n_heads, prompt_len, prompt_len)
        attentions_lists = [[] for _ in layers]  # one empty list per layer
        handle_attn, call_counter_attn = register_generation_attention_hook(model, attentions_lists, layers)
        
        # ----------------------------------------------------------
        # [FOWARD PASS] Run model with hooks to capture intermediate states
        # ----------------------------------------------------------
        # Pass inputs through the model. When the target layer is reached,
        # the hook executes and saves its output in captured_hidden.
        if logit_scores is not None and len(logit_scores) > 0:
            with torch.no_grad():
                outputs = model(**inputs, return_dict=True, return_logits=True)
            prompt_logits = outputs.logits
        else:
            with torch.no_grad():
                outputs = model(**inputs, return_dict=True)
        
        # Remove hooks to avoid memory leaks or duplicate logging
        for h in handle_act: h.remove()
        for h in handle_attn: h.remove()
        
        # Verify that hooks worked properly
        verify_call_counters(call_counter_act, name="activation hooks")
        verify_call_counters(call_counter_attn, name="attention hooks")


        # ----------------------------------------------------------
        # [OFFSET] Modify prompt mask with offset, if specified
        # ----------------------------------------------------------
        if start_offset !=0 or end_offset !=0:
            prompt_attention_mask, start_indices, end_indices = compute_offset_attention_mask(
                attention_mask=prompt_attention_mask, 
                start_offset=start_offset, 
                end_offset=end_offset
            ) # (batch_size, prompt_len), (batch_size,), (batch_size,)


        # **********************************************************
        # [LAYER LOOP] Extract activation and attention-based scores for each specified layer 
        # **********************************************************
        save_layers_scores = {}

        for l, layer_idx in enumerate(layers):

            activations = activations_lists[l]
            attentions = attentions_lists[l]

            # Define prompt and generation hidden states 
            prompt_activations=activations[0]    
            
            # Define prompt and generation attention maps
            prompt_attentions=attentions[0]        

            # ------------------------------------------------------
            # [HIDDEN SCORES] Extract token-level activations/hidden-states
            # ------------------------------------------------------
            if hidden_scores is not None and len(hidden_scores) > 0:
                # Return only the token activations from the prompt
                selected_token_vecs = extract_token_activations(
                        selected_layer=prompt_activations, 
                        attention_mask=prompt_attention_mask, 
                        device=prompt_activations.device,
                        modes=hidden_scores,
                    ) # (batch_size, hidden_size)
 
                # Save results to dict
                hidden_results = {}
                for mode in hidden_scores:
                    if mode in selected_token_vecs:
                        hidden_results[mode] = selected_token_vecs[mode].cpu().numpy()
                save_layers_scores.setdefault(f"layer_{layer_idx}", {}).update({"hidden": hidden_results})

            # ------------------------------------------------------
            # [ATTENTION SCORES] Extract attention eigenvalue-based metric
            # ------------------------------------------------------
            if attn_scores is not None and 'attn_eig_prod' in attn_scores:
                attn_eig_prod = compute_attn_eig_prod(
                        prompt_attentions=prompt_attentions, 
                        generation_attentions=None,
                        prompt_attention_mask=prompt_attention_mask, 
                        generation_attention_mask=None,
                        mode='prompt',
                )
                # Save results to dict
                save_layers_scores.setdefault(f"layer_{layer_idx}", {}).update({"attention": {"attn_eig_prod": attn_eig_prod}}) 
        
        # **********************************************************
        # [END LAYER LOOP] 
        # **********************************************************

        save_logits_scores = {}
        # ------------------------------------------------------
        # [LOGIT SCORES] Compute metrics from model logits
        # ------------------------------------------------------
        if logit_scores is not None:
            if 'perplexity' in logit_scores:
                perplexity = compute_perplexity(
                    prompt_logits=prompt_logits, 
                    gen_logits=None,
                    prompt_ids=prompt_ids, 
                    gen_ids=None,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=None,
                    mode='prompt',
                    min_k=None
                )
                # Save results to dict
                save_logits_scores['perplexity'] = perplexity 

            if 'logit_entropy' in logit_scores:
                if logit_config is None:
                    raise ValueError("logit_entropy is required but logit_config is None")
                logit_entropy = compute_logit_entropy(
                    prompt_logits=prompt_logits,
                    gen_logits=None,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=None,
                    mode='prompt',
                    top_k=logit_config['top_k'], 
                    window_size=None,
                    stride=None
                )
                # Save results to dict
                save_logits_scores['logit_entropy'] = logit_entropy 
        
            if 'window_logit_entropy' in logit_scores:
                if logit_config is None:
                    raise ValueError("window_logit_entropy is required but logit_config is None")
                window_logit_entropy = compute_logit_entropy(
                    prompt_logits=prompt_logits,
                    gen_logits=None,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=None,
                    mode='prompt',
                    top_k=logit_config['top_k'],
                    window_size=logit_config['window_size'], 
                    stride=logit_config['stride'] 
                )
                # Save results to dict
                save_logits_scores['window_logit_entropy'] = window_logit_entropy 


        # ==========================================================
        # [OUTPUT] Store extracted results (to memory or file)
        # ==========================================================
        batch_results = {
            "id": [s['id'] for s in batch],
            "original_indices": [s['original_index'] for s in batch],
            "context": [s['context'] for s in batch],
            "question": [s['question'] for s in batch],
            "gt_answers": [s['answers'] for s in batch],
            "scores": {**save_layers_scores, **({"logits": save_logits_scores} if save_logits_scores else {})}
        }

        from src.data_reader.pickle_io import save_batch_pickle

        if save_to_pkl:
            save_batch_pickle(batch_data=batch_results, output_dir=output_path, batch_idx=i)
        else:
            all_batch_results.append(batch_results)

    if not save_to_pkl:
        return all_batch_results


In [45]:
# Clear memory to avoid "CUDA out of memory"
# -----------------------------------
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

start_offset = 0
end_offset = 0
result = run_prompt_score_extraction(
    model=model,
    tokenizer=tokenizer,
    dataset=id_fit_dataset,
    batch_size=2,
    idx_start_sample=0,
    max_samples=2,
    save_to_pkl = False,
    output_path = OUTPUT_DIR + "all_batch_resultsTEST.pkl",
    build_prompt_fn=build_prompt,
    layers = [18,-1],  
    hidden_scores=["average", "last", "max", "first_generated", "token_svd_score", "feat_var"],
    attn_scores=["attn_eig_prod"],
    logit_scores=["perplexity", "logit_entropy", "window_logit_entropy"],
    logit_config={"top_k": 50, "window_size": 1, "stride": 1},
    start_offset = start_offset,
    end_offset = end_offset
)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00, 10.02it/s]


In [41]:
result

[{'id': ['56be85543aeaaa14008c9063', '56be85543aeaaa14008c9065'],
  'original_indices': [0, 1],
  'context': ['Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny\'s Child. Managed by her father, Mathew Knowles, the group became one of the world\'s best-selling girl groups of all time. Their hiatus saw the release of Beyoncé\'s debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles "Crazy in Love" and "Baby Boy".',
   'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she

## Test that all modules work correctly

In [22]:
from src.inference.generation_utils import build_prompt

from src.inference.run_extraction import run_prompt_and_generation_score_extraction

result = run_prompt_and_generation_score_extraction(
    model=model,
    tokenizer=tokenizer,
    dataset=id_fit_dataset,
    batch_size=2,
    idx_start_sample= 0,
    max_samples= 2,
    save_to_pkl = False,
    output_path = "outputs/all_batch_results.pkl",
    build_prompt_fn=build_prompt,
    layers = [18,-1],  
    activation_source = "promptGeneration",
    hidden_scores=["average", "last", "max", "first_generated", "token_svd_score", "feat_var"],
    attn_scores=['attn_scores'],
    logit_scores=["perplexity", "logit_entropy", "window_logit_entropy"],
    logit_config={"top_k": 50, "window_size": 1, "stride": 1},
    start_offset = 0,
    end_offset = 0
)


  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:01<00:00,  1.10s/it]


In [23]:
result

[{'id': ['56be85543aeaaa14008c9063', '56be85543aeaaa14008c9065'],
  'original_indices': [0, 1],
  'context': ['Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny\'s Child. Managed by her father, Mathew Knowles, the group became one of the world\'s best-selling girl groups of all time. Their hiatus saw the release of Beyoncé\'s debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles "Crazy in Love" and "Baby Boy".',
   'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she

### Fonctions du papier LLM-check

In [18]:
# Fonctions exactement comme dans le papier LLM check 

"""
Length of outputs.hidden_states:  7 
( = number of generated tokens - 1 element to exclude last generated token + 1 element for prompt)
L : layer, batch_size = 2, hidden_size = 4096, prompt_len = 329
Shape of outputs.hidden_states[0][L]: torch.Size([2, 329, 4096])
Shape of activations[1][L]: torch.Size([2, 1, 4096])
Shape of activations[-1][L]: torch.Size([2, 1, 4096])

hidden_act = [x[0].to(torch.float32).detach().cpu() for x in outputs.hidden_states]

Pour chaque couche (layer_num), tu calcules un score SVD sur les activations 
cachées de cette couche (pour chaque sample).
get_svd_eval retourne un score par sample pour cette couche.
Ces scores sont stockés dans indiv_scores[mt]["HlyX"] (X = numéro de la couche).
Donc :

**** Un score par couche, par sample (et pas par head, car les hidden states ne sont pas splittés par head). ****
"""
# dans compute_scores()
for layer_num in range(1, len(hidden_acts[0])):
    mt_score.append(get_svd_eval(hidden_acts, layer_num, tok_lens, use_toklens)[0])
    indiv_scores[mt]["Hly" + str(layer_num)].append(mt_score[-1])


def get_svd_eval(hidden_acts, layer_num=15, tok_lens=[], use_toklens=True):
    """Evaluate hidden states at a given layer using SVD-based scoring.

    For each sample, this function extracts the hidden states at a specified layer,
    optionally slices them according to `tok_lens`, and computes the SVD-based score.

    Args:
        hidden_acts (list): A list of tuples, each containing hidden states for all layers
            for a single sample.
        layer_num (int, optional): The layer index to evaluate. Defaults to 15.
        tok_lens (list, optional): A list of (start, end) indices for each sample to slice
            the hidden states. Defaults to [].
        use_toklens (bool, optional): Whether to slice the hidden states using `tok_lens`.
            Defaults to True.

    Returns:
        np.array: An array of SVD-based scores for each sample.
    """
    svd_scores = []
    for i in range(len(hidden_acts)): # loop sur les samples 
        Z = hidden_acts[i][layer_num] # activations pour le sample i à la couche layer_num, shape (seq_len, hidden_size)

        if use_toklens and tok_lens[i]:
            i1, i2 = tok_lens[i][0], tok_lens[i][1]
            Z = Z[i1:i2, :]

        Z = torch.transpose(Z, 0, 1)
        svd_scores.append(centered_svd_val(Z).item())
    # print("Sigma matrix shape:",Z.shape[1])
    return np.stack(svd_scores)


def centered_svd_val(Z, alpha=0.001):
    """Compute the mean log singular value of a centered covariance matrix.

    This function centers the data and computes the singular value decomposition
    (SVD) of the resulting covariance matrix. It then returns the mean of the
    log singular values, regularized by `alpha`.

    Args:
        Z (torch.Tensor): A 2D tensor representing features hidden acts.
        alpha (float, optional): Regularization parameter added to the covariance matrix.
            Defaults to 0.001.

    Returns:
        float: The mean of the log singular values of the centered covariance matrix.
    """
    # assumes Z is in full precision
    # Center the lines of  Z (i.e. subtract the average of each line).
    # Allows to study variance without bias due to a non-zero mean.
    J = torch.eye(Z.shape[0]) - (1 / Z.shape[0]) * torch.ones(Z.shape[0], Z.shape[0])
    # Compute column-centered covariance matrix of Z
    Sigma = torch.matmul(torch.matmul(Z.t(), J), Z)
    # Regularization for stabilization
    Sigma = Sigma + alpha * torch.eye(Sigma.shape[0])
    # Singular Value Decomposition
    svdvals = torch.linalg.svdvals(Sigma)
    # Final Score
    eigscore = torch.log(svdvals).mean() # multiplication by 2 missing from the paper ? 
    return eigscore

NameError: name 'hidden_acts' is not defined

In [11]:
# Faire tourner les fonctions du papier LLM check sur ma machine 

import numpy as np

def get_model_vals(model, tok_in):
    """Run the model forward pass to obtain logits, hidden states, and attention scores.

    Args:
        model: A pretrained model compatible with the transformers API.
        tok_in (torch.Tensor): A tensor of tokenized input IDs.

    Returns:
        tuple: A tuple `(logits, hidden_states, attentions)` where:
        logits (torch.Tensor): Output logits from the model.
        hidden_states (tuple of torch.Tensor): Hidden states from each model layer.
        attentions (tuple of torch.Tensor): Attention weights from each model layer.
    """
    kwargs = {
        "input_ids": tok_in,
        "use_cache": False,
        "past_key_values": None,
        "output_attentions": True,
        "output_hidden_states": True,
        "return_dict": True,
    }
    with torch.no_grad():
        output = model(**kwargs)
    return output.logits, output.hidden_states, output.attentions



#def compute_scores(logits, hidden_acts, attns, scores,  mt_list, tok_ins, indiv_scores=None, tok_lens=[], use_toklens=False):
def compute_scores(logits, hidden_acts, attns, tok_ins, mt_list = ['hidden'], indiv_scores=None, tok_lens=[], use_toklens=False):
    """Compute various evaluation scores (e.g., perplexity, entropy, SVD scores) from model outputs.

    This function takes model outputs (logits, hidden states, attentions) and computes
    a list of metric scores defined by `mt_list`. The computed scores are appended
    to `scores` and `indiv_scores` dictionaries for tracking.

    NOTE: The indiv_scores score dictionary will be saved to disk and then used for final metric computation in
    check scores ipynb

    Args:
        logits: Model logits.
        hidden_acts: Hidden activations.
        attns: Attention matrices.
        scores (list): A list to store aggregated scores across samples.
        indiv_scores (dict): A dictionary to store metric-specific scores for each sample
        mt_list (list): A list of metric types to compute.
        tok_ins: A list of tokenized inputs for each sample.
        tok_lens: A list of tuples indicating the start and end token indices for each sample.
        use_toklens (bool, optional): Whether to use `tok_lens` to slice sequences. Defaults to True.

    Raises:
        ValueError: If an invalid metric type is encountered in `mt_list`.
    """
    j=0
    sample_scores = []
    for mt in mt_list:
        mt_score = []
        if mt == "logit":
            mt_score.append(perplexity(logits, tok_ins, tok_lens)[0])
            #indiv_scores[mt]["perplexity"].append(mt_score[-1])

            #mt_score.append(window_logit_entropy(logits, tok_lens, w=1)[0])
            #indiv_scores[mt]["window_entropy"].append(mt_score[-1])

            #mt_score.append(logit_entropy(logits, tok_lens, top_k=50)[0])
            #indiv_scores[mt]["logit_entropy"].append(mt_score[-1])

        elif mt == "hidden":
            print("=============== j ===============", j)
            j+=1
            for layer_num in range(1, len(hidden_acts[0])):
                print("****** layer_num: *******", layer_num)
                mt_score.append(get_svd_eval(hidden_acts, layer_num, tok_lens, use_toklens)[0])
                #indiv_scores[mt]["Hly" + str(layer_num)].append(mt_score[-1])

        elif mt == "attns":
            for layer_num in range(1, len(attns[0])):
                mt_score.append(get_attn_eig_prod(attns, layer_num, tok_lens, use_toklens)[0])
                #indiv_scores[mt]["Attn" + str(layer_num)].append(mt_score[-1])

        else:
            raise ValueError("Invalid method type")

        sample_scores.extend(mt_score)
    #scores.append(sample_scores)

def centered_svd_val(Z, alpha=0.001):
    """Compute the mean log singular value of a centered covariance matrix.

    This function centers the data and computes the singular value decomposition
    (SVD) of the resulting covariance matrix. It then returns the mean of the
    log singular values, regularized by `alpha`.

    Args:
        Z (torch.Tensor): A 2D tensor representing features hidden acts.
        alpha (float, optional): Regularization parameter added to the covariance matrix.
            Defaults to 0.001.

    Returns:
        float: The mean of the log singular values of the centered covariance matrix.
    """
    # assumes Z is in full precision
    print("Z.shape[0]: ", Z.shape[0])
    print("--Z.shape: ", Z.shape)
    J = torch.eye(Z.shape[0]) - (1 / Z.shape[0]) * torch.ones(Z.shape[0], Z.shape[0])
    print("J.shape: ", J.shape)
    Sigma = torch.matmul(torch.matmul(Z.t(), J), Z)
    Sigma = Sigma + alpha * torch.eye(Sigma.shape[0])
    print("Sigma.shape: ", Sigma.shape)
    svdvals = torch.linalg.svdvals(Sigma)
    eigscore = torch.log(svdvals).mean()
    return eigscore

def get_svd_eval(hidden_acts, layer_num=15, tok_lens=[], use_toklens=False):
    """Evaluate hidden states at a given layer using SVD-based scoring.

    For each sample, this function extracts the hidden states at a specified layer,
    optionally slices them according to `tok_lens`, and computes the SVD-based score.

    Args:
        hidden_acts (list): A list of tuples, each containing hidden states for all layers
            for a single sample.
        layer_num (int, optional): The layer index to evaluate. Defaults to 15.
        tok_lens (list, optional): A list of (start, end) indices for each sample to slice
            the hidden states. Defaults to [].
        use_toklens (bool, optional): Whether to slice the hidden states using `tok_lens`.
            Defaults to True.

    Returns:
        np.array: An array of SVD-based scores for each sample.
    """
    svd_scores = []
    print("len(hidden_acts): ", len(hidden_acts))
    for i in range(len(hidden_acts)):
        print("i: ", i)
        Z = hidden_acts[i][layer_num]
        print("Z.shape: ", Z.shape) # (seq_len, hidden_size)

        if use_toklens and tok_lens[i]:
            i1, i2 = tok_lens[i][0], tok_lens[i][1]
            Z = Z[i1:i2, :]

        Z = torch.transpose(Z, 0, 1) # (hidden_size, seq_len)
        print("Z.T.shape: ", Z.shape)
        svd_scores.append(centered_svd_val(Z).item())
        print("len(svd_scores)", len(svd_scores))
        print("svd_scores[0]: ", svd_scores)
    # print("Sigma matrix shape:",Z.shape[1])
    return np.stack(svd_scores)


def perplexity(logits, tok_ins, tok_lens, min_k=None):
    """Compute the perplexity of model predictions for given tokenized inputs.

    This function computes the perplexity by taking the negative log probability
    of the correct tokens and exponentiating the mean. If `min_k` is provided,
    it filters the lowest probabilities to compute a restricted perplexity.

    Args:
        logits: A list or array of model logits (samples x seq_len x vocab_size).
        tok_ins: A list of tokenized input IDs for each sample.
        tok_lens (list): A list of (start, end) indices specifying the portion of the
            sequence to evaluate.
        min_k (float, optional): A fraction of tokens to consider from the lowest
            probabilities. If not None, only these tokens are considered.

    Returns:
        np.array: An array of perplexity values for each sample.
    """
    softmax = torch.nn.Softmax(dim=-1)
    ppls = []

    
    for i in range(len(logits)):
        print("logits[i]: ", logits[i])
        print("logits[i].shape: ", logits[i].shape)
        
        i1, i2 = tok_lens[i][0], tok_lens[i][1]
        
        pr = torch.log(softmax(logits[i]))[torch.arange(i1, i2) - 1, tok_ins[i][0, i1:i2]]
        if min_k is not None:
            pr = torch.topk(pr, k=int(min_k * len(pr)), largest=False).values
        ppls.append(torch.exp(-pr.mean()).item())

    return np.stack(ppls)




def get_attn_eig_prod(attns, layer_num=15, tok_lens=[], use_toklens=True):
    """Compute an eigenvalue-based attention score by analyzing attention matrices.

    This function takes the attention matrices of a given layer and for each sample,
    computes the mean log of the diagonal elements (assumed to be eigenvalues) across
    all attention heads. Slices are applied if `tok_lens` is used.

    Args:
        attns (list): A list of tuples, each containing attention matrices for all layers
            and heads for a single sample.
        layer_num (int, optional): The layer index to evaluate. Defaults to 15.
        tok_lens (list, optional): A list of (start, end) indices for each sample to slice
            the attention matrices. Defaults to [].
        use_toklens (bool, optional): Whether to slice the attention matrices using `tok_lens`.
            Defaults to True.

    Returns:
        np.array: An array of computed attention-based eigenvalue scores for each sample.
    """
    attn_scores = []

    for i in range(len(attns)):  # iterating over number of samples
        eigscore = 0.0
        counter = 0
        for attn_head_num in range(len(attns[i][layer_num])):  # iterating over number of attn heads
            counter += 1
            # attns[i][layer_num][j] is of size seq_len x seq_len = [10,10] if 10 tokens in the sentence
            Sigma = attns[i][layer_num][attn_head_num]
            #print("Attention, Sigma.shape: ", Sigma.shape)

            if use_toklens and tok_lens[i]:
                i1, i2 = tok_lens[i][0], tok_lens[i][1]
                Sigma = Sigma[i1:i2, i1:i2]

            eigscore += torch.log(torch.diagonal(Sigma, 0)).mean()
            #print("eigscore: ", eigscore)
             
        attn_scores.append(eigscore.item())
        #print("len(attn_scores): ", len(attn_scores))
        res = np.stack(attn_scores)
        #print("res.shape: ", res.shape)

    #print("Counter: ", counter)
    return res


prompts = ["Je suis une très jolie fleur"]
inputs = tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(model.device)
tok_in  = inputs['input_ids']
print("tok_in: ", tok_in)
print("tok_in.shape: ", tok_in.shape)

logit, hidden_act, attn = get_model_vals(model, tok_in.to(0))
print("len(logit): ", len(logit))
print("logit.shape :", logit.shape)

# Unpacking the values into lists on CPU
logit = logit[0].cpu()
hidden_act = [x[0].to(torch.float32).detach().cpu() for x in hidden_act]
attn = [x[0].to(torch.float32).detach().cpu() for x in attn]
print("=============len(attn):", len(attn))
tok_in = tok_in.cpu()

compute_scores(
    [logit],
    hidden_acts=[hidden_act], # hidden_acts est la liste des activations de chaque couche pour ce sample
    attns= [attn],
    mt_list= ['logit'], # ['hidden', 'attns', 'logit'],
    tok_ins=[tok_in]
)






tok_in:  tensor([[    1,  2581, 26099,  1597,  9577,   432,   324,   347,  9115,   332]],
       device='cuda:0')
tok_in.shape:  torch.Size([1, 10])




len(logit):  1
logit.shape : torch.Size([1, 10, 32000])
logits[i]:  tensor([[ 1.0217e-01, -2.1973e-01,  3.1348e-01,  ...,  1.3281e+00,
          1.8799e+00,  6.4502e-01],
        [-6.7344e+00, -6.9375e+00, -1.1934e+00,  ..., -2.8359e+00,
         -7.2969e+00, -3.1680e+00],
        [-3.2051e+00, -3.1211e+00, -3.9014e-01,  ..., -2.3281e+00,
         -7.9648e+00, -3.3086e+00],
        ...,
        [-3.1445e+00, -4.4258e+00,  2.3438e+00,  ..., -4.1289e+00,
         -7.4336e+00, -1.7598e+00],
        [-3.5176e+00, -2.7402e+00,  4.0391e+00,  ..., -1.5322e+00,
         -2.4648e+00, -5.1270e-03],
        [-4.9492e+00, -4.7539e+00,  3.8203e+00,  ..., -4.7812e+00,
         -6.2500e+00, -4.0703e+00]], dtype=torch.float16)
logits[i].shape:  torch.Size([10, 32000])


IndexError: list index out of range

## Brouillon

In [23]:
    # ================
    elif mode == "cov_svd":
        # Compute the mean of the log of singular values of the centered covariance 
        # for each sample in the batch, taking into account only valid tokens.
        svd_scores = []
        batch_size = selected_layer.shape[0]
        for i in range(batch_size):
            # Select valid tokens 
            mask = attention_mask[i].bool()
            Z = selected_layer[i][mask]  # (num_valid_tokens, hidden_size)
            if Z.shape[0] == 0:
                svd_scores.append(float('nan'))
                continue
            # Transpose to have (hidden_size, num_valid_tokens)
            Z = Z.transpose(0, 1)
            # Assumes Z is in full precision
            # Center the features of Z (subtract the average of each line) and compute covariance on tokens.
            # Allows to study variance without bias due to a non-zero mean.
            d = Z.shape[0] # hidden_size
            J = torch.eye(d, device=Z.device) - (1 / d) * torch.ones(d, d, device=Z.device)
            # Compute centered covariance matrix of Z
            Sigma = torch.matmul(torch.matmul(Z, J), Z.t())
            # Regularization for stabilization
            Sigma = Sigma + alpha * torch.eye(Sigma.shape[0], device=Z.device)
            # Singular Value Decomposition
            svdvals = torch.linalg.svdvals(Sigma)
            # Final Score
            eigscore = torch.log(svdvals).mean() # mult by 2 missing from the paper? 
            svd_scores.append(eigscore.item())
        aggregated_tokens = np.array(svd_scores)

SyntaxError: invalid syntax (694832033.py, line 2)

In [None]:
from transformers import PreTrainedTokenizer, PreTrainedModel
import torch
from torch.utils.hooks import RemovableHandle
from typing import Tuple, Literal, List, Optional, Dict

def extract_token_activations(
    selected_layer: torch.Tensor,
    attention_mask: torch.Tensor,
    device: torch.device,
    modes: List[Literal[
        "average", "last", "max", "first_generated", 
        "token_cov_svd", "feat_cov_svd", 
        "token_cov_stats", "feat_cov_stats", "feat_cov_var"
    ]] = ["average"],
    skip_length: Optional[int] = None,
    alpha: int = 0.001,
) -> Dict[str, torch.Tensor]:
    """   
    Aggregate token-level activations over a specified span for each sequence in a batch,
    using various aggregation modes and attention mask.

    This function takes as input:
      - The layer activations (selected_layer) for each token in a batch of sequences,
      - An attention mask (attention_mask) of the same shape, where 1 indicates tokens to include
        in the aggregation and 0 marks tokens to ignore.

    The attention mask may be the original model mask, or a custom mask generated using
    `compute_offset_attention_mask` to dynamically select a sub-span of tokens.

    Parameters
    ----------
    selected_layer : torch.Tensor
        Tensor of shape (batch_size, seq_len, hidden_size) containing model activations for each token.
    attention_mask : torch.Tensor
        Attention mask of shape (batch_size, seq_len),  1 for real tokens, 0 for padding.
    device : torch.device
        Device for computation.
    modes : List[str]
        List of aggregation modes to compute. Computed using only valid tokens where attention_mask == 1.
        Supported:
        - "average": Mean activation vector across valid tokens. Shape: (batch_size, hidden_size)
        - "max": Element-wise max activation across valid tokens. Shape: (batch_size, hidden_size)
        - "last": Activation vector of last valid token in each sequence. Shape: (batch_size, hidden_size)
        - "first_generated": Activation of the first generated valid token in each sequence. Shape: (batch_size, hidden_size)
             If skip_length is provided, selects the token starting from that offset. 
        - "token_cov_svd": Mean log singular value of the centered token covariance matrix. Shape: (batch_size,)
        - "feat_cov_svd": Mean log singular value of the centered feature covariance matrix. Shape: (batch_size,)
        - "token_cov_stats": Statistics (mean, std, min, max) of the centered token covariance matrix. Shape: (batch_size, 4)
        - "feat_cov_stats": Statistics (mean, std, min, max) of the centered feature covariance matrix. Shape: (batch_size, 4)
        - "feat_cov_var": Diagonal of the centered feature covariance matrix (variances). Shape: (batch_size, hidden_size)

    skip_length : Optional[int]
        If provided, used to explicitly select the first generated token (useful for "first_generated" mode).
    alpha : float
        Regularization parameter added to the covariance matrix.

    Returns
    -------
    Dict[str, torch.Tensor or np.ndarray]
        Dictionary mapping each mode to its result:
            - (batch_size, hidden_size) for "average", "max", "last", "first_generated", "feat_cov_var"
            - (batch_size,) for "token_cov_svd", "feat_cov_svd"
            -  (batch_size, 4) for "token_cov_stats", "feat_cov_stats"
        
    NOTE: computation `token_cov_svd` score from: 
    "LLM-Check: Investigating Detection of Hallucinations in Large Language Models"
    (Sriramanan et al. 2024)
    """

    batch_size, seq_len, hidden_size = selected_layer.shape
    print(f"batch_size: {batch_size}, hidden_size: {hidden_size}")
    aggregated_tokens = {}
    
    # Move to device 
    attention_mask = attention_mask.to(selected_layer.device)
    print("selected_layer.device", selected_layer.device)
    print("attention_mask.devce", attention_mask.device)

    # =======================================
    # Select the first token with optional offset `skip_length`
    # =======================================
    if "first_generated" in modes:
        batch_indices = torch.arange(batch_size, device=device)
        if skip_length is not None:
            first_indices = torch.full((batch_size,), skip_length, device=device, dtype=torch.long)
        else:
            first_indices = (attention_mask == 1).float().argmax(dim=1)
        first = selected_layer[batch_indices, first_indices] # Shape: (batch_size, hidden_size)
        aggregated_tokens["first_generated"] = first

    # =======================================
    # Select the last token 
    # =======================================
    if "last" in modes:
        last_indices = attention_mask.shape[1] - 1 - attention_mask.flip(dims=[1]).float().argmax(dim=1)
        batch_indices = torch.arange(batch_size, device=device)
        last = selected_layer[batch_indices, last_indices]  # Shape: (batch_size, hidden_size)
        aggregated_tokens["last"] = last

    # =======================================
    # Apply mask and compute aggregation 
    # =======================================
    if "average" in modes or "max" in modes:
        # Add one dimension for the broadcast on hidden_size
        mask_float = attention_mask.float().unsqueeze(-1)  # (batch_size, num_valid_tokens, 1)
        # Apply the mask to the activations: zero out tokens outside the target interval
        masked = selected_layer * mask_float
        #  Count the number of selected tokens for each sequence (avoid division by zero with clamp)
        counts = mask_float.sum(dim=1).clamp(min=1e-6)
        if "average" in modes:
            # Compute the mean activation vector for each sequence over the selected interval
            avg = masked.sum(dim=1) / counts # Shape: (batch_size, hidden_size)
            aggregated_tokens["average"] = avg
        if "max" in modes:
            # Replace padding with -inf to exclude from max calculation
            masked_max = masked.masked_fill(mask_float.logical_not(), float('-inf'))
            # Extract maximum values across sequence dimension
            max_vals, _ = masked_max.max(dim=1) # Shape: (batch_size, hidden_size)
            aggregated_tokens["max"] = max_vals

    # =======================================
    # Covariance-based metrics
    # =======================================
    if any(m in modes for m in ["token_cov_svd", "feat_cov_svd", "token_cov_stats", "feat_cov_stats", "feat_cov_var"]):
        token_cov_svd = [] 
        feat_cov_svd = [] 
        token_cov_stats = []
        feat_cov_stats = []
        feat_cov_var = []
        
        for i in range(batch_size):
            # Select valid tokens 
            mask = attention_mask[i].bool()
            Z = selected_layer[i][mask]  # (num_valid_tokens, hidden_size)
            if Z.shape[0] == 0:
                feat_cov_var.append(torch.full((hidden_size,), float('nan')))
                token_cov_svd.append(float('nan'))
                feat_cov_svd.append(float('nan'))
                token_cov_stats.append(dict())
                feat_cov_stats.append(dict())
                continue
            
            if Z.dtype != torch.float32:
                Z = Z.to(torch.float32)
            num_valid_tokens = Z.shape[0]

            # Compute covariance matrix on tokens : Sigma_token 
            # ---------------------------------------
            # Assumes Z is in full precision
            # Center the features of Z (subtract the average of each line) and compute covariance on tokens.
            # Allows to study variance without bias due to a non-zero mean.
            J = torch.eye(hidden_size, device=Z.device, dtype=Z.dtype) - (1 / hidden_size) * torch.ones(hidden_size, hidden_size, device=Z.device, dtype=Z.dtype)
            print("J.shape:", J.shape)
            # Compute centered covariance matrix of Z
            Sigma_token = torch.matmul(torch.matmul(Z, J), Z.t()) # (num_valid_tokens, num_valid_tokens)
            # Regularization for stabilization
            Sigma_token = Sigma_token + alpha * torch.eye(Sigma_token.shape[0], device=Z.device, dtype=Z.dtype)
            print("1) Sigma_token stats:", Sigma_token.mean(), Sigma_token.min(), Sigma_token.max())

            # 2. Token covariance
            Z_token_centered = Z - Z.mean(dim=1, keepdim=True)
            Sigma_token = (Z_token_centered @ Z_token_centered.t()) / max(1, Z.shape[1] - 1)
            Sigma_token += alpha * torch.eye(Z.shape[0], device=Z.device, dtype=Z.dtype)
            print("2) Sigma_token stats:", Sigma_token.mean(), Sigma_token.min(), Sigma_token.max())


    
            # Compute covariance matrix on features : Sigma_feat
            # ---------------------------------------
            # Center the features of Z (subtract the average of each column) and compute covariance on features.
            J = torch.eye(num_valid_tokens, device=Z.device, dtype=Z.dtype) - (1 / num_valid_tokens) * torch.ones(num_valid_tokens, num_valid_tokens, device=Z.device, dtype=Z.dtype)
            # Compute centered covariance matrix of Z
            Sigma_feat = torch.matmul(torch.matmul(Z.t(), J), Z) # (hidden_size, hidden_size)
            # Regularization for stabilization
            Sigma_feat = Sigma_feat + alpha * torch.eye(Sigma_feat.shape[0], device=Z.device, dtype=Z.dtype)
            
            # Statistics of the token covariance matrix
            # ---------------------------------------
            if Sigma_token.dtype != torch.float32:
                Sigma_token = Sigma_token.to(torch.float32)
            
            Sigma_token_diag = Sigma_token.diag()
            token_stats = [
                Sigma_token_diag.mean().item(),
                Sigma_token_diag.std().item(),
                Sigma_token_diag.min().item(),
                Sigma_token_diag.max().item(),
            ]
            token_cov_stats.append(token_stats)

            # Singular value decomposition (SVD) of the token covariance matrix
            # ---------------------------------------
            token_svdvals = torch.linalg.svdvals(Sigma_token) # Singular Value Decomposition
            token_eigscore = torch.log(token_svdvals).mean()  # mult by 2 missing from the paper? 
            token_cov_svd.append(token_eigscore)

            # Statistics of the feature covariance matrix
            # ---------------------------------------
            Sigma_feat_diag = Sigma_feat.diag()
            
            if Sigma_feat_diag.dtype != torch.float32:
                Sigma_feat_diag = Sigma_feat_diag.to(torch.float32)

            feat_stats = [
                Sigma_feat_diag.mean().item(),
                Sigma_feat_diag.std().item(),
                Sigma_feat_diag.min().item(),
                Sigma_feat_diag.max().item()
            ]
            feat_cov_var.append(Sigma_feat_diag)
            feat_cov_stats.append(feat_stats)
            
            # Singular value decomposition (SVD) of the feature covariance matrix
            # ---------------------------------------
            feat_svdvals = torch.linalg.svdvals(Sigma_feat) # Singular Value Decomposition
            feat_eigscore = torch.log(feat_svdvals).mean() 
            feat_cov_svd.append(feat_eigscore)

        # Return scores
        # ---------------------------------------
        if "token_cov_svd" in modes:
            aggregated_tokens["token_cov_svd"] = torch.stack(token_cov_svd) # (batch_size,) 
        if "feat_cov_svd" in modes:
            aggregated_tokens["feat_cov_svd"] = torch.stack(feat_cov_svd) # (batch_size,) 
        if "token_cov_stats" in modes:
            aggregated_tokens["token_cov_stats"] = torch.tensor(token_cov_stats) # (batch_size, 4) 
        if "feat_cov_stats" in modes:
            aggregated_tokens["feat_cov_stats"] = torch.tensor(feat_cov_stats) # (batch_size, 4) 
        if "feat_cov_var" in modes:
            aggregated_tokens["feat_cov_var"] = torch.stack(feat_cov_var, dim=0) # (batch_size, hidden_size) 

    print("====================")
    print(aggregated_tokens)
    return aggregated_tokens

# Mettre tous les aggregated tokens sur le CPU ? 
# non mais nettoyer la mémoire à chaque fois: torch.cuda.empty_cache()

In [None]:
from transformers import PreTrainedTokenizer, PreTrainedModel, BatchEncoding
import torch
from datasets import  Dataset
from tqdm import tqdm
from typing import List, Callable, Union, Literal, Dict

from src.inference.activation_utils import (
    compute_offset_attention_mask,
)
from src.inference.inference_utils import (
    build_prompt,
    extract_batch, 
    align_generation_hidden_states,
    align_prompt_hidden_states,
    build_generation_attention_mask)

def generate(
    model: PreTrainedModel,
    inputs: BatchEncoding,
    tokenizer: PreTrainedTokenizer,
    max_new_tokens: int = 50,
    k_beams: int = 1,
    **generate_kwargs
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
    """
    Generate sequences from the model with optional beam search.
    Supports advanced options via **generate_kwargs (e.g., output_attentions).

    Parameters
    ----------
    model : PreTrainedModel
        The language model to use for generation.
    inputs : BatchEncoding
        Tokenized input prompts.
    tokenizer : PreTrainedTokenizer
        Tokenizer providing eos and pad token IDs.
    max_new_tokens : int, optional
        Maximum number of new tokens to generate.
    k_beams : int, optional
        Number of beams to use. If 1, uses sampling. If >1, beam search is enabled.
    **generate_kwargs : dict
        Additional keyword arguments passed to `model.generate()`.

    Returns
    -------
    Union[torch.Tensor, Dict[str, torch.Tensor]]
        - If k_beams == 1:
            Returns a tensor of generated token IDs: shape (batch_size, prompt_len + gen_len)
        - If k_beams > 1:
            Returns a dictionary with keys:
                - "sequences": the generated token IDs
                - "beam_indices": the beam path for each token
    """
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=1, #max_new_tokens,
            do_sample=True    if k_beams == 1 else False,
            temperature=0.6   if k_beams == 1 else None,
            top_p=0.9         if k_beams == 1 else None,
            top_k=50          if k_beams == 1 else None,
            num_beams=k_beams,
            use_cache=True, 
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id, # Ensures clean padding (right padding)
            output_hidden_states=False,      # We rely on the hook to extract hidden states instead (more memory efficient)
            output_attentions=False,         # We rely on the hook to extract attention map instead (more memory efficient)
            return_dict_in_generate=True,    # Needed for access to beam_indices when num_beams > 1
            early_stopping=False if k_beams == 1 else True, #Generation stops as soon as any sequence hits EOS, even if other candidates have not yet finished.
            **generate_kwargs                # For future flexibility (e.g., output_attentions, output_scores)
        )
        return outputs 


def register_generation_activation_hook(
    model: PreTrainedModel,
    captured_hidden_list: List[torch.Tensor],
    layer_idx: int = -1
) -> Tuple[RemovableHandle, dict]:
    """
    Attaches a forward hook to a specific transformer layer to capture hidden states
    during autoregressive text generation i.e., at each decoding step.
    (more memory-efficient than using output_hidden_states=True).
    Transformer layer = self-attention + FFN + normalization.

    Parameters
    ----------
    model : PreTrainedModel
        The Hugging Face causal language model (e.g., GPT, LLaMA).
    captured_hidden_list : List[torch.Tensor]
        A list that will be filled with hidden states for each generation step. 
        Each tensor has shape (batch_size * num_beams, seq_len, hidden_size).
    layer_idx : int
        Index of the transformer block to hook. Defaults to -1 (the last layer).
        Use a positive integer if you want to hook an intermediate layer instead.

    Returns
    ----------
    RemovableHandle : A handle object
        Call `handle.remove()` after generation to remove the hook.
    call_counter : int 
        Stores the number of times the hook is activated.
    """
    # Raise error if layer_idx not in correct range
    num_layers = len(model.model.layers)
    if not (layer_idx == -1 or 0 <= layer_idx < num_layers):
        raise ValueError(
            f"`layer_idx` must be -1 or in [0, {num_layers - 1}], but got {layer_idx}."
        )
    
    call_counter = {"count": 0} # count how many times the hook is triggered

    def hook_fn(module, input, output):
        """Function called automatically by PyTorch just after
            the layer has produced its output during the forward pass."""
        
        call_counter["count"] += 1 

        # output is a tuple (hidden_states,) → keep [0]
        if layer_idx == -1:
            # Capture the final normalized output 
            captured_hidden_list.append(model.model.norm(output[0]).detach())  # post RMSNorm!
        else:
            # Capture raw hidden states before layer normalization
            captured_hidden_list.append(output[0].detach()) #### TEST #### 
    
    # Register hook on the transformer block
    # When Pytorch pass through this layer during forward pass, it also execute hook_fn.
    handle = model.model.layers[layer_idx].register_forward_hook(hook_fn)
    
    return handle, call_counter


import numpy as np
def run_prompt_and_generation_activation_extraction(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    dataset: Dataset,
    batch_size: int = 4,
    idx_start_sample: int = 0,
    max_samples: int = 1000,
    save_to_pkl: bool = False,
    output_path: str = "outputs/all_batch_results.pkl",
    build_prompt_fn: Callable[[str, str], str] = None,
    register_generation_activation_hook_fn: Callable = None,
    layer_idx: int = -1,  
    extract_token_activations_fn: Callable = None,
    activation_source: Literal["prompt", "generation", "promptGeneration"] = "generation",
    k_beams : int = 1,
    start_offset : int = 0,
    end_offset : int = 0,
) -> Union[Dict[str, List[np.ndarray]], None]:
    """
    
    """
    batch_activations = {}  # Chosen token activation vectors

    for i in tqdm(range(idx_start_sample, idx_start_sample + max_samples, batch_size)):
        
        # ==============================
        # Prepare input batch
        # ==============================
        batch = extract_batch(dataset, i, batch_size)
        print("batch[0]: ", batch[0])
        
        prompts = [build_prompt_fn(s["context"], s["question"]) for s in batch]
        #prompts = [build_prompt_fn(s["context"], s["question"]) for s in [batch[1], batch[1]]]
        #prompts = [build_prompt_fn(batch[0]["context"][:10], batch[0]["question"][:5]), \
        #           build_prompt_fn(batch[1]["context"][:15], batch[1]["question"][:10]) ]
        #prompts = [s["context"] + s["question"] for s in batch]
        #prompts = [batch[0]["context"][:5], batch[0]["question"][:5], \
        #           batch[1]["context"][:20], batch[1]["question"][:20]]
        #prompts = ["ceci est un test", "Je ne sais pas pourquoi il y a un probleme"]

        inputs = tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(model.device)
        prompt_len = inputs["input_ids"].shape[1] # Assumes prompts are padded to same length

        for i in range(inputs['input_ids'].shape[0]):
            print("--- Example", i)
            print("prompt:", prompts[i])
            print("input_ids:", inputs["input_ids"][i])
            print("len(input_ids): ", len(inputs["input_ids"][i]))
            print("attention_mask:", inputs["attention_mask"][i])
            print("pad positions:", (inputs['input_ids'][i] == tokenizer.pad_token_id).nonzero())
            print("EOS positions:", (inputs['input_ids'][i] == tokenizer.eos_token_id).nonzero())


        max_id = inputs['input_ids'].max().item()
        print("max input id:", max_id)
        print("embedding matrix size:", model.get_input_embeddings().weight.size(0))
        assert max_id < model.get_input_embeddings().weight.size(0)

        # ==============================
        # Register forward hook to capture layer output
        # ==============================
        # This hook collects the hidden states at each decoding step
        # activations = [prompt] + [gen_step_1, gen_step_2, ..., gen_step_49], len(activations)=50, if max_new_tokens=50.
        activations = [] # activations[k] of Shape: (batch_size * k_beams, seq_len, hidden_size)
        #handle, call_counter_act = register_generation_activation_hook_fn(model, activations, layer_idx=layer_idx)
        
        attentions = [] # attentions[k] of Shape: (batch  *k_beams, n_heads, tgt_seq_len, src_seq_len)
        #handle_attn, call_counter_attn = register_generation_attention_hook(model, attentions, layer_idx=layer_idx) # PUT function in params

        # ==============================
        # Run model forward pass (hook captures activations)
        # ==============================
        # Generate text from prompts using beam search or sampling. 
        outputs = generate(model, inputs, tokenizer, max_new_tokens=50, k_beams=k_beams)

        print("attentions: ", attentions)
        
    
        print("len(outputs.attentions): ", len(outputs.attentions))
        print("outputs.attentions[0].shape: ", outputs.attentions[0].shape)
        
    
        print("outputs.sequences.device: ", outputs.sequences.device)
    
        # Retrieve text of generated answers
        gen_answers = tokenizer.batch_decode(
            outputs.sequences[:, prompt_len:], 
            skip_special_tokens=True
        ) # Shape: [batch_size,]
        
        # Define prompt and generation hidden states 
        #   outputs.hidden_states: list of hidden states at each generation step.
        #   - outputs.hidden_states[0][layer_idx]: hidden states for the entire prompt
        #   Shape: (batch_size * k_beams, prompt_len, hidden_size)
        #   - outputs.hidden_states[1 + i][layer_idx]: hidden states after generating the i-th new token
        #   Shape: (batch_size * k_beams, 1, hidden_size), for i = 0, ..., max_new_tokens-1
        #   So, for max_new_tokens=50 in generate(): len(outputs.hidden_states)=50
        #   and outputs.hidden_states: [prompt] + [gen_step_1] + [gen_step_2] + ... + [gen_step_49]
        prompt_activations=outputs.hidden_states[0][layer_idx] 
        generation_activations=[outputs.hidden_states[k][layer_idx] for k in range(1, len(outputs.hidden_states))] 
        
        
        print("len(prompt_activations): ", len(prompt_activations))
        print("len(generation_activations): ", len(generation_activations))

        print("generation_activations[0].device: ", generation_activations[0].device)
        print("prompt_activations.device:", prompt_activations.device)

        print("Layer L = ", layer_idx)
        print("=======================================")
        print("outputs.sequences.shape[1]: ", outputs.sequences.shape[1])
        if k_beams > 1:
            print("outputs.beam_indices:", outputs.beam_indices)
            print("outputs.beam_indices.shape[1] :", outputs.beam_indices.shape[1]) 
        print("outputs.sequences[:,prompt_len:] :", outputs.sequences[: , prompt_len:])
        print("outputs.sequences[: , prompt_len:].shape[1]:", outputs.sequences[: , prompt_len:].shape[1])
        print("gen_answers :", gen_answers)
        print("=======================================")
        print("Length of activations:", len(activations))
        print("Shape of activations[0]:", activations[0].shape)
        print("Shape of activations[1]:", activations[1].shape)
        print("Shape of activations[-1]:", activations[-1].shape)
        print("============")
        print("Length of outputs.hidden_states: ", len(outputs.hidden_states))
        print("Shape of outputs.hidden_states[0][L]:", outputs.hidden_states[0][layer_idx].shape)
        print("Shape of outputs.hidden_states[1][L]:", outputs.hidden_states[1][layer_idx].shape)
        print("Shape of outputs.hidden_states[-1][L]:", outputs.hidden_states[-1][layer_idx].shape)
        print("=======================================")

        print("activations[0] stats")
        print("mean", activations[0].mean().item(),
              "min", activations[0].min().item(), 
              "max", activations[0].max().item(), 
              "has inf?", torch.isinf(activations[0]).any().item(),
              "has nan?", torch.isnan(activations[0]).any().item())
        print("outputs.hidden_states[0][layer_idx] stats")
        print("mean", outputs.hidden_states[0][layer_idx].mean().item(),
              "min", outputs.hidden_states[0][layer_idx].min().item(), 
              "max", outputs.hidden_states[0][layer_idx].max().item(), 
              "has inf?", torch.isinf(outputs.hidden_states[0][layer_idx]).any().item(),
              "has nan?", torch.isnan(outputs.hidden_states[0][layer_idx]).any().item())
        
        print("activations[1:][-1] stats")
        print("mean", activations[1:][-1].mean().item(),
              "min", activations[1:][-1].min().item(), 
              "max", activations[1:][-1].max().item(), 
              "has inf?", torch.isinf(activations[1:][-1]).any().item(),
            "has nan?", torch.isnan(activations[1:][-1]).any().item())
        print("outputs.hidden_states[-1][layer_idx] stats")
        print("mean", outputs.hidden_states[-1][layer_idx].mean().item(),
              "min", outputs.hidden_states[-1][layer_idx].min().item(), 
              "max", outputs.hidden_states[-1][layer_idx].max().item(), 
              "has inf?", torch.isinf(outputs.hidden_states[-1][layer_idx]).any().item(),
            "has nan?", torch.isnan(outputs.hidden_states[-1][layer_idx]).any().item())
        print("=======================================")

        # ===============================
        # Truncate activations to match real generation steps (cf. Understanding Note #1)
        # ===============================
        # During generation, the model may run extra forward passes (especially with beam search)
        # beyond the number of tokens in the final output. This results in activations being longer
        # than needed — we need to truncate them accordingly.
        # (see Understanding Note #1).
        if k_beams > 1:
            # In beam search, we use beam_indices.shape[1] to determine the actual number of generation steps
            gen_len = outputs.beam_indices.shape[1]
        else:
            # In greedy/top-k sampling, gen_len is simply the number of new tokens beyond the prompt
            gen_len = outputs.sequences.shape[1] - prompt_len

        # Sometimes, activations may include extra "ghost" steps (e.g., due to internal padding/sync in beam search)
        bool_truncate_activations = (len(generation_activations) >= gen_len) 
 
        if bool_truncate_activations:
            # Truncate extra steps to ensure alignment with generated tokens
            generation_activations = generation_activations[:gen_len]

        """
        ==================================
        Understanding Note #1:
        ==================================
        When using beam search in Hugging Face Transformers, the number of decoder hidden states
        (len(outputs.hidden_states)) can be greater than the number of tokens in the final generated 
        sequence (outputs.sequences[:,prompt_len:].shape[1] = outputs.beam_indices.shape[1]). 
        This happens because, during beam search, the model explores multiple candidate sequences 
        (beams) at each generation step and continues generating until a stopping condition is met 
        (such as all beams reaching EOS or the maximum number of tokens). But because beams can 
        finish at different steps (some hitting EOS early, others continuing), the model must keep
        generating for the remaining active beams. 
        *Note* that in our code, outputs.hidden_states and activations are the same. 
      
        Explanation from Hugging Face, January 2023: 
        (https://github.com/huggingface/transformers/issues/21374)
        "Beam Search: Here it's trickier. In essence, beam search looks for candidate outputs until it hits 
        a stopping condition. The candidate outputs can have fewer tokens than the total number of generation 
        steps -- for instance, in an encoder-decoder text model, if your input is How much is 2 + 2? and the 
        model generates as candidates <BOS>4<EOS> (3 tokens) and <BOS>The answer is potato<EOS> 
        (for argument's sake, 6 tokens) before deciding to stop, you should see sequences with shape [1, 3] 
        and decoder_hidden_states with length 5, because 5 tokens were generated internally before settling 
        on the 1st candidate."    
        """

        # ===============================
        # Truncate generated token IDs to match activations (cf. Understanding Note #2) 
        # ===============================
        # - When N tokens are generated, only the first N-1 tokens have corresponding hidden states.
        #   So activations[1:] covers only the first N-1 steps (cf. Understanding Note #2).
        #   Therefore, we exclude the last generated token from outputs.sequences and beam_indices
        #   to match activations[1:]
        # - Exception: if activations were truncated earlier (bool_truncate_activations = True),
        #   then we already lost activations of the final decoding step(s), and our activations[1:]
        #   only cover the available tokens. In that case, we keep the full `gen_len` to match.
        # (see Understanding Note #2)
        if bool_truncate_activations:
            expected_gen_len = gen_len  # All generated tokens have hidden states
        else: 
            expected_gen_len  = gen_len - 1 # Drop final token to match activations[1:]

        # Truncate generated sequences and beam paths accordingly
        truncated_gen_sequences = outputs.sequences[:, prompt_len : prompt_len + expected_gen_len]
        if k_beams > 1:
            truncated_beam_indices = outputs.beam_indices[:, :expected_gen_len] 

        """
        ==================================
        Understanding Note #2:
        ==================================
        When using model.generate() with output_hidden_states=True (what we are replicating here with the hook),
        use_cache=True and max_new_tokens=30, there is always an offset between the length of the 
        generated sequence (outputs.sequences.shape[1][prompt_len:]) and the length of len(outputs.hidden_states) : 
        * outputs.sequences.shape[1] = prompt_len (17) + max_new_tokens (30) = 47
        * len(outputs.hidden_states) = max_new_tokens (30)
            With : 
            * outputs.hidden_states[0][layer_idx].shape = (batch_size, prompt_len, hidden_size)           --> includes the prompt ! 
            * outputs.hidden_states[i][layer_idx].shape = (batch_size, 1, hidden_size) with 1 <= i <= 29  --> stops at 29 ! 
        *Note* that in our code, outputs.hidden_states and activations are the same. 
            
        Explanation from Hugging Face, April 2024 
        (https://github.com/huggingface/transformers/issues/30036):
        "If you have 30 tokens at the end of generation, you'll always have 29 hidden states.
        The token with index N is used to produce hidden states with index N, which is then used 
        to get the token with index N+1. The generation ends as soon as the target number of 
        tokens is obtained so, when we obtain the 30th token, we don't spend compute to get the 30th 
        set of hidden states. You can, however, manually run an additional forward pass to obtain the 
        30th set of hidden states, corresponding to the 30th token and used to obtain the 31st token.
        """
        # ===============================
        # Align generated and prompt hidden states
        # ===============================
        # Extract the hidden states that correspond to the generated sequence
        # selected by the beam search (or top-k sampling if k_beams = 1)
        aligned_generation_hidden_states = align_generation_hidden_states(
            generation_activations=generation_activations, 
            beam_indices=truncated_beam_indices if k_beams > 1 else None,
            k_beams=k_beams
        ) # Shape: (batch_size, gen_len, hidden_size)

        # Extract the hidden states that correspond to the prompt
        aligned_prompt_hidden_states = align_prompt_hidden_states(
            prompt_activations=prompt_activations, 
            k_beams=k_beams
        ) # Shape: (batch_size, prompt_len, hidden_size)

        # Concatenate the prompt and generation aligned hidden states  
        aligned_prompt_and_gen_hidden_states = torch.cat(
            [aligned_prompt_hidden_states, 
             aligned_generation_hidden_states], 
             dim=1
        ) # Shape: (batch_size, prompt_len + gen_len, hidden_size)

        
        print("=======================================")
        print("aligned_prompt_hidden_states stats")
        print(" min", aligned_prompt_hidden_states.min().item(), 
              "max", aligned_prompt_hidden_states.max().item(), 
              "has inf?", torch.isinf(aligned_prompt_hidden_states).any().item(),
            "has nan?", torch.isnan(aligned_prompt_hidden_states).any().item())

        print("aligned_generation_hidden_states stats")
        print(" min", aligned_generation_hidden_states.min().item(), 
              "max", aligned_generation_hidden_states.max().item(), 
              "has inf?", torch.isinf(aligned_generation_hidden_states).any().item(),
            "has nan?", torch.isnan(aligned_generation_hidden_states).any().item())
        print("=======================================")


        # ===============================
        # Build generation and prompt attention mask
        # ===============================
        # This mask marks which generated tokens are valid (i.e., not padding).
        # Positions are marked True up to and including the first eos_token_id
        generation_attention_mask = build_generation_attention_mask(
            gen_ids=truncated_gen_sequences, 
            eos_token_id=tokenizer.eos_token_id
        ) # Shape (batch_size, gen_len)

        # Prompt attention mask
        prompt_attention_mask = inputs["attention_mask"] 
        # Shape (batch_size, prompt_len)
        
        # ===============================
        # Modify prompt attention mask with offsets
        # ===============================
        if start_offset !=0 or end_offset !=0:
            prompt_attention_mask, start_indices, end_indices = compute_offset_attention_mask(
                attention_mask=prompt_attention_mask, 
                start_offset=start_offset, 
                end_offset=end_offset
            ) # Shape (batch_size, prompt_len), (batch_size,), (batch_size,)

        # Concatenate the prompt and generation attention mask
        prompt_and_gen_attention_mask = torch.cat(
            [prompt_attention_mask,
            generation_attention_mask],
            dim=1
        ) # Shape (batch_size, prompt_len + gen_len)

        # ==============================
        # Extract token activations from captured layer, based on source
        # ==============================
        if activation_source == "generation":
            # Return only the token activations from the generated answer 
            selected_token_vecs = extract_token_activations_fn(
                    selected_layer=aligned_generation_hidden_states, 
                    attention_mask=generation_attention_mask, 
                    device=aligned_generation_hidden_states.device,
                ) # Shape (batch_size, hidden_size)
            
        elif activation_source == "prompt":    
            # Return only the token activations from the prompt
            selected_token_vecs = extract_token_activations_fn(
                    selected_layer=aligned_prompt_hidden_states, 
                    attention_mask=prompt_attention_mask, 
                    device=aligned_prompt_hidden_states.device,
                ) # Shape (batch_size, hidden_size)
            
        elif activation_source == "promptGeneration":
            # Return token activations from the concatenated prompt + generated answer 
            selected_token_vecs = extract_token_activations_fn(
                    selected_layer=aligned_prompt_and_gen_hidden_states, 
                    attention_mask=prompt_and_gen_attention_mask, 
                    device=aligned_prompt_and_gen_hidden_states.device,
                    skip_length=prompt_len 
                    # skip_length: exclude prompt from computation if 
                    # mode=='first_generated' in `extract_token_activations_fn`
                ) # Shape (batch_size, hidden_size)

        else:
            raise ValueError(
                f"Invalid value for `activation_source`: '{activation_source}'. "
                f"Expected one of: ['prompt', 'generation', 'promptGeneration']."
            )    
        
        print("selected_token_vecs:", selected_token_vecs)
        # ==============================
        # Store results (to file or memory)
        # ==============================
        activations = {}
        for mode, tensor in selected_token_vecs.items():
            activations[mode] = [tensor[j].unsqueeze(0).cpu().numpy() for j in range(tensor.size(0))]

        batch_dataset_ids = []; batch_dataset_original_idx = []; batch_context = []
        batch_question = []; batch_gt_answers = []; batch_title = []
        for s in batch:
            batch_dataset_ids.append(s['id'])
            batch_dataset_original_idx.append(s['original_index'])
            batch_context.append(s['context'])
            batch_question.append(s['question'])
            batch_gt_answers.append(s['answers'])
            batch_title.append(s['title'])
        
        batch_results = {
            "id": batch_dataset_ids,
            "original_indices": batch_dataset_original_idx,
            "activations": activations, # Dict
            "gen_answers": gen_answers,
            "gt_answers": batch_gt_answers,
            "context": batch_context,
            "question": batch_question,
            "title": batch_title,
        }

        if save_to_pkl:
            #append_to_pickle(output_path, batch_results)
            #save_batch_pickle(batch_data=batch_results, output_dir=output_path, batch_idx=i)
            pass
        else:
            for mode, acts in activations.items():
                if mode not in batch_activations:
                    batch_activations[mode] = []
                if isinstance(acts, list):
                    batch_activations[mode].extend(acts)
                else:
                    batch_activations[mode].extend([a for a in acts])

            #batch_activations.extend(activations)
        
    if not save_to_pkl:
        return batch_activations

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

def load_llama(model_name: str = "meta-llama/Llama-2-7b-chat-hf"):
    model_name = model_name  # fine-tuned version of LLaMA for conversational uses

    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
    tokenizer.model_max_length = 1024  # LLaMA-2’s max length tokens is 4096
    
    # For llama, pad_token is not defined by default:
    # The convention is to use tokenizer.pad_token = tokenizer.eos_token 
    # However, this causes an issue when outputting attention maps during model.generate()
    if tokenizer.pad_token is None:
        # add "<pad>" to vocab as a special token
        tokenizer.add_special_tokens({"pad_token": "<pad>"})
        print("<pad> token not defined by default, add it to vocabulary.")
       
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16, 
            device_map= "auto",          # load model to device 
            low_cpu_mem_usage=True,     # reduce RAM usage during loading
            attn_implementation="eager",
            #output_hidden_states=True, # to hidden activations -> memory overload since we access ALL hidden states 
            #force_download=True        # redo complete download 
        )
    
        # required for the model to accept the new vocabulary.
        model.resize_token_embeddings(len(tokenizer)) 

    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16, 
            device_map="auto",          # load model to device 
            low_cpu_mem_usage=True,     # reduce RAM usage during loading
        )

    # Ensures that during generation all sequences are aligned with the PAD token, and not with random tokens. 
    model.config.pad_token_id = tokenizer.pad_token_id 
    
    return model, tokenizer

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

def load_llama(model_name: str = "meta-llama/Llama-2-7b-chat-hf"):
    model_name = model_name  # fine-tuned version of LLaMA for conversational uses

    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
    tokenizer.model_max_length = 1024  # LLaMA-2’s max length tokens is 4096
    tokenizer.pad_token = tokenizer.eos_token  # pad_token not defined by default: reuse the EOS token (</s>) as the padding token.

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16, 
        device_map="cuda:0", #,"auto",          # load model to device 
        low_cpu_mem_usage=True,     # reduce RAM usage during loading
        attn_implementation="eager", # "" "flex_attention"
        #output_hidden_states=True, # to hidden activations -> memory overload since we access ALL hidden states 
        #force_download=True        # redo complete download 
    )
    print("attn_implementation changed")
    model.config.pad_token_id = model.config.eos_token_id # ensures that during generation all sequences are aligned with the EOS token, and not with random tokens. 
    
    return model, tokenizer

In [None]:
'''
for layer in model.model.layers:
    layer.self_attn.forward = patched_LlamaAttention_forward.__get__(layer.self_attn, layer.self_attn.__class__)

'''
attention_module = model.model.layers[-1].self_attn
print("attention_module._attn_implementation:", attention_module.config._attn_implementation)

# attention_module.type <class 'transformers.models.llama.modeling_llama.LlamaAttention'>

for layer in model.model.layers:
    layer.self_attn._attn_implementation = "eager"

model.model.layers[-1].self_attn



'''from transformers import GPT2LMHeadModel, GPT2Tokenizer

model_name = "gpt2"  # ou "gpt2-medium", "gpt2-large"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

# GPT-2 a un pad_token natif
tokenizer.pad_token = tokenizer.eos_token'''

In [None]:
'''
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        
        print("=== DEBUG ATTENTION FORWARD ===")
    
        # 1. Vérifier les inputs

        print(f"   INPUT hidden_states:")
        print(f"   Shape: {hidden_states.shape}")
        print(f"   Dtype: {hidden_states.dtype}")
        print(f"   Has NaN: {torch.isnan(hidden_states).any()}")
        print(f"   Has Inf: {torch.isinf(hidden_states).any()}")
        print(f"   Range: [{hidden_states.min():.4f}, {hidden_states.max():.4f}]")
        num_nans = torch.isnan(hidden_states).sum().item()
        print(f"Nombre de NaNs dans hidden_states : {num_nans}")
        # Masque des NaNs (True là où c'est NaN)
        nan_mask = torch.isnan(hidden_states)
        # Affichage d'exemples de positions NaN
        print("Positions des premiers NaNs :", nan_mask.nonzero(as_tuple=False)[:10])
        # Nombre de NaNs par dimension (par exemple par token ou par batch)
        nans_par_token = nan_mask.view(hidden_states.size(0), -1).sum(dim=1)
        print("Nombre de NaNs par exemple du batch :", nans_par_token)
        if num_nans > 0:
            return  

         # 2. Vérifier les poids des projections
        print(f"\n  PROJECTION WEIGHTS:")
        for name, param in [("q_proj", self.q_proj.weight), ("k_proj", self.k_proj.weight), ("v_proj", self.v_proj.weight)]:
            print(f"   {name}.weight:")
            print(f"     Has NaN: {torch.isnan(param).any()}")
            print(f"     Has Inf: {torch.isinf(param).any()}")
            print(f"     Range: [{param.min():.4f}, {param.max():.4f}]")
            print(f"     Norm: {param.norm().item():.4f}")
        # 3. Vérifier les biais (s'ils existent)
        for name, proj in [("q_proj", self.q_proj), ("k_proj", self.k_proj), ("v_proj", self.v_proj)]:
            if hasattr(proj, 'bias') and proj.bias is not None:
                print(f"   {name}.bias:")
                print(f"     Has NaN: {torch.isnan(proj.bias).any()}")
                print(f"     Has Inf: {torch.isinf(proj.bias).any()}")
                print(f"     Range: [{proj.bias.min():.4f}, {proj.bias.max():.4f}]")

        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        print(f"\n PROJECTIONS:")

        # Q projection
        print(f"   Computing Q projection...")
        q_raw = self.q_proj(hidden_states)
        print(f"   Q raw - Has NaN: {torch.isnan(q_raw).any()}, Range: [{q_raw.min():.4f}, {q_raw.max():.4f}]")
        
        query_states = q_raw.view(hidden_shape).transpose(1, 2)
        print(f"   Q reshaped - Has NaN: {torch.isnan(query_states).any()}")
        
        # K projection
        print(f"   Computing K projection...")
        k_raw = self.k_proj(hidden_states)
        print(f"   K raw - Has NaN: {torch.isnan(k_raw).any()}, Range: [{k_raw.min():.4f}, {k_raw.max():.4f}]")
        
        key_states = k_raw.view(hidden_shape).transpose(1, 2)
        print(f"   K reshaped - Has NaN: {torch.isnan(key_states).any()}")
        
        # V projection
        print(f"   Computing V projection...")
        v_raw = self.v_proj(hidden_states)
        print(f"   V raw - Has NaN: {torch.isnan(v_raw).any()}, Range: [{v_raw.min():.4f}, {v_raw.max():.4f}]")
        
        value_states = v_raw.view(hidden_shape).transpose(1, 2)
        print(f"   V reshaped - Has NaN: {torch.isnan(value_states).any()}")
        

        """
        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        """

        print("Q max", query_states.max(), "min", query_states.min())
        print("K max", key_states.max(), "min", key_states.min())

        cos, sin = position_embeddings

        print(f"\n  POSITION EMBEDDINGS:")
        print(f"   cos - Has NaN: {torch.isnan(cos).any()}, Range: [{cos.min():.4f}, {cos.max():.4f}]")
        print(f"   sin - Has NaN: {torch.isnan(sin).any()}, Range: [{sin.min():.4f}, {sin.max():.4f}]")

        print(f"\n  APPLYING RoPE...")
        query_states_before_rope = query_states.clone() ######
        key_states_before_rope = key_states.clone() #######

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        print(f"   Q after RoPE - Has NaN: {torch.isnan(query_states).any()}")
        print(f"   K after RoPE - Has NaN: {torch.isnan(key_states).any()}")
        
        if torch.isnan(query_states).any() or torch.isnan(key_states).any():
            print("  NaN detectes après RoPE!")
            print(f"   Q before RoPE had NaN: {torch.isnan(query_states_before_rope).any()}")
            print(f"   K before RoPE had NaN: {torch.isnan(key_states_before_rope).any()}")
        

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:  
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
        
        print("attention_interface:", attention_interface)
        print("attention_mask:", attention_mask)

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )
        
        #print("attention weights dtype:", attn_weights.dtype)

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights
    '''

In [None]:
def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        print("Q max", query_states.max(), "min", query_states.min())
        print("K max", key_states.max(), "min", key_states.min())

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:  
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
        
        print("attention_interface:", attention_interface)
        print("attention_mask:", attention_mask)

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )
        
        print("attention weights dtype:", attn_weights.dtype)

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

In [None]:
from transformers.models.llama.modeling_llama import LlamaAttention, eager_attention_forward, sdpa_attention_forward

def patched_forward(
    self,
    hidden_states,
    position_embeddings,
    attention_mask,
    past_key_value=None,
    cache_position=None,
    **kwargs,
):
    # Codes classiques (projections etc.)
    input_shape = hidden_states.shape[:-1]
    hidden_shape = (*input_shape, -1, self.head_dim)

    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

    cos, sin = position_embeddings
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

    # --- 1. Passe principale avec SDPA ---
    attn_output, _ = sdpa_attention_forward(
        self,
        query_states,
        key_states,
        value_states,
        attention_mask,
        dropout=0.0 if not self.training else self.attention_dropout,
        scaling=self.scaling,
        **kwargs,
    )

    # --- 2. Calcul des attn_weights en EAGER (mais pas utilisés, juste pour logging) ---
    # À ce stade, attention: le calcul de attn_weights peut planter si instable ou NaN
    try:
        _, attn_weights = eager_attention_forward(
            self,
            query_states, key_states, value_states, attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling, **kwargs,
        )
    except Exception as ex:
        print(f"!!! Eager NaN/inf during attn_weights calculation: {ex}")
        attn_weights = None

    attn_output = attn_output.reshape(*input_shape, -1).contiguous()
    attn_output = self.o_proj(attn_output)
    return attn_output, attn_weights  # NB: seule la sortie SDPA est utilisée downstream !


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
import torch
import torch
import sys
import os 
# Add the path to the src directory
sys.path.append(os.path.abspath(".."))
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"

set_seed(1234)

from src.model_loader.llama_loader import load_llama
model, tok = load_llama(MODEL_NAME)
model.eval()

encoding = tok(["Hi there, how are you?"], return_tensors="pt").to("cuda:0")
with torch.no_grad():
    generation_output = model.generate(**encoding, return_dict_in_generate=True, output_logits=True)

sequences = generation_output.sequences
sanity_check_logits = generation_output.logits

with torch.no_grad():
    model_output = model(input_ids=encoding['input_ids'])

# Vérification des logits
prompt_len = encoding['input_ids'].shape[1]
total_len = sequences.shape[1]
generated_len = total_len - prompt_len

print(f"Longueur du prompt: {prompt_len}")
print(f"Longueur totale des sequences: {total_len}")
print(f"Longueur générée: {generated_len}")
print(f"Nombre de logits générés: {len(generation_output.logits)}")

# Vérification que sequences = prompt + génération
print(f"\nPrompt original: {encoding['input_ids']}")
print(f"Séquence complète: {sequences}")
print(f"Partie prompt de sequences: {sequences[0, :prompt_len]}")
print(f"Partie générée de sequences: {sequences[0, prompt_len:]}")
print(f"Prompt == partie prompt ? {torch.equal(encoding['input_ids'][0], sequences[0, :prompt_len])}")

# Le dernier logit du prompt (à la position prompt_len-1) prédit le premier token généré
last_prompt_logit = model_output.logits[:, -1, :].float()  # Dernier logit du prompt

# Le premier logit de génération correspond au premier token généré
first_gen_logit = generation_output.logits[0].float()  # Premier logit de la génération

# Comparaison
diff = torch.max(torch.abs(first_gen_logit - last_prompt_logit)).cpu().item()
are_close = torch.allclose(first_gen_logit, last_prompt_logit, rtol=1e-5, atol=1e-8)

print(f"\nLes logits sont-ils identiques ? {are_close}")
print(f"Différence maximale: {diff:.10f}")

# Affichage pour debug
print(f"\nShape du dernier logit du prompt: {last_prompt_logit.shape}")
print(f"Shape du premier logit généré: {first_gen_logit.shape}")

# Vérification que le premier token généré correspond bien
first_generated_token_id = sequences[0, prompt_len]  # Premier token après le prompt
predicted_token_id = torch.argmax(last_prompt_logit, dim=-1)
print(f"\nPremier token généré (ID): {first_generated_token_id}")
print(f"Token prédit par le dernier logit du prompt (ID): {predicted_token_id.item()}")
print(f"Les tokens correspondent-ils ? {first_generated_token_id == predicted_token_id.item()}")

# Décodage pour visualisation
print(f"\nPremier token généré: '{tok.decode(first_generated_token_id)}'")
print(f"Token prédit: '{tok.decode(predicted_token_id)}'")

In [None]:
import torch
import numpy as np

def compute_perplexity(
        prompt_logits: torch.Tensor, 
        gen_logits: torch.Tensor,
        prompt_input_ids: torch.Tensor, 
        gen_input_ids: torch.Tensor,
        prompt_attention_mask: torch.Tensor,
        gen_attention_mask: torch.Tensor,
        mode: Literal["prompt", "generation", "promptGeneration"] = "prompt",
        min_k: float = None
    ):
    """
    Computes the per-sample perplexity of language model outputs using logits 
    and corresponding input token IDs. Logits maked by 0 in the attention mask 
    are ignored in the computation of the perplexity. 

    Perplexity is defined as:
        Perplexity = exp(- mean(log P(token_i | context)))

    NOTE: This implementation is inspired by:
    "LLM-Check: Investigating Detection of Hallucinations in Large Language Models"
    (Sriramanan et al., 2024)

    Parameters
    ----------
    prompt_logits : torch.Tensor
        Tensor of shape (batch_size, prompt_len, vocab_size) 
        These are the model's output logits obtained from a standard forward pass over the prompt sequence.
    gen_logits : torch.Tensor
        Tensor of shape (batch_size, gen_len, vocab_size).
        These are the logits obtained during autoregressive decoding using `model.generate()`.
    prompt_input_ids : torch.Tensor
        Tensor of shape (batch_size, prompt_len), containing the input token IDs for the prompt.
    gen_input_ids : torch.Tensor
        Tensor of shape (batch_size, gen_len), containing the token IDs generated by the model.
    prompt_attention_mask: torch.Tensor
        Tensor of shape (batch_size, Tensor), 1 where token valid, 0 for padding.
    gen_attention_mask: torch.Tensor  
        Tensor of shape (batch_size, gen_len), 1 where token valid, 0 for padding.
    mode : str, optional
        One of {"prompt", "generation", "promptGeneration"}:
        - "prompt": compute perplexity only over the prompt.
        - "generation": compute perplexity only over the generated tokens.
        - "promptGeneration": compute perplexity over both prompt and generation.
    min_k : float, optional
        Optional value between 0 and 1. If specified, only the bottom-k lowest-probability
        tokens are used for perplexity calculation.

    Returns
    --------
        np.ndarray: Per-sample perplexity scores of shape (batch_size,)

    Notes
    -----
    About token shifting in autoregressive models:

    In a standard autoregressive forward pass:
        - At step *t*, the model predicts the token at position *t* based on the tokens up to *t-1*.
        - Thus, the logit at position *t* predicts the token at position *t+1*.
        - The first token has no preceding context and is not predicted.
        - When computing log-probabilities, we must **shift the targets one position to the left** 
        to correctly align logits with target tokens.
        
        Example: Suppose we have a sequence of tokens (with their token IDs):
        | Index | Token | ID  |
        |-------|-------|-----| - The model produces logits at positions 0, 1, 
        | 0     | A     | 10  | and 2 to predict the tokens B, C, and D, respectively.
        | 1     | B     | 29  |
        | 2     | C     | 305 |  - The logits at position 0 are used to predict
        | 3     | D     | 24  |  token B (ID 29).

    During generation (e.g., using model.generate()):
        - The logit at time step *t* predicts the token generated at position *t*.
        - Each logit already corresponds to the prediction of the token at this step 
        - No shifting is needed in this case.

    Summary of alignment:
        - Prompt: logit at position *t* predicts token at position *t+1* -> shift targets left.
        - Generation: logit at position *t* predicts token at position *t* -> no shift.

    NOTE: help from issue https://github.com/huggingface/transformers/issues/29664
    """

    softmax = torch.nn.Softmax(dim=-1)

    # Apply softmax over vocabulary dimension and take log to get log-probabilities
    prompt_log_probs = torch.log(softmax(prompt_logits))  # shape: (batch_size, prompt_len, vocab_size)
    gen_log_probs = torch.log(softmax(gen_logits))        # shape: (batch_size, gen_len, vocab_size)

    if mode in ("prompt", "promptGeneration"):
        # In prompt: logit at position t predicts token at t+1 (requires shifting)
        # Remove first token from target (no context to predict it)
        prompt_target_tokens = prompt_input_ids[:, 1:] # (batch_size, prompt_len - 1)
        
        # Remove last logit position (since it predicts next token)
        prompt_pred_log_probs = prompt_log_probs[:, :-1, :] # shape: (batch_size, prompt_len - 1, vocab_size)
        
        # Retrieves, for each position and each batch, the log-probability corresponding to the next token 
        # (the one in target_tokens) from all the probas on the vocabulary.
        prompt_token_log_probs = prompt_pred_log_probs.gather(
            dim=2, index=prompt_target_tokens.unsqueeze(-1)
            ).squeeze(-1) # shape: (batch_size, prompt_len - 1)
    
    if mode in ("prompt", "promptGeneration"):
        # In generation: logit at position t predicts token at position t (no shift needed)
        gen_token_log_probs = gen_log_probs.gather(
            dim=2, index=gen_input_ids.unsqueeze(-1)
            ).squeeze(-1)  # shape: (batch_size, gen_len)
        
    if mode == "promptGeneration":
        # Last logit of prompt from the forward pass == first logit of generation from `model.generate()`. 
        # To compute perplexity over the full sequence:
        # - Use prompt_token_log_probs (excluding final prompt token)
        # - Use gen_token_log_probs from generation
        # Concatenate both to form a complete sequence of predicted log-probs
        prompt_gen_token_log_probs = torch.cat(
            [prompt_token_log_probs, gen_token_log_probs], dim=1
        )  # shape: (batch_size, prompt_len - 1 + gen_len)

    # Select the appropriate token log-probabilities based on mode
    if mode == "prompt":
        token_log_probs = prompt_token_log_probs  # (batch_size, prompt_len - 1)
    elif mode == "generation":
        token_log_probs = gen_token_log_probs # (batch_size, gen_len)
    elif mode == "promptGeneration":
        token_log_probs = prompt_gen_token_log_probs # (batch_size, prompt_len - 1 + gen_len)

    # Optionally focus only on the k% hardest tokens (lowest log-probs)
    if min_k is not None:
        # Keep only the min_k fraction of tokens with the lowest log-probs 
        k = int(min_k * token_log_probs.size(1))  # number of tokens to keep per sample
        
        # Use topk with largest=False to get the k tokens with the lowest log-probabilities
        topk_vals, _ = torch.topk(token_log_probs, k=k, dim=1, largest=False)

        # Compute perplexity using only the selected subset
        ppls = torch.exp(-topk_vals.mean(dim=1))
    else:
        # Compute perplexity over all predicted tokens
        ppls = torch.exp(-token_log_probs.mean(dim=1))

    return ppls.cpu().numpy()


In [None]:
import torch
import numpy as np

def compute_attn_eig_prod(
    prompt_attentions: torch.Tensor,
    generation_attentions: List[torch.Tensor],
    attentions: List[torch.Tensor], 
    mode: Literal["prompt", "generation", "promptGeneration"] = "promptGeneration"
):
    """
    Compute a mean log-diagonal attention score (eigenvalue-inspired) for a single layer's attention map.
    NOTE: Implementation inspired from 
    "LLM-Check: Investigating Detection of Hallucinations in Large Language Models"
    (Sriramanan et al. 2024)

    Parameters
    ----------
    attentions : list of torch.Tensor: [attn_prompt, attn_gen1, attn_gen2, ...]
        - attentions[0]: Tensor of shape (batch_size, n_heads, prompt_len, prompt_len)
            Self-attention over the prompt tokens. 
        - attentions[1:]: List of tensors of shape(batch_size, n_heads, 1, prompt_len + t)
            Self-attention for each generated token at generation step t (t >= 1).
     mode : str, optional
        Specifies which part of the attention map to use for the score computation.
        Must be one of the following:
        - "prompt":
            Only uses the prompt self-attention matrix (attentions[0]).
            The diagonal (i.e., self-attention values per token) is extracted,
            then the log is taken, followed by a mean over prompt tokens and sum over heads.
        - "generation":
            Only uses the generated self-attention maps (attentions[1:]).
            Each tensor in attentions[1:] has shape (batch_size, n_heads, 1, prompt_len + t),
            where t is the generation step. 
            Intead of concatenating these tensors to obtain the generation attention matrix, 
            for each step, we directly take the last value along the last axis (i.e., the self-attention
            of the newly generated token). These values are stacked across time steps, then we take the log,
            compute the mean over time, and sum over heads.
        - "prompt+generation":
            Combines the diagonals from both the prompt and generation attention maps as described above
            for "prompt" and "generation" mode. The two diagonals are concatenated along the token/time axis, 
            then the log is taken, followed by a mean across all tokens and a sum over heads.
            Note: we do **not** concatenate the full prompt and generation attention matrices,
            since the diagonal of the combined matrix would only include values from the prompt attention
            due to mismatched matrix shapes.

    Returns
    -------
    np.ndarray
        A NumPy array of shape (batch_size,), where each value is the per-sample attention score.
        The score is summed across heads and averaged across tokens (in log-space).
    """
    assert mode in ("prompt", "generation", "promptGeneration"), "Invalid mode."
    
    prompt_attentions = attentions[0]
    gen_attentions = attentions[1:]

    batch_size, n_heads = prompt_attentions.shape[:2]
    n_generated = len(attentions) - 1

    diag_blocks = []

    if mode in ("prompt", "promptGeneration"):
        prompt_diag = torch.diagonal(prompt_attentions, dim1=-2, dim2=-1) # (batch_size, n_heads, prompt_len)
        diag_blocks.append(prompt_diag)
        print("prompt_diag.shape", prompt_diag.shape)

    if mode in ("generation", "promptGeneration") and n_generated > 0:
        # For each generation step, take the (batch_size, n_heads, 1, prompt_len + t)
        # => keep last value of last axis (self-attention of generated token)
        gen_diag_steps = [attn[..., -1].squeeze(-1) for attn in gen_attentions]  # list of (batch_size, n_heads)
        gen_debug = [attn[..., -1] for attn in gen_attentions] 
        print("len(gen_diag_steps): ", len(gen_diag_steps))
        print("gen_diag_steps[0].shape:", gen_diag_steps[0].shape)
        print("gen_debug[0].shape:", gen_debug[0].shape)
        # Stack along newly generated tokens
        gen_diag = torch.stack(gen_diag_steps, dim=-1) if gen_diag_steps else None # (batch_size, n_heads, n_generated)
        if gen_diag is not None:
            diag_blocks.append(gen_diag)
            print("gen_diag.shape: ", gen_diag.shape)

    # Now concatenate along token axis
    all_diags = torch.cat(diag_blocks, dim=-1) # (batch_size, n_heads, N) where N = prompt_len + n_generated (or a subset)
    print("all_diags.shape:", all_diags.shape)

    # Take log, mean over N tokens, sum over heads
    all_diags = all_diags.clamp(min=1e-6)
    log_diag = torch.log(all_diags).mean(dim=-1) # (batch_size, n_heads)
    print("log_diag.shape: ", log_diag.shape)
    scores = log_diag.sum(dim=-1).cpu().numpy()  # sum over n_heads
    print("scores.shape: ", scores.shape)

    return scores # (batch_size,)

In [None]:
import torch
import torch.nn.functional as F
import numpy as np

def compute_logit_entropy(
    prompt_logits: torch.Tensor,
    gen_logits: torch.Tensor,
    prompt_attention_mask: torch.Tensor,
    gen_attention_mask: torch.Tensor,
    mode: str = "prompt",
    top_k: int = None
) -> np.ndarray:
    """
    Computes the per-sample entropy of a language model's output distributions
    using its logits and attention masks.
    For each token position, the function computes the entropy of the softmax distribution
    over the vocabulary. Entropy is averaged over the valid tokens (i.e., those marked
    as 1 in the attention mask). If `top_k` is specified, the entropy is computed only
    over the top-k logits (highest values) for each position.

    Entropy is defined as:
        Entropy = -Sum_i p_i * log(p_i)
        where p_i = softmax(logits)_i

    Parameters
    ----------
    prompt_logits : torch.Tensor
        Tensor of shape (batch_size, prompt_len, vocab_size).
    gen_logits : torch.Tensor
        Tensor of shape (batch_size, gen_len, vocab_size).
    prompt_attention_mask : torch.Tensor
        Tensor of shape (batch_size, prompt_len). Contains 1 where valid token, 0 for padding.
    gen_attention_mask : torch.Tensor  
        Tensor of shape (batch_size, gen_len). Contains 1 where valid token, 0 for padding.
    mode : str, optional
        One of {"prompt", "generation", "promptGeneration"}:
        - "prompt": compute entropy only over the prompt tokens.
        - "generation": compute entropy only over generated tokens.
        - "promptGeneration": compute entropy over both prompt and generated tokens.
    top_k : int, optional
        If specified, compute entropy only over the top-k logits per token.

    Returns
    -------
    np.ndarray
        Per-sample entropy values, shape (batch_size,).
    """

    def entropy_from_logits(logits, attention_mask, top_k=None):
        print(f"[DEBUG] Computing entropy from logits of shape {logits.shape}")
        print(f"[DEBUG] Attention mask shape: {attention_mask.shape}")

        # Convert float12 -> float32 for better accuracy during computations
        logits = logits.float()
        attention_mask = attention_mask.float()

        # AJOUT : Vérifier les logits d'entrée
        print(f"[DEBUG] Logits sample (batch0, first 5 tokens, first 10 vocab): {logits[0, :5, :10]}")
        print(f"[DEBUG] Logits min/max: {logits.min()}, {logits.max()}")
        print(f"[DEBUG] Attention mask sample (batch0, first 10): {attention_mask[0, :10]}")

        if top_k is not None:
            topk_vals = torch.topk(logits, k=top_k, dim=-1).values  # (batch_size, seq_len, top_k)
            print(f"[DEBUG] Selected top_k={top_k} logits shape: {topk_vals.shape}")
            probs = F.softmax(topk_vals, dim=-1)  # (batch_size, seq_len, top_k)
        else:
            probs = F.softmax(logits, dim=-1)  # (batch_size, seq_len, vocab_size)
            print(f"[DEBUG] Softmax probabilities shape: {probs.shape}")

        # AJOUT : Vérifier les probabilités
        print(f"[DEBUG] Probs sample (batch0, first 5 tokens, first 10 vocab): {probs[0, :5, :10]}")
        print(f"[DEBUG] Probs sum per token (should be ~1): {probs[0, :5].sum(dim=-1)}")

        epsilon = 1e-12  # Plus petit epsilon 1e-9
        log_probs = torch.log(probs + epsilon)  # numerical stability
        print(f"[DEBUG] probs sample values (batch0, first 5 tokens): {probs[0, :5]}")
        print(f"[DEBUG] log_probs sample values (batch0, first 5 tokens): {log_probs[0, :5]}")
        # AJOUT : Vérifier les log_probs
        print(f"[DEBUG] Log_probs sample (batch0, first 5 tokens, first 10 vocab): {log_probs[0, :5, :10]}")
        print(f"[DEBUG] Log_probs min/max: {log_probs.min()}, {log_probs.max()}")

        product = (probs * log_probs)
        print(f"[DEBUG] product per token shape: {product.shape}")
        print(f"HERE [DEBUG] product sample values (batch0, first 5 tokens): {product[0, :5]}")
        #entropy = -(probs* log_probs).sum(dim=-1)  # (batch_size, seq_len)

        # Use torch.special.entr, which automatically handles edge cases.
        # entropy(x) = -x * log(x) with entropy(0) = 0
        entropy = torch.special.entr(probs).sum(dim=-1) # (batch_size, seq_len)

        print(f"[DEBUG] Entropy per token shape: {entropy.shape}")
        print(f"HERE [DEBUG] Entropy sample values (batch0, first 5 tokens): {entropy[0, :5]}")
        # AJOUT : Vérifier l'entropie avant masquage
        print(f"[DEBUG] Entropy before masking (batch0, first 5): {entropy[0, :5]}")
        print(f"[DEBUG] Entropy min/max: {entropy.min()}, {entropy.max()}")
        print(f"[DEBUG] NaN count in entropy: {torch.isnan(entropy).sum()}")
    
        entropy_masked = entropy * attention_mask  # Zero out padded tokens
        print(f"[DEBUG] entropy_masked per token shape: {entropy_masked.shape}")
        print(f"[DEBUG] entropy_masked sample values (batch0, first 5 tokens): {entropy_masked[0, :5]}")
        total_entropy = entropy_masked.sum(dim=-1)  # sum over seq_len, (batch_size,)
        valid_token_count = attention_mask.sum(dim=-1)  # (batch_size,)
        print(f"[DEBUG] Total entropy per sample: {total_entropy}")
        print(f"[DEBUG] Valid token counts per sample: {valid_token_count}")

        return total_entropy.cpu().numpy(), valid_token_count.cpu().numpy()

    if mode == "prompt":
        total_entropy, count = entropy_from_logits(prompt_logits, prompt_attention_mask, top_k)
        avg_entropy = total_entropy / (count + 1e-9)
        print(f"[INFO] Mode 'prompt': avg_entropy shape {avg_entropy.shape}")
        print(f"[INFO] Sample avg_entropy: {avg_entropy}")
        return avg_entropy

    elif mode == "generation":
        total_entropy, count = entropy_from_logits(gen_logits, gen_attention_mask, top_k)
        avg_entropy = total_entropy / (count + 1e-9)
        print(f"[INFO] Mode 'generation': avg_entropy shape {avg_entropy.shape}")
        print(f"[INFO] Sample avg_entropy: {avg_entropy}")
        return avg_entropy

    elif mode == "promptGeneration":
        ent_prompt, count_prompt = entropy_from_logits(prompt_logits, prompt_attention_mask, top_k)
        ent_gen, count_gen = entropy_from_logits(gen_logits, gen_attention_mask, top_k)
        total_ent = ent_prompt + ent_gen
        total_count = count_prompt + count_gen
        avg_entropy = total_ent / (total_count + 1e-9)
        print(f"[INFO] Mode 'promptGeneration': avg_entropy shape {avg_entropy.shape}")
        print(f"[INFO] Sample avg_entropy: {avg_entropy}")
        return avg_entropy

    else:
        raise ValueError(f"Unknown mode: {mode}. Must be 'prompt', 'generation' or 'promptGeneration'")


In [None]:
def compute_logit_entropy(
    prompt_logits: torch.Tensor,
    gen_logits: torch.Tensor,
    prompt_attention_mask: torch.Tensor,
    gen_attention_mask: torch.Tensor,
    mode: str = "prompt",
    top_k: int = None,
    window_size: int = None,
    stride: int = None
) -> np.ndarray:
    def entropy_from_logits(logits, attention_mask, top_k=None):
        print(f"\n[DEBUG] Input logits shape: {logits.shape}")
        print(f"[DEBUG] Input logits: {logits}")
        print(f"[DEBUG] Input attention_mask shape: {attention_mask.shape}")
        print(f"[DEBUG] Input attention_mask: {attention_mask}")
        
        logits = logits.float()
        attention_mask = attention_mask.float()
        
        # >>>>> masquer avant de sélectionner le top k non ? non car c'est sur la dim vocab size 

        if top_k is not None:
            print(f"[DEBUG] top_k activated: {top_k}")
            topk_vals = torch.topk(logits, k=top_k, dim=-1).values
            print(f"[DEBUG] topk_vals:\n{topk_vals}")
            probs = F.softmax(topk_vals, dim=-1)
        else:
            print(f"[DEBUG] Using full softmax")
            probs = F.softmax(logits, dim=-1)

        entropy = torch.special.entr(probs).sum(dim=-1)
        print(f"[DEBUG] Entropy shape: {entropy.shape}")
        print(f"[DEBUG] Entropy example:\n{entropy}")
        return entropy, attention_mask

    def average_entropy(entropy, mask):
        print("\n[DEBUG] === AVERAGE ENTROPY ===")
        entropy_masked = entropy * mask
        print(f"[DEBUG] entropy_masked: {entropy_masked}")
        total_entropy = entropy_masked.sum(dim=-1)
        valid_count = mask.sum(dim=-1)
        print(f"[DEBUG] total_entropy: {total_entropy}")
        print(f"[DEBUG] valid_count: {valid_count}")
        avg_entropy = total_entropy / (valid_count + 1e-9)
        print(f"[DEBUG] avg_entropy: {avg_entropy}")
        return avg_entropy

    def max_sliding_window_entropy(entropy, mask, w, stride):
        print("\n[DEBUG] === SLIDING WINDOW ENTROPY ===")
        entropy = entropy.unsqueeze(1)  # (B, 1, T)
        mask = mask.unsqueeze(1)        # (B, 1, T)
        print(f"[DEBUG] entropy: \n{entropy}")
        print(f"[DEBUG] mask: \n{mask}")
        kernel = torch.ones(1, 1, w, device=entropy.device) / w
        print(f"[DEBUG] kernel: \n{kernel}")
        moving_avg = F.conv1d(entropy, kernel, stride=stride, padding=0)
        valid_counts = F.conv1d(mask, kernel, stride=stride, padding=0)
        valid_mask = (valid_counts == 1.0)

        print(f"[DEBUG] moving_avg shape: {moving_avg.shape}")
        print(f"[DEBUG] moving_avg :\n{moving_avg}")
        print(f"[DEBUG] valid_counts:\n{valid_counts}")
        print(f"[DEBUG] valid_mask:\n{valid_mask}")
        print(f"[DEBUG] moving_avg (before masking):\n{moving_avg}")
        
        moving_avg = moving_avg.masked_fill(~valid_mask, float('-inf'))

        print(f"[DEBUG] moving_avg (after masking):\n{moving_avg}")
        max_avg_entropy, _ = moving_avg.max(dim=-1)
        return max_avg_entropy.squeeze(1)

    print("\n[DEBUG] compute_logit_entropy called")
    print(f"[DEBUG] mode: {mode}, top_k: {top_k}, window_size: {window_size}, stride: {stride}")
    
    if top_k is not None:
        top_k = int(top_k)
        if top_k <= 0 or top_k > prompt_logits.shape[2]:
            raise ValueError("top_k must be a positive integer less or equal to vocab size")

    if window_size is not None:
        if stride is None:
            stride = window_size
        else:
            stride = int(stride)
            if stride <= 0 or stride > window_size:
                raise ValueError("stride must be a positive integer <= window_size.")
    else:
        stride = None

    if mode == "prompt":
        entropy, mask = entropy_from_logits(prompt_logits, prompt_attention_mask, top_k)
    elif mode == "generation":
        entropy, mask = entropy_from_logits(gen_logits, gen_attention_mask, top_k)
    elif mode == "promptGeneration":
        ent_p, mask_p = entropy_from_logits(prompt_logits, prompt_attention_mask, top_k)
        ent_g, mask_g = entropy_from_logits(gen_logits, gen_attention_mask, top_k)
        entropy = torch.cat([ent_p, ent_g], dim=1)
        mask = torch.cat([mask_p, mask_g], dim=1)
        print(f"[DEBUG] After concat: entropy shape = {entropy.shape}, mask shape = {mask.shape}")
    else:
        raise ValueError("mode must be in {'prompt','generation','promptGeneration'}")

    if window_size is None:
        result = average_entropy(entropy, mask)
    else:
        if window_size <= 0:
            raise ValueError("window_size must be a positive integer")
        if window_size > entropy.shape[1]:
            raise ValueError("window_size greater than sequence length")
        result = max_sliding_window_entropy(entropy, mask, window_size, stride)

    print(f"\n[DEBUG] Final entropy per sample: {result}")
    return result.cpu().numpy()


import torch
import numpy as np
import torch.nn.functional as F

# Dimensions
batch_size = 2
prompt_len = 5
gen_len = 4
vocab_size = 3

# Logits aléatoires
torch.manual_seed(42)
prompt_logits = torch.randn(batch_size, prompt_len, vocab_size)
gen_logits = torch.randn(batch_size, gen_len, vocab_size)

# Masques d’attention avec padding (0 = padding)
prompt_mask = torch.tensor([
    [0, 0, 1, 1, 1],  # 3 tokens valides
    [1, 1, 1, 1, 1],  # tous valides
], dtype=torch.float32)

gen_mask = torch.tensor([
    [1, 1, 0, 0],     # 2 tokens valides
    [1, 1, 1, 1],     # 3 valides
], dtype=torch.float32)

# Appel avec top_k et fenêtre
result = compute_logit_entropy(
    prompt_logits=prompt_logits,
    gen_logits=gen_logits,
    prompt_attention_mask=prompt_mask,
    gen_attention_mask=gen_mask,
    mode="promptGeneration",
    top_k=2,
    window_size=1,
    stride=1
)

print("\n[TEST RESULT] Entropy scores:\n", result)


In [None]:
import torch
import torch.nn.functional as F
import numpy as np

def compute_logit_entropy(
    prompt_logits: torch.Tensor,
    gen_logits: torch.Tensor,
    prompt_attention_mask: torch.Tensor,
    gen_attention_mask: torch.Tensor,
    mode: str = "prompt",
    top_k: int = None
) -> np.ndarray:
    """
    Computes the per-sample entropy of a language model's output distributions
    using its logits and attention masks.
    For each token position, the function computes the entropy of the softmax distribution
    over the vocabulary. Entropy is averaged over the valid tokens (i.e., those marked
    as 1 in the attention mask). If `top_k` is specified, the entropy is computed only
    over the top-k logits (highest values) for each position.

    Entropy is defined as:
        Entropy = -Sum_i p_i * log(p_i)
        where p_i = softmax(logits)_i

    Parameters
    ----------
    prompt_logits : torch.Tensor
        Tensor of shape (batch_size, prompt_len, vocab_size).
       These are the model's output logits obtained from a standard forward pass over the prompt sequence.
    gen_logits : torch.Tensor
        Tensor of shape (batch_size, gen_len, vocab_size).
        These are the logits obtained during autoregressive decoding using `model.generate()`.
    prompt_attention_mask : torch.Tensor
        Tensor of shape (batch_size, prompt_len). Contains 1 where the token is valid and 0 for padding.
    gen_attention_mask : torch.Tensor  
        Tensor of shape (batch_size, gen_len). Contains 1 where the token is valid and 0 for padding.
    mode : str, optional
        One of {"prompt", "generation", "promptGeneration"}:
        - "prompt": compute entropy only over the prompt tokens.
        - "generation": compute entropy only over the generated tokens.
        - "promptGeneration": compute entropy over both prompt and generated tokens.
          In this case, entropies are summed over all valid tokens and averaged globally.
    top_k : int, optional
        If specified, compute entropy only over the top-k logits per token.
        Useful for estimating uncertainty in the most likely predictions.

    Returns
    -------
    np.ndarray
        Per-sample entropy values as a tensor of shape (batch_size,).
        Each value is the average entropy over valid tokens for that sample.
    """

    def entropy_from_logits(logits, attention_mask, top_k=None):
        """
        Parameters
        ---------
        logits: (batch_size, seq_len, vocab_size)
        attention_mask: (batch_size, seq_len)
        
        Returns
        -------
        total_entropy: (batch_size,)
        valid_token_count: (batch_size,)
        """
        # Convert float16 -> float32 for better accuracy during computations
        logits = logits.float()
        attention_mask = attention_mask.float()

        if top_k is not None:
            topk_vals = torch.topk(logits, k=top_k, dim=-1).values # (batch_size, seq_len, top_k)
            probs = F.softmax(topk_vals, dim=-1) # (batch_size, seq_len, top_k)
        else:
            probs = F.softmax(logits, dim=-1) # (batch_size, seq_len, vocab_size)

        # Use torch.special.entr, which automatically handles edge cases
        # entropy(x) = -x * log(x) with entropy(0) = 0
        entropy = torch.special.entr(probs).sum(dim=-1) # (batch_size, seq_len)

        entropy_masked = entropy * attention_mask       # (batch_size, seq_len)
        total_entropy = entropy_masked.sum(dim=-1)      # (batch_size,)
        valid_token_count = attention_mask.sum(dim=-1)  # (batch_size,)

        return total_entropy.cpu().numpy(), valid_token_count.cpu().numpy()  # both are (batch_size,)
    
    if top_k is not None:
        top_k = int(top_k)
        if top_k < 0 or top_k > prompt_logits.shape[2]: raise ValueError("top_k must be an integer between 0 and vocab_size")

    if mode == "prompt":
        total_entropy, count = entropy_from_logits(prompt_logits, prompt_attention_mask, top_k)
        return total_entropy / (count + 1e-9) # (batch_size,)

    elif mode == "generation":
        total_entropy, count = entropy_from_logits(gen_logits, gen_attention_mask, top_k)
        return total_entropy / (count + 1e-9) # (batch_size,)

    elif mode == "promptGeneration":
        # Combine prompt and gen entropies
        ent_prompt, count_prompt = entropy_from_logits(prompt_logits, prompt_attention_mask, top_k)
        ent_gen, count_gen = entropy_from_logits(gen_logits, gen_attention_mask, top_k)

        total_ent = ent_prompt + ent_gen         # (batch_size,)
        total_count = count_prompt + count_gen   # (batch_size,)
        return total_ent / (total_count + 1e-9)  # (batch_size,)

    else:
        raise ValueError(f"Unknown mode: {mode}. Must be 'prompt', 'generation' or 'promptGeneration'")
    

In [None]:
# argument pour spécifier ce qu'on récupère comme aggrégation de tokens !!! 

def run_prompt_and_generation_activation_extraction(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    dataset: Dataset,
    batch_size: int = 4,
    idx_start_sample: int = 0,
    max_samples: int = 1000,
    save_to_pkl: bool = False,
    output_path: str = "outputs/all_batch_results.pkl",
    build_prompt_fn: Callable[[str, str], str] = None,
    register_generation_activation_hook_fn: Callable = None,
    layers: List[int] = [-1],  
    extract_token_activations_fn: Callable = None,
    activation_source: Literal["prompt", "generation", "promptGeneration"] = "generation",
    k_beams : int = 1,
    start_offset : int = 0,
    end_offset : int = 0,
) -> Union[List[torch.Tensor], None]:
    """
    Runs batched inference on a dataset using a decoder-only language model.
    For each batch, it performs text generation and extracts token-level hidden activations 
    (both from the prompt and the generated text depending on `activation_source`) 
    from specified transformer layers.

    Hidden states are captured via a forward hook during generation, then aligned and 
    filtered using attention masks. 
    These activations are saved as individual batch files in a specified pickle directory, 
    allowing efficient incremental storage and later aggregation.
    Alternatively, the representations can be returned directly.

    Parameters
    ----------
    model : PreTrainedModel
        The causal language model to evaluate (e.g., LLaMA).
    tokenizer : PreTrainedTokenizer
        The corresponding tokenizer.
    dataset : Dataset
        The input dataset.
    batch_size : int
        Number of samples per batch.
    idx_start_sample : int
        Index of the first sample to process from the dataset.
    max_samples : int
        Total number of examples to process from the dataset, starting from idx_start_sample. 
    save_to_pkl : bool
        If True, activations are appended to the pickle file at output_path.
        If False, the function returns a list of activations.
    output_path : str
        Path to the directory where extracted answers will be saved as individual pickle batch files.
    build_prompt_fn : Callable
        Function to build a prompt from context and question.
    register_generation_activation_hook_fn : Callable
        Function that registers a forward hook on the model during autoregressive text generation.
    layers : List[int]
        List of indices of the transformer layers to extract activations from (default: [-1] for last layer).
    extract_token_activations_fn : Callable
        Function that selects and aggregates token-level activations. 
    activation_source : {"prompt", "generation", "promptGeneration"}
        Which part of the sequence to extract activations from:
        - "prompt": only from the prompt
        - "generation": only from the generated answer
        - "promptGeneration": prompt and generation answer both concatenated
    k_beams : int, optional
        Number of beams for beam search during generation (default: 1). If 1, uses sampling. 
    start_offset : int
        Offset from the first non-padding token (must be >= 0). 
    end_offset : int
        Offset from the last non-padding token (must be <= 0, e.g., -3 to remove 3 tokens).
    
    Returns
    -------
    Union[
        List[torch.Tensor],
        None
    ]
        If save_to_pkl is False 
            Returns batch_activations: list of length `num_samples`, each element is a tensor 
            of shape (1, hidden_size), containing the selected and aggragated token activations.
        If save_to_pkl is True:
            Returns None (activations are saved incrementally to output_path).
    """

    hidden_scores = ["average", "last", "max", "first_generated", "token_svd_score", "feat_var"]
    attn_scores  = ["attn_eig_prod"]
    logit_scores = ["perplexity", "logit_entropy", "window_logit_entropy"] #=> avec la config qui va avec ??

    if activation_source not in ('prompt', 'generation', 'promptGeneration'):
        raise ValueError(
                f"Invalid value for `activation_source`: '{activation_source}'. "
                f"Expected one of: ['prompt', 'generation', 'promptGeneration']."
            )    
        
    batch_activations = []  # Chosen token activation vectors

    # ==============================
    # Patch selected layer(s) with custom LlamaAttention Forward function to retrieve attention weights
    # ==============================
    for idx in layers:  
        model.model.layers[idx].self_attn.forward = patched_LlamaAttention_forward.__get__(
            model.model.layers[idx].self_attn,
            model.model.layers[idx].self_attn.__class__
    )

    for i in tqdm(range(idx_start_sample, idx_start_sample + max_samples, batch_size)):
        
        # ==============================
        # Prepare input batch
        # ==============================
        batch = extract_batch(dataset, i, batch_size)
        prompts = [build_prompt_fn(s["context"], s["question"]) for s in batch]
        inputs = tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(model.device)
        prompt_ids = inputs["input_ids"] # (batch_size, prompt_len)
        prompt_len = prompt_ids.shape[1] # Assumes prompts are padded to same length

        print(f"[INFO] prompt_ids shape: {prompt_ids.shape}, prompt_len: {prompt_len}")

        # ==============================
        # Register forward hook to capture layer output
        # ==============================
        # This hook collects the hidden states at each decoding step. For layer l: 
        # activations_lists[l] = [act_prompt, act_gen_step1, ..., act_gen_step49] of length 50, if max_new_tokens=50.
        # activations_lists[l][k] of shape: (batch_size, seq_len, hidden_size) 
        activations_lists = [[] for _ in layers]  # one empty list per layer 
        handle_act, call_counter_act = register_generation_activation_hook(model, activations_lists, layers)

        # This hook collects the activations at each decoding step. For layer l: 
        # attentions_lists[l] = [attn_prompt, attn_gen_step1, ..., attn_gen_step49], of length 50, if max_new_tokens=50.
        # activations_lists[l][k] of shape: (batch_size, n_heads, tgt_seq_len, src_seq_len)
        #   tgt_seq_len: length of the sequence the model is currently producing (query)
        #   src_seq_len: length of the sequence the model is focusing on (key/value)
        attentions_lists = [[] for _ in layers]  # one empty list per layer
        handle_attn, call_counter_attn = register_generation_attention_hook(model, attentions_lists, layers)

        # ==============================
        # Forward pass to the model to retrieve prompt logits 
        # ==============================
        if len(logit_scores) > 0:
            gen_logits = torch.stack(outputs.logits, dim=1) 
            with torch.no_grad():
                prompt_logits = model(input_ids=inputs["input_ids"]).logits
        
        # ==============================
        # Run model generation (hook captures activations and attentions)
        # ==============================
        # When target layers are reached, hooks execute and saves their output in activations and attentions
        print("[INFO] Starting generation...")
        outputs = generate(model, inputs, tokenizer, max_new_tokens=50, k_beams=k_beams)
        gen_ids = outputs.sequences[:, prompt_len:]
        print(f"[INFO] gen_ids shape: {gen_ids.shape}")
        print(f"[INFO] Sample generated tokens: {gen_ids}")

        # Remove hooks to avoid memory leaks or duplicate logging
        for h in handle_act: h.remove()
        for h in handle_attn: h.remove()
        
        # Verify that hooks worked properly
        verify_call_counters(call_counter_act, name="activation hooks")
        verify_call_counters(call_counter_attn, name="attention hooks")

        # Retrieve text of generated answers
        gen_answers = tokenizer.batch_decode(
            outputs.sequences[:, prompt_len:], 
            skip_special_tokens=True
        ) # (batch_size,)
        print(f"[INFO] Sample generated answer: {gen_answers}")
        
        # ===============================
        # Build generation and prompt attention mask
        # ===============================
        # This mask marks which generated tokens are valid (i.e., not padding).
        # Positions are marked True up to and including the first eos_token_id
        generation_attention_mask = build_generation_attention_mask(
            gen_ids=gen_ids, 
            eos_token_id=tokenizer.eos_token_id
        ) # (batch_size, gen_len)

        prompt_attention_mask = inputs["attention_mask"] 
        # (batch_size, prompt_len)

        print(f"[INFO] prompt_attention_mask: {prompt_attention_mask.shape}")
        print(f"[INFO] generation_attention_mask: {generation_attention_mask.shape}")
        print(f"[INFO] generation_attention_mask: {generation_attention_mask}")

        # Modify prompt attention mask with offsets
        if start_offset !=0 or end_offset !=0:
            print(f"[INFO] Offsetting prompt_attention_mask with start={start_offset}, end={end_offset}")
            prompt_attention_mask, start_indices, end_indices = compute_offset_attention_mask(
                attention_mask=prompt_attention_mask, 
                start_offset=start_offset, 
                end_offset=end_offset
            ) # (batch_size, prompt_len), (batch_size,), (batch_size,)

        print(f"[INFO] New prompt_attention_mask shape: {prompt_attention_mask.shape}")
        print(f"[INFO] New prompt_attention_mask : {prompt_attention_mask}")

        # Concatenate the prompt and generation attention mask
        prompt_and_gen_attention_mask = torch.cat(
            [prompt_attention_mask,
            generation_attention_mask],
            dim=1
        ) # (batch_size, prompt_len + gen_len)

        print(f"[INFO] prompt_and_gen_attention_mask shape: {prompt_and_gen_attention_mask.shape}")
        print(f"[INFO] prompt_and_gen_attention_mask : {prompt_and_gen_attention_mask}")

        # ===============================
        # Truncate generated token IDs and mask to match activations and attentions
        # ===============================
        # When N tokens are generated, only the first N-1 tokens have corresponding hidden states.
        # So activations[1:] covers only the first N-1 steps. Therefore, we exclude the last
        # generated token from outputs.sequences to match activations[1:]. Same for attentions.
        truncated_gen_ids = gen_ids[:,:-1] # (gen_len-1,)
        truncated_generation_attention_mask = generation_attention_mask[:,:-1] # (batch_size, gen_len-1)
        truncated_prompt_and_gen_attention_mask = prompt_and_gen_attention_mask[:,:-1] # (batch_size, prompt_len + gen_len-1)

        print(f"[INFO] Truncated gen_ids shape: {truncated_gen_ids.shape}")
        print(f"[INFO] Truncated generation_attention_mask shape: {truncated_generation_attention_mask.shape}")
        print(f"[INFO] Truncated prompt_and_gen_attention_mask shape: {truncated_prompt_and_gen_attention_mask.shape}")
        print(f"[INFO] Truncated gen_ids : {truncated_gen_ids}")
        print(f"[INFO] Truncated generation_attention_mask : {truncated_generation_attention_mask}")
        print(f"[INFO] Truncated prompt_and_gen_attention_mask : {truncated_prompt_and_gen_attention_mask}")

        # *******************************
        # START: loop on layers
        # *******************************
        for l in range(len(layers)):
            layer_idx = layers[l]
            print(f"\n----- Layer {layer_idx} -----")

            activations = activations_lists[l]
            attentions = attentions_lists[l]

            print("============")
            print("[INFO] Length of activations:", len(activations))
            for i in range(len(activations)):
                print(f"[INFO] Shape  of activations[{i}]: {activations[i].shape}") 
            print("============")
            print("[INFO] Length of attentions:", len(attentions))
            for i in range(len(attentions)):
                print(f"[INFO] Shape  of attentions[{i}]: {attentions[i].shape}") 
            print("============")

            # Define prompt and generation hidden states 
            prompt_activations=activations[0]       # `[0]` to include only the prompt part 
            generation_activations=activations[1:]  # `[1:]` to exclude the prompt part 
            
            # Define prompt and generation attention maps
            prompt_attentions=attentions[0]         # `[0]` to include only the prompt part 
            generation_attentions=attentions[1:]    # `[1:]` to exclude the prompt part 

            print(f"[DEBUG] prompt_activations shape: {prompt_activations.shape}")
            print(f"[DEBUG] prompt_attentions shape: {prompt_attentions.shape}")

            # ===============================
            # Align generated and prompt hidden states
            # ===============================
            # For each batch item, take the last generated hidden state at this step
            stacked_generation_activations = torch.stack(
                [h[:, -1, :] for h in generation_activations], dim=1
            ) # (batch_size, gen_len, hidden_size)

            print(f"[DEBUG] stacked_generation_activations shape: {stacked_generation_activations.shape}")

            # Concatenate the prompt and generation aligned hidden states  
            prompt_and_gen_activations = torch.cat(
                [stacked_generation_activations, # (batch_size, gen_len, hidden_size)
                prompt_activations],             # (batch_size, prompt_len, hidden_size)
                dim=1
            ) # (batch_size, prompt_len + gen_len, hidden_size)
            
            print(f"[DEBUG] prompt_and_gen_activations shape: {prompt_and_gen_activations.shape}")

            # ==============================
            # Extract token activations from captured layer, based on source
            # ==============================
            print(f"[INFO] Activation source: {activation_source}")

            if len(hidden_scores) > 0:
                if activation_source == "generation":
                    # Return only the token activations from the generated answer 
                    selected_token_vecs = extract_token_activations(                 ##### extract_token_activations_fn 
                            selected_layer=stacked_generation_activations, 
                            attention_mask=truncated_generation_attention_mask, 
                            device=stacked_generation_activations.device,
                            modes=hidden_scores,
                        ) # (batch_size, hidden_size)
                    
                elif activation_source == "prompt":    
                    # Return only the token activations from the prompt
                    selected_token_vecs = extract_token_activations(
                            selected_layer=prompt_activations, 
                            attention_mask=prompt_attention_mask, 
                            device=prompt_activations.device,
                            modes=hidden_scores,
                        ) # (batch_size, hidden_size)
                    
                else: # activation_source == "promptGeneration"
                    # Return token activations from the concatenated prompt + generated answer 
                    selected_token_vecs = extract_token_activations(
                            selected_layer=prompt_and_gen_activations, 
                            attention_mask=truncated_prompt_and_gen_attention_mask, 
                            device=prompt_and_gen_activations.device,
                            skip_length=prompt_len,
                            modes=hidden_scores,
                            # skip_length: exclude prompt from computation if 
                            # mode=='first_generated' in `extract_token_activations_fn`
                        ) # (batch_size, hidden_size)
                
            print(f"[RESULT] selected_token_vecs sample:\n{selected_token_vecs}")

            if 'attn_eig_prod' in attn_scores:
                attn_eig_prod = compute_attn_eig_prod(
                        prompt_attentions=prompt_attentions, 
                        generation_attentions=generation_attentions,
                        prompt_attention_mask=prompt_attention_mask, 
                        generation_attention_mask=truncated_generation_attention_mask,
                        mode=activation_source,
                )
                print(f"[RESULT] attn_eig_prod:\n{attn_eig_prod}")
            
            if 'perplexity' in logit_scores:
                perplexity = compute_perplexity(
                    prompt_logits=prompt_logits, 
                    gen_logits=gen_logits,
                    prompt_ids=prompt_ids, 
                    gen_ids=gen_ids,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=generation_attention_mask,
                    mode=activation_source,
                    min_k=None
                )
                print(f"[RESULT] perplexity:\n{perplexity}")

            if 'logit_entropy' in logit_scores:
                logit_entropy = compute_logit_entropy(
                    prompt_logits=prompt_logits,
                    gen_logits=gen_logits,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=generation_attention_mask,
                    mode=activation_source,
                    top_k=50,
                    window_size=None,
                    stride=None
                )
                print(f"[RESULT] logit_entropy:\n{logit_entropy}")
            
            if 'window_logit_entropy' in logit_scores:
                window_logit_entropy = compute_logit_entropy(
                    prompt_logits=prompt_logits,
                    gen_logits=gen_logits,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=generation_attention_mask,
                    mode=activation_source,
                    top_k=50,
                    window_size=1,
                    stride=1
                )
                print(f"[RESULT] window_logit_entropy:\n{window_logit_entropy}")
            

        # *******************************
        # END: loop on layers
        # *******************************


In [None]:
# VERIFIER que dernier logit du prompt avec forward = premier logit avec model.generate()
# CONCLUSION: ca fonctionne avec gpt2 mais pas avec LLama. 

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# Charger modèle et tokenizer
'''model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
'''
model, tokenizer = load_llama(MODEL_NAME)

model.eval()

# Phrase prompt
prompt_text = "Hello, how are you?"

# Encoder prompt
inputs = tokenizer(prompt_text, return_tensors="pt")

# Forward pass complet sur le prompt pour obtenir les logits
with torch.no_grad():
    outputs = model(**inputs, return_dict=True)
    prompt_logits = outputs.logits  # shape: (1, prompt_len, vocab_size)

# Générer la suite à partir du prompt, en conservant le cache pour extraire logits
with torch.no_grad():
    generate_outputs = model.generate(
        **inputs,
        max_new_tokens=5,
        output_scores=True,
        output_logits=True, 
        return_dict_in_generate=True,
        do_sample=False,
        use_cache=False
    )

# Extraire les logits des tokens générés lors de la génération autoregressive
gen_logits = torch.stack(generate_outputs.logits, dim=1)  # (1, gen_len, vocab_size)

# Vérifier les shapes
print(f"Prompt logits shape: {prompt_logits.shape}")
print(f"Gen logits shape: {gen_logits.shape}")

# Comparer le dernier logit du prompt avec le premier logit generé par generate()
last_prompt_logit = prompt_logits[:, -1, :]
first_gen_logit = gen_logits[:, 0, :]

last_prompt_logit = last_prompt_logit.float()
first_gen_logit = first_gen_logit.float()

print("Vérification que les logits correspondent (avec tolérance numérique) :")
are_close = torch.allclose(last_prompt_logit, first_gen_logit, atol=1e-5)
print(f"Logits égaux ? {are_close}")

# Vérifier aussi les tokens prédit par ces logits
last_prompt_token = last_prompt_logit.argmax(dim=-1)
first_gen_token = first_gen_logit.argmax(dim=-1)

print(f"Token prédit par dernier logit prompt : {tokenizer.decode(last_prompt_token)}")
print(f"Token prédit par premier logit génération : {tokenizer.decode(first_gen_token)}")
print(f"Tokens prédits identiques ? {torch.equal(last_prompt_token, first_gen_token)}")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Prompt logits shape: torch.Size([1, 7, 32000])
Gen logits shape: torch.Size([1, 5, 32000])
Vérification que les logits correspondent (avec tolérance numérique) :
Logits égaux ? False
Token prédit par dernier logit prompt : I
Token prédit par premier logit génération : I
Tokens prédits identiques ? True




In [None]:
def run_prompt_score_extraction(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    dataset: Dataset,
    batch_size: int = 4,
    idx_start_sample: int = 0,
    max_samples: int = 1000,
    save_to_pkl: bool = False,
    output_path: str = "outputs/all_batch_results.pkl",
    build_prompt_fn: Callable[[str, str], str] = None,
    layers: List[int] = [-1],  
    hidden_scores: List[str] = ["average", "last", "max", "first_generated", "token_svd_score", "feat_var"],
    attn_scores: List[str] = ["attn_eig_prod"],
    logit_scores: List[str] = ["perplexity", "logit_entropy", "window_logit_entropy"],
    logit_config: dict = {"top_k": 50, "window_size": 1, "stride": 1},
    start_offset : int = 0,
    end_offset : int = 0,
) -> Union[List[torch.Tensor], None]:
    """
    Runs batched inference on a dataset using a decoder-only language model.
    For each batch, it runs a forward pass on the prompt and extracts token-level hidden 
    activations, attention maps and logit scores from specified transformer layers.

    The function supports multiple aggregation modes for the activations (`hidden_scores`), attention-based 
    scores (`attn_scores`), and logit-based scores (`logit_scores`). The `logit_config` argument provides 
    configuration parameters for logit-based score functions.
    
    Hidden states and attention maps are captured via forward hooks, 
    then aggregated based on token position and attention masks.
    
    These activations are saved as individual batch files in a specified pickle directory, 
    allowing efficient incremental storage and later aggregation.
    Alternatively, the representations can be returned directly.

    Parameters
    ----------
    model : PreTrainedModel
        The causal language model to evaluate (e.g., LLaMA).
    tokenizer : PreTrainedTokenizer
        The corresponding tokenizer.
    dataset : Dataset
        The input dataset.
    batch_size : int
        Number of samples per batch.
    idx_start_sample : int
        Index of the first sample to process from the dataset.
    max_samples : int
        Total number of examples to process from the dataset, starting from idx_start_sample. 
    save_to_pkl : bool
        If True, activations are appended to the pickle file at output_path.
        If False, the function returns a list of activations.
    output_path : str
        Path to the directory where extracted answers will be saved as individual pickle batch files.
    build_prompt_fn : Callable
        Function to build a prompt from context and question.
    layers : List[int]
        List of indices of the transformer layers to extract activations from (default: [-1] for last layer).
    hidden_scores : List[str], optional
        List of aggregation modes to compute on token activations. Possible modes include:
            "average", "last", "max", "first_generated", "token_svd_score", "feat_var".
        These modes are passed to `extract_token_activations` for aggregation. Default includes the above.
    attn_scores : List[str], optional
        List of attention-based scores to compute. Supported: "attn_eig_prod".
    logit_scores : List[str], optional
        List of logit-based scores to compute. Supported:
            "perplexity", "logit_entropy", "window_logit_entropy".
    logit_config : dict, optional
        Configuration dictionary for logit-based scoring functions, with keys such as:
            - "top_k": int, number of top logits considered (default 50)
            - "window_size": int, window size for windowed entropy (default 1)
            - "stride": int, stride for windowed entropy (default 1)
    start_offset : int
        Offset from the first non-padding token (must be >= 0). 
    end_offset : int
        Offset from the last non-padding token (must be <= 0, e.g., -3 to remove 3 tokens).
    
    Returns
    -------
    Union[List[dict], None]
        If `save_to_pkl` is False, returns a list of dictionaries, one per batch, with each element
         of the list having the following structure:
            {
                "id": List[str],  # IDs of batch samples
                "original_indices": List[int],  # Original dataset indices
                "context": List[str],
                "question": List[str],
                "gt_answers": List[str],        # Ground-truth reference answers
                "gen_answers": List[str],       # Generated model answers
                "scores": {
                    "layer_{layer_idx}": {
                        "hidden": { 
                            "{mode}": np.ndarray[(batch_size, hidden_size), float], 
                            ... # one entry per mode in hidden_scores
                        },
                        "attention": {
                            "{attn_score}": np.ndarray[(batch_size,), float],  
                            ...
                        }
                    },
                    "logits": {
                        "perplexity": np.ndarray[(batch_size,), float],
                        "logit_entropy": np.ndarray[(batch_size,), float],
                        "window_logit_entropy": np.ndarray[(batch_size,), float] 
                    }
                }
            },

        If `save_to_pkl` is True, saves each batch's dictionary incrementally to disk and returns None.
    """

    # ==============================================================  
    # [PATCH] Replace LlamaAttention.forward on target layers by
    #  custom module to extract attention weights
    # ==============================================================
    for idx in layers:  
        model.model.layers[idx].self_attn.forward = patched_LlamaAttention_forward.__get__(
            model.model.layers[idx].self_attn,
            model.model.layers[idx].self_attn.__class__
    )
        
    # ==============================================================  
    # [LOOP] Process batches of examples  
    # ==============================================================
    all_batch_results = []  

    for i in tqdm(range(idx_start_sample, idx_start_sample + max_samples, batch_size)):
      
        # ----------------------------------------------------------
        # [BATCH INPUT] Extract and tokenize prompts
        # ----------------------------------------------------------
        batch = extract_batch(dataset, i, batch_size)
        prompts = [build_prompt_fn(s["context"], s["question"]) for s in batch]
        inputs = tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(model.device)
        prompt_ids = inputs["input_ids"] # (batch_size, prompt_len)
        prompt_attention_mask = inputs["attention_mask"] 

        # ----------------------------------------------------------
        # [HOOKS] Register hooks to capture hidden states and attentions
        # The activations/attention retrieved by the hooks are have similar values 
        # as the ones from `output_hidden_states=True`/`output_attentions=True` in `model.generate()`
        # ----------------------------------------------------------
        # This hook collects the hidden states. For layer l: 
        # activations_lists[l] = [act_prompt], 
        # activations_lists[l][0] of shape: (batch_size, prompt_len, hidden_size) 
        activations_lists = [[] for _ in layers]  # one empty list per layer 
        handle_act, call_counter_act = register_generation_activation_hook(model, activations_lists, layers)

        # This hook collects the activations at each decoding step. For layer l: 
        # attentions_lists[l] = [attn_prompt], 
        # activations_lists[l][0] of shape: (batch_size, n_heads, prompt_len, prompt_len)
        attentions_lists = [[] for _ in layers]  # one empty list per layer
        handle_attn, call_counter_attn = register_generation_attention_hook(model, attentions_lists, layers)
        
        # ----------------------------------------------------------
        # [FOWARD PASS] Run model with hooks to capture intermediate states
        # ----------------------------------------------------------
        # Pass inputs through the model. When the target layer is reached,
        # the hook executes and saves its output in captured_hidden.
        if logit_scores is not None and len(logit_scores) > 0:
            with torch.no_grad():
                outputs = model(**inputs, return_dict=True, return_logits=True)
            prompt_logits = outputs.logits
        else:
            with torch.no_grad():
                outputs = model(**inputs, return_dict=True)
        
        # Remove hooks to avoid memory leaks or duplicate logging
        for h in handle_act: h.remove()
        for h in handle_attn: h.remove()
        
        # Verify that hooks worked properly
        verify_call_counters(call_counter_act, name="activation hooks")
        verify_call_counters(call_counter_attn, name="attention hooks")


        # ----------------------------------------------------------
        # [OFFSET] Modify prompt mask with offset, if specified
        # ----------------------------------------------------------
        if start_offset !=0 or end_offset !=0:
            prompt_attention_mask, start_indices, end_indices = compute_offset_attention_mask(
                attention_mask=prompt_attention_mask, 
                start_offset=start_offset, 
                end_offset=end_offset
            ) # (batch_size, prompt_len), (batch_size,), (batch_size,)


        # **********************************************************
        # [LAYER LOOP] Extract activation and attention-based scores for each specified layer 
        # **********************************************************
        save_layers_scores = {}

        for l, layer_idx in enumerate(layers):

            activations = activations_lists[l]
            attentions = attentions_lists[l]

            # Define prompt and generation hidden states 
            prompt_activations=activations[0]    
            
            # Define prompt and generation attention maps
            prompt_attentions=attentions[0]        

            # ------------------------------------------------------
            # [HIDDEN SCORES] Extract token-level activations/hidden-states
            # ------------------------------------------------------
            if hidden_scores is not None and len(hidden_scores) > 0:
                # Return only the token activations from the prompt
                selected_token_vecs = extract_token_activations(
                        selected_layer=prompt_activations, 
                        attention_mask=prompt_attention_mask, 
                        device=prompt_activations.device,
                        modes=hidden_scores,
                    ) # (batch_size, hidden_size)
 
                # Save results to dict
                hidden_results = {}
                for mode in hidden_scores:
                    if mode in selected_token_vecs:
                        hidden_results[mode] = selected_token_vecs[mode].cpu().numpy()
                save_layers_scores.setdefault(f"layer_{layer_idx}", {}).update({"hidden": hidden_results})

            # ------------------------------------------------------
            # [ATTENTION SCORES] Extract attention eigenvalue-based metric
            # ------------------------------------------------------
            if attn_scores is not None and 'attn_eig_prod' in attn_scores:
                attn_eig_prod = compute_attn_eig_prod(
                        prompt_attentions=prompt_attentions, 
                        generation_attentions=None,
                        prompt_attention_mask=prompt_attention_mask, 
                        generation_attention_mask=None,
                        mode='prompt',
                )
                # Save results to dict
                save_layers_scores.setdefault(f"layer_{layer_idx}", {}).update({"attention": {"attn_eig_prod": attn_eig_prod}}) 
        
        # **********************************************************
        # [END LAYER LOOP] 
        # **********************************************************

        save_logits_scores = {}
        # ------------------------------------------------------
        # [LOGIT SCORES] Compute metrics from model logits
        # ------------------------------------------------------
        if logit_scores is not None:
            if 'perplexity' in logit_scores:
                perplexity = compute_perplexity(
                    prompt_logits=prompt_logits, 
                    gen_logits=None,
                    prompt_ids=prompt_ids, 
                    gen_ids=None,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=None,
                    mode='prompt',
                    min_k=None
                )
                # Save results to dict
                save_logits_scores['perplexity'] = perplexity 

            if 'logit_entropy' in logit_scores:
                if logit_config is None:
                    raise ValueError("logit_entropy is required but logit_config is None")
                logit_entropy = compute_logit_entropy(
                    prompt_logits=prompt_logits,
                    gen_logits=None,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=None,
                    mode='prompt',
                    top_k=logit_config['top_k'], 
                    window_size=None,
                    stride=None
                )
                # Save results to dict
                save_logits_scores['logit_entropy'] = logit_entropy 
        
            if 'window_logit_entropy' in logit_scores:
                if logit_config is None:
                    raise ValueError("window_logit_entropy is required but logit_config is None")
                window_logit_entropy = compute_logit_entropy(
                    prompt_logits=prompt_logits,
                    gen_logits=None,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=None,
                    mode='prompt',
                    top_k=logit_config['top_k'],
                    window_size=logit_config['window_size'], 
                    stride=logit_config['stride'] 
                )
                # Save results to dict
                save_logits_scores['window_logit_entropy'] = window_logit_entropy 


        # ==========================================================
        # [OUTPUT] Store extracted results (to memory or file)
        # ==========================================================
        batch_results = {
            "id": [s['id'] for s in batch],
            "original_indices": [s['original_index'] for s in batch],
            "context": [s['context'] for s in batch],
            "question": [s['question'] for s in batch],
            "gt_answers": [s['answers'] for s in batch],
            "scores": {**save_layers_scores, **({"logits": save_logits_scores} if save_logits_scores else {})}
        }

        from src.data_reader.pickle_io import save_batch_pickle

        if save_to_pkl:
            save_batch_pickle(batch_data=batch_results, output_dir=output_path, batch_idx=i)
        else:
            all_batch_results.append(batch_results)

    if not save_to_pkl:
        return all_batch_results



def run_prompt_and_generation_score_extraction(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    dataset: Dataset,
    batch_size: int = 4,
    idx_start_sample: int = 0,
    max_samples: int = 1000,
    save_to_pkl: bool = False,
    output_path: str = "outputs/all_batch_results.pkl",
    build_prompt_fn: Callable[[str, str], str] = None,
    layers: List[int] = [-1],  
    activation_source: Literal["prompt", "generation", "promptGeneration"] = "generation",
    hidden_scores: List[str] = ["average", "last", "max", "first_generated", "token_svd_score", "feat_var"],
    attn_scores: List[str] = ["attn_eig_prod"],
    logit_scores: List[str] = ["perplexity", "logit_entropy", "window_logit_entropy"],
    logit_config: dict = {"top_k": 50, "window_size": 1, "stride": 1},
    start_offset : int = 0,
    end_offset : int = 0,
) -> Union[List[torch.Tensor], None]:
    """
    Runs batched inference on a dataset using a decoder-only language model.
    For each batch, it performs text generation and extracts token-level 
    hidden activations, attention maps and logit scores from specified transformer layers.
    (both from the prompt and the generated text depending on `activation_source`) 

    The function supports multiple aggregation modes for the activations (`hidden_scores`), attention-based 
    scores (`attn_scores`), and logit-based scores (`logit_scores`). The `logit_config` argument provides 
    configuration parameters for logit-based score functions.
    
    Hidden states and attention maps are captured via forward hooks during generation, 
    then aggregated based on token position and attention masks.
    
    These activations are saved as individual batch files in a specified pickle directory, 
    allowing efficient incremental storage and later aggregation.
    Alternatively, the representations can be returned directly.

    Parameters
    ----------
    model : PreTrainedModel
        The causal language model to evaluate (e.g., LLaMA).
    tokenizer : PreTrainedTokenizer
        The corresponding tokenizer.
    dataset : Dataset
        The input dataset.
    batch_size : int
        Number of samples per batch.
    idx_start_sample : int
        Index of the first sample to process from the dataset.
    max_samples : int
        Total number of examples to process from the dataset, starting from idx_start_sample. 
    save_to_pkl : bool
        If True, activations are appended to the pickle file at output_path.
        If False, the function returns a list of activations.
    output_path : str
        Path to the directory where extracted answers will be saved as individual pickle batch files.
    build_prompt_fn : Callable
        Function to build a prompt from context and question.
    layers : List[int]
        List of indices of the transformer layers to extract activations from (default: [-1] for last layer).
    activation_source : {"prompt", "generation", "promptGeneration"}
        Which part of the sequence to extract activations/attentions/logits from:
        - "prompt": only from the prompt
        - "generation": only from the generated answer
        - "promptGeneration": prompt and generation answer both concatenated
    hidden_scores : List[str], optional
        List of aggregation modes to compute on token activations. Possible modes include:
            "average", "last", "max", "first_generated", "token_svd_score", "feat_var".
        These modes are passed to `extract_token_activations` for aggregation. Default includes the above.
    attn_scores : List[str], optional
        List of attention-based scores to compute. Supported: "attn_eig_prod".
    logit_scores : List[str], optional
        List of logit-based scores to compute. Supported:
            "perplexity", "logit_entropy", "window_logit_entropy".
    logit_config : dict, optional
        Configuration dictionary for logit-based scoring functions, with keys such as:
            - "top_k": int, number of top logits considered (default 50)
            - "window_size": int, window size for windowed entropy (default 1)
            - "stride": int, stride for windowed entropy (default 1)
    start_offset : int
        Offset from the first non-padding token (must be >= 0). 
    end_offset : int
        Offset from the last non-padding token (must be <= 0, e.g., -3 to remove 3 tokens).
    
    Returns
    -------
    Union[List[dict], None]
        If `save_to_pkl` is False, returns a list of dictionaries, one per batch, with each element
         of the list having the following structure:
            {
                "id": List[str],  # IDs of batch samples
                "original_indices": List[int],  # Original dataset indices
                "context": List[str],
                "question": List[str],
                "gt_answers": List[str],        # Ground-truth reference answers
                "gen_answers": List[str],       # Generated model answers
                "scores": {
                    "layer_{layer_idx}": {
                        "hidden": { 
                            "{mode}": np.ndarray[(batch_size, hidden_size), float], 
                            ... # one entry per mode in hidden_scores
                        },
                        "attention": {
                            "{attn_score}": np.ndarray[(batch_size,), float],  
                            ...
                        }
                    },
                    "logits": {
                        "perplexity": np.ndarray[(batch_size,), float],
                        "logit_entropy": np.ndarray[(batch_size,), float],
                        "window_logit_entropy": np.ndarray[(batch_size,), float] 
                    }
                }
            },

        If `save_to_pkl` is True, saves each batch's dictionary incrementally to disk and returns None.

    Notes
    -----
    When using model.generate() with output_hidden_states=True (what we are replicating here with the ,
    activation hook) use_cache=True and max_new_tokens=30, there is always an offset between the length of the 
    generated sequence (outputs.sequences.shape[1][prompt_len:]) and the length of len(outputs.hidden_states) : 
    * outputs.sequences.shape[1] = prompt_len (17) + max_new_tokens (30) = 47
    * len(outputs.hidden_states) = max_new_tokens (30)
        With : 
        * outputs.hidden_states[0][layer_idx].shape = (batch_size, prompt_len, hidden_size)           --> includes the prompt ! 
        * outputs.hidden_states[i][layer_idx].shape = (batch_size, 1, hidden_size) with 1 <= i <= 29  --> stops at 29 ! 
    *Note* that in our code, outputs.hidden_states and activations are the same. 
        
    Explanation from Hugging Face, April 2024 
    (https://github.com/huggingface/transformers/issues/30036):
    """

    # ==============================================================
    # [VALIDATION] Ensure activation_source is correctly defined
    # ==============================================================
    if activation_source not in ('prompt', 'generation', 'promptGeneration'):
        raise ValueError(
                f"Invalid value for `activation_source`: '{activation_source}'. "
                f"Expected one of: ['prompt', 'generation', 'promptGeneration']."
            )    
        
    # ==============================================================  
    # [PATCH] Replace LlamaAttention.forward on target layers by
    #  custom module to extract attention weights
    # ==============================================================
    for idx in layers:  
        model.model.layers[idx].self_attn.forward = patched_LlamaAttention_forward.__get__(
            model.model.layers[idx].self_attn,
            model.model.layers[idx].self_attn.__class__
    )
        
    # ==============================================================  
    # [LOOP] Process batches of examples  
    # ==============================================================
    all_batch_results = []  

    for i in tqdm(range(idx_start_sample, idx_start_sample + max_samples, batch_size)):
      
        # ----------------------------------------------------------
        # [BATCH INPUT] Extract and tokenize prompts
        # ----------------------------------------------------------
        batch = extract_batch(dataset, i, batch_size)
        prompts = [build_prompt_fn(s["context"], s["question"]) for s in batch]
        inputs = tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(model.device)
        prompt_ids = inputs["input_ids"] # (batch_size, prompt_len)
        prompt_len = prompt_ids.shape[1] # Assumes prompts are padded to same length

        # ----------------------------------------------------------
        # [HOOKS] Register hooks to capture hidden states and attentions
        # ----------------------------------------------------------
        # This hook collects the hidden states at each decoding step. For layer l: 
        # activations_lists[l] = [act_prompt, act_gen_step1, ..., act_gen_step49] of length 50, if max_new_tokens=50.
        # activations_lists[l][k] of shape: (batch_size, seq_len, hidden_size) 
        activations_lists = [[] for _ in layers]  # one empty list per layer 
        handle_act, call_counter_act = register_generation_activation_hook(model, activations_lists, layers)

        # This hook collects the activations at each decoding step. For layer l: 
        # attentions_lists[l] = [attn_prompt, attn_gen_step1, ..., attn_gen_step49], of length 50, if max_new_tokens=50.
        # activations_lists[l][k] of shape: (batch_size, n_heads, tgt_seq_len, src_seq_len)
        #   tgt_seq_len: length of the sequence the model is currently producing (query)
        #   src_seq_len: length of the sequence the model is focusing on (key/value)
        attentions_lists = [[] for _ in layers]  # one empty list per layer
        handle_attn, call_counter_attn = register_generation_attention_hook(model, attentions_lists, layers)
        
        # ----------------------------------------------------------
        # [GENERATION] Run model with hooks to capture intermediate states
        # ----------------------------------------------------------
        # When target layers are reached, hooks execute and saves their output in activations and attentions
        outputs = generate(model, inputs, tokenizer, max_new_tokens=50, k_beams=1)
        gen_ids = outputs.sequences[:, prompt_len:]
    
        # Remove hooks to avoid memory leaks or duplicate logging
        for h in handle_act: h.remove()
        for h in handle_attn: h.remove()
        
        # Verify that hooks worked properly
        verify_call_counters(call_counter_act, name="activation hooks")
        verify_call_counters(call_counter_attn, name="attention hooks")

        # Retrieve text of generated answers
        gen_answers = tokenizer.batch_decode(
            outputs.sequences[:, prompt_len:], 
            skip_special_tokens=True
        ) # (batch_size,)

        # ----------------------------------------------------------
        # [FOWARD PASS] Forward pass to the model to retrieve prompt logits 
        # ----------------------------------------------------------
        if logit_scores is not None and len(logit_scores) > 0:
            gen_logits = torch.stack(outputs.logits, dim=1) 
            with torch.no_grad():
                prompt_logits = model(input_ids=inputs["input_ids"]).logits
        
        # ----------------------------------------------------------
        # [MASKING] Build attention masks for prompt and generation
        # ----------------------------------------------------------
        # This mask marks which generated tokens are valid (i.e., not padding).
        # Positions are marked True up to and including the first eos_token_id
        generation_attention_mask = build_generation_attention_mask(
            gen_ids=gen_ids, 
            eos_token_id=tokenizer.eos_token_id
        ) # (batch_size, gen_len)

        prompt_attention_mask = inputs["attention_mask"] 
        # (batch_size, prompt_len)

        # ----------------------------------------------------------
        # [OFFSET] Modify prompt mask with offset, if specified
        # ----------------------------------------------------------
        if start_offset !=0 or end_offset !=0:
            prompt_attention_mask, start_indices, end_indices = compute_offset_attention_mask(
                attention_mask=prompt_attention_mask, 
                start_offset=start_offset, 
                end_offset=end_offset
            ) # (batch_size, prompt_len), (batch_size,), (batch_size,)

        # ----------------------------------------------------------
        # [MASKING] Concatenate the prompt and generation attention mask
        # ----------------------------------------------------------
        prompt_and_gen_attention_mask = torch.cat(
            [prompt_attention_mask,
            generation_attention_mask],
            dim=1
        ) # (batch_size, prompt_len + gen_len)

        # ----------------------------------------------------------
        # [TRUNCATE] Remove final token from generated outputs to align with activations/attentions
        # ----------------------------------------------------------
        # When N tokens are generated, only the first N-1 tokens have corresponding hidden states.
        # So activations[1:] covers only the first N-1 steps. Therefore, we exclude the last
        # generated token from outputs.sequences to match activations[1:]. Same for attentions.
        truncated_gen_ids = gen_ids[:,:-1] # (gen_len-1,)
        truncated_generation_attention_mask = generation_attention_mask[:,:-1] # (batch_size, gen_len-1)
        truncated_prompt_and_gen_attention_mask = prompt_and_gen_attention_mask[:,:-1] # (batch_size, prompt_len + gen_len-1)

        # **********************************************************
        # [LAYER LOOP] Extract activation and attention-based scores for each specified layer 
        # **********************************************************
        save_layers_scores = {}

        for l, layer_idx in enumerate(layers):

            activations = activations_lists[l]
            attentions = attentions_lists[l]

            # Define prompt and generation hidden states 
            prompt_activations=activations[0]       # `[0]` to include only the prompt part 
            generation_activations=activations[1:]  # `[1:]` to exclude the prompt part 
            
            # Define prompt and generation attention maps
            prompt_attentions=attentions[0]         # `[0]` to include only the prompt part 
            generation_attentions=attentions[1:]    # `[1:]` to exclude the prompt part 

            # ------------------------------------------------------
            # [ALIGNMENT] Stack and concatenate prompt + generation activations
            # ------------------------------------------------------
            # For each batch item, take the last generated hidden state at this step
            stacked_generation_activations = torch.stack(
                [h[:, -1, :] for h in generation_activations], dim=1
            ) # (batch_size, gen_len, hidden_size)

            # Concatenate the prompt and generation hidden states  
            prompt_and_gen_activations = torch.cat(
                [stacked_generation_activations, # (batch_size, gen_len, hidden_size)
                prompt_activations],             # (batch_size, prompt_len, hidden_size)
                dim=1
            ) # (batch_size, prompt_len + gen_len, hidden_size)
            
            # ------------------------------------------------------
            # [HIDDEN SCORES] Extract token-level activations/hidden-states
            # ------------------------------------------------------
            if hidden_scores is not None and len(hidden_scores) > 0:
                if activation_source == "generation":
                    # Return only the token activations from the generated answer 
                    selected_token_vecs = extract_token_activations(               
                            selected_layer=stacked_generation_activations, 
                            attention_mask=truncated_generation_attention_mask, 
                            device=stacked_generation_activations.device,
                            modes=hidden_scores,
                        ) # (batch_size, hidden_size)
                    
                elif activation_source == "prompt":    
                    # Return only the token activations from the prompt
                    selected_token_vecs = extract_token_activations(
                            selected_layer=prompt_activations, 
                            attention_mask=prompt_attention_mask, 
                            device=prompt_activations.device,
                            modes=hidden_scores,
                        ) # (batch_size, hidden_size)
                    
                else: # activation_source == "promptGeneration"
                    # Return token activations from the concatenated prompt + generated answer 
                    selected_token_vecs = extract_token_activations(
                            selected_layer=prompt_and_gen_activations, 
                            attention_mask=truncated_prompt_and_gen_attention_mask, 
                            device=prompt_and_gen_activations.device,
                            skip_length=prompt_len,
                            modes=hidden_scores,
                            # skip_length: exclude prompt from computation if 
                            # mode=='first_generated' in `extract_token_activations_fn`
                        ) # (batch_size, hidden_size)

                # Save results to dict
                hidden_results = {}
                for mode in hidden_scores:
                    if mode in selected_token_vecs:
                        hidden_results[mode] = selected_token_vecs[mode].cpu().numpy()
                save_layers_scores.setdefault(f"layer_{layer_idx}", {}).update({"hidden": hidden_results})

            # ------------------------------------------------------
            # [ATTENTION SCORES] Extract attention eigenvalue-based metric
            # ------------------------------------------------------
            if attn_scores is not None and 'attn_eig_prod' in attn_scores:
                attn_eig_prod = compute_attn_eig_prod(
                        prompt_attentions=prompt_attentions, 
                        generation_attentions=generation_attentions,
                        prompt_attention_mask=prompt_attention_mask, 
                        generation_attention_mask=truncated_generation_attention_mask,
                        mode=activation_source,
                )
                # Save results to dict
                save_layers_scores.setdefault(f"layer_{layer_idx}", {}).update({"attention": {"attn_eig_prod": attn_eig_prod}}) 
        
        # **********************************************************
        # [END LAYER LOOP] 
        # **********************************************************

        save_logits_scores = {}
        # ------------------------------------------------------
        # [LOGIT SCORES] Compute metrics from model logits
        # ------------------------------------------------------
        if logit_scores is not None:
            if 'perplexity' in logit_scores:
                perplexity = compute_perplexity(
                    prompt_logits=prompt_logits, 
                    gen_logits=gen_logits,
                    prompt_ids=prompt_ids, 
                    gen_ids=gen_ids,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=generation_attention_mask,
                    mode=activation_source,
                    min_k=None
                )
                # Save results to dict
                save_logits_scores['perplexity'] = perplexity 

            if 'logit_entropy' in logit_scores:
                if logit_config is None:
                    raise ValueError("logit_entropy is required but logit_config is None")
                logit_entropy = compute_logit_entropy(
                    prompt_logits=prompt_logits,
                    gen_logits=gen_logits,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=generation_attention_mask,
                    mode=activation_source,
                    top_k=logit_config['top_k'], 
                    window_size=None,
                    stride=None
                )
                # Save results to dict
                save_logits_scores['logit_entropy'] = logit_entropy 
        
            if 'window_logit_entropy' in logit_scores:
                if logit_config is None:
                    raise ValueError("window_logit_entropy is required but logit_config is None")
                window_logit_entropy = compute_logit_entropy(
                    prompt_logits=prompt_logits,
                    gen_logits=gen_logits,
                    prompt_attention_mask=prompt_attention_mask,
                    gen_attention_mask=generation_attention_mask,
                    mode=activation_source,
                    top_k=logit_config['top_k'],
                    window_size=logit_config['window_size'], 
                    stride=logit_config['stride'] 
                )
                # Save results to dict
                save_logits_scores['window_logit_entropy'] = window_logit_entropy 


        # ==========================================================
        # [OUTPUT] Store extracted results (to memory or file)
        # ==========================================================
        batch_results = {
            "id": [s['id'] for s in batch],
            "original_indices": [s['original_index'] for s in batch],
            "context": [s['context'] for s in batch],
            "question": [s['question'] for s in batch],
            "gt_answers": [s['answers'] for s in batch],
            "gen_answers": gen_answers,
            "scores": {**save_layers_scores, **({"logits": save_logits_scores} if save_logits_scores else {})}
        }

        from src.data_reader.pickle_io import save_batch_pickle

        if save_to_pkl:
            save_batch_pickle(batch_data=batch_results, output_dir=output_path, batch_idx=i)
        else:
            all_batch_results.append(batch_results)

    if not save_to_pkl:
        return all_batch_results


In [22]:
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer, PreTrainedModel, PreTrainedTokenizer, BatchEncoding
from typing import Union, Dict

def generate2(
    model: PreTrainedModel,
    inputs: BatchEncoding,
    tokenizer: PreTrainedTokenizer,
    max_new_tokens: int = 50,
    k_beams: int = 1,
    **generate_kwargs
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True    if k_beams == 1 else False,
            temperature=0.6   if k_beams == 1 else None,
            top_p=0.9         if k_beams == 1 else None,
            top_k=50          if k_beams == 1 else None,
            num_beams=k_beams,
            use_cache=True, 
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id, # Ensures clean padding (right padding)
            output_hidden_states=True,      # We rely on the hook to extract hidden states instead (more memory efficient)
            output_attentions=False,         # We rely on the hook to extract attention map instead (more memory efficient)
            output_logits=True,              # Logits not filtered/truncated by top-k/top-p sampling. Note: `output_scores=True` returns filtered logits. 
            return_dict_in_generate=True,    # Needed for access to beam_indices when num_beams > 1
            early_stopping=False if k_beams == 1 else True, #Generation stops as soon as any sequence hits EOS, even if other candidates have not yet finished.
            **generate_kwargs                # For future flexibility (e.g., output_attentions, output_scores)
        )
        return outputs 

def compare(l1, l2, name1, name2):
    diff = torch.norm(l1-l2)
    print(f"{name1} vs {name2} close? {torch.allclose(l1, l2, rtol=1e-5, atol=1e-8)}, diff: {diff}")

# Charger modèle et tokenizer (remplacez par votre chemin ou nom HuggingFace)
#model, tokenizer = load_llama(MODEL_NAME)

'''
Setting:
-------
Each time we retrieve the LAST layer hidden state of the PROMPT.

In this code, we compute:
- case 1) last layer prompt hidden state retrieved from output_hidden_states=True in the forward pass, denoted as hidden_forward
- case 2) last layer prompt hidden state retrieved from output_hidden_states=True in model.generate(), denoted as hidden_gen
- case 3) last layer prompt hidden state retrieved from a hook on model.generate(), denoted as hidden_hook

We also compute the logits:
- case a) logits obtained with the built-in argument outputs_forward.logits where outputs_forward is obtained in case 1)
- case b) logits obtained using the logit lens on the activations hidden_forward obtained in case 1)
- case c) logits obtained using the logit lens on the activations hidden_gen obtained in case 2)
- case d) logits obtained using the logit lens on the activations hidden_hook obtained in case 3)

Note: we cannot compare with logits obtained from the built-in argument outputs_gen.logits where outputs_gen
is obtained in case 2), because it only contains logits for the generated prompt tokens and not the logits 
of the original prompt.

Results:
-------
>>> When batch_size = 1:

case 1) = case 3) : forward vs hook close? True
case 1) = case 2) : forward vs gen close? True
case 2) = case 3) : hook vs gen close? True

case a) = case b) : forward vs logitLens forward close? True
case a) = case d) : forward vs logitLens hook close? True
case a) = case c) : forward vs logitLens gen close? True
case b) = case d) : logitLens forward vs logitLens hook close? True
case b) = case c) : logitLens forward vs logitLens gen close? True
case d) = case c) : logitLens hook vs logitLens gen close? True

>>> When batch_size > 1:

case 1) != case 3) : forward vs hook close? False  => differences due to batching (I d'ont know why)
case 1) != case 2) : forward vs gen close? False   =>  
case 2)  = case 3) : hook vs gen close? True       => my hook works properly

case a)  = case b) : forward vs logitLens forward close? True   => my logitLens function works properly
case a) != case d) : forward vs logitLens hook close? False     => differences due to differences in activations 
case a) != case c) : forward vs logitLens gen close? False      
case b) != case d) : logitLens forward vs logitLens hook close? False  
case b) != case c) : logitLens forward vs logitLens gen close? False
case d)  = case c) : logitLens hook vs logitLens gen close? True 

Even when I only compare the non-padded tokens, I see a difference. 
'''

# --- Prompt et tokenisation ---
prompt_text = ["The cat sat on the"]#, "What color is the sky?"]
inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to("cuda")
prompt_mask = inputs['attention_mask']

# --- Forward classique ---
with torch.no_grad():
    outputs_forward = model(inputs['input_ids'], output_hidden_states=True, output_logits=True)

hidden_forward = outputs_forward.hidden_states[-1] # [1, promp_len, hidden_size]

# --- Hook pendant la génération ---
activations = [[]]
handles, _ = register_generation_activation_hook(model, activations, layers_idx_list=[-1])
outputs_gen = generate2(model, inputs, tokenizer)
for h in handles: h.remove()

hidden_hook = activations[0][0].to(hidden_forward.device) # [1, promp_len, hidden_size]
hidden_gen = outputs_gen.hidden_states[0][-1].to(hidden_forward.device) # [1, promp_len, hidden_size]

# --- Compare activations ---
print("Compare activations from different sources")
compare(hidden_forward, hidden_hook, "forward", "hook")
compare(hidden_forward, hidden_gen, "forward", "gen")
compare(hidden_hook, hidden_gen, "hook", "gen")

# --- Logit lens ---
logits_forward = outputs_forward.logits
logits_lens_forward = apply_logit_lens(model, hidden_forward).to(hidden_forward.device)
logits_lens_hook = apply_logit_lens(model, hidden_hook).to(hidden_forward.device)
logits_lens_gen = apply_logit_lens(model, hidden_gen).to(hidden_forward.device)

# --- Comparaison globale ---
print("Compare Logits from different sources")
compare(logits_forward, logits_lens_forward, "forward", "logitLens forward")
compare(logits_forward, logits_lens_hook, "forward", "logitLens hook")
compare(logits_forward, logits_lens_gen, "forward", "logitLens gen")
compare(logits_lens_forward, logits_lens_hook, "logitLens forward", "logitLens hook")
compare(logits_lens_forward, logits_lens_gen, "logitLens forward", "logitLens gen")
compare(logits_lens_hook, logits_lens_gen, "logitLens hook", "logitLens gen")


# --- Comparaison par token pour chaque exemple (sans padding) ---
print("Compare predicted tokens from computed logits")
batch_size = inputs['input_ids'].shape[0]
valid_positions = prompt_mask.bool()

for i in range(batch_size):
    valids = valid_positions[i]
    
    logits_f = logits_forward[i][valids]
    logits_l = logits_lens_forward[i][valids]
    logits_h = logits_lens_hook[i][valids]
    logits_g = logits_lens_gen[i][valids]

    toks_f = torch.argmax(logits_f, dim=-1)
    toks_l = torch.argmax(logits_l, dim=-1)
    toks_h = torch.argmax(logits_h, dim=-1)
    toks_g = torch.argmax(logits_g, dim=-1)

    dec_f = tokenizer.decode(toks_f, skip_special_tokens=True)
    dec_l = tokenizer.decode(toks_l, skip_special_tokens=True)
    dec_h = tokenizer.decode(toks_h, skip_special_tokens=True)
    dec_g = tokenizer.decode(toks_g, skip_special_tokens=True)

    print(f"\n[Exemple {i}] Prompt :", tokenizer.decode(inputs['input_ids'][i], skip_special_tokens=True))
    print("→ forward   :", dec_f)
    print("→ logitLens :", dec_l)
    print("→ hook      :", dec_h)
    print("→ gen       :", dec_g)

    def diff_count(t1, t2): return (t1 != t2).sum().item()
    print("  Différences (fwd vs hook):", diff_count(toks_f, toks_h))
    print("  Différences (fwd vs gen) :", diff_count(toks_f, toks_g))
    print("  Différences (hook vs gen):", diff_count(toks_h, toks_g))


# --- Comparaison aplatie ---
print("=============================")
print("Flatten to compare logits only with non padded part")
valid_flat = prompt_mask.view(-1).bool()

logits_f_flat = logits_forward.view(-1, logits_forward.shape[-1])[valid_flat]
logits_l_flat = logits_lens_forward.view(-1, logits_forward.shape[-1])[valid_flat]
logits_h_flat = logits_lens_hook.view(-1, logits_forward.shape[-1])[valid_flat]
logits_g_flat = logits_lens_gen.view(-1, logits_forward.shape[-1])[valid_flat]


# Compare logits 
print("Compare Logits from different sources")
compare(logits_f_flat, logits_l_flat, "forward", "logitLens forward")
compare(logits_f_flat, logits_h_flat, "forward", "logitLens hook")
compare(logits_f_flat, logits_g_flat, "forward", "logitLens gen")
compare(logits_l_flat, logits_h_flat, "logitLens forward", "logitLens hook")
compare(logits_l_flat, logits_g_flat, "logitLens forward", "logitLens gen")
compare(logits_h_flat, logits_g_flat, "logitLens hook", "logitLens gen")


toks_f_flat = torch.argmax(logits_f_flat, dim=-1)
toks_l_flat = torch.argmax(logits_l_flat, dim=-1)
toks_h_flat = torch.argmax(logits_h_flat, dim=-1)
toks_g_flat = torch.argmax(logits_g_flat, dim=-1)

print("\n===== Version aplatie (tous les tokens valides du batch) =====")
print("→ forward   :", tokenizer.decode(toks_f_flat, skip_special_tokens=True))
print("→ logitLens :", tokenizer.decode(toks_l_flat, skip_special_tokens=True))
print("→ hook      :", tokenizer.decode(toks_h_flat, skip_special_tokens=True))
print("→ gen       :", tokenizer.decode(toks_g_flat, skip_special_tokens=True))

print("  Différences (fwd vs hook):", (toks_f_flat != toks_h_flat).sum().item())
print("  Différences (fwd vs gen) :", (toks_f_flat != toks_g_flat).sum().item())
print("  Différences (hook vs gen):", (toks_h_flat != toks_g_flat).sum().item())

Compare activations from different sources
forward vs hook close? True, diff: 0.0
forward vs gen close? True, diff: 0.0
hook vs gen close? True, diff: 0.0
Compare Logits from different sources
forward vs logitLens forward close? True, diff: 0.0
forward vs logitLens hook close? True, diff: 0.0
forward vs logitLens gen close? True, diff: 0.0
logitLens forward vs logitLens hook close? True, diff: 0.0
logitLens forward vs logitLens gen close? True, diff: 0.0
logitLens hook vs logitLens gen close? True, diff: 0.0
Compare predicted tokens from computed logits

[Exemple 0] Prompt : The cat sat on the
→ forward   : Unterscheidung  is on the windows
→ logitLens : Unterscheidung  is on the windows
→ hook      : Unterscheidung  is on the windows
→ gen       : Unterscheidung  is on the windows
  Différences (fwd vs hook): 0
  Différences (fwd vs gen) : 0
  Différences (hook vs gen): 0
Flatten to compare logits only with non padded part
Compare Logits from different sources
forward vs logitLens for

In [23]:
'''
Setting:
-------
Each time we retrieve the LAST layer hidden state of the GENERATED TOKENS.
We can see that last prompt logit = first generated logit. 
'''

hidden_hook = torch.stack(activations[0][1:], dim=0).squeeze(2).transpose(0, 1).to(hidden_forward.device) # [1, gen_len-1, hidden_size]

last_layer_activations = [outputs_gen.hidden_states[i][-1] for i in range(1,len(outputs_gen.hidden_states))]
last_layer_activations = [h.squeeze(1) if h.dim() == 3 and h.size(1) == 1 else h for h in last_layer_activations]
hidden_gen = torch.stack(last_layer_activations, dim=1)# [1, gen_len-1, hidden_size]

# --- Compare activations ---
print("Compare activations from different sources")
compare(hidden_hook, hidden_gen, "hook", "gen")

# --- Logit lens ---
logits_gen = torch.stack(outputs_gen.logits, dim=0).transpose(0, 1)[:,:,:].float() 
logits_gen_truncated = logits_gen[:,1:,:] # remove first token
logits_lens_hook = apply_logit_lens(model, hidden_hook).float().to(hidden_forward.device)
logits_lens_gen = apply_logit_lens(model, hidden_gen).float().to(hidden_forward.device)

# --- Comparaison globale ---
print("Compare Logits from different sources")
compare(logits_gen_truncated, logits_lens_hook, "gen", "logitLens hook")
compare(logits_gen_truncated, logits_lens_gen, "gen", "logitLens gen")
compare(logits_lens_hook, logits_lens_gen, "logitLens hook", "logitLens gen")

print(logits_gen.shape)
print(logits_gen_truncated.shape)
print(logits_lens_hook.shape)
print(logits_lens_gen.shape)

last_prompt_logit = logits_forward[:,-1,:] #we can see that last prompt logit = first generated logit
print('\nlast_prompt_logit\n', last_prompt_logit)
print('\nlogits_gen\n', logits_gen)
print('\nlogits_gen_truncated\n', logits_gen_truncated)
print('\nlogits_lens_hook\n', logits_lens_hook)
print('\nlogits_lens_gen\n', logits_lens_gen)

Compare activations from different sources
hook vs gen close? True, diff: 0.0
Compare Logits from different sources
gen vs logitLens hook close? False, diff: 0.21297432482242584
gen vs logitLens gen close? False, diff: 0.21297432482242584
logitLens hook vs logitLens gen close? True, diff: 0.0
torch.Size([1, 50, 32000])
torch.Size([1, 49, 32000])
torch.Size([1, 49, 32000])
torch.Size([1, 49, 32000])

last_prompt_logit
 tensor([[-4.6992, -5.0938,  3.2773,  ..., -4.9961, -4.6602, -3.6621]],
       device='cuda:0', dtype=torch.float16)

logits_gen
 tensor([[[-4.6992, -5.0938,  3.2773,  ..., -4.9961, -4.6602, -3.6621],
         [-2.0820, -2.6055,  6.6562,  ..., -1.3359, -4.8594, -1.1240],
         [-2.6113, -2.0156,  7.5156,  ..., -0.7070, -2.8418, -1.8438],
         ...,
         [-4.6562, -6.9102,  2.6953,  ..., -4.2930, -2.2969, -5.2305],
         [-2.8398, -3.9766,  6.6953,  ..., -1.3516, -5.1016, -2.3145],
         [-4.4023, -5.5742,  7.2500,  ..., -4.0039, -5.2695, -5.5352]]],
       