In [12]:
import torch

# Check if GPU is available
gpu_available = torch.cuda.is_available()
print("Is GPU available?:", gpu_available)

# If yes, print GPU details
if gpu_available:
    print("Number of GPUs:", torch.cuda.device_count())
    print("GPU Name:", torch.cuda.get_device_name(0))
    print("Current device:", torch.cuda.current_device())
else:
    print("Using CPU only")

Is GPU available?: True
Number of GPUs: 2
GPU Name: NVIDIA GeForce RTX 3080 Ti
Current device: 0


## Section 1: Importing

### **Importing Libraries**

In [13]:
import os
import time # for the delay before nvidia-smi
import warnings # for non-critical warnings
import shutil
import tarfile # for .tar archives
import ast
import json
import random
import re
import textwrap # for snipper preview
import traceback
import numpy as np
import torch
from IPython.display import display, HTML
import requests
# import torch # already imported
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# from typing import List, Optional # For type hinting (optional)
# from transformers import PreTrainedTokenizerBase # For type hinting (optional)
from datasets import load_dataset, DownloadMode
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError

from tqdm.auto import tqdm

from rank_bm25 import BM25Okapi
import importlib.metadata as md # For version checking after install

print("Libraries correctly imported")

ImportError: C extension: pandas.compat._constants not built. If you want to import pandas from the source directory, you may need to run 'python setup.py build_ext' to build the C extensions first.

In [None]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

### **Importing Prompt Templates**

In [14]:
from prompts import (
    build_baseline_prompt_v1,
    build_rag_prompt_v1,
    build_baseline_prompt_v2,
    build_rag_prompt_v2,
    build_baseline_prompt_v3,
    build_rag_prompt_v3,
    build_baseline_prompt_v4,
    build_rag_prompt_v4,
    build_baseline_prompt_v5,
    build_rag_prompt_v5,
    truncate_to_n_tokens,
    build_baseline_prompt_v6,
    build_rag_prompt_v6,
    build_baseline_prompt_v6_2,
    build_rag_prompt_v6_2,
    build_baseline_prompt_v6_3,
    build_rag_prompt_v6_3,
    build_baseline_prompt_v7,
    build_rag_prompt_v7,
    build_baseline_prompt_v8,
    build_rag_prompt_v8,
    build_baseline_prompt_v9,
    build_rag_prompt_v9,)


### **Other Useful Imports**

In [3]:
from utils import _qcfg_to_dict , download_github_raw_json, robust_code_tokenizer_for_s5



## Section 2: LLM &Tokenizer Loading with 4-bit Quantization Model + Save in cache

### **LLM & Tokenizer Loading**

In [4]:
# check that there is only one selected model

# --- Gemma Series (Google) ---
# model_name = "google/codegemma-7b"
# model_name = "google/codegemma-7b-it"

# --- Qwen Series (Alibaba) ---
# model_name = "Qwen/Qwen2.5-Coder-7B-Instruct"
# model_name = "Qwen/Qwen2.5-Coder-3B-Instruct"
# model_name = "Qwen/Qwen2.5-Coder-1.5B-Instruct"

# --- Deepseek Coder Series (Deepseek AI) ---
model_name = "deepseek-ai/deepseek-coder-7b-instruct-v1.5"
#model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
# model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
#model_name = "deepseek-ai/deepseek-coder-1.3b-base"

# --- Code Llama Series (Meta) ---
# model_name = "codellama/CodeLlama-7b-Instruct-hf"

# --- Phi Series (Microsoft) ---
# model_name = "microsoft/Phi-4-mini-instruct"
# model_name = "microsoft/Phi-4-multimodal-instruct"

TRUST_REMOTE_CODE_MODELS = ["microsoft/Phi-","Qwen/",]
trust_code = any(model_name.startswith(prefix) for prefix in TRUST_REMOTE_CODE_MODELS)
print(f"Setting trust_remote_code={trust_code} for {model_name}")
if trust_code:
    print("WARNING: trust_remote_code=True will execute Python code from the model's Hugging Face repo")

# Definisci dove salvare la cache in locale (nella cartella 'cache' del progetto)
CACHE_ROOT = "./cache"

# Sottocartella per ogni modello e configurazione
CACHE_DIR = os.path.join( CACHE_ROOT,model_name.replace("/", "_") + "_4bit_nf4")
META_FILE = os.path.join(CACHE_DIR, "metadata.json")

# Mostra dove verrà salvata la cache
print(f"Model cache will be saved in: {CACHE_DIR}")

# == Build the 4‐bit config (for GPU only) ===================
# (use bfloat16 on bf16‐capable GPUs, else float16)
compute_dtype = (torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16)

QUANT_CFG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=True,
)


REQ_META = { "model_name": model_name, "quant_cfg":  _qcfg_to_dict(QUANT_CFG)}


Setting trust_remote_code=False for deepseek-ai/deepseek-coder-7b-instruct-v1.5
Model cache will be saved in: ./cache/deepseek-ai_deepseek-coder-7b-instruct-v1.5_4bit_nf4


### **Caching**

In [5]:
# == Check for existing cache & metadata ====================
use_cache = False
if os.path.isfile(META_FILE):
    try:
        saved = json.load(open(META_FILE))
        use_cache = saved == REQ_META
        print(" Cache metadata match:", use_cache)
    except Exception:
        print(" Could not parse metadata.json; ignoring cache.")

# == 7.  Load tokenizer & model (fast or slow path) ==============
trust_code = model_name.startswith(("microsoft/Phi-", "Qwen/"))

try:
    if use_cache:
        print(" Loading from cache…")
        tokenizer = AutoTokenizer.from_pretrained(CACHE_DIR, local_files_only=True, trust_remote_code=trust_code)
        model     = AutoModelForCausalLM.from_pretrained(
            CACHE_DIR,
            device_map="auto",
            low_cpu_mem_usage=True,
            trust_remote_code=trust_code,
        )
    else:
        # decide whether we *can* do 4-bit quant:
        do_4bit = torch.cuda.is_available()
        print(f" No valid cache. CUDA available? {do_4bit}")
        print(f" {'Quantising 4-bit…' if do_4bit else 'Loading fp16…'} this will happen once")

        # ─── tokenizer ───────────────────────────────────────────
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_code)
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = tokenizer.eos_token_id
            if tokenizer.pad_token is None:
                tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})

        # ─── model ───────────────────────────────────────────────
        if do_4bit:
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                quantization_config=QUANT_CFG,
                device_map="auto",
                trust_remote_code=trust_code,
                low_cpu_mem_usage=True,
            )
        else:
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=compute_dtype,
                device_map="auto",
                trust_remote_code=trust_code,
            )

        # ─── Save cache for next time ───────────────────────────
        print("Saving to cache…")
        os.makedirs(CACHE_DIR, exist_ok=True)
        tokenizer.save_pretrained(CACHE_DIR)
        model.save_pretrained(CACHE_DIR)
        with open(META_FILE, "w") as f:
            json.dump(REQ_META, f)
        print("Cache written at", CACHE_DIR)

    # ensure model.pad_token_id
    if getattr(model, "config", None) and model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.pad_token_id

    print("Model & tokenizer ready!")

except Exception as e:
    print("Error loading model/tokenizer:")
    traceback.print_exc()
    raise
    

 No valid cache. CUDA available? True
 Quantising 4-bit… this will happen once


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Error loading model/tokenizer:


Traceback (most recent call last):
  File "/tmp/ipykernel_2258/2649707255.py", line 39, in <module>
    model = AutoModelForCausalLM.from_pretrained(
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
    return model_class.from_pretrained(
  File "/opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3960, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4434, in _load_pretrained_model
    new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
  File "/opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py", line 961, in _load_state_dict_into_meta_model
    set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 337, in set_module_tensor_to_device
    

OutOfMemoryError: CUDA out of memory. Tried to allocate 800.00 MiB. GPU  has a total capacity of 11.67 GiB of which 700.25 MiB is free. Process 160571 has 7.16 GiB memory in use. Process 315811 has 1.90 GiB memory in use. Process 320021 has 1.90 GiB memory in use. Of the allocated memory 1.44 GiB is allocated by PyTorch, and 181.49 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Section 3: Dataset Preparation and Validation



In [None]:
# Scegli una directory locale dove vuoi salvare risultati e output
save_path = './results/'  # puoi cambiare il nome se preferisci

# Verifica/crea la directory
try:
    os.makedirs(save_path, exist_ok=True)
    print(f"Directory disponibile per i risultati: {save_path}")
except OSError as e:
    print(f"Warning: non posso creare o verificare la directory: {save_path}. Dettagli: {e}")


In [None]:
from datasets import load_dataset, DownloadMode

dataset_name = "JetBrains-Research/lca-library-based-code-generation"
data_split   = "test"

print(f"\\n  Loading dataset '{dataset_name}' (split='{data_split}')…")

try:
    # To work with all libraries
    lca_dataset_all_libraries = load_dataset( # Renamed for clarity
        dataset_name,
        split=data_split,
    )
    lca_dataset_split = lca_dataset_all_libraries
    print(f"Dataset loaded with {len(lca_dataset_all_libraries)} examples across all libraries.")

    # To work with only specific libraries for testing:
    # target_repos = ["seed-emulator", "another-repo"] # Examples
    # lca_dataset_split = lca_dataset_all_libraries.filter(
    #     lambda ex: ex["repo_name"] in target_repos
    # )
    # print(f"Filtered to {len(lca_dataset_split)} examples in {target_repos}")

except Exception as e:
    print(f" ERROR loading or filtering dataset: {e}")


## Section 4: GitHub Knowledge Base Access Setup

This section configures access to pre-built Knowledge Bases (KBs) hosted on a GitHub repository.

It performs two main tasks:
1.  Defines a helper function (`download_github_raw_json`) to fetch KB JSON files from GitHub.
2.  Sets essential GitHub repository parameters (username, repo name, branch, KB folder path) to construct the base URL for downloading KBs. It also specifies a local temporary directory for these downloads.

This setup enables subsequent sections to dynamically load specific KBs from the designated GitHub source.

In [None]:
GITHUB_USERNAME = "PatrizioAcquadro"
GITHUB_REPO_NAME = "RAG_Project_SE2"
GITHUB_BRANCH = "main"
GITHUB_KBS_FOLDER_PATH = "knowledge_bases_prod"

# Base URL per contenuti raw di GitHub
GITHUB_RAW_CONTENT_BASE_URL = f"https://raw.githubusercontent.com/{GITHUB_USERNAME}/{GITHUB_REPO_NAME}/{GITHUB_BRANCH}/{GITHUB_KBS_FOLDER_PATH}"

# Directory locale dove salvare temporaneamente le KB scaricate
LOCAL_TEMP_KB_DOWNLOAD_DIR = "./temp_downloaded_kbs"
os.makedirs(LOCAL_TEMP_KB_DOWNLOAD_DIR, exist_ok=True)

print(f"--- RAG Knowledge Base Configuration (GitHub) ---")
print(f"  KBs will be downloaded from GitHub base URL: {GITHUB_RAW_CONTENT_BASE_URL}/kb_LIBRARY_KEY.json")
print(f"  Downloaded KBs will be temporarily stored in: {LOCAL_TEMP_KB_DOWNLOAD_DIR}")


In [None]:
import requests
import os
import json

# GitHub configuration (reuse your variables)
GITHUB_USERNAME = "PatrizioAcquadro"
GITHUB_REPO_NAME = "RAG_Project_SE2"
GITHUB_BRANCH = "main"
GITHUB_KBS_FOLDER_PATH = "knowledge_bases_prod"
GITHUB_API_URL = f"https://api.github.com/repos/{GITHUB_USERNAME}/{GITHUB_REPO_NAME}/contents/{GITHUB_KBS_FOLDER_PATH}?ref={GITHUB_BRANCH}"
GITHUB_RAW_CONTENT_BASE_URL = f"https://raw.githubusercontent.com/{GITHUB_USERNAME}/{GITHUB_REPO_NAME}/{GITHUB_BRANCH}/{GITHUB_KBS_FOLDER_PATH}"

print("Fetching KB files from GitHub...")

try:
    response = requests.get(GITHUB_API_URL)
    response.raise_for_status()
    kb_files = response.json()

    kb_files = [f for f in kb_files if f['name'].endswith('.json')]
    print(f"\nFound {len(kb_files)} KB files:\n")

    for file in kb_files:
        kb_name = file["name"]
        raw_url = f"{GITHUB_RAW_CONTENT_BASE_URL}/{kb_name}"
        print(f"- {kb_name}")
        
        try:
            kb_resp = requests.get(raw_url)
            kb_resp.raise_for_status()
            kb_json = kb_resp.json()

            if isinstance(kb_json, list):
                print(f"  Entries: {len(kb_json)} (list)")
            elif isinstance(kb_json, dict):
                print(f"  Top-level keys: {list(kb_json.keys())}")
            else:
                print("  Unknown KB format")

        except Exception as e:
            print(f"  Failed to load: {e}")

except Exception as e:
    print(f"Error fetching KB metadata from GitHub: {e}")


## Section 5: BM25 Retrieval Analysis

This section is dedicated to analyzing the BM25 retrieval process for a selected code generation sample. It is divided into two main parts:
1.  **Configuration:** Defining all parameters for the retrieval analysis (The sample to inspect, BM25 settings, and the tokenization strategy).
2.  **Execution & Display:** Loading the relevant KB, performing BM25 retrieval based on the configurations, and displaying the results.

This allows for easy experimentation with different retrieval settings before full-scale evaluation.

It uses pre-built KBs from `DRIVE_KBS_ROOT_PATH`.

In [None]:
# --- 1. Rep & Sample Selection ---
ANALYSIS_TARGET_REPO_FULL_NAME = "pyscf__pyscf"
ANALYSIS_SAMPLE_INDEX_WITHIN_REPO = 0 # If line above = "None", it's chosen randomly from all repos

# --- 2. BM25 Algorithm & Retrieval Parameters ---
ANALYSIS_BM25_K1 = 1.5
ANALYSIS_BM25_B = 0.75
ANALYSIS_TOP_K_SNIPPETS = 1

# --- 3. BM25 Tokenizer Selection ---
# Define tokenizer functions here. The selected one will be used by BM25 in Cell 5.2.
def robust_code_tokenizer_for_s5(text_input):
    if not isinstance(text_input, str): return []
    text = text_input.lower()
    raw_tokens = re.split(r'[^a-z0-9_]+', text) # Keep alphanumeric and underscore
    # Filter out empty strings, single characters (often noise), and pure numbers
    return [token for token in raw_tokens if token and len(token) > 1 and not token.isdigit()]

# Select the tokenizer function to be used:
ANALYSIS_BM25_TOKENIZER = robust_code_tokenizer_for_s5

# --- 4. Display Options for Analysis Cell (Cell 5.2) ---
ANALYSIS_SHOW_QUERY_TOKENS = True
ANALYSIS_HIGHLIGHT_KEYWORDS = True # In retrieved snippets

print("  Configuration for BM25 Retrieval Analysis (Section 5) is set:")

if ANALYSIS_TARGET_REPO_FULL_NAME:
    print(f"    Target Library for Analysis: '{ANALYSIS_TARGET_REPO_FULL_NAME}'")
    print(f"    Instruction Index within this library's samples: {ANALYSIS_SAMPLE_INDEX_WITHIN_REPO}")
else:
    print(f"    Sample Index from lca_dataset_split for Analysis: {ANALYSIS_SAMPLE_INDEX_WITHIN_REPO}")

print(f"    BM25 Params: k1={ANALYSIS_BM25_K1}, b={ANALYSIS_BM25_B}")
print(f"    Number of Snippets to Retrieve (Top-K): {ANALYSIS_TOP_K_SNIPPETS}")
print(f"    Tokenizer for BM25: {ANALYSIS_BM25_TOKENIZER.__name__}")

In [None]:
# Cell 5.2: Retrieval Analysis - Execution and Display (GitHub KBs, Simplified top_k)

# Needed for highlighting if not imported globally for ANALYSIS_BM25_TOKENIZER

# --- 1. Ensure Configurations from Cell 5.1 & Globals are available ---
config_vars_s5_2_final_github = [
    'ANALYSIS_SAMPLE_INDEX_WITHIN_REPO', 'ANALYSIS_TARGET_REPO_FULL_NAME',
    'ANALYSIS_BM25_K1', 'ANALYSIS_BM25_B', 'ANALYSIS_TOP_K_SNIPPETS',
    'ANALYSIS_BM25_TOKENIZER', 'ANALYSIS_SHOW_QUERY_TOKENS', 'ANALYSIS_HIGHLIGHT_KEYWORDS'
]
if any(var_name not in globals() for var_name in config_vars_s5_2_final_github):
    raise NameError("One or more configuration variables from Cell 5.1 are not defined. Run Cell 5.1 first.")

global_vars_s5_2_final_github = ['lca_dataset_split', 'GITHUB_RAW_CONTENT_BASE_URL',
                                 'LOCAL_TEMP_KB_DOWNLOAD_DIR', 'download_github_raw_json']
if any(var_name not in globals() for var_name in global_vars_s5_2_final_github):
    raise NameError("One or more global prerequisite variables (dataset, GitHub config, download helper) are not defined.")

print(f"--- Section 5.2: Executing BM25 Retrieval Analysis (from GitHub KBs) ---")

# --- 2. Select Sample Data based on Configuration from Cell 5.1 ---
s5_instruction_to_s6 = None
s5_snippets_to_s6 = [] # Initialize for output to Section 6
library_key_for_kb_file = None # This will be the 'repo_full_name' from dataset

try:
    if ANALYSIS_TARGET_REPO_FULL_NAME: # As defined in Cell 5.1
        library_key_for_kb_file = ANALYSIS_TARGET_REPO_FULL_NAME
        library_samples = lca_dataset_split.filter(lambda ex: ex['repo_full_name'] == ANALYSIS_TARGET_REPO_FULL_NAME)
        if not library_samples:
            raise ValueError(f"No samples found for specified library (repo_full_name): '{ANALYSIS_TARGET_REPO_FULL_NAME}'.")
        if not (0 <= ANALYSIS_SAMPLE_INDEX_WITHIN_REPO < len(library_samples)): # ANALYSIS_SAMPLE_INDEX_WITHIN_REPO from Cell 5.1
            raise IndexError(f"ANALYSIS_SAMPLE_INDEX_WITHIN_REPO {ANALYSIS_SAMPLE_INDEX_WITHIN_REPO} is out of bounds for library '{ANALYSIS_TARGET_REPO_FULL_NAME}'.")
        target_sample_data = library_samples[ANALYSIS_SAMPLE_INDEX_WITHIN_REPO]
        print(f"  Analyzing instruction #{ANALYSIS_SAMPLE_INDEX_WITHIN_REPO} from library: '{library_key_for_kb_file}'")
    else:
        if not (0 <= ANALYSIS_SAMPLE_INDEX_WITHIN_REPO < len(lca_dataset_split)):
            raise IndexError(f"ANALYSIS_SAMPLE_INDEX_WITHIN_REPO {ANALYSIS_SAMPLE_INDEX_WITHIN_REPO} is out of bounds for the full lca_dataset_split.")
        target_sample_data = lca_dataset_split[ANALYSIS_SAMPLE_INDEX_WITHIN_REPO]
        library_key_for_kb_file = target_sample_data.get('repo_full_name') # Derive from sample
        print(f"  Analyzing sample at global index {ANALYSIS_SAMPLE_INDEX_WITHIN_REPO}. Derived Library Key: '{library_key_for_kb_file}'")

    s5_instruction_to_s6 = target_sample_data.get('instruction')
    if not s5_instruction_to_s6 or not library_key_for_kb_file:
        raise ValueError("Selected sample missing 'instruction' or 'repo_full_name' could not be determined.")

    print(f"\n  Target Library Key for KB: '{library_key_for_kb_file}'") # This is the 'repo_full_name'
    print(f"  Instruction Text:\n    {textwrap.fill(s5_instruction_to_s6, width=100, initial_indent='    ', subsequent_indent='    ')}")

    query_tokens = ANALYSIS_BM25_TOKENIZER(s5_instruction_to_s6) # ANALYSIS_BM25_TOKENIZER from Cell 5.1
    if ANALYSIS_SHOW_QUERY_TOKENS: # From Cell 5.1
        print(f"\n  Query Tokens (using '{ANALYSIS_BM25_TOKENIZER.__name__}'):\n    {query_tokens}")

except Exception as e_sample_select:
    print(f"  ERROR during sample selection: {type(e_sample_select).__name__}: {e_sample_select}")
    library_key_for_kb_file = None # Prevent further processing if sample selection fails

# --- 3. Load KB from GitHub, Build Index, and Perform BM25 Retrieval ---
if library_key_for_kb_file: # Proceed only if library key was successfully determined
    kb_filename_on_github = f"kb_{library_key_for_kb_file}.json"
    raw_kb_url_s5 = f"{GITHUB_RAW_CONTENT_BASE_URL}/{kb_filename_on_github}" # GITHUB_RAW_CONTENT_BASE_URL from Sec4.Cell1 (GitHub config)

    # Download to a subfolder within LOCAL_TEMP_KB_DOWNLOAD_DIR
    temp_save_subdir_for_kb_s5 = os.path.join(LOCAL_TEMP_KB_DOWNLOAD_DIR, library_key_for_kb_file) # LOCAL_TEMP_KB_DOWNLOAD_DIR from Sec4.Cell1

    print(f"\n  Attempting to download/load KB for '{library_key_for_kb_file}' from GitHub...")
    kb_data_from_git = download_github_raw_json( # download_github_raw_json helper function
        raw_kb_url_s5,
        temp_save_subdir_for_kb_s5,
        kb_filename_on_github,
        overwrite=True # For analysis, always get fresh from GitHub, or False to use local Colab cache
    )

    if not kb_data_from_git:
        print(f"  ERROR: Failed to download or parse KB for '{library_key_for_kb_file}' from GitHub.")
    else:
        print(f"  Successfully loaded KB for '{library_key_for_kb_file}' from GitHub ({len(kb_data_from_git)} snippets).")
        valid_kb_docs = [str(doc) for doc in kb_data_from_git if isinstance(doc, str) and str(doc).strip()]

        if not valid_kb_docs:
            print("  No valid string snippets in loaded KB for BM25 indexing.")
        else:
            tokenized_corpus = [ANALYSIS_BM25_TOKENIZER(doc) for doc in valid_kb_docs]
            final_bm25_corpus_docs, map_idx_bm25_to_valid_docs = [], []
            for i, tokens in enumerate(tokenized_corpus):
                if tokens:
                    final_bm25_corpus_docs.append(tokens)
                    map_idx_bm25_to_valid_docs.append(i)

            if not final_bm25_corpus_docs:
                print("  Tokenized KB is empty after filtering. BM25 index not built.")
            else:
                print(f"  Creating BM25 index from {len(final_bm25_corpus_docs)} processable documents...")
                # ANALYSIS_BM25_K1 and ANALYSIS_BM25_B are from Cell 5.1
                bm25_index = BM25Okapi(final_bm25_corpus_docs, k1=ANALYSIS_BM25_K1, b=ANALYSIS_BM25_B)
                print("  BM25 index built.")

                # ANALYSIS_TOP_K_SNIPPETS is from Cell 5.1
                if query_tokens and ANALYSIS_TOP_K_SNIPPETS > 0:
                    print(f"\n  --- Retrieving and Displaying Top {ANALYSIS_TOP_K_SNIPPETS} Snippets ---")
                    num_docs = len(final_bm25_corpus_docs)
                    top_indices = bm25_index.get_top_n(
                        query_tokens, list(range(num_docs)),
                        n=min(ANALYSIS_TOP_K_SNIPPETS, num_docs)
                    )
                    s5_snippets_to_s6 = [valid_kb_docs[map_idx_bm25_to_valid_docs[i]] for i in top_indices]

                    if not s5_snippets_to_s6:
                        print(f"    No snippets retrieved for top_k = {ANALYSIS_TOP_K_SNIPPETS}.")
                    else:
                        for i_snip, snip_content in enumerate(s5_snippets_to_s6):
                            print(f"\n    Snippet {i_snip+1}/{len(s5_snippets_to_s6)} (Length: {len(snip_content)} chars):")
                            if ANALYSIS_HIGHLIGHT_KEYWORDS: # From Cell 5.1
                                hl_content = snip_content
                                unique_qt = sorted(list(set(query_tokens)), key=len, reverse=True)
                                for i_t, tkn in enumerate(unique_qt):
                                    ph = f"__HL_S5_{i_t}__" # More specific placeholder
                                    hl_content = re.sub(f"\\b({re.escape(tkn)})\\b", ph, hl_content, flags=re.IGNORECASE)
                                for i_t, tkn in enumerate(unique_qt):
                                    ph = f"__HL_S5_{i_t}__"
                                    hl_content = hl_content.replace(ph, f"<b style='background-color:#FFFACD; color:black; font-weight:bold;'>{tkn}</b>")
                                display(HTML(f"<pre style='white-space:pre-wrap; word-wrap:break-word; border:1px dashed #ccc; padding:6px; margin-left:20px;'>{hl_content}</pre>"))
                            else:
                                print(textwrap.indent(textwrap.fill(snip_content, width=100, subsequent_indent='      '), '      '))
                    print(f"\n    Stored {len(s5_snippets_to_s6)} snippets in 's5_snippets_to_s6' for Section 6.")
                elif not query_tokens: print("  Query tokens empty. BM25 retrieval skipped.")
                else: print(f"  ANALYSIS_TOP_K_SNIPPETS ({ANALYSIS_TOP_K_SNIPPETS}) is not positive. No snippets retrieved.")
else: # library_key_for_kb_file was None due to sample selection error
    print("\n  Sample selection failed earlier. Skipping KB loading and BM25 retrieval.")

if not s5_snippets_to_s6 and library_key_for_kb_file and ('kb_data_from_git' in locals() and kb_data_from_git is not None):
    print("  INFO: No snippets were ultimately stored for Section 6 from this analysis.")

print(f"\n--- Section 5.2: Retrieval Analysis Execution Complete ---")
# Variables `s5_instruction_to_s6` and `s5_snippets_to_s6` are now populated for Section 6.

## Section 6: RAG Prompt Assembly for Demo Sample

This section assembles the final RAG prompt for the LLM, using the instruction and retrieved snippets from the Section 5 analysis.

It first sets up by selecting the desired `build_rag_prompt` function (from the globally defined prompt templates) and checks for its dependencies, like the LLM `tokenizer` if needed for snippet truncation.

Then, it takes the instruction and retrieved context (snippets) provided by Section 5, formats the snippets into a text block, and calls the chosen `build_rag_prompt` function. The resulting complete prompt string is stored in `s6_final_rag_prompt_output` and a preview is displayed, making it ready for the LLM generation step in Section 7.

In [None]:
import textwrap

# --- 1. Ensure Prerequisite variables from Section 5 (5.2) are defined ---
if 's5_instruction_to_s6' not in locals() or \
   's5_snippets_to_s6' not in locals():
    raise NameError("Variables 's5_instruction_to_s6' or 's5_snippets_to_s6' not found. "
                    "Ensure Section 5 (BM25 Retrieval Analysis - Execution and Display) has been run successfully.")

# --- 2. Select the RAG Prompt Builder Function ---
build_rag_prompt_to_use_in_s6 = build_rag_prompt_v6_3 # Defaulting to v6 as an example

# Verify that the chosen function is actually defined
if 'build_rag_prompt_to_use_in_s6' not in locals() or not callable(build_rag_prompt_to_use_in_s6):
    raise NameError(f"The function assigned to 'build_rag_prompt_to_use_in_s6' is not defined or not callable. "
                    f"Please check its definition and assignment in this cell.")

print(f"INFO: Using '{build_rag_prompt_to_use_in_s6.__name__}' for RAG prompt assembly in this section.")


# --- 3. Check for Tokenizer Dependency if needed by the chosen prompt builder ---
# Important if the chosen prompt builder (v5/v6) uses a helper like `truncate_to_n_tokens` which itself requires the LLM's tokenizer.
TOKENIZER_NEEDED_BY_SELECTED_PROMPT_BUILDER = False
# Heuristic: Check if 'truncate_to_n_tokens' is called by the selected builder.
# This assumes 'truncate_to_n_tokens' is the name of the helper that needs the tokenizer.
try:
    func_code_object = getattr(build_rag_prompt_to_use_in_s6, '__code__', None)
    if func_code_object and "truncate_to_n_tokens" in func_code_object.co_names:
        TOKENIZER_NEEDED_BY_SELECTED_PROMPT_BUILDER = True
        # Also, ensure 'truncate_to_n_tokens' itself is defined globally if it's called.
        if "truncate_to_n_tokens" not in globals():
            raise NameError(f"'truncate_to_n_tokens' function is called by '{build_rag_prompt_to_use_in_s6.__name__}' "
                            "but 'truncate_to_n_tokens' itself is not defined globally.")
except AttributeError:
    # Could happen if build_rag_prompt_to_use_in_s6 is a lambda or other non-standard callable
    # For simplicity, we'll assume if we can't inspect, it might not need it, or it will fail at runtime if it does and tokenizer is missing.
    pass

if TOKENIZER_NEEDED_BY_SELECTED_PROMPT_BUILDER:
    if 'tokenizer' not in globals():
        raise NameError(f"LLM 'tokenizer' not defined globally, but it is required by the selected "
                        f"prompt builder '{build_rag_prompt_to_use_in_s6.__name__}' (likely for snippet truncation). "
                        "Please ensure the tokenizer is loaded in an earlier section (e.g., Section 2).")
    else:
        print(f"  INFO: Selected prompt builder '{build_rag_prompt_to_use_in_s6.__name__}' may use the global 'tokenizer'. Ensure 'tokenizer' is correctly loaded.")
else:
    print(f"  INFO: Selected prompt builder '{build_rag_prompt_to_use_in_s6.__name__}' does not appear to directly require the global 'tokenizer' for truncation via 'truncate_to_n_tokens'.")

In [None]:
s6_final_rag_prompt_output = None # Initialize the output of this section

# Prepare the prompt with the instruction
if s5_instruction_to_s6 is None:
    print("  ERROR: No instruction ('s5_instruction_to_s6') available from Section 5. Cannot assemble prompt.")
else:
    print(f"  Using instruction: '{s5_instruction_to_s6[:100]}...'")
    print(f"  Using {len(s5_snippets_to_s6)} retrieved snippets from 's5_snippets_to_s6'.")


    # Prepare the prompt with the retrieved snippets (`build_rag_prompt` likely expects a single string, so they are joined by separator)
    retrieved_snippets_as_text_block = "\n\n# --- Snippet Separator ---\n\n".join(s5_snippets_to_s6) \
                                       if s5_snippets_to_s6 else ""
    if not s5_snippets_to_s6:
        print("  INFO: No snippets were provided from Section 5 ('s5_snippets_to_s6' is empty). "
              "The RAG prompt will be assembled with an empty retrieved context.")

    # --- Build the Prompt ---
    try:
        s6_final_rag_prompt_output = build_rag_prompt_to_use_in_s6(
            s5_instruction_to_s6,
            retrieved_snippets_as_text_block
        )

        if s6_final_rag_prompt_output:
            print("\n  Final RAG Prompt Assembled (First 700 Characters):")
            print(textwrap.shorten(s6_final_rag_prompt_output, width=700, placeholder="... (prompt truncated) ..."))
            # For the full prompt if needed for debugging:
            # print(s6_final_rag_prompt_output)
        else:
             print("  ERROR: Prompt assembly using your 'build_rag_prompt' function "
                   "resulted in an empty or None prompt. Please check the function's logic.")
    except Exception as e_build_prompt:
        print(f"  ERROR during prompt assembly with '{build_rag_prompt_to_use_in_s6.__name__}': {e_build_prompt}")
        s6_final_rag_prompt_output = None


# The variable `s6_final_rag_prompt_output` now holds the complete RAG prompt string.
if s6_final_rag_prompt_output:
    print(f"\n--- Section 6: RAG Prompt Assembly Complete. Output in 's6_final_rag_prompt_output' (Length: {len(s6_final_rag_prompt_output)} chars) ---")
else:
    print(f"\n--- Section 6: RAG Prompt Assembly Failed or Produced No Output ---")

## Section 7: LLM Generation for RAG Demo Prompt

This section generates code using the LLM for the demo RAG prompt assembled in Section 6.

It first **configures LLM generation settings**, including global inference parameters and a custom stopping criteria class (`EosAndCodeEndStoppingCriteria`) to manage output length. This setup is done once.

Subsequently, it **executes the generation**:
*   Takes the RAG prompt from Section 6.
*   Uses the pre-set configurations to call `model.generate()`.
*   Decodes, cleans, and displays the LLM-generated code, storing the result in `s7_generated_code_rag_demo`.

In [None]:
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
import time
import re
import textwrap

def generate_llm_code_and_clean(
    prompt_text: str,
    llm_model,
    llm_tokenizer,
    max_new_tokens_gen,
    do_sample_gen,
    temperature_gen,
    top_p_gen,
    top_k_gen,
    repetition_penalty_gen,
    stopping_criteria_list_gen, # Can be None
    prompt_name: str = "Prompt" # For logging
    ) -> str | None:
    cleaned_generated_code = None
    print(f"  Starting LLM generation for: {prompt_name}")
    try:
        inputs = llm_tokenizer(
            prompt_text, return_tensors="pt"
        ).to(llm_model.device)
        prompt_len = inputs['input_ids'].shape[1]
        # print(f"    Tokenized prompt length: {prompt_len} tokens.") # Optional

        gen_args = {
            "input_ids": inputs['input_ids'], "attention_mask": inputs['attention_mask'],
            "max_new_tokens": max_new_tokens_gen, "pad_token_id": llm_tokenizer.eos_token_id,
            "repetition_penalty": repetition_penalty_gen, "stopping_criteria": stopping_criteria_list_gen
        }
        if do_sample_gen:
            gen_args.update({
                "temperature": temperature_gen, "top_p": top_p_gen,
                "top_k": top_k_gen, "do_sample": True
            })
        else:
            gen_args["do_sample"] = False

        gen_start = time.time()
        with torch.no_grad(): output_ids = llm_model.generate(**gen_args)
        print(f"    LLM generation for '{prompt_name}' finished in {time.time() - gen_start:.2f}s.")

        generated_ids_part = output_ids[0, prompt_len:]
        raw_output = llm_tokenizer.decode(generated_ids_part, skip_special_tokens=True)
        # print(f"\n    Raw LLM Output for '{prompt_name}' (first 300 chars):\n{textwrap.shorten(raw_output, 300, placeholder='...')}") # Optional

        # Using your preferred cleaning logic (can be the external function if defined)
        # For simplicity, embedding it here. If you defined `extract_code_from_llm_output` earlier, call that.
        match = re.search(r"```python\n(.*?)(?:\n```|\Z)", raw_output, re.DOTALL | re.IGNORECASE)
        if match: cleaned_generated_code = match.group(1).strip()
        elif "\n```" in raw_output: cleaned_generated_code = raw_output.split("\n```")[0].strip()
        else: cleaned_generated_code = raw_output.strip()

        # print(f"    Cleaned code for '{prompt_name}' (first 300 chars):\n{textwrap.shorten(cleaned_generated_code, 300, placeholder='...') if cleaned_generated_code else '[No code extracted]'}")

    except torch.cuda.OutOfMemoryError as e: cleaned_generated_code = None; print(f"  ❌ OOM ERROR during '{prompt_name}' generation: {e}")
    except Exception as e: cleaned_generated_code = None; print(f"  ❌ ERROR during '{prompt_name}' generation: {type(e).__name__}: {e}")

    return cleaned_generated_code

In [None]:
MAX_NEW_TOKENS = globals().get('MAX_NEW_TOKENS', 384)
TEMPERATURE = globals().get('TEMPERATURE', 0.6)
TOP_P = globals().get('TOP_P', 0.95)
TOP_K = globals().get('TOP_K', 50)
REPETITION_PENALTY = globals().get('REPETITION_PENALTY', 1.1)
DO_SAMPLE = globals().get('DO_SAMPLE', True)
STOP_ON_EOS = globals().get('STOP_ON_EOS', True)           # For custom stopping criteria
STOP_ON_CODE_END = globals().get('STOP_ON_CODE_END', True) # For custom stopping criteria

print(f"  Global LLM Generation Parameters now set/confirmed:")
print(f"    MAX_NEW_TOKENS={MAX_NEW_TOKENS}, TEMPERATURE={TEMPERATURE if DO_SAMPLE else 'N/A (Greedy)'}")
print(f"    TOP_P={TOP_P if DO_SAMPLE else 'N/A'}, TOP_K={TOP_K if DO_SAMPLE else 'N/A'}")
print(f"    REPETITION_PENALTY={REPETITION_PENALTY}, DO_SAMPLE={DO_SAMPLE}")
print(f"    Custom Stopping: STOP_ON_EOS={STOP_ON_EOS}, STOP_ON_CODE_END={STOP_ON_CODE_END}")

# --- 2. Define Custom Stopping Criteria Class (if not already globally defined) ---
if 'EosAndCodeEndStoppingCriteria' not in globals(): # Define only if not already defined
    class EosAndCodeEndStoppingCriteria(StoppingCriteria):
        """Stops generation on EOS token or a specific code-ending sequence."""
        def __init__(self, tokenizer_instance, stop_on_eos_token=True, code_end_sequence="\n```\n"):
            self.tokenizer = tokenizer_instance
            self.stop_on_eos = stop_on_eos_token
            self.code_end_sequence_str = code_end_sequence
            self.code_end_sequence_ids = []
            if self.code_end_sequence_str and self.tokenizer: # Ensure tokenizer is valid
                try:
                    self.code_end_sequence_ids = self.tokenizer.encode(
                        self.code_end_sequence_str,
                        add_special_tokens=False
                    )
                except Exception as e_encode:
                    print(f"    WARNING: Failed to encode stop sequence '{self.code_end_sequence_str}': {e_encode}")
                    self.code_end_sequence_ids = [] # Ensure it's empty on failure

        def __call__(self, current_ids: torch.LongTensor, current_scores: torch.FloatTensor, **kwargs) -> bool:
            if self.stop_on_eos and self.tokenizer and hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
                if current_ids[0, -1] == self.tokenizer.eos_token_id:
                    return True
            if self.code_end_sequence_ids: # Only check if sequence IDs were successfully created
                seq_len = len(self.code_end_sequence_ids)
                if current_ids.shape[1] >= seq_len:
                    if torch.equal(current_ids[0, -seq_len:], torch.tensor(self.code_end_sequence_ids).to(current_ids.device)):
                        return True
            return False
    print("  Custom 'EosAndCodeEndStoppingCriteria' class defined.")
else:
    print("  Custom 'EosAndCodeEndStoppingCriteria' class already defined.")


# --- 3. Instantiate Global Stopping Criteria List ---
# This `llm_stopping_criteria_global` will be used by all generation calls needing these criteria.
llm_stopping_criteria_global = None
if 'tokenizer' not in globals() or tokenizer is None: # Check if global 'tokenizer' is loaded
    print("  WARNING (Cell 7.1): Global 'tokenizer' not found or is None. "
          "Custom stopping criteria cannot be created. LLM will use default stopping.")
elif STOP_ON_EOS or STOP_ON_CODE_END: # Only create if flags are true
    try:
        # Ensure EosAndCodeEndStoppingCriteria is defined before calling it
        if 'EosAndCodeEndStoppingCriteria' not in globals():
             raise NameError("EosAndCodeEndStoppingCriteria class not defined prior to instantiation.")

        eos_code_ender_criteria_instance = EosAndCodeEndStoppingCriteria(
            tokenizer, # Use the globally loaded LLM tokenizer
            stop_on_eos_token=STOP_ON_EOS,
            code_end_sequence="\n```\n" if STOP_ON_CODE_END else None
        )
        llm_stopping_criteria_global = StoppingCriteriaList([eos_code_ender_criteria_instance])
        print("  Global 'llm_stopping_criteria_global' (for EOS/code end) created successfully.")
    except Exception as e_stop_crit_create:
        print(f"  WARNING (Cell 7.1): Failed to create global custom stopping criteria: {type(e_stop_crit_create).__name__}: {e_stop_crit_create}")
        llm_stopping_criteria_global = None # Ensure it's None on failure
else:
    print("  Global custom stopping criteria (EOS/code end) are disabled by configuration flags (STOP_ON_EOS/STOP_ON_CODE_END).")

In [None]:
if 's6_final_rag_prompt_output' not in locals() or not s6_final_rag_prompt_output:
    raise NameError("Input prompt 's6_final_rag_prompt_output' from Section 6 not found. Run Section 6 first.")
if 'model' not in locals() or 'tokenizer' not in locals():
    raise NameError("LLM 'model' or 'tokenizer' not defined. Ensure Section 2 has run.")
if 'generate_llm_code_and_clean' not in locals(): # Helper defined before 7.1
    raise NameError("Helper function 'generate_llm_code_and_clean' not defined. Ensure it's defined before Section 7.1.")

# Ensure global generation params and llm_stopping_criteria_global are available from Cell 7.1
required_globals_for_7_2 = [
    'MAX_NEW_TOKENS', 'DO_SAMPLE', 'TEMPERATURE', 'TOP_P', 'TOP_K',
    'REPETITION_PENALTY', 'llm_stopping_criteria_global' # Note: llm_stopping_criteria_global can be None
]
if any(p not in globals() for p in required_globals_for_7_2):
    raise NameError(f"One or more global LLM generation parameters/criteria from Cell 7.1 are missing.")

print(f"  Using RAG prompt (length: {len(s6_final_rag_prompt_output)} chars) from Section 6.")

# --- 2. Call Reusable LLM Generation Function ---
s7_generated_code_rag_demo = generate_llm_code_and_clean(
    prompt_text=s6_final_rag_prompt_output,
    llm_model=model,
    llm_tokenizer=tokenizer,
    max_new_tokens_gen=MAX_NEW_TOKENS,         # Global param
    do_sample_gen=DO_SAMPLE,                  # Global param
    temperature_gen=TEMPERATURE,              # Global param
    top_p_gen=TOP_P,                          # Global param
    top_k_gen=TOP_K,                          # Global param
    repetition_penalty_gen=REPETITION_PENALTY,# Global param
    stopping_criteria_list_gen=llm_stopping_criteria_global, # From Cell 7.1 (can be None)
    prompt_name="RAG Demo Prompt (S7.2)"      # For logging within the helper
)

# --- 3. Display Result ---
if s7_generated_code_rag_demo is not None: # Check if helper returned code (not None for error)
    print("\n  --- Cleaned Generated RAG Code (Demo) ---")
    # Displaying a significant portion for review
    # print(textwrap.shorten(s7_generated_code_rag_demo, width=1000, placeholder="... (cleaned code truncated for display) ..."))
    # To print the entire generated code if needed:
    print("\n  Full Cleaned Generated RAG Code (Demo):\n", s7_generated_code_rag_demo)
    print(f"\n--- RAG Demo Generation Complete. Result in 's7_generated_code_rag_demo'. ---")
else:
    print(f"\n--- RAG Demo Generation Failed or Produced No Valid Code Output. 's7_generated_code_rag_demo' is None. ---")

## Section 8: Baseline LLM Code Generation

This section generates code using the LLM for the **baseline prompt** corresponding to the same demo sample instruction analyzed in Section 5 (`s5_instruction_to_s6`). It does **not** use any retrieved RAG context.

**Workflow:**
1.  **Prompt Builder Selection:** Chooses the appropriate `build_baseline_prompt_vX` function.
2.  **Baseline Prompt Construction:** Creates the baseline prompt using the demo instruction.
3.  **LLM Generation:** Calls the reusable `generate_llm_code_and_clean` helper function with the baseline prompt. It uses the same global LLM generation parameters and stopping criteria as defined in Section 7.1 for consistency in comparing RAG vs. Baseline outputs.
4.  **Output:** Displays the cleaned baseline-generated code and stores it in `s8_generated_code_baseline_demo`.

This allows for a direct comparison between the RAG-augmented output (from Section 7) and the LLM's output with only the instruction.

In [None]:
# --- 1. Ensure Prerequisite variables and functions are defined ---
# From Section 5:
if 's5_instruction_to_s6' not in locals() or not s5_instruction_to_s6:
    raise NameError("Instruction 's5_instruction_to_s6' from Section 5 not found. Run Section 5 first.")
# From earlier sections (or Section 7.1):
if 'model' not in locals() or 'tokenizer' not in locals():
    raise NameError("LLM 'model' or 'tokenizer' not defined.")
if 'generate_llm_code_and_clean' not in locals():
    raise NameError("Helper function 'generate_llm_code_and_clean' not defined.")
# Ensure global generation params (MAX_NEW_TOKENS, etc.) and llm_stopping_criteria_global are set (typically in Sec 7.1)
required_gen_params_s8 = ['MAX_NEW_TOKENS', 'REPETITION_PENALTY', 'DO_SAMPLE', 'TEMPERATURE', 'TOP_P', 'TOP_K', 'llm_stopping_criteria_global']
if any(p not in globals() for p in required_gen_params_s8):
    raise NameError(f"One or more global LLM generation parameters/criteria needed for baseline are missing. Run Section 7.1.")

# --- 2. Select the Baseline Prompt Builder Function ---
CHOSEN_BASELINE_PROMPT_BUILDER = build_baseline_prompt_v6_3

if 'CHOSEN_BASELINE_PROMPT_BUILDER' not in locals() or not callable(CHOSEN_BASELINE_PROMPT_BUILDER):
    if 'build_baseline_prompt_v1' in globals(): # Generic name
        CHOSEN_BASELINE_PROMPT_BUILDER = build_baseline_prompt_v1
    elif 'build_baseline_prompt_v6' in globals(): # Specific version
        CHOSEN_BASELINE_PROMPT_BUILDER = build_baseline_prompt_v6
    else:
        raise NameError("A suitable 'build_baseline_prompt_vX' or 'build_baseline_prompt_v1' function is not defined.")
print(f"INFO: Using Baseline Prompt Builder: {CHOSEN_BASELINE_PROMPT_BUILDER.__name__}")

# --- 3. Construct Baseline Prompt ---
s8_generated_code_baseline_demo = None # Initialize output

# Using s5_instruction_to_s6 (the instruction for the demo sample from Section 5)
baseline_prompt_s8 = CHOSEN_BASELINE_PROMPT_BUILDER(s5_instruction_to_s6)

# print("\n  Baseline Prompt (first 500 chars):")
# print(textwrap.shorten(baseline_prompt_s8, width=500, placeholder="..."))

# --- 4. Call Reusable LLM Generation Function ---
s8_generated_code_baseline_demo = generate_llm_code_and_clean(
    prompt_text=baseline_prompt_s8,
    llm_model=model,
    llm_tokenizer=tokenizer,
    max_new_tokens_gen=MAX_NEW_TOKENS, # Global param
    do_sample_gen=DO_SAMPLE,           # Global param
    temperature_gen=TEMPERATURE,       # Global param
    top_p_gen=TOP_P,                   # Global param
    top_k_gen=TOP_K,                   # Global param
    repetition_penalty_gen=REPETITION_PENALTY, # Global param
    stopping_criteria_list_gen=llm_stopping_criteria_global, # From Sec 7.1
    prompt_name="Baseline Demo Prompt"
)

if s8_generated_code_baseline_demo is not None:
    print("\n  --- Cleaned Generated Baseline Code (Demo) ---")
    # print(textwrap.shorten(s8_generated_code_baseline_demo, width=700, placeholder="... (code truncated) ..."))
    print("\n  Full Cleaned Generated RAG Code (Demo):\n", s8_generated_code_baseline_demo)
    print(f"\n--- Baseline Demo Generation Complete. Result in 's8_generated_code_baseline_demo'. ---")
else:
    print(f"\n--- Baseline Demo Generation Failed or No Output. ---")

## Section 9 · Metrics Results


In [None]:
import re # For the tokenizer if defined here again

# --- 1. Number of Samples for Evaluation ---
# NUM_EXAMPLES_TO_EVALUATE = len(lca_dataset_split) if 'lca_dataset_split' in globals() else 10 # Evaluate all (or by defaul 10)
NUM_EXAMPLES_TO_EVALUATE = 3  # Subset for a quicker test run

# --- 2. Select Prompt Builders for Evaluation (Defined Globally) ---
EVAL_RAG_PROMPT_BUILDER = build_rag_prompt_v6 # RAG Prompt
EVAL_BASELINE_PROMPT_BUILDER = build_baseline_prompt_v6 # Baseline Prompt

print(f"  Using RAG Prompt Builder for Evaluation: {EVAL_RAG_PROMPT_BUILDER.__name__}")
print(f"  Using Baseline Prompt Builder for Evaluation: {EVAL_BASELINE_PROMPT_BUILDER.__name__}")


# --- 3. BM25 Parameters for Evaluation ---
EVAL_BM25_K1 = 1.5
EVAL_BM25_B = 0.75
EVAL_TOP_K_SNIPPETS_FOR_PROMPT = 3 # Adjust this to select the number of snippets to retrieve

# Define/Select the BM25 tokenizer for evaluation
def eval_robust_code_tokenizer_s9(text_input: str) -> list[str]:
    if not isinstance(text_input, str): 
        return []
    
    text = text_input.lower()
    raw_tokens = re.split(r'[^a-z0-9_.]+', text)
    
    processed_tokens = []
    for token in raw_tokens:
        if not token:
            continue
        
        sub_tokens = token.split('.')
        for sub_token in sub_tokens:
            if not sub_token:
                continue
            if len(sub_token) > 1 and not sub_token.isdigit():
                processed_tokens.append(sub_token)
            elif len(sub_token) == 1 and sub_token.isalpha():
                 processed_tokens.append(sub_token)
                
    return processed_tokens

EVAL_BM25_TOKENIZER = eval_robust_code_tokenizer_s9

print(f"  BM25 Params for Evaluation: K1={EVAL_BM25_K1}, B={EVAL_BM25_B}, Top-K for prompt={EVAL_TOP_K_SNIPPETS_FOR_PROMPT}")
print(f"  BM25 Tokenizer for Evaluation: {EVAL_BM25_TOKENIZER.__name__}")


# --- 4. LLM Generation Parameters for Evaluation ---
# LLM parameters
EVAL_MAX_NEW_TOKENS = 524
EVAL_TEMPERATURE = 0.5  # Potentially more deterministic for evaluation
EVAL_TOP_P = 0.95
EVAL_TOP_K = 50
EVAL_REPETITION_PENALTY = 1.1
EVAL_DO_SAMPLE = True    # Set to False for deterministic greedy decoding during evaluation if preferred

# Stopping criteria
EVAL_STOP_ON_EOS = True
EVAL_STOP_ON_CODE_END = True

eval_llm_stopping_criteria = None
if 'tokenizer' not in globals() or tokenizer is None:
    print("  WARNING: Global 'tokenizer' not found. Custom stopping criteria for evaluation cannot be created.")
elif 'EosAndCodeEndStoppingCriteria' not in globals():
    print("  WARNING: 'EosAndCodeEndStoppingCriteria' class not defined (expected from Sec 7.1). Cannot create custom stopping criteria.")
elif EVAL_STOP_ON_EOS or EVAL_STOP_ON_CODE_END:
    try:
        eval_stopper_instance = EosAndCodeEndStoppingCriteria( # Class from 7.1
            tokenizer, # Global tokenizer
            stop_on_eos_token=EVAL_STOP_ON_EOS,
            code_end_sequence="\n```\n" if EVAL_STOP_ON_CODE_END else None
        )
        eval_llm_stopping_criteria = StoppingCriteriaList([eval_stopper_instance])
        print("  Custom 'eval_llm_stopping_criteria' for evaluation run created successfully.")
    except Exception as e:
        print(f"  WARNING: Failed to create custom stopping criteria for evaluation: {e}")
else:
    print("  Custom stopping criteria (EOS/code end) for evaluation are disabled by EVAL_ flags.")

print(f"  LLM Config: MaxNew={EVAL_MAX_NEW_TOKENS}, Temp={EVAL_TEMPERATURE if EVAL_DO_SAMPLE else 'N/A'}, DoSample={EVAL_DO_SAMPLE}, StopCriteriaIsSet={'Yes' if eval_llm_stopping_criteria else 'No'}")


# --- 5. Verify other critical dependencies for Cell 9.2 ---
critical_deps_for_9_2 = [
    'model', 'generate_llm_code_and_clean',
    'GITHUB_RAW_CONTENT_BASE_URL', 'LOCAL_TEMP_KB_DOWNLOAD_DIR', 'download_github_raw_json'
]
if any(dep not in globals() for dep in critical_deps_for_9_2):
    raise NameError(f"One or more critical global dependencies for Cell 9.2 are missing: {critical_deps_for_9_2}. "
                    "Ensure all prior setup sections (LLM loading, GitHub config, helper functions) have run.")

In [None]:
from rank_bm25 import BM25Okapi
from tqdm.auto import tqdm
import json
import os

# --- 1. Verify Essential Configurations & Dependencies (Quick Check) ---
required_globals_for_eval = [
    'lca_dataset_split', 'EVAL_RAG_PROMPT_BUILDER', 'EVAL_BASELINE_PROMPT_BUILDER',
    'EVAL_BM25_TOKENIZER', 'EVAL_BM25_K1', 'EVAL_BM25_B', 'EVAL_TOP_K_SNIPPETS_FOR_PROMPT',
    'GITHUB_RAW_CONTENT_BASE_URL', 'LOCAL_TEMP_KB_DOWNLOAD_DIR', 'download_github_raw_json',
    'model', 'tokenizer', 'generate_llm_code_and_clean', 'NUM_EXAMPLES_TO_EVALUATE',
    'EVAL_MAX_NEW_TOKENS', 'EVAL_TEMPERATURE', 'EVAL_TOP_P', 'EVAL_TOP_K',
    'EVAL_REPETITION_PENALTY', 'EVAL_DO_SAMPLE', 'eval_llm_stopping_criteria' # Can be None
]
if any(var not in globals() for var in required_globals_for_eval):
    raise NameError("One or more critical configurations or dependencies for the evaluation loop are missing. "
                    "Ensure Cell 9.1 and all preceding setup cells (dataset, model, helpers) have been run.")


print(f"\n--- Starting Main Evaluation Loop, targeting {NUM_EXAMPLES_TO_EVALUATE} samples. ---")
# --- 2. Initialize Output Lists & BM25 Cache ---
eval_all_baseline_outputs, eval_all_rag_outputs = [], []
eval_all_references, eval_all_reference_apis    = [], []
eval_bm25_index_cache = {} # {library_key: (bm25_index, valid_docs_list, map_bm25_to_valid_idx_list)}

# --- 3. Prepare Dataset View ---
global dataset_for_eval
dataset_for_eval = lca_dataset_split.select(
    range(min(NUM_EXAMPLES_TO_EVALUATE, len(lca_dataset_split)))
)
num_samples_to_process = len(dataset_for_eval)

# --- 4. Main Evaluation Loop ---
for idx, sample_data in enumerate(tqdm(dataset_for_eval, desc="⚙️ Evaluating", unit="sample")):
    instruction = sample_data.get("instruction")
    repo_key = sample_data.get('repo_full_name') # Used for KB lookup and logging

    # Initialize per-sample outputs to ensure lists stay synchronized on error
    baseline_code = ""
    rag_code = ""

    if not instruction or not repo_key:
        print(f"  Skipping sample {idx} due to missing instruction or repo_key.")
    else:
        # --- A. Baseline Generation ---
        try:
            baseline_prompt = EVAL_BASELINE_PROMPT_BUILDER(instruction)
            baseline_code = generate_llm_code_and_clean(
                prompt_text=baseline_prompt, llm_model=model, llm_tokenizer=tokenizer,
                max_new_tokens_gen=EVAL_MAX_NEW_TOKENS, do_sample_gen=EVAL_DO_SAMPLE,
                temperature_gen=EVAL_TEMPERATURE, top_p_gen=EVAL_TOP_P, top_k_gen=EVAL_TOP_K,
                repetition_penalty_gen=EVAL_REPETITION_PENALTY,
                stopping_criteria_list_gen=eval_llm_stopping_criteria,
                prompt_name=f"Base_{idx}_{repo_key}"
            )
        except Exception as e_base:
            print(f"  Error during Baseline for sample {idx} ({repo_key}): {type(e_base).__name__}")

        # --- B. RAG Generation ---
        retrieved_context_str = ""
        try:
            if repo_key in eval_bm25_index_cache:
                bm25_idx, valid_docs, map_to_valid = eval_bm25_index_cache[repo_key]
            else:
                kb_file = f"kb_{repo_key}.json"
                kb_url = f"{GITHUB_RAW_CONTENT_BASE_URL}/{kb_file}"
                temp_kb_dir = os.path.join(LOCAL_TEMP_KB_DOWNLOAD_DIR, f"eval_kb_{repo_key}")
                kb_json = download_github_raw_json(kb_url, temp_kb_dir, kb_file, overwrite=False)

                bm25_idx, valid_docs, map_to_valid, tokenized_corpus_for_index = (None, [], [], []) # Defaults
                if kb_json and isinstance(kb_json, list):
                    valid_docs = [str(d) for d in kb_json if isinstance(d, str) and str(d).strip()]
                    if valid_docs:
                        tokenized_kb = [EVAL_BM25_TOKENIZER(doc) for doc in valid_docs]
                        for i_map, toks in enumerate(tokenized_kb):
                            if toks: tokenized_corpus_for_index.append(toks); map_to_valid.append(i_map)
                        if tokenized_corpus_for_index:
                            bm25_idx = BM25Okapi(tokenized_corpus_for_index, k1=EVAL_BM25_K1, b=EVAL_BM25_B)
                            eval_bm25_index_cache[repo_key] = (bm25_idx, valid_docs, map_to_valid, tokenized_corpus_for_index)

            if bm25_idx:
                query_toks = EVAL_BM25_TOKENIZER(instruction)
                if query_toks:
                    num_idx_docs = len(bm25_idx.doc_len)
                    top_idxs = bm25_idx.get_top_n(query_toks, list(range(num_idx_docs)),
                                                  n=min(EVAL_TOP_K_SNIPPETS_FOR_PROMPT, num_idx_docs))
                    retrieved_list = [valid_docs[map_to_valid[i]] for i in top_idxs]
                    retrieved_context_str = "\n\n# --- Snippet ---\n\n".join(retrieved_list)

            # --- Assemble RAG Prompt ---
            if 'tokenizer' not in globals() and hasattr(EVAL_RAG_PROMPT_BUILDER, '__code__') and "truncate_to_n_tokens" in EVAL_RAG_PROMPT_BUILDER.__code__.co_names:
                 print(f"  WARNING: Tokenizer needed by {EVAL_RAG_PROMPT_BUILDER.__name__} but not global. Truncation may fail.")

            rag_prompt_text = EVAL_RAG_PROMPT_BUILDER(instruction, retrieved_context_str)

            rag_code = generate_llm_code_and_clean(
                prompt_text=rag_prompt_text, llm_model=model, llm_tokenizer=tokenizer,
                max_new_tokens_gen=EVAL_MAX_NEW_TOKENS, do_sample_gen=EVAL_DO_SAMPLE,
                temperature_gen=EVAL_TEMPERATURE, top_p_gen=EVAL_TOP_P, top_k_gen=EVAL_TOP_K,
                repetition_penalty_gen=EVAL_REPETITION_PENALTY,
                stopping_criteria_list_gen=eval_llm_stopping_criteria,
                prompt_name=f"RAG_Eval_{idx}_{repo_key}"
            )
        except Exception as e_rag:
            print(f"  Error during RAG for sample {idx} ({repo_key}): {type(e_rag).__name__}")

    # --- Append results for this sample ---
    eval_all_baseline_outputs.append(baseline_code or "")
    eval_all_rag_outputs.append(rag_code or "")
    eval_all_references.append(sample_data.get("clean_reference", ""))
    eval_all_reference_apis.append(sample_data.get("unique_apis", []))

# --- Final Sanity Check of List Lengths ---
if not (len(eval_all_baseline_outputs) == len(eval_all_rag_outputs) == \
        len(eval_all_references) == len(eval_all_reference_apis) == num_samples_to_process):
    print("\n❌ CRITICAL ERROR: Length mismatch in final evaluation output lists!")
else:
    print(f"\n✅ Successfully collected all outputs for {num_samples_to_process} evaluation examples.")

In [None]:
import numpy as np
import sacrebleu # Ensure sacrebleu is imported if not already globally
import re
import warnings
import importlib.metadata as md # For version checking
import traceback # For more detailed error info

# --- Initialize Sacrebleu & CodeBLEU ---
_HAS_CODEBLEU_METRIC_UTIL = False
try:
    sbleu_version_util = md.version("sacrebleu")
    print(f" Sacrebleu imported successfully for utilities (Version: {sbleu_version_util}).")
except md.PackageNotFoundError:
    print("Sacrebleu not found via importlib.metadata. Ensure it's installed.")
    sbleu_version_util = "n/a"

try:
    from codebleu import calc_codebleu
    cb_version_util = md.version("codebleu")
    print(f" CodeBLEU imported successfully for utilities (Version: {cb_version_util}).")
    _HAS_CODEBLEU_METRIC_UTIL = True
except ImportError:
    print(" CodeBLEU import failed in utilities. CodeBLEU scores will be skipped or result in None/0.")
    def calc_codebleu(*args, **kwargs): # Dummy function
        warnings.warn("calc_codebleu called, but CodeBLEU library is not available or failed to import.")
        return {}
except md.PackageNotFoundError:
    print("CodeBLEU seems imported but version not found via importlib.metadata.")
    _HAS_CODEBLEU_METRIC_UTIL = True # Assume import was fine

# --- Globals for CodeBLEU Key Detection ---
_codebleu_score_key_cache_util = None

def _get_and_cache_codebleu_score_key(pred_for_key_detect: str, ref_for_key_detect: str):
    global _codebleu_score_key_cache_util
    if not _HAS_CODEBLEU_METRIC_UTIL or _codebleu_score_key_cache_util is not None:
        return _codebleu_score_key_cache_util

    safe_pred_for_key = pred_for_key_detect if isinstance(pred_for_key_detect, str) and pred_for_key_detect.strip() else "def example_pred(): pass"
    safe_ref_for_key = ref_for_key_detect if isinstance(ref_for_key_detect, str) and ref_for_key_detect.strip() else "def example_ref(): pass"

    try:
        result_dict_for_key = calc_codebleu(
            references=[[safe_ref_for_key]],
            predictions=[safe_pred_for_key],
            lang="python",
            weights=(0.25,0.25,0.25,0.25)
        )
        for key_name in result_dict_for_key:
            if "codebleu" in key_name.lower():
                _codebleu_score_key_cache_util = key_name
                break
        if not _codebleu_score_key_cache_util:
             warnings.warn("Could not auto-detect CodeBLEU score key from calc_codebleu output.")
    except Exception as e_cb_key_detect:
        warnings.warn(f"Error during CodeBLEU key detection: {type(e_cb_key_detect).__name__}: {e_cb_key_detect}.")
        traceback.print_exc()
    return _codebleu_score_key_cache_util

# --- Metric Calculation Functions ---
def calculate_chrf_score_metric(prediction: str, reference: str) -> float | None:
    if not (isinstance(prediction, str) and isinstance(reference, str)): return None
    if not prediction.strip() or not reference.strip(): return 0.0
    try: return sacrebleu.corpus_chrf([prediction], [[reference]]).score
    except Exception as e: warnings.warn(f"ChrF failed: {e}"); return None

def calculate_codebleu_score_metric(prediction: str, reference: str) -> float | None:
    if not _HAS_CODEBLEU_METRIC_UTIL: return None
    if not (isinstance(prediction, str) and isinstance(reference, str)): return None
    if not prediction.strip() or not reference.strip(): return 0.0

    global _codebleu_score_key_cache_util
    if _codebleu_score_key_cache_util is None: # Should be pre-cached by evaluate_rag_vs_baseline
        _get_and_cache_codebleu_score_key(prediction, reference) # Attempt to cache if not already
        if _codebleu_score_key_cache_util is None:
            warnings.warn("CodeBLEU key not available for calculation; returning None for CodeBLEU.")
            return None
    try:
        res_cb = calc_codebleu(references=[[reference]], predictions=[prediction], lang="python", weights=(0.25,0.25,0.25,0.25))
        return float(res_cb.get(_codebleu_score_key_cache_util, 0.0))
    except Exception as e:
        warnings.warn(f"CodeBLEU calc failed for sample: {type(e).__name__} - {e}")
        return None

def calculate_api_recall_score_metric(generated_code: str, ref_apis: list) -> float:
    if not isinstance(generated_code, str) or not isinstance(ref_apis, list): return 0.0
    valid_apis = [api for api in ref_apis if isinstance(api, str) and api.strip()]
    if not generated_code.strip() or not valid_apis: return 0.0
    # CORRECTED LINE: Removed the trailing backslash in the regex f-string
    hits = sum(bool(re.search(rf"\b{re.escape(api)}\b", generated_code)) for api in valid_apis)
    return hits / len(valid_apis) if valid_apis else 0.0

# --- Helper for Averaging Metrics ---
def _calculate_mean_score_metric(metric_func, predictions_list: list, ground_truths_list: list) -> float:
    scores = [metric_func(p, gt) for p, gt in zip(predictions_list, ground_truths_list)]
    valid_scores = [s for s in scores if s is not None]
    return np.mean(valid_scores) if valid_scores else 0.0

# --- Main Evaluation Function ---
def evaluate_rag_vs_baseline(baseline_preds: list, rag_preds: list, refs: list, ref_api_lists: list) -> dict:
    global _codebleu_score_key_cache_util

    if _HAS_CODEBLEU_METRIC_UTIL and _codebleu_score_key_cache_util is None:
        first_p, first_r = (None, None)
        for p_list_to_check in [baseline_preds, rag_preds]:
            for p_item, r_item in zip(p_list_to_check, refs):
                if isinstance(p_item, str) and p_item.strip() and isinstance(r_item, str) and r_item.strip():
                    first_p, first_r = p_item, r_item; break
            if first_p: break
        if first_p and first_r: _get_and_cache_codebleu_score_key(first_p, first_r)
        else: warnings.warn("evaluate_rag_vs_baseline: No valid data pair to detect CodeBLEU key initially.")

    metrics_data = {
        "API Recall": (_calculate_mean_score_metric(calculate_api_recall_score_metric, baseline_preds, ref_api_lists),
                       _calculate_mean_score_metric(calculate_api_recall_score_metric, rag_preds, ref_api_lists)),
        "ChrF Score": (_calculate_mean_score_metric(calculate_chrf_score_metric, baseline_preds, refs),
                       _calculate_mean_score_metric(calculate_chrf_score_metric, rag_preds, refs)),
        "CodeBLEU Score": (_calculate_mean_score_metric(calculate_codebleu_score_metric, baseline_preds, refs),
                           _calculate_mean_score_metric(calculate_codebleu_score_metric, rag_preds, refs))
    }

    print("\n--- Automatic Evaluation Metrics ---")
    print(f"| Metric         | Baseline |   RAG   | Δ (RAG - Base) |")
    print(f"|----------------|----------|---------|----------------|")
    for name, (base_val, rag_val) in metrics_data.items():
        delta_val = rag_val - base_val
        format_str = "{:8.4f}" if name == "API Recall" else "{:8.2f}"
        print(f"| {name:<14} | {format_str.format(base_val)} | {format_str.format(rag_val):>7} | {delta_val:+14.4f} |")
    print("---------------------------------------------------------")

    return {
        metric_name: {"baseline": scores[0], "rag": scores[1], "delta": scores[1] - scores[0]}
        for metric_name, scores in metrics_data.items()
    }

print("Metric calculation utilities (including 'evaluate_rag_vs_baseline') defined.")

In [None]:
# --- 1. Essential Prerequisite Checks ---
if 'evaluate_rag_vs_baseline' not in globals() or not callable(globals()['evaluate_rag_vs_baseline']):
    raise NameError("The main evaluation function ('evaluate_rag_vs_baseline') is not defined. "
                    "Ensure the 'Metric Calculation Utilities' cell (e.g., 9.2.5) has been run.")

result_lists_for_metrics_final = [
    'eval_all_baseline_outputs', 'eval_all_rag_outputs',
    'eval_all_references', 'eval_all_reference_apis'
]
for required_list_name_final in result_lists_for_metrics_final:
    if required_list_name_final not in globals() or not isinstance(globals()[required_list_name_final], list):
        raise NameError(f"Result list '{required_list_name_final}' from Cell 9.2 is missing or not a list. "
                        "Ensure Cell 9.2 (Main Evaluation Loop) completed successfully.")

# --- 2. Proceed with Metrics Calculation if Data is Available ---
if not eval_all_baseline_outputs:
    print("  INFO: Evaluation output lists are empty. Skipping metrics calculation.")
else:
    try:
        final_metrics_results_dict_output = evaluate_rag_vs_baseline(
            eval_all_baseline_outputs,
            eval_all_rag_outputs,
            eval_all_references,
            eval_all_reference_apis
        )

    except Exception as e_metric_calc_final_run:
        print(f"\n   ERROR during metrics calculation via 'evaluate_rag_vs_baseline':")
        print(f"    Error: {type(e_metric_calc_final_run).__name__}: {e_metric_calc_final_run}")
        import traceback; traceback.print_exc()

In [None]:
# Cell 9.4: Qualitative Analysis Sample Viewer
import textwrap
from rank_bm25 import BM25Okapi
import os

# --- Configuration for this Analysis Cell ---
INDICES_TO_ANALYZE = [0, 1, 2] # Specify the indices from your evaluation run
MAX_SNIPPET_DISPLAY_LENGTH = 500
# If you want to ensure you're using the exact BM25 index built during eval:
# USE_CACHED_BM25_INDEX = True
# For simplicity now, let's rebuild on the fly.

print(f"--- Qualitative Analysis for Selected Samples ---")

# --- Prerequisite Checks ---
# Check for dataset_for_eval explicitly
if 'dataset_for_eval' not in globals():
    raise NameError("Global variable 'dataset_for_eval' not found. Ensure Cell 9.2 (with 'global dataset_for_eval') has been run.")
if not hasattr(globals()['dataset_for_eval'], '__getitem__') or not hasattr(globals()['dataset_for_eval'], '__len__'):
    raise TypeError(f"'dataset_for_eval' is not a list-like object (e.g., a Hugging Face Dataset). Type is: {type(globals()['dataset_for_eval'])}")
if len(globals()['dataset_for_eval']) == 0 and len(INDICES_TO_ANALYZE) > 0 and INDICES_TO_ANALYZE[0] >=0 : # Check if empty only if we intend to access it
     # Allow empty if INDICES_TO_ANALYZE is empty or all negative (though that's unusual)
    if any(i >= 0 for i in INDICES_TO_ANALYZE): # Only raise if we actually try to access a positive index
        raise ValueError("'dataset_for_eval' is empty. Ensure Cell 9.2 processed samples and NUM_EXAMPLES_TO_EVALUATE > 0.")


# Verify other prerequisite lists from Cell 9.2
other_required_lists = ['eval_all_baseline_outputs', 'eval_all_rag_outputs', 'eval_all_references']
for lst_name in other_required_lists:
    if lst_name not in globals() or not isinstance(globals()[lst_name], list):
        raise NameError(f"Required list '{lst_name}' not found or not a list. Run Cell 9.2 first.")
    # Check if these lists are non-empty if we actually plan to use them based on INDICES_TO_ANALYZE
    if any(i >= 0 and i < len(globals()[lst_name]) for i in INDICES_TO_ANALYZE) and not globals()[lst_name]:
         raise ValueError(f"Required list '{lst_name}' is empty, but INDICES_TO_ANALYZE expects to access it.")
    # Check if indices are valid for these lists
    for sample_idx_to_check in INDICES_TO_ANALYZE:
        if sample_idx_to_check >= 0 and sample_idx_to_check >= len(globals()[lst_name]):
            raise IndexError(f"Index {sample_idx_to_check} is out of bounds for list '{lst_name}' (size {len(globals()[lst_name])}).")


# Verify BM25 and GitHub config
bm25_config_vars = ['EVAL_BM25_TOKENIZER', 'EVAL_BM25_K1', 'EVAL_BM25_B', 'EVAL_TOP_K_SNIPPETS_FOR_PROMPT']
github_vars = ['GITHUB_RAW_CONTENT_BASE_URL', 'LOCAL_TEMP_KB_DOWNLOAD_DIR', 'download_github_raw_json']
if any(v not in globals() for v in bm25_config_vars + github_vars):
    raise NameError("BM25 or GitHub configuration variables from Cell 9.1 are missing.")


# --- Main Analysis Loop ---
for sample_idx_in_eval_run in INDICES_TO_ANALYZE:
    # We've already checked bounds for dataset_for_eval implicitly by checking other lists
    # but an explicit check against dataset_for_eval's length is good.
    if not (0 <= sample_idx_in_eval_run < len(globals()['dataset_for_eval'])):
        print(f"\n Warning: Index {sample_idx_in_eval_run} is out of bounds for 'dataset_for_eval' (size {len(globals()['dataset_for_eval'])}). Skipping.")
        continue

    print(f"\n\n======= Analyzing Sample Index (from eval run): {sample_idx_in_eval_run} =======")

    # 1. Get data from the evaluation run
    # Access dataset_for_eval via globals() for this check, then use directly
    sample_data_original = globals()['dataset_for_eval'][sample_idx_in_eval_run]
    instruction = sample_data_original.get("instruction")
    repo_key = sample_data_original.get('repo_full_name')
    reference_code = eval_all_references[sample_idx_in_eval_run]
    baseline_output = eval_all_baseline_outputs[sample_idx_in_eval_run]
    rag_output = eval_all_rag_outputs[sample_idx_in_eval_run]

    print(f"Repo Key: {repo_key}")
    print("\n--- INSTRUCTION ---")
    print(textwrap.fill(instruction, width=100))

    print("\n--- REFERENCE CODE ---")
    print(reference_code if reference_code.strip() else "[No Reference Code]")

    # 2. Re-perform BM25 retrieval for this sample to get the snippets
    print("\n--- RETRIEVED SNIPPETS (for RAG) ---")
    retrieved_snippets_for_display = []
    if not repo_key:
        print("  No repo_key for this sample, cannot retrieve snippets.")
    else:
        kb_filename_on_github = f"kb_{repo_key}.json"
        raw_kb_url = f"{GITHUB_RAW_CONTENT_BASE_URL}/{kb_filename_on_github}"
        temp_save_subdir_for_kb = os.path.join(LOCAL_TEMP_KB_DOWNLOAD_DIR, f"qa_kb_{repo_key}")
        kb_data = download_github_raw_json(raw_kb_url, temp_save_subdir_for_kb, kb_filename_on_github, overwrite=False)

        if not kb_data or not isinstance(kb_data, list):
            print(f"  Could not load or parse KB for '{repo_key}'.")
        else:
            valid_kb_docs = [str(doc) for doc in kb_data if isinstance(doc, str) and str(doc).strip()]
            if not valid_kb_docs:
                print(f"  No valid string snippets in KB for '{repo_key}'.")
            else:
                tokenized_kb_for_bm25 = [EVAL_BM25_TOKENIZER(doc) for doc in valid_kb_docs]
                final_bm25_corpus_tokenized = []
                map_idx_bm25_to_original_valid_docs = []

                for i, tokens in enumerate(tokenized_kb_for_bm25):
                    if tokens:
                        final_bm25_corpus_tokenized.append(tokens)
                        map_idx_bm25_to_original_valid_docs.append(i)

                if not final_bm25_corpus_tokenized:
                    print("  Tokenized KB is empty after filtering. BM25 index not built.")
                else:
                    bm25_index_for_sample = BM25Okapi(final_bm25_corpus_tokenized, k1=EVAL_BM25_K1, b=EVAL_BM25_B)
                    query_tokens_for_sample = EVAL_BM25_TOKENIZER(instruction)

                    if not query_tokens_for_sample:
                        print("  Instruction tokenized to empty list, no snippets retrieved.")
                    else:
                        num_docs_in_bm25_index = len(final_bm25_corpus_tokenized)
                        top_n_indices_in_bm25_corpus = bm25_index_for_sample.get_top_n(
                            query_tokens_for_sample,
                            list(range(num_docs_in_bm25_index)),
                            n=min(EVAL_TOP_K_SNIPPETS_FOR_PROMPT, num_docs_in_bm25_index)
                        )
                        retrieved_original_indices = [map_idx_bm25_to_original_valid_docs[i] for i in top_n_indices_in_bm25_corpus]
                        retrieved_snippets_for_display = [valid_kb_docs[i] for i in retrieved_original_indices]
        if retrieved_snippets_for_display:
            for i_snip, snippet_text in enumerate(retrieved_snippets_for_display):
                print(f"\n  Snippet {i_snip+1}/{len(retrieved_snippets_for_display)} (Length: {len(snippet_text)} chars):")
                print(textwrap.shorten(snippet_text, width=MAX_SNIPPET_DISPLAY_LENGTH, placeholder="... (snippet truncated) ..."))
        else:
            print("  No snippets were retrieved for this sample.")

    print("\n--- BASELINE OUTPUT ---")
    print(baseline_output if baseline_output.strip() else "[No Baseline Output]")

    print("\n--- RAG OUTPUT ---")
    print(rag_output if rag_output.strip() else "[No RAG Output]")

    print(f"======= End of Analysis for Sample Index: {sample_idx_in_eval_run} =======")

## Section 10: RAG Performance vs. Baseline with Varying Top-K Retrieved Snippets

In [None]:
import re # For the tokenizer if defined here again

# --- 1. Number of Samples for Section 10 Evaluation ---
NUM_EXAMPLES_TO_EVALUATE_S10 = 3

# --- 2. Top-K Values for Snippet Retrieval ---
TOP_K_VALUES_FOR_S10_EVAL = [1, 3, 5, 10]

# --- 3. Select Prompt Builders for Section 10 Evaluation (Reusing from Section 9) ---
# Ensure EVAL_RAG_PROMPT_BUILDER and EVAL_BASELINE_PROMPT_BUILDER are defined in Section 9.1
if 'EVAL_RAG_PROMPT_BUILDER' not in globals() or 'EVAL_BASELINE_PROMPT_BUILDER' not in globals():
    raise NameError("EVAL_RAG_PROMPT_BUILDER or EVAL_BASELINE_PROMPT_BUILDER not defined. Run Section 9.1 first.")
S10_EVAL_RAG_PROMPT_BUILDER = EVAL_RAG_PROMPT_BUILDER
S10_EVAL_BASELINE_PROMPT_BUILDER = EVAL_BASELINE_PROMPT_BUILDER

# --- 4. BM25 Parameters for Section 10 Evaluation (Reusing from Section 9) ---
# Ensure EVAL_BM25_K1, EVAL_BM25_B, EVAL_BM25_TOKENIZER are defined in Section 9.1
s9_bm25_params_check = ['EVAL_BM25_K1', 'EVAL_BM25_B', 'EVAL_BM25_TOKENIZER']
if any(p not in globals() for p in s9_bm25_params_check):
    raise NameError(f"One or more BM25 parameters ({s9_bm25_params_check}) from Section 9.1 are missing.")
S10_EVAL_BM25_K1 = EVAL_BM25_K1
S10_EVAL_BM25_B = EVAL_BM25_B
S10_EVAL_BM25_TOKENIZER = EVAL_BM25_TOKENIZER

# --- 5. LLM Generation Parameters for Section 10 Evaluation (Reusing from Section 9) ---
# Ensure EVAL_MAX_NEW_TOKENS, EVAL_TEMPERATURE, etc., and eval_llm_stopping_criteria are set in Section 9.1
s9_llm_params_check = [
    'EVAL_MAX_NEW_TOKENS', 'EVAL_TEMPERATURE', 'EVAL_TOP_P', 'EVAL_TOP_K',
    'EVAL_REPETITION_PENALTY', 'EVAL_DO_SAMPLE', 'eval_llm_stopping_criteria'
]
if any(p not in globals() for p in s9_llm_params_check):
    raise NameError(f"One or more LLM generation parameters from Section 9.1 are missing.")

# --- 6. Verify other critical dependencies for Cell 10.2 ---
critical_deps_for_s10_loop = [
    'lca_dataset_split', 'model', 'tokenizer', 'generate_llm_code_and_clean',
    'GITHUB_RAW_CONTENT_BASE_URL', 'LOCAL_TEMP_KB_DOWNLOAD_DIR', 'download_github_raw_json',
    'BM25Okapi' # Added BM25Okapi here for explicitness, it's imported in S9.2
]
if any(dep not in globals() for dep in critical_deps_for_s10_loop):
    missing_deps = [dep for dep in critical_deps_for_s10_loop if dep not in globals()]
    raise NameError(f"Critical S10 dependencies missing: {missing_deps}. Ensure prior setup sections ran.")

In [None]:
from tqdm.auto import tqdm
import json
import os

# --- 1. Initialize Output Storage ---
s10_evaluation_results = {}
for top_k in TOP_K_VALUES_FOR_S10_EVAL:
    s10_evaluation_results[top_k] = {
        'baseline_outputs': [],
        'rag_outputs': [],
        'references': [],
        'reference_apis': []
    }

eval_bm25_index_cache_s10 = {} # Cache for BM25 indexes to speed up if multiple samples use same repo

# --- 2. Prepare Dataset View for Section 10 ---
dataset_for_s10_eval = lca_dataset_split.select(
    range(min(NUM_EXAMPLES_TO_EVALUATE_S10, len(lca_dataset_split)))
)
num_samples_to_process_s10 = len(dataset_for_s10_eval)

# --- 3. Main Evaluation Loop ---
for idx, sample_data in enumerate(tqdm(dataset_for_s10_eval, desc="⚙️ S10 Evaluating", unit="sample")):
    instruction = sample_data.get("instruction")
    repo_key = sample_data.get('repo_full_name')
    reference_code_s10 = sample_data.get("clean_reference", "")
    reference_apis_s10 = sample_data.get("unique_apis", [])

    baseline_code_s10 = ""
    if not instruction or not repo_key:
        print(f"  Skipping S10 sample {idx} ({repo_key}) due to missing instruction or repo_key.")
        # Append empty strings to keep list lengths consistent for all top_k
        for top_k_val in TOP_K_VALUES_FOR_S10_EVAL:
            s10_evaluation_results[top_k_val]['baseline_outputs'].append("")
            s10_evaluation_results[top_k_val]['rag_outputs'].append("")
            s10_evaluation_results[top_k_val]['references'].append(reference_code_s10)
            s10_evaluation_results[top_k_val]['reference_apis'].append(reference_apis_s10)
        continue

    # --- A. Baseline Generation (once per sample) ---
    try:
        baseline_prompt_s10 = S10_EVAL_BASELINE_PROMPT_BUILDER(instruction)
        baseline_code_s10 = generate_llm_code_and_clean(
            prompt_text=baseline_prompt_s10, llm_model=model, llm_tokenizer=tokenizer,
            max_new_tokens_gen=EVAL_MAX_NEW_TOKENS, do_sample_gen=EVAL_DO_SAMPLE,
            temperature_gen=EVAL_TEMPERATURE, top_p_gen=EVAL_TOP_P, top_k_gen=EVAL_TOP_K,
            repetition_penalty_gen=EVAL_REPETITION_PENALTY,
            stopping_criteria_list_gen=eval_llm_stopping_criteria, # from S9.1
            prompt_name=f"S10_Base_{idx}_{repo_key}"
        )
    except Exception as e_base:
        print(f"  Error during S10 Baseline for sample {idx} ({repo_key}): {type(e_base).__name__}: {e_base}")
        baseline_code_s10 = "" # Ensure it's an empty string on error

    # --- B. RAG Generation (iterating through TOP_K_VALUES_FOR_S10_EVAL) ---
    for top_k_val in TOP_K_VALUES_FOR_S10_EVAL:
        s10_evaluation_results[top_k_val]['baseline_outputs'].append(baseline_code_s10 or "")
        s10_evaluation_results[top_k_val]['references'].append(reference_code_s10)
        s10_evaluation_results[top_k_val]['reference_apis'].append(reference_apis_s10)

        rag_code_s10 = ""
        retrieved_context_str_s10 = ""
        try:
            # BM25 Retrieval (similar to Section 9.2)
            if repo_key in eval_bm25_index_cache_s10:
                bm25_idx, valid_docs, map_to_valid = eval_bm25_index_cache_s10[repo_key]
            else:
                kb_file = f"kb_{repo_key}.json"
                kb_url = f"{GITHUB_RAW_CONTENT_BASE_URL}/{kb_file}"
                temp_kb_dir = os.path.join(LOCAL_TEMP_KB_DOWNLOAD_DIR, f"s10_eval_kb_{repo_key}")
                kb_json = download_github_raw_json(kb_url, temp_kb_dir, kb_file, overwrite=False) # Use cache if available

                bm25_idx, valid_docs, map_to_valid, tokenized_corpus_for_index = (None, [], [], [])
                if kb_json and isinstance(kb_json, list):
                    valid_docs = [str(d) for d in kb_json if isinstance(d, str) and str(d).strip()]
                    if valid_docs:
                        tokenized_kb = [S10_EVAL_BM25_TOKENIZER(doc) for doc in valid_docs]
                        for i_map, toks in enumerate(tokenized_kb):
                            if toks: tokenized_corpus_for_index.append(toks); map_to_valid.append(i_map)
                        if tokenized_corpus_for_index:
                            bm25_idx = BM25Okapi(tokenized_corpus_for_index, k1=S10_EVAL_BM25_K1, b=S10_EVAL_BM25_B)
                            eval_bm25_index_cache_s10[repo_key] = (bm25_idx, valid_docs, map_to_valid)

            if bm25_idx:
                query_toks = S10_EVAL_BM25_TOKENIZER(instruction)
                if query_toks:
                    num_idx_docs = len(bm25_idx.doc_len) # Access doc_len from BM25Okapi instance
                    top_idxs = bm25_idx.get_top_n(query_toks, list(range(num_idx_docs)),
                                                  n=min(top_k_val, num_idx_docs))
                    retrieved_list = [valid_docs[map_to_valid[i]] for i in top_idxs]
                    retrieved_context_str_s10 = "\\n\\n# --- Snippet ---\\n\\n".join(retrieved_list)

            # Assemble RAG Prompt
            rag_prompt_text_s10 = S10_EVAL_RAG_PROMPT_BUILDER(instruction, retrieved_context_str_s10)

            rag_code_s10 = generate_llm_code_and_clean(
                prompt_text=rag_prompt_text_s10, llm_model=model, llm_tokenizer=tokenizer,
                max_new_tokens_gen=EVAL_MAX_NEW_TOKENS, do_sample_gen=EVAL_DO_SAMPLE,
                temperature_gen=EVAL_TEMPERATURE, top_p_gen=EVAL_TOP_P, top_k_gen=EVAL_TOP_K,
                repetition_penalty_gen=EVAL_REPETITION_PENALTY,
                stopping_criteria_list_gen=eval_llm_stopping_criteria, # from S9.1
                prompt_name=f"S10_RAG_TopK{top_k_val}_{idx}_{repo_key}"
            )
        except Exception as e_rag:
            print(f"  Error during S10 RAG (TopK={top_k_val}) for sample {idx} ({repo_key}): {type(e_rag).__name__}: {e_rag}")
            rag_code_s10 = "" # Ensure it's an empty string on error
        s10_evaluation_results[top_k_val]['rag_outputs'].append(rag_code_s10 or "")

# --- 4. Final Sanity Check of List Lengths ---
for tk_val, results_dict in s10_evaluation_results.items():
    if not (len(results_dict['baseline_outputs']) == len(results_dict['rag_outputs']) == \
            len(results_dict['references']) == len(results_dict['reference_apis']) == num_samples_to_process_s10):
        print(f"\nCRITICAL ERROR: Length mismatch for TopK={tk_val} in S10 final output lists!")
    else:
        print(f"\nSUCCESS: Collected all S10 outputs for TopK={tk_val} for {num_samples_to_process_s10} evaluation examples.")

In [None]:
# --- 1. Essential Prerequisite Checks ---
if 'evaluate_rag_vs_baseline' not in globals() or not callable(globals()['evaluate_rag_vs_baseline']):
    raise NameError("Function 'evaluate_rag_vs_baseline' from S9.3 is not defined.")
if 's10_evaluation_results' not in globals() or not isinstance(s10_evaluation_results, dict):
    raise NameError("S10 evaluation results ('s10_evaluation_results') not found or not a dict. Run Cell 10.2.")

# --- 2. Proceed with Metrics Calculation ---
if not s10_evaluation_results or not any(d.get('rag_outputs') for d in s10_evaluation_results.values()):
    print("  INFO: Section 10 evaluation output lists are empty or no RAG outputs. Skipping metrics calculation.")
else:
    # Store all metrics for potential summary later, though evaluate_rag_vs_baseline already prints tables
    all_final_metrics_s10_by_top_k = {}

    for top_k_value in sorted(s10_evaluation_results.keys()):
        results_for_this_top_k = s10_evaluation_results[top_k_value]
        print(f"\n======= METRICS FOR RAG (Top-K = {top_k_value}) vs BASELINE =======")

        # Check if there's anything to evaluate for this top_k
        if not results_for_this_top_k.get('rag_outputs') or not results_for_this_top_k.get('baseline_outputs'):
            print(f"  Skipping Top-K={top_k_value} due to empty RAG or Baseline outputs.")
            continue

        try:
            final_metrics_dict_output_s10 = evaluate_rag_vs_baseline(
                results_for_this_top_k['baseline_outputs'],
                results_for_this_top_k['rag_outputs'],
                results_for_this_top_k['references'],
                results_for_this_top_k['reference_apis']
            )
            all_final_metrics_s10_by_top_k[top_k_value] = final_metrics_dict_output_s10
        except Exception as e_metric_calc_s10:
            print(f"\n   ERROR during S10 metrics calculation for Top-K = {top_k_value}: {type(e_metric_calc_s10).__name__}: {e_metric_calc_s10}")
            import traceback; traceback.print_exc()