In [1]:
from datasets import load_dataset
import pandas as pd
import numpy as np
import re
import math
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import string
from nltk.stem import WordNetLemmatizer
import time

In [2]:
from sentence_transformers import SentenceTransformer, util
import torch

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# Loading the dataset

In [4]:
# dataset_py = load_dataset("code_search_net", "python")
# train_py = pd.DataFrame(dataset_py['train'])

train_py = pd.read_csv('python_train_dataset.csv')

ground_truth = pd.read_csv('annotationStore.csv')
gt_py = ground_truth[ground_truth['Language'] == 'Python']

queries = pd.read_csv('queries.csv')

merged_py = gt_py.merge(train_py, left_on = 'GitHubUrl', right_on = 'func_code_url')
merged_py.shape

(990, 17)

## Identifier + Doc embedding

In [5]:
model_doc = SentenceTransformer('all-MiniLM-L6-v2')

Downloading (…)e9125/.gitattributes:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)7e55de9125/README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

Downloading (…)55de9125/config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)125/data_config.json:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)e9125/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

Downloading (…)9125/train_script.py:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading (…)7e55de9125/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)5de9125/modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

In [6]:
sample_doc1 = train_py['func_name'] + " " + train_py['func_documentation_string'] # func name + documentation
sample_doc2 = train_py['func_documentation_string'] # only documentation

In [7]:
# embed docs here, pickle best one
sample_doc1_doc_emb = model_doc.encode(sample_doc1, convert_to_tensor=True)
# sample_doc2_doc_emb = model_doc.encode(sample_doc2, convert_to_tensor=True)

In [8]:
# pickle sample_doc1_doc_emb
# with open('func_name_docu_doc_emb.npy', 'wb') as f:
#     np.save(f, sample_doc1_doc_emb.to('cpu').numpy())

In [256]:
# pickle sample_doc2_doc_emb
# with open('doc_emb.npy', 'wb') as f:
#     np.save(f, sample_doc2_doc_emb.to('cpu').numpy())

In [258]:
# np.load('doc_emb.npy')

412178

In [9]:
np.load('func_name_docu_doc_emb.npy')

array([[ 0.02837376, -0.02163807,  0.03949121, ..., -0.10328282,
        -0.090763  , -0.02504425],
       [ 0.0171462 , -0.00621527,  0.02851562, ..., -0.0531636 ,
        -0.10418153, -0.05421131],
       [ 0.01913145,  0.01080318,  0.01642643, ..., -0.06695274,
        -0.10259605, -0.04814855],
       ...,
       [-0.10193831,  0.02762234,  0.00687646, ...,  0.06234076,
         0.02191263,  0.0464052 ],
       [ 0.03554316,  0.01823683,  0.03538615, ...,  0.02640572,
         0.01306683,  0.00430896],
       [-0.09010839,  0.04852536, -0.02157482, ..., -0.0042924 ,
         0.0464542 ,  0.02487415]], dtype=float32)

In [112]:
def return_document_relevances(query):
    """
    Returns all of query's GitHubUrls from labeled dataset (gt_py) and its relevance in descending order.
    """
    return gt_py[gt_py['Query'] == query].sort_values(by='Relevance', ascending=False)

In [113]:
return_document_relevances('convert int to string').head() # example

Unnamed: 0,Language,Query,GitHubUrl,Relevance,Notes
1913,Python,convert int to string,https://github.com/espressif/esptool/blob/c583...,3,
1933,Python,convert int to string,https://github.com/espressif/esptool/blob/c583...,3,
2260,Python,convert int to string,https://github.com/DeepHorizons/iarm/blob/b913...,3,
1723,Python,convert int to string,https://github.com/DeepHorizons/iarm/blob/b913...,1,
2762,Python,convert int to string,https://github.com/commonwealth-of-puerto-rico...,1,


In [207]:
def return_relevant_docs(query, sample_doc, doc_emb, k=10):
#     query_start = time.time()
    que_doc_emb = model_doc.encode(query, convert_to_tensor=True)
#     query_end = time.time()
    
#     print(f'Time to embed query: {query_end - query_start}')
    
    results = util.semantic_search(que_doc_emb, doc_emb)[0]
    top_k_results = results[:k]
    rank = 1
    
    que_doc_df = pd.DataFrame()
    que_doc_score = []
    que_doc_url = []
    que_doc_rank = []
#     que_relevances = []
    for res in top_k_results:
#         print("Cossim: {:.2f}".format(res['score']))
#         print(f"Rank: {rank}")
#         print(train_py.iloc[res['corpus_id']]['func_code_url'])
#         print("\n\n")

        que_doc_score.append(res['score'])
        que_doc_rank.append(rank)
        que_doc_url.append(train_py.iloc[res['corpus_id']]['func_code_url'])
        github_url = train_py.iloc[res['corpus_id']]['func_code_url']
#         matching_url_relevances = gt_py[gt_py['GitHubUrl'] == github_url]['Relevance'].values
#         que_relevances.append(matching_url_relevances)
        
        rank += 1

    que_doc_df['score'] = que_doc_score # similarity score
    que_doc_df['rank'] = que_doc_rank # our search engine rank
    que_doc_df['url'] = que_doc_url # github url of our returned document
#     que_doc_df['gt_relevances'] = que_relevances # relevances of returned document from gt_py
    
    return que_doc_df

In [172]:
# for now, what is the fraction of results that are in ground truth -- compare sample_doc1 and sample_doc2
# eventually, factor in actual relevance scores 
    # e.g. results only include relevance of 1 but relevance of 3 exists in gt

def evaluate_different_doc_emb(query, sample_doc1, sample_doc1_doc_emb, sample_doc2, sample_doc2_doc_emb):
    sample_doc1_res = return_relevant_docs(query, sample_doc1, sample_doc1_doc_emb)
    sample_doc2_res = return_relevant_docs(query, sample_doc2, sample_doc2_doc_emb)

    ground_truth = return_document_relevances(query)

    sample_doc1_merged_rel = sample_doc1_res.merge(ground_truth, how='left', left_on='url', right_on='GitHubUrl')
    sample_doc2_merged_rel = sample_doc2_res.merge(ground_truth, how='left', left_on='url', right_on='GitHubUrl')
    
    print('sample_doc1')
    display(sample_doc1_merged_rel)
    print('sample_doc2')
    display(sample_doc2_merged_rel)

    sample_doc1_frac_in_gt = (sum(sample_doc1_merged_rel['Relevance'].notnull()) 
                              / len(sample_doc1_merged_rel['Relevance']))
    sample_doc2_frac_in_gt = (sum(sample_doc2_merged_rel['Relevance'].notnull()) 
                              / len(sample_doc2_merged_rel['Relevance']))
    
    return sample_doc1_frac_in_gt, sample_doc2_frac_in_gt

In [174]:
q = queries.sample(n=1).values[0][0] # random query
print(f'Query: {q}')
evaluate_different_doc_emb(q, sample_doc1, sample_doc1_doc_emb, sample_doc2, sample_doc2_doc_emb)

Query: priority queue
sample_doc1


Unnamed: 0,score,rank,url,Language,Query,GitHubUrl,Relevance,Notes
0,0.763506,1,https://github.com/flaviogrossi/sockjs-cyclone...,,,,,
1,0.745125,2,https://github.com/Jaymon/prom/blob/b7ad2c259e...,,,,,
2,0.736036,3,https://github.com/Jaymon/prom/blob/b7ad2c259e...,,,,,
3,0.731232,4,https://github.com/keon/algorithms/blob/4d6569...,Python,priority queue,https://github.com/keon/algorithms/blob/4d6569...,3.0,
4,0.731232,4,https://github.com/keon/algorithms/blob/4d6569...,Python,priority queue,https://github.com/keon/algorithms/blob/4d6569...,2.0,The overall class would have been a better res...
5,0.718924,5,https://github.com/SKA-ScienceDataProcessor/in...,,,,,
6,0.71555,6,https://github.com/Erotemic/utool/blob/3b27e1f...,,,,,
7,0.703601,7,https://github.com/SKA-ScienceDataProcessor/in...,,,,,
8,0.702734,8,https://github.com/nerdvegas/rez/blob/1d3b846d...,,,,,
9,0.699726,9,https://github.com/pinax/django-mailer/blob/12...,,,,,


sample_doc2


Unnamed: 0,score,rank,url,Language,Query,GitHubUrl,Relevance,Notes
0,0.838333,1,https://github.com/hubo1016/vlcp/blob/23905522...,Python,priority queue,https://github.com/hubo1016/vlcp/blob/23905522...,2.0,
1,0.838333,1,https://github.com/hubo1016/vlcp/blob/23905522...,Python,priority queue,https://github.com/hubo1016/vlcp/blob/23905522...,1.0,
2,0.793109,2,https://github.com/nerdvegas/rez/blob/1d3b846d...,,,,,
3,0.792819,3,https://github.com/SKA-ScienceDataProcessor/in...,,,,,
4,0.78249,4,https://github.com/diffeo/rejester/blob/5438a4...,,,,,
5,0.75919,5,https://github.com/limpyd/redis-limpyd-jobs/bl...,Python,priority queue,https://github.com/limpyd/redis-limpyd-jobs/bl...,1.0,query somewhat unclear
6,0.75919,5,https://github.com/limpyd/redis-limpyd-jobs/bl...,Python,priority queue,https://github.com/limpyd/redis-limpyd-jobs/bl...,1.0,
7,0.717238,6,https://github.com/Murali-group/halp/blob/6eb2...,,,,,
8,0.711495,7,https://github.com/Murali-group/halp/blob/6eb2...,,,,,
9,0.704453,8,https://github.com/keon/algorithms/blob/4d6569...,Python,priority queue,https://github.com/keon/algorithms/blob/4d6569...,3.0,


(0.18181818181818182, 0.46153846153846156)

In [None]:
# evaluating test query

In [125]:
test_q = 'matrix multiply'

res = return_relevant_docs(query, sample_doc2, sample_doc2_doc_emb, k=10)
gt = return_document_relevances(test_q)

In [128]:
display(res)
display(gt)

Unnamed: 0,score,rank,url
0,0.872561,1,https://github.com/zeaphoo/reston/blob/9650248...
1,0.8333,2,https://github.com/Gandi/gandi.cli/blob/6ee5b8...
2,0.81685,3,https://github.com/ArchiveTeam/wpull/blob/ddf0...
3,0.809639,4,https://github.com/toomore/goristock/blob/e61f...
4,0.809639,5,https://github.com/dgomes/pyipma/blob/cd808abe...
5,0.809639,6,https://github.com/toomore/grs/blob/a1285cb578...
6,0.797438,7,https://github.com/Yelp/kafka-utils/blob/cdb4d...
7,0.793106,8,https://github.com/darxtrix/lehar/blob/8a2fbeb...
8,0.787179,9,https://github.com/gabrielfalcao/dominic/blob/...
9,0.776116,10,https://github.com/pyopenapi/pyswagger/blob/33...


Unnamed: 0,Language,Query,GitHubUrl,Relevance,Notes
1787,Python,matrix multiply,https://github.com/mabuchilab/QNET/blob/cc20d2...,3,
2507,Python,matrix multiply,https://github.com/AndrewAnnex/SpiceyPy/blob/f...,3,
2127,Python,matrix multiply,https://github.com/apache/spark/blob/618d6bff7...,3,
2172,Python,matrix multiply,https://github.com/churchill-lab/emase/blob/ae...,3,
2204,Python,matrix multiply,https://github.com/churchill-lab/emase/blob/ae...,3,
3013,Python,matrix multiply,https://github.com/fogleman/pg/blob/124ea3803c...,3,
2994,Python,matrix multiply,https://github.com/AndrewAnnex/SpiceyPy/blob/f...,3,
1937,Python,matrix multiply,https://github.com/mabuchilab/QNET/blob/cc20d2...,3,
3617,Python,matrix multiply,https://github.com/cmbruns/pyopenvr/blob/68395...,2,
3328,Python,matrix multiply,https://github.com/pymupdf/PyMuPDF/blob/917f2d...,2,


In [129]:
res.merge(gt, how='left', left_on='url', right_on='GitHubUrl')

Unnamed: 0,score,rank,url,Language,Query,GitHubUrl,Relevance,Notes
0,0.872561,1,https://github.com/zeaphoo/reston/blob/9650248...,,,,,
1,0.8333,2,https://github.com/Gandi/gandi.cli/blob/6ee5b8...,,,,,
2,0.81685,3,https://github.com/ArchiveTeam/wpull/blob/ddf0...,,,,,
3,0.809639,4,https://github.com/toomore/goristock/blob/e61f...,,,,,
4,0.809639,5,https://github.com/dgomes/pyipma/blob/cd808abe...,,,,,
5,0.809639,6,https://github.com/toomore/grs/blob/a1285cb578...,,,,,
6,0.797438,7,https://github.com/Yelp/kafka-utils/blob/cdb4d...,,,,,
7,0.793106,8,https://github.com/darxtrix/lehar/blob/8a2fbeb...,,,,,
8,0.787179,9,https://github.com/gabrielfalcao/dominic/blob/...,,,,,
9,0.776116,10,https://github.com/pyopenapi/pyswagger/blob/33...,,,,,


In [44]:
# manually comparing labeled dataset to sample_doc1 to sample_doc2
test_queries = gt_py.reset_index(drop=True).iloc[:5, 1].values
sample_doc1_results = [return_relevant_docs(q, sample_doc1, sample_doc1_doc_emb, k=5) for q in test_queries]
sample_doc2_results = [return_relevant_docs(q, sample_doc2, sample_doc2_doc_emb, k=5) for q in test_queries]

In [71]:
print(f'Query: "{test_queries[1]}"')
sample_doc1_results[1].loc[0, 'url']

Query: "priority queue"


'https://github.com/flaviogrossi/sockjs-cyclone/blob/d3ca053ec1aa1e85f652347bff562c2319be37a2/sockjs/cyclone/utils.py#L38-L41'

In [72]:
print(f'Query: "{test_queries[1]}"')
sample_doc2_results[1].loc[0, 'url']

Query: "priority queue"


'https://github.com/hubo1016/vlcp/blob/239055229ec93a99cc7e15208075724ccf543bd1/vlcp/event/pqueue.py#L1032-L1040'

In [43]:
# query = 'convert int to string'
# sample_doc1_results = return_relevant_docs(query, sample_doc1, sample_doc1_doc_emb, k=5)
# sample_doc2_results = return_relevant_docs(query, sample_doc2, sample_doc2_doc_emb, k=5)

# Evaluation - nDCG

In [219]:
k = 10

df_queries = []
s1_df_urls = []
s2_df_urls = []
for query in queries['query'].values:
    s1_results = return_relevant_docs(query, sample_doc1, sample_doc1_doc_emb, k=k)
    s2_results = return_relevant_docs(query, sample_doc2, sample_doc2_doc_emb, k=k)
    
    df_queries += [query] * k
    s1_df_urls += list(s1_results['url'].values)
    s2_df_urls += list(s2_results['url'].values)
    
s1_eval_df = pd.DataFrame({'language': ['Python'] * len(df_queries), 
                           'query': df_queries, 
                           'url': s1_df_urls})

s2_eval_df = pd.DataFrame({'language': ['Python'] * len(df_queries), 
                           'query': df_queries, 
                           'url': s2_df_urls})

s1_eval_df.to_csv('s1_eval_df.csv')
s2_eval_df.to_csv('s2_eval_df.csv')

In [226]:
test_py = pd.read_csv('python_test_dataset.csv')

In [228]:
test_doc_emb = model_doc.encode(test_py['func_documentation_string'], convert_to_tensor=True)

In [232]:
k = 10

df_queries = []
test_df_urls = []
for query in queries['query'].values:
    test_results = return_relevant_docs(query, test_py['func_documentation_string'], test_doc_emb, k=k)
    
    df_queries += [query] * k
    test_df_urls += list(test_results['url'].values)
    
test_eval_df = pd.DataFrame({'language': ['Python'] * len(df_queries), 
                           'query': df_queries, 
                           'url': test_df_urls})

test_eval_df.to_csv('test_eval_df.csv')

In [239]:
gt_py.merge(test_py, left_on = 'GitHubUrl', right_on = 'func_code_url')

Unnamed: 0.1,Language,Query,GitHubUrl,Relevance,Notes,Unnamed: 0,repository_name,func_path_in_repository,func_name,whole_func_string,language,func_code_string,func_code_tokens,func_documentation_string,func_documentation_tokens,split_name,func_code_url
0,Python,how to determine a string is a valid word,https://github.com/a-tal/nagaram/blob/2edcb0ef...,0,Query is unclear,19724,a-tal/nagaram,nagaram/scrabble.py,valid_scrabble_word,"def valid_scrabble_word(word):\n """"""Checks ...",python,"def valid_scrabble_word(word):\n """"""Checks ...","['def', 'valid_scrabble_word', '(', 'word', ')...",Checks if the input word could be played with ...,"['Checks', 'if', 'the', 'input', 'word', 'coul...",test,https://github.com/a-tal/nagaram/blob/2edcb0ef...
1,Python,how to determine a string is a valid word,https://github.com/a-tal/nagaram/blob/2edcb0ef...,1,,19724,a-tal/nagaram,nagaram/scrabble.py,valid_scrabble_word,"def valid_scrabble_word(word):\n """"""Checks ...",python,"def valid_scrabble_word(word):\n """"""Checks ...","['def', 'valid_scrabble_word', '(', 'word', ')...",Checks if the input word could be played with ...,"['Checks', 'if', 'the', 'input', 'word', 'coul...",test,https://github.com/a-tal/nagaram/blob/2edcb0ef...
2,Python,get current observable value,https://github.com/Qiskit/qiskit-terra/blob/d4...,0,,4250,Qiskit/qiskit-terra,qiskit/quantum_info/analyzation/make_observabl...,make_dict_observable,def make_dict_observable(matrix_observable):\n...,python,def make_dict_observable(matrix_observable):\n...,"['def', 'make_dict_observable', '(', 'matrix_o...",Convert an observable in matrix form to dictio...,"['Convert', 'an', 'observable', 'in', 'matrix'...",test,https://github.com/Qiskit/qiskit-terra/blob/d4...
3,Python,get current observable value,https://github.com/Qiskit/qiskit-terra/blob/d4...,2,,4250,Qiskit/qiskit-terra,qiskit/quantum_info/analyzation/make_observabl...,make_dict_observable,def make_dict_observable(matrix_observable):\n...,python,def make_dict_observable(matrix_observable):\n...,"['def', 'make_dict_observable', '(', 'matrix_o...",Convert an observable in matrix form to dictio...,"['Convert', 'an', 'observable', 'in', 'matrix'...",test,https://github.com/Qiskit/qiskit-terra/blob/d4...
4,Python,k means clustering,https://github.com/oscarbranson/latools/blob/c...,3,,14669,oscarbranson/latools,latools/filtering/clustering.py,cluster_kmeans,"def cluster_kmeans(data, n_clusters, **kwargs)...",python,"def cluster_kmeans(data, n_clusters, **kwargs)...","['def', 'cluster_kmeans', '(', 'data', ',', 'n...",Identify clusters using K - Means algorithm.\n...,"['Identify', 'clusters', 'using', 'K', '-', 'M...",test,https://github.com/oscarbranson/latools/blob/c...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
69,Python,extracting data from a text file,https://github.com/lcharleux/argiope/blob/8170...,2,,18883,lcharleux/argiope,argiope/abq/pypostproc.py,read_field_report,"def read_field_report(path, data_flag = ""*DATA...",python,"def read_field_report(path, data_flag = ""*DATA...","['def', 'read_field_report', '(', 'path', ',',...",Reads a field output report.,"['Reads', 'a', 'field', 'output', 'report', '.']",test,https://github.com/lcharleux/argiope/blob/8170...
70,Python,get the description of a http status code,https://github.com/funilrys/PyFunceble/blob/cd...,3,,6846,funilrys/PyFunceble,PyFunceble/http_code.py,HTTPCode.get,"def get(self):\n """"""\n Return th...",python,"def get(self):\n """"""\n Return th...","['def', 'get', '(', 'self', ')', ':', 'if', 'P...",Return the HTTP code status.\n\n :retur...,"['Return', 'the', 'HTTP', 'code', 'status', '.']",test,https://github.com/funilrys/PyFunceble/blob/cd...
71,Python,get the description of a http status code,https://github.com/funilrys/PyFunceble/blob/cd...,0,,6846,funilrys/PyFunceble,PyFunceble/http_code.py,HTTPCode.get,"def get(self):\n """"""\n Return th...",python,"def get(self):\n """"""\n Return th...","['def', 'get', '(', 'self', ')', ':', 'if', 'P...",Return the HTTP code status.\n\n :retur...,"['Return', 'the', 'HTTP', 'code', 'status', '.']",test,https://github.com/funilrys/PyFunceble/blob/cd...
72,Python,how to get database table name,https://github.com/apache/airflow/blob/b69c686...,1,,664,apache/airflow,airflow/contrib/hooks/aws_glue_catalog_hook.py,AwsGlueCatalogHook.get_table,"def get_table(self, database_name, table_name)...",python,"def get_table(self, database_name, table_name)...","['def', 'get_table', '(', 'self', ',', 'databa...",Get the information of the table\n\n :p...,"['Get', 'the', 'information', 'of', 'the', 'ta...",test,https://github.com/apache/airflow/blob/b69c686...


## Code Embedding

In [21]:
model_code = SentenceTransformer("all-MiniLM-L6-v2")
sample_code = train_py['whole_func_string']

In [23]:
code_start = time.time()
code_emb = model_code.encode(sample_code, convert_to_tensor=True)
code_end = time.time()

query_start = time.time()
que_code_emb = model_code.encode(query, convert_to_tensor=True)
query_end = time.time()

hits = util.semantic_search(que_code_emb, code_emb)[0]
top_hits = hits[:3]

print(f'Time to embed code: {code_end - code_start}')
print(f'Time to embed query: {query_end - query_start}')

for top_hit in top_hits:
    print("Cossim: {:.2f}".format(top_hit['score']))
    print(merged_py.iloc[top_hit['corpus_id']]['func_code_url'])
    print("\n\n")

Cossim: 0.71


NameError: name 'merged' is not defined

## Tryouts - Pretrained CodeBERT(currently only on python)


In [None]:
from torch.utils.data import TensorDataset, DataLoader
from transformers import RobertaTokenizer, RobertaModel
import torch
import torch.nn as nn

In [None]:
# Setting up the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the tokenizer and model
tokenizer_code = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
model_code = RobertaModel.from_pretrained("demo/python_model").to(device)

# Load the data
# train_data = pd.read_csv('merged_py.csv')  # only on the ones with query and code pairs
train_data = train_py
sample_code = list(train_data['whole_func_string'])

# Tokenize and encode the query
encoded_query = tokenizer_code(query, return_tensors='pt', truncation=True, max_length=512).to(device)
with torch.no_grad():
    query_vec = model_code(**encoded_query)[1]  # Get the pooled output

code_vecs = []
codes = []

# Tokenize and encode the code snippets
for code in sample_code:
    encoded_code = tokenizer_code(code, return_tensors='pt', truncation=True, max_length=512).to(device)
    with torch.no_grad():
        cur_code_vec = model_code(**encoded_code)[1]  # Get the pooled output
    code_vecs.append(cur_code_vec)
    codes.append(code)

# Concatenate the code vectors and move to the same device as the query vector
code_vecs = torch.cat(code_vecs).to(device)

# Calculate the cosine similarities
scores = torch.einsum("ab,cb->ac", query_vec, code_vecs)
scores = torch.softmax(scores, -1)

# Get the top 5 scores and their indices
top_scores, top_indices = torch.topk(scores[0], 5, largest=True)

# Retrieve the top 5 most relevant code snippets using the indices
top_code_snippets = [sample_code[index] for index in top_indices.cpu().numpy()]

# Print the results
for score, snippet in zip(top_scores, top_code_snippets):
    print(f"Relevance Score: {score.item()}\nCode Snippet: {snippet}\n")
