## Section 1: Importing

In [None]:
# Import the 'drive' module from google.colab library.
from google.colab import drive

# Mount Google Drive at the '/content/drive' path in the Colab filesystem.
# Requires user authorization.
drive.mount('/content/drive')

# Confirm successful mounting.
print("Drive mounted successfully.")

Mounted at /content/drive
Drive mounted successfully.


In [None]:
print("Installing libraries...")
!pip install --upgrade pip

# -q ensures minimal output (quiet installation)
!pip install rank-bm25
# update pip
!pip install --upgrade pip -q

# install bm25 (worked already, but no harm)
!pip install -q rank-bm25

!pip install -q \
  "transformers" \
  "datasets" \
  "torch" \
  "accelerate" \
  "bitsandbytes" \
  "huggingface_hub" \
  "sacrebleu>=2.0.0" \
  "codebleu" \
  "tree-sitter-python"
  #"fsspec==2024.12.0" \

# Upgrade to a good combination ─────────────
!pip install -qU "datasets>=2.19.0" "fsspec>=2024.3.0"  \
                 "huggingface_hub>=0.22.2"  "aiohttp"


import os
import time # for the delay before nvidia-smi
import warnings # for non-critical warnings
import shutil
import tarfile # for .tar archives
import ast
import json
import random
import re
import textwrap # for snipper preview
import traceback
import numpy as np
import torch

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# from typing import List, Optional # For type hinting (optional)
# from transformers import PreTrainedTokenizerBase # For type hinting (optional)

from datasets import load_dataset, DownloadMode
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError

try:
    from tqdm.auto import tqdm  # for progress bars (optional)
    USE_TQDM = True
except ImportError:
    USE_TQDM = False
    print("Library 'tqdm' not found, progress bar will not be shown.")
    print("You can install it with: !pip install -q tqdm")

from rank_bm25 import BM25Okapi

print("\nLibraries installed (or updated) successfully!")


Installing libraries...
Collecting pip
  Downloading pip-25.1.1-py3-none-any.whl.metadata (3.6 kB)
Downloading pip-25.1.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.1.1
Collecting rank-bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank-bm25
Successfully installed rank-bm25-0.2.2
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m30.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m188.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m

### **Prompt Templates**

In [None]:
# Prompt templates - version 1

def build_baseline_prompt_v1(instruction: str) -> str:
    """
    Baseline: only the task, then a '### Code:' marker to start generation.
    """
    return f"""\
### Task:
{instruction.strip()}

### Code:
"""

def build_rag_prompt_v1(instruction: str, retrieved: str) -> str:
    """
    RAG: first show retrieved examples, then the task, then '### Code:'.
    """
    return f"""\
### Retrieved Examples:
{retrieved.strip()}

### Task:
{instruction.strip()}

### Code:
"""

In [None]:
# ────────────────────────────────────────────────────────────────
# Prompt templates - version 2
# ────────────────────────────────────────────────────────────────

def build_baseline_prompt_v2(instruction: str) -> str:
    """
    Baseline: very explicit, with sections for clarity.
    """
    return f"""\
### Library:
You will use the `seedemu` Python library to build a network emulation.

### Task Description:
{instruction.strip()}

### Requirements:
1. Import only from `seedemu` (layers, services, core, compiler).
2. Create objects in this order: Emulator → Layers → Services → Bindings → Dump.
3. Use clear variable names (e.g. `base`, `routing`, `ebgp`, `sim`).
4. Target Python 3.8+ syntax.

### Output Format:
- Provide only valid Python code.
- No comments, no extra text.
- Start at the first line of code (do not repeat the task).

### Code:
"""

def build_rag_prompt_v2(instruction: str, retrieved: str) -> str:
    """
    RAG: include retrieved examples plus the detailed task template.
    """
    return f"""\
### Retrieved Examples:
{retrieved.strip()}

### Library:
Use the `seedemu` Python library.

### Task Description:
{instruction.strip()}

### Requirements:
1. Imports: `seedemu.layers`, `seedemu.services`, `seedemu.core`, `seedemu.compiler`.
2. Instantiate Emulator, then Base, Routing, eBGP layers in order.
3. Install the domain name caching service on specified hosts.
4. Add private eBGP peerings between ASes.
5. Finally, dump the emulator state to `base-component.bin`.

### Output Format:
- Return **only** runnable Python code.
- No comments or markdown.
- Do not echo the instructions.

### Code:
"""


In [None]:
# ────────────────────────────────────────────────────────────────
# Prompt templates - version 3
# ────────────────────────────────────────────────────────────────

def build_baseline_prompt_v3(instruction: str) -> str:
    """
    Baseline: given only the instruction, ask the model to produce:
      1. A clear function signature with type hints
      2. A concise docstring explaining behavior, inputs, and outputs
      3. The implementation, following PEP8
      4. At least one simple unit test demonstrating correct usage
    """
    return f"""\
You are a senior Python engineer.  Fulfill the following task by writing production-ready code.

**Task**:
{instruction.strip()}

**Requirements**:
- Python 3, include type hints
- One well-formed function or class with a descriptive name
- A docstring (inputs, outputs, edge cases)
- PEP8 style (4-space indent, snake_case)
- At least one unit test using `assert` or `unittest`

**Implementation**:
```python
"""

def build_rag_prompt_v3(instruction: str, retrieved: str) -> str:
    """
    RAG: first show retrieved examples for inspiration, then the same structured prompt:
      instruction, requirements, and a code block marker.
    """
    return f"""\
You are a senior Python engineer.  Use the retrieved examples to guide your implementation.

**Retrieved Examples**:
{retrieved.strip()}

**Task**:
{instruction.strip()}

**Requirements**:
- Python 3 with type hints
- Clean function or class design with a docstring
- Adhere to PEP8 conventions
- Include at least one unit test


**Implementation**:
```python
"""


In [None]:
# ────────────────────────────────────────────────────────────────
# Prompt templates - version 4
# ────────────────────────────────────────────────────────────────

def build_baseline_prompt_v4(instruction: str) -> str:
    """Return exactly and only the raw answer to the user request, no extra text with no extra prompt engineering."""
    return instruction.strip()

def build_rag_prompt_v4(instruction: str, retrieved: str) -> str:
    """
    RAG: first show retrieved examples for inspiration, then the same structured prompt:
      instruction, requirements, and a code block marker.
    """
    return f"""\
You are a senior Python engineer.  Use the retrieved examples to guide your implementation.

**Retrieved Examples**:
{retrieved.strip()}

**Task**:
{instruction.strip()}

**Requirements**:
- Python 3 with type hints
- Clean function or class design with a docstring
- Adhere to PEP8 conventions
- Include at least one unit test
- Do **NOT** copy or repeat the retrieved examples; write a NEW solution
- Return exactly and only the raw answer to the user request, no extra text
**Implementation**:
```python
"""


In [None]:
# ────────────────────────────────────────────────────────────────
# Prompt templates - version 5
# ────────────────────────────────────────────────────────────────
def build_baseline_prompt_v5(instruction: str) -> str:
    return f"""\
Write a complete Python 3 implementation for the following task.  Include type hints, a docstring, and at least one unit test.

Task:
{instruction.strip()}

```python
"""


def truncate_to_n_tokens(text: str, n: int = 1800) -> str:
    ids = tokenizer.encode(text, add_special_tokens=False)
    if len(ids) <= n:
        return text
    return tokenizer.decode(ids[-n:])

def build_rag_prompt_v5(instruction: str, retrieved: str) -> str:
    retrieved = truncate_to_n_tokens(retrieved, 1800)
    return f"""\
You are a senior Python engineer.  Use the retrieved examples to inspire your implementation.

**Retrieved Examples**:
{retrieved}

**Task**:
{instruction.strip()}

**Requirements**:
- Python 3 with type hints
- Clean function or class design with a docstring
- Adhere to PEP8
- Include at least one unit test
- Do **NOT** copy or repeat the examples; write a NEW solution
**Implementation**:
```python
"""


In [None]:
# ────────────────────────────────────────────────────────────────
# Prompt templates - version 6 - markdown
# ────────────────────────────────────────────────────────────────
def build_baseline_prompt_v6(instruction: str) -> str:
    """
    Returns a Markdown-formatted prompt for a standalone task.
    """
    return f"""\
# Python Task Implementation

Write a complete **Python 3** implementation for the following task.
Include type hints, a docstring, and at least one unit test.

---

**Task**

{instruction.strip()}

```python
"""


def truncate_to_n_tokens(text: str, n: int = 1800) -> str:
    ids = tokenizer.encode(text, add_special_tokens=False)
    if len(ids) <= n:
        return text
    return tokenizer.decode(ids[-n:])



def build_rag_prompt_v6(instruction: str, retrieved: str) -> str:
    """
    Returns a Markdown-formatted RAG prompt that includes retrieved examples.
    """
    retrieved_snippet = truncate_to_n_tokens(retrieved, 1800)
    return f"""\
# Python Task Implementation (RAG)

You are a senior Python engineer. Use the retrieved examples to inspire your implementation.

---

## Retrieved Examples

{retrieved_snippet}

---

## Task

{instruction.strip()}

---

## Requirements

- Python 3 with type hints
- Clean function or class design with a docstring
- Adhere to PEP 8
- Include at least one unit test
- **Do NOT** copy or repeat the examples; write a **NEW** solution

```python
"""


## Section 2: LLM and Tokenizer Loading with 4-bit Quantization


In [None]:
# check that there is only one selected model

# --- Gemma Series (Google) ---
# model_name = "google/codegemma-7b"
# model_name = "google/codegemma-7b-it"

# --- Qwen Series (Alibaba) ---
# model_name = "Qwen/Qwen2.5-Coder-7B-Instruct"
# model_name = "Qwen/Qwen2.5-Coder-3B-Instruct"
# model_name = "Qwen/Qwen2.5-Coder-1.5B-Instruct"

# --- Deepseek Coder Series (Deepseek AI) ---
# model_name = "deepseek-ai/deepseek-coder-7b-instruct-v1.5"
#model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
# model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model_name = "deepseek-ai/deepseek-coder-1.3b-base"

# --- Code Llama Series (Meta) ---
# model_name = "codellama/CodeLlama-7b-Instruct-hf"

# --- Phi Series (Microsoft) ---
# model_name = "microsoft/Phi-4-mini-instruct"
# model_name = "microsoft/Phi-4-multimodal-instruct"

TRUST_REMOTE_CODE_MODELS = [
    "microsoft/Phi-",
    "Qwen/",
]
trust_code = any(model_name.startswith(prefix) for prefix in TRUST_REMOTE_CODE_MODELS)

print(f"Setting trust_remote_code={trust_code} for {model_name}")
if trust_code:
    print("WARNING: trust_remote_code=True will execute Python code from the model's Hugging Face repo")


Setting trust_remote_code=False for deepseek-ai/deepseek-coder-1.3b-base


In [None]:
# 4-bit NF4 quantisation
QUANT_CFG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
        if (torch.cuda.is_available() and torch.cuda.is_bf16_supported())
        else torch.float16,
    bnb_4bit_use_double_quant=True,
)

In [None]:
# == Where the cached copy will live on your Drive ===========
CACHE_ROOT = "/content/drive/MyDrive/llm_cache"
CACHE_DIR  = os.path.join(
    CACHE_ROOT,
    model_name.replace("/", "_") + "_4bit_nf4",
)

META_FILE  = os.path.join(CACHE_DIR, "metadata.json")

In [None]:
# == Build the 4‐bit config (for GPU only) ===================
# (use bfloat16 on bf16‐capable GPUs, else float16)
compute_dtype = (
    torch.bfloat16
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    else torch.float16
)

QUANT_CFG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=True,
)

def _qcfg_to_dict(cfg):
    return {
        "load_in_4bit": cfg.load_in_4bit,
        "bnb_4bit_quant_type": cfg.bnb_4bit_quant_type,
        "bnb_4bit_compute_dtype": str(cfg.bnb_4bit_compute_dtype),
        "bnb_4bit_use_double_quant": cfg.bnb_4bit_use_double_quant,
    }

REQ_META = {
    "model_name": model_name,
    "quant_cfg":  _qcfg_to_dict(QUANT_CFG),
}


In [None]:
# == Check for existing cache & metadata ====================
use_cache = False
if os.path.isfile(META_FILE):
    try:
        saved = json.load(open(META_FILE))
        use_cache = saved == REQ_META
        print("⚡ Cache metadata match:", use_cache)
    except Exception:
        print("⚠️  Could not parse metadata.json; ignoring cache.")

# == 7.  Load tokenizer & model (fast or slow path) ==============
trust_code = model_name.startswith(("microsoft/Phi-", "Qwen/"))

try:
    if use_cache:
        print("⚡ Loading from Drive cache…")
        tokenizer = AutoTokenizer.from_pretrained(CACHE_DIR, local_files_only=True, trust_remote_code=trust_code)
        model     = AutoModelForCausalLM.from_pretrained(
            CACHE_DIR,
            device_map="auto",
            low_cpu_mem_usage=True,
            trust_remote_code=trust_code,
        )
    else:
        # decide whether we *can* do 4-bit quant:
        do_4bit = torch.cuda.is_available()
        print(f"⏳ No valid cache. CUDA available? {do_4bit}")
        print(f"⏳ {'Quantising 4-bit…' if do_4bit else 'Loading fp16…'} this will happen once")

        # ─── tokenizer ───────────────────────────────────────────
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_code)
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = tokenizer.eos_token_id
            if tokenizer.pad_token is None:
                tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})

        # ─── model ───────────────────────────────────────────────
        if do_4bit:
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                quantization_config=QUANT_CFG,
                device_map="auto",
                trust_remote_code=trust_code,
                low_cpu_mem_usage=True,
            )
        else:
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=compute_dtype,
                device_map="auto",
                trust_remote_code=trust_code,
            )

        # ─── Save cache for next time ───────────────────────────
        print("💾 Saving to Drive cache…")
        os.makedirs(CACHE_DIR, exist_ok=True)
        tokenizer.save_pretrained(CACHE_DIR)
        model.save_pretrained(CACHE_DIR)
        with open(META_FILE, "w") as f:
            json.dump(REQ_META, f)
        print("✅ Cache written at", CACHE_DIR)

    # ensure model.pad_token_id
    if getattr(model, "config", None) and model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.pad_token_id

    print("🎉 Model & tokenizer ready!")

except Exception as e:
    print("❌ Error loading model/tokenizer:")
    traceback.print_exc()
    raise

⚡ Cache metadata match: True
⚡ Loading from Drive cache…
🎉 Model & tokenizer ready!


## Section 3: Dataset Preparation and Validation



In [None]:
# --- 1. Configuration of Google Drive directory ---

drive_save_path = '/content/drive/MyDrive/RAG_Project/' # to store results and outputs
# check that the directory exists!

try:
    os.makedirs(drive_save_path, exist_ok=True)
    print(f"Google Drive Directory available: {drive_save_path}")
except OSError as e:
    print(f"Warning: can not create or verify the existence of the directory: {drive_save_path}. Details: {e}")


Google Drive Directory available: /content/drive/MyDrive/RAG_Project/


In [None]:
#!pip install --upgrade --quiet datasets==2.16.0 fsspec==2023.9.2

In [None]:
from datasets import load_dataset, DownloadMode

dataset_name = "JetBrains-Research/lca-library-based-code-generation"
data_split   = "test"

print(f"\\n▶️  Loading dataset '{dataset_name}' (split='{data_split}')…")

try:
    # To work with all libraries
    lca_dataset_all_libraries = load_dataset( # Renamed for clarity
        dataset_name,
        split=data_split,
    )
    lca_dataset_split = lca_dataset_all_libraries
    print(f"✅ Dataset loaded with {len(lca_dataset_all_libraries)} examples across all libraries.")

    # To work with only specific libraries for testing:
    # target_repos = ["seed-emulator", "another-repo"] # Examples
    # lca_dataset_split = lca_dataset_all_libraries.filter(
    #     lambda ex: ex["repo_name"] in target_repos
    # )
    # print(f"✅ Filtered to {len(lca_dataset_split)} examples in {target_repos}")

except Exception as e:
    print(f"❌ ERROR loading or filtering dataset: {e}")


\n▶️  Loading dataset 'JetBrains-Research/lca-library-based-code-generation' (split='test')…


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/5.21k [00:00<?, ?B/s]

(…)-00000-of-00001-518ed46ecbe35ff9.parquet:   0%|          | 0.00/4.58M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/150 [00:00<?, ? examples/s]

✅ Dataset loaded with 150 examples across all libraries.


## Section 4: GitHub Knowledge Base Access Setup

This section configures access to pre-built Knowledge Bases (KBs) hosted on a GitHub repository.

It performs two main tasks:
1.  Defines a helper function (`download_github_raw_json`) to fetch KB JSON files from GitHub.
2.  Sets essential GitHub repository parameters (username, repo name, branch, KB folder path) to construct the base URL for downloading KBs. It also specifies a local temporary directory for these downloads.

This setup enables subsequent sections to dynamically load specific KBs from the designated GitHub source.

In [None]:
import requests
import os
import json # For loading the JSON after download

def download_github_raw_json(raw_url, local_save_dir, filename, overwrite=False):
    """
    Downloads a JSON file from a GitHub raw content URL.
    Saves it locally and then loads and returns the JSON content.
    Returns None if download or JSON parsing fails.
    """
    os.makedirs(local_save_dir, exist_ok=True)
    local_file_path = os.path.join(local_save_dir, filename)

    if os.path.exists(local_file_path) and not overwrite:
        # print(f"  File already exists locally: {local_file_path}. Loading it.")
        try:
            with open(local_file_path, 'r', encoding='utf-8') as f:
                return json.load(f)
        except Exception as e_load:
            print(f"  ERROR loading existing local file {local_file_path}: {e_load}. Will attempt re-download.")
            # Proceed to re-download by falling through

    try:
        # print(f"  Downloading from GitHub: {raw_url} to {local_file_path}")
        response = requests.get(raw_url, stream=True)
        response.raise_for_status()  # Will raise an HTTPError if the HTTP request returned an unsuccessful status code

        with open(local_file_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        # print(f"  Download complete: {local_file_path}")

        with open(local_file_path, 'r', encoding='utf-8') as f:
            return json.load(f)

    except requests.exceptions.RequestException as e_req:
        print(f"  ERROR downloading from GitHub {raw_url}: {e_req}")
    except json.JSONDecodeError as e_json:
        print(f"  ERROR decoding JSON from downloaded file {local_file_path}: {e_json}")
    except Exception as e_generic:
        print(f"  An unexpected error occurred during GitHub download/processing for {filename}: {e_generic}")

    return None # Return None on any failure

print("GitHub download helper function defined.")

GitHub download helper function defined.


In [None]:
import os # Retain os if other parts of this cell use it

GITHUB_USERNAME = "PatrizioAcquadro"
GITHUB_REPO_NAME = "RAG_Project_SE2"
GITHUB_BRANCH = "main"
GITHUB_KBS_FOLDER_PATH = "knowledge_bases_prod"

# Base URL for raw GitHub content
GITHUB_RAW_CONTENT_BASE_URL = f"https://raw.githubusercontent.com/{GITHUB_USERNAME}/{GITHUB_REPO_NAME}/{GITHUB_BRANCH}/{GITHUB_KBS_FOLDER_PATH}"

# Local directory on Colab VM to temporarily store downloaded KBs
LOCAL_TEMP_KB_DOWNLOAD_DIR = "/content/temp_downloaded_kbs"
os.makedirs(LOCAL_TEMP_KB_DOWNLOAD_DIR, exist_ok=True)


print(f"--- RAG Knowledge Base Configuration (GitHub) ---")
print(f"  KBs will be downloaded from GitHub base URL: {GITHUB_RAW_CONTENT_BASE_URL}/kb_LIBRARY_KEY.json")
print(f"  Downloaded KBs will be temporarily stored in: {LOCAL_TEMP_KB_DOWNLOAD_DIR}")

--- RAG Knowledge Base Configuration (GitHub) ---
  KBs will be downloaded from GitHub base URL: https://raw.githubusercontent.com/PatrizioAcquadro/RAG_Project_SE2/main/knowledge_bases_prod/kb_LIBRARY_KEY.json
  Downloaded KBs will be temporarily stored in: /content/temp_downloaded_kbs


## Section 5: BM25 Retrieval Analysis

This section is dedicated to analyzing the BM25 retrieval process for a selected code generation sample. It is divided into two main parts:
1.  **Configuration:** Defining all parameters for the retrieval analysis (The sample to inspect, BM25 settings, and the tokenization strategy).
2.  **Execution & Display:** Loading the relevant KB, performing BM25 retrieval based on the configurations, and displaying the results.

This allows for easy experimentation with different retrieval settings before full-scale evaluation.

It uses pre-built KBs from `DRIVE_KBS_ROOT_PATH`.

In [None]:
import re # For the default tokenizer

# --- 1. Rep & Sample Selection ---
ANALYSIS_TARGET_REPO_FULL_NAME = "pyscf__pyscf"
ANALYSIS_SAMPLE_INDEX_WITHIN_REPO = 0 # If line above = "None", it's chosen randomly from all repos

# --- 2. BM25 Algorithm & Retrieval Parameters ---
ANALYSIS_BM25_K1 = 1.5
ANALYSIS_BM25_B = 0.75
ANALYSIS_TOP_K_SNIPPETS = 5

# --- 3. BM25 Tokenizer Selection ---
# Define tokenizer functions here. The selected one will be used by BM25 in Cell 5.2.
def robust_code_tokenizer_for_s5(text_input):
    if not isinstance(text_input, str): return []
    text = text_input.lower()
    raw_tokens = re.split(r'[^a-z0-9_]+', text) # Keep alphanumeric and underscore
    # Filter out empty strings, single characters (often noise), and pure numbers
    return [token for token in raw_tokens if token and len(token) > 1 and not token.isdigit()]

# Select the tokenizer function to be used:
ANALYSIS_BM25_TOKENIZER = robust_code_tokenizer_for_s5

# --- 4. Display Options for Analysis Cell (Cell 5.2) ---
ANALYSIS_SHOW_QUERY_TOKENS = True
ANALYSIS_HIGHLIGHT_KEYWORDS = True # In retrieved snippets

print("  Configuration for BM25 Retrieval Analysis (Section 5) is set:")

if ANALYSIS_TARGET_REPO_FULL_NAME:
    print(f"    Target Library for Analysis: '{ANALYSIS_TARGET_REPO_FULL_NAME}'")
    print(f"    Instruction Index within this library's samples: {ANALYSIS_SAMPLE_INDEX_WITHIN_REPO}")
else:
    print(f"    Sample Index from lca_dataset_split for Analysis: {ANALYSIS_SAMPLE_INDEX_WITHIN_REPO}")

print(f"    BM25 Params: k1={ANALYSIS_BM25_K1}, b={ANALYSIS_BM25_B}")
print(f"    Number of Snippets to Retrieve (Top-K): {ANALYSIS_TOP_K_SNIPPETS}")
print(f"    Tokenizer for BM25: {ANALYSIS_BM25_TOKENIZER.__name__}")

  Configuration for BM25 Retrieval Analysis (Section 5) is set:
    Target Library for Analysis: 'pyscf__pyscf'
    Instruction Index within this library's samples: 0
    BM25 Params: k1=1.5, b=0.75
    Number of Snippets to Retrieve (Top-K): 5
    Tokenizer for BM25: robust_code_tokenizer_for_s5


In [None]:
# Cell 5.2: Retrieval Analysis - Execution and Display (GitHub KBs, Simplified top_k)

import json
import os
import textwrap
from rank_bm25 import BM25Okapi
from IPython.display import display, HTML
import re # Needed for highlighting if not imported globally for ANALYSIS_BM25_TOKENIZER

# --- 1. Ensure Configurations from Cell 5.1 & Globals are available ---
config_vars_s5_2_final_github = [
    'ANALYSIS_SAMPLE_INDEX_WITHIN_REPO', 'ANALYSIS_TARGET_REPO_FULL_NAME',
    'ANALYSIS_BM25_K1', 'ANALYSIS_BM25_B', 'ANALYSIS_TOP_K_SNIPPETS',
    'ANALYSIS_BM25_TOKENIZER', 'ANALYSIS_SHOW_QUERY_TOKENS', 'ANALYSIS_HIGHLIGHT_KEYWORDS'
]
if any(var_name not in globals() for var_name in config_vars_s5_2_final_github):
    raise NameError("One or more configuration variables from Cell 5.1 are not defined. Run Cell 5.1 first.")

global_vars_s5_2_final_github = ['lca_dataset_split', 'GITHUB_RAW_CONTENT_BASE_URL',
                                 'LOCAL_TEMP_KB_DOWNLOAD_DIR', 'download_github_raw_json']
if any(var_name not in globals() for var_name in global_vars_s5_2_final_github):
    raise NameError("One or more global prerequisite variables (dataset, GitHub config, download helper) are not defined.")

print(f"--- Section 5.2: Executing BM25 Retrieval Analysis (from GitHub KBs) ---")

# --- 2. Select Sample Data based on Configuration from Cell 5.1 ---
s5_instruction_to_s6 = None
s5_snippets_to_s6 = [] # Initialize for output to Section 6
library_key_for_kb_file = None # This will be the 'repo_full_name' from dataset

try:
    if ANALYSIS_TARGET_REPO_FULL_NAME: # As defined in Cell 5.1
        library_key_for_kb_file = ANALYSIS_TARGET_REPO_FULL_NAME
        library_samples = lca_dataset_split.filter(lambda ex: ex['repo_full_name'] == ANALYSIS_TARGET_REPO_FULL_NAME)
        if not library_samples:
            raise ValueError(f"No samples found for specified library (repo_full_name): '{ANALYSIS_TARGET_REPO_FULL_NAME}'.")
        if not (0 <= ANALYSIS_SAMPLE_INDEX_WITHIN_REPO < len(library_samples)): # ANALYSIS_SAMPLE_INDEX_WITHIN_REPO from Cell 5.1
            raise IndexError(f"ANALYSIS_SAMPLE_INDEX_WITHIN_REPO {ANALYSIS_SAMPLE_INDEX_WITHIN_REPO} is out of bounds for library '{ANALYSIS_TARGET_REPO_FULL_NAME}'.")
        target_sample_data = library_samples[ANALYSIS_SAMPLE_INDEX_WITHIN_REPO]
        print(f"  Analyzing instruction #{ANALYSIS_SAMPLE_INDEX_WITHIN_REPO} from library: '{library_key_for_kb_file}'")
    else:
        if not (0 <= ANALYSIS_SAMPLE_INDEX_WITHIN_REPO < len(lca_dataset_split)):
            raise IndexError(f"ANALYSIS_SAMPLE_INDEX_WITHIN_REPO {ANALYSIS_SAMPLE_INDEX_WITHIN_REPO} is out of bounds for the full lca_dataset_split.")
        target_sample_data = lca_dataset_split[ANALYSIS_SAMPLE_INDEX_WITHIN_REPO]
        library_key_for_kb_file = target_sample_data.get('repo_full_name') # Derive from sample
        print(f"  Analyzing sample at global index {ANALYSIS_SAMPLE_INDEX_WITHIN_REPO}. Derived Library Key: '{library_key_for_kb_file}'")

    s5_instruction_to_s6 = target_sample_data.get('instruction')
    if not s5_instruction_to_s6 or not library_key_for_kb_file:
        raise ValueError("Selected sample missing 'instruction' or 'repo_full_name' could not be determined.")

    print(f"\n  Target Library Key for KB: '{library_key_for_kb_file}'") # This is the 'repo_full_name'
    print(f"  Instruction Text:\n    {textwrap.fill(s5_instruction_to_s6, width=100, initial_indent='    ', subsequent_indent='    ')}")

    query_tokens = ANALYSIS_BM25_TOKENIZER(s5_instruction_to_s6) # ANALYSIS_BM25_TOKENIZER from Cell 5.1
    if ANALYSIS_SHOW_QUERY_TOKENS: # From Cell 5.1
        print(f"\n  Query Tokens (using '{ANALYSIS_BM25_TOKENIZER.__name__}'):\n    {query_tokens}")

except Exception as e_sample_select:
    print(f"  ERROR during sample selection: {type(e_sample_select).__name__}: {e_sample_select}")
    library_key_for_kb_file = None # Prevent further processing if sample selection fails

# --- 3. Load KB from GitHub, Build Index, and Perform BM25 Retrieval ---
if library_key_for_kb_file: # Proceed only if library key was successfully determined
    kb_filename_on_github = f"kb_{library_key_for_kb_file}.json"
    raw_kb_url_s5 = f"{GITHUB_RAW_CONTENT_BASE_URL}/{kb_filename_on_github}" # GITHUB_RAW_CONTENT_BASE_URL from Sec4.Cell1 (GitHub config)

    # Download to a subfolder within LOCAL_TEMP_KB_DOWNLOAD_DIR
    temp_save_subdir_for_kb_s5 = os.path.join(LOCAL_TEMP_KB_DOWNLOAD_DIR, library_key_for_kb_file) # LOCAL_TEMP_KB_DOWNLOAD_DIR from Sec4.Cell1

    print(f"\n  Attempting to download/load KB for '{library_key_for_kb_file}' from GitHub...")
    kb_data_from_git = download_github_raw_json( # download_github_raw_json helper function
        raw_kb_url_s5,
        temp_save_subdir_for_kb_s5,
        kb_filename_on_github,
        overwrite=True # For analysis, always get fresh from GitHub, or False to use local Colab cache
    )

    if not kb_data_from_git:
        print(f"  ERROR: Failed to download or parse KB for '{library_key_for_kb_file}' from GitHub.")
    else:
        print(f"  Successfully loaded KB for '{library_key_for_kb_file}' from GitHub ({len(kb_data_from_git)} snippets).")
        valid_kb_docs = [str(doc) for doc in kb_data_from_git if isinstance(doc, str) and str(doc).strip()]

        if not valid_kb_docs:
            print("  No valid string snippets in loaded KB for BM25 indexing.")
        else:
            tokenized_corpus = [ANALYSIS_BM25_TOKENIZER(doc) for doc in valid_kb_docs]
            final_bm25_corpus_docs, map_idx_bm25_to_valid_docs = [], []
            for i, tokens in enumerate(tokenized_corpus):
                if tokens:
                    final_bm25_corpus_docs.append(tokens)
                    map_idx_bm25_to_valid_docs.append(i)

            if not final_bm25_corpus_docs:
                print("  Tokenized KB is empty after filtering. BM25 index not built.")
            else:
                print(f"  Creating BM25 index from {len(final_bm25_corpus_docs)} processable documents...")
                # ANALYSIS_BM25_K1 and ANALYSIS_BM25_B are from Cell 5.1
                bm25_index = BM25Okapi(final_bm25_corpus_docs, k1=ANALYSIS_BM25_K1, b=ANALYSIS_BM25_B)
                print("  BM25 index built.")

                # ANALYSIS_TOP_K_SNIPPETS is from Cell 5.1
                if query_tokens and ANALYSIS_TOP_K_SNIPPETS > 0:
                    print(f"\n  --- Retrieving and Displaying Top {ANALYSIS_TOP_K_SNIPPETS} Snippets ---")
                    num_docs = len(final_bm25_corpus_docs)
                    top_indices = bm25_index.get_top_n(
                        query_tokens, list(range(num_docs)),
                        n=min(ANALYSIS_TOP_K_SNIPPETS, num_docs)
                    )
                    s5_snippets_to_s6 = [valid_kb_docs[map_idx_bm25_to_valid_docs[i]] for i in top_indices]

                    if not s5_snippets_to_s6:
                        print(f"    No snippets retrieved for top_k = {ANALYSIS_TOP_K_SNIPPETS}.")
                    else:
                        for i_snip, snip_content in enumerate(s5_snippets_to_s6):
                            print(f"\n    Snippet {i_snip+1}/{len(s5_snippets_to_s6)} (Length: {len(snip_content)} chars):")
                            if ANALYSIS_HIGHLIGHT_KEYWORDS: # From Cell 5.1
                                hl_content = snip_content
                                unique_qt = sorted(list(set(query_tokens)), key=len, reverse=True)
                                for i_t, tkn in enumerate(unique_qt):
                                    ph = f"__HL_S5_{i_t}__" # More specific placeholder
                                    hl_content = re.sub(f"\\b({re.escape(tkn)})\\b", ph, hl_content, flags=re.IGNORECASE)
                                for i_t, tkn in enumerate(unique_qt):
                                    ph = f"__HL_S5_{i_t}__"
                                    hl_content = hl_content.replace(ph, f"<b style='background-color:#FFFACD; color:black; font-weight:bold;'>{tkn}</b>")
                                display(HTML(f"<pre style='white-space:pre-wrap; word-wrap:break-word; border:1px dashed #ccc; padding:6px; margin-left:20px;'>{hl_content}</pre>"))
                            else:
                                print(textwrap.indent(textwrap.fill(snip_content, width=100, subsequent_indent='      '), '      '))
                    print(f"\n    Stored {len(s5_snippets_to_s6)} snippets in 's5_snippets_to_s6' for Section 6.")
                elif not query_tokens: print("  Query tokens empty. BM25 retrieval skipped.")
                else: print(f"  ANALYSIS_TOP_K_SNIPPETS ({ANALYSIS_TOP_K_SNIPPETS}) is not positive. No snippets retrieved.")
else: # library_key_for_kb_file was None due to sample selection error
    print("\n  Sample selection failed earlier. Skipping KB loading and BM25 retrieval.")

if not s5_snippets_to_s6 and library_key_for_kb_file and ('kb_data_from_git' in locals() and kb_data_from_git is not None):
    print("  INFO: No snippets were ultimately stored for Section 6 from this analysis.")

print(f"\n--- Section 5.2: Retrieval Analysis Execution Complete ---")
# Variables `s5_instruction_to_s6` and `s5_snippets_to_s6` are now populated for Section 6.

--- Section 5.2: Executing BM25 Retrieval Analysis (from GitHub KBs) ---


Filter:   0%|          | 0/150 [00:00<?, ? examples/s]

  Analyzing instruction #0 from library: 'pyscf__pyscf'

  Target Library Key for KB: 'pyscf__pyscf'
  Instruction Text:
        Generate code that calculates the effective electronic coupling based on single determinant
    diabatic states using the pyscf library. The code should first define a molecule with specific
    atoms and basis. Then, it should perform two state calculations with DFT, storing molecular
    orbital information into separate chkfiles. The code should then read the MO coefficients and
    occupation numbers from these chkfiles. Afterwards, it should calculate the overlap between two
    determinants, construct density matrices, calculate one-electron and two-electron part
    contributions, and calculate new total energy. Finally, the code should calculate the effective
    electronic coupling and print the results. The code should also remove the chkfiles at the end.

  Query Tokens (using 'robust_code_tokenizer_for_s5'):
    ['generate', 'code', 'that', 'calcu


    Snippet 2/5 (Length: 1758 chars):



    Snippet 3/5 (Length: 566 chars):



    Snippet 4/5 (Length: 5338 chars):



    Snippet 5/5 (Length: 5880 chars):



    Stored 5 snippets in 's5_snippets_to_s6' for Section 6.

--- Section 5.2: Retrieval Analysis Execution Complete ---


## Section 6: RAG Prompt Assembly for Demo Sample

This section assembles the final RAG prompt for the LLM, using the instruction and retrieved snippets from the Section 5 analysis.

It first sets up by selecting the desired `build_rag_prompt` function (from the globally defined prompt templates) and checks for its dependencies, like the LLM `tokenizer` if needed for snippet truncation.

Then, it takes the instruction and retrieved context (snippets) provided by Section 5, formats the snippets into a text block, and calls the chosen `build_rag_prompt` function. The resulting complete prompt string is stored in `s6_final_rag_prompt_output` and a preview is displayed, making it ready for the LLM generation step in Section 7.

In [None]:
import textwrap

# --- 1. Ensure Prerequisite variables from Section 5 (5.2) are defined ---
if 's5_instruction_to_s6' not in locals() or \
   's5_snippets_to_s6' not in locals():
    raise NameError("Variables 's5_instruction_to_s6' or 's5_snippets_to_s6' not found. "
                    "Ensure Section 5 (BM25 Retrieval Analysis - Execution and Display) has been run successfully.")

# --- 2. Select the RAG Prompt Builder Function ---
build_rag_prompt_to_use_in_s6 = build_rag_prompt_v6 # Defaulting to v6 as an example

# Verify that the chosen function is actually defined
if 'build_rag_prompt_to_use_in_s6' not in locals() or not callable(build_rag_prompt_to_use_in_s6):
    raise NameError(f"The function assigned to 'build_rag_prompt_to_use_in_s6' is not defined or not callable. "
                    f"Please check its definition and assignment in this cell.")

print(f"INFO: Using '{build_rag_prompt_to_use_in_s6.__name__}' for RAG prompt assembly in this section.")


# --- 3. Check for Tokenizer Dependency if needed by the chosen prompt builder ---
# Important if the chosen prompt builder (v5/v6) uses a helper like `truncate_to_n_tokens` which itself requires the LLM's tokenizer.
TOKENIZER_NEEDED_BY_SELECTED_PROMPT_BUILDER = False
# Heuristic: Check if 'truncate_to_n_tokens' is called by the selected builder.
# This assumes 'truncate_to_n_tokens' is the name of the helper that needs the tokenizer.
try:
    func_code_object = getattr(build_rag_prompt_to_use_in_s6, '__code__', None)
    if func_code_object and "truncate_to_n_tokens" in func_code_object.co_names:
        TOKENIZER_NEEDED_BY_SELECTED_PROMPT_BUILDER = True
        # Also, ensure 'truncate_to_n_tokens' itself is defined globally if it's called.
        if "truncate_to_n_tokens" not in globals():
            raise NameError(f"'truncate_to_n_tokens' function is called by '{build_rag_prompt_to_use_in_s6.__name__}' "
                            "but 'truncate_to_n_tokens' itself is not defined globally.")
except AttributeError:
    # Could happen if build_rag_prompt_to_use_in_s6 is a lambda or other non-standard callable
    # For simplicity, we'll assume if we can't inspect, it might not need it, or it will fail at runtime if it does and tokenizer is missing.
    pass

if TOKENIZER_NEEDED_BY_SELECTED_PROMPT_BUILDER:
    if 'tokenizer' not in globals():
        raise NameError(f"LLM 'tokenizer' not defined globally, but it is required by the selected "
                        f"prompt builder '{build_rag_prompt_to_use_in_s6.__name__}' (likely for snippet truncation). "
                        "Please ensure the tokenizer is loaded in an earlier section (e.g., Section 2).")
    else:
        print(f"  INFO: Selected prompt builder '{build_rag_prompt_to_use_in_s6.__name__}' may use the global 'tokenizer'. Ensure 'tokenizer' is correctly loaded.")
else:
    print(f"  INFO: Selected prompt builder '{build_rag_prompt_to_use_in_s6.__name__}' does not appear to directly require the global 'tokenizer' for truncation via 'truncate_to_n_tokens'.")

INFO: Using 'build_rag_prompt_v6' for RAG prompt assembly in this section.
  INFO: Selected prompt builder 'build_rag_prompt_v6' may use the global 'tokenizer'. Ensure 'tokenizer' is correctly loaded.


In [None]:
s6_final_rag_prompt_output = None # Initialize the output of this section

# Prepare the prompt with the instruction
if s5_instruction_to_s6 is None:
    print("  ERROR: No instruction ('s5_instruction_to_s6') available from Section 5. Cannot assemble prompt.")
else:
    print(f"  Using instruction: '{s5_instruction_to_s6[:100]}...'")
    print(f"  Using {len(s5_snippets_to_s6)} retrieved snippets from 's5_snippets_to_s6'.")


    # Prepare the prompt with the retrieved snippets (`build_rag_prompt` likely expects a single string, so they are joined by separator)
    retrieved_snippets_as_text_block = "\n\n# --- Snippet Separator ---\n\n".join(s5_snippets_to_s6) \
                                       if s5_snippets_to_s6 else ""
    if not s5_snippets_to_s6:
        print("  INFO: No snippets were provided from Section 5 ('s5_snippets_to_s6' is empty). "
              "The RAG prompt will be assembled with an empty retrieved context.")

    # --- Build the Prompt ---
    try:
        s6_final_rag_prompt_output = build_rag_prompt_to_use_in_s6(
            s5_instruction_to_s6,
            retrieved_snippets_as_text_block
        )

        if s6_final_rag_prompt_output:
            print("\n  Final RAG Prompt Assembled (First 700 Characters):")
            print(textwrap.shorten(s6_final_rag_prompt_output, width=700, placeholder="... (prompt truncated) ..."))
            # For the full prompt if needed for debugging:
            # print(s6_final_rag_prompt_output)
        else:
             print("  ERROR: Prompt assembly using your 'build_rag_prompt' function "
                   "resulted in an empty or None prompt. Please check the function's logic.")
    except Exception as e_build_prompt:
        print(f"  ERROR during prompt assembly with '{build_rag_prompt_to_use_in_s6.__name__}': {e_build_prompt}")
        s6_final_rag_prompt_output = None


# The variable `s6_final_rag_prompt_output` now holds the complete RAG prompt string.
if s6_final_rag_prompt_output:
    print(f"\n--- Section 6: RAG Prompt Assembly Complete. Output in 's6_final_rag_prompt_output' (Length: {len(s6_final_rag_prompt_output)} chars) ---")
else:
    print(f"\n--- Section 6: RAG Prompt Assembly Failed or Produced No Output ---")

  Using instruction: 'Generate code that calculates the effective electronic coupling based on single determinant diabatic...'
  Using 5 retrieved snippets from 's5_snippets_to_s6'.

  Final RAG Prompt Assembled (First 700 Characters):
# Python Task Implementation (RAG) You are a senior Python engineer. Use the retrieved examples to inspire your implementation. --- ## Retrieved Examples matrix. A frozen ddCOSMO potential is added to the results. ''' if isinstance(mc, _Solvation): mc.with_solvent = solvent_obj return mc oldCAS = mc.__class__ if dm is not None: solvent_obj.e, solvent_obj.v = solvent_obj.kernel(dm) solvent_obj.frozen = True class CASSCFWithSolvent(_Solvation, oldCAS): def __init__(self, mc, solvent): self.__dict__.update(mc.__dict__) self.with_solvent = solvent self._e_tot_without_solvent = 0 self._keys.update(['with_solvent']) def dump_flags(self, verbose=None):... (prompt truncated) ...

--- Section 6: RAG Prompt Assembly Complete. Output in 's6_final_rag_prompt_output'

## Section 7: LLM Generation for RAG Demo Prompt

This section generates code using the LLM for the demo RAG prompt assembled in Section 6.

It first **configures LLM generation settings**, including global inference parameters and a custom stopping criteria class (`EosAndCodeEndStoppingCriteria`) to manage output length. This setup is done once.

Subsequently, it **executes the generation**:
*   Takes the RAG prompt from Section 6.
*   Uses the pre-set configurations to call `model.generate()`.
*   Decodes, cleans, and displays the LLM-generated code, storing the result in `s7_generated_code_rag_demo`.

In [None]:
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
import time
import re
import textwrap

def generate_llm_code_and_clean(
    prompt_text: str,
    llm_model,
    llm_tokenizer,
    max_new_tokens_gen,
    do_sample_gen,
    temperature_gen,
    top_p_gen,
    top_k_gen,
    repetition_penalty_gen,
    stopping_criteria_list_gen, # Can be None
    prompt_name: str = "Prompt" # For logging
    ) -> str | None:
    cleaned_generated_code = None
    print(f"  Starting LLM generation for: {prompt_name}")
    try:
        inputs = llm_tokenizer(
            prompt_text, return_tensors="pt"
        ).to(llm_model.device)
        prompt_len = inputs['input_ids'].shape[1]
        # print(f"    Tokenized prompt length: {prompt_len} tokens.") # Optional

        gen_args = {
            "input_ids": inputs['input_ids'], "attention_mask": inputs['attention_mask'],
            "max_new_tokens": max_new_tokens_gen, "pad_token_id": llm_tokenizer.eos_token_id,
            "repetition_penalty": repetition_penalty_gen, "stopping_criteria": stopping_criteria_list_gen
        }
        if do_sample_gen:
            gen_args.update({
                "temperature": temperature_gen, "top_p": top_p_gen,
                "top_k": top_k_gen, "do_sample": True
            })
        else:
            gen_args["do_sample"] = False

        gen_start = time.time()
        with torch.no_grad(): output_ids = llm_model.generate(**gen_args)
        print(f"    LLM generation for '{prompt_name}' finished in {time.time() - gen_start:.2f}s.")

        generated_ids_part = output_ids[0, prompt_len:]
        raw_output = llm_tokenizer.decode(generated_ids_part, skip_special_tokens=True)
        # print(f"\n    Raw LLM Output for '{prompt_name}' (first 300 chars):\n{textwrap.shorten(raw_output, 300, placeholder='...')}") # Optional

        # Using your preferred cleaning logic (can be the external function if defined)
        # For simplicity, embedding it here. If you defined `extract_code_from_llm_output` earlier, call that.
        match = re.search(r"```python\n(.*?)(?:\n```|\Z)", raw_output, re.DOTALL | re.IGNORECASE)
        if match: cleaned_generated_code = match.group(1).strip()
        elif "\n```" in raw_output: cleaned_generated_code = raw_output.split("\n```")[0].strip()
        else: cleaned_generated_code = raw_output.strip()

        # print(f"    Cleaned code for '{prompt_name}' (first 300 chars):\n{textwrap.shorten(cleaned_generated_code, 300, placeholder='...') if cleaned_generated_code else '[No code extracted]'}")

    except torch.cuda.OutOfMemoryError as e: cleaned_generated_code = None; print(f"  ❌ OOM ERROR during '{prompt_name}' generation: {e}")
    except Exception as e: cleaned_generated_code = None; print(f"  ❌ ERROR during '{prompt_name}' generation: {type(e).__name__}: {e}")

    return cleaned_generated_code

In [None]:
MAX_NEW_TOKENS = globals().get('MAX_NEW_TOKENS', 1024)
TEMPERATURE = globals().get('TEMPERATURE', 0.6)
TOP_P = globals().get('TOP_P', 0.95)
TOP_K = globals().get('TOP_K', 50)
REPETITION_PENALTY = globals().get('REPETITION_PENALTY', 1.1)
DO_SAMPLE = globals().get('DO_SAMPLE', True)
STOP_ON_EOS = globals().get('STOP_ON_EOS', True)           # For custom stopping criteria
STOP_ON_CODE_END = globals().get('STOP_ON_CODE_END', True) # For custom stopping criteria

print(f"  Global LLM Generation Parameters now set/confirmed:")
print(f"    MAX_NEW_TOKENS={MAX_NEW_TOKENS}, TEMPERATURE={TEMPERATURE if DO_SAMPLE else 'N/A (Greedy)'}")
print(f"    TOP_P={TOP_P if DO_SAMPLE else 'N/A'}, TOP_K={TOP_K if DO_SAMPLE else 'N/A'}")
print(f"    REPETITION_PENALTY={REPETITION_PENALTY}, DO_SAMPLE={DO_SAMPLE}")
print(f"    Custom Stopping: STOP_ON_EOS={STOP_ON_EOS}, STOP_ON_CODE_END={STOP_ON_CODE_END}")

# --- 2. Define Custom Stopping Criteria Class (if not already globally defined) ---
# This class is defined once.
if 'EosAndCodeEndStoppingCriteria' not in globals(): # Define only if not already defined
    class EosAndCodeEndStoppingCriteria(StoppingCriteria):
        """Stops generation on EOS token or a specific code-ending sequence."""
        def __init__(self, tokenizer_instance, stop_on_eos_token=True, code_end_sequence="\n```\n"):
            self.tokenizer = tokenizer_instance
            self.stop_on_eos = stop_on_eos_token
            self.code_end_sequence_str = code_end_sequence
            self.code_end_sequence_ids = []
            if self.code_end_sequence_str and self.tokenizer: # Ensure tokenizer is valid
                try:
                    self.code_end_sequence_ids = self.tokenizer.encode(
                        self.code_end_sequence_str,
                        add_special_tokens=False
                    )
                except Exception as e_encode:
                    print(f"    WARNING: Failed to encode stop sequence '{self.code_end_sequence_str}': {e_encode}")
                    self.code_end_sequence_ids = [] # Ensure it's empty on failure

        def __call__(self, current_ids: torch.LongTensor, current_scores: torch.FloatTensor, **kwargs) -> bool:
            if self.stop_on_eos and self.tokenizer and hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
                if current_ids[0, -1] == self.tokenizer.eos_token_id:
                    return True
            if self.code_end_sequence_ids: # Only check if sequence IDs were successfully created
                seq_len = len(self.code_end_sequence_ids)
                if current_ids.shape[1] >= seq_len:
                    if torch.equal(current_ids[0, -seq_len:], torch.tensor(self.code_end_sequence_ids).to(current_ids.device)):
                        return True
            return False
    print("  Custom 'EosAndCodeEndStoppingCriteria' class defined.")
else:
    print("  Custom 'EosAndCodeEndStoppingCriteria' class already defined.")


# --- 3. Instantiate Global Stopping Criteria List ---
# This `llm_stopping_criteria_global` will be used by all generation calls needing these criteria.
llm_stopping_criteria_global = None
if 'tokenizer' not in globals() or tokenizer is None: # Check if global 'tokenizer' is loaded
    print("  WARNING (Cell 7.1): Global 'tokenizer' not found or is None. "
          "Custom stopping criteria cannot be created. LLM will use default stopping.")
elif STOP_ON_EOS or STOP_ON_CODE_END: # Only create if flags are true
    try:
        # Ensure EosAndCodeEndStoppingCriteria is defined before calling it
        if 'EosAndCodeEndStoppingCriteria' not in globals():
             raise NameError("EosAndCodeEndStoppingCriteria class not defined prior to instantiation.")

        eos_code_ender_criteria_instance = EosAndCodeEndStoppingCriteria(
            tokenizer, # Use the globally loaded LLM tokenizer
            stop_on_eos_token=STOP_ON_EOS,
            code_end_sequence="\n```\n" if STOP_ON_CODE_END else None
        )
        llm_stopping_criteria_global = StoppingCriteriaList([eos_code_ender_criteria_instance])
        print("  Global 'llm_stopping_criteria_global' (for EOS/code end) created successfully.")
    except Exception as e_stop_crit_create:
        print(f"  WARNING (Cell 7.1): Failed to create global custom stopping criteria: {type(e_stop_crit_create).__name__}: {e_stop_crit_create}")
        llm_stopping_criteria_global = None # Ensure it's None on failure
else:
    print("  Global custom stopping criteria (EOS/code end) are disabled by configuration flags (STOP_ON_EOS/STOP_ON_CODE_END).")

  Global LLM Generation Parameters now set/confirmed:
    MAX_NEW_TOKENS=1024, TEMPERATURE=0.6
    TOP_P=0.95, TOP_K=50
    REPETITION_PENALTY=1.1, DO_SAMPLE=True
    Custom Stopping: STOP_ON_EOS=True, STOP_ON_CODE_END=True
  Custom 'EosAndCodeEndStoppingCriteria' class defined.
  Global 'llm_stopping_criteria_global' (for EOS/code end) created successfully.


In [None]:
if 's6_final_rag_prompt_output' not in locals() or not s6_final_rag_prompt_output:
    raise NameError("Input prompt 's6_final_rag_prompt_output' from Section 6 not found. Run Section 6 first.")
if 'model' not in locals() or 'tokenizer' not in locals():
    raise NameError("LLM 'model' or 'tokenizer' not defined. Ensure Section 2 has run.")
if 'generate_llm_code_and_clean' not in locals(): # Helper defined before 7.1
    raise NameError("Helper function 'generate_llm_code_and_clean' not defined. Ensure it's defined before Section 7.1.")

# Ensure global generation params and llm_stopping_criteria_global are available from Cell 7.1
required_globals_for_7_2 = [
    'MAX_NEW_TOKENS', 'DO_SAMPLE', 'TEMPERATURE', 'TOP_P', 'TOP_K',
    'REPETITION_PENALTY', 'llm_stopping_criteria_global' # Note: llm_stopping_criteria_global can be None
]
if any(p not in globals() for p in required_globals_for_7_2):
    raise NameError(f"One or more global LLM generation parameters/criteria from Cell 7.1 are missing.")

print(f"  Using RAG prompt (length: {len(s6_final_rag_prompt_output)} chars) from Section 6.")

# --- 2. Call Reusable LLM Generation Function ---
s7_generated_code_rag_demo = generate_llm_code_and_clean(
    prompt_text=s6_final_rag_prompt_output,
    llm_model=model,
    llm_tokenizer=tokenizer,
    max_new_tokens_gen=MAX_NEW_TOKENS,         # Global param
    do_sample_gen=DO_SAMPLE,                  # Global param
    temperature_gen=TEMPERATURE,              # Global param
    top_p_gen=TOP_P,                          # Global param
    top_k_gen=TOP_K,                          # Global param
    repetition_penalty_gen=REPETITION_PENALTY,# Global param
    stopping_criteria_list_gen=llm_stopping_criteria_global, # From Cell 7.1 (can be None)
    prompt_name="RAG Demo Prompt (S7.2)"      # For logging within the helper
)

# --- 3. Display Result ---
if s7_generated_code_rag_demo is not None: # Check if helper returned code (not None for error)
    print("\n  --- Cleaned Generated RAG Code (Demo) ---")
    # Displaying a significant portion for review
    # print(textwrap.shorten(s7_generated_code_rag_demo, width=1000, placeholder="... (cleaned code truncated for display) ..."))
    # To print the entire generated code if needed:
    print("\n  Full Cleaned Generated RAG Code (Demo):\n", s7_generated_code_rag_demo)
    print(f"\n--- RAG Demo Generation Complete. Result in 's7_generated_code_rag_demo'. ---")
else:
    print(f"\n--- RAG Demo Generation Failed or Produced No Valid Code Output. 's7_generated_code_rag_demo' is None. ---")

  Using RAG prompt (length: 6859 chars) from Section 6.
  Starting LLM generation for: RAG Demo Prompt (S7.2)
    LLM generation for 'RAG Demo Prompt (S7.2)' finished in 33.04s.

  --- Cleaned Generated RAG Code (Demo) ---

  Full Cleaned Generated RAG Code (Demo):
 """
Test whether you can call all functions within the module.
If you cannot, try to use asserts instead of exceptions.
"""
import sys
from pyscf import gto, scf, ao2mo
try:
    from pyscf import mcscf, scf
except ImportError:
    raise ImportError("Cannot import pyscf")
finally:
    if "--no-pymcscf" in sys.argv:
        del globals()["mcscf"]
print(__doc__)
if __name__ == "__main__":
    # Create a molecule
    mol = gto.Mole()
    mol.atom = [
        ["N", (0.0, 0.0, 0.0)],
        ["H", (0.0, 0.0, 1.0)],
        ["H", (0.0, 0.0, -1.0)]
    ]
    mol.basis = "cc-pVDZ"
    mol.build()
    
    # Perform SCF calculation
    mc = scf.RHF(mol)
    mc.kernel()
    
    # Calculate MO coefficients and MO integrals
    mo_coef

## Section 8: Baseline LLM Code Generation

This section generates code using the LLM for the **baseline prompt** corresponding to the same demo sample instruction analyzed in Section 5 (`s5_instruction_to_s6`). It does **not** use any retrieved RAG context.

**Workflow:**
1.  **Prompt Builder Selection:** Chooses the appropriate `build_baseline_prompt_vX` function.
2.  **Baseline Prompt Construction:** Creates the baseline prompt using the demo instruction.
3.  **LLM Generation:** Calls the reusable `generate_llm_code_and_clean` helper function with the baseline prompt. It uses the same global LLM generation parameters and stopping criteria as defined in Section 7.1 for consistency in comparing RAG vs. Baseline outputs.
4.  **Output:** Displays the cleaned baseline-generated code and stores it in `s8_generated_code_baseline_demo`.

This allows for a direct comparison between the RAG-augmented output (from Section 7) and the LLM's output with only the instruction.

In [None]:
# --- 1. Ensure Prerequisite variables and functions are defined ---
# From Section 5:
if 's5_instruction_to_s6' not in locals() or not s5_instruction_to_s6:
    raise NameError("Instruction 's5_instruction_to_s6' from Section 5 not found. Run Section 5 first.")
# From earlier sections (or Section 7.1):
if 'model' not in locals() or 'tokenizer' not in locals():
    raise NameError("LLM 'model' or 'tokenizer' not defined.")
if 'generate_llm_code_and_clean' not in locals():
    raise NameError("Helper function 'generate_llm_code_and_clean' not defined.")
# Ensure global generation params (MAX_NEW_TOKENS, etc.) and llm_stopping_criteria_global are set (typically in Sec 7.1)
required_gen_params_s8 = ['MAX_NEW_TOKENS', 'REPETITION_PENALTY', 'DO_SAMPLE', 'TEMPERATURE', 'TOP_P', 'TOP_K', 'llm_stopping_criteria_global']
if any(p not in globals() for p in required_gen_params_s8):
    raise NameError(f"One or more global LLM generation parameters/criteria needed for baseline are missing. Run Section 7.1.")

# --- 2. Select the Baseline Prompt Builder Function ---
CHOSEN_BASELINE_PROMPT_BUILDER = build_baseline_prompt_v6

if 'CHOSEN_BASELINE_PROMPT_BUILDER' not in locals() or not callable(CHOSEN_BASELINE_PROMPT_BUILDER):
    if 'build_baseline_prompt_v1' in globals(): # Generic name
        CHOSEN_BASELINE_PROMPT_BUILDER = build_baseline_prompt_v1
    elif 'build_baseline_prompt_v6' in globals(): # Specific version
        CHOSEN_BASELINE_PROMPT_BUILDER = build_baseline_prompt_v6
    else:
        raise NameError("A suitable 'build_baseline_prompt_vX' or 'build_baseline_prompt_v1' function is not defined.")
print(f"  Using Baseline Prompt Builder: {CHOSEN_BASELINE_PROMPT_BUILDER.__name__}")

# --- 3. Construct Baseline Prompt ---
s8_generated_code_baseline_demo = None # Initialize output

# Using s5_instruction_to_s6 (the instruction for the demo sample from Section 5)
baseline_prompt_s8 = CHOSEN_BASELINE_PROMPT_BUILDER(s5_instruction_to_s6)

# print("\n  Baseline Prompt (first 500 chars):")
# print(textwrap.shorten(baseline_prompt_s8, width=500, placeholder="..."))

# --- 4. Call Reusable LLM Generation Function ---
s8_generated_code_baseline_demo = generate_llm_code_and_clean(
    prompt_text=baseline_prompt_s8,
    llm_model=model,
    llm_tokenizer=tokenizer,
    max_new_tokens_gen=MAX_NEW_TOKENS, # Global param
    do_sample_gen=DO_SAMPLE,           # Global param
    temperature_gen=TEMPERATURE,       # Global param
    top_p_gen=TOP_P,                   # Global param
    top_k_gen=TOP_K,                   # Global param
    repetition_penalty_gen=REPETITION_PENALTY, # Global param
    stopping_criteria_list_gen=llm_stopping_criteria_global, # From Sec 7.1
    prompt_name="Baseline Demo Prompt"
)

if s8_generated_code_baseline_demo is not None:
    print("\n  --- Cleaned Generated Baseline Code (Demo) ---")
    # print(textwrap.shorten(s8_generated_code_baseline_demo, width=700, placeholder="... (code truncated) ..."))
    print("\n  Full Cleaned Generated RAG Code (Demo):\n", s8_generated_code_baseline_demo)
    print(f"\n--- Baseline Demo Generation Complete. Result in 's8_generated_code_baseline_demo'. ---")
else:
    print(f"\n--- Baseline Demo Generation Failed or No Output. ---")

  Using Baseline Prompt Builder: build_baseline_prompt_v6
  Starting LLM generation for: Baseline Demo Prompt
    LLM generation for 'Baseline Demo Prompt' finished in 37.74s.

  --- Cleaned Generated Baseline Code (Demo) ---

  Full Cleaned Generated RAG Code (Demo):
 """
This is my program description: It takes in an input file containing a list of atom types and coordinates, along with atomic basis sets, and uses pyscf to generate a molecule object. It then uses pyscf to perform two state calculations with DFT, storing the molecular orbital information into separate chkfiles. It then reads the MO coefficients and occupation numbers from these chkfiles. Afterwards, it calculates the overlap between two determinants, constructs density matrices, calculates one-electron and two-electron part contributions, and calculates new total energy. Finally, it calculates the effective electronic coupling and prints the results. It also removes the chkfiles after completion.
"""
import pyscf
from

## Section 9 · Metrics Results


In [None]:
import re # For the tokenizer if defined here again

# --- 1. Number of Samples for Evaluation ---
# NUM_EXAMPLES_TO_EVALUATE = len(lca_dataset_split) if 'lca_dataset_split' in globals() else 10 # Evaluate all (or by defaul 10)
NUM_EXAMPLES_TO_EVALUATE = 2  # Subset for a quicker test run


# --- 2. Select Prompt Builders for Evaluation (Defined Globally) ---
EVAL_RAG_PROMPT_BUILDER = build_rag_prompt_v6 # RAG Prompt
EVAL_BASELINE_PROMPT_BUILDER = build_baseline_prompt_v6 # Baseline Prompt

print(f"  Using RAG Prompt Builder for Evaluation: {EVAL_RAG_PROMPT_BUILDER.__name__}")
print(f"  Using Baseline Prompt Builder for Evaluation: {EVAL_BASELINE_PROMPT_BUILDER.__name__}")


# --- 3. BM25 Parameters for Evaluation ---
EVAL_BM25_K1 = 1.5
EVAL_BM25_B = 0.75
EVAL_TOP_K_SNIPPETS_FOR_PROMPT = 5 # Adjust this to select the number of snippets to retrieve

# Define/Select the BM25 tokenizer for evaluation
def eval_robust_code_tokenizer_s9(text_input):
    if not isinstance(text_input, str): return []
    text = text_input.lower()
    raw_tokens = re.split(r'[^a-z0-9_]+', text)
    return [token for token in raw_tokens if token and len(token) > 1 and not token.isdigit()]
EVAL_BM25_TOKENIZER = eval_robust_code_tokenizer_s9

print(f"  BM25 Params for Evaluation: K1={EVAL_BM25_K1}, B={EVAL_BM25_B}, Top-K for prompt={EVAL_TOP_K_SNIPPETS_FOR_PROMPT}")
print(f"  BM25 Tokenizer for Evaluation: {EVAL_BM25_TOKENIZER.__name__}")


# --- 4. LLM Generation Parameters for Evaluation ---
# LLM parameters
EVAL_MAX_NEW_TOKENS = 384
EVAL_TEMPERATURE = 0.5  # Potentially more deterministic for evaluation
EVAL_TOP_P = 0.95
EVAL_TOP_K = 50
EVAL_REPETITION_PENALTY = 1.1
EVAL_DO_SAMPLE = True    # Set to False for deterministic greedy decoding during evaluation if preferred

# Stopping criteria
EVAL_STOP_ON_EOS = True
EVAL_STOP_ON_CODE_END = True

eval_llm_stopping_criteria = None
if 'tokenizer' not in globals() or tokenizer is None:
    print("  WARNING: Global 'tokenizer' not found. Custom stopping criteria for evaluation cannot be created.")
elif 'EosAndCodeEndStoppingCriteria' not in globals():
    print("  WARNING: 'EosAndCodeEndStoppingCriteria' class not defined (expected from Sec 7.1). Cannot create custom stopping criteria.")
elif EVAL_STOP_ON_EOS or EVAL_STOP_ON_CODE_END:
    try:
        eval_stopper_instance = EosAndCodeEndStoppingCriteria( # Class from 7.1
            tokenizer, # Global tokenizer
            stop_on_eos_token=EVAL_STOP_ON_EOS,
            code_end_sequence="\n```\n" if EVAL_STOP_ON_CODE_END else None
        )
        eval_llm_stopping_criteria = StoppingCriteriaList([eval_stopper_instance])
        print("  Custom 'eval_llm_stopping_criteria' for evaluation run created successfully.")
    except Exception as e:
        print(f"  WARNING: Failed to create custom stopping criteria for evaluation: {e}")
else:
    print("  Custom stopping criteria (EOS/code end) for evaluation are disabled by EVAL_ flags.")

print(f"  LLM Config: MaxNew={EVAL_MAX_NEW_TOKENS}, Temp={EVAL_TEMPERATURE if EVAL_DO_SAMPLE else 'N/A'}, DoSample={EVAL_DO_SAMPLE}, StopCriteriaIsSet={'Yes' if eval_llm_stopping_criteria else 'No'}")


# --- 5. Verify other critical dependencies for Cell 9.2 ---
critical_deps_for_9_2 = [
    'model', 'generate_llm_code_and_clean',
    'GITHUB_RAW_CONTENT_BASE_URL', 'LOCAL_TEMP_KB_DOWNLOAD_DIR', 'download_github_raw_json'
]
if any(dep not in globals() for dep in critical_deps_for_9_2):
    raise NameError(f"One or more critical global dependencies for Cell 9.2 are missing: {critical_deps_for_9_2}. "
                    "Ensure all prior setup sections (LLM loading, GitHub config, helper functions) have run.")

  Using RAG Prompt Builder for Evaluation: build_rag_prompt_v6
  Using Baseline Prompt Builder for Evaluation: build_baseline_prompt_v6
  BM25 Params for Evaluation: K1=1.5, B=0.75, Top-K for prompt=5
  BM25 Tokenizer for Evaluation: eval_robust_code_tokenizer_s9
  Custom 'eval_llm_stopping_criteria' for evaluation run created successfully.
  LLM Config: MaxNew=1024, Temp=0.5, DoSample=True, StopCriteriaIsSet=Yes


In [None]:
from rank_bm25 import BM25Okapi
from tqdm.auto import tqdm
import json
import os
import re
import textwrap # For debug printing snippets

# --- 1. Verify All Necessary Configurations and Dependencies ---
# (Assuming these checks are passed based on previous discussions and successful 9.1 run)
# ... (You can re-add the comprehensive checks from previous versions if desired for absolute safety) ...
if 'EVAL_RAG_PROMPT_BUILDER' not in globals() or 'generate_llm_code_and_clean' not in globals(): # Quick check
    raise NameError("Essential evaluation configurations or helper functions from 9.1 or earlier are missing.")

print(f"\n--- Starting Main Evaluation Loop for {NUM_EXAMPLES_TO_EVALUATE} Samples ---")

# --- 2. Initialize Lists for Storing All Evaluation Results ---
eval_all_baseline_outputs, eval_all_rag_outputs = [], []
eval_all_references, eval_all_reference_apis    = [], []

# --- 3. Initialize BM25 Index Cache ---
eval_bm25_index_cache = {} # {library_key: (bm25_index, valid_docs, map_idx_to_valid, tokenized_corpus)}

# --- 4. Select Dataset Subset for Evaluation ---
dataset_for_this_eval_run = lca_dataset_split.select(
    range(min(NUM_EXAMPLES_TO_EVALUATE, len(lca_dataset_split)))
)
actual_num_examples_being_evaluated = len(dataset_for_this_eval_run)

# --- 5. Main Evaluation Loop ---
for eval_sample_idx, current_eval_sample in enumerate(tqdm(dataset_for_this_eval_run, desc="⚙️ Evaluating Samples")):
    eval_instruction_text = current_eval_sample.get("instruction")
    eval_clean_ref_code = current_eval_sample.get("clean_reference")
    eval_ref_api_list = current_eval_sample.get("unique_apis", [])
    eval_sample_repo_key = current_eval_sample.get('repo_full_name')

    # --- Initialize outputs for this sample to ensure append happens ---
    generated_baseline_code_eval = "" # Default to empty string
    generated_rag_code_eval = ""      # Default to empty string

    # Minimal print per sample to reduce log verbosity
    if (eval_sample_idx + 1) % max(1, actual_num_examples_being_evaluated // 10) == 0 or \
       eval_sample_idx == actual_num_examples_being_evaluated - 1 :
        print(f"\n  Processing Eval Sample {eval_sample_idx + 1}/{actual_num_examples_being_evaluated}: Library '{eval_sample_repo_key}'")

    if not eval_instruction_text:
        print(f"    WARNING: Instruction missing for sample index {eval_sample_idx} (Lib: {eval_sample_repo_key}).")
        # Appending placeholders directly
        eval_all_baseline_outputs.append("")
        eval_all_rag_outputs.append("")
        eval_all_references.append(eval_clean_ref_code or "")
        eval_all_reference_apis.append(eval_ref_api_list or [])
        continue

    # --- A. Baseline Generation ---
    try:
        baseline_prompt_text_eval = EVAL_BASELINE_PROMPT_BUILDER(eval_instruction_text)
        generated_baseline_code_eval = generate_llm_code_and_clean(
            prompt_text=baseline_prompt_text_eval, llm_model=model, llm_tokenizer=tokenizer,
            max_new_tokens_gen=EVAL_MAX_NEW_TOKENS, do_sample_gen=EVAL_DO_SAMPLE,
            temperature_gen=EVAL_TEMPERATURE, top_p_gen=EVAL_TOP_P, top_k_gen=EVAL_TOP_K,
            repetition_penalty_gen=EVAL_REPETITION_PENALTY,
            stopping_criteria_list_gen=eval_llm_stopping_criteria, # From 9.1
            prompt_name=f"Baseline_Eval_{eval_sample_idx}_{eval_sample_repo_key}"
        )
    except Exception as e_base:
        print(f"    ERROR during baseline generation for sample {eval_sample_idx}: {type(e_base).__name__}: {e_base}")
        # generated_baseline_code_eval remains "" (or None if generate_llm_code_and_clean returned None on error)
    eval_all_baseline_outputs.append(generated_baseline_code_eval or "")


    # --- B. RAG Generation ---
    retrieved_snippets_block_for_rag = ""
    try:
        if eval_sample_repo_key:
            current_eval_bm25_index = None
            current_eval_valid_docs = []
            current_eval_map_bm25idx_to_validdocidx = []
            # Variable to hold the tokenized corpus used for building the index, for len check
            final_bm25_corpus_for_lib_eval_for_len_check = []


            if eval_sample_repo_key in eval_bm25_index_cache:
                current_eval_bm25_index, current_eval_valid_docs, current_eval_map_bm25idx_to_validdocidx, final_bm25_corpus_for_lib_eval_for_len_check = eval_bm25_index_cache[eval_sample_repo_key]
            else:
                kb_filename_for_eval = f"kb_{eval_sample_repo_key}.json"
                kb_url_for_eval = f"{GITHUB_RAW_CONTENT_BASE_URL}/{kb_filename_for_eval}"
                temp_kb_save_dir_for_eval_lib = os.path.join(LOCAL_TEMP_KB_DOWNLOAD_DIR, f"eval_kb_{eval_sample_repo_key}")
                kb_json_content_eval = download_github_raw_json(kb_url_for_eval, temp_kb_save_dir_for_eval_lib, kb_filename_for_eval, overwrite=False)

                if kb_json_content_eval and isinstance(kb_json_content_eval, list):
                    current_eval_valid_docs = [str(d) for d in kb_json_content_eval if isinstance(d, str) and str(d).strip()]
                    if current_eval_valid_docs:
                        tokenized_corpus_for_lib_eval = [EVAL_BM25_TOKENIZER(doc) for doc in current_eval_valid_docs]
                        # final_bm25_corpus_for_lib_eval_for_len_check is defined here
                        for i_valid, doc_tokens_eval in enumerate(tokenized_corpus_for_lib_eval):
                            if doc_tokens_eval:
                                final_bm25_corpus_for_lib_eval_for_len_check.append(doc_tokens_eval)
                                current_eval_map_bm25idx_to_validdocidx.append(i_valid)

                        if final_bm25_corpus_for_lib_eval_for_len_check:
                            current_eval_bm25_index = BM25Okapi(final_bm25_corpus_for_lib_eval_for_len_check, k1=EVAL_BM25_K1, b=EVAL_BM25_B)
                            eval_bm25_index_cache[eval_sample_repo_key] = (
                                current_eval_bm25_index, current_eval_valid_docs,
                                current_eval_map_bm25idx_to_validdocidx,
                                final_bm25_corpus_for_lib_eval_for_len_check # Cache tokenized corpus
                            )

            if current_eval_bm25_index:
                query_tokens_for_rag_eval = EVAL_BM25_TOKENIZER(eval_instruction_text)
                if query_tokens_for_rag_eval:
                    # Use len(current_eval_bm25_index.doc_len) as corrected
                    num_docs_in_lib_bm25_index = len(current_eval_bm25_index.doc_len)
                    top_indices_from_lib_bm25 = current_eval_bm25_index.get_top_n(
                        query_tokens_for_rag_eval, list(range(num_docs_in_lib_bm25_index)),
                        n=min(EVAL_TOP_K_SNIPPETS_FOR_PROMPT, num_docs_in_lib_bm25_index)
                    )
                    retrieved_snippets_list_for_prompt = [current_eval_valid_docs[current_eval_map_bm25idx_to_validdocidx[i]] for i in top_indices_from_lib_bm25]
                    retrieved_snippets_block_for_rag = "\n\n# --- Snippet ---\n\n".join(retrieved_snippets_list_for_prompt)

                    # --- DIAGNOSTIC PRINTS FOR RAG (Uncomment to debug specific samples) ---
                    # if eval_sample_idx == 2 and "weihuayi__fealpy" in eval_sample_repo_key: # Example: For sample 2 if it's fealpy
                    #     print(f"      DEBUG RAG Sample {eval_sample_idx} ({eval_sample_repo_key}):")
                    #     print(f"        Num retrieved snippets: {len(retrieved_snippets_list_for_prompt)}")
                    #     total_chars = sum(len(s) for s in retrieved_snippets_list_for_prompt)
                    #     print(f"        Total chars in retrieved snippets: {total_chars}")
                    #     for i_debug, snip_debug in enumerate(retrieved_snippets_list_for_prompt):
                    #         print(f"        --- Debug Snippet {i_debug+1} (Len: {len(snip_debug)}) ---")
                    #         print(textwrap.shorten(snip_debug, width=200, placeholder="..."))
                    # --------------------------------------------------------------------

        # --- Assemble RAG Prompt (outside the BM25 block, uses retrieved_snippets_block_for_rag) ---
        eval_pass_tokenizer_flag = False
        if 'tokenizer' in globals():
            try:
                eval_func_code_obj = getattr(EVAL_RAG_PROMPT_BUILDER, '__code__', None)
                if eval_func_code_obj and "truncate_to_n_tokens" in eval_func_code_obj.co_names:
                    eval_pass_tokenizer_flag = True
            except AttributeError: pass
        if eval_pass_tokenizer_flag and 'tokenizer' not in globals(): eval_pass_tokenizer_flag = False

        if eval_pass_tokenizer_flag:
            current_rag_prompt_text = EVAL_RAG_PROMPT_BUILDER(eval_instruction_text, retrieved_snippets_block_for_rag, tokenizer)
        else:
            current_rag_prompt_text = EVAL_RAG_PROMPT_BUILDER(eval_instruction_text, retrieved_snippets_block_for_rag)

        # --- DIAGNOSTIC PRINT FOR RAG PROMPT LENGTH (Uncomment to debug) ---
        # if eval_sample_idx == 2 and "weihuayi__fealpy" in eval_sample_repo_key:
        #     if current_rag_prompt_text:
        #         rag_prompt_tokens = tokenizer(current_rag_prompt_text, return_tensors="pt")['input_ids'].shape[1]
        #         print(f"      DEBUG RAG Sample {eval_sample_idx} ({eval_sample_repo_key}): Final RAG prompt input token length: {rag_prompt_tokens}")
        # --------------------------------------------------------------------

        generated_rag_code_eval = generate_llm_code_and_clean(
            prompt_text=current_rag_prompt_text, llm_model=model, llm_tokenizer=tokenizer,
            max_new_tokens_gen=EVAL_MAX_NEW_TOKENS, do_sample_gen=EVAL_DO_SAMPLE, temperature_gen=EVAL_TEMPERATURE,
            top_p_gen=EVAL_TOP_P, top_k_gen=EVAL_TOP_K, repetition_penalty_gen=EVAL_REPETITION_PENALTY,
            stopping_criteria_list_gen=eval_llm_stopping_criteria, prompt_name=f"RAG_Eval_{eval_sample_idx}_{eval_sample_repo_key}"
        )
    except Exception as e_rag:
        print(f"    ERROR during RAG processing for sample {eval_sample_idx}: {type(e_rag).__name__}: {e_rag}")
        # generated_rag_code_eval remains "" (or None if generate_llm_code_and_clean returned None)
    eval_all_rag_outputs.append(generated_rag_code_eval or "")


    # --- Store References ---
    eval_all_references.append(eval_clean_ref_code or "")
    eval_all_reference_apis.append(eval_ref_api_list or [])

# --- Final Sanity Check ---
if not (len(eval_all_baseline_outputs) == len(eval_all_rag_outputs) == \
        len(eval_all_references) == len(eval_all_reference_apis) == actual_num_examples_being_evaluated):
    print("\n❌ CRITICAL ERROR: Length mismatch in final evaluation output lists!")
    print(f"  Expected: {actual_num_examples_being_evaluated} for all lists.")
    print(f"  Baseline outputs: {len(eval_all_baseline_outputs)}")
    print(f"  RAG outputs: {len(eval_all_rag_outputs)}")
    print(f"  References: {len(eval_all_references)}")
    print(f"  Reference APIs: {len(eval_all_reference_apis)}")
else:
    print(f"\n✅ Successfully collected all outputs for {actual_num_examples_being_evaluated} evaluation examples.")


--- Section 9.2: Starting Main Evaluation Loop for 3 Samples ---
  Will evaluate 3 samples.


⚙️ Evaluating Samples:   0%|          | 0/3 [00:00<?, ?it/s]


  Processing Eval Sample 1/3: Library 'seed-labs__seed-emulator'
  Starting LLM generation for: Baseline_Eval_0_seed-labs__seed-emulator
    LLM generation for 'Baseline_Eval_0_seed-labs__seed-emulator' finished in 66.23s.
    ERROR during RAG processing for sample 0: TypeError: build_rag_prompt_v6() takes 2 positional arguments but 3 were given

  Processing Eval Sample 2/3: Library 'weihuayi__fealpy'
  Starting LLM generation for: Baseline_Eval_1_weihuayi__fealpy
    LLM generation for 'Baseline_Eval_1_weihuayi__fealpy' finished in 58.90s.
    ERROR during RAG processing for sample 1: TypeError: build_rag_prompt_v6() takes 2 positional arguments but 3 were given

  Processing Eval Sample 3/3: Library 'weihuayi__fealpy'
  Starting LLM generation for: Baseline_Eval_2_weihuayi__fealpy
    LLM generation for 'Baseline_Eval_2_weihuayi__fealpy' finished in 52.44s.
    ERROR during RAG processing for sample 2: TypeError: build_rag_prompt_v6() takes 2 positional arguments but 3 were given



In [None]:
"""
# ================================================================
#  Metrics helpers – with CodeBLEU support & automatic key-detection
# ================================================================
import importlib, warnings, re
from importlib.metadata import version as _get_version, PackageNotFoundError
import numpy as np

# ─── sacrebleu ─────────────────────────────────────────────────
import sacrebleu
print("✅ sacrebleu", sacrebleu.__version__)

# ─── codebleu ─────────────────────────────────────────────────
try:
    from codebleu import calc_codebleu
    try:
        cb_ver = _get_version("codebleu")
    except PackageNotFoundError:
        cb_ver = "n/a"
    print("✅ codebleu", cb_ver)
    _HAS_CODEBLEU = True
except ImportError:
    print("⚠️  codebleu import failed — CodeBLEU will be skipped.")
    _HAS_CODEBLEU = False

!pip install -q --upgrade tree_sitter tree_sitter_python
!pip install -q git+https://github.com/k4black/codebleu.git
# run this in a fresh cell *before* any CodeBLEU import
!pip install -q --upgrade "tree_sitter<0.23" "tree_sitter_python<0.23"
"""

!pip uninstall -yq codebleu            # throw away 0.7.0
!pip install -q --upgrade tree_sitter tree_sitter_python  # stays at 0.24+
!pip install -q git+https://github.com/k4black/codebleu.git  # 0.7.1-dev



[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/575.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m575.6/575.6 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m567.6/567.6 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for codebleu (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [codebleu]
[1A[2K

In [None]:
from codebleu import calc_codebleu
import sacrebleu
import importlib.metadata as md
print("Now running CodeBLEU", md.version("codebleu"))
_HAS_CODEBLEU = True

# ────────────────────────────────────────────────────────────────
#  Composite key auto-detector (runs once on your first example)
# ────────────────────────────────────────────────────────────────
_codebleu_key = None
def _detect_codebleu_key(sample_pred, sample_ref):
    global _codebleu_key
    if not _HAS_CODEBLEU:
        return None
    res = calc_codebleu(
        references=[[sample_ref]],
        predictions=[sample_pred],
        lang="python",
        weights=(0.25,0.25,0.25,0.25)
    )
    print("🔍 CodeBLEU raw result keys:", list(res.keys()))
    # pick the first key containing “codebleu” (case-insensitive)
    for k in res:
        if "codebleu" in k.lower():
            _codebleu_key = k
            break
    return _codebleu_key

# ────────────────────────────────────────────────────────────────
#  Metric functions
# ────────────────────────────────────────────────────────────────
def calculate_chrf(pred, ref):
    if not (isinstance(pred, str) and isinstance(ref, str)):
        return None
    if not pred or not ref:
        return 0.0
    return sacrebleu.corpus_chrf([pred], [[ref]]).score

def calculate_codebleu(pred, ref, lang="python", weights=(0.25,0.25,0.25,0.25)):
    if not _HAS_CODEBLEU:
        return None
    if not (isinstance(pred, str) and isinstance(ref, str)):
        return None

    # detect key on first call
    global _codebleu_key
    if _codebleu_key is None:
        _detect_codebleu_key(pred, ref)
        if _codebleu_key is None:
            warnings.warn("Could not find a CodeBLEU key in the result; returning 0.0")
            return 0.0

    try:
        res = calc_codebleu(
            references=[[ref]],
            predictions=[pred],
            lang=lang,
            weights=weights
        )
        return float(res.get(_codebleu_key, 0.0))
    except Exception as e:
        warnings.warn(f"CodeBLEU failed for one example: {e}")
        return None

def calculate_api_recall(gen, ref_apis):
    if not isinstance(gen, str) or not isinstance(ref_apis, list):
        return 0.0
    valid = [api for api in ref_apis if isinstance(api, str) and api.strip()]
    if not gen or not valid:
        return 0.0
    hits = sum(bool(re.search(rf"\b{re.escape(api)}\b", gen)) for api in valid)
    return hits / len(valid)


# ────────────────────────────────────────────────────────────────
#  Evaluation driver (baseline vs RAG)
# ────────────────────────────────────────────────────────────────
def _mean(fn, preds, gts):
    """Mean of fn(pred, gt) over a list of pairs."""
    vals = [fn(p, g) for p, g in zip(preds, gts)]
    vals = [v for v in vals if v is not None]
    return np.mean(vals) if vals else 0.0

def evaluate(baseline_preds, rag_preds, refs, ref_api_lists):
    recall_b = _mean(calculate_api_recall, baseline_preds, ref_api_lists)
    recall_r = _mean(calculate_api_recall, rag_preds, ref_api_lists)

    chrf_b   = _mean(calculate_chrf, baseline_preds, refs)
    chrf_r   = _mean(calculate_chrf, rag_preds, refs)

    cbleu_b  = _mean(calculate_codebleu, baseline_preds, refs)
    cbleu_r  = _mean(calculate_codebleu, rag_preds, refs)

    print("\n--- Risultati Metriche Automatiche ---")
    print(f"| Metrica    | Baseline |   RAG   |")
    print(f"|------------|----------|---------|")
    print(f"| API Recall | {recall_b:.4f}   | {recall_r:.4f}   |")
    print(f"| ChrF       | {chrf_b:.2f}    | {chrf_r:.2f}    |")
    print(f"| CodeBLEU   | {cbleu_b:.2f}    | {cbleu_r:.2f}    |")
    print("--------------------------------------------------")


    return {
        "recall":   (recall_b,   recall_r),
        "chrf":     (chrf_b,     chrf_r),
        "codebleu": (cbleu_b,    cbleu_r),
    }


metrics = evaluate(baseline_output, rag_output, references, reference_apis)


Now running CodeBLEU 0.7.1


NameError: name 'references' is not defined

### Prompt 1
| Metrica    | Baseline |   RAG   |
|------------|----------|---------|
| API Recall | 0.0000   | 0.0000   |
| ChrF       | 17.05    | 53.43    |
| CodeBLEU   | 0.11    | 0.43    |



---



### Prompt 2
| Metrica    | Baseline |   RAG   |
|------------|----------|---------|
| API Recall | 0.0000   | 0.0000   |
| ChrF       | 25.75    | 53.43    |
| CodeBLEU   | 0.21    | 0.43    |

---

### Prompt 3
| Metrica    | Baseline |   RAG   |
|------------|----------|---------|
| API Recall | 0.0000   | 0.0000   |
| ChrF       | 23.09    | 53.15    |
| CodeBLEU   | 0.16    | 0.43    |

---

### Prompt 4 (corretto API)
| Metrica    | Baseline |   RAG   |
|------------|----------|---------|
| API Recall | 0.0228   | 0.7427   |
| ChrF       | 17.32    | 53.15    |
| CodeBLEU   | 0.13    | 0.43    |

---

### Prompt 5
| Metrica    | Baseline |   RAG   |
|------------|----------|---------|
| API Recall | 0.0265   | 0.6082   |
| ChrF       | 21.14    | 49.75    |
| CodeBLEU   | 0.16    | 0.41    |

---
### Prompt 5 - REDO
| Metrica    | Baseline |   RAG   |
|------------|----------|---------|
| API Recall | 0.0121   | 0.1393   |
| ChrF       | 9.78    | 11.26    |
| CodeBLEU   | 0.13    | 0.14    |


---
### Prompt 6
| Metrica    | Baseline |   RAG   |
|------------|----------|---------|
| API Recall | 0.0000   | 0.0861   |
| ChrF       | 9.38    | 11.52    |
| CodeBLEU   | 0.13    | 0.14    |

In [None]:
def count_valid(fn, preds, refs):
    """Count how many (pred, ref) pairs return a non‐None metric."""
    return sum(1 for p, r in zip(preds, refs) if fn(p, r) is not None)

# For API‐Recall we never return None, so it’s simply the full length:
n_api = len(baseline_outputs)

# For ChrF & CodeBLEU we drop any None’s:
n_chrf     = count_valid(calculate_chrf,     baseline_outputs, references)
n_codebleu = count_valid(calculate_codebleu, baseline_outputs, references)

print(f"API Recall was computed on {n_api} samples")
print(f"ChrF       was computed on {n_chrf} samples")
print(f"CodeBLEU   was computed on {n_codebleu} samples")


In [None]:
n_chrf_rag     = count_valid(calculate_chrf,     rag_outputs, references)
n_codebleu_rag = count_valid(calculate_codebleu, rag_outputs, references)
print(f"(RAG) ChrF       on {n_chrf_rag} samples")
print(f"(RAG) CodeBLEU   on {n_codebleu_rag} samples")


## Section 10 · Example result


In [None]:
# ──────────────────────────────────────────────────────────────────
# Display a sample: task + reference + baseline + RAG side by side
# ──────────────────────────────────────────────────────────────────
from IPython.display import display, Markdown
import textwrap

EXAMPLE_INDEX = 7  # Change this to any index within your dataset size

task       = references[EXAMPLE_INDEX]            # Gold reference code (cleaned)
baseline   = baseline_outputs[EXAMPLE_INDEX]      # Generated from instruction only
rag        = rag_outputs[EXAMPLE_INDEX]           # Generated with RAG prompt
instruction = lca_dataset_split[EXAMPLE_INDEX]['instruction']  # Original task (English)

def print_block(title, content):
    print(f"{title}")
    print("-" * 10)
    print(textwrap.dedent(content).strip())
    print("-" * 10)
    print()




In [None]:
print("Instruction")
print("=" * 40)
print(textwrap.dedent(instruction).strip())
print("=" * 40)

In [None]:
print("✅ Gold Reference")
print("=" * 40)
print(textwrap.dedent(task).strip())
print("=" * 40)

In [None]:
print("Baseline Output")
print("=" * 40)
print(textwrap.dedent(generated_code_baseline).strip())
print("=" * 40)

In [None]:
print("RAG Output")
print("=" * 40)
print(textwrap.dedent(rag).strip())
print("=" * 40)


In [None]:
example_idx = 0

In [None]:
print("Baseline Output")
print("="*40)
print(textwrap.dedent(baseline_outputs[example_idx]).strip())
print("="*40)

In [None]:
print("RAG Output")
print("="*40)
print(textwrap.dedent(rag_outputs[example_idx]).strip())
print("="*40)


In [None]:
print(generate_code(build_rag_prompt("reverse a string", "def foo(): pass")))