### Check if GPU has free memory

In [None]:
import subprocess

def get_gpu_memory():
    result = subprocess.run(
        ['nvidia-smi', '--query-gpu=memory.free,memory.total', '--format=csv,nounits,noheader'],
        stdout=subprocess.PIPE, text=True
    )
    lines = result.stdout.strip().split('\n')
    for i, line in enumerate(lines):
        free, total = map(int, line.split(','))
        print(f"GPU {i}: {free / 1024: .4} GiB free / {total / 1024: .4} GiB total")

get_gpu_memory()

## Setup

In [None]:
import os
import sys
from pathlib import Path

PROJECT_DIR = Path.cwd()
HAYSTACK_DIR = PROJECT_DIR / "haystack"
RELEVANT_DIR = HAYSTACK_DIR / "relevant"
IRRELAVANT_DIR = HAYSTACK_DIR / "irrelevant"

sys.path.append(str(PROJECT_DIR))

# Set Hugging Face cache path. Use Absolute path.
os.environ['HF_HOME'] = '/cs/student/projects1/2021/aagarwal/SNLP_Project/.cache/hf_with_quota'

In [None]:
import json
from argparse import Namespace

def get_args(id, context_type, data):
    
    item = next(item for item in data if item["id"] == id)

    if context_type == "relevant":
        haystack_file = RELEVANT_DIR / item["context_relevant"]
    elif context_type == "irrelevant":
        haystack_file = IRRELAVANT_DIR / item["context_irrelevant"]

    args = Namespace(
        model_name = "yaofu/llama-2-7b-80k",
        model_name_suffix = f"{context_type}_id_{item['id']}",
        model_provider = "LLaMA",

        context_lengths_min = 0,
        context_lengths_max = 5000,
        context_lengths_num_intervals = 10,

        document_depth_percent_intervals = 10,

        needle = item["needle"],
        real_needle = item["real_needle"],
        retrieval_question = item["question"],
        haystack_file = haystack_file,
    )

    return args

## Run Tests

In [None]:
from retrieval_head_detection import LLMNeedleHaystackTester as Tester

with open(HAYSTACK_DIR / "needles.json", "r") as f:
    data = json.load(f)

args = get_args(id=44, context_type="relevant", data=data)
tester = Tester(**vars(args))

tester.start_test()