In [1]:
import os
import sys
import argparse
import logging
import re
import typing as ty

from tqdm import tqdm
from warnings import warn
from torch.multiprocessing import Pool, set_start_method
set_start_method('spawn', force=True)
from functools import partial
import more_itertools as mit

import torch
import fairseq
from fairseq.models.bart import BARTHubInterface
from fairseq.models.bart import BARTModel

import nvgpu

from pathlib import Path

In [2]:
import logzero

from datetime import datetime
_datetime_exec = datetime.now()

logzero.logfile(f"logs/{_datetime_exec.isoformat()}.log")

logger = logzero.logger

In [3]:
def load_model(task: Path, model_path: Path) -> BARTHubInterface:
    """
    Args:
        task: a path to the directory of the model.
        model_path: a path to 'model.pt' file.
    """
    assert task.exists()
    assert model_path.exists()

    logger.info(f"Loading model {model_path}")
    model_dirname, model_fname = os.path.split(model_path.as_posix())
    bart = BARTModel.from_pretrained(
        model_dirname,
        checkpoint_file=model_fname,
        data_name_or_path=task.as_posix()
    )
    logger.info(f"Loading done.")
    return bart


In [4]:
# path to input
PATH_TEXT_FILE_INPUT = Path("/workdir/kmitsuzawa/Project/neurips-2025/ConstraintsFact-Dreyer-2023/abstractive-factual-tradeoff/tests/testresources/xsum/test_source.txt")
assert PATH_TEXT_FILE_INPUT.exists()

seq_text_input = PATH_TEXT_FILE_INPUT.open().readlines()
assert len(seq_text_input) > 0

In [66]:
# with xsum model
PATH_MODEL_FILE = Path('/workdir/kmitsuzawa/Project/neurips-2025/ConstraintsFact-Dreyer-2023/abstractive-factual-tradeoff/tests/testresources/models/bart.large.xsum')
# with cnn model
# PATH_MODEL_FILE = Path('/workdir/kmitsuzawa/Project/neurips-2025/ConstraintsFact-Dreyer-2023/abstractive-factual-tradeoff/tests/testresources/models/bart.large.cnn')

bart_model = load_model(PATH_MODEL_FILE, PATH_MODEL_FILE / 'model.pt')

[I 250718 07:54:11 2610531437:10] Loading model /workdir/kmitsuzawa/Project/neurips-2025/ConstraintsFact-Dreyer-2023/abstractive-factual-tradeoff/tests/testresources/models/bart.large.xsum/model.pt
[I 250718 07:54:26 2610531437:17] Loading done.


In [67]:
if torch.cuda.is_available():
    device_obj = torch.device('cuda:0')
else:
    device_obj = torch.device('cpu')
# end if

bart_model = bart_model.to(device_obj)

In [7]:
# case Xsum constraints dataset
import json

PATH_CONSTRAINS_XSUM = Path("/workdir/kmitsuzawa/Project/neurips-2025/ConstraintsFact-Dreyer-2023/abstractive-factual-tradeoff/tests/testresources/datasets/constraints_fact_v1.0/xsum/collect.json")
assert PATH_CONSTRAINS_XSUM.exists()

with PATH_CONSTRAINS_XSUM.open() as f:
    seq_dataset = [json.loads(_line) for _line in f.readlines()]
# end with

logger.info(f'{len(seq_dataset)} records')

# double check: all xsum
for _record in seq_dataset:
    assert _record['dataset_name'] == 'xsum'
# end for

[I 250718 06:50:21 3352606314:11] 3000 records


In [8]:
def get_extractive_penalty_fct(penalty_command: str) -> str:
    dict_commnad2ep = dict(
        lambda4 = 'log_exp(2,4.804488)',  # lambda4
        lambda2 = 'log_exp(2,2.402244)',  # lambda2
        lambda1 = 'log_exp(2,1.201122)',  # lambda1
        none = 'none()',
        linear = 'linear()',
    )
    dict_commnad2ep['1/lambda2'] = 'log_exp(2,0.416277447)'  # 1/lambda2, log_exp(2, 1 / (1.20112 * 2))
    dict_commnad2ep['1/lambda1'] = 'log_exp(2,0.832556281)'  # 1/lambda1, log_exp(2, 1 / 1.20112)

    assert penalty_command in dict_commnad2ep

    return dict_commnad2ep[penalty_command]


def bart_sample(bart_model: BARTHubInterface,
                batch: ty.List[str],
                extractive_penalty_fct: str,
                beam: int = 4,
                lenpen: float = 2.0,  # length penalty
                min_len: int = 55,
                max_len_a: int = 0,
                max_len_b: int = 140,
                no_repeat_ngram_size: int = 3):
    
    # lenpen: float = 2.0  # length penalty
    # min_len: int = 55
    # max_len_a: int = 0
    # max_len_b: int = 140
    # no_repeat_ngram_size: int = 3
    extractive_penalty_fct = get_extractive_penalty_fct('none')

    with torch.random.fork_rng():
        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)  # if you are using multi-GPU.

        # tau = 0.1
        dict_parameters = dict(
            beam=beam,
            lenpen=lenpen,
            sampling=False,
            min_len=min_len, 
            max_len_a=max_len_a, 
            max_len_b=max_len_b,
            # temperature=0.1,
            no_repeat_ngram_size=no_repeat_ngram_size,
            extractive_penalty_fct=extractive_penalty_fct)

        with torch.no_grad():
            tensor_input_ids = [bart_model.encode(text) for text in  batch]
            tensor_stack = torch.stack(tensor_input_ids).to(bart_model.device)
            generated_ids = bart_model.generate(tensor_stack, **dict_parameters)
            text_summary_generate_method: str = bart_model.decode(generated_ids[0]['tokens'])
    
# end def



def get_source_and_summary(record_obj: ty.Dict) -> ty.Tuple[str, str]:
    # return record_obj['document_original'], record_obj['summary_raw']
    return record_obj['document_full'], record_obj['summary_raw']
# end def

target_document_index = [1, 10, 100, 200]

import pprint

seq_stack = []

dict_commnad2ep = dict(
    lambda4 = 'log_exp(2,4.804488)',  # lambda4
    lambda2 = 'log_exp(2,2.402244)',  # lambda2
    lambda1 = 'log_exp(2,1.201122)',  # lambda1
    none = 'none()',
    linear = 'linear()',
)
dict_commnad2ep['1/lambda2'] = 'log_exp(2,0.416277447)'  # 1/lambda2, log_exp(2, 1 / (1.20112 * 2))
dict_commnad2ep['1/lambda1'] = 'log_exp(2,0.832556281)'  # 1/lambda1, log_exp(2, 1 / 1.20112)


for _idx in target_document_index:
    _record = seq_dataset[_idx]

    _document_id: str = _record['document_id']
    command_abstractiveness_constraint: str = _record['abstractiveness_constraint']

    _document_original, _summary_raw = get_source_and_summary(_record)
    extractive_penalty_fct = dict_commnad2ep[command_abstractiveness_constraint]

    seq_summary = bart_sample(
        bart_model=bart_model,
        batch=[_document_original],
        extractive_penalty_fct=extractive_penalty_fct
    )




In [None]:
# with xsum model
PATH_MODEL_FILE = Path('/workdir/kmitsuzawa/Project/neurips-2025/ConstraintsFact-Dreyer-2023/abstractive-factual-tradeoff/tests/testresources/models/bart.large.xsum')
# with cnn model
# PATH_MODEL_FILE = Path('/workdir/kmitsuzawa/Project/neurips-2025/ConstraintsFact-Dreyer-2023/abstractive-factual-tradeoff/tests/testresources/models/bart.large.cnn')

bart_model = load_model(PATH_MODEL_FILE, PATH_MODEL_FILE / 'model.pt')

if torch.cuda.is_available():
    device_obj = torch.device('cuda:0')
else:
    device_obj = torch.device('cpu')
# end if

bart_model = bart_model.to(device_obj)

In [70]:
from fairseq.sequence_generator import SequenceGenerator
from functools import wraps

# ------------------------------------
# setting hook to extract the attention heads
attn_data = {}

def get_attention_hook(layer_idx, attn_type):
    def hook(module, input, output):
        # if output is not None and isinstance(output, tuple) and output[1] is not None:
        #     # print(layer_idx, attn_type)
        #     attn_data.setdefault((layer_idx, attn_type), (output[1].detach().cpu()))
        
        # Check output format from MultiheadAttention.forward
        # It can return (attn_output, attn_weights) if `need_weights=True`
        if isinstance(output, tuple):
            _, attn_weights = output
            if attn_weights is not None:
                attn_data[(layer_idx, attn_type)] = attn_weights.detach().cpu()
            # end if
        # end if
    return hook
# end def

# Register hooks for all decoder layers
self_attn_hooks = []
cross_attn_hooks = []
for i, layer in enumerate(bart_model.model.decoder.layers):
    layer.self_attn.register_forward_hook(get_attention_hook(i, 'self'))
    layer.encoder_attn.register_forward_hook(get_attention_hook(i, 'cross'))
    logger.debug(f'Setting attention_hook at a layer at {i}')
# end for

# ------------------------------------
# setting the internal parameters True for extracting the attention heads.
# for layer in encoder_decoder_interface.model.decoder.layers:
#     def wrap_forward(original_forward):
#         def new_forward(*args, **kwargs):
#             kwargs['need_attn'] = True
#             kwargs['need_head_weights'] = True
#             return original_forward(*args, **kwargs)
#         return new_forward
#     layer.forward = wrap_forward(layer.forward)
# setting the flag to the self-attention layers.
# overwriting the foward method of `MultiheadAttention`. The default value is False and attn_weights are averaged automatically.
# https://github.com/facebookresearch/fairseq/blob/d13e14a800bb588e5a77fb4e551f554ff9b24a72/fairseq/modules/multihead_attention.py#L469
for i, layer in enumerate(bart_model.model.decoder.layers):
    orig_self_attn_forward = layer.self_attn.forward
    orig_encoder_attn_forward = layer.encoder_attn.forward        

    # @wraps(orig_self_attn_forward)
    # def wrapped_self_attn_forward(*args, **kwargs):
    #     kwargs['need_weights'] = True
    #     kwargs['need_head_weights'] = True
    #     return orig_self_attn_forward(*args, **kwargs)

    # @wraps(orig_encoder_attn_forward)
    # def wrapped_encoder_attn_forward(*args, **kwargs):
    #     kwargs['need_weights'] = True
    #     kwargs['need_head_weights'] = True
    #     return orig_encoder_attn_forward(*args, **kwargs)

    # layer.self_attn.forward = wrapped_self_attn_forward
    # layer.encoder_attn.forward = wrapped_encoder_attn_forward        
# ------------------------------------


generator = SequenceGenerator(
    models=[bart_model.model],
    tgt_dict=bart_model.task.target_dictionary,
    beam_size=4,
    len_penalty=1.0,
    max_len_b=200,
    min_len=1,
    no_repeat_ngram_size=3,
    extractive_penalty_fct="none()"
)


source_text = "We present a novel approach for detecting hallucinations in large language models (LLMs) by analyzing the probabilistic divergence between prompt and response hiddenstate distributions. Counterintuitively, we find that hallucinated responses exhibit smaller deviations from their prompts compared to grounded responses, suggesting that hallucinations often arise from superficial rephrasing rather than substantive reasoning. Leveraging this insight, we propose a model-intrinsic detection method1 that uses distributional distances as principled hallucination scores, eliminating the need for external knowledge or auxiliary models. To enhance sensitivity, we employ deep learnable kernels that automatically adapt to capture nuanced geometric differences between distributions. Our approach outperforms existing baselines, demonstrating state-of-the-art performance on several benchmarks. The method remains competitive even without kernel training, offering a robust, scalable solution for hallucination detection."
source_ids = bart_model.encode(source_text).unsqueeze(0).to(device_obj)

encoder_out = bart_model.model.encoder.forward(source_ids, src_lengths=None)

output = generator.generate(
    models=[bart_model.model],
    sample={'net_input': {'src_tokens': source_ids, 'src_lengths': None}},
    prefix_tokens=None
)

# deleting the hook functions
for h in self_attn_hooks + cross_attn_hooks:
    h.remove()
# end for

bart_model.decode(output[0][0]['tokens'])

[D 250718 07:54:36 1771820259:31] Setting attention_hook at a layer at 0
[D 250718 07:54:36 1771820259:31] Setting attention_hook at a layer at 1
[D 250718 07:54:36 1771820259:31] Setting attention_hook at a layer at 2
[D 250718 07:54:36 1771820259:31] Setting attention_hook at a layer at 3
[D 250718 07:54:36 1771820259:31] Setting attention_hook at a layer at 4
[D 250718 07:54:36 1771820259:31] Setting attention_hook at a layer at 5
[D 250718 07:54:36 1771820259:31] Setting attention_hook at a layer at 6
[D 250718 07:54:36 1771820259:31] Setting attention_hook at a layer at 7
[D 250718 07:54:36 1771820259:31] Setting attention_hook at a layer at 8
[D 250718 07:54:36 1771820259:31] Setting attention_hook at a layer at 9
[D 250718 07:54:36 1771820259:31] Setting attention_hook at a layer at 10
[D 250718 07:54:36 1771820259:31] Setting attention_hook at a layer at 11


'A novel method for detecting hallucinations in large language models has been proposed.'

In [71]:
attn_data.keys()

dict_keys([(0, 'cross'), (1, 'cross'), (2, 'cross'), (3, 'cross'), (4, 'cross'), (5, 'cross'), (6, 'cross'), (7, 'cross'), (8, 'cross'), (9, 'cross'), (10, 'cross'), (11, 'cross')])

In [53]:
bart_model.sample([source_text],
                  beam_size=4,
                  len_penalty=1.0,
                  max_len_b=200,
                  min_len=1,
                  no_repeat_ngram_size=3,
                  extractive_penalty_fct="none()"
)

['A new method for detecting hallucinations in large language models has been proposed by researchers.']

In [49]:
tensor_source_id = bart_model.encode(source_text)

generate_out = bart_model.generate(
    tensor_source_id.unsqueeze(0).to(bart_model.device),
    beam_size=4,
    len_penalty=1.0,
    max_len_b=200,
    min_len=1,
    no_repeat_ngram_size=3,
    extractive_penalty_fct="none()"
)

bart_model.decode(generate_out[0]['tokens'])

'We present a novel approach for detecting hallucinations in large language models.'

# Codebase toward Attention-Head

In [None]:
# # with xsum model
# PATH_MODEL_FILE = Path('/workdir/kmitsuzawa/Project/neurips-2025/ConstraintsFact-Dreyer-2023/abstractive-factual-tradeoff/tests/testresources/models/bart.large.xsum')
# # with cnn model
# # PATH_MODEL_FILE = Path('/workdir/kmitsuzawa/Project/neurips-2025/ConstraintsFact-Dreyer-2023/abstractive-factual-tradeoff/tests/testresources/models/bart.large.cnn')

# bart_model = load_model(PATH_MODEL_FILE, PATH_MODEL_FILE / 'model.pt')

# if torch.cuda.is_available():
#     device_obj = torch.device('cuda:0')
# else:
#     device_obj = torch.device('cpu')
# # end if

# bart_model = bart_model.to(device_obj)

[I 250718 06:50:24 2610531437:10] Loading model /workdir/kmitsuzawa/Project/neurips-2025/ConstraintsFact-Dreyer-2023/abstractive-factual-tradeoff/tests/testresources/models/bart.large.xsum/model.pt
[I 250718 06:50:38 2610531437:17] Loading done.


In [None]:
# from torch.nn import functional as F

# from functools import wraps

# import fairseq
# from fairseq.hub_utils import GeneratorHubInterface
# from fairseq.sequence_generator import SequenceGenerator


# class GenerateResult(ty.NamedTuple):
#     generated_token_ids: torch.Tensor
#     generated_text: str
#     generation_parameter: ty.Dict
#     attention_headers: ty.Dict[ty.Tuple[int, str], torch.Tensor]
#     stats: ty.Dict[str, int]


# def generate_attention_head_extraction(encoder_decoder_interface: GeneratorHubInterface,
#                                        source_token_ids: torch.Tensor,
#                                        reference_token_ids: ty.Optional[torch.Tensor] = None,
#                                        sampling: bool = False,
#                                        beam_size: int = 1,
#                                        max_len_a: float = 0.0,
#                                        max_len_b: int = 200,
#                                        max_len: int = 0,
#                                        min_len: int = 1,
#                                        normalize_scores=True,
#                                        len_penalty: float = 1.0,
#                                        unk_penalty: float  =0.0,
#                                        temperature: float = 1.0,
#                                        no_repeat_ngram_size: int = 0,
#                                        random_seed: int = 42,
#                                        is_activate_eval_mode: bool = True
#                                        ) -> GenerateResult:
#     """A custom function.
    
#     NOT Possible:
#         `top_p` sampling.
#         `top_k` sampling.


#     Args:
#         reference_token_ids (optinal): used for the teacher-forcing mode.
#     """
#     # input arguments check
#     assert len(source_token_ids.shape) == 2, f"The input tensor must be (batch, n-token). Given -> {source_token_ids.shape}"
#     if reference_token_ids is not None:
#         assert len(reference_token_ids.shape) == 2, f"The `reference_token_ids` tensor must be (batch, n-token). Given -> {reference_token_ids.shape}"

#     assert source_token_ids.shape[0] == 1, f"This implementation can process only one batch. Given -> {source_token_ids.shape}"


#     if is_activate_eval_mode:
#         encoder_decoder_interface.eval()
#     # end if


#     # Constants
#     eos_idx = encoder_decoder_interface.task.source_dictionary.eos()
#     # bos_idx = eos_idx  # For BART, BOS and EOS are the same: </s> 
#     target_ids_start = torch.tensor([[encoder_decoder_interface.model.decoder.dictionary.bos()]]).to(device_obj)

#     # ------------------------------------
#     # check the `reference_token_ids`. `reference_token_ids` may start from the `target_ids_start`
#     if reference_token_ids is not None:
#         if all(reference_token_ids[0, 0] == target_ids_start):
#             reference_token_ids = reference_token_ids[:, 1:]
#         # end if
#     # end if

#     # ------------------------------------
#     # setting hook to extract the attention heads
#     attn_data = {}

#     def get_attention_hook(layer_idx, attn_type):
#         def hook(module, input, output):
#             # if output is not None and isinstance(output, tuple) and output[1] is not None:
#             #     # print(layer_idx, attn_type)
#             #     attn_data.setdefault((layer_idx, attn_type), (output[1].detach().cpu()))
            
#             # Check output format from MultiheadAttention.forward
#             # It can return (attn_output, attn_weights) if `need_weights=True`
#             if isinstance(output, tuple):
#                 _, attn_weights = output
#                 if attn_weights is not None:
#                     attn_data[(layer_idx, attn_type)] = attn_weights.detach().cpu()
#                 # end if
#             # end if
#         return hook
#     # end def

#     # Register hooks for all decoder layers
#     self_attn_hooks = []
#     cross_attn_hooks = []
#     for i, layer in enumerate(encoder_decoder_interface.model.decoder.layers):
#         layer.self_attn.register_forward_hook(get_attention_hook(i, 'self'))
#         layer.encoder_attn.register_forward_hook(get_attention_hook(i, 'cross'))
#         logger.debug(f'Setting attention_hook at a layer at {i}')
#     # end for
    
#     source_token_ids = source_token_ids.to(encoder_decoder_interface.device)
    
#     if reference_token_ids is not None:
#         reference_token_ids = reference_token_ids.to(encoder_decoder_interface.device)
#     # end if

#     # ------------------------------------
#     # setting the internal parameters True for extracting the attention heads.
#     # for layer in encoder_decoder_interface.model.decoder.layers:
#     #     def wrap_forward(original_forward):
#     #         def new_forward(*args, **kwargs):
#     #             kwargs['need_attn'] = True
#     #             kwargs['need_head_weights'] = True
#     #             return original_forward(*args, **kwargs)
#     #         return new_forward
#     #     layer.forward = wrap_forward(layer.forward)
#     # setting the flag to the self-attention layers.
#     # overwriting the foward method of `MultiheadAttention`. The default value is False and attn_weights are averaged automatically.
#     # https://github.com/facebookresearch/fairseq/blob/d13e14a800bb588e5a77fb4e551f554ff9b24a72/fairseq/modules/multihead_attention.py#L469
#     for i, layer in enumerate(encoder_decoder_interface.model.decoder.layers):
#         orig_self_attn_forward = layer.self_attn.forward
#         orig_encoder_attn_forward = layer.encoder_attn.forward        

#         @wraps(orig_self_attn_forward)
#         def wrapped_self_attn_forward(*args, **kwargs):
#             kwargs['need_weights'] = True
#             kwargs['need_head_weights'] = True
#             return orig_self_attn_forward(*args, **kwargs)

#         @wraps(orig_encoder_attn_forward)
#         def wrapped_encoder_attn_forward(*args, **kwargs):
#             kwargs['need_weights'] = True
#             kwargs['need_head_weights'] = True
#             return orig_encoder_attn_forward(*args, **kwargs)

#         layer.self_attn.forward = wrapped_self_attn_forward
#         layer.encoder_attn.forward = wrapped_encoder_attn_forward        
#     # ------------------------------------
#     # encoder
#     with torch.no_grad():
#         encoder_out = encoder_decoder_interface.model.encoder(
#             src_tokens=source_token_ids,
#             src_lengths=None
#         )

#     # ------------------------------------

#     alignment_layer = 0
#     alignment_heads = None  # or specify a head

#     # Start decoder input with BOS token
#     prev_output_tokens = torch.tensor([[target_ids_start]], dtype=torch.long, device=encoder_decoder_interface.device)

#     generated_tokens = [prev_output_tokens.item()]

#     # ------------------------------------
#     # main loop of the generation.
#     i_step = 1

#     if max_len == 0:
#         _max_len = (len(source_token_ids) * max_len_a) + max_len_b
#     else:
#         _max_len = max_len
#     # end if

#     while i_step <= _max_len:
#         _decoder_out = encoder_decoder_interface.model.decoder.forward(
#             prev_output_tokens=prev_output_tokens,
#             encoder_out=encoder_out,
#             alignment_layer=alignment_layer,
#             alignment_heads=alignment_heads,
#             return_all_hiddens=True)

#         logits = _decoder_out[0][:, -1, :]  # logits for last token

#         if sampling:
#             # Apply sampling strategy
#             probs = F.softmax(logits / temperature, dim=-1)
#             with torch.random.fork_rng(devices=[encoder_decoder_interface.device]):
#                 torch.manual_seed(random_seed)
#                 torch.cuda.manual_seed_all(random_seed)  # if you are using multi-GPU.
#                 __next_token_id = torch.multinomial(probs, num_samples=1)  # stochastic sampling
#             # end with
#             _next_token = __next_token_id            
#         elif reference_token_ids is not None:
#             # teacher forcing
#             __next_token_id = reference_token_ids[:, (i_step - 1)]
#             _next_token = __next_token_id.unsqueeze(0)
#         elif temperature == 0.0:
#             # greedy search
#             __next_token_id = torch.argmax(logits / temperature, dim=-1)  # greedy
#             _next_token = torch.stack([__next_token_id])
#         else:
#             raise Exception('Not defined Parameter combination.')
#         # end if
        
#         # Append to sequence
#         prev_output_tokens = torch.cat([prev_output_tokens, _next_token], dim=1)
        
#         # end if
#         generated_tokens.append(__next_token_id.item())

#         # end condition
#         if __next_token_id.item() == encoder_decoder_interface.task.source_dictionary.eos():
#             break
#         elif i_step == _max_len:
#             break
#         elif reference_token_ids is not None and i_step == reference_token_ids.shape[1]:
#             break
#         else:
#             i_step += 1
#         # end if
#     # end with


#     _generation_parameters = dict(
#         max_len=_max_len,
#         sampling=sampling,
#         temperature=temperature
#     )

#     generated_text = encoder_decoder_interface.decode(torch.tensor(generated_tokens))

#     # ------------------------------------
#     # stats
#     n_layers_decoder = len(encoder_decoder_interface.model.decoder.layers)
#     n_self_attn_head_decoder = encoder_decoder_interface.model.decoder.layers[0].self_attn.num_heads
#     n_encoder_attn_header_decoder = encoder_decoder_interface.model.decoder.layers[0].encoder_attn.num_heads
#     _stats = dict(
#         n_token_source=source_token_ids.shape[1],
#         n_token_generated=len(generated_tokens) - 1,  # -1 for the BOS token.
#         n_layers_decoder=n_layers_decoder,
#         n_self_attn_head_decoder=n_self_attn_head_decoder,
#         n_encoder_attn_header_decoder=n_encoder_attn_header_decoder
#     )

#     # ------------------------------------

#     generated_obj = GenerateResult(
#         generated_token_ids=torch.tensor(generated_tokens),
#         generated_text=generated_text,
#         generation_parameter=_generation_parameters,
#         attention_headers=attn_data,
#         stats=_stats
#     )

#     # deleting the hook functions
#     for h in self_attn_hooks + cross_attn_hooks:
#         h.remove()
#     # end for

#     return generated_obj
# # end def



# def test_teacher_forcing():

#     source_text = "We present a novel approach for detecting hallucinations in large language models (LLMs) by analyzing the probabilistic divergence between prompt and response hiddenstate distributions. Counterintuitively, we find that hallucinated responses exhibit smaller deviations from their prompts compared to grounded responses, suggesting that hallucinations often arise from superficial rephrasing rather than substantive reasoning. Leveraging this insight, we propose a model-intrinsic detection method1 that uses distributional distances as principled hallucination scores, eliminating the need for external knowledge or auxiliary models. To enhance sensitivity, we employ deep learnable kernels that automatically adapt to capture nuanced geometric differences between distributions. Our approach outperforms existing baselines, demonstrating state-of-the-art performance on several benchmarks. The method remains competitive even without kernel training, offering a robust, scalable solution for hallucination detection."
#     source_ids = bart_model.encode(source_text).unsqueeze(0).to(device_obj)

#     target_text_forcing = "This text introduces a hallucination detection method." 
#     target_token_ids = bart_model.encode(target_text_forcing).unsqueeze(0).to(device_obj)
#     # target_token_ids = None

#     obj_generated = generate_attention_head_extraction(
#         encoder_decoder_interface=bart_model,
#         source_token_ids=source_ids,
#         reference_token_ids=target_token_ids,
#         max_len=10,
#         temperature=0.1
#     )

#     assert target_text_forcing == obj_generated.generated_text
#     assert tuple(target_token_ids[0].tolist()) == tuple(obj_generated.generated_token_ids.tolist())
# # end def


# def test_stochastic_sampling():

#     source_text = "We present a novel approach for detecting hallucinations in large language models (LLMs) by analyzing the probabilistic divergence between prompt and response hiddenstate distributions. Counterintuitively, we find that hallucinated responses exhibit smaller deviations from their prompts compared to grounded responses, suggesting that hallucinations often arise from superficial rephrasing rather than substantive reasoning. Leveraging this insight, we propose a model-intrinsic detection method1 that uses distributional distances as principled hallucination scores, eliminating the need for external knowledge or auxiliary models. To enhance sensitivity, we employ deep learnable kernels that automatically adapt to capture nuanced geometric differences between distributions. Our approach outperforms existing baselines, demonstrating state-of-the-art performance on several benchmarks. The method remains competitive even without kernel training, offering a robust, scalable solution for hallucination detection."
#     source_ids = bart_model.encode(source_text).unsqueeze(0).to(device_obj)

#     target_token_ids = None

#     # bart_model.eval()

#     obj_generated_1st = generate_attention_head_extraction(
#         encoder_decoder_interface=bart_model,
#         source_token_ids=source_ids,
#         reference_token_ids=target_token_ids,
#         max_len=10,
#         temperature=0.1,
#         sampling=True,
#         random_seed=42   
#     )
#     obj_generated_2nd = generate_attention_head_extraction(
#         encoder_decoder_interface=bart_model,
#         source_token_ids=source_ids,
#         reference_token_ids=target_token_ids,
#         max_len=10,
#         temperature=0.1,
#         sampling=True,
#         random_seed=42
#     )
#     assert obj_generated_1st.generated_token_ids.tolist() == obj_generated_2nd.generated_token_ids.tolist()

#     obj_generated_with_05 = generate_attention_head_extraction(
#         encoder_decoder_interface=bart_model,
#         source_token_ids=source_ids,
#         reference_token_ids=target_token_ids,
#         max_len=10,
#         temperature=1.0,
#         sampling=True,
#         random_seed=42   
#     )
#     assert obj_generated_1st.generated_token_ids.tolist() != obj_generated_with_05.generated_token_ids.tolist()
# # end def


# # test_teacher_forcing()
# # test_stochastic_sampling()


# source_text = "We present a novel approach for detecting hallucinations in large language models (LLMs) by analyzing the probabilistic divergence between prompt and response hiddenstate distributions. Counterintuitively, we find that hallucinated responses exhibit smaller deviations from their prompts compared to grounded responses, suggesting that hallucinations often arise from superficial rephrasing rather than substantive reasoning. Leveraging this insight, we propose a model-intrinsic detection method1 that uses distributional distances as principled hallucination scores, eliminating the need for external knowledge or auxiliary models. To enhance sensitivity, we employ deep learnable kernels that automatically adapt to capture nuanced geometric differences between distributions. Our approach outperforms existing baselines, demonstrating state-of-the-art performance on several benchmarks. The method remains competitive even without kernel training, offering a robust, scalable solution for hallucination detection."
# source_ids = bart_model.encode(source_text).unsqueeze(0).to(device_obj)

# obj_generated_1st = generate_attention_head_extraction(
#     encoder_decoder_interface=bart_model,
#     source_token_ids=source_ids,
#     reference_token_ids=None,
#     max_len=10,
#     temperature=0.1,
#     sampling=True,
#     random_seed=42   
# )

# print(obj_generated_1st.stats)

# for _key in obj_generated_1st.attention_headers:
#     _attention_head = obj_generated_1st.attention_headers[_key]
#     print(_key, _attention_head.shape)

[D 250718 06:50:38 3983859745:98] Setting attention_hook at a layer at 0
[D 250718 06:50:38 3983859745:98] Setting attention_hook at a layer at 1
[D 250718 06:50:38 3983859745:98] Setting attention_hook at a layer at 2
[D 250718 06:50:38 3983859745:98] Setting attention_hook at a layer at 3
[D 250718 06:50:38 3983859745:98] Setting attention_hook at a layer at 4
[D 250718 06:50:38 3983859745:98] Setting attention_hook at a layer at 5
[D 250718 06:50:38 3983859745:98] Setting attention_hook at a layer at 6
[D 250718 06:50:38 3983859745:98] Setting attention_hook at a layer at 7
[D 250718 06:50:38 3983859745:98] Setting attention_hook at a layer at 8
[D 250718 06:50:38 3983859745:98] Setting attention_hook at a layer at 9
[D 250718 06:50:38 3983859745:98] Setting attention_hook at a layer at 10
[D 250718 06:50:38 3983859745:98] Setting attention_hook at a layer at 11


{'n_token_source': 170, 'n_token_generated': 10, 'n_layers_decoder': 12, 'n_self_attn_head_decoder': 16, 'n_encoder_attn_header_decoder': 16}
(0, 'self') torch.Size([1, 10, 10])
(0, 'cross') torch.Size([16, 1, 10, 170])
(1, 'self') torch.Size([1, 10, 10])
(1, 'cross') torch.Size([16, 1, 10, 170])
(2, 'self') torch.Size([1, 10, 10])
(2, 'cross') torch.Size([16, 1, 10, 170])
(3, 'self') torch.Size([1, 10, 10])
(3, 'cross') torch.Size([16, 1, 10, 170])
(4, 'self') torch.Size([1, 10, 10])
(4, 'cross') torch.Size([16, 1, 10, 170])
(5, 'self') torch.Size([1, 10, 10])
(5, 'cross') torch.Size([16, 1, 10, 170])
(6, 'self') torch.Size([1, 10, 10])
(6, 'cross') torch.Size([16, 1, 10, 170])
(7, 'self') torch.Size([1, 10, 10])
(7, 'cross') torch.Size([16, 1, 10, 170])
(8, 'self') torch.Size([1, 10, 10])
(8, 'cross') torch.Size([16, 1, 10, 170])
(9, 'self') torch.Size([1, 10, 10])
(9, 'cross') torch.Size([16, 1, 10, 170])
(10, 'self') torch.Size([1, 10, 10])
(10, 'cross') torch.Size([16, 1, 10, 170]

In [None]:
# type(bart_model.model.decoder.layers[1].encoder_attn)

fairseq.modules.multihead_attention.MultiheadAttention