In [33]:
import sys
sys.path.append('../src')

import pandas as pd
import os

from utils import get_combined_df, prepare_code_triplets, full_tokenize
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from tqdm import tqdm

from bm25_v2 import BM25Searcher
from eval import ModelEvaluator, SearchEvaluator

In [36]:
pd.read_parquet('../data/2_7/facebook_react/cache/4X_random_split/code_df.parquet').train_commit_id.nunique()

1890

In [11]:
class Args:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

args = Args(
    index_path='../data/2_7/facebook_react/index_commit_tokenized', repo_path='../data/2_7/facebook_react', k=1000, n=100, model_path='microsoft/codebert-base', overwrite_cache=False, batch_size=32, num_epochs=10, learning_rate=5e-05, run_name='repr_0.1663', notes='reproducing current best 0.1663 MAP result for CodeReranker', num_positives=10, num_negatives=10, train_depth=1000, num_workers=8, train_commits=1000, psg_cnt=25, aggregation_strategy='sump', use_gpu=True, rerank_depth=100, do_train=True, do_eval=True, eval_gold=True, openai_model='gpt4', overwrite_eval=False, sanity_check=True, debug=False, best_model_path=None, bert_best_model='data/combined_commit_train/best_model', psg_len=350, psg_stride=250, ignore_gold_in_training=False, eval_folder='repr_0.1663', use_gpt_train=True
)

metrics =['MAP', 'P@1', 'P@10', 'P@20', 'P@30', 'MRR', 'R@1', 'R@10', 'R@100', 'R@1000']
repo_path = args.repo_path
repo_name = repo_path.split('/')[-1]
index_path = args.index_path
K = args.k
n = args.n
combined_df = get_combined_df(repo_path)
BM25_AGGR_STRAT = 'sump'
eval_path = os.path.join(repo_path, 'eval')
if not os.path.exists(eval_path):
    os.makedirs(eval_path)

bm25_searcher = BM25Searcher(index_path)
evaluator = SearchEvaluator(metrics)
model_evaluator = ModelEvaluator(bm25_searcher, evaluator, combined_df)

test_path = os.path.join('..', 'gold', 'facebook_react', 'v2_facebook_react_gpt4_gold.parquet')
gold_df = pd.read_parquet(test_path)

Loaded index at ../data/2_7/facebook_react/index_commit_tokenized
Index Stats: {'total_terms': 7587973, 'documents': 73765, 'non_empty_documents': 73765, 'unique_terms': 14602}


In [92]:
cache_path = os.path.join(args.repo_path, 'cache', 'X_diff_split')
code_df = pd.read_parquet('../data/2_7/facebook_react/cache/repr_0.1663/code_df.parquet')

In [54]:
def prep_line(line):
    return line.rstrip().lstrip()

def parse_diff_remove_minus(diff):
    return [
        line[1:] if line.startswith('+') else line
        for line in diff.split('\n')
        if not (line.startswith('-') or len(line) == 0 or (line.startswith('@@') and line.count('@@') > 1))
        and len(prep_line(line)) > 2
    ]

def full_parse_diffs(diff):
   # keep both insertions and deletions to be passed to the model
    return [
        line[1:] if (line.startswith('+') or line.startswith('-')) else line
        for line in diff.split('\n')
        if not (len(line) == 0 or (line.startswith('@@') and line.count('@@') > 1))
    ]

def full_parse_diffs_split(diff):
   # keep both insertions and deletions to be passed to the model
    res = []
    cur = []
    for line in diff.split('\n'):
        if not len(line) == 0:
            if (line.startswith('@@') and line.count('@@') > 1):
                if cur:
                    res.append(cur)
                cur = []
            else:
                cur.append(line[1:] if (line.startswith('+') or line.startswith('-')) else line)
    if cur:
        res.append(cur)
    return res

def full_tokenize(s, tokenizer):
        return tokenizer.encode_plus(s, max_length=None, truncation=False, return_tensors='pt', add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False)['input_ids'].squeeze().tolist()

In [12]:
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base')

In [79]:
def find_diff_tokens(diff):
    ntokens = len(full_tokenize(diff, tokenizer))
    return ntokens

In [78]:
# code_df = pd.read_parquet('../data/merged_code_df/multi_code_df.parquet')
# code_df.train_commit_id.nunique()

In [73]:
# average token in each diff (only insertions)
total_rows = 0
total_diff_tokens = 0
for i, row in tqdm(code_df.iterrows(), total=code_df.shape[0]):
    diff = row.SR_diff
    if diff or not pd.isna(diff):
        total_diff_tokens += find_diff_tokens(diff)
        total_rows += 1

total_diff_tokens, total_diff_tokens / total_rows

100%|██████████| 61173/61173 [02:48<00:00, 362.16it/s]


(95828835, 1598.1594176311664)

In [77]:
# average number of tokens in each diff split
total_diff_splits = 0
total_diff_split_tokens = 0
for i, row in tqdm(code_df.iterrows(), total=code_df.shape[0]):
    diff = row.SR_diff
    if diff or not pd.isna(diff):
        diff_split_list = full_parse_diffs_split(diff)
        total_diff_splits += len(diff_split_list)
        for diff_split in diff_split_list:
            total_diff_split_tokens += find_diff_tokens('\n'.join(diff_split))

total_diff_split_tokens, total_diff_split_tokens/total_diff_splits

100%|██████████| 61173/61173 [02:46<00:00, 367.34it/s]


(85946036, 326.5960472265606)

In [76]:
# average splits per diff (number of @@ -- @@ changes)
# so this will be number of distinct places where the file is edited
total_diff_splits / total_rows

4.388729528701511

In [138]:
# triplets_df = pd.read_parquet(os.path.join(cache_path, 'diff_code_triplets.parquet'))
triplets_df = prepare_code_triplets(code_df, args, mode='diff_subsplit', cache_file=None)

Preparing code triplets with mode diff_subsplit for 6631 rows.
Preparing triplets split by diff content (further subplit at @@)


100%|██████████| 6631/6631 [00:00<00:00, 7136.89it/s]


                                               query  \
0  Malformed data types (`commitDetails`, `intera...   
1  Malformed data types (`commitDetails`, `intera...   
2  Malformed data types (`commitDetails`, `intera...   
3  Malformed data types (`commitDetails`, `intera...   
4  Malformed data types (`commitDetails`, `intera...   

                                           file_path  \
0  packages/react-reconciler/src/ReactFiberSchedu...   
1  packages/react-reconciler/src/ReactFiberSchedu...   
2  packages/react-reconciler/src/ReactFiberSchedu...   
3  packages/react-reconciler/src/ReactFiberSchedu...   
4  packages/react-reconciler/src/ReactFiberSchedu...   

                                             passage  label  
0       firstBatch._defer &&\n     firstBatch._ex...      0  
1   }\n \n function prepareFreshStack(root, expir...      0  
2       return null;\n   }\n \n  if (root.pending...      0  
3     // something suspended, wait to commit it a...      0  
4         }\n   

In [139]:
triplets_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 31734 entries, 0 to 31733
Data columns (total 4 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   query      31734 non-null  object
 1   file_path  31734 non-null  object
 2   passage    31734 non-null  object
 3   label      31734 non-null  int64 
dtypes: int64(1), object(3)
memory usage: 991.8+ KB


# Getting average query and passage length from the CodeSearchNet dataset

In [146]:
code_string = '''def get_vid_from_url(url):
        """Extracts video ID from URL.
        """
        return match1(url, r'youtu\.be/([^?/]+)') or \
          match1(url, r'youtube\.com/embed/([^/?]+)') or \
          match1(url, r'youtube\.com/v/([^/?]+)') or \
          match1(url, r'youtube\.com/watch/([^/?]+)') or \
          parse_query_param(url, 'v') or \
          parse_query_param(parse_query_param(url, 'u'), 'v')'''

In [142]:
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base')

In [147]:
len([tokenizer.decode(i) for i in full_tokenize(code_string, tokenizer)])

207

[Their notebook](https://github.com/github/CodeSearchNet/blob/master/notebooks/ExploreData.ipynb), which uses the tree-sitter tokenizer has ~52 tokens. So codebert's tokenizer is approx 4x.

In [1]:
from datasets import load_dataset
dataset = load_dataset("code_search_net")

Downloading data files:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.66G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/488M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/112M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/852M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/6 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/1880853 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/100529 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/89154 [00:00<?, ? examples/s]

In [5]:
csn_df = dataset['train'].to_pandas()

In [7]:
TOTAL_ROWS = csn_df.shape[0]
TOTAL_ROWS

1880853

In [8]:
csn_df.head()

Unnamed: 0,repository_name,func_path_in_repository,func_name,whole_func_string,language,func_code_string,func_code_tokens,func_documentation_string,func_documentation_tokens,split_name,func_code_url
0,mjirik/imcut,imcut/pycut.py,ImageGraphCut.__msgc_step3_discontinuity_local...,def __msgc_step3_discontinuity_localization(se...,python,def __msgc_step3_discontinuity_localization(se...,"[def, __msgc_step3_discontinuity_localization,...",Estimate discontinuity in basis of low resolut...,"[Estimate, discontinuity, in, basis, of, low, ...",train,https://github.com/mjirik/imcut/blob/1b38e7cd1...
1,mjirik/imcut,imcut/pycut.py,ImageGraphCut.__multiscale_gc_lo2hi_run,"def __multiscale_gc_lo2hi_run(self): # , pyed...",python,"def __multiscale_gc_lo2hi_run(self): # , pyed...","[def, __multiscale_gc_lo2hi_run, (, self, ), :...",Run Graph-Cut segmentation with refinement of ...,"[Run, Graph, -, Cut, segmentation, with, refin...",train,https://github.com/mjirik/imcut/blob/1b38e7cd1...
2,mjirik/imcut,imcut/pycut.py,ImageGraphCut.__multiscale_gc_hi2lo_run,"def __multiscale_gc_hi2lo_run(self): # , pyed...",python,"def __multiscale_gc_hi2lo_run(self): # , pyed...","[def, __multiscale_gc_hi2lo_run, (, self, ), :...",Run Graph-Cut segmentation with simplifiyng of...,"[Run, Graph, -, Cut, segmentation, with, simpl...",train,https://github.com/mjirik/imcut/blob/1b38e7cd1...
3,mjirik/imcut,imcut/pycut.py,ImageGraphCut.__ordered_values_by_indexes,"def __ordered_values_by_indexes(self, data, in...",python,"def __ordered_values_by_indexes(self, data, in...","[def, __ordered_values_by_indexes, (, self, ,,...",Return values (intensities) by indexes.\n\n ...,"[Return, values, (, intensities, ), by, indexe...",train,https://github.com/mjirik/imcut/blob/1b38e7cd1...
4,mjirik/imcut,imcut/pycut.py,ImageGraphCut.__hi2lo_multiscale_indexes,"def __hi2lo_multiscale_indexes(self, mask, ori...",python,"def __hi2lo_multiscale_indexes(self, mask, ori...","[def, __hi2lo_multiscale_indexes, (, self, ,, ...",Function computes multiscale indexes of ndarra...,"[Function, computes, multiscale, indexes, of, ...",train,https://github.com/mjirik/imcut/blob/1b38e7cd1...


In [31]:
csn_df.iloc[0].whole_func_string

'def __msgc_step3_discontinuity_localization(self):\n        """\n        Estimate discontinuity in basis of low resolution image segmentation.\n        :return: discontinuity in low resolution\n        """\n        import scipy\n\n        start = self._start_time\n        seg = 1 - self.segmentation.astype(np.int8)\n        self.stats["low level object voxels"] = np.sum(seg)\n        self.stats["low level image voxels"] = np.prod(seg.shape)\n        # in seg is now stored low resolution segmentation\n        # back to normal parameters\n        # step 2: discontinuity localization\n        # self.segparams = sparams_hi\n        seg_border = scipy.ndimage.filters.laplace(seg, mode="constant")\n        logger.debug("seg_border: %s", scipy.stats.describe(seg_border, axis=None))\n        # logger.debug(str(np.max(seg_border)))\n        # logger.debug(str(np.min(seg_border)))\n        seg_border[seg_border != 0] = 1\n        logger.debug("seg_border: %s", scipy.stats.describe(seg_border, a

In [25]:
full_tokenize('hello', tokenizer)

42891

In [27]:
total_query_tokens = 0
for i, row in tqdm(csn_df.iterrows(), total=TOTAL_ROWS):
    query = row.func_documentation_string
    if query:
        tkns = full_tokenize(query, tokenizer)
        if isinstance(tkns, int):
            tkns = [tkns]
        total_query_tokens += len(tkns)

total_query_tokens, total_query_tokens/TOTAL_ROWS 

100%|██████████| 1880853/1880853 [09:56<00:00, 3152.28it/s]


(118596716, 63.05475015857167)

In [32]:
total_func_tokens = 0
for i, row in tqdm(csn_df.iterrows(), total=TOTAL_ROWS):
    query = row.whole_func_string
    if query:
        tkns = full_tokenize(query, tokenizer)
        if isinstance(tkns, int):
            tkns = [tkns]
        total_func_tokens += len(tkns)

total_func_tokens, total_func_tokens/TOTAL_ROWS

100%|██████████| 1880853/1880853 [23:27<00:00, 1336.53it/s] 


(644719290, 342.7802651243877)

In [28]:
combined_df.head()

Unnamed: 0,owner,repo_name,commit_date,commit_id,commit_message,file_path,previous_commit_id,previous_file_content,cur_file_content,diff,status,is_merge_request,file_extension
0,facebook,react,1696522497,dddfe688206dafa5646550d351eb9a8e9c53654a,pull implementations from the right react-dom ...,packages/react-dom/server-rendering-stub.js,546178f9109424f6a0176ea8702a7620c4417569,"/**  * Copyright (c) Meta Platforms, Inc. and ...","/**  * Copyright (c) Meta Platforms, Inc. and ...","@@ -30,7 +30,10 @@ export {  } from './src/ser...",modified,False,js
1,facebook,react,1696521194,546178f9109424f6a0176ea8702a7620c4417569,`react-dom/server-rendering-stub`: restore exp...,packages/react-dom/server-rendering-stub.js,16619f106ab5ba8e6aca19d55be46cce22e4a7ff,"/**  * Copyright (c) Meta Platforms, Inc. and ...","/**  * Copyright (c) Meta Platforms, Inc. and ...","@@ -28,3 +28,30 @@ export {  useFormState,  ...",modified,False,js
2,facebook,react,1696452492,0fba3ecf73900a1b54ed6d3b0617462ac92d2fef,[Fizz] Reset error component stack and fix err...,packages/react-dom/src/__tests__/ReactDOMFizzS...,6f132439578ee11e04b41a278df51c52b0dc8563,"/**  * Copyright (c) Meta Platforms, Inc. and ...","/**  * Copyright (c) Meta Platforms, Inc. and ...","@@ -981,4 +981,149 @@ describe('ReactDOMFizzSt...",modified,False,js
3,facebook,react,1696452492,0fba3ecf73900a1b54ed6d3b0617462ac92d2fef,[Fizz] Reset error component stack and fix err...,packages/react-server/src/ReactFizzServer.js,6f132439578ee11e04b41a278df51c52b0dc8563,"/**  * Copyright (c) Meta Platforms, Inc. and ...","/**  * Copyright (c) Meta Platforms, Inc. and ...","@@ -1110,7 +1110,6 @@ function replaySuspenseB...",modified,False,js
4,facebook,react,1696450581,6f132439578ee11e04b41a278df51c52b0dc8563,Move ReactCurrentDispatcher back to shared int...,packages/react-server/src/ReactFlightServer.js,ca237d6f0ab986e799f192224d3066f76d66b73b,"/**  * Copyright (c) Meta Platforms, Inc. and ...","/**  * Copyright (c) Meta Platforms, Inc. and ...","@@ -108,6 +108,7 @@ import {  } from 'shared/R...",modified,False,js
