# Notebook Overview

This notebook provides an faster testing interface. 

It includes steps to 
 - Check GPU availability
 - Set up the environment, 
 - Define helper functions,
 - Run tests with a given **id** and **context_type**. 

The results are saved for further analysis.

*Make sure to select the created `venv` as your kernel.*

If there are any errors while running the tests, restart the the notebook, and run again.

## Check GPU Availability
Run this cell to check if there is free space on the gpu. 

If the **free space** is not close to the **total space** then someone else is probably using the machine. 

You can check the active processes bu running `nvidia-smi` in the terminal.

In [1]:
import subprocess

def get_gpu_memory():
    result = subprocess.run(
        ['nvidia-smi', '--query-gpu=memory.free,memory.total', '--format=csv,nounits,noheader'],
        stdout=subprocess.PIPE, text=True
    )
    lines = result.stdout.strip().split('\n')
    for i, line in enumerate(lines):
        free, total = map(int, line.split(','))
        print(f"GPU {i}: {free / 1024: .4} GiB free / {total / 1024: .4} GiB total")

get_gpu_memory()

GPU 0:  23.46 GiB free /  23.99 GiB total


## Setup
The code below imports necessary modules, defines directory paths, and sets the Hugging Face cache path.

**No need to modify this.**

In [2]:
import os
import sys
import json
from pathlib import Path

PROJECT_DIR = Path.cwd()
HAYSTACK_DIR = PROJECT_DIR / "haystack"
RELEVANT_DIR = HAYSTACK_DIR / "relevant"
IRRELAVANT_DIR = HAYSTACK_DIR / "irrelevant"

sys.path.append(str(PROJECT_DIR))

with open(HAYSTACK_DIR / "needles.json", "r") as f:
    NEEDLES_DATA = json.load(f)

# This sets the Hugging Face cache path. Make sure this directory exists. If not, refer to the README.
os.environ['HF_HOME'] = '.cache/hf_with_quota'

## Helper Function

In [3]:
from argparse import Namespace

def get_args(id, context_type, data=NEEDLES_DATA):
    """
    Get the arguments for the given id and context type.
    Arguments:
        id (int): The id of the item.
        context_type (str): The type of context ("relevant" or "irrelevant").
        data (list): The list of data items.
    Returns:
        Namespace: The arguments for the given id and context type.
    """
    
    # find the item with the given id
    item = next(item for item in data if item["id"] == id)

    if context_type == "relevant":
        haystack_file = RELEVANT_DIR / item["context_relevant"]
    elif context_type == "irrelevant":
        haystack_file = IRRELAVANT_DIR / item["context_irrelevant"]

    args = Namespace(
        model_name = "yaofu/llama-2-7b-80k",
        model_name_suffix = f"{context_type}_id_{item['id']}", # suffix used to name the results files,
        model_provider = "LLaMA",

        context_lengths_min = 0, # min context length
        context_lengths_max = 5000, # max context length
        context_lengths_num_intervals = 10, # number of intervals for context lengths

        document_depth_percent_intervals = 10, # number of intervals for document depth

        needle = item["needle"],
        real_needle = item["real_needle"],
        retrieval_question = item["question"],
        haystack_file = haystack_file,
    )

    return args

## Runing the Tests

### Set up the test object:
Specify the **id** and the **context_type** in the `get_args` function. 

The notebook currently supports only individual testing for relevant and irrelevant context. 

The code for testing with misleading context will be added soon.

In [4]:
from retrieval_head_detection import LLMNeedleHaystackTester as RetrievalHeads

# Get the arguments for the given id and context type
args = get_args(id=44, context_type="relevant")

# Create the Tester object.
rh = RetrievalHeads(**vars(args))

loading from yaofu/llama-2-7b-80k
layer number: 32, head number 32


The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.
LlamaForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

### Start the Test

Running the test will save:

- **Contexts** → Given to the model per context length/depth  
  → `contexts/{model_name}_{context_type}_id_{id}`

- **Results** → Model outputs per context length/depth  
  → `results/graph/{model_name}_{context_type}_id_{id}`

**Example**:  
`results/graph/llama-2-7b-80k_relevant_id_44/`

Each test should take <=5 minutes. It sometimes uses the CPU instead of the GPU, if it is taking longer to finish, then restart the notebook and run again. 

In [5]:
rh.start_test()



Starting Needle In A Haystack Testing...
- Model: yaofu/llama-2-7b-80k
- Context Lengths: 10, Min: 0, Max: 5000
- Document Depths: 10, Min: 0%, Max: 100%
- Needle: There won't be a FIFA World Cup this year.



insertion at 0
There won't be a FIFA World Cup this year.
[['16-19'], ['7-4'], ['11-15'], ['6-9'], ['19-15'], ['6-30'], ['8-26'], ['11-2'], ['12-26'], ['17-22'], ['21-30'], ['28-14'], ['19-12'], ['21-4'], ['24-11'], ['24-29'], ['24-30'], ['25-3'], ['29-21'], ['29-26']]
-- Test Summary -- 
Duration: 0.7 seconds
Context: 0 tokens
Depth: 0%
Score: 100.0
Response: There won't be a FIFA World Cup this year.

Writing at results/graph/llama-2-7b-80k_relevant_id_44/llama-2-7b-80k_relevant_id_44_len_0_depth_0_results.json
insertion at 0
There won't be a FIFA World Cup this year.
[['16-19'], ['7-4'], ['11-15'], ['6-9'], ['19-15'], ['6-30'], ['8-26'], ['11-2'], ['12-26'], ['17-22'], ['21-30'], ['28-14'], ['19-12'], ['21-4'], ['24-11'], ['24-29'], ['24-30'], ['25-3'], ['29-21'], ['29-26']]