In [None]:
sample_dir = "./sample/chrono"
sample_repo = "https://github.com/chronotope/chrono"

In [None]:
from git import Repo
Repo.clone_from(sample_repo, sample_dir)

In [None]:
import os

def get_all_files(directory):
    file_paths = []
    for root, _, files in os.walk(directory):
        for file in files:
            file_paths.append(os.path.join(root, file))
    return file_paths

all_files = [file for file in get_all_files(sample_dir) if file.endswith('.rs')]
print(all_files)
print(f"Total Rust files: {len(all_files)}")

In [None]:
import tree_sitter_rust as rustts
from tree_sitter import Language, Parser


RUST_LANGUAGE = Language(rustts.language())
parser = Parser(RUST_LANGUAGE)

query = RUST_LANGUAGE.query("""
(
   ((function_item) @function)
)
                            """)

functions = []
for file in all_files:
    with open(file, 'rb') as f:
        code = f.read()
        tree = parser.parse(code)
        captures = query.matches(tree.root_node)
        for v in captures:
            node = v[1]['function'][0]
            n = node.prev_sibling
            comments = []
            while n:
                if n.type == 'line_comment':
                    comments.append(n.text.decode('utf-8'))
                elif n.type == 'attribute_item':
                    comments.append(n.text.decode('utf-8') + '\n')
                else:
                    break
                n = n.prev_sibling

            comment = ''.join(reversed(comments))
            code = node.text.decode('utf-8')
            range = node.range
            whole_text = comment + code
            f = {
                "file": file,
                "text": whole_text,
                "range": range
            }
            functions.append(f)

print(functions)


In [None]:
from pylate import models, indexes
import torch

pylate_model = "joe32140/ColModernBERT-base-msmarco-en-bge"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = models.ColBERT(
    model_name_or_path=pylate_model,
    device=device,
)

index = indexes.Voyager(
    index_folder="pylate-index",
    index_name="index",
    override=True,
)

In [None]:
functions_ids = []
functions_texts = []
i = 0
for f in functions:
    functions_ids.append(str(i))
    functions_texts.append(f["text"])
    i += 1

documents_embeddings = model.encode(
    functions_texts,
    batch_size=32,
    is_query=False,
    show_progress_bar=True,
)

index.add_documents(
    documents_ids=functions_ids,
    documents_embeddings=documents_embeddings,
)

In [None]:
from pylate import retrieve

retriever = retrieve.ColBERT(index=index)

queries_embeddings = model.encode(
    ["Show the code related to parsing a time, excluding tests"],
    batch_size=32,
    is_query=True,
    show_progress_bar=True,
)

scores = retriever.retrieve(
    queries_embeddings=queries_embeddings,
    k=20,
)

for score in scores[0]:
    id = score["id"]
    function = functions[int(id)]
    print(f"Score: {score['score']}")
    print(f"File: {function['file']}")
    print(f"Range: {function['range']}")
    print(f"Text: {function['text']}")