In [None]:
# imports

import re
import json
import random
import requests
from pathlib import Path
from urllib.parse import urlparse

In [3]:
# prepare repositories and paths

github_repos = [
    "https://github.com/Otina12/leetcode-go",
    "https://github.com/Otina12/raft-go",
    "https://github.com/Otina12/interpreter-go",
    "https://github.com/travisjeffery/proglog",
    "https://github.com/melbahja/got"
]

data_dir = Path("./data")
ref_dir = data_dir / "reference_corpus"

In [5]:
# helper functions

def parse_github_url(url):
    parsed = urlparse(url)
    parts = parsed.path.strip("/").split("/")

    if len(parts) < 2:
        raise ValueError(f"Cannot parse owner and repo from {url}")
    
    owner, repo = parts[0], parts[1]
    return owner, repo


def github_api_get(url, params = None):
    headers = {"Accept": "application/vnd.github+json"}
    res = requests.get(url, headers=headers, params=params)

    if res.status_code != 200:
        raise RuntimeError(f"GitHub API error {res.status_code} on {url}: {res.text[:200]}")
    
    return res.json()

def get_default_branch(owner, repo):
    url = f"https://api.github.com/repos/{owner}/{repo}"
    data = github_api_get(url)
    return data["default_branch"]

def get_repo_tree(owner, repo, branch):
    url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{branch}"
    data = github_api_get(url, params={"recursive": "1"})
    return data["tree"]

def download_go_file(owner, repo, branch, path_in_repo):
    raw_url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path_in_repo}"
    headers = {}
    res = requests.get(raw_url, headers=headers)

    if res.status_code != 200:
        print(f"skip {raw_url}, status {res.status_code}")
        return ""
    
    return res.text

def save_local(repo_name: str, path_in_repo: str, content: str):
    target_dir = ref_dir / repo_name / Path(path_in_repo).parent
    target_dir.mkdir(parents=True, exist_ok=True)
    target_path = ref_dir / repo_name / path_in_repo
    target_path.write_text(content, encoding="utf-8")


In [6]:
# save .go files

max_files_per_repo = 10

for repo_url in github_repos:
    owner, repo = parse_github_url(repo_url)
    print(f"collecting from {owner}/{repo}...")

    branch = get_default_branch(owner, repo)
    tree = get_repo_tree(owner, repo, branch)

    count_go = 0

    for item in tree:
        if item["type"] != "blob":
            continue

        file_path = item["path"]
        if not file_path.endswith(".go"):
            continue

        content = download_go_file(owner, repo, branch, file_path)
        if not content:
            continue

        save_local(repo, file_path, content)
        count_go += 1

        if count_go >= max_files_per_repo:
            break

    print(f"collected {count_go} .go files from {owner}/{repo}\n")

collecting from Otina12/leetcode-go...
collected 10 .go files from Otina12/leetcode-go

collecting from Otina12/raft-go...
collected 9 .go files from Otina12/raft-go

collecting from Otina12/interpreter-go...
collected 9 .go files from Otina12/interpreter-go

collecting from travisjeffery/proglog...
collected 10 .go files from travisjeffery/proglog

collecting from melbahja/got...
collected 10 .go files from melbahja/got



In [7]:
# load docs and split into chunks (functions + maybe extra code)

all_docs = []

for go_file in ref_dir.rglob("*.go"):
    text = go_file.read_text(encoding="utf-8", errors="ignore")
    repo = go_file.relative_to(ref_dir).parts[0]

    all_docs.append({
        "repo": repo,
        "path": str(go_file.relative_to(ref_dir)),
        "text": text,
    })

print(f"loaded {len(all_docs)} .go files")

func_pattern = re.compile(r"^func\s", re.MULTILINE)

chunks = []
chunk_id = 1

for doc in all_docs:
    text = doc["text"]
    repo = doc["repo"]
    source_path = doc["path"]

    matches = list(func_pattern.finditer(text))

    if not matches:
        chunks.append({
            "id": f"chunk_{chunk_id:05d}",
            "repo": repo,
            "source_path": source_path,
            "text": text.strip(),
        })
        chunk_id += 1
        continue

    for idx, m in enumerate(matches):
        start = m.start()
        end = matches[idx + 1].start() if idx + 1 < len(matches) else len(text)

        chunk_text = text[start:end].strip()
        if not chunk_text:
            continue

        chunks.append({
            "id": f"chunk_{chunk_id:05d}",
            "repo": repo,
            "source_path": source_path,
            "text": chunk_text,
        })

        chunk_id += 1

print(f"built {len(chunks)} chunks")

loaded 48 .go files
built 371 chunks


In [8]:
# helper methods for changing existing functions (positive examples)

def remove_comments(code):
    code = re.sub(r"//.*", "", code) # remove // comments
    code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL) # remove /* */ comments
    return code

# I analyzed the most common tokens for renaming to have the best outcome
def rename_vars(code):
    repl = {
        "i": "idx",
        "j": "pos",
        "n": "count",
        "err": "e",
        "cfg": "config",
        "token": "tok",
        "ast": "as_tree",
        "leader": "l",
        "server": "s"
    }

    for old, new in repl.items():
        code = re.sub(rf"\b{old}\b", new, code)

    return code

def reorder_small_blocks(code):
    parts = code.split("\n")

    if len(parts) > 6:
        middle = parts[2:-2]
        random.shuffle(middle)
        code = "\n".join([parts[0], parts[1], *middle, parts[-2], parts[-1]])

    return code

def make_positive_variant(code):
    code2 = remove_comments(code)
    code2 = rename_vars(code2)
    code2 = reorder_small_blocks(code2)

    return code2.strip()

In [9]:
# create 15 positive cases

positive_cases = []
sample_chunks = chunks[:15]

for c in sample_chunks:
    original_code = c["text"]
    transformed_code = make_positive_variant(original_code)
    positive_cases.append({
        "id": f"pos_{c['id']}",
        "query_code": transformed_code,
        "is_positive": True,
        "source_hint": f"{c['source_path']}",
        "notes": "renamed vars, removed comments, small reorder"
    })

print(f"made {len(positive_cases)} positive examples")

made 15 positive examples


In [10]:
# create 15 negative cases

negative_cases = []

NEGATIVE_SNIPPETS = [
    """package main
func SumSlice(nums []int) int {
    s := 0
    for _, x := range nums {
        s += x
    }
    return s
}
""",
    """package main
func ReverseString(s string) string {
    r := []rune(s)
    for i, j := 0, len(r)-1; i < j; i, j = i+1, j-1 {
        r[i], r[j] = r[j], r[i]
    }
    return string(r)
}
""",
    """package main
func UniqueInts(nums []int) []int {
    m := make(map[int]struct{})
    out := make([]int, 0, len(nums))
    for _, v := range nums {
        if _, ok := m[v]; !ok {
            m[v] = struct{}{}
            out = append(out, v)
        }
    }
    return out
}
""",
    """package main
func Factorial(n int) int {
    if n == 0 {
        return 1
    }
    return n * Factorial(n-1)
}
""",
    """package main
func Fibonacci(n int) int {
    if n <= 1 {
        return n
    }
    return Fibonacci(n-1) + Fibonacci(n-2)
}
""",
    """package main
func FindMax(nums []int) int {
    max := nums[0]
    for _, v := range nums {
        if v > max {
            max = v
        }
    }
    return max
}
""",
    """package main
func CountVowels(s string) int {
    count := 0
    for _, c := range s {
        switch c {
        case 'a', 'e', 'i', 'o', 'u', 'A', 'E', 'I', 'O', 'U':
            count++
        }
    }
    return count
}
""",
    """package main
func IsPalindrome(s string) bool {
    r := []rune(s)
    for i := 0; i < len(r)/2; i++ {
        if r[i] != r[len(r)-1-i] {
            return false
        }
    }
    return true
}
""",
    """package main
func BubbleSort(arr []int) []int {
    n := len(arr)
    for i := 0; i < n; i++ {
        for j := 0; j < n-i-1; j++ {
            if arr[j] > arr[j+1] {
                arr[j], arr[j+1] = arr[j+1], arr[j]
            }
        }
    }
    return arr
}
""",
    """package main
func WordCount(s string) int {
    words := 0
    inWord := false
    for _, c := range s {
        if c == ' ' || c == '\n' || c == '\t' {
            inWord = false
        } else if !inWord {
            words++
            inWord = true
        }
    }
    return words
}
""",
    """package main
func ToUpperCase(s string) string {
    result := ""
    for _, c := range s {
        if c >= 'a' && c <= 'z' {
            result += string(c - 32)
        } else {
            result += string(c)
        }
    }
    return result
}
""",
    """package main
func Average(nums []float64) float64 {
    sum := 0.0
    for _, v := range nums {
        sum += v
    }
    return sum / float64(len(nums))
}
""",
    """package main
func CountEven(nums []int) int {
    count := 0
    for _, v := range nums {
        if v%2 == 0 {
            count++
        }
    }
    return count
}
""",
    """package main
func Contains(nums []int, target int) bool {
    for _, v := range nums {
        if v == target {
            return true
        }
    }
    return false
}
""",
    """package main
func MergeSlices(a, b []int) []int {
    merged := make([]int, 0, len(a)+len(b))
    merged = append(merged, a...)
    merged = append(merged, b...)
    return merged
}
""",
]

for idx, code in enumerate(NEGATIVE_SNIPPETS, start=1):
    negative_cases.append({
        "id": f"neg_{idx:03d}",
        "query_code": code.strip(),
        "is_positive": False,
        "source_hint": None,
        "notes": "manually written different Go function",
    })

print(f"made {len(negative_cases)} negative examples")


made 15 negative examples


In [11]:
# dump all testcases into test_dataset.json

out_path = data_dir / "test_dataset.json"

dataset = positive_cases + negative_cases
random.shuffle(dataset)

with open(out_path, "w", encoding="utf-8") as f:
    json.dump(dataset, f, indent=2)

print(f"saved {len(dataset)} test cases to {out_path}")

saved 30 test cases to data\test_dataset.json
