In [22]:
import torch
import wandb
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging

# Must login first with huggingface-cli login
# Get a read token from https://huggingface.co/settings/tokens


# --- CONFIGURATION ---
CONFIG = {
    "project_name": "slm-long-context-nolima",
    "model_id": "meta-llama/Llama-3.1-8B-Instruct", # "Qwen/Qwen2.5-7B-Instruct", # "meta-llama/Llama-3-8b-Instruct",  # Replace with your SLM of choice
    "context_lengths": [4096, 8192, 16000, 32000], # The lengths to test
    "num_samples_per_length": 50,                  # N samples per data point
    "needle_depths": [(0.2, 0.8), (0.1, 0.9)],     # (Needle1_Depth, Needle2_Depth) pairs to test
}

In [23]:
# --- 1. DATA GENERATION: THE NoLiMa PROTOCOL ---
class NoLiMaDataset:
    def __init__(self):
        # Dictionary of (Question, Bridge_Fact, Final_Fact, Answer)
        # CRITICAL: The Question must NOT share words with the Bridge/Final Facts.
        self.scenarios = [
            {
                "qid": "tech_history_01",
                "question": "Which continent is the birth place of the creator of the Linux kernel located in?",
                # Bridge: Linux kernel -> Linus Torvalds (Implicit) -> Fact 1
                "needle_1": "Linus Torvalds was born in the city of Helsinki.",
                # Final: Helsinki -> Europe (Implicit) -> Fact 2
                "needle_2": "The city of Helsinki is the capital of Finland, a country in Northern Europe.",
                "answer": "Europe",
                "bridge_entity": "Helsinki",
                "distractors": [
                    "Steve Jobs was born in San Francisco.",
                    "Bill Gates founded Microsoft in Albuquerque.",
                    "The Python language was created by Guido van Rossum."
                ]
            },
            # Add more templates here for a real run
        ]
        
        # Load a "boring" background text (e.g., Paul Graham essays or public domain text)
        self.background_text = "The quick brown fox jumps over the lazy dog. " * 5000 

    def generate_sample(self, context_len, depth_1_pct, depth_2_pct):
        scenario = np.random.choice(self.scenarios)
        
        # 1. Truncate background to target length (approx tokens)
        target_chars = context_len * 4 # Rough approx
        haystack = self.background_text[:target_chars]
        
        # 2. Insert Needles
        # We slice the string to insert needles at specific % depths
        len_h = len(haystack)
        idx_1 = int(len_h * depth_1_pct)
        idx_2 = int(len_h * depth_2_pct)
        
        # Ensure order (Needle 1 before Needle 2 or vice versa? We test both implicitly)
        first, second = sorted([(idx_1, scenario['needle_1']), (idx_2, scenario['needle_2'])])
        
        final_context = (
            haystack[:first[0]] + 
            f" {first[1]} " + 
            haystack[first[0]:second[0]] + 
            f" {second[1]} " + 
            haystack[second[0]:]
        )
        
        return {
            "context": final_context,
            "question": scenario['question'],
            "answer": scenario['answer'],
            "n1_text": scenario['needle_1'],
            "n2_text": scenario['needle_2']
        }

In [24]:
# --- 2. METRICS & SYSTEMS LOGGING ---
def calculate_attention_attribution(model_outputs, input_ids, needle_token_ids):
    """
    Surgical System Metric:
    Did the model ACTUALLY attend to the needle tokens, or did it guess?
    """
    # Extract attention: Tuple of (Batch, Heads, Seq, Seq) per layer
    # We look at the last layer, aggregate over heads
    last_layer_attn = model_outputs.attentions[-1][0] # shape [Heads, Seq, Seq]
    avg_attn = torch.mean(last_layer_attn, dim=0)     # shape [Seq, Seq]
    
    # Get attention weights for the last token (generation step) attending to needle tokens
    last_token_idx = -1 
    needle_mask = torch.isin(input_ids[0], needle_token_ids.to(input_ids.device))
    
    # Sum of attention mass falling on the needle tokens
    attribution_score = avg_attn[last_token_idx][needle_mask].sum().item()
    return attribution_score

In [25]:
# --- 3. MAIN EVALUATION LOOP ---
def run_experiment():
    wandb.init(project=CONFIG["project_name"], config=CONFIG)
    
    tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_id"])
    model = AutoModelForCausalLM.from_pretrained(
        CONFIG["model_id"], 
        torch_dtype=torch.bfloat16, 
        device_map="auto",
        # "flash_attention_2" # Faster, but requires an extra kernel install
        # "sdpa" # Pytorch implementation.
        # "eager" # standard (unfused) PyTorch implementation. It will be slower and use more memory, but it allows inspection.
        attn_implementation= "sdpa"
    )
    
    dataset = NoLiMaDataset()
    
    print(f"Starting Grid Search: {CONFIG['context_lengths']} tokens")
    
    for ctx_len in CONFIG["context_lengths"]:
        for depth_pair in CONFIG["needle_depths"]:
            scores = []
            attributions = []
            
            for i in range(CONFIG["num_samples_per_length"]):
                # A. Prepare Data
                sample = dataset.generate_sample(ctx_len, depth_pair[0], depth_pair[1])
                
                # B. Tokenize
                prompt = f"Context: {sample['context']}\n\nQuestion: {sample['question']}\nAnswer:"
                inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
                
                # Identify needle token IDs for attribution tracking
                n1_ids = tokenizer(sample['n1_text'], add_special_tokens=False).input_ids
                n2_ids = tokenizer(sample['n2_text'], add_special_tokens=False).input_ids
                needle_ids = torch.tensor(n1_ids + n2_ids)

                # C. Inference (Forward Pass + Gen)
                with torch.no_grad():
                    outputs = model.generate(
                        **inputs, 
                        max_new_tokens=20,
                        # Only request attentions if not using SDPA
                        output_attentions=(model.config.attn_implementation != "sdpa"),
                        return_dict_in_generate=True,
                        use_cache=True
                    )
                
                # D. Decode & Metric Calculation
                pred = tokenizer.decode(outputs.sequences[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
                
                # Metric 1: Exact Match (Relaxed)
                is_correct = sample['answer'].lower() in pred.lower()
                
                # Metric 2: Attention Mass on Needles (if not using SPDA)
                # We pass the full outputs (including attention) and the input/needle IDs
                attn_mass = 0.0
                if outputs.attentions is not None:
                    attn_mass = calculate_attention_attribution(outputs, inputs.input_ids, needle_ids)
                else:
                    # Log -1 or None to indicate "metric not available due to optimization"
                    attn_mass = -1.0
                
                scores.append(is_correct)
                attributions.append(attn_mass)
                
                # E. Step-wise Logging
                wandb.log({
                    "context_length": ctx_len,
                    "depth_1": depth_pair[0],
                    "depth_2": depth_pair[1],
                    "prediction": pred,
                    "correct": int(is_correct),
                    "attn_mass_on_needles": attn_mass,
                    "example_prompt": prompt[:200] + "..." # Log snippet only
                })

            # Aggregate per context length
            avg_acc = sum(scores) / len(scores)
            avg_attn = sum(attributions) / len(attributions)
            
            print(f"[L={ctx_len}] Acc: {avg_acc:.2f} | Attn: {avg_attn:.4f}")
            wandb.log({
                "summary/accuracy_vs_length": avg_acc, 
                "summary/attn_vs_length": avg_attn,
                "length_axis": ctx_len
            })

    wandb.finish()

In [26]:
if __name__ == "__main__":
    run_experiment()

OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct.
401 Client Error. (Request ID: Root=1-69436bcd-676c32f12309dd083a88636d;383f7768-7d45-4787-a5b1-7197266495dc)

Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.1-8B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in.