# Data preprocess

In [2]:
from __future__ import annotations
import html
import json
import re
from dataclasses import dataclass
from typing import List, Tuple

import pandas as pd

# -----------------------------
# Utilities
# -----------------------------
_whitespace_re = re.compile(r"\s+", re.MULTILINE)
_word_re = re.compile(r"[A-Za-z_][A-Za-z0-9_]*")
# Punctuation/delimiters typical for C/Java/JS/Solidity/Python
# We'll split code tokens by placing spaces around these, then split.
_separators = r"[\(\)\[\]\{\};:,\.\+\-\*/%&\|\^!=<>\?~]"
_sep_re = re.compile(f"({_separators})")

# String literal patterns ("...", '...', `...`)
_str_re = re.compile(r'(""".*?"""|\'\'\'.*?\'\'\'|".*?"|\'.*?\'|`.*?`)', re.DOTALL)
_num_re = re.compile(r"\b\d+(?:_\d+)*(?:\.[0-9_]+)?\b")

# Comments (C-like): // line, /* block */
_c_line_cmt = re.compile(r"//.*?(?=\n|$)")
_c_block_cmt = re.compile(r"/\*.*?\*/", re.DOTALL)
# Python: # line, triple-quoted docstrings (we'll treat as comments)
_py_line_cmt = re.compile(r"#.*?(?=\n|$)")
_py_doc_cmt = re.compile(r'(""".*?"""|\'\'\'.*?\'\'\')', re.DOTALL)

_c_like_file_ext = {".c", ".h", ".cpp", ".hpp", ".cc", ".java", ".js", ".ts", ".sol", ".cs", ".swift", ".go"}
_py_like_file_ext = {".py"}

@dataclass
class PreprocessConfig:
    lowercase: bool = True
    normalize_numbers: bool = True   # replace numbers with <NUM>
    normalize_strings: bool = True   # replace strings with <STR>
    split_identifiers: bool = True
    keep_empty: bool = False


def normalize_whitespace(text: str) -> str:
    return _whitespace_re.sub(" ", text).strip()


def split_identifiers(token: str) -> List[str]:
    """Split camelCase/mixedCase and snake_case identifiers.
    e.g. `allowedAmount_total` -> ["allowed", "Amount", "total"] -> lowercased later
    """
    # First split snake_case
    parts = re.split(r"_+", token)
    out: List[str] = []
    for p in parts:
        # Split on camelCase transitions (including digits)
        # e.g. HTTPServerError -> HTTP, Server, Error
        for m in re.finditer(r"[A-Z]?[a-z]+|[A-Z]+(?![a-z])|\d+", p):
            out.append(m.group(0))
    return [t for t in out if t]


def extract_comments_and_code(text: str, language_hint: str = "auto") -> Tuple[str, str]:
    """Return (comments_text, code_without_comments).
    Supports a broad regex-based pass for C-like and Python-like languages.
    """
    s = text
    comments: List[str] = []

    # Python-like docstrings first to avoid gobbling by string regex
    for m in _py_doc_cmt.finditer(s):
        comments.append(m.group(0))
    s = _py_doc_cmt.sub("\n", s)

    # C-like block comments
    for m in _c_block_cmt.finditer(s):
        comments.append(m.group(0))
    s = _c_block_cmt.sub("\n", s)

    # Line comments
    for m in _c_line_cmt.finditer(s):
        comments.append(m.group(0))
    s = _c_line_cmt.sub("\n", s)

    for m in _py_line_cmt.finditer(s):
        comments.append(m.group(0))
    s = _py_line_cmt.sub("\n", s)

    comments_text = "\n".join(comments)
    code_wo_comments = s
    return comments_text, code_wo_comments


def tokenize_comment_text(text: str, cfg: PreprocessConfig) -> List[str]:
    text = html.unescape(text)
    if cfg.normalize_strings:
        text = _str_re.sub(" <STR> ", text)
    if cfg.normalize_numbers:
        text = _num_re.sub(" <NUM> ", text)
    # Words only for comments; punctuation is less useful
    toks = _word_re.findall(text)
    if cfg.split_identifiers:
        expanded: List[str] = []
        for tok in toks:
            expanded.extend(split_identifiers(tok))
        toks = expanded
    if cfg.lowercase:
        toks = [t.lower() for t in toks]
    return toks if toks or cfg.keep_empty else (toks or [])


def tokenize_code(text: str, cfg: PreprocessConfig) -> List[str]:
    text = html.unescape(text)
    # Replace string and numbers with placeholders to reduce noise
    if cfg.normalize_strings:
        text = _str_re.sub(" <STR> ", text)
    if cfg.normalize_numbers:
        text = _num_re.sub(" <NUM> ", text)
    # Put spaces around separators so we can split and KEEP them
    text = _sep_re.sub(r" \1 ", text)
    # Collapse whitespace
    text = normalize_whitespace(text)
    raw_toks = text.split(" ") if text else []

    toks: List[str] = []
    for tok in raw_toks:
        if not tok:
            continue
        if tok == "<STR>" or tok == "<NUM>" or _sep_re.fullmatch(tok):
            toks.append(tok)
        else:
            if cfg.split_identifiers:
                toks.extend(split_identifiers(tok))
            else:
                toks.append(tok)
    if cfg.lowercase:
        toks = [t.lower() for t in toks]
    return toks if toks or cfg.keep_empty else (toks or [])


def preprocess_code_snippet(snippet: str, cfg: PreprocessConfig) -> Tuple[List[str], List[str]]:
    """Given raw code snippet, return (comment_tokens, code_tokens)."""
    comments_text, code_wo = extract_comments_and_code(snippet)
    cmt_toks = tokenize_comment_text(comments_text, cfg)
    code_toks = tokenize_code(code_wo, cfg)
    return cmt_toks, code_toks


def preprocess_query(q: str, cfg: PreprocessConfig) -> List[str]:
    # For queries, treat like comment text (natural language)
    return tokenize_comment_text(q, cfg)

In [23]:
codes_csv = "data/code_snippets.csv"
queries_csv = "data/test_queries.csv"
out_codes_csv = "data/code_snippets_proc.csv"
out_queries_csv = "data/test_queries_proc.csv"

cfg = PreprocessConfig(
        lowercase=True,
        split_identifiers=True,
        normalize_numbers=True,
        normalize_strings=True,
    )

df = pd.read_csv(codes_csv, engine="python")
if "code" not in df.columns:
    raise ValueError("codes CSV must contain a 'code' column")
comment_tokens_list: List[List[str]] = []
code_tokens_list: List[List[str]] = []
all_tokens_list: List[List[str]] = []

for s in df["code"].astype(str).tolist():
    cmt_toks, code_toks = preprocess_code_snippet(s, cfg)
    comment_tokens_list.append(cmt_toks)
    code_tokens_list.append(code_toks)
    all_tokens_list.append(cmt_toks + code_toks)

# df["comment_tokens"] = [json.dumps(x, ensure_ascii=False) for x in comment_tokens_list]
df["code_tokens"] = [json.dumps(x, ensure_ascii=False) for x in code_tokens_list]
# df["all_tokens"] = [json.dumps(x, ensure_ascii=False) for x in all_tokens_list]
df.to_csv(out_codes_csv, index=False)

# Process queries
df = pd.read_csv(queries_csv, engine="python")
if "query" not in df.columns:
    raise ValueError("queries CSV must contain a 'query' column")
query_tokens_list: List[List[str]] = []
for s in df["query"].astype(str).tolist():
    q_toks = preprocess_query(s, cfg)
    query_tokens_list.append(q_toks)
df["query_tokens"] = [json.dumps(x, ensure_ascii=False) for x in query_tokens_list]
df.to_csv(out_queries_csv, index=False)

code_snippets = pd.read_csv(out_codes_csv, engine="python")['code_tokens']
queries = pd.read_csv(out_queries_csv, engine="python")['query_tokens']

In [None]:
def output_submission_file(results, output_file):
    with open(output_file, 'w') as f:
        f.write("query_id,code_id\n")
        for query_id, code_ids in enumerate(results):
            code_ids_str = ' '.join(map(str, code_ids))
            f.write(f"{query_id+1},{code_ids_str}\n")

# Sparse Retrieval

## TF-IDF

In [71]:
import math
import numpy as np
import json
from collections import defaultdict

class TFIDF:
    '''
    documents: List of documents (strings)
    term_freqs: List of term frequency dictionaries for each document ex: [{'term1': 2, 'term2': 1}, ...]
    doc_freqs: Document frequency dictionary for terms ex: {'term1': 3, 'term2': 5}
    num_docs: Total number of documents
    vocabulary: Set of unique terms across all documents
    idf: Inverse Document Frequency dictionary for terms
    tfidf_matrix: 2D numpy array storing TF-IDF scores for documents
    vocab_to_idx: Mapping from term to its index in the vocabulary
    '''
    def __init__(self):
        self.documents = []
        self.term_freqs = []
        self.doc_freqs = defaultdict(int)
        self.num_docs = 0
        self.vocabulary = set()
        self.idf = defaultdict(float)
        self.tfidf_matrix = None
        self.vocab_to_idx = defaultdict(int)

    def add_document(self, document_tokens):
        if isinstance(document_tokens, str):
            try:
                tokens = json.loads(document_tokens)
            except json.JSONDecodeError:
                
                tokens = document_tokens.split()
        else:
            tokens = document_tokens
        self.documents.append(tokens)
        self.num_docs += 1
        term_count = defaultdict(int)

        for term in tokens:
            term_count[term] += 1
            self.vocabulary.add(term)

        self.term_freqs.append(term_count)

        for term in term_count.keys():
            self.doc_freqs[term] += 1
    
    def _build_vocab_index(self):
        self.vocabulary = list(self.vocabulary)
        self.vocab_to_idx = {term: idx for idx, term in enumerate(self.vocabulary)}
        
    def compute_tfidf(self):
        if self.tfidf_matrix is not None:
            return self.tfidf_matrix
            
        self._build_vocab_index()
        vocab_size = len(self.vocabulary)
        
        for term in self.vocabulary:
            self.idf[term] = math.log((1 + (self.num_docs/self.doc_freqs[term])), 2)
        
        self.tfidf_matrix = np.zeros((self.num_docs, vocab_size))
        
        for doc_idx, term_count in enumerate(self.term_freqs):
            for term, count in term_count.items():
                if term in self.vocab_to_idx:
                    tf = 1 + math.log(count, 2)
                    tfidf_val = tf * self.idf[term]
                    self.tfidf_matrix[doc_idx][self.vocab_to_idx[term]] = tfidf_val
        
        return self.tfidf_matrix

    def get_query_vector(self, query_tokens):
        
        if self.tfidf_matrix is None:
            self.compute_tfidf()
        
        if isinstance(query_tokens, str):
            try:
                tokens = json.loads(query_tokens)
            except json.JSONDecodeError:
                tokens = query_tokens.split()
        else:
            tokens = query_tokens
        
        query_term_count = defaultdict(int)
        for term in tokens:
            query_term_count[term] += 1

        query_vector = np.zeros(len(self.vocabulary))
        
        for term, count in query_term_count.items():
            if term in self.vocab_to_idx:
                tf = 1 + math.log(count, 2)
                idf = self.idf.get(term, 0.0)
                query_vector[self.vocab_to_idx[term]] = tf * idf
                
        return query_vector
    
    def compute_similarity_batch(self, query_vector):
        
        dot_products = np.dot(self.tfidf_matrix, query_vector)
        
        # Document Length Normalization
        doc_norms = np.linalg.norm(self.tfidf_matrix, axis=1)
        query_norm = np.linalg.norm(query_vector)
        
        valid_mask = (doc_norms > 0) & (query_norm > 0)
        similarities = np.zeros(len(doc_norms))
        similarities[valid_mask] = dot_products[valid_mask] / (doc_norms[valid_mask] * query_norm)
        
        return similarities
    
    def get_top_k_similar_documents(self, query_tokens, k):
        
        query_vector = self.get_query_vector(query_tokens)
        similarities = self.compute_similarity_batch(query_vector)
        
        if k >= len(similarities):
            top_k_indices = np.argsort(similarities)[::-1]
        else:
            top_k_indices = np.argpartition(similarities, -k)[-k:]
            top_k_indices = top_k_indices[np.argsort(similarities[top_k_indices])[::-1]]
        
        # need to plus 1 to indices for 1-based indexing
        top_k_indices += 1
        return top_k_indices.tolist()

In [72]:
tfidf = TFIDF()

for code in code_snippets:
    tfidf.add_document(code)
    
tfidf.compute_tfidf()

top_k_results = []
for query in queries:
    top_k = tfidf.get_top_k_similar_documents(query, 10)
    top_k_results.append(top_k)

output_submission_file(top_k_results, 'results/tf_idf_submission.csv')

## BM25

In [None]:
# BM25 implementation
import math
import json
from collections import defaultdict
import numpy as np

class BM25:
    def __init__(self, k1=1.5, b=0.75):
        self.documents = []
        self.term_freqs = []
        self.doc_freqs = defaultdict(int)
        self.num_docs = 0
        self.vocabulary = set()
        self.doc_lengths = []
        self.avg_doc_len = 0
        self.k1 = k1
        self.b = b
        self._precomputed_idf = {}
        self._precomputed_norms = None

    def add_document(self, document_tokens):
        if isinstance(document_tokens, str):
            try:
                terms = json.loads(document_tokens)
            except json.JSONDecodeError:
                terms = document_tokens.split()
        else:
            terms = document_tokens
            
        self.documents.append(terms)
        self.doc_lengths.append(len(terms))
        self.num_docs += 1
        term_count = defaultdict(int)

        for term in terms:
            term_count[term] += 1
            self.vocabulary.add(term)

        self.term_freqs.append(term_count)

        for term in term_count.keys():
            self.doc_freqs[term] += 1
    
    def _precompute_idf(self):
        self.avg_doc_len = sum(self.doc_lengths) / self.num_docs
        for term in self.vocabulary:
            df = self.doc_freqs[term]
            self._precomputed_idf[term] = math.log((self.num_docs - df + 0.5) / (df + 0.5))
    
    def _precompute_normalization_factors(self):
        self._precomputed_norms = np.array([
            self.k1 * (1 - self.b + self.b * (doc_len / self.avg_doc_len))
            for doc_len in self.doc_lengths
        ])
    
    def compute_similarity_batch(self, query_tokens):
        if not self._precomputed_idf:
            self._precompute_idf()
        if self._precomputed_norms is None:
            self._precompute_normalization_factors()
            
        if isinstance(query_tokens, str):
            try:
                query_terms = json.loads(query_tokens)
            except json.JSONDecodeError:
                query_terms = query_tokens.split()
        else:
            query_terms = query_tokens
            
        similarities = np.zeros(self.num_docs)
        
        for doc_idx in range(self.num_docs):
            score = 0.0
            term_freq_doc = self.term_freqs[doc_idx]
            norm_factor = self._precomputed_norms[doc_idx]
            
            for term in query_terms:
                if term in self.vocabulary:
                    tf = term_freq_doc.get(term, 0)
                    if tf > 0:  # 只計算有出現的詞彙
                        idf = self._precomputed_idf[term]
                        normalized_tf = (tf * (self.k1 + 1)) / (tf + norm_factor)
                        score += idf * normalized_tf
            
            similarities[doc_idx] = score

        return similarities

    def get_top_k_similar_documents(self, query_tokens, k):
        similarities = self.compute_similarity_batch(query_tokens)

        if k >= len(similarities):
            top_k_indices = np.argsort(similarities)[::-1]      
        else:
            top_k_indices = np.argpartition(similarities, -k)[-k:]
            top_k_indices = top_k_indices[np.argsort(similarities[top_k_indices])[::-1]]
        
        # need to plus 1 to indices for 1-based indexing
        top_k_indices += 1
        return top_k_indices.tolist()

In [70]:
bm25 = BM25()

for code in code_snippets:
    bm25.add_document(code)

top_k_results = []
for query in queries:
    top_k = bm25.get_top_k_similar_documents(query, 10)
    top_k_results.append(top_k)

output_submission_file(top_k_results, 'results/bm25_submission.csv')

# Dense Retrieval

## Pre-trained model

In [45]:
from transformers import AutoTokenizer, AutoModel
import torch
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize
from tqdm import tqdm
import ast
import numpy as np
import torch.nn.functional as F

model_name = "microsoft/codebert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

def format_str(string):
    for char in ['\r\n', '\r', '\n']:
        string = string.replace(char, ' ')
    return string.strip()


code_snippets = pd.read_csv(out_codes_csv, engine="python")['code']
queries = pd.read_csv(out_queries_csv, engine="python")['query']

codes = [format_str(c)[:512] for c in code_snippets]
queries = [format_str(q)[:512] for q in queries]

def get_embeddings(text_list, batch_size=16):
    all_embeddings = []
    for i in tqdm(range(0, len(text_list), batch_size)):
        batch_texts = text_list[i:i+batch_size]
        inputs = tokenizer(batch_texts, padding=True, truncation=True, max_length=256, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
            cls_embeddings = outputs.last_hidden_state[:, 0, :]   # 取 [CLS] 向量
            norm_embeddings = F.normalize(cls_embeddings, p=2, dim=1)
            all_embeddings.append(norm_embeddings)
    return torch.cat(all_embeddings, dim=0)

query_embs = get_embeddings(queries)
code_embs  = get_embeddings(codes)

sim_matrix = torch.matmul(query_embs, code_embs.T)

top_k = 10
results = []

for i in range(len(queries)):
    top_indices = torch.topk(sim_matrix[i], k=top_k).indices
    results.append(top_indices.cpu().numpy().tolist())
    
output_submission_file(results, 'results/pre_trained_submission.csv')

100%|██████████| 32/32 [00:21<00:00,  1.48it/s]
100%|██████████| 32/32 [00:28<00:00,  1.13it/s]


### Origin tokenizer

In [None]:
from transformers import AutoTokenizer, AutoModel
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F

tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
model = AutoModel.from_pretrained("microsoft/codebert-base")
model.eval()

def format_str(string):
    for char in ['\r\n', '\r', '\n']:
        string = string.replace(char, ' ')
    return string.strip()

df = pd.read_csv("data/train_queries.csv", engine="python")
codes = [format_str(c)[:512] for c in df["code"]]
queries = [format_str(q)[:512] for q in df["query"]]


def get_embeddings(text_list, batch_size=16):
    all_embeddings = []
    for i in tqdm(range(0, len(text_list), batch_size)):
        batch_texts = text_list[i:i+batch_size]
        inputs = tokenizer(batch_texts, padding='max_length', truncation=True, max_length=256, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
            cls_embeddings = outputs.last_hidden_state[:, 0, :]
            norm_embeddings = F.normalize(cls_embeddings, p=2, dim=1)
            all_embeddings.append(norm_embeddings)
    return torch.cat(all_embeddings, dim=0)

query_embs = get_embeddings(queries)
code_embs  = get_embeddings(codes)

sim_matrix = torch.matmul(query_embs, code_embs.T)


def recall_at_k(sim_matrix, k=10):
    N = sim_matrix.size(0)
    _, topk_idx = torch.topk(sim_matrix, k, dim=1) 
    correct = 0
    for i in range(N):
        if i in topk_idx[i]: 
            correct += 1
    return correct / N

recall10 = recall_at_k(sim_matrix, k=10)
print(f"Recall@10 = {recall10:.4f}")

100%|██████████| 32/32 [00:27<00:00,  1.18it/s]
100%|██████████| 32/32 [00:29<00:00,  1.10it/s]

Recall@10 = 0.1440





### Self-define tokenizer

In [None]:
input_csv = "data/train_queries.csv"
output_csv = "data/train_queries_proc.csv"

cfg = PreprocessConfig(
        lowercase=True,
        split_identifiers=True,
        normalize_numbers=True,
        normalize_strings=True,
    )

df = pd.read_csv(input_csv, engine="python")
if "code" not in df.columns:
    raise ValueError("codes CSV must contain a 'code' column")
comment_tokens_list: List[List[str]] = []
code_tokens_list: List[List[str]] = []
all_tokens_list: List[List[str]] = []

for s in df["code"].astype(str).tolist():
    cmt_toks, code_toks = preprocess_code_snippet(s, cfg)
    comment_tokens_list.append(cmt_toks)
    code_tokens_list.append(code_toks)
    all_tokens_list.append(cmt_toks + code_toks)

df["code_tokens"] = [json.dumps(x, ensure_ascii=False) for x in code_tokens_list]

df2 = pd.read_csv(input_csv, engine="python")
if "query" not in df2.columns:
    raise ValueError("queries CSV must contain a 'query' column")
query_tokens_list: List[List[str]] = []
for s in df2["query"].astype(str).tolist():
    q_toks = preprocess_query(s, cfg)
    query_tokens_list.append(q_toks)
df2["query_tokens"] = [json.dumps(x, ensure_ascii=False) for x in query_tokens_list]

df_final = pd.concat([df["code_tokens"], df2["query_tokens"]], axis=1)
df_final.to_csv(output_csv, index=False)

In [49]:
from transformers import AutoTokenizer, AutoModel
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F

tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
model = AutoModel.from_pretrained("microsoft/codebert-base")
model.eval()

def format_str(string):
    for char in ['\r\n', '\r', '\n']:
        string = string.replace(char, ' ')
    return string.strip()

df = pd.read_csv("data/train_queries_proc.csv", engine="python")

codes = [eval(c) for c in df["code_tokens"]]      
queries = [eval(q) for q in df["query_tokens"]]

def tokens_to_ids_batch(token_lists, pad_token_id, max_len=256):
    id_lists = [tokenizer.convert_tokens_to_ids(tokens[:max_len]) for tokens in token_lists]
    max_len_in_batch = min(max_len, max(len(ids) for ids in id_lists))
    padded = [ids + [pad_token_id] * (max_len_in_batch - len(ids)) for ids in id_lists]
    attention_masks = [[1] * len(ids) + [0] * (max_len_in_batch - len(ids)) for ids in id_lists]
    return torch.tensor(padded), torch.tensor(attention_masks)

def get_embeddings_from_tokens(token_lists, batch_size=16, max_len=512):
    all_embeddings = []
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id 
    for i in tqdm(range(0, len(token_lists), batch_size)):
        batch_tokens = token_lists[i:i+batch_size]
        input_ids, attn_masks = tokens_to_ids_batch(batch_tokens, pad_id, max_len)
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attn_masks)
            cls_emb = outputs.last_hidden_state[:, 0, :]
            norm_emb = F.normalize(cls_emb, p=2, dim=1)
            all_embeddings.append(norm_emb)
    return torch.cat(all_embeddings, dim=0)

query_embs = get_embeddings_from_tokens(queries)
code_embs  = get_embeddings_from_tokens(codes)

sim_matrix = torch.matmul(query_embs, code_embs.T)

def recall_at_k(sim_matrix, k=10):
    N = sim_matrix.size(0)
    _, topk_idx = torch.topk(sim_matrix, k, dim=1)
    correct = 0
    for i in range(N):
        if i in topk_idx[i]:
            correct += 1
    return correct / N

recall10 = recall_at_k(sim_matrix, k=10)
print(f"Recall@10 = {recall10:.4f}")

100%|██████████| 32/32 [00:18<00:00,  1.75it/s]
100%|██████████| 32/32 [00:46<00:00,  1.45s/it]

Recall@10 = 0.0340





## Fine-tuned model

In [76]:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from torch.optim import AdamW
import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import os

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model_name = "microsoft/codebert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModel.from_pretrained(model_name)

class CodeSearchModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.encoder = base_model

    def forward(self, code_input_ids, code_attention_mask, query_input_ids, query_attention_mask):
        batch_size = code_input_ids.size(0)
        combined_input_ids = torch.cat([code_input_ids, query_input_ids], dim=0)
        combined_attention_mask = torch.cat([code_attention_mask, query_attention_mask], dim=0)
        outputs = self.encoder(input_ids=combined_input_ids, attention_mask=combined_attention_mask)
        cls_emb = outputs.last_hidden_state[:, 0, :]
        code_emb = F.normalize(cls_emb[:batch_size], p=2, dim=1)
        query_emb = F.normalize(cls_emb[batch_size:], p=2, dim=1)
        return code_emb, query_emb

Using device: cpu


In [77]:
model = CodeSearchModel(base_model)
model

CodeSearchModel(
  (encoder): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm):

## Prepare training dataset

In [78]:
class CodeSearchDataset(Dataset):
    def __init__(self, csv_path, tokenizer, max_len=256):
        df = pd.read_csv(csv_path)
        self.queries = df["query"].tolist()
        self.codes = df["code"].tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.queries)

    def __getitem__(self, idx):
        query = str(self.queries[idx])
        code = str(self.codes[idx])
        code_enc = self.tokenizer(code, padding='max_length', truncation=True,
                                  max_length=self.max_len, return_tensors="pt")
        query_enc = self.tokenizer(query, padding='max_length', truncation=True,
                                   max_length=self.max_len, return_tensors="pt")
        return {
            "code_input_ids": code_enc["input_ids"].squeeze(0),
            "code_attention_mask": code_enc["attention_mask"].squeeze(0),
            "query_input_ids": query_enc["input_ids"].squeeze(0),
            "query_attention_mask": query_enc["attention_mask"].squeeze(0),
        }

## Evaluation metrics

In [79]:
def contrastive_loss(code_emb, query_emb, temperature=0.05):
    sim_matrix = torch.matmul(query_emb, code_emb.T) / temperature  # (B, B)
    labels = torch.arange(sim_matrix.size(0)).to(sim_matrix.device)
    return nn.CrossEntropyLoss()(sim_matrix, labels)

def recall_at_k(sim_matrix, k=10):
    N = sim_matrix.size(0)
    _, topk_idx = torch.topk(sim_matrix, k, dim=1)
    correct = sum(i in topk_idx[i] for i in range(N))
    return correct / N

## Training

In [None]:
def get_embeddings(text_list, model, tokenizer, batch_size=16, max_len=256):
    model.eval()
    embs = []
    for i in tqdm(range(0, len(text_list), batch_size), desc="Embedding"):
        batch_text = text_list[i:i+batch_size]
        inputs = tokenizer(batch_text, padding=True, truncation=True,
                           max_length=max_len, return_tensors="pt").to(device)
        with torch.no_grad():
            out = model.encoder(**inputs).last_hidden_state[:, 0, :]
            embs.append(F.normalize(out, p=2, dim=1))
    return torch.cat(embs, dim=0)

def train_model(model, tokenizer, train_csv, epochs=10, batch_size=8, lr=2e-5, max_len=256):
    dataset = CodeSearchDataset(train_csv, tokenizer, max_len=max_len)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = len(dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )

    best_train_loss = float("inf")
    best_model_path = "best_model.pt"
    model.to(device)

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            code_input_ids = batch["code_input_ids"].to(device)
            code_attention_mask = batch["code_attention_mask"].to(device)
            query_input_ids = batch["query_input_ids"].to(device)
            query_attention_mask = batch["query_attention_mask"].to(device)

            optimizer.zero_grad()
            code_emb, query_emb = model(
                code_input_ids, code_attention_mask,
                query_input_ids, query_attention_mask
            )
            loss = contrastive_loss(code_emb, query_emb, temperature=0.05)
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} - Avg Train Loss: {avg_loss:.4f}")

        if avg_loss < best_train_loss:
            best_train_loss = avg_loss
            torch.save(model.state_dict(), best_model_path)
            print(f"Saved new best model (train_loss={avg_loss:.4f})")

    print(f"Training complete. Best Train Loss: {best_train_loss:.4f}")
    return best_model_path

train_csv = "data/train_queries.csv"
model = CodeSearchModel(base_model)
best_model_path = train_model(model, tokenizer, train_csv, epochs=20, batch_size=8, lr=2e-5)

print("\nEvaluating Recall@10 for Best Model ...")
model.load_state_dict(torch.load(best_model_path, map_location=device))
model.to(device)
model.eval()

def format_str(string):
    for char in ['\r\n', '\r', '\n']:
        string = string.replace(char, ' ')
    return string.strip()

df = pd.read_csv(train_csv)
codes = [format_str(c)[:512] for c in df["code"]]
queries = [format_str(q)[:512] for q in df["query"]]

query_embs = get_embeddings(queries, model, tokenizer)
code_embs = get_embeddings(codes, model, tokenizer)

sim_matrix = torch.matmul(query_embs, code_embs.T)
recall10 = recall_at_k(sim_matrix, k=10)
print(f"Best Model Recall@10 = {recall10:.4f}")

## Inference

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


model_name = "microsoft/codebert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModel.from_pretrained(model_name)

class CodeSearchModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.encoder = base_model

    def forward(self, code_input_ids, code_attention_mask, query_input_ids, query_attention_mask):
        batch_size = code_input_ids.size(0)
        combined_input_ids = torch.cat([code_input_ids, query_input_ids], dim=0)
        combined_attention_mask = torch.cat([code_attention_mask, query_attention_mask], dim=0)
        outputs = self.encoder(input_ids=combined_input_ids, attention_mask=combined_attention_mask)
        cls_emb = outputs.last_hidden_state[:, 0, :]
        code_emb = F.normalize(cls_emb[:batch_size], p=2, dim=1)
        query_emb = F.normalize(cls_emb[batch_size:], p=2, dim=1)
        return code_emb, query_emb

def get_embeddings(text_list, model, tokenizer, batch_size=16, max_len=256):
    model.eval()
    all_embs = []
    for i in tqdm(range(0, len(text_list), batch_size), desc="Embedding"):
        batch_text = text_list[i:i+batch_size]
        inputs = tokenizer(batch_text, padding=True, truncation=True,
                           max_length=max_len, return_tensors="pt").to(device)
        with torch.no_grad():
            out = model.encoder(**inputs).last_hidden_state[:, 0, :]
            emb = F.normalize(out, p=2, dim=1)
            all_embs.append(emb)
    return torch.cat(all_embs, dim=0)

def run_inference(model_path, query_csv, code_csv, output_csv="submission.csv", top_k=10):
    # 載入 fine-tuned 模型
    model = CodeSearchModel(base_model)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    df = pd.read_csv(query_csv)
    queries = [format_str(q)[:512] for q in df["query"]]
    df = pd.read_csv(code_csv)
    codes = [format_str(c)[:512] for c in df["code"]]

    print(f"Loaded {len(queries)} queries and {len(codes)} code snippets")

    query_embs = get_embeddings(queries, model, tokenizer)
    code_embs = get_embeddings(codes, model, tokenizer)

    print("Calculating similarity matrix ...")
    sim_matrix = torch.matmul(query_embs, code_embs.T)

    _, topk_indices = torch.topk(sim_matrix, k=top_k, dim=1)
    results = topk_indices.cpu().numpy().tolist()

    results = [[idx + 1 for idx in code_ids] for code_ids in results]

    output_submission_file(results, output_csv)
    print(f"✅ Saved results to {output_csv}")



model_path = "best_model.pt"       
query_csv = "data/test_queries.csv"          
code_csv = "data/code_snippets.csv"
output_csv = "result/fine_tuned_submission.csv"
run_inference(model_path, query_csv, code_csv, output_csv, top_k=10)