## Notebook Overview – Retrieval‑Augmented Code Generation


> **Inspiration**  
> This notebook re-implements the three-stage **Retrieval-Augmented Generation (RAG)** workflow proposed in  
> *“An Empirical Study of Retrieval-Augmented Code Generation: Challenges and Opportunities”* (Yang et al., 2025).  
> By feeding an LLM with code snippets retrieved on-the-fly from a project-specific knowledge base, it narrows the
> semantic gap between natural-language instructions and the final source code.



Three-Phase Design (exactly as in the paper)

| Phase | One-line description (what the notebook does) | Paper counterpart |
|-------|-----------------------------------------------|-------------------|
| **Retrieval** | Build a snippet KB from repo sources → index with **BM25** → fetch Top-*k* snippets for each user query. | Comparative retriever study; BM25 emerges as the simplest & strongest. |
| **Fusion** | Pack snippets + instruction into a **Snippet-Integration Format** (SIF); other fusion modes (Sample Expansion, Vectorised Decoding, Sketch Filling) are optional. | Evaluation of four fusion strategies; SIF ≅ *Sequential Integration Fusion* (best trade-off). |
| **Generation** | Run a 4-bit quantised LLM (DeepSeek R1 / CodeGen / UniXcoder / CodeT5) with shared decoding settings and a custom stop-rule (EOS or closing ``` block). | Measure how retrieved context boosts vanilla pre-trained models. |

---

Walk-Through of Notebook Sections

| # | Colab section | What happens in **one sentence** |
|---|---------------|----------------------------------|
| 0 | **Google Drive Mounting** | Mounts Drive to save models, KBs and results persistently. |
| 1 | **Environment Setup** | Installs required libraries (transformers, datasets, BM25, CodeBLEU, etc.). |
| 2 | **LLM + Tokenizer Loading (4-bit)** | Downloads and quantises the chosen model to fit GPU memory. |
| 3 | **Dataset Preparation & Validation** | Loads the LCA test split and prints sanity-check stats. |
| 4 | **Repository Download & Prep** | Retrieves the library archive from HF Hub and extracts it in /content. |
| 5 | **Source Extraction → KB Build** | Parses every .py file, extracts functions/classes, stores ≤ 15 k snippets. |
| 6 | **BM25 Retrieval & Prompt Assembly** | Tokenises KB + query, ranks snippets, and assembles the SIF prompt. |
| 7 | **RAG Code Generation & Post-processing** | Feeds the SIF prompt to the LLM, decodes output, cleans first ```python``` block. |
| 8 | **Baseline Generation & Comparison** | Generates code with the *same* LLM but **without** snippets; stores both outputs. |
| 9 | **Metric Calculation** | Computes BLEU, CodeBLEU, ChrF, API-Recall, Edit-Distance; logs wins / losses. |

 Alignment with the Paper

* **BM25** is used as default retriever – matching the empirical winner reported by the authors.  
* **Sequential Integration Fusion** (SIF) is the primary fusion baseline.  
* **Metrics** (BLEU, CodeBLEU, ChrF, structural & API measures) mirror those in the study.  
* **Failure logging** reproduces the paper’s “Direction 2” analysis of RAG short-comings.




### Google Drive Mounting



This cell connects the Google Colab runtime to the user's *Google Drive*, establishing **persistent storage**.

The primary **rationale** is to enable saving and loading of critical project components like *models*, *datasets*, *indexes*, and *results*, ensuring work continuity.

In [1]:
# Import the 'drive' module from google.colab library.
from google.colab import drive

# Mount Google Drive at the '/content/drive' path in the Colab filesystem.
# Requires user authorization.
drive.mount('/content/drive')

# Confirm successful mounting.
print("Drive mounted successfully.")

Mounted at /content/drive
Drive mounted successfully.


## Section 1: Environment Setup

This section establishes the foundational Python environment for our Retrieval-Augmented Generation (RAG) project.
Key activities include:
1.  **Dependency Installation**: Securely installing all required external libraries.
2.  **Module Imports**: Loading necessary Python modules and classes.
3.  **Parser Configuration**: Setting up the `tree-sitter` parser for Python, crucial for advanced code evaluation.

### 1.1 Installing Core Libraries

We install dependencies using `pip`. Version pinning is applied to critical packages like `codebleu` and its `tree-sitter` dependencies to ensure consistent behavior and reproducibility across environments.

**Key Libraries & Purpose:**
*   **`transformers`, `datasets`, `huggingface_hub`**: Hugging Face ecosystem for models, datasets, and hub interaction.
*   **`torch`, `accelerate`, `bitsandbytes`**: PyTorch framework and tools for efficient model execution, including 4-bit quantization.
*   **`rank_bm25`**: BM25 algorithm for lexical retrieval.
*   **`sacrebleu`, `codebleu`, `tree-sitter`, `tree-sitter-languages`**: Code evaluation metrics and their parsing dependencies.
*   **`fsspec`**: Filesystem Abstraction (managed by pip's resolver unless conflicts arise).

In [2]:
print("Installing fundamental libraries (attempt 1.1 - updating transformers and hf_hub)...")
!pip install -U --no-cache-dir \
    transformers \
    datasets \
    fsspec \
    huggingface_hub \
    accelerate \
    bitsandbytes

!pip install --no-cache-dir \
    torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

!pip install --no-cache-dir \
    rank_bm25 \
    sacrebleu \
    tree-sitter-python \
    tree-sitter-languages \
    codebleu

print("Fundamental libraries (attempt 1.1) installation attempt finished.")

Installing fundamental libraries (attempt 1.1 - updating transformers and hf_hub)...
Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting accelerate
  Downloading accelerate-1.7.0-py3-none-any.whl.metadata (19 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting fsspec
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu

### 1.2 Importing Python Modules

We import standard Python libraries and specific modules from the newly installed packages.

**Module Groups & Purpose:**
*   **Standard Library**: `os`, `sys`, `time`, `warnings`, `json`, `re`, etc., for general utility.
*   **Type Hinting**: From `typing` for improved code readability and static analysis.
*   **PyTorch**: `torch` for tensor operations.
*   **Hugging Face `transformers`**:
    *   `AutoTokenizer`, `AutoModelForCausalLM`: For loading models and tokenizers.
    *   `BitsAndBytesConfig`: For quantization configuration.
    *   `PreTrainedTokenizer`, `PreTrainedTokenizerFast`: For tokenizer type hints.
    *   `StoppingCriteria`, `StoppingCriteriaList`: For custom generation stopping conditions.
*   **Hugging Face `datasets`**: `load_dataset`, `Dataset` (class for type hint), `DownloadMode`.
*   **Hugging Face `huggingface_hub`**: `hf_hub_download`, `EntryNotFoundError`.
*   **Retrieval**: `BM25Okapi` from `rank_bm25`.
*   **Progress Visualization**: `tqdm` (optional, with graceful fallback).
*   **Evaluation Metrics**: `sacrebleu` and `calc_codebleu` (imports are failure-tolerant).

In [3]:
import os
import sys
import time
import warnings
import json
import re
import shutil
import tarfile
import ast
import random
import textwrap

from typing import List, Optional, Dict, Any, Tuple, Union

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
    StoppingCriteria,
    StoppingCriteriaList
)
from datasets import load_dataset, Dataset, DownloadMode
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError as HfEntryNotFoundError # Alias for clarity

from rank_bm25 import BM25Okapi

try:
    from tqdm.auto import tqdm
    USE_TQDM = True
except ImportError:
    USE_TQDM = False
    print("Warning: 'tqdm' library not found. Progress bars will not be shown.")
    print("You can install it with: !pip install -q tqdm")

# Filter out less critical warnings for a cleaner output
warnings.filterwarnings("ignore", category=UserWarning, module="accelerate")

## Section 2: LLM and Tokenizer Loading with 4-bit Quantization

This section focuses on loading and validating the dataset that will serve as the foundation for our Retrieval-Augmented Generation task. It involves specifying the dataset source, loading a particular split, performing integrity checks, and previewing its structure to ensure it's suitable for subsequent processing, including knowledge base construction and evaluation.

This script orchestrates the loading of a specific LLM and its corresponding tokenizer, incorporating **4-bit quantization** to manage memory usage effectively.

1.  **Model Selection:** A `model_name` variable is defined, selecting the specific pre-trained model to be used (CodeGemma, Qwen, Deepseek, Code Llama, Phi) by simply removing the  comment. The choice impacts the model's architecture, size, and capabilities.

2.  **`trust_remote_code` Configuration:** Determines if the selected model requires executing custom code hosted on the Hugging Face Hub (`trust_remote_code=True`). This is automatically set based on known model prefixes (e.g., `Qwen/`, `microsoft/Phi-`) for security and functionality. *A warning is issued if set to `True`*.

3.  **Quantization Setup (`BitsAndBytesConfig`):**
    *   **Goal:** Reduce the model's memory footprint, enabling larger models to run on the available hardware.
    *   **Method:** Configures **4-bit NF4 quantization** using `BitsAndBytesConfig`.
    *   **Compute Type:** Dynamically selects the compute data type (`torch.bfloat16` if supported, otherwise `torch.float16`) for optimal performance during quantized operations.
    *   **Details:** Uses `load_in_4bit=True`, `bnb_4bit_quant_type="nf4"` (a common and effective quantization type), and `bnb_4bit_use_double_quant=True` for further memory savings.

4.  **Tokenizer and Model Loading:**
    *   Loads the appropriate `tokenizer` using `AutoTokenizer.from_pretrained`, passing the `trust_remote_code` flag. It includes a *crucial check* to set `tokenizer.pad_token_id` to `tokenizer.eos_token_id` if it's missing, a common requirement for causal LMs during batch processing or generation.
    *   Loads the `model` using `AutoModelForCausalLM.from_pretrained`.
        *   Applies the `quantization_config`.
        *   Uses `device_map="auto"` to automatically distribute model layers across available devices (GPU/CPU), essential for large models.
        *   Sets `low_cpu_mem_usage=True` to minimize RAM usage during model loading.
        *   Updates `model.config.pad_token_id` to match the tokenizer's setting if necessary.

5.  **Error Handling:** Robust `try...except` blocks catch common loading issues (e.g., `ImportError`, network/model `OSError`, GPU `OutOfMemoryError`, `ValueError` related to quantization support) providing informative error messages.

6.  **Verification:** If loading succeeds, it prints key details about the loaded model (name, estimated parameters, quantization settings) and tokenizer (class, vocab size, pad token ID). It concludes by running `nvidia-smi` to show the actual GPU memory consumption post-loading. If loading fails, a clear failure message is displayed.

In [4]:
# Check that there is only one selected model

# --- Gemma Series (Google) ---
# model_name = "google/codegemma-7b"
# model_name = "google/codegemma-7b-it"

# --- Qwen Series (Alibaba) ---
# model_name = "Qwen/Qwen2.5-Coder-7B-Instruct"
# model_name = "Qwen/Qwen2.5-Coder-3B-Instruct"
# model_name = "Qwen/Qwen2.5-Coder-1.5B-Instruct"

# --- Deepseek Coder Series (Deepseek AI) ---
# model_name = "deepseek-ai/deepseek-coder-7b-instruct-v1.5"
# model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
# model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model_name = "deepseek-ai/deepseek-coder-1.3b-base"

# --- Code Llama Series (Meta) ---
# model_name = "codellama/CodeLlama-7b-Instruct-hf"

# --- Phi Series (Microsoft) ---
# model_name = "microsoft/Phi-4-mini-instruct"
# model_name = "microsoft/Phi-4-multimodal-instruct"

print(f"Selected model: {model_name}")

Selected model: deepseek-ai/deepseek-coder-1.3b-base


In [5]:
# Models/prefixes generally requiring trust_remote_code=True
TRUST_REMOTE_CODE_MODELS = [
    "microsoft/Phi-",
    "Qwen/",
]

# Default to False, enable only if the model_name matches a prefix in the list
trust_code = any(model_name.startswith(prefix) for prefix in TRUST_REMOTE_CODE_MODELS)

# for deepseek-ai/DeepSeek-R1..., trust_code will be False (correct)
print(f"Setting trust_remote_code={trust_code} for {model_name}")

Setting trust_remote_code=False for deepseek-ai/deepseek-coder-1.3b-base


In [6]:
if torch.cuda.is_available():
    if torch.cuda.is_bf16_supported():
        compute_dtype = torch.bfloat16
        print("GPU supports bfloat16.")
    else:
        compute_dtype = torch.float16
        print("GPU does not supports bfloat16: use float16.")
else:
    compute_dtype = torch.float16
    print("CUDA not available: use float16.")

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True, # use 4-bit precision
    bnb_4bit_quant_type="nf4", # use Normalized Float 4 (NF4)
    bnb_4bit_compute_dtype=compute_dtype, #either bfloat16 or float16
    bnb_4bit_use_double_quant=True, # double quantization
)
print("Create quantization")

GPU supports bfloat16.
Create quantization


In [7]:
# --- Stage 1: Load Tokenizer ---
tokenizer = None

try:
    print(f"\nLoading tokenizer for model: {model_name}...")
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=trust_code  # From cell 2.2
    )
    print("Tokenizer loading successful.")

    # Standard pad_token configuration
    if tokenizer.eos_token_id is None:
        warnings.warn(
            f"CRITICAL: Tokenizer for {model_name} has no eos_token_id. "
            "This is highly unusual and may lead to issues with padding or generation."
        )

    if tokenizer.pad_token_id is None:
        if tokenizer.eos_token_id is not None:
            tokenizer.pad_token_id = tokenizer.eos_token_id
            warnings.warn(
                f"Tokenizer for {model_name} lacked a pad_token_id. "
                f"Set to eos_token_id: {tokenizer.eos_token_id}."
            )
        else: # Should be caught by the warning above, but good to have a fallback.
             raise ValueError(f"Tokenizer for {model_name} has neither pad_token_id nor eos_token_id. Cannot proceed.")


    if tokenizer.pad_token is None and tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token
        warnings.warn(
            f"Tokenizer for {model_name} lacked a pad_token string. "
            f"Set to eos_token string: '{tokenizer.eos_token}'."
        )

except HfEntryNotFoundError:
    print(f"ERROR: Tokenizer for '{model_name}' not found on Hugging Face Hub. Check model name.")
    tokenizer = None
except Exception as e:
    print(f"ERROR: Unexpected error loading tokenizer for {model_name}:")
    import traceback
    traceback.print_exc()
    tokenizer = None


Loading tokenizer for model: deepseek-ai/deepseek-coder-1.3b-base...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/793 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

Tokenizer loading successful.


In [8]:
# --- Stage 2: Load Model ---
model = None

if tokenizer:
    try:
        print(f"\nLoading quantized model: {model_name} (4-bit)...")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=quantization_config, #From cell 2.3
            device_map="auto",
            trust_remote_code=trust_code, # From cell 2.2
            low_cpu_mem_usage=True
        )

        if hasattr(model, 'config') and model.config.pad_token_id is None and tokenizer.pad_token_id is not None:
            model.config.pad_token_id = tokenizer.pad_token_id
            print(f"Synced model.config.pad_token_id with tokenizer: {tokenizer.pad_token_id}")

        print(f"Model '{model_name}' loaded successfully with 4-bit quantization.")

    except HfEntryNotFoundError:
        print(f"ERROR: Model weights/config for '{model_name}' not found on Hugging Face Hub.")
        model = None
    except torch.cuda.OutOfMemoryError:
        print(f"ERROR: CUDA OOM while loading model '{model_name}'. Model too large for GPU VRAM.")
        model = None
    except ValueError as e:
        print(f"ERROR (ValueError) loading model '{model_name}': {e}")
        if "bitsandbytes" in str(e).lower() or "NF4" in str(e):
            print("  This might be related to bitsandbytes setup or NF4 incompatibility.")
        import traceback
        traceback.print_exc()
        model = None
    except Exception as e:
        print(f"ERROR: Unexpected error loading model {model_name}:")
        import traceback
        traceback.print_exc()
        model = None
else:
    print(f"\nSkipping model loading for {model_name} due to tokenizer loading failure.")
    model = None

# --- Final Status Update ---
if tokenizer and model:
    print(f"\nSUCCESS: '{model_name}' is loaded.")
else:
    print(f"\nCRITICAL FAILURE: Could not load '{model_name}'.")


Loading quantized model: deepseek-ai/deepseek-coder-1.3b-base (4-bit)...


config.json:   0%|          | 0.00/631 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.69G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.69G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

Synced model.config.pad_token_id with tokenizer: 32014
Model 'deepseek-ai/deepseek-coder-1.3b-base' loaded successfully with 4-bit quantization.

SUCCESS: 'deepseek-ai/deepseek-coder-1.3b-base' is loaded.


## Section 3: Dataset Preparation and Validation



This section orchestrates the ingestion and initial vetting of our target dataset to ensure it's ready for downstream RAG processing. It covers four main steps:

1. **Google Drive Directory Setup**  
   Defines and verifies a persistent storage path (`drive_save_path`) on Google Drive—creating it if needed—to hold all dataset artifacts, intermediate files, and future outputs.

2. **Dataset Loading**  
   Uses Hugging Face's `load_dataset` to fetch the specified split (`test`) of the `JetBrains-Research/lca-library-based-code-generation` dataset into Colab's local cache for speed and stability (avoiding Drive as a cache directory).

3. **Inspection & Error Handling**  
   - **Entry Count & Size Estimate:** Reports the number of records and the approximate memory/cache footprint.  
   - **Structure Preview:** Prints available columns and a snippet of the first example (repository name, instruction, reference code, and top APIs).  
   - **Robust Exceptions:** Catches issues like missing dataset, connection failures, or other unexpected errors, offering troubleshooting tips.

4. **Final Verification & Advisory**  
   Prints a clear success/failure banner and reminds you that the actual library source code (the “repos”) must be downloaded separately if you plan to build your knowledge base.


In [9]:
# --- 3.1 GDrive Directory Setup ---
drive_save_path = '/content/drive/MyDrive/RAG_Project/' # To store results/outputs

# Check that the directory exists; If not, we create it
try:
    os.makedirs(drive_save_path, exist_ok=True)
    print(f"Google Drive Directory available: {drive_save_path}")
except OSError as e:
    print(f"Warning: can not create or verify the existence of the directory: {drive_save_path}. Details: {e}")

Google Drive Directory available: /content/drive/MyDrive/RAG_Project/


In [10]:
# --- 3.2 Loading of the dataset ---
from datasets import load_dataset, DownloadMode

# Configuration for dataset loading
dataset_name = "JetBrains-Research/lca-library-based-code-generation"
data_split = "test"  # As per dataset documentation for evaluation

lca_dataset_split = None  # Initialize variable to store the loaded dataset

print(f"\nAttempting to load dataset: '{dataset_name}' (split: '{data_split}').")
print("Note: Hugging Face Datasets library will use local Colab cache for optimal performance.")


try:
    # Dataset Loading Attempt
    lca_dataset_split = load_dataset(
        dataset_name,
        split=data_split,
        download_mode=DownloadMode.FORCE_REDOWNLOAD,
    )
    print("\nThe dataset was successfully loaded!")

except ConnectionError as e:
    print(f"\nERROR: Connection error while loading dataset. Please check your internet connection. Details: {e}")
    lca_dataset_split = None # Ensure it's None on failure
except Exception as e:
    print(f"\nERROR: An unexpected error occurred while loading the dataset: {e}")     # Catch other potential errors during dataset loading.
    lca_dataset_split = None # Ensure it's None on failure



Attempting to load dataset: 'JetBrains-Research/lca-library-based-code-generation' (split: 'test').
Note: Hugging Face Datasets library will use local Colab cache for optimal performance.


README.md:   0%|          | 0.00/5.21k [00:00<?, ?B/s]

(…)-00000-of-00001-518ed46ecbe35ff9.parquet:   0%|          | 0.00/4.58M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/150 [00:00<?, ? examples/s]


The dataset was successfully loaded!


In [13]:
import textwrap # For cleaner text truncation

print("\n--- 3.3 Dataset Inspection and Validation ---")

if lca_dataset_split is not None:
    print(f"\nDataset '{dataset_name}' (split: '{data_split}') appears to be loaded.")
    print(f"Number of entries: {len(lca_dataset_split)}")

    # Attempt to get and display dataset size
    try:
        # dataset.info.size_in_bytes is another way, or dataset.size_in_bytes
        dataset_size_bytes = lca_dataset_split.size_in_bytes
        if dataset_size_bytes is not None: # Check if the attribute exists and is not None
            print(f"Estimated dataset size (RAM/cache): {dataset_size_bytes / (1024**2):.2f} MB")
        else:
            print("Dataset size information (size_in_bytes) is None or not available.")
    except AttributeError:
        print("Info: The 'size_in_bytes' attribute is not available for this dataset object.")
    except Exception as e:
        print(f"Info: Could not retrieve dataset size. Error: {e}")

    # Display dataset features (columns)
    if hasattr(lca_dataset_split, 'features'):
        print("\nAvailable columns (features) in the dataset:")
        print(list(lca_dataset_split.features.keys()))
    else:
        print("\nWarning: Dataset features (column names) could not be retrieved.")

    # Preview the first example if the dataset is not empty
    if len(lca_dataset_split) > 0:
        print("\nPreview of the first example's content:")
        first_example = lca_dataset_split[0]

        repo_name = first_example.get('repo_full_name', 'N/A')
        instruction_text = first_example.get('instruction', 'N/A')
        reference_text = first_example.get('reference', 'N/A')
        apis_list = first_example.get('unique_apis', [])

        print(f"  Repository: {repo_name}")
        print(f"  Instruction: {textwrap.shorten(instruction_text, width=100, placeholder='...')}")
        print(f"  Reference Code: {textwrap.shorten(reference_text, width=100, placeholder='...')}")
        print(f"  Unique APIs (first 5): {apis_list[:5]}{'...' if len(apis_list) > 5 else ''}")
    else:
        print("\nDataset is loaded but contains no entries.")


--- 3.3 Dataset Inspection and Validation ---

Dataset 'JetBrains-Research/lca-library-based-code-generation' (split: 'test') appears to be loaded.
Number of entries: 150
Estimated dataset size (RAM/cache): 12.62 MB

Available columns (features) in the dataset:
['repo_full_name', 'repo_name', 'repo_owner', 'instruction', 'reference', 'clean_reference', 'path_to_reference_file', 'path_to_examples_folder', 'n_unique_apis', 'unique_apis', 'project_defined_elements', 'api_calls', 'internal_apis']

Preview of the first example's content:
  Repository: seed-labs__seed-emulator
  Instruction: Generate code that creates an emulation using the seedemu library. The emulation should include...
  Reference Code: #!/usr/bin/env python # encoding: utf-8 # __author__ = 'Demon' from seedemu.layers import Base,...
  Unique APIs (first 5): ['DomainNameCachingService', 'addLayer', 'addPrivatePeering', 'Ospf', 'createHost']...


## Section 4: Repository Archive Download & Preparation



This section automates fetching a specific archived snapshot from our dataset’s Hugging Face repository and placing it into the working directory with robust error handling. It comprises five key phases:

1. **Exception Imports**  
   Attempts to import `HfHubHTTPError`, `RepositoryNotFoundError`, and `EntryNotFoundError` from the most up‑to‑date `huggingface_hub.utils`; falls back to `.errors`, and finally defines dummy exception classes if neither is available. This ensures our download logic can catch and respond to all common HF Hub issues.

2. **Archive Download**  
   Uses `hf_hub_download` to pull `repos/seed-labs__seed-emulator.tar.gz` from the `JetBrains-Research/lca-library-based-code-generation` dataset into a temporary local cache. Prints progress messages to keep you informed.

3. **Error Handling**  
   - **RepositoryNotFoundError** if the repo ID is invalid  
   - **EntryNotFoundError** if the specified file path doesn’t exist in the repo  
   - **HfHubHTTPError** for other HTTP failures (401/403/404, etc.)  
   - **Generic Exceptions** for any unforeseen issues, with a concise summary of the error type and message

4. **File Relocation & Cleanup**  
   If the download succeeds, the script checks whether the file is already at the desired target (`/content/seed-labs__seed-emulator.tar.gz`). If not, it moves the archive there, creates any missing directories, and then removes any now‑empty intermediate folders to keep the workspace tidy.

5. **Final Verification**  
   Prints a clear `[OK]` banner if the archive is present at the target path—or an `[ERROR]` banner otherwise—so you can be confident whether to proceed with the next extraction or processing steps.


In [None]:
# --- 1. Exception Imports ---
try:
    from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError
    print("Succesfully import exceptions from huggingface_hub.utils.")
except ImportError:
    print("WARNING: can not import exceptions from huggingface_hub.utils, try from .errors")
    try:
        from huggingface_hub.errors import HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError
        print("Importing exceptions from huggingface_hub.errors completed.")
    except ImportError:
        print("ERROR: can not import exceptions from huggingface_hub.")
        class HfHubHTTPError(Exception): pass
        class RepositoryNotFoundError(Exception): pass
        class EntryNotFoundError(Exception): pass

Succesfully import exceptions from huggingface_hub.utils.


In [None]:
# --- 2. Archive Download ---

repo_id = "JetBrains-Research/lca-library-based-code-generation"
filename_in_repo = "repos/seed-labs__seed-emulator.tar.gz"
desired_local_archive_path = "/content/seed-labs__seed-emulator.tar.gz"
download_base_dir = "/content/"

print(f"\n--- Download and configuration ---")
print(f"Repo: {repo_id}")
print(f"File in repo: {filename_in_repo}")
print(f"Desired destination: {desired_local_archive_path}")

actual_downloaded_path = None

# actual donwload from Hugging Face
try:
    print(f"\nStarting download from Hugging Face Hub...")
    actual_downloaded_path = hf_hub_download(
        repo_id=repo_id,
        filename=filename_in_repo,
        repo_type="dataset",
        local_dir=download_base_dir,
        local_dir_use_symlinks=False,
    )
    print(f"Download completed. File saved at: {actual_downloaded_path}")

# --- 3. Error Handling ---
except RepositoryNotFoundError:
    print(f"\nERROR: Repository '{repo_id}' not found on Hugging Face Hub.")
    print("  Make sure the repository name is correct.")
except EntryNotFoundError:  # Specific file not found in the repo
    print(f"\nERROR: File/Entry '{filename_in_repo}' not found in the repository '{repo_id}'.")
except HfHubHTTPError as e:  # HTTP errors (including 401, 403, 404 not already caught above)
    print(f"\nHTTP ERROR during download from Hugging Face Hub: {e}")
    if hasattr(e, 'response') and e.response is not None:
        print(f"  Status Code: {e.response.status_code}")
        if e.response.status_code == 404:
            print(f"  -> The file '{filename_in_repo}' or the repo '{repo_id}' may not exist (Error 404).")
    print(f"  Please check the repo_id, filename_in_repo, and your internet connection or HF token if necessary.")
except Exception as e:
    # Catch other unexpected errors
    import traceback
    print(f"\nUNEXPECTED ERROR during the download:")
    # print(traceback.format_exc())  # Uncomment for full traceback during debugging
    print(f"  Error Type: {type(e).__name__}, Message: {e}")

# --- 4. File Relocation & Cleanup ---
archive_ready = False
if actual_downloaded_path and os.path.exists(actual_downloaded_path):
    if os.path.abspath(actual_downloaded_path) == os.path.abspath(desired_local_archive_path):
        print(f"\nThe archive is already at the desired final location: {desired_local_archive_path}")
        archive_ready = True
    else:
        try:
            print(f"\Moving '{os.path.basename(actual_downloaded_path)}' to '{desired_local_archive_path}'...")
            os.makedirs(os.path.dirname(desired_local_archive_path), exist_ok=True)
            shutil.move(actual_downloaded_path, desired_local_archive_path)
            print(f"Move completed successfully.")
            archive_ready = True

            # Clean up intermediate directory if empty
            download_parent_dir = os.path.dirname(actual_downloaded_path)
            if (os.path.exists(download_parent_dir) and
                os.path.abspath(download_parent_dir) != os.path.abspath(download_base_dir) and
                os.path.abspath(download_parent_dir).startswith(os.path.abspath(download_base_dir)) and
                not os.listdir(download_parent_dir)):
                try:
                    print(f"Removing empty intermediate directory: {download_parent_dir}")
                    os.rmdir(download_parent_dir)
                except OSError as rmdir_e:
                    print(f"  Warining: can not remove {download_parent_dir}. Issue: {rmdir_e}")

        except Exception as move_e:
            print("\nERROR during move or cleanup of downloaded file:")
            print(f"  Error: {move_e}")
            print(f"  The downloaded file may still be located at: {actual_downloaded_path}")
            archive_ready = False

elif not actual_downloaded_path:
     print("\nDownload failed. Cannot proceed.")
else:
     print(f"\nINTERNAL ERROR: Download path ({actual_downloaded_path}) does not exist after the attempt.")

# --- 5. Final Verification ---
print("\nFinal check:")
if archive_ready and os.path.exists(desired_local_archive_path):
    print(f"[OK] The final archive is ready at: {desired_local_archive_path}")
else:
    print(f"[ERROR] The final archive was NOT found or prepared correctly at: {desired_local_archive_path}")

print("\n--- End of Download and Preparation ---")


--- Download and configuration ---
Repo: JetBrains-Research/lca-library-based-code-generation
File in repo: repos/seed-labs__seed-emulator.tar.gz
Desired destination: /content/seed-labs__seed-emulator.tar.gz

Starting download from Hugging Face Hub...


For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


seed-labs__seed-emulator.tar.gz:   0%|          | 0.00/24.0M [00:00<?, ?B/s]

Download completed. File saved at: /content/repos/seed-labs__seed-emulator.tar.gz
\Moving 'seed-labs__seed-emulator.tar.gz' to '/content/seed-labs__seed-emulator.tar.gz'...
Move completed successfully.
Removing empty intermediate directory: /content/repos

Final check:
[OK] The final archive is ready at: /content/seed-labs__seed-emulator.tar.gz

--- End of Download and Preparation ---


## Section 5: Source Extraction & Knowledge Base Construction



This section handles unpacking the downloaded emulator archive and building a searchable code-snippet knowledge base (KB) from the extracted Python sources. It consists of three major parts:

1. **Archive Extraction**  
   - **Configuration:** Points to the local archive (`local_archive_path`) and defines the parent directory for extraction (`extract_dir_parent`).  
   - **Safety Checks:** Verifies the archive exists, cleans any old extraction folder, then creates the target directory.  
   - **Unpacking:** Uses Python's `tarfile` to extract all contents into `extract_dir_parent`.  
   - **Dynamic Path Resolution:** Scans the extraction folder to identify the main code directory—handling single‐folder archives, multiple items, or unexpected layouts—falling back to the parent directory if needed.

2. **Snippet Extraction Helpers**  
   - **`extract_code_units(...)`:** Reads each `.py` file (with UTF-8 and fallback encoding), parses its AST, and collects the full source text of every function, async function, and class definition, gracefully ignoring syntax or permission errors.  
   - **`build_kb_for_library(...)`:** Recursively walks the extracted source tree, applies `extract_code_units` to every Python file (with an optional `tqdm` progress bar), and accumulates all code units. If the total exceeds `MAX_KB_SIZE`, it randomly samples down to limit memory use.

3. **KB Creation & Persistence**  
   - **Sample Selection:** Retrieves the `repo_full_name` from the chosen dataset entry (`SAMPLE_INDEX`) to label the KB.  
   - **KB Assembly:** Invokes `build_kb_for_library` on the extracted code path, reporting how many files and snippets were processed or skipped.  
   - **Drive Saving:** Creates (if necessary) a `library_kbs` folder on Google Drive, serializes the final snippet list to JSON, and writes it as `kb_<repo_name>_sample_<i>.json`.  
   - **Final Check:** Prints a success banner with the total snippet count or an error if the KB is empty or failed to save.

By the end of this cell you'll have both the raw source files available under `extract_dir_parent` and a curated, size-limited KB of Python code snippets stored permanently in your Drive for use in retrieval-augmented generation.  


In [None]:
# --- 1. Archive Extraction ---
# --- 1.1. Configuration ---
# Path to the downloaded archive (should already exist from the previous cell)
local_archive_path = '/content/seed-labs__seed-emulator.tar.gz'
# Base directory where we want to extract the archive contents
extract_dir_parent = "/content/library_sources/"

# This variable will hold the actual path to the main extracted folder.
# It will be determined after extraction is complete.
final_extracted_code_path = None

print("--- Extraction archive ---")
print(f"Archive: {local_archive_path}")
print(f"Destination directory: {extract_dir_parent}")

--- Extraction archive ---
Archive: /content/seed-labs__seed-emulator.tar.gz
Destination directory: /content/library_sources/


In [None]:
# --- 1.2. Safety Check ---
if not os.path.exists(local_archive_path):
    print(f"\n[ERROR] Source archive not found: {local_archive_path}")
    print("  Make sure the download cell was run correctly.")


else:
    try:
        # Optional: clean destionation directory before the execution (if next instruction is not commented)
        if os.path.exists(extract_dir_parent): shutil.rmtree(extract_dir_parent)

        # Create destination directory
        # exist_ok=True avoids errors if already exists
        os.makedirs(extract_dir_parent, exist_ok=True)
        print(f"\nTarget extraction directory '{extract_dir_parent}' is ready.")

        # --- 1.3. Unpacking ---
        # extract the archive
        print(f"Starting extraction of '{os.path.basename(local_archive_path)}'...")
        with tarfile.open(local_archive_path, "r:gz") as tar:
            tar.extractall(path=extract_dir_parent)
        print("Extraction completed successfully.")

        # --- 1.4. Dynamic Path Resolution ---
        # dynamically determine the extracted path
        try:
            extracted_items = os.listdir(extract_dir_parent)
            if len(extracted_items) == 1 and os.path.isdir(os.path.join(extract_dir_parent, extracted_items[0])):
                final_extracted_code_path = os.path.join(extract_dir_parent, extracted_items[0])
                print(f"Identified main extracted directory: {final_extracted_code_path}")
            elif len(extracted_items) > 0:
                 # look for a folder matching the archive's base name
                 archive_basename = os.path.basename(local_archive_path).replace('.tar.gz', '').replace('.tgz', '')
                 potential_match = os.path.join(extract_dir_parent, archive_basename)
                 if os.path.isdir(potential_match):
                     final_extracted_code_path = potential_match
                     print(f"Found potential matching directory: {final_extracted_code_path}")
                 else:
                     first_item_path = os.path.join(extract_dir_parent, extracted_items[0])
                     if os.path.isdir(first_item_path):
                          final_extracted_code_path = first_item_path
                          print(f"WARNING: Multiple items found. Assuming first directory: {final_extracted_code_path}")
                     else:
                          print(f"WARNING: No main directory found in the extraction folder {extract_dir_parent}.")
                          print(f"  Contents: {extracted_items}")
                          print(f"  'final_extracted_code_path' might be set manually.")
                          final_extracted_code_path = extract_dir_parent # Fallback: use the parent dir
                          print(f"  Impostato fallback a: {final_extracted_code_path}")

            else:
                 print(f"WARNING: Extraction folder '{extract_dir_parent}' is empty after extraction.")

        except Exception as list_e:
             print(f"Issue while analyzing the extracted data: {list_e}")

    except tarfile.ReadError:
        print(f"\n[ERROR] Cannot read archive: {local_archive_path}. It may be corrupted.")
    except FileNotFoundError:
        # can happen only if local_archive_path is removed
        print(f"\n[ERROR] Archive file not found during open attempt: {local_archive_path}")
    except Exception as e:
        print(f"\n[ERROR] Unexpected error during preparation or extraction:")
        # print(traceback.format_exc()) # uncomment for debug
        print(f"  Error Type: {type(e).__name__}, Message: {e}")



Target extraction directory '/content/library_sources/' is ready.
Starting extraction of 'seed-labs__seed-emulator.tar.gz'...
Extraction completed successfully.
Identified main extracted directory: /content/library_sources/mnt


In [None]:
# --- 1.5. Final check ---

print("\nFinal check:")
if final_extracted_code_path and os.path.isdir(final_extracted_code_path):
    print(f"[OK] The extracted source code path is: {final_extracted_code_path}")
    print("\nPartial content of the extracted directory (first 10 entries):")
    try:
        content_list = os.listdir(final_extracted_code_path)
        for item in content_list[:10]:
            print(f"  - {item}")
        if len(content_list) > 10:
            print("  ...")
    except Exception as e:
        print(f"  Errore while listing the content of {final_extracted_code_path}: {e}")
else:
    print(f"[ERROR] Unable to determine or locate the extracted code directory.")
    print(f"         'final_extracted_code_path' is: {final_extracted_code_path}")
    print(f"         Make sure the extraction completed successfully.")


print("\n--- End of Archive Extraction ---")

# Make the variable available for subsequent cells (optional but useful)
# You may want to rename it to `extracted_code_path` if subsequent cells
# use that specific name.
# extracted_code_path = final_extracted_code_path
# print(f‘\nVariable “extracted_code_path” set to: {extracted_code_path}’)


Final check:
[OK] The extracted source code path is: /content/library_sources/mnt

Partial content of the extracted directory (first 10 entries):
  - data

--- End of Archive Extraction ---


In [None]:
# --- 1. Configuration (pre Snippet Extraction Helpers) ---
SAMPLE_INDEX = 0       # Index of the dataset sample to process
MAX_KB_SIZE = 15000    # Max number of code snippets to include in the KB (to limit RAM)
FALLBACK_ENCODING = 'iso-8859-1'  # Encoding to use if UTF-8 fails
DRIVE_KB_SAVE_DIR = '/content/drive/MyDrive/RAG_Project/library_kbs'  # Directory to save KBs on Google Drive


# Check the existance of the needed variables
if 'lca_dataset_split' not in locals() or not lca_dataset_split:
    raise NameError("CRITICAL ERROR: Variable 'lca_dataset_split' is not defined or is empty. Rerun the dataset loading cell.")
if 'final_extracted_code_path' not in locals() or not final_extracted_code_path:
     # Fallback: try to use old name
     if 'extracted_code_path' in locals() and extracted_code_path:
          warnings.warn("Variable 'final_extracted_code_path' not found, using 'extracted_code_path' as fallback.")
          final_extracted_code_path = extracted_code_path
     else:
          raise NameError("CRITICAL ERROR: Variable 'final_extracted_code_path' (or 'extracted_code_path') is not defined. Rerun the archive extraction cell.")
if not os.path.isdir(final_extracted_code_path):
     raise FileNotFoundError(f"CRITICAL ERROR: The extracted code path '{final_extracted_code_path}' does not exist or is not a directory. Check the archive extraction step.")

# Actual source code path from previous cell
library_source_dir = final_extracted_code_path


In [None]:
# --- 2. Snippet Extraction Helpers ---

def extract_code_units(py_file_path, fallback_encoding=FALLBACK_ENCODING):
    """Extracts functions and classes from a Python file as strings, with improved error handling."""
    units = []
    source = None
    encoding_used = 'utf-8'
    try:
        # Attempt to read with UTF-8
        with open(py_file_path, 'r', encoding='utf-8') as file:
            source = file.read()
    except UnicodeDecodeError:
        # Fallback to the specified encoding
        encoding_used = fallback_encoding
        try:
            with open(py_file_path, 'r', encoding=fallback_encoding) as file:
                source = file.read()
            # warnings.warn(f"Used encoding '{encoding_used}' for {py_file_path}") # Optional: Log used encoding
        except Exception as read_e:
            # print(f"  Error reading file {py_file_path} (even with {encoding_used}): {read_e}")
            return units # Nothing we can do if reading fails
    except PermissionError:
        # print(f"  Permission error reading {py_file_path}")
        return units
    except Exception as read_e:
        # print(f"  Unexpected error reading {py_file_path}: {read_e}")
        return units

    # If reading succeeds, try to parse
    if source is not None:
        try:
            tree = ast.parse(source, filename=py_file_path)
            # Check availability of get_source_segment (should be available in Python 3.8+)
            can_get_segment = hasattr(ast, 'get_source_segment')

            for node in ast.walk(tree):
                if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                    code_segment = None
                    if can_get_segment:
                        try:
                            code_segment = ast.get_source_segment(source, node)
                        except Exception as segment_e:
                            # Sometimes the segment can't be extracted for complex nodes or decorators
                            # print(f"  Warning: Error extracting segment ({type(node).__name__}) in {py_file_path}: {segment_e}")
                            pass
                    else: # Very simple fallback if get_source_segment is not available
                        code_segment = ast.dump(node) # Not ideal, but better than nothing

                    if code_segment:
                        units.append(code_segment)

        except SyntaxError as syn_e:
            # Ignore files with Python syntax errors
            # print(f"  Ignored: Syntax error in {py_file_path}: {syn_e}")
            pass
        except Exception as parse_e:
            # Ignore other parsing errors
            # print(f"  Ignored: AST parsing error in {py_file_path}: {parse_e}")
            pass
    return units

def build_kb_for_library(source_path, max_kb_size=MAX_KB_SIZE, use_tqdm=USE_TQDM):
    """Builds the KB (list of snippets) by scanning .py files, with progress and error handling."""
    if not os.path.isdir(source_path):
        print(f"[ERROR] The provided source path is not a valid directory: {source_path}")
        return []

    print(f"\nStarting library scan in: {source_path}")
    knowledge_base = []
    file_count = 0
    processed_count = 0
    skipped_count = 0

    # Count total .py files for tqdm (if used)
    total_py_files = 0
    if use_tqdm:
        print("Counting .py files for progress bar...")
        for _, _, files in os.walk(source_path):
            total_py_files += sum(1 for file in files if file.endswith(".py"))
        print(f"Found {total_py_files} .py files.")

    # Set up the iterator (with or without tqdm)
    walker = os.walk(source_path, topdown=True) # topdown=True for potential dir exclusion
    if use_tqdm:
        pbar = tqdm(total=total_py_files, desc="Extracting Snippets", unit="file")

    try:
        for root, dirs, files in walker:
            # Optional: Exclude specific directories (e.g., test, docs, build)
            # dirs[:] = [d for d in dirs if d not in ['tests', 'test', 'docs', '__pycache__', 'build']]

            for file in files:
                if file.endswith(".py"):
                    file_path = os.path.join(root, file)
                    file_count += 1
                    snippets = extract_code_units(file_path)
                    if snippets:
                        knowledge_base.extend(snippets)
                        processed_count += 1
                    else:
                        skipped_count += 1 # .py file read but no snippet extracted (error or empty)

                    if use_tqdm:
                        pbar.update(1)
                    elif file_count % 200 == 0: # Print progress less frequently without tqdm
                        print(f"  Processed {file_count} files...")

    except PermissionError as perm_e:
        print(f"\n[ERROR] Permission error during scan of {source_path}: {perm_e}")
        print("  You may need to adjust permissions or run as a different user.")
    except Exception as walk_e:
        print(f"\n[ERROR] Unexpected error during scan: {walk_e}")
    finally:
        if use_tqdm:
            pbar.close()

    print(f"\nScan completed.")
    print(f"  Total .py files encountered: {file_count}")
    print(f"  .py files processed with snippets: {processed_count}")
    print(f"  .py files skipped/with errors: {skipped_count}")
    print(f"  Total snippets extracted (before sampling): {len(knowledge_base)}")

    # Sampling if the KB is too large
    if len(knowledge_base) > max_kb_size:
        print(f"\nWARNING: KB too large ({len(knowledge_base)} snippets).")
        print(f"  Random sampling to keep a maximum of {max_kb_size} snippets.")
        knowledge_base = random.sample(knowledge_base, max_kb_size)
        print(f"  KB size after sampling: {len(knowledge_base)}")
    elif len(knowledge_base) == 0:
        print("\nWARNING: No snippet extracted from the library.")
        print(f"  Check that '{source_path}' contains valid and readable .py files.")

    return knowledge_base

In [None]:
# --- 3. KB Creation & Persistence ---

print("\n" + "="*40)
print("--- Knowledge Base (KB) Creation ---")
print("="*40)

current_kb = []  # Initialize KB as empty

try:
    # Retrieve info from the loaded dataset
    sample = lca_dataset_split[SAMPLE_INDEX]
    repo_full_name = sample.get('repo_full_name')

    if not repo_full_name:
        print(f"[ERROR] 'repo_full_name' not found in dataset sample {SAMPLE_INDEX}.")
    else:
        print(f"Processing Sample {SAMPLE_INDEX}: Library '{repo_full_name}'")
        print(f"Source code path: {library_source_dir}")

        # Build the KB
        current_kb = build_kb_for_library(library_source_dir)  # Use the improved function

        # Save the KB to Drive if it's not empty
        if current_kb:
            # Create the save directory on Drive if it doesn't exist
            try:
                os.makedirs(DRIVE_KB_SAVE_DIR, exist_ok=True)
            except OSError as drive_err:
                print(f"\n[ERROR] Unable to create save directory on Drive: {DRIVE_KB_SAVE_DIR}")
                print(f"  Error: {drive_err}")
                print("  KB save skipped.")
                # You might choose to exit or continue without saving
                # raise drive_err  # Uncomment to stop execution

            # Build the full path for the KB file
            # Clean the repo name to avoid problematic characters in filenames
            safe_repo_name = repo_full_name.replace('/', '__')  # Replace / with __
            kb_filename = f"kb_{safe_repo_name}_sample_{SAMPLE_INDEX}.json"
            kb_full_path = os.path.join(DRIVE_KB_SAVE_DIR, kb_filename)

            print(f"\nAttempting to save KB ({len(current_kb)} snippets) to: {kb_full_path}")
            try:
                with open(kb_full_path, 'w', encoding='utf-8') as f:
                    json.dump(current_kb, f, indent=2, ensure_ascii=False)
                print(f"[OK] KB successfully saved.")
            except OSError as save_err:
                print(f"\n[ERROR] Unable to write KB file to Drive: {kb_full_path}")
                print(f"  Error: {save_err}. Check write permissions on Drive.")
            except Exception as json_err:
                print(f"\n[ERROR] Error during JSON serialization of the KB: {json_err}")
        else:
            print("\nKB is empty, no file saved.")

except IndexError:
    print(f"[ERROR] Index {SAMPLE_INDEX} out of bounds for 'lca_dataset_split' (size: {len(lca_dataset_split)}).")
except Exception as main_e:
    import traceback
    print(f"\n[ERROR] Unexpected error in main script:")
    print(traceback.format_exc())

# --- 4. Final Check ---
if current_kb:
    print(f"\n--- KB for {repo_full_name} Ready ({len(current_kb)} snippets) ---")
else:
    print(f"\n--- KB not created or empty ---")

print("\n--- End of KB Creation ---")



--- Knowledge Base (KB) Creation ---
Processing Sample 0: Library 'seed-labs__seed-emulator'
Source code path: /content/library_sources/mnt

Starting library scan in: /content/library_sources/mnt
Counting .py files for progress bar...
Found 136 .py files.


Extracting Snippets:   0%|          | 0/136 [00:00<?, ?file/s]


Scan completed.
  Total .py files encountered: 136
  .py files processed with snippets: 99
  .py files skipped/with errors: 37
  Total snippets extracted (before sampling): 1196

Attempting to save KB (1196 snippets) to: /content/drive/MyDrive/RAG_Project/library_kbs/kb_seed-labs__seed-emulator_sample_0.json
[OK] KB successfully saved.

--- KB for seed-labs__seed-emulator Ready (1196 snippets) ---

--- End of KB Creation ---


## Section 6: BM25 Retrieval & Prompt Assembly



This section implements the core retrieval‑augmented generation workflow: it loads or reuses the previously built knowledge base (KB), runs a BM25 search to find the most relevant code snippets for the current instruction, and then assembles those snippets into a Snippet‑Integration‑Format (SIF) prompt ready for the LLM.

1. **Configuration**  
   Sets retrieval parameters (`TOP_K_SNIPPETS`, BM25’s `k1` and `b`) and locations (`DRIVE_KB_SAVE_DIR`) to control how many and which snippets to fetch.

2. **Tokenizer Helper**  
   Defines a simple code‑aware tokenizer (`simple_code_tokenizer`) that normalizes text, splits on punctuation/whitespace, and filters out noise—preparing both snippets and the user instruction for BM25.

3. **KB Loading**  
   Attempts to use the in‑memory `current_kb` list; if unavailable, loads the JSON file from Drive. Validates that the KB is a non‑empty list of strings before proceeding.

4. **BM25 Indexing & Retrieval**  
   - **Tokenization:** Converts all valid KB snippets into token lists.  
   - **Index Construction:** Builds a BM25 index with the specified parameters.  
   - **Querying:** Tokenizes the instruction, scores every snippet, and selects the top K matches.  
   - **Preview:** Prints a shortened preview of each retrieved snippet for quick inspection.

5. **SIF Prompt Creation**  
   Defines `create_sif_prompt()`, which:  
   - Calculates available token budget given the model’s context window.  
   - Iteratively incorporates retrieved snippets (wrapped in fenced Python blocks), stopping when the token budget is reached.  
   - Embeds the final instruction at the end, producing a single string ready to send to the model.  
   - Reports estimated token usage, warns if limits are exceeded, and shows a truncated preview of the assembled prompt.

By the end of this cell, you will have a ranked set of relevant code snippets and a fully formatted, token‑aware prompt that leverages those snippets to guide the LLM’s code generation.  


In [None]:
# --- 1. Configuration ---
SAMPLE_INDEX = 0      # Index of the sample to process (same as the KB cells)
TOP_K_SNIPPETS = 5    # Number of snippets to retrieve with BM25
BM25_K1 = 1.5         # BM25 parameter (common default, controls TF saturation)
BM25_B = 0.75         # BM25 parameter (common default, controls document length)
DRIVE_KB_SAVE_DIR = '/content/drive/MyDrive/RAG_Project/library_kbs' # KB folder on Drive

# --- 2. Tokenizer Helper ---
def simple_code_tokenizer(text):
    """
    Simple tokenizer optimized for code snippets:
    - lowercase
    - split on spaces and common punctuation (keeping underscores)
    - optionally removes very short tokens
    """
    if not isinstance(text, str):  # Handles non-string input
        return []
    text = text.lower()
    # Replace non-alphanumeric or underscore characters with space
    text = re.sub(r'[^\w\s]', ' ', text)
    # Split on multiple spaces
    tokens = text.split()
    # Optional: remove very short tokens (e.g., length 1), they might be noise
    # tokens = [token for token in tokens if len(token) > 1]
    return tokens

# --- 3. KB Loading ---

print("--- Retrieval with BM25 ---")
kb_data = None  # Initialize KB

# Check required variables from previous cells
if 'lca_dataset_split' not in locals() or not lca_dataset_split:
    raise NameError("CRITICAL ERROR: 'lca_dataset_split' not defined or empty. Re-run the dataset loading cell.")

# Try using KB already in memory ('current_kb' from the previous cell)
# Check that it exists, is a list, and is not empty
if 'current_kb' in locals() and isinstance(current_kb, list) and current_kb:
    print("Using in-memory KB ('current_kb').")
    kb_data = current_kb
else:
    # If current_kb is not valid, try loading from Drive
    print("\n'current_kb' not available or empty in memory.")
    try:
        # Determine KB file name (requires repo_full_name)
        sample = lca_dataset_split[SAMPLE_INDEX]
        repo_full_name_for_kb = sample.get('repo_full_name')
        if not repo_full_name_for_kb:
            print(f"[ERROR] 'repo_full_name' not found in sample {SAMPLE_INDEX} to load KB.")
        else:
            # Clean repo name and build path
            safe_repo_name = repo_full_name_for_kb.replace('/', '__')
            kb_filename = f"kb_{safe_repo_name}_sample_{SAMPLE_INDEX}.json"
            kb_full_path = os.path.join(DRIVE_KB_SAVE_DIR, kb_filename)

            if os.path.exists(kb_full_path):
                print(f"Attempting to load KB from Drive: {kb_full_path}")
                with open(kb_full_path, 'r', encoding='utf-8') as f:
                    kb_data = json.load(f)
                # Additional check: is the loaded file a non-empty list?
                if isinstance(kb_data, list) and kb_data:
                    print(f"KB for '{repo_full_name_for_kb}' loaded from Drive ({len(kb_data)} snippets).")
                else:
                    print(f"[ERROR] KB file loaded from '{kb_full_path}' is not a valid list or is empty.")
                    kb_data = None  # Reset if content is invalid
            else:
                print(f"[ERROR] KB file not found at: {kb_full_path}")

    except IndexError:
        print(f"[ERROR] Invalid index {SAMPLE_INDEX} for 'lca_dataset_split' when retrieving repo name.")
    except FileNotFoundError:  # If DRIVE_KB_SAVE_DIR does not exist
        print(f"[ERROR] KB directory on Drive not found: {DRIVE_KB_SAVE_DIR}")
    except Exception as e:
        print(f"[ERROR] Unexpected error while loading KB from Drive: {e}")
        kb_data = None  # Ensure None in case of error

# If kb_data is still not loaded, exit with a clear error
if not kb_data:
    raise RuntimeError("CRITICAL ERROR: Unable to obtain Knowledge Base (KB) data, neither from memory nor Drive. "
                       "Run Step 2.B cell first to create/save the KB.")

# --- 4. BM25 Indexing & Retrieval ---
retrieved_snippets_bm25 = []  # Initialize results list

try:
    # Extract instruction and repo name (reuse sample if previously loaded)
    if 'sample' not in locals() or sample is None:  # Load sample if not already loaded
        sample = lca_dataset_split[SAMPLE_INDEX]
    instruction = sample.get('instruction')
    repo_full_name = sample.get('repo_full_name', 'N/A')  # Use N/A if missing

    if not instruction:
        print("[ERROR] Instruction (query) not found in the sample.")
    else:
        print(f"\n--- Running BM25 for Sample {SAMPLE_INDEX} (Library: {repo_full_name}) ---")
        print(f"Instruction (Query): {instruction[:250]}...")  # Show a bit more of the query

        # 4.1 Tokenize the Knowledge Base (ensure snippets are strings)
        print("\nTokenizing Knowledge Base...")
        valid_kb_docs = [doc for doc in kb_data if isinstance(doc, str) and doc.strip()]
        if len(valid_kb_docs) < len(kb_data):
            print(f"  Warning: {len(kb_data) - len(valid_kb_docs)} invalid snippets (non-strings/empty) ignored.")

        if not valid_kb_docs:
            print("[ERROR] No valid snippets found in the KB after cleaning.")
        else:
            tokenized_kb = [simple_code_tokenizer(doc) for doc in valid_kb_docs]
            # Remove any empty lists resulting from tokenization
            tokenized_kb_filtered = [tokens for tokens in tokenized_kb if tokens]
            if not tokenized_kb_filtered:
                print("[ERROR] Tokenized KB is empty after removing empty tokens.")
            else:
                original_indices = [i for i, tokens in enumerate(tokenized_kb) if tokens]  # Original indices of valid docs
                print(f"Tokenized KB ({len(tokenized_kb_filtered)} valid documents).")

                # 4.2 Create the BM25 index with configured parameters
                print(f"Creating BM25 index (k1={BM25_K1}, b={BM25_B})...")
                bm25 = BM25Okapi(tokenized_kb_filtered, k1=BM25_K1, b=BM25_B)
                print("BM25 index created.")

                # 4.3 Tokenize the instruction (query)
                print("Tokenizing instruction (query)...")
                tokenized_query = simple_code_tokenizer(instruction)
                if not tokenized_query:
                    print("[ERROR] Tokenized query is empty.")
                else:
                    # 4.4 Perform the retrieval
                    print(f"Retrieving top {TOP_K_SNIPPETS} relevant snippets...")
                    scores = bm25.get_scores(tokenized_query)
                    top_n_filtered_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:TOP_K_SNIPPETS]
                    retrieved_snippets_bm25 = [
                        valid_kb_docs[original_indices[i]] for i in top_n_filtered_indices if i < len(original_indices)
                    ]

                    print(f"\n--- Top {len(retrieved_snippets_bm25)} Snippets Retrieved (BM25) ---")
                    if retrieved_snippets_bm25:
                        for i, snippet in enumerate(retrieved_snippets_bm25):
                            print(f"\n--- Snippet {i+1} (BM25 Rank {i+1}) ---")
                            snippet_preview = textwrap.shorten(
                                snippet.strip(),
                                width=120,
                                placeholder=f" ... (total length: {len(snippet)} characters)"
                            )
                            print(snippet_preview)
                    else:
                        print("No snippets retrieved.")

except IndexError:
    print(f"[ERROR] Invalid index {SAMPLE_INDEX} for 'lca_dataset_split'.")
except Exception as main_e:
    import traceback
    print(f"\n[ERROR] Unexpected error in BM25 main script:")
    print(traceback.format_exc())

if retrieved_snippets_bm25:
    print(f"\n--- [OK] Retrieved {len(retrieved_snippets_bm25)} BM25 snippets ---")
    # The variable 'retrieved_snippets_bm25' contains the list of strings
else:
    print(f"\n--- [WARNING/ERROR] No snippets retrieved from BM25 ---")

--- Retrieval with BM25 ---
Using in-memory KB ('current_kb').

--- Running BM25 for Sample 0 (Library: seed-labs__seed-emulator) ---
Instruction (Query): Generate code that creates an emulation using the seedemu library. The emulation should include three layers: base, routing, and eBGP. It should also include a domain name caching service. 

The base layer should create multiple autonomous systems an...

Tokenizing Knowledge Base...
Tokenized KB (1196 valid documents).
Creating BM25 index (k1=1.5, b=0.75)...
BM25 index created.
Tokenizing instruction (query)...
Retrieving top 5 relevant snippets...

--- Top 5 Snippets Retrieved (BM25) ---

--- Snippet 1 (BM25 Rank 1) ---
def makeStubAs(emu: Emulator, base: Base, asn: int, exchange: int, services: ... (total length: 895 characters)

--- Snippet 2 (BM25 Rank 2) ---
def __init__( self, onAsConflict: Callable[[AutonomousSystem, AutonomousSystem], ... (total length: 1088 characters)

--- Snippet 3 (BM25 Rank 3) ---
def install(self, vnode

In [None]:
# --- 5. SIF Prompt Creation ---

# --- Constants and Configurations (Optional but good practice) ---
# Conservative estimate of tokens for the fixed prompt structure
# (You can calculate it more precisely later with your tokenizer)
# Safety margin to avoid hitting the limit exactly
PROMPT_TEMPLATE_BASE_TOKENS = 100
TOKEN_LIMIT_MARGIN = 50

def create_sif_prompt(
    instruction: str,                          # Original instruction
    retrieved_snippets: list[str],            # List of retrieved snippets (from BM25 or similar)
    tokenizer,                                # Loaded Hugging Face tokenizer instance
    max_prompt_tokens: int = 3500,            # Maximum tokens for the entire prompt
    # model_max_length: Optional[int] = None  # Optional: Model max length (if different)
) -> str:
    """
    Creates a SIF (Snippet Integration Format) prompt optimized for an LLM.

    Integrates retrieved snippets as context for code generation based on the given instruction,
    handling tokenization and truncation.

    Args:
        instruction: The user's instruction.
        retrieved_snippets: List of strings containing the retrieved code snippets.
        tokenizer: The initialized Hugging Face tokenizer instance.
        max_prompt_tokens: The approximate maximum tokens allowed for the final prompt.
                           (Considers the LLM's context window minus the tokens for the response).
        # model_max_length: Optional: The model's absolute max length, if known and different
        #                   from tokenizer.model_max_length.

    Returns:
        The formatted prompt string ready to be passed to the LLM.
        Returns an empty string if the instruction is missing.

    Raises:
        TypeError: If tokenizer is not provided or is invalid.
        ValueError: If max_prompt_tokens is not a positive integer.
    """
    # --- Input Validation ---
    if not isinstance(instruction, str) or not instruction.strip():
        warnings.warn("Missing or empty instruction; returning an empty prompt.")
        return ""
    if tokenizer is None or not hasattr(tokenizer, 'encode'):
        raise TypeError("A valid Hugging Face tokenizer is required for create_sif_prompt.")
    if not isinstance(max_prompt_tokens, int) or max_prompt_tokens <= 0:
        raise ValueError("max_prompt_tokens must be a positive integer.")

    # Determine the effective context limit of the model, if available
    effective_model_max_length = getattr(tokenizer, 'model_max_length', None)
    if effective_model_max_length and max_prompt_tokens > effective_model_max_length:
        warnings.warn(
            f"max_prompt_tokens ({max_prompt_tokens}) exceeds the model's maximum length"
            f" ({effective_model_max_length}). The model limit will take precedence"
        )

    # --- Improved Prompt Template ---
    prompt_template = """SYSTEM: You are an expert Python programmer. Generate Python code based ONLY on the user's instruction, using the provided library code snippets for context and correct API usage. Adapt snippets as needed; do not copy them verbatim unless requested.

USER:
### Context: Relevant Code Snippets from Library

{snippets_section}
### Instruction:
{instruction}

ASSISTANT:
```python
"""

    # --- End of Template ---

    # --- Calculating Available Space for Snippets ---
    # Tokenize instruction and base template to know how much space remains
    # Use add_special_tokens=False to count only content tokens
    instruction_tokens = len(tokenizer(instruction, add_special_tokens=False).input_ids)
    template_base_formatted = prompt_template.format(snippets_section="", instruction="")
    template_base_tokens = len(tokenizer(template_base_formatted, add_special_tokens=False).input_ids)

    available_tokens_for_snippets = max(
        0,
        max_prompt_tokens
        - instruction_tokens
        - template_base_tokens
        - TOKEN_LIMIT_MARGIN
    )
    print(f"Token calculation: Total max={max_prompt_tokens}, Instruction={instruction_tokens}, Base template={template_base_tokens}")
    print(f"Available tokens for snippets (approx): {available_tokens_for_snippets}")

    # --- Constructing Snippet Section with Token Checks ---
    snippets_text_parts = []
    accumulated_snippet_tokens = 0
    snippets_included_count = 0

    if not retrieved_snippets:
        warnings.warn("No snippets provided to create_sif_prompt.")

    for i, snippet in enumerate(retrieved_snippets):
        if not isinstance(snippet, str) or not snippet.strip():
            continue

        snippet_header = f"# --- Snippet {i+1} ---\n"
        snippet_content = snippet.strip().strip('`')
        if not snippet_content:
            continue
        snippet_formatted = f"```python\n{snippet_content}\n```\n\n"

        # Estimate tokens for this snippet (header + formatted code)
        current_snippet_section_tokens = len(
            tokenizer(snippet_header + snippet_formatted, add_special_tokens=False).input_ids
        )

        # Check if adding this snippet exceeds available space
        if accumulated_snippet_tokens + current_snippet_section_tokens > available_tokens_for_snippets:
            print(
                f"INFO: Token limit for snippets ({available_tokens_for_snippets}) reached. "
                f"Snippet {i+1} and subsequent ones skipped."
            )
            break

        # Add snippet to the prompt
        snippets_text_parts.append(snippet_header)
        snippets_text_parts.append(snippet_formatted)
        accumulated_snippet_tokens += current_snippet_section_tokens
        snippets_included_count += 1

    # Assemble final snippet section
    if snippets_included_count > 0:
        snippets_section_content = "".join(snippets_text_parts).strip()
    else:
        snippets_section_content = "# (No relevant snippets provided or all exceeded token limit)"

    # --- Composing Final Prompt ---
    final_prompt = prompt_template.format(
        snippets_section=snippets_section_content,
        instruction=instruction
    )

    # --- Final Length Check (Optional but Useful) ---
    final_token_count = len(
        tokenizer(final_prompt, add_special_tokens=False).input_ids
    )
    print(f"\nPrompt SIF created.")
    print(f"  Snippets included: {snippets_included_count} / {len(retrieved_snippets)}")
    print(f"  Estimated length (content only): {final_token_count} tokens (Limit set: {max_prompt_tokens})")

    if effective_model_max_length and final_token_count > effective_model_max_length:
        warnings.warn(
            f"The final prompt ({final_token_count} tokens) EXCEEDS the model's maximum length"
            f" ({effective_model_max_length}). It may be truncated or cause errors."
        )
    elif final_token_count > max_prompt_tokens:
        warnings.warn(
            f"The final prompt ({final_token_count} tokens) EXCEEDS the 'max_prompt_tokens' limit"
            f" ({max_prompt_tokens}). The token estimate may be inaccurate."
        )

    return final_prompt

# --- Example Usage (Modified to use correct variable) ---
print("\n" + "="*40)
print("--- Step 4: Creating RAG Prompt (SIF) ---")
print("="*40)

sif_prompt_final = None

if ('instruction' in locals() and instruction and
    'retrieved_snippets_bm25' in locals() and isinstance(retrieved_snippets_bm25, list) and
    'tokenizer' in locals() and tokenizer):

    prompt_token_limit = 3500
    print(f"Creating SIF prompt with max {prompt_token_limit} tokens...")
    sif_prompt_final = create_sif_prompt(
        instruction=instruction,
        retrieved_snippets=retrieved_snippets_bm25,
        tokenizer=tokenizer,
        max_prompt_tokens=prompt_token_limit
    )

    if sif_prompt_final:
        print("\n--- Preview of Final SIF Prompt (start) ---")
        # Usa textwrap.shorten per la preview
        print(textwrap.shorten(sif_prompt_final, width=1500, placeholder=" [...]\n```python\n")) # show the beginning
    else:
         print("[ERROR] Failed to create the SIF prompt (returned empty).")

else:
    missing_vars = []
    if 'instruction' not in locals() or not instruction: missing_vars.append("'instruction'")
    if 'retrieved_snippets_bm25' not in locals() or not isinstance(retrieved_snippets_bm25, list): missing_vars.append("'retrieved_snippets_bm25' (BM25 list)")
    if 'tokenizer' not in locals() or not tokenizer: missing_vars.append("'tokenizer'")
    print(f"[ERROR] Cannot create SIF prompt. Missing or invalid variables: {', '.join(missing_vars)}.")
    print("         Ensure the previous cells (dataset loading, BM25, tokenizer load) ran correctly.")

print("\n--- End of SIF Prompt Creation ---")


--- Step 4: Creating RAG Prompt (SIF) ---
Creating SIF prompt with max 3500 tokens...
Token calculation: Total max=3500, Instruction=158, Base template=69
Available tokens for snippets (approx): 3223

Prompt SIF created.
  Snippets included: 5 / 5
  Estimated length (content only): 959 tokens (Limit set: 3500)

--- Preview of Final SIF Prompt (start) ---
SYSTEM: You are an expert Python programmer. Generate Python code based ONLY on the user's instruction, using the provided library code snippets for context and correct API usage. Adapt snippets as needed; do not copy them verbatim unless requested. USER: ### Context: Relevant Code Snippets from Library # --- Snippet 1 --- ```python def makeStubAs(emu: Emulator, base: Base, asn: int, exchange: int, services: List[Service]): """! @brief create a new stub AS. @param emu reference to the Emulator object. @param base reference to the base layer. @param asn ASN for the newly created AS. @param exchange IXP ID for new newly created AS to jo

## Section 7 · RAG Code Generation and Output Processing



This is the point where the **Retrieval-Augmented Generation (RAG)** pipeline actually **creates new code** from the prompt assembled in the previous steps.



 1 · Generation Setup
* **Configuration parameters**  
  * `MAX_NEW_TOKENS` – upper bound on tokens the model may produce.  
  * `TEMPERATURE` – randomness / creativity; lower → more deterministic.  
  * `TOP_P`, `TOP_K` – nucleus & top-k sampling thresholds.  
  * `REPETITION_PENALTY` – discourages verbatim repetition.  
  * `STOP_ON_EOS`, `STOP_ON_CODE_END` – early-stop on `<eos>` or when the model closes a ``` code block.
* **Advanced stopping criterion**  
  A custom `EosAndCodeStopCriteria` halts decoding as soon as either condition is met, preventing endless or irrelevant output.
* **(Optional) Forced decoder IDs**  
  A commented stub shows how you could force the model to start with a token such as `<think>` to steer generation, but it is **disabled by default**.


2 · Input & Generation
1. **Dependency check** – verifies that `sif_prompt_final`, `model`, and `tokenizer` are present; otherwise the cell aborts with a clear error.
2. **Tokenisation** – the prompt is converted into IDs the model understands and moved to the correct device (GPU/CPU).
3. **Code generation** – `model.generate()` is called with the chosen decoding hyper-parameters and (optionally) the custom stopping list.
4. **Decoding & cleanup**  
   * Newly generated tokens are decoded back to text.  
   * A regex extracts the first ```python …``` block (or a fallback slice) so only the **relevant code** is returned.


 3 · Error Handling
* **Out-of-memory (OOM)** – catches `torch.cuda.OutOfMemoryError`, prints advice on shrinking prompt or `MAX_NEW_TOKENS`.
* **Unexpected exceptions** – a generic `try/except` prints the full traceback for rapid debugging.


4 · Final Verification
* If `generated_code_rag` **contains code**, the cell announces success.  
* Otherwise it flags a failure, pointing back to missing dependencies or runtime errors.


**In short:** this section feeds the prepared RAG prompt to the language model, retrieves the fresh code it produces, sanitises the output, and robustly reports any issues encountered along the way.


In [None]:
import torch
from transformers import StoppingCriteria, StoppingCriteriaList, LogitsProcessor, LogitsProcessorList
import warnings
import time

# --- 1. Generation Setup ---
# --- 1.1. Configuration Parameters ---
MAX_NEW_TOKENS = 1024      # Max tokens to generate for the response
TEMPERATURE = 0.6          # Recommended value for R1-Distill (0.5-0.7). Lower = more deterministic
TOP_P = 0.95               # Nucleus sampling (considers only tokens whose cumulative probability > top_p)
TOP_K = 50                 # Top-k sampling (considers only the top k most probable tokens)
REPETITION_PENALTY = 1.1   # Slightly penalize already generated tokens (e.g., 1.1-1.2) to reduce repetition
DO_SAMPLE = True           # Enable sampling (True to use temp/top_p/top_k, False for greedy/deterministic)
STOP_ON_EOS = True         # Stop generation if the EOS token is generated
STOP_ON_CODE_END = True    # Attempt to stop after the end of a code block (e.g. ```)

# --- 1.2. Advanced Stopping Criteria (Optional but Recommended) ---
# Combines EOS stop and, optionally, code block ending

class EosAndCodeStopCriteria(StoppingCriteria):
    def __init__(self, tokenizer, stop_on_eos=True, stop_sequence="\n```\n"):
        self.tokenizer = tokenizer
        self.stop_on_eos = stop_on_eos
        self.stop_sequence = stop_sequence
        self.stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
        # Remove any unwanted leading/trailing tokens from the stop sequence
        # (e.g., if encode adds BOS) - may require tokenizer-specific debugging
        print(f"Stopping sequence: '{self.stop_sequence}' -> IDs: {self.stop_sequence_ids}")
        print(f"Stop on EOS ({self.tokenizer.eos_token_id}): {self.stop_on_eos}")

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # 1. Check EOS
        if self.stop_on_eos and (input_ids[0, -1] == self.tokenizer.eos_token_id):
            print("Stopping criteria: EOS token detected.")
            return True

        # 2. Check the stop sequence (e.g., \n```\n)
        if self.stop_sequence_ids:
             # Check if the last N tokens match the stop sequence
             len_stop_seq = len(self.stop_sequence_ids)
             if input_ids.shape[1] >= len_stop_seq:
                  last_tokens = input_ids[0, -len_stop_seq:]
                  if torch.equal(last_tokens, torch.tensor(self.stop_sequence_ids).to(last_tokens.device)):
                      print(f"Stopping criteria: Stop sequence '{self.stop_sequence}' detected.")
                      return True
        return False

stopping_criteria_list = None
if STOP_ON_EOS or STOP_ON_CODE_END:
     try:
         custom_stopper = EosAndCodeStopCriteria(
             tokenizer,
             stop_on_eos=STOP_ON_EOS,
             stop_sequence="\n```\n" if STOP_ON_CODE_END else None # Use \n```\n as the code stop sequence
         )
         stopping_criteria_list = StoppingCriteriaList([custom_stopper])
         print("Custom StoppingCriteria created")
     except Exception as e:
          print(f"WARNING: Unable to create custom StoppingCriteria: {e}")

# --- 1.3. (Optional) Forced Decoder IDs to start with <think> ---
# According to R1-Distill recommendations. Basic implementation:
# think_token_sequence = tokenizer.encode("<think>\n", add_special_tokens=False)
# force_think_processor = LogitsProcessorList([
#     ForcedBOSTokenLogitsProcessor(think_token_sequence[0]),  # Force the first token
#     ForcedEOSTokenLogitsProcessor(max_length=MAX_NEW_TOKENS + len(think_token_sequence), eos_token_id=think_token_sequence[1:])  # Force the rest if necessary
# ])
# This part is complex and may require tokenizer-specific adjustments.
# For now we omit it and rely on manually adding it to the prompt if needed.

# --- 2. Input & Generation ---

print("\n" + "="*40)
print("--- Step 5: RAG Code Generation ---")
print("="*40)

generated_code_rag = None # Initialize output

# Check dependencies
if 'sif_prompt_final' in locals() and sif_prompt_final and \
   'model' in locals() and model and \
   'tokenizer' in locals() and tokenizer:

    print(f"SIF prompt received (length: {len(sif_prompt_final)} chars).")
    print("Generation parameters:")
    print(f"  max_new_tokens={MAX_NEW_TOKENS}, temperature={TEMPERATURE if DO_SAMPLE else 'N/A (Greedy)'}")
    print(f"  top_p={TOP_P if DO_SAMPLE else 'N/A'}, top_k={TOP_K if DO_SAMPLE else 'N/A'}")
    print(f"  repetition_penalty={REPETITION_PENALTY}")
    print(f"  do_sample={DO_SAMPLE}")
    print(f"  Stopping Criteria: {'Active' if stopping_criteria_list else 'Inactive'}")

    try:
        # --- Tokenization ---
        print("\nTokenizing SIF prompt...")
        # No need to truncate here if create_sif_prompt already handled limits
        # max_length = tokenizer.model_max_length  # Model maximum length
        inputs = tokenizer(
            sif_prompt_final,
            return_tensors="pt",
            # truncation=True,  # Enable only if strictly necessary
            # max_length=max_length - MAX_NEW_TOKENS  # Leave room for generation
        ).to(model.device)  # Move to GPU

        input_length = inputs['input_ids'].shape[1]
        print(f"Tokenized input length: {input_length} tokens.")

        # --- Generation ---
        print("Starting code generation...")
        start_time = time.time()

        generation_args = {
            "input_ids": inputs['input_ids'],
            "attention_mask": inputs['attention_mask'],
            "max_new_tokens": MAX_NEW_TOKENS,
            "pad_token_id": tokenizer.eos_token_id,
            "repetition_penalty": REPETITION_PENALTY,
            "stopping_criteria": stopping_criteria_list  # Can be None
        }
        if DO_SAMPLE:
            generation_args.update({
                "temperature": TEMPERATURE,
                "top_p": TOP_P,
                "top_k": TOP_K,
                "do_sample": True,
            })
        else:
            # Greedy (deterministic) generation
            generation_args["do_sample"] = False
            # temperature, top_p, top_k are not used

        with torch.no_grad():  # Essential for inference
            # outputs = model.generate(**inputs, ...)  # Alternate way
            outputs = model.generate(**generation_args)

        end_time = time.time()
        print(f"Generation completed in {end_time - start_time:.2f} seconds.")

        # --- Decode and Clean Output ---
        # Decode only the NEW generated tokens
        output_tokens = outputs[0, input_length:]
        generated_code_rag_full = tokenizer.decode(output_tokens, skip_special_tokens=True)

        print("\n--- Generated Code (Raw) ---")
        print(generated_code_rag_full[:500] + "..." if len(generated_code_rag_full) > 500 else generated_code_rag_full)

        # --- Specific Cleanup for Code Blocks ---
        # Look for the content inside the first ```python ... ``` block
        # This is more robust than splitting only on ```
        code_block_match = re.search(r'```python\n(.*?)(?:\n```|\Z)', generated_code_rag_full, re.DOTALL)
        if code_block_match:
            generated_code_rag = code_block_match.group(1).strip()
            print("\nExtracted code from the ```python ... ``` block.")
        else:
            # Fallback: if it does not find ```python, take everything before a closing ```
            # or simply take the whole output if there are no backticks.
            if "\n```" in generated_code_rag_full:  # Look for \n``` to avoid inline matches
                generated_code_rag = generated_code_rag_full.split("\n```")[0].strip()
                print("\n```python block not found, taking output before ```." )
            else:
                generated_code_rag = generated_code_rag_full.strip()
                print("\nNo ``` block found, taking the full output.")

        print("\n--- Generated Code (Clean) ---")
        print(generated_code_rag)

    # --- 3. Error Handling ---
    except torch.cuda.OutOfMemoryError as e:
        print(f"\n[ERROR] Out Of Memory (OOM) during GENERATION!")
        print("  The prompt plus the generated output may exceed VRAM.")
        print("  Try reducing 'max_prompt_tokens' in create_sif_prompt or 'MAX_NEW_TOKENS' here.")
        generated_code_rag = None
    except Exception as e:
        import traceback
        print(f"\n[ERROR] Unexpected error during RAG generation:")
        print(traceback.format_exc())
        generated_code_rag = None

else:
    missing = []
    if 'sif_prompt_final' not in locals() or not sif_prompt_final: missing.append("'sif_prompt_final'")
    if 'model' not in locals() or not model: missing.append("'model'")
    if 'tokenizer' not in locals() or not tokenizer: missing.append("'tokenizer'")
    print(f"[ERROR] Unable to perform generation. Missing or invalid variables: {', '.join(missing)}.")
    print("         Make sure the previous cells have been executed correctly.")
    generated_code_rag = None

# --- 4. Final Verification ---
if generated_code_rag:
    print("\n--- RAG code generation completed ---")
    # The variable 'generated_code_rag' contains the cleaned code
else:
    print("\n--- [ERROR] RAG code generation failed or was not executed ---")


Stopping sequence: '
```
' -> IDs: [198, 13874, 3989]
Stop on EOS (151643): True
Custom StoppingCriteria created

--- Step 5: RAG Code Generation ---
[ERROR] Unable to perform generation. Missing or invalid variables: 'model'.
         Make sure the previous cells have been executed correctly.

--- [ERROR] RAG code generation failed or was not executed ---


## Section 8 · Baseline Generation and RAG Comparison






In this stage we generate a **baseline**—code produced by the LLM **without any retrieval augmentation**—and
contrast it with the RAG output obtained earlier.  
The goal is to isolate the model’s native ability, establish a reference point, and quantify the gains introduced by
the Retrieval-Augmented Generation pipeline.


Procedure

1. **Baseline Code Generation**  
   1. A minimal prompt is built that contains only the user instruction.  
   2. The prompt is tokenised and passed to the LLM with the same decoding parameters used for RAG.  
   3. The raw output is decoded and cleaned, stripping headers, back-ticks, or other artefacts so that only executable code remains.

2. **Side-by-side Evaluation**  
   * The baseline output is compared to the RAG output generated in Section 7.  
   * The same evaluation metrics are applied to both versions.  
   * A comparative report highlights improvements or regressions in quality, accuracy, and completeness.

Evaluation Metrics

| Metric | Purpose |
|--------|---------|
| **BLEU** | Token-level similarity to the reference implementation. |
| **CodeBLEU** | Code-aware score that accounts for syntax, data-flow and API usage. |
| **Accuracy** | Functional correctness (e.g. pass/fail on test cases). |
| **Completeness** | Whether all requested features from the instruction are implemented. |

Expected Outcome

We anticipate that incorporating RAG will **increase quality and accuracy** versus the standalone
LLM.  
Quantifying these deltas allows us to assess how effective retrieval is for the specific dataset and model used.
Analysis & Conclusions

The results table will be analysed to identify:

* Scenarios where the baseline already excels (little room for RAG improvement).  
* Cases where RAG corrects or enriches the baseline solution.  
* Remaining weaknesses—e.g. when poor retrieval hurts generation—which inform future improvements.

Notes & Caveats

* This section assumes that the LLM, tokenizer, dataset, and decoding parameters were successfully initialised in earlier cells.  
* Metric scores will vary with the dataset, the LLM architecture, and hyper-parameters chosen.  
* Always interpret the numbers in the context of your project requirements and evaluation budget.

In [None]:
import torch
import time
import re       # Required for regex cleanup
import textwrap # For prompt preview
import warnings # To handle warnings

print("\n" + "=" * 40)
print("--- Step 6.A: Baseline Generation (LLM-only) ---")
print("=" * 40)
print("NOTE: This cell expects that 'instruction', 'model', 'tokenizer'")
print("      and the generation parameters (MAX_NEW_TOKENS, etc.) have")
print("      been defined in the previous cells (including Step 5).")

generated_code_baseline = None  # Initialize output

# --- 1. Robust Dependency Check ---
# Verify all necessary variables inherited from the previous execution
required_vars = [
    'instruction', 'model', 'tokenizer',
    'MAX_NEW_TOKENS', 'TEMPERATURE', 'TOP_P',
    'TOP_K', 'REPETITION_PENALTY', 'DO_SAMPLE'
]
missing_vars = []
invalid_vars = []

for var_name in required_vars:
    if var_name not in locals():
        missing_vars.append(f"'{var_name}'")
    # Also check that they are not None or empty (where applicable)
    elif var_name in ['instruction', 'model', 'tokenizer'] and not locals()[var_name]:
        invalid_vars.append(f"'{var_name}' (is None or empty)")

# Also verify the stopping criteria (optional, but if it exists it must be used)
# If it doesn't exist from the previous cell, it will be set to None later
stopping_criteria_to_use = locals().get('stopping_criteria_list', None)

# --- 2. Proceed only if all dependencies are OK ---
if not missing_vars and not invalid_vars:

    print("\nAll required variables were found.")

    # --- 3. Baseline Prompt Creation ---
    # Use the same prompt structure for consistency (even if simple)
    baseline_prompt = f"""USER:
### Instruction:
{instruction}

ASSISTANT:
```python
"""
    # Do not print the entire prompt if it is very long
    print("\nBaseline Prompt (start):")
    print(textwrap.shorten(baseline_prompt, width=1200, placeholder="...```python\n"))

    # --- 4. Code Generation ---
    print(f"\nUsing the SAME parameters inherited from the RAG generation:")
    print(f"  max_new_tokens={MAX_NEW_TOKENS}, temperature={TEMPERATURE if DO_SAMPLE else 'N/A (Greedy)'}")
    print(f"  top_p={TOP_P if DO_SAMPLE else 'N/A'}, top_k={TOP_K if DO_SAMPLE else 'N/A'}")
    print(f"  repetition_penalty={REPETITION_PENALTY}")
    print(f"  do_sample={DO_SAMPLE}")
    print(f"  Stopping Criteria: {'Active' if stopping_criteria_to_use else 'Inactive'}")

    try:
        # --- Tokenization ---
        inputs_base = tokenizer(baseline_prompt, return_tensors="pt").to(model.device)
        input_length_base = inputs_base['input_ids'].shape[1]
        print(f"\nTokenized input length: {input_length_base} tokens.")

        # --- model.generate call (Same as RAG except for the input) ---
        print("Starting Baseline generation...")
        start_time = time.time()

        generation_args_base = {
            "input_ids": inputs_base['input_ids'],
            "attention_mask": inputs_base['attention_mask'],
            "max_new_tokens": MAX_NEW_TOKENS,
            "pad_token_id": tokenizer.eos_token_id,
            "repetition_penalty": REPETITION_PENALTY,
            "stopping_criteria": stopping_criteria_to_use  # Use the same one from RAG (can be None)
        }
        if DO_SAMPLE:
            generation_args_base.update({
                "temperature": TEMPERATURE,
                "top_p": TOP_P,
                "top_k": TOP_K,
                "do_sample": True,
            })
        else:
            generation_args_base["do_sample"] = False

        with torch.no_grad():
            outputs_base = model.generate(**generation_args_base)

        end_time = time.time()
        print(f"Baseline generation completed in {end_time - start_time:.2f} seconds.")

        # --- Decode and Clean (Same logic as RAG) ---
        output_tokens_base = outputs_base[0, input_length_base:]
        generated_code_baseline_full = tokenizer.decode(output_tokens_base, skip_special_tokens=True)

        print("\n--- Baseline Generated Code (Raw) ---")
        print(generated_code_baseline_full[:500] + "..." if len(generated_code_baseline_full) > 500 else generated_code_baseline_full)

        # Cleanup with Regex (identical to RAG)
        code_block_match_base = re.search(r'```python\n(.*?)(?:\n```|\Z)', generated_code_baseline_full, re.DOTALL)
        if code_block_match_base:
            generated_code_baseline = code_block_match_base.group(1).strip()
            print("\nExtracted code from the ```python block.")
        else:
            if "\n```" in generated_code_baseline_full:
                generated_code_baseline = generated_code_baseline_full.split("\n```")[0].strip()
                print("\n```python block not found, took output before ```.")
            else:
                generated_code_baseline = generated_code_baseline_full.strip()
                print("\nNo ``` block found, taking the full output.")

        print("\n--- Generated Code (Baseline LLM-only - Clean) ---")
        print(generated_code_baseline or "[Empty generation]")

    except torch.cuda.OutOfMemoryError as e:
        print(f"\n[ERROR] Out Of Memory (OOM) during BASELINE GENERATION!")
        print("  Try reducing 'MAX_NEW_TOKENS'.")
        generated_code_baseline = None  # Ensure None in case of error
    except Exception as e:
        import traceback
        print(f"\n[ERROR] Unexpected error during Baseline generation:")
        print(traceback.format_exc())
        generated_code_baseline = None  # Ensure None in case of error

else:
    # Print detailed error message
    print("\n[ERROR] Unable to perform Baseline generation.")
    error_msg = "         Issue detected with:"
    if missing_vars:
        error_msg += f" Missing variables: {', '.join(missing_vars)}."
    if invalid_vars:
        error_msg += f" Invalid variables (None/empty): {', '.join(invalid_vars)}."
    print(error_msg)
    print("         Make sure ALL previous cells (data/model loading, RAG generation) executed successfully.")

# --- 5. Final Verification ---
if generated_code_baseline is not None:
    print("\n--- Baseline code generation completed ---")
else:
    print("\n--- [ERROR] Baseline code generation failed or was not executed ---")



--- Step 6.A: Baseline Generation (LLM-only) ---
NOTE: This cell expects that 'instruction', 'model', 'tokenizer'
      and the generation parameters (MAX_NEW_TOKENS, etc.) have
      been defined in the previous cells (including Step 5).

[ERROR] Unable to perform Baseline generation.
         Issue detected with: Invalid variables (None/empty): 'model' (is None or empty).
         Make sure ALL previous cells (data/model loading, RAG generation) executed successfully.

--- [ERROR] Baseline code generation failed or was not executed ---


In [None]:
import textwrap
import json
import os
import re
import torch # Assicurati sia importato

# --- Setup Metriche ---
# --- Funzione ChrF ---
try:
    import sacrebleu
    print(f"Sacrebleu versione: {getattr(sacrebleu, '__version__', 'N/A')}") # Stampa versione per debug

    def calculate_chrf(prediction, reference):
        """Calcola ChrF (o ChrF++) usando sacrebleu, con gestione errori e tipi."""
        # Validazione Input
        if not isinstance(prediction, str) or not isinstance(reference, str):
             print("Errore ChrF: Predizione o riferimento non sono stringhe.")
             return None # O 0.0 se preferisci
        if not prediction or not reference:
             # Restituiamo 0.0 se uno è vuoto, come da comportamento originale
             # Ma potresti preferire None se vuoi distinguere zero score da input vuoto
             return 0.0

        try:
            # Chiamata corretta per sacrebleu >= 2.0.0
            score = sacrebleu.corpus_chrf([prediction], [[reference]]).score
            return score
        except Exception as e:
            import traceback
            print(f"--- ERRORE durante il calcolo di ChrF ---")
            # Stampa più dettagli per il debug
            pred_type = type(prediction).__name__
            ref_type = type(reference).__name__
            pred_preview = str(prediction)[:100] + '...' if prediction else 'None'
            ref_preview = str(reference)[:100] + '...' if reference else 'None'
            print(f"  Errore: {e}")
            # print(traceback.format_exc()) # Decommenta per stack trace completo se necessario
            print(f"  Predizione (tipo {pred_type}): {pred_preview}")
            print(f"  Riferimento (tipo {ref_type}): {ref_preview}")
            return None # Restituisce None per indicare fallimento nel calcolo
except ImportError:
    print("WARNING: Libreria 'sacrebleu' non trovata. ChrF non sarà calcolato.")
    # Restituisce None per coerenza
    calculate_chrf = lambda p, r: None

# --- Funzione CodeBLEU ---
try:
    from codebleu import calc_codebleu, __version__ as codebleu_version
    print(f"CodeBLEU versione: {codebleu_version}") # Stampa versione per debug

    def calculate_codebleu(prediction, reference, lang="python", weights=(0.25, 0.25, 0.25, 0.25)):
        """Calcola CodeBLEU con gestione errori, validazione input e correzione argomenti."""
        # Validazione Input
        if not isinstance(prediction, str) or not isinstance(reference, str):
            print("Errore CodeBLEU: Predizione o riferimento non sono stringhe.")
            return None
        # Gestione input vuoti (CodeBLEU potrebbe dare errori o risultati 0)
        if not prediction:
            print("Warning CodeBLEU: Predizione vuota.")
            # Potrebbe avere senso restituire 0.0 o un valore specifico, ma None indica fallimento/non calcolato
            # return 0.0
        if not reference:
            print("Warning CodeBLEU: Riferimento vuoto.")
            # return 0.0

        try:
            # !!! CORREZIONE CRITICA: Ordine e formato argomenti !!!
            # references deve essere lista di liste, predictions lista semplice
            result_dict = calc_codebleu(
                references=[[reference]], # Doppio array per i riferimenti
                predictions=[prediction],   # Array singolo per le predizioni
                lang=lang,
                weights=weights
            )
            # Estrai il punteggio composito 'codebleu'
            # Aggiungi controllo se la chiave esiste nel dizionario risultato
            return result_dict.get('codebleu', None) # Restituisce None se 'codebleu' non è presente

        except TypeError as e:
            # Errore comune se tree-sitter non è configurato correttamente
            print(f"--- ERRORE CodeBLEU (TypeError): {e} ---")
            print("Questo spesso indica un problema irrisolto con la configurazione di tree-sitter.")
            print(f"  Verifica che la cella di configurazione/test di tree-sitter completi senza errori.")
            pred_preview = str(prediction)[:100] + '...' if prediction else 'None'
            ref_preview = str(reference)[:100] + '...' if reference else 'None'
            print(f"  Predizione (tipo {type(prediction).__name__}): {pred_preview}")
            print(f"  Riferimento (tipo {type(reference).__name__}): {ref_preview}")
            return None # Restituisce None per indicare fallimento
        except Exception as e:
            # Cattura altri possibili errori
            import traceback
            print(f"--- ERRORE INASPETTATO durante il calcolo di CodeBLEU ---")
            print(traceback.format_exc()) # Stampa lo stack trace completo per debug
            print("-----------------------------------------------------")
            pred_preview = str(prediction)[:100] + '...' if prediction else 'None'
            ref_preview = str(reference)[:100] + '...' if reference else 'None'
            print(f"  Predizione (tipo {type(prediction).__name__}): {pred_preview}")
            print(f"  Riferimento (tipo {type(reference).__name__}): {ref_preview}")
            return None # Restituisce None per indicare fallimento

except ImportError:
    print("WARNING: Libreria 'codebleu' non trovata. CodeBLEU non sarà calcolato.")
    # Restituisce None per coerenza
    calculate_codebleu = lambda p, r, lang="python": None

# --- Funzione API Recall (con controlli tipo aggiunti) ---
def calculate_api_recall(generated_code, reference_apis):
    """Calcola la recall delle API con validazione tipi."""
    # Validazione Input
    if not isinstance(generated_code, str):
        print("Errore API Recall: 'generated_code' non è una stringa.")
        return 0.0 # O None se preferisci
    if not isinstance(reference_apis, list):
        print("Errore API Recall: 'reference_apis' non è una lista.")
        return 0.0 # O None
    if not generated_code or not reference_apis: # Se uno è vuoto, recall è 0
        return 0.0

    present_apis = 0
    valid_ref_apis_count = 0 # Contiamo solo le API valide nel riferimento
    for api_call in reference_apis:
        # Validazione tipo elemento lista
        if not isinstance(api_call, str):
            # print(f"Warning API Recall: Trovato elemento non-stringa nella lista API di riferimento: {api_call}")
            continue # Salta questo elemento non valido
        if not api_call.strip(): # Salta stringhe vuote o solo spazi
             continue

        valid_ref_apis_count += 1 # Incrementa solo per API di riferimento valide
        try:
            # Usa word boundary per matchare nomi interi
            pattern = r'\b' + re.escape(api_call) + r'\b'
            if re.search(pattern, generated_code):
                present_apis += 1
        except re.error as e:
             print(f"Warning API Recall: Ignorato pattern regex non valido per API '{api_call}'. Errore: {e}")
             pass # Ignora pattern regex non validi

    # Calcola recall basandosi sul numero di API *valide* nel riferimento
    recall = present_apis / valid_ref_apis_count if valid_ref_apis_count > 0 else 0.0
    return recall

print("\nFunzioni per le metriche (calculate_chrf, calculate_codebleu, calculate_api_recall) definite/aggiornate.")

Sacrebleu versione: 2.5.1

Funzioni per le metriche (calculate_chrf, calculate_codebleu, calculate_api_recall) definite/aggiornate.


## Section 9 · Metrics Results


In [None]:
# Passo 6.B: Calcolo Metriche di Confronto

import textwrap
import json
import os
import re

# Assicurati che le funzioni delle metriche siano definite (da inizio Passo 6 o sopra)
# Esempio: calculate_api_recall, calculate_chrf, calculate_codebleu

print("\n" + "="*30)
print("--- Passo 6.B: Calcolo Metriche di Confronto ---")
print("="*30)

# Dizionario per conservare le metriche calcolate
metrics = {"api_recall": {}, "chrf": {}, "codebleu": {}}
calculation_possible = False # Flag per sapere se possiamo calcolare

# Verifica che gli output generati e il dataset siano disponibili
if ('lca_dataset_split' in locals() and lca_dataset_split and
    'generated_code_rag' in locals() and generated_code_rag is not None and
    'generated_code_baseline' in locals() and generated_code_baseline is not None):

    sample_index = 0 # L'indice dell'esempio che stiamo valutando
    sample = lca_dataset_split[sample_index]
    reference_code = sample.get('reference') # Codice di riferimento
    reference_apis = sample.get('unique_apis') # Lista API di riferimento

    print(f"Confronto per Sample {sample_index}...")

    if not reference_code:
        print("ATTENZIONE: Codice di riferimento non trovato nel dataset. Impossibile calcolare ChrF e CodeBLEU.")
    if not reference_apis:
        print("ATTENZIONE: Lista API di riferimento non trovata nel dataset. Impossibile calcolare API Recall.")

    calculation_possible = True # Possiamo provare a calcolare qualcosa

else:
    print("ERRORE: Impossibile calcolare metriche.")
    print("Verifica che 'lca_dataset_split', 'generated_code_rag', 'generated_code_baseline' siano definiti.")

# Calcola le metriche solo se possibile
if calculation_possible:

    # --- API Recall ---
    print("\nCalcolo API Recall...")
    if reference_apis:
        metrics["api_recall"]["baseline"] = calculate_api_recall(generated_code_baseline, reference_apis)
        metrics["api_recall"]["rag"] = calculate_api_recall(generated_code_rag, reference_apis)
    else:
         metrics["api_recall"]["baseline"] = None
         metrics["api_recall"]["rag"] = None

    # --- ChrF ---
    print("Calcolo ChrF...")
    if reference_code and 'calculate_chrf' in locals():
        metrics["chrf"]["baseline"] = calculate_chrf(generated_code_baseline, reference_code)
        metrics["chrf"]["rag"] = calculate_chrf(generated_code_rag, reference_code)
    else:
        metrics["chrf"]["baseline"] = None
        metrics["chrf"]["rag"] = None

    # --- CodeBLEU ---
    print("Calcolo CodeBLEU...")
    if reference_code and 'calculate_codebleu' in locals():
        # Assicurati che la predizione non sia vuota, potrebbe causare errori in codebleu
        pred_baseline = generated_code_baseline if generated_code_baseline else ""
        pred_rag = generated_code_rag if generated_code_rag else ""
        metrics["codebleu"]["baseline"] = calculate_codebleu(pred_baseline, reference_code)
        metrics["codebleu"]["rag"] = calculate_codebleu(pred_rag, reference_code)
    else:
        metrics["codebleu"]["baseline"] = None
        metrics["codebleu"]["rag"] = None

    # --- Stampa Risultati Metriche ---
    print("\n--- Risultati Metriche Automatiche ---")
    print(f"| Metrica         | Baseline        | RAG             |")
    print(f"|-----------------|-----------------|-----------------|")
    api_ref_count = len(reference_apis) if reference_apis else 0
    print(f"| API Recall      | {metrics['api_recall'].get('baseline', 'N/A'):<15.4f} | {metrics['api_recall'].get('rag', 'N/A'):<15.4f} | (Ref APIs: {api_ref_count})")
    print(f"| ChrF            | {metrics['chrf'].get('baseline', 'N/A'):<15} | {metrics['chrf'].get('rag', 'N/A'):<15} |")
    # Gestisci None per CodeBLEU se il calcolo fallisce
    cb_baseline_str = f"{metrics['codebleu'].get('baseline'):.4f}" if isinstance(metrics['codebleu'].get('baseline'), float) else str(metrics['codebleu'].get('baseline', 'N/A'))
    cb_rag_str = f"{metrics['codebleu'].get('rag'):.4f}" if isinstance(metrics['codebleu'].get('rag'), float) else str(metrics['codebleu'].get('rag', 'N/A'))
    print(f"| CodeBLEU        | {cb_baseline_str:<15} | {cb_rag_str:<15} |")
    print("-" * 50)

else:
     print("Calcolo metriche saltato a causa di dati mancanti.")


--- Passo 6.B: Calcolo Metriche di Confronto ---
ERRORE: Impossibile calcolare metriche.
Verifica che 'lca_dataset_split', 'generated_code_rag', 'generated_code_baseline' siano definiti.
Calcolo metriche saltato a causa di dati mancanti.


In [None]:
# Passo 6.C: Analisi Direzione 2 (Fallimenti RAG) e Salvataggio

import textwrap
import json
import os

print("\n" + "="*30)
print("--- Passo 6.C: Analisi Direzione 2 e Salvataggio ---")
print("="*30)

# Verifica che le metriche e gli altri dati siano disponibili
analysis_possible = (
    'metrics' in locals() and metrics and # Verifica che il dizionario metriche esista
    'lca_dataset_split' in locals() and lca_dataset_split and
    'generated_code_rag' in locals() and generated_code_rag is not None and
    'generated_code_baseline' in locals() and generated_code_baseline is not None and
    'instruction' in locals() and instruction and
    'retrieved_snippets' in locals()
)

if analysis_possible:
    sample_index = 0 # Indice dell'esempio
    sample = lca_dataset_split[sample_index]
    repo_full_name = sample.get('repo_full_name')

    # --- 1. Analisi Fallimenti RAG (Direzione 2) ---
    print("\n--- Analisi Performance RAG (Direzione 2) ---")
    rag_failed = False
    failure_reasons = []

    # Logica per determinare se RAG ha fallito (basata sulle metriche di 6.B)
    # Confrontiamo se le metriche sono numeriche e RAG è peggiore
    try:
        # API Recall
        recall_rag = metrics.get("api_recall", {}).get("rag")
        recall_baseline = metrics.get("api_recall", {}).get("baseline")
        if isinstance(recall_rag, (int, float)) and isinstance(recall_baseline, (int, float)):
            if recall_rag < recall_baseline:
                rag_failed = True
                failure_reasons.append("API Recall RAG < Baseline")

        # CodeBLEU (solo se non già fallito e metriche disponibili/valide)
        if not rag_failed:
             codebleu_rag = metrics.get("codebleu", {}).get("rag")
             codebleu_baseline = metrics.get("codebleu", {}).get("baseline")
             # CodeBLEU ritorna None se fallisce, controlla sia float
             if isinstance(codebleu_rag, float) and isinstance(codebleu_baseline, float):
                  if codebleu_rag < codebleu_baseline:
                       rag_failed = True
                       failure_reasons.append("CodeBLEU RAG < Baseline")

        # ChrF (meno indicativo, ma puoi aggiungerlo se vuoi)
        # if not rag_failed:
        #      chrf_rag = metrics.get("chrf", {}).get("rag")
        #      chrf_baseline = metrics.get("chrf", {}).get("baseline")
        #      if isinstance(chrf_rag, (int, float)) and isinstance(chrf_baseline, (int, float)):
        #           if chrf_rag < chrf_baseline - 5: # Es. se ChrF è peggiore di 5 punti
        #                rag_failed = True
        #                failure_reasons.append("ChrF RAG < Baseline (significativamente)")

    except Exception as e:
        print(f"- Errore durante il confronto metriche per analisi fallimenti: {e}")

    # Se RAG è considerato fallito, guida l'analisi
    if rag_failed:
        print(f"\n**ANALISI RICHIESTA: RAG sembra aver performato peggio della Baseline per Sample {sample_index}.**")
        print(f"  Motivo/i rilevato/i: {', '.join(failure_reasons)}")

        print("\n  **1. Esamina gli Snippet Recuperati (stampati qui sotto):**")
        if retrieved_snippets:
            for i, snippet in enumerate(retrieved_snippets):
                print(f"\n  --- Snippet {i+1} ---")
                snippet_preview = '\n'.join(snippet.splitlines()[:10]) # Mostra prime 10 righe
                if len(snippet_preview) > 400: snippet_preview = snippet_preview[:400] + "..."
                elif len(snippet) > len(snippet_preview): snippet_preview += "\n..."
                print(textwrap.indent(snippet_preview, '    ')) # Indenta per chiarezza
        else:
            print("    (Nessuno snippet recuperato)")

        print("\n  **2. Punti Chiave da Analizzare:**")
        print("     - Qualità Retrieval (BM25): Gli snippet erano realmente utili/rilevanti per l'istruzione?")
        print("     - Correttezza Snippet: Contenevano errori o erano fuorvianti?")
        print("     - Integrazione LLM: Il modello RAG ha ignorato/usato male gli snippet?")
        print("     - Causa Probabile: Il problema è nel retriever (BM25) o nel generatore (LLM)?")

        # Log del caso di fallimento per analisi futura
        failure_log_path = os.path.join(drive_save_path, 'rag_failures.jsonl')
        failure_analysis_data = {
             "sample_index": sample_index,
             "repo_full_name": repo_full_name,
             "failure_reasons": failure_reasons,
             "instruction": instruction,
             "retrieved_snippets": retrieved_snippets, # Salva gli snippet completi nel log
             "generated_code_rag": generated_code_rag,
             "generated_code_baseline": generated_code_baseline,
             "metrics": metrics # Salva tutte le metriche calcolate
        }
        try:
             with open(failure_log_path, 'a', encoding='utf-8') as f: # Modalità Append (aggiunge)
                  f.write(json.dumps(failure_analysis_data) + '\n') # Scrivi come JSON per riga
             print(f"\n  --> Caso di fallimento RAG loggato in: {failure_log_path}")
        except Exception as e:
             print(f"\n  ERRORE durante il logging del caso di fallimento: {e}")

    else:
        print("\nAnalisi Performance RAG: Nessun chiaro segnale di fallimento rispetto alla Baseline basato sulle metriche attuali.")

    # --- 2. Salvataggio Risultati Complessivi del Sample ---
    print("\n--- Salvataggio Risultati Complessivi ---")
    results_save_path = os.path.join(drive_save_path, 'results')
    os.makedirs(results_save_path, exist_ok=True)
    result_filename = f"result_sample_{sample_index}_{repo_full_name}.json" # Nome file più descrittivo
    result_full_path = os.path.join(results_save_path, result_filename)

    # Prepara i dati da salvare (con anteprime per ridurre dimensione)
    reference_code = sample.get('reference') # Recupera di nuovo se necessario
    result_data = {
        "sample_index": sample_index,
        "repo_full_name": repo_full_name,
        "instruction": instruction,
        # Salva solo anteprime di snippet e riferimento nel risultato principale
        "retrieved_snippets_preview": [s[:300]+"..." for s in retrieved_snippets] if retrieved_snippets else [],
        "generated_code_baseline": generated_code_baseline,
        "generated_code_rag": generated_code_rag,
        "reference_code_preview": textwrap.shorten(reference_code or "N/A", width=600, placeholder="..."),
        "metrics": metrics, # Include il dizionario completo delle metriche
        "rag_failed_analysis_triggered": rag_failed
    }

    try:
        with open(result_full_path, 'w', encoding='utf-8') as f:
             # Usa default=str per gestire tipi non serializzabili (es. None)
            json.dump(result_data, f, indent=2, default=str)
        print(f"Risultati dettagliati per Sample {sample_index} salvati in: {result_full_path}")
    except Exception as e:
        print(f"\nERRORE durante il salvataggio dei risultati dettagliati: {e}")

else:
    print("\n" + "="*30)
    print("--- Analisi/Salvataggio Saltati ---")
    print("Impossibile procedere. Dati mancanti dai passi precedenti (metriche, output generati, ecc.).")
    print("="*30)

print("\n--- Fine Passo 6.C ---")


--- Passo 6.C: Analisi Direzione 2 e Salvataggio ---

--- Analisi/Salvataggio Saltati ---
Impossibile procedere. Dati mancanti dai passi precedenti (metriche, output generati, ecc.).

--- Fine Passo 6.C ---
