In [1]:
# Part 1: Data Loading, Indexing with Whoosh, and Baseline TF-IDF Retrieval

import os
import pandas as pd
from whoosh import index
from whoosh.fields import Schema, TEXT, ID, STORED
from whoosh.analysis import StemmingAnalyzer
from whoosh.qparser import MultifieldParser
import torch

In [2]:

# 1. Device Setup (for later embedding steps)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

# 2. Load Snippet Dataset
def load_snippets(csv_path: str) -> pd.DataFrame:
    """
    Load code snippet data from a CSV. Expected columns:
      - id: unique identifier for the snippet
      - language: programming language (e.g., Python, Java)
      - snippet: the code text
      - description: brief human-readable description
    """
    df = pd.read_csv(csv_path)
    required_cols = {'id', 'language', 'snippet', 'description'}
    if not required_cols.issubset(df.columns):
        raise ValueError(f"CSV must contain columns {required_cols}")
    df = df.dropna(subset=['id', 'snippet'])
    df['id'] = df['id'].astype(str)
    return df



# 3. Whoosh Index Creation
def create_or_open_index(index_dir: str, schema: Schema) -> index.Index:
    """
    Create a new Whoosh index or open existing one.
    """
    if not os.path.exists(index_dir):
        os.makedirs(index_dir)
        ix = index.create_in(index_dir, schema)
        print(f"[INFO] Created new index at {index_dir}")
    else:
        ix = index.open_dir(index_dir)
        print(f"[INFO] Opened existing index at {index_dir}")
    return ix

# Define schema: id, language, snippet text, description
snippet_schema = Schema(
    id=ID(stored=True, unique=True),
    language=TEXT(stored=True),
    snippet=TEXT(stored=True, analyzer=StemmingAnalyzer()),
    description=TEXT(stored=True)
)

# 4. Indexing Function
def build_index(df: pd.DataFrame, index_dir: str = "indexdir"):
    """
    Build Whoosh index from snippets DataFrame.
    """
    ix = create_or_open_index(index_dir, snippet_schema)
    writer = ix.writer()

    for _, row in df.iterrows():
        writer.update_document(
            id=row['id'],
            language=row['language'],
            snippet=row['snippet'],
            description=row.get('description', '')
        )
    writer.commit()
    print(f"[INFO] Indexed {len(df)} code snippets.")




# 5. Baseline Retrieval Function
def search_snippets(query: str, index_dir: str = "indexdir", top_k: int = 10):
    """
    Perform TF-IDF based search over code snippets.
    Returns a list of (id, score, language, snippet, description).
    """
    ix = index.open_dir(index_dir)
    # Search over `snippet` and `description` fields
    parser = MultifieldParser(["snippet", "description"], schema=ix.schema)
    q = parser.parse(query)
    with ix.searcher() as searcher:
        results = searcher.search(q, limit=top_k)
        hits = []
        for hit in results:
            hits.append({
                'id': hit['id'],
                'score': hit.score,
                'language': hit['language'],
                'snippet': hit['snippet'],
                'description': hit['description']
            })
    return hits


[INFO] Using device: cuda


In [3]:
# Part 2: Embedding-based Re-ranking with CodeBERT on GPU

import os
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from typing import List, Dict


# 1. Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

# 2. Load snippets and prepare embeddings
EMBEDDING_DIR = "embeddings"
EMBEDDING_FILE = os.path.join(EMBEDDING_DIR, "snippet_embeddings.npz")
MODEL_NAME = "microsoft/codebert-base"

# Initialize CodeBERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(device)
model.eval()


def build_snippet_embeddings(csv_path: str, rebuild: bool = False) -> Dict[str, np.ndarray]:
    """
    Build or load embeddings for each code snippet.
    Returns a dict mapping snippet IDs to embedding vectors.
    """
    os.makedirs(EMBEDDING_DIR, exist_ok=True)
    if os.path.exists(EMBEDDING_FILE) and not rebuild:
        print(f"[INFO] Loading existing embeddings from {EMBEDDING_FILE}")
        data = np.load(EMBEDDING_FILE, allow_pickle=True)
        ids = data['ids']
        embeddings = data['embeddings']
        return {id_: emb for id_, emb in zip(ids.tolist(), embeddings)}

    print("[INFO] Computing new snippet embeddings...")
    df = load_snippets(csv_path)
    ids = []
    embs = []
    batch_size = 16

    with torch.no_grad():
        for i in range(0, len(df), batch_size):
            batch = df.iloc[i:i+batch_size]
            texts = (batch['description'] + " " + batch['snippet']).tolist()
            # Tokenize
            enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)
            # Forward
            out = model(**enc)
            # Use [CLS] token embedding
            cls_emb = out.last_hidden_state[:,0,:]
            # Normalize
            cls_emb = F.normalize(cls_emb, p=2, dim=1)
            # Move to CPU and numpy
            batch_embs = cls_emb.cpu().numpy()

            for id_, emb in zip(batch['id'], batch_embs):
                ids.append(id_)
                embs.append(emb)

    ids = np.array(ids)
    embeddings = np.stack(embs)
    np.savez_compressed(EMBEDDING_FILE, ids=ids, embeddings=embeddings)
    print(f"[INFO] Saved embeddings: {EMBEDDING_FILE}")
    return {id_: emb for id_, emb in zip(ids.tolist(), embeddings)}


# 3. Re-ranking function
def rerank_with_embeddings(
    query: str,
    csv_path: str,
    index_dir: str = "indexdir",
    top_k: int = 10,
    alpha: float = 0.5
) -> List[Dict]:
    """
    Perform initial TF-IDF retrieval, then re-rank using combined score:
    combined_score = alpha * normalized_tfidf + (1 - alpha) * cosine_similarity

    Returns list of hits with added 'cosine' and 'combined_score'.
    """
    # Ensure embeddings are built
    embeddings_map = build_snippet_embeddings(csv_path)

    # Initial retrieval
    hits = search_snippets(query, index_dir=index_dir, top_k=top_k)
    if not hits:
        return []

    # Extract tfidf scores
    tfidf_scores = np.array([h['score'] for h in hits], dtype=np.float32)
    # Normalize TF-IDF to [0,1]
    tfidf_norm = (tfidf_scores - tfidf_scores.min()) / (tfidf_scores.max() - tfidf_scores.min() + 1e-8)

    # Embed query
    with torch.no_grad():
        enc = tokenizer(query, return_tensors="pt", truncation=True, padding=True).to(device)
        q_out = model(**enc)
        q_emb = F.normalize(q_out.last_hidden_state[:,0,:], p=2, dim=1).cpu().numpy()[0]

    # Compute cosine similarity with each snippet embedding
    cos_sims = []
    for h in hits:
        emb = embeddings_map.get(h['id'])
        if emb is None:
            cos = 0.0
        else:
            cos = float(np.dot(q_emb, emb))
        cos_sims.append(cos)
    cos_sims = np.array(cos_sims, dtype=np.float32)
    # Normalize cosine to [0,1]
    cos_norm = (cos_sims - cos_sims.min()) / (cos_sims.max() - cos_sims.min() + 1e-8)

    # Compute combined score
    combined = alpha * tfidf_norm + (1 - alpha) * cos_norm

    # Attach scores and sort
    for i, h in enumerate(hits):
        h['cosine'] = float(cos_sims[i])
        h['combined_score'] = float(combined[i])

    hits_sorted = sorted(hits, key=lambda x: x['combined_score'], reverse=True)
    return hits_sorted



[INFO] Using device: cuda


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/498 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/150 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/499M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

In [4]:
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output


# Paths and constants
CSV_PATH = "snippets.csv"
INDEX_DIR = "indexdir"

# Ensure index is built
df = load_snippets(CSV_PATH)
build_index(df, INDEX_DIR)

# UI Widgets
query_input = widgets.Text(
    value='',
    placeholder='Type your code query here',
    description='Query:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)
topk_slider = widgets.IntSlider(
    value=5, min=1, max=20, step=1,
    description='Top K:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)
alpha_slider = widgets.FloatSlider(
    value=0.5, min=0.0, max=1.0, step=0.05,
    description='Alpha (TF-IDF vs. Embeddings):',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)
search_button = widgets.Button(
    description='Search',
    button_style='primary',
    tooltip='Click to run code snippet search',
    icon='search'
)
output = widgets.Output(layout={'border': '1px solid #ddd'})

def on_search_click(b):
    with output:
        clear_output()
        query = query_input.value.strip()
        if not query:
            display(HTML("<p style='color:red;'>Please enter a query.</p>"))
            return
        # Retrieve results
        results = rerank_with_embeddings(
            query=query,
            csv_path=CSV_PATH,
            index_dir=INDEX_DIR,
            top_k=topk_slider.value,
            alpha=alpha_slider.value
        )
        if not results:
            display(HTML("<p>No results found.</p>"))
            return
        
        # Build HTML display
        html = f"<h3>Results for '<em>{query}</em>'</h3>"
        for idx, hit in enumerate(results, 1):
            # Summarize snippet
            summary = ""
            for line in hit['snippet'].splitlines():
                stripped = line.strip()
                if stripped.startswith('#') or stripped.startswith('//') or stripped.startswith('def ') or stripped.startswith('function '):
                    summary = stripped
                    break
            if not summary and hit['snippet'].splitlines():
                summary = hit['snippet'].splitlines()[0].strip()
            
            # Highlight keywords
            keywords = [t.lower() for t in query.split() if len(t) > 1]
            highlighted = ""
            for line in hit['snippet'].splitlines():
                low = line.lower()
                if any(k in low for k in keywords):
                    highlighted += f"<mark>{line}</mark>\n"
                else:
                    highlighted += f"{line}\n"
            
            html += f"""
            <div style="border:1px solid #ccc; padding:10px; margin-bottom:10px; border-radius:5px;">
              <strong>{idx}. [ID: {hit['id']}] ({hit['language']})</strong><br>
              <small>Score: {hit['combined_score']:.4f} (tfidf), cosine: {hit['cosine']:.4f}</small><br>
              <em>Summary:</em> {summary}<br>
              <em>Description:</em> {hit.get('description', '')}<br>
              <pre style="background:#f7f7f7; padding:10px; overflow:auto; white-space:pre-wrap;">{highlighted}</pre>
            </div>
            """
        display(HTML(html))

search_button.on_click(on_search_click)

# Layout and display
ui = widgets.VBox([
    widgets.HTML("<h2>Code Snippet Retrieval</h2>"),
    query_input,
    topk_slider,
    alpha_slider,
    search_button,
    output
])

display(ui)


[INFO] Opened existing index at indexdir
[INFO] Indexed 100 code snippets.


VBox(children=(HTML(value='<h2>Code Snippet Retrieval</h2>'), Text(value='', description='Query:', layout=Layo…

In [6]:
# Annotation export (top-10 TF-IDF)
QUERIES = [
    # First 20 queries
    "convert list of tuples to dict python",
    "merge two sorted arrays in C++",
    "read JSON file in Java",
    "python regex find all numbers in string",
    "format date in Swift",
    "remove duplicates from array in javascript",
    "implement binary search in python",
    "center a div in css",
    "connect to mysql with php pdo",
    "async await fetch json javascript",
    "create and start a thread in Java",
    "template function example C++",
    "count occurrences in list with Counter python",
    "exception handling in Ruby",
    "launch a goroutine in Go",
    "list comprehension in Haskell",
    "map anonymous function in Elixir",
    "define interface in TypeScript",
    "compress directory using tar bash",
    "plot sine wave in MATLAB",
    
]

records = []
for q in QUERIES:
    hits = search_snippets(q, INDEX_DIR, top_k=10)
    for rank, h in enumerate(hits, start=1):
        records.append({
            "query":       q,
            "rank":        rank,
            "snippet_id":  h["id"],
            "language":    h["language"],
            "description": h["description"],
            "snippet":     h["snippet"].replace("\n", "\\n")
        })

annotation_df = pd.DataFrame(records)
annotation_df.to_csv("snippets.csv", index=False)
print("Exported annotation_candidates.csv:", len(annotation_df), "rows")
annotation_df.head()

Exported annotation_candidates.csv: 10 rows


Unnamed: 0,query,rank,snippet_id,language,description,snippet
0,merge two sorted arrays in C++,1,2,C++,Merge two sorted arrays,"vector<int> mergeSorted(const vector<int>& a, ..."
1,format date in Swift,1,5,Swift,Format Date object in Swift,let formatter = DateFormatter()\nformatter.dat...
2,implement binary search in python,1,7,Python,Implement binary search in Python,"def binary_search(arr, x):\n low, high = 0,..."
3,create and start a thread in Java,1,16,Java,Create and start a Java thread,Thread thread = new Thread(() -> {\n System...
4,template function example C++,1,19,C++,Template function example in C++,"template<typename T>\nT add(T a, T b) { return..."


In [None]:
# Evaluation metrics (P@5, nDCG@5)


ann = pd.read_csv('annotation_labels.csv')
res = pd.read_csv('results.csv')
ann['rel_bin'] = ann['relevance'].apply(lambda x: 1 if x.lower().startswith('relevant') else 0)

def precision_at_k(res_df, ann_df, k):
    records = []
    for q in res_df['query'].unique():
        df_res = res_df[res_df['query']==q].sort_values('rank').head(k)
        df_ann = ann_df[ann_df['query']==q]
        merged = pd.merge(df_res, df_ann[['snippet_id','rel_bin']], on='snippet_id', how='left').fillna(0)
        p = merged['rel_bin'].mean()
        records.append({'query': q, f'P@{k}': p})
    return pd.DataFrame(records)

def dcg(rels):
    return sum((2**r - 1) / np.log2(idx + 2) for idx, r in enumerate(rels))

def ndcg_at_k(res_df, ann_df, k):
    records = []
    for q in res_df['query'].unique():
        df_res = res_df[res_df['query']==q].sort_values('rank').head(k)
        df_ann = ann_df[ann_df['query']==q]
        merged = pd.merge(df_res, df_ann[['snippet_id','rel_bin']], on='snippet_id', how='left').fillna(0)
        rels = merged['rel_bin'].tolist()
        ideal = sorted(rels, reverse=True)
        nd = dcg(rels) / dcg(ideal) if dcg(ideal)>0 else 0.0
        records.append({'query': q, f'nDCG@{k}': nd})
    return pd.DataFrame(records)

p5 = precision_at_k(res, ann, 5)
n5 = ndcg_at_k(res, ann, 5)
metrics = pd.merge(p5, n5, on='query')
import ace_tools as tools; tools.display_dataframe_to_user('P@5 & nDCG@5', metrics)