# RAG Code Generation System with ChromaDB

This notebook implements a RAG system using **ChromaDB** for persistent vector storage.

## Advantages of ChromaDB:
- **Persistent Storage**: Data persists between sessions
- **Built-in Metadata**: Easy metadata management
- **Simple API**: More intuitive than FAISS
- **No Manual Index Saving**: Auto-persists to disk

## Table of Contents
1. Installation & Setup
2. Dataset Loading
3. ChromaDB Collection Setup
4. Code Generation
5. Examples & Testing

## 1. Installation

In [1]:
%pip install datasets sentence-transformers chromadb openai python-dotenv torch -q

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.3.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


## 2. Import Libraries

In [1]:
import os
from typing import List, Dict
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
from openai import OpenAI
from getpass import getpass

print("✓ All libraries imported successfully")


✓ All libraries imported successfully


## 3. Setup API Key

In [2]:
# Enter your OpenRouter API key
OPENROUTER_API_KEY = getpass("Enter your OpenRouter API key: ")

# Or set it directly
# OPENROUTER_API_KEY = "your-api-key-here"

## 4. Dataset Functions

In [3]:
def load_humaneval_dataset():
    """Load and process the HumanEval dataset."""
    print("Loading HumanEval dataset...")
    dataset = load_dataset("openai/openai_humaneval", split="test")
    
    examples = []
    for item in dataset:
        examples.append({
            'task_id': item['task_id'],
            'prompt': item['prompt'],
            'canonical_solution': item['canonical_solution'],
            'entry_point': item['entry_point']
        })
    
    print(f"✓ Loaded {len(examples)} examples")
    return examples

## 5. ChromaDB Functions

In [5]:
def create_chroma_client(persist_directory="./chroma_codegen"):
    """Create and return ChromaDB client."""
    print(f"Creating ChromaDB client with persist directory: {persist_directory}")
    client = chromadb.Client(Settings(
        persist_directory=persist_directory,
        anonymized_telemetry=False
    ))
    print("✓ ChromaDB client created")
    return client


def setup_chroma_collection(client, examples, embedding_model_name="sentence-transformers/all-MiniLM-L6-v2", force_reload=False):

    collection = client.get_or_create_collection(
        name="humaneval_code_examples",
        metadata={"description": "HumanEval code examples for RAG"}
    )
    
    # Load embedding model
    print(f"Loading embedding model: {embedding_model_name}")
    embedding_model = SentenceTransformer(embedding_model_name)
    
    # Check if collection already has data
    if collection.count() > 0 and not force_reload:
        print(f"✓ Collection already contains {collection.count()} examples")
        print("   Set force_reload=True to reload data")
        return collection, embedding_model
    
    # If force reload, clear collection
    if force_reload and collection.count() > 0:
        print("Clearing existing collection...")
        client.delete_collection("humaneval_code_examples")
        collection = client.create_collection(
            name="humaneval_code_examples",
            metadata={"description": "HumanEval code examples for RAG"}
        )
    
    # Prepare data for ChromaDB
    print("Preparing data...")
    ids = []
    documents = []
    metadatas = []
    
    for example in examples:
        ids.append(example['task_id'])
        documents.append(example['prompt'])
        metadatas.append({
            'task_id': example['task_id'],
            'entry_point': example['entry_point'],
            'canonical_solution': example['canonical_solution']
        })
    
    # Generate embeddings
    print("Creating embeddings...")
    embeddings_list = embedding_model.encode(documents, show_progress_bar=True).tolist()
    
    # Add to ChromaDB
    print("Adding to ChromaDB collection...")
    collection.add(
        ids=ids,
        documents=documents,
        embeddings=embeddings_list,
        metadatas=metadatas
    )
    
    print(f"✓ Collection contains {collection.count()} examples")
    return collection, embedding_model


def retrieve_similar(query, collection, embedding_model, n_results=3):
    """Retrieve similar code examples from ChromaDB."""
    # Generate query embedding
    query_embedding = embedding_model.encode([query]).tolist()
    
    # Search in ChromaDB
    results = collection.query(
        query_embeddings=query_embedding,
        n_results=n_results,
        include=["documents", "metadatas", "distances"]
    )
    
    # Format results
    similar_examples = []
    for i in range(len(results['ids'][0])):
        similar_examples.append({
            'task_id': results['ids'][0][i],
            'prompt': results['documents'][0][i],
            'canonical_solution': results['metadatas'][0][i]['canonical_solution'],
            'entry_point': results['metadatas'][0][i]['entry_point'],
            'distance': results['distances'][0][i]
        })
    
    return similar_examples

## 6. Code Generation Functions

In [7]:
def build_context(examples):
    """Build context string from retrieved examples."""
    context_parts = []
    for i, ex in enumerate(examples, 1):
        context_parts.append(f"Example {i}:")
        context_parts.append(f"Task: {ex['prompt'].strip()}")
        context_parts.append(f"Solution:\n{ex['canonical_solution'].strip()}")
        context_parts.append("")
    return "\n".join(context_parts)


def extract_code(response):
    """Extract code from LLM response."""
    if "```python" in response:
        start = response.find("```python") + len("```python")
        end = response.find("```", start)
        return response[start:end].strip()
    elif "```" in response:
        start = response.find("```") + 3
        end = response.find("```", start)
        return response[start:end].strip()
    return response.strip()


def generate_code(task_description, retrieved_examples, api_key,
                 model, max_tokens=500, temperature=0.2):
    """Generate code using OpenRouter API."""
    client = OpenAI(
        base_url="https://openrouter.ai/api/v1",
        api_key=api_key
    )
    
    context = build_context(retrieved_examples)
    
    prompt = f"""Based on the following examples of Python coding tasks and solutions, generate a complete function for the new task.

{context}

New Task:
{task_description}

Generate a complete, working Python function that solves this task. Include the function signature and implementation. Only return the code, no explanations."""
    
    print(" Generating code...")
    response = client.chat.completions.create(
        model=model,
        messages=[
            {
                "role": "system",
                "content": "You are an expert Python programmer. Generate clean, efficient, and well-documented code."
            },
            {
                "role": "user",
                "content": prompt
            }
        ],
        max_tokens=max_tokens,
        temperature=temperature
    )
    
    return extract_code(response.choices[0].message.content)

## 7. Main Pipeline Setup

In [9]:
def setup_chromadb_pipeline(api_key, embedding_model_name="sentence-transformers/all-MiniLM-L6-v2",
                           persist_directory="./chroma_codegen", force_reload=False):
    print("="*80)
    print(" Setting up ChromaDB RAG Pipeline")
    print("="*80 + "\n")
    
    # Load dataset
    examples = load_humaneval_dataset()
    
    # Create ChromaDB client
    client = create_chroma_client(persist_directory)
    
    # Setup collection with embeddings
    collection, embedding_model = setup_chroma_collection(
        client, examples, embedding_model_name, force_reload
    )
    
    print("\n" + "="*80)
    print("✓ Pipeline setup complete!")
    print("="*80 + "\n")
    
    return {
        'client': client,
        'collection': collection,
        'embedding_model': embedding_model,
        'api_key': api_key,
        'examples': examples
    }


def generate_code_for_task(pipeline, task_description, n_examples=3,
                           generation_model="deepseek/deepseek-chat-v3.1:free", verbose=True):

    # Retrieve similar examples
    retrieved_examples = retrieve_similar(
        task_description,
        pipeline['collection'],
        pipeline['embedding_model'],
        n_results=n_examples
    )
    
    # Generate code
    generated_code = generate_code(
        task_description,
        retrieved_examples,
        pipeline['api_key'],
        model=generation_model
    )
    
    return {
        'task_description': task_description,
        'generated_code': generated_code,
        'retrieved_examples': retrieved_examples
    }

## 8. Utility Functions

In [10]:
def print_result(result):
    """Pretty print the generation result."""
    print("\n" + "="*80)
    print(" TASK DESCRIPTION:")
    print("="*80)
    print(result['task_description'])
    
    print("\n" + "="*80)
    print(" GENERATED CODE:")
    print("="*80)
    print(result['generated_code'])
    
    print("\n" + "="*80)
    print(" RETRIEVED EXAMPLES:")
    print("="*80)
    for i, ex in enumerate(result['retrieved_examples'], 1):
        print(f"\n{i}. {ex['task_id']} (distance: {ex['distance']:.4f})")
        print(f"   {ex['prompt']}...")

## 9. Initialize the Pipeline

**Note**: ChromaDB persists data automatically! After first run, subsequent runs will be much faster.

In [11]:
# Initialize the pipeline
# Set force_reload=True only if you want to reload the dataset
pipeline = setup_chromadb_pipeline(
    api_key=OPENROUTER_API_KEY,
    embedding_model_name="sentence-transformers/all-MiniLM-L6-v2",
    persist_directory="./chroma_codegen",
    force_reload=False  # Set to True to reload data
)

 Setting up ChromaDB RAG Pipeline

Loading HumanEval dataset...
✓ Loaded 164 examples
Creating ChromaDB client with persist directory: ./chroma_codegen
✓ ChromaDB client created
Loading embedding model: sentence-transformers/all-MiniLM-L6-v2
Preparing data...
Creating embeddings...


Batches:   0%|          | 0/6 [00:00<?, ?it/s]

Adding to ChromaDB collection...
✓ Collection contains 164 examples

✓ Pipeline setup complete!



## 10. Example 1: Maximum Subarray Sum

In [12]:
task1 = """
def find_max_subarray_sum(arr: List[int]) -> int:
    \"\"\" Find the maximum sum of a contiguous subarray (Kadane's algorithm).
    >>> find_max_subarray_sum([-2, 1, -3, 4, -1, 2, 1, -5, 4])
    6
    >>> find_max_subarray_sum([1, 2, 3, 4])
    10
    \"\"\"
"""

result1 = generate_code_for_task(
    pipeline=pipeline,
    task_description=task1,
    n_examples=3
)

print_result(result1)

 Generating code...

 TASK DESCRIPTION:

def find_max_subarray_sum(arr: List[int]) -> int:
    """ Find the maximum sum of a contiguous subarray (Kadane's algorithm).
    >>> find_max_subarray_sum([-2, 1, -3, 4, -1, 2, 1, -5, 4])
    6
    >>> find_max_subarray_sum([1, 2, 3, 4])
    10
    """


 GENERATED CODE:
from typing import List

def find_max_subarray_sum(arr: List[int]) -> int:
    max_current = max_global = arr[0]
    for i in range(1, len(arr)):
        max_current = max(arr[i], max_current + arr[i])
        if max_current > max_global:
            max_global = max_current
    return max_global

 RETRIEVED EXAMPLES:

1. HumanEval/120 (distance: 0.7802)
   
def maximum(arr, k):
    """
    Given an array arr of integers and a positive integer k, return a sorted list 
    of length k with the maximum k numbers in arr.

    Example 1:

        Input: arr = [-3, -4, 5], k = 3
        Output: [-4, -3, 5]

    Example 2:

        Input: arr = [4, -4, 4], k = 2
        Output: [4, 4

## 11. Example 2: Remove Duplicates

In [13]:
task2 = """
def remove_duplicates(nums: List[int]) -> List[int]:
    \"\"\" Remove duplicates from a list while preserving order.
    >>> remove_duplicates([1, 1, 2, 2, 3, 4, 4])
    [1, 2, 3, 4]
    >>> remove_duplicates([1, 2, 3])
    [1, 2, 3]
    \"\"\"
"""

result2 = generate_code_for_task(
    pipeline=pipeline,
    task_description=task2,
    n_examples=3,
    verbose=True
)

print_result(result2)

 Generating code...

 TASK DESCRIPTION:

def remove_duplicates(nums: List[int]) -> List[int]:
    """ Remove duplicates from a list while preserving order.
    >>> remove_duplicates([1, 1, 2, 2, 3, 4, 4])
    [1, 2, 3, 4]
    >>> remove_duplicates([1, 2, 3])
    [1, 2, 3]
    """


 GENERATED CODE:
from typing import List

def remove_duplicates(nums: List[int]) -> List[int]:
    seen = set()
    result = []
    for num in nums:
        if num not in seen:
            seen.add(num)
            result.append(num)
    return result

 RETRIEVED EXAMPLES:

1. HumanEval/26 (distance: 0.2354)
   from typing import List


def remove_duplicates(numbers: List[int]) -> List[int]:
    """ From a list of integers, remove all elements that occur more than once.
    Keep order of elements left the same as in the input.
    >>> remove_duplicates([1, 2, 3, 2, 4])
    [1, 3, 4]
    """
...

2. HumanEval/34 (distance: 0.7434)
   

def unique(l: list):
    """Return sorted unique elements in a list
    >>

## 12. Example 3: Binary Search

In [14]:
task3 = """
def binary_search(arr: List[int], target: int) -> int:
    \"\"\" Perform binary search on a sorted array. Return index or -1.
    >>> binary_search([1, 2, 3, 4, 5, 6, 7], 5)
    4
    >>> binary_search([1, 2, 3, 4, 5], 10)
    -1
    \"\"\"
"""

result3 = generate_code_for_task(
    pipeline=pipeline,
    task_description=task3,
    n_examples=2
)

print_result(result3)

 Generating code...

 TASK DESCRIPTION:

def binary_search(arr: List[int], target: int) -> int:
    """ Perform binary search on a sorted array. Return index or -1.
    >>> binary_search([1, 2, 3, 4, 5, 6, 7], 5)
    4
    >>> binary_search([1, 2, 3, 4, 5], 10)
    -1
    """


 GENERATED CODE:
from typing import List

def binary_search(arr: List[int], target: int) -> int:
    left, right = 0, len(arr) - 1
    while left <= right:
        mid = (left + right) // 2
        if arr[mid] == target:
            return mid
        elif arr[mid] < target:
            left = mid + 1
        else:
            right = mid - 1
    return -1

 RETRIEVED EXAMPLES:

1. HumanEval/116 (distance: 0.7351)
   
def sort_array(arr):
    """
    In this Kata, you have to sort an array of non-negative integers according to
    number of ones in their binary representation in ascending order.
    For similar number of ones, sort based on decimal value.

    It must be implemented like this:
    >>> sort_array

## 13. Inspect ChromaDB Collection

In [15]:
# Get collection statistics
collection = pipeline['collection']

print(f"Collection Name: {collection.name}")
print(f"Total Documents: {collection.count()}")
print(f"\nMetadata: {collection.metadata}")

# Peek at first 3 documents
sample = collection.peek(3)
print(f"\n\nSample Documents:")
print("="*80)
for i, (doc_id, doc) in enumerate(zip(sample['ids'], sample['documents']), 1):
    print(f"\n{i}. ID: {doc_id}")
    print(f"   Document: {doc[:200]}...")

Collection Name: humaneval_code_examples
Total Documents: 164

Metadata: {'description': 'HumanEval code examples for RAG'}


Sample Documents:

1. ID: HumanEval/0
   Document: from typing import List


def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given thr...

2. ID: HumanEval/1
   Document: from typing import List


def separate_paren_groups(paren_string: str) -> List[str]:
    """ Input to this function is a string containing multiple groups of nested parentheses. Your goal is to
    se...

3. ID: HumanEval/2
   Document: 

def truncate_number(number: float) -> float:
    """ Given a positive floating point number, it can be decomposed into
    and integer part (largest integer smaller than given number) and decimals
 ...
