# Notebook Overview

This notebook provides an faster testing interface. 

It includes steps to 
 - Check GPU availability
 - Set up the environment, 
 - Define helper functions,
 - Run tests with a given **id** and **context_type**. 

The results are saved for further analysis.

*Make sure to select the created `venv` as your kernel.*

If there are any errors while running the tests, restart the the notebook, and run again.

## Check GPU Availability
Run this cell to check if there is free space on the gpu. 

If the **free space** is not close to the **total space** then someone else is probably using the machine. 

You can check the active processes bu running `nvidia-smi` in the terminal.

In [1]:
import torch

def get_gpu_memory_torch():
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            free_mem = torch.cuda.mem_get_info(i)[0] / 1024**3
            total_mem = torch.cuda.get_device_properties(i).total_memory / 1024**3
            print(f"GPU {i}: {free_mem:.4f} GiB free / {total_mem:.4f} GiB total")
    else:
        print("No CUDA-compatible GPU detected.")

get_gpu_memory_torch()

GPU 0: 23.2824 GiB free / 23.5766 GiB total


## Setup
The code below imports necessary modules, defines directory paths, and sets the Hugging Face cache path.

**No need to modify this.**

In [2]:
import os
import sys
import json
from pathlib import Path

PROJECT_DIR = Path.cwd()
HAYSTACK_DIR = PROJECT_DIR / "haystack"
RELEVANT_DIR = HAYSTACK_DIR / "relevant"
IRRELAVANT_DIR = HAYSTACK_DIR / "irrelevant"
MISLEADING_IN_RELEVANT_DIR = HAYSTACK_DIR / "misleading_in_relevant"
MISLEADING_IN_IRRELEVANT_DIR = HAYSTACK_DIR / "misleading_in_irrelevant"

sys.path.append(str(PROJECT_DIR))

with open(HAYSTACK_DIR / "needles.json", "r") as f:
    NEEDLES_DATA = json.load(f)

# This sets the Hugging Face cache path. Make sure this directory exists. If not, refer to the README.
os.environ['HF_HOME'] = '.cache/hf_with_quota'

## Helper Functions

In [3]:
from transformers import AutoTokenizer

def load_text(path):
    with open(path, 'r', encoding='utf-8') as f:
        return f.read()

def insert_strings_at_sentence_breaks(text, insert_strings, context_length, tokenizer):
    # Step 1: Tokenize and truncate
    tokens = tokenizer.encode(text, add_special_tokens=False)[:context_length]
    
    # Step 2: Detect sentence breaks by decoding each token
    sentence_break_indices = [
        i for i, tok in enumerate(tokens)
        if tokenizer.decode([tok]).strip().endswith((".", "!", "?"))
    ]

    if len(sentence_break_indices) < len(insert_strings):
        raise ValueError(f"Only {len(sentence_break_indices)} sentence breaks found, "
                         f"but {len(insert_strings)} insertions requested.")

    # Step 3: Compute evenly spaced sentence break positions
    step = len(sentence_break_indices) // (len(insert_strings) + 1)
    insert_positions = [sentence_break_indices[(i + 1) * step] + 1 for i in range(len(insert_strings))]

    # Step 4: Insert insert_strings as tokens
    for pos, insert_str in reversed(list(zip(insert_positions, insert_strings))):
        insert_tokens = tokenizer.encode(insert_str, add_special_tokens=False)
        tokens[pos:pos] = insert_tokens

    return tokenizer.decode(tokens)


def insert_misleading_statements(filepath, insert_strings, context_length, output_path):
    """
    Insert misleading statements into the text at sentence breaks.
    Evenly distribute the misleading statements across the text within the context length.
    """
    tokenizer = AutoTokenizer.from_pretrained("yaofu/llama-2-7b-80k")

    full_text = load_text(filepath)
    modified_text = insert_strings_at_sentence_breaks(full_text, insert_strings, context_length, tokenizer)

    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(modified_text)

    print(f"✅ Output saved to {output_path}")


In [4]:
from argparse import Namespace

def get_haystack_file(item, context_type, with_misleading, context_length):
    """
    Get the haystack file for the given item and context type.
    Arguments:
        item (dict): The item to process.
        context_type (str): The type of context ("relevant" or "irrelevant").
        with_misleading (bool): Whether to include misleading information.
    Returns:
        str: The path to the haystack file.
    """

    if context_type == "relevant":
        original_dir = RELEVANT_DIR
        original_file = item["context_relevant"]  # Original file without misleading info
        misleading_dir = MISLEADING_IN_RELEVANT_DIR  # Output dir for the misleading file
    elif context_type == "irrelevant":
        original_dir = IRRELAVANT_DIR
        original_file = item["context_irrelevant"]
        misleading_dir = MISLEADING_IN_IRRELEVANT_DIR
    else:
        raise ValueError(f"Unknown context_type: {context_type}")

    if with_misleading:
        haystack_file = misleading_dir / original_file
        insert_misleading_statements(
            filepath=original_dir / original_file,           # Path to the original context file
            insert_strings=item["statements_misleading"],    # List of misleading statements
            context_length=context_length,                   # Max context length for insertion
            output_path=haystack_file                        # Output path for the processed file
        )
    else:
        haystack_file = original_dir / original_file

    return haystack_file

def get_args(id, context_type, with_misleading, data=NEEDLES_DATA):
    """
    Get the arguments for the given id and context type.
    Arguments:
        id (int): The id of the item.
        context_type (str): The type of context ("relevant" or "irrelevant").
        with_misleading (bool): Whether to include misleading information.
        data (list): The list of data items.
    Returns:
        Namespace: The arguments for the given id and context type.
    """

    context_length_max = 5000
    
    # find the item with the given id
    item = next(item for item in data if item["id"] == id)

    # get the haystack file for the given item and context type
    haystack_file = get_haystack_file(item, context_type, with_misleading, context_length_max)

    args = Namespace(
        model_name = "yaofu/llama-2-7b-80k",
        model_name_suffix = ( f"{context_type}_id_{item['id']}" if not with_misleading 
                             else f"{context_type}_id_{item['id']}_misleading"), # suffix used to name the results files,
        model_provider = "LLaMA",

        context_lengths_min = 0, # min context length
        context_lengths_max = context_length_max, # max context length
        context_lengths_num_intervals = 10, # number of intervals for context lengths

        document_depth_percent_intervals = 5, # number of intervals for document depth

        needle = item["needle"],
        real_needle = item["real_needle"],
        retrieval_question = item["question"],
        haystack_file = haystack_file,
    )

    return args

In [5]:
import gc
import torch
from retrieval_head_detection import LLMNeedleHaystackTester as RetrievalHeads

def cleanup(tester: RetrievalHeads):
    del tester.model_to_test
    del tester.testing_results
    del tester.head_counter
    del tester
    
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

def run_test(args):
    tester = RetrievalHeads(**vars(args))
    tester.start_test()
    cleanup(tester)


## Runing the Tests

### Set up the test object:
Specify the **id** and the **context_type** in the `get_args` function.

If `with_misleading` is True, then misleading statements are added to the specified `context_type` file.

### Start the Test

Running the test will save:

- **Contexts** → Given to the model per context length/depth  
  → `contexts/{model_name}_{context_type}_id_{id}`

- **Results** → Model outputs per context length/depth  
  → `results/graph/{model_name}_{context_type}_id_{id}`

**Example**:  
`results/graph/llama-2-7b-80k_relevant_id_44/`

Each test should take <=5 minutes. It sometimes uses the CPU instead of the GPU, if it is taking longer to finish, then restart the notebook and run again. 

In [6]:

# Get the arguments for the given id and context type
args = get_args(id=44, context_type="relevant", with_misleading=True)
run_test(args)

args = get_args(id=44, context_type="irrelevant", with_misleading=True)
run_test(args)

args = get_args(id=44, context_type="relevant", with_misleading=False)
run_test(args)

args = get_args(id=44, context_type="irrelevant", with_misleading=False)
run_test(args)

✅ Output saved to /cs/student/projects1/2021/aagarwal/SNLP_Project/haystack/misleading_in_relevant/44.txt
loading from yaofu/llama-2-7b-80k
layer number: 32, head number 32


The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.
LlamaForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Starting Needle In A Haystack Testing...
- Model: yaofu/llama-2-7b-80k
- Context Lengths: 10, Min: 0, Max: 5000
- Document Depths: 5, Min: 0%, Max: 100%
- Needle: There won't be a FIFA World Cup this year.



insertion at 0
There won't be a FIFA World Cup this year.
[['16-19'], ['7-4'], ['11-15'], ['6-9'], ['19-15'], ['6-30'], ['8-26'], ['11-2'], ['12-26'], ['17-22'], ['21-30'], ['28-14'], ['19-12'], ['21-4'], ['24-11'], ['24-29'], ['24-30'], ['25-3'], ['29-21'], ['29-26']]
-- Test Summary -- 
Duration: 0.7 seconds
Context: 0 tokens
Depth: 0%
Score: 100.0
Response: There won't be a FIFA World Cup this year.

Writing at results/graph/llama-2-7b-80k_relevant_id_44_misleading/llama-2-7b-80k_relevant_id_44_misleading_len_0_depth_0_results.json
insertion at 0
There won't be a FIFA World Cup this year.
[['16-19'], ['7-4'], ['11-15'], ['6-9'], ['19-15'], ['6-30'], ['8-26'], ['11-2'], ['12-26'], ['17-22'], ['21-30'], ['28-14'], ['19-12'], ['21-4'], ['24-11'], ['24-29'], ['24-30'], ['25-3'], 

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Starting Needle In A Haystack Testing...
- Model: yaofu/llama-2-7b-80k
- Context Lengths: 10, Min: 0, Max: 5000
- Document Depths: 5, Min: 0%, Max: 100%
- Needle: There won't be a FIFA World Cup this year.



insertion at 0
There won't be a FIFA World Cup this year.
[['16-19'], ['7-4'], ['11-15'], ['6-9'], ['19-15'], ['6-30'], ['8-26'], ['11-2'], ['12-26'], ['17-22'], ['21-30'], ['28-14'], ['19-12'], ['21-4'], ['24-11'], ['24-29'], ['24-30'], ['25-3'], ['29-21'], ['29-26']]
-- Test Summary -- 
Duration: 0.6 seconds
Context: 0 tokens
Depth: 0%
Score: 100.0
Response: There won't be a FIFA World Cup this year.

Writing at results/graph/llama-2-7b-80k_irrelevant_id_44_misleading/llama-2-7b-80k_irrelevant_id_44_misleading_len_0_depth_0_results.json
insertion at 0
There won't be a FIFA World Cup this year.
[['16-19'], ['7-4'], ['11-15'], ['6-9'], ['19-15'], ['6-30'], ['8-26'], ['11-2'], ['12-26'], ['17-22'], ['21-30'], ['28-14'], ['19-12'], ['21-4'], ['24-11'], ['24-29'], ['24-30'], ['25-3

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Starting Needle In A Haystack Testing...
- Model: yaofu/llama-2-7b-80k
- Context Lengths: 10, Min: 0, Max: 5000
- Document Depths: 5, Min: 0%, Max: 100%
- Needle: There won't be a FIFA World Cup this year.



insertion at 0
There won't be a FIFA World Cup this year.
[['16-19'], ['7-4'], ['11-15'], ['6-9'], ['19-15'], ['6-30'], ['8-26'], ['11-2'], ['12-26'], ['17-22'], ['21-30'], ['28-14'], ['19-12'], ['21-4'], ['24-11'], ['24-29'], ['24-30'], ['25-3'], ['29-21'], ['29-26']]
-- Test Summary -- 
Duration: 0.6 seconds
Context: 0 tokens
Depth: 0%
Score: 100.0
Response: There won't be a FIFA World Cup this year.

Writing at results/graph/llama-2-7b-80k_relevant_id_44/llama-2-7b-80k_relevant_id_44_len_0_depth_0_results.json
insertion at 0
There won't be a FIFA World Cup this year.
[['16-19'], ['7-4'], ['11-15'], ['6-9'], ['19-15'], ['6-30'], ['8-26'], ['11-2'], ['12-26'], ['17-22'], ['21-30'], ['28-14'], ['19-12'], ['21-4'], ['24-11'], ['24-29'], ['24-30'], ['25-3'], ['29-21'], ['29-26']]


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Starting Needle In A Haystack Testing...
- Model: yaofu/llama-2-7b-80k
- Context Lengths: 10, Min: 0, Max: 5000
- Document Depths: 5, Min: 0%, Max: 100%
- Needle: There won't be a FIFA World Cup this year.



insertion at 0
There won't be a FIFA World Cup this year.
[['16-19'], ['7-4'], ['11-15'], ['6-9'], ['19-15'], ['6-30'], ['8-26'], ['11-2'], ['12-26'], ['17-22'], ['21-30'], ['28-14'], ['19-12'], ['21-4'], ['24-11'], ['24-29'], ['24-30'], ['25-3'], ['29-21'], ['29-26']]
-- Test Summary -- 
Duration: 0.6 seconds
Context: 0 tokens
Depth: 0%
Score: 100.0
Response: There won't be a FIFA World Cup this year.

Writing at results/graph/llama-2-7b-80k_irrelevant_id_44/llama-2-7b-80k_irrelevant_id_44_len_0_depth_0_results.json
insertion at 0
There won't be a FIFA World Cup this year.
[['16-19'], ['7-4'], ['11-15'], ['6-9'], ['19-15'], ['6-30'], ['8-26'], ['11-2'], ['12-26'], ['17-22'], ['21-30'], ['28-14'], ['19-12'], ['21-4'], ['24-11'], ['24-29'], ['24-30'], ['25-3'], ['29-21'], ['29-26