In [1]:
import sys
sys.path.append('../src')

import pandas as pd
import os

from utils import get_combined_df
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from tqdm import tqdm

from bm25_v2 import BM25Searcher
from eval import ModelEvaluator, SearchEvaluator

In [2]:
class Args:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

args = Args(
    index_path='../data/2_7/facebook_react/index_commit_tokenized', repo_path='../data/2_7/facebook_react', k=1000, n=100, model_path='microsoft/codebert-base', overwrite_cache=False, batch_size=32, num_epochs=10, learning_rate=5e-05, run_name='repr_0.1663', notes='reproducing current best 0.1663 MAP result for CodeReranker', num_positives=10, num_negatives=10, train_depth=1000, num_workers=8, train_commits=1000, psg_cnt=25, aggregation_strategy='sump', use_gpu=True, rerank_depth=100, do_train=True, do_eval=True, eval_gold=True, openai_model='gpt4', overwrite_eval=False, sanity_check=True, debug=False, best_model_path=None, bert_best_model='data/combined_commit_train/best_model', psg_len=350, psg_stride=250, ignore_gold_in_training=False, eval_folder='repr_0.1663', use_gpt_train=True
)

metrics =['MAP', 'P@1', 'P@10', 'P@20', 'P@30', 'MRR', 'R@1', 'R@10', 'R@100', 'R@1000']
repo_path = args.repo_path
repo_name = repo_path.split('/')[-1]
index_path = args.index_path
K = args.k
n = args.n
combined_df = get_combined_df(repo_path)
BM25_AGGR_STRAT = 'sump'
eval_path = os.path.join(repo_path, 'eval')
if not os.path.exists(eval_path):
    os.makedirs(eval_path)

bm25_searcher = BM25Searcher(index_path)
evaluator = SearchEvaluator(metrics)
model_evaluator = ModelEvaluator(bm25_searcher, evaluator, combined_df)

test_path = os.path.join('..', 'gold', 'facebook_react', 'v2_facebook_react_gpt4_gold.parquet')
gold_df = pd.read_parquet(test_path)

Loaded index at ../data/2_7/facebook_react/index_commit_tokenized
Index Stats: {'total_terms': 7587973, 'documents': 73765, 'non_empty_documents': 73765, 'unique_terms': 14602}


In [15]:
cache_path = os.path.join(args.repo_path, 'cache', '4X_function_split')
code_df = pd.read_parquet(os.path.join(cache_path, 'code_df.parquet'))

In [8]:
code_df.train_commit_id.nunique()

1890

In [9]:
code_df.label.value_counts()

label
0    18114
1     5862
Name: count, dtype: int64

In [16]:
triplets_df = pd.read_parquet(os.path.join(cache_path, 'diff_code_triplets.parquet'))

In [17]:
triplets_df.label.value_counts()

label
0    305107
1     79627
Name: count, dtype: int64

In [18]:
triplets_df.head(10)

Unnamed: 0,query,file_path,passage,label
0,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,<s>function resolveLocksOnRoot(root: FiberRoot...,0
1,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,"<s>function prepareFreshStack(root, expiration...",0
2,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,pirationTime = expirationTime;\n workInProgre...,0
3,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,<s>function computeUniqueAsyncExpiration(): Ex...,0
4,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,<s>function scheduleUpdateOnFiber(\n fiber: F...,0
5,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,// should be deferred until the end of t...,0
6,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,\n } else {\n // TODO: computeExpirationFo...,0
7,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,\n rootsWithPendingDiscreteUpdates.se...,0
8,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,<s>function markUpdateTimeFromFiberToRoot(fibe...,0
9,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,alternate!== null &&\n alternate.c...,0


In [27]:
class BERTCodeReranker:
    def __init__(self, parameters):
        self.parameters = parameters
        self.model_name = parameters['model_name']
        self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name, num_labels=1, problem_type='regression')
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() and parameters['use_gpu'] else "cpu")
        self.model.to(self.device)

        print(f'Using device: {self.device}')

        # print GPU info
        if torch.cuda.is_available() and parameters['use_gpu']:
            print(f"Using GPU: {torch.cuda.get_device_name(0)}")
            print(f'GPU Device Count: {torch.cuda.device_count()}')
            print(f"GPU Memory Usage: {torch.cuda.memory_allocated(0) / 1024 ** 2:.2f} MB")


        self.psg_len = parameters['psg_len']
        self.psg_cnt = parameters['psg_cnt'] # how many contributing_results to use per file for reranking
        self.psg_stride = parameters.get('psg_stride', self.psg_len)
        self.aggregation_strategy = parameters['aggregation_strategy'] # how to aggregate the scores of the psg_cnt contributing_results
        self.batch_size = parameters['batch_size'] # batch size for reranking efficiently
        self.rerank_depth = parameters['rerank_depth']
        self.max_seq_length = self.tokenizer.model_max_length # max sequence length for the model

        print(f"Initialized Code File BERT reranker with parameters: {parameters}")


    def rerank(self, query, aggregated_results):
        """
        Rerank the BM25 aggregated search results using BERT model scores.

        query: The issue query string.
        aggregated_results: A list of AggregatedSearchResult objects from BM25 search.
        """
        # aggregated_results = aggregated_results[:self.rerank_depth] # already done in the pipeline
        # print(f'Reranking {len(aggregated_results)} results')

        self.model.eval()

        query_passage_pairs, per_result_contribution = self.split_into_query_passage_pairs(query, aggregated_results)


        # for agg_result in aggregated_results:
        #     query_passage_pairs.extend(
        #         (query, result.commit_message)
        #         for result in agg_result.contributing_results[: self.psg_cnt]
        #     )

        if not query_passage_pairs:
            print('WARNING: No query passage pairs to rerank, returning original results from previous stage')
            print(query, aggregated_results, self.psg_cnt)
            return aggregated_results

        # tokenize the query passage pairs
        encoded_pairs = [self.tokenizer.encode_plus([query, passage], max_length=self.max_seq_length, truncation=True, padding='max_length', return_tensors='pt', add_special_tokens=True) for query, passage in query_passage_pairs]

        # create tensors for the input ids, attention masks
        input_ids = torch.stack([encoded_pair['input_ids'].squeeze() for encoded_pair in encoded_pairs], dim=0) # type: ignore
        attention_masks = torch.stack([encoded_pair['attention_mask'].squeeze() for encoded_pair in encoded_pairs], dim=0) # type: ignore

        # Create a dataloader for feeding the data to the model
        dataset = TensorDataset(input_ids, attention_masks)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) # shuffle=False very important for reconstructing the results back into the original order

        scores = self.get_scores(dataloader, self.model)

        score_index = 0
        # Now assign the scores to the aggregated results by mapping the scores to the contributing results
        for i, agg_result in enumerate(aggregated_results):
            # Each aggregated result gets a slice of the scores equal to the number of contributing results it has which should be min(psg_cnt, len(contributing_results))
            assert score_index < len(scores), f'score_index {score_index} is greater than or equal to scores length {len(scores)}'
            end_index = score_index + per_result_contribution[i] # only use psg_cnt contributing_results
            cur_passage_scores = scores[score_index:end_index]
            score_index = end_index


            # Aggregate the scores for the current aggregated result
            agg_score = self.aggregate_scores(cur_passage_scores)
            agg_result.score = agg_score  # Assign the aggregated score

        assert score_index == len(scores), f'score_index {score_index} does not equal scores length {len(scores)}, indices probably not working correctly'

        # Sort by the new aggregated score
        aggregated_results.sort(key=lambda res: res.score, reverse=True)

        return aggregated_results

    def get_scores(self, dataloader, model):
        scores = []
        with torch.no_grad():
            for batch in dataloader:
                # Unpack the batch and move it to GPU
                b_input_ids, b_attention_mask = batch
                b_input_ids = b_input_ids.to(self.device)
                b_attention_mask = b_attention_mask.to(self.device)

                # Get scores from the model
                outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_attention_mask)
                scores.extend(outputs.logits.detach().cpu().numpy().squeeze(-1))
        return scores

    def aggregate_scores(self, passage_scores):
        """
        Aggregate passage scores based on the specified strategy.
        """
        if len(passage_scores) == 0:
            return 0.0

        if self.aggregation_strategy == 'firstp':
            return passage_scores[0]
        if self.aggregation_strategy == 'maxp':
            return max(passage_scores)
        if self.aggregation_strategy == 'avgp':
            return sum(passage_scores) / len(passage_scores)
        if self.aggregation_strategy == 'sump':
            return sum(passage_scores)
        # else:
        raise ValueError(f"Invalid score aggregation method: {self.aggregation_strategy}")


    def split_into_query_passage_pairs(self, query, aggregated_results):
        # Flatten the list of results into a list of (query, passage) pairs but only keep max psg_cnt passages per file
        def full_tokenize(s):
            return self.tokenizer.encode_plus(s, max_length=None, truncation=False, return_tensors='pt', add_special_tokens=True, return_attention_mask=False, return_token_type_ids=False)['input_ids'].squeeze().tolist()
        query_passage_pairs = []
        per_result_contribution = []
        for agg_result in aggregated_results:
            agg_result.contributing_results.sort(key=lambda res: res.commit_date, reverse=True)
            # get most recent file version
            most_recent_search_result = agg_result.contributing_results[0]
            # get the file_path and commit_id
            file_path = most_recent_search_result.file_path
            commit_id = most_recent_search_result.commit_id
            # get the file content from combined_df
            file_content = combined_df[(combined_df['commit_id'] == commit_id) & (combined_df['file_path'] == file_path)]['cur_file_content'].values[0]

            # now need to split this file content into psg_cnt passages
            # first tokenize the file content

            # warning these asserts are useless since we are using NaNs
            assert file_content is not None, f'file_content is None for commit_id: {commit_id}, file_path: {file_path}'
            assert file_path is not None, f'file_path is None for commit_id: {commit_id}'
            assert query is not None, f'query is None'

            query_tokens = full_tokenize(query)
            path_tokens = full_tokenize(file_path)

            if pd.isna(file_content):
                # if file_content is NaN, then we can just set file_content to empty string
                print(f'WARNING: file_content is NaN for commit_id: {commit_id}, file_path: {file_path}, setting file_content to empty string')
                file_content = ''

            file_tokens = full_tokenize(file_content)


            # now split the file content into psg_cnt passages
            cur_result_passages = []
            # get the input ids
            # input_ids = file_content['input_ids'].squeeze()
            # get the number of tokens in the file content
            total_tokens = len(file_tokens)

            for cur_start in range(0, total_tokens, self.psg_stride):
                cur_passage = []
                # add query tokens and path tokens
                # cur_passage.extend(query_tokens) # ??????????????
                cur_passage.extend(path_tokens)

                # add the file tokens
                cur_passage.extend(file_tokens[cur_start:cur_start+self.psg_len])

                # now convert cur_passage into a string
                cur_passage_decoded = self.tokenizer.decode(cur_passage)

                # add the cur_passage to cur_result_passages
                cur_result_passages.append(cur_passage_decoded)

                if len(cur_result_passages) == self.psg_cnt:
                    break

            # now add the query, passage pairs to query_passage_pairs
            per_result_contribution.append(len(cur_result_passages))
            query_passage_pairs.extend((query, passage) for passage in cur_result_passages)
        return query_passage_pairs, per_result_contribution

    def rerank_pipeline(self, query, aggregated_results):
        if len(aggregated_results) == 0:
            return aggregated_results
        top_results = aggregated_results[:self.rerank_depth]
        bottom_results = aggregated_results[self.rerank_depth:]
        reranked_results = self.rerank(query, top_results)
        min_top_score = reranked_results[-1].score
        # now adjust the scores of bottom_results
        for i, result in enumerate(bottom_results):
            result.score = min_top_score - i - 1
        # combine the results
        reranked_results.extend(bottom_results)
        assert(len(reranked_results) == len(aggregated_results))
        return reranked_results

In [13]:
bm25_baseline_eval = model_evaluator.evaluate_sampling(n=n, k=K, output_file_path=None, aggregation_strategy=BM25_AGGR_STRAT)

print("BM25 Baseline Evaluation")
print(bm25_baseline_eval)



100%|██████████| 100/100 [00:30<00:00,  3.27it/s]

BM25 Baseline Evaluation
{'MAP': 0.1542, 'P@1': 0.11, 'P@10': 0.087, 'P@20': 0.063, 'P@30': 0.0517, 'MRR': 0.2133, 'R@1': 0.0509, 'R@10': 0.2285, 'R@100': 0.5077, 'R@1000': 0.6845}





# Fixing/Verifying code triplets

Multiple modes:
1. Random splits
2. Diff only
3. Function split

Run with 500 gold, 500 gold + 500 normal commits, 500 + 1500 normal commits

In [28]:
code_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 126 entries, 0 to 125
Data columns (total 8 columns):
 #   Column                  Non-Null Count  Dtype 
---  ------                  --------------  ----- 
 0   train_commit_id         126 non-null    object
 1   train_query             126 non-null    object
 2   train_original_message  126 non-null    object
 3   SR_file_path            126 non-null    object
 4   SR_commit_id            126 non-null    object
 5   SR_file_content         126 non-null    object
 6   SR_diff                 123 non-null    object
 7   label                   126 non-null    int64 
dtypes: int64(1), object(7)
memory usage: 8.0+ KB


In [30]:
code_df.head(2)

Unnamed: 0,train_commit_id,train_query,train_original_message,SR_file_path,SR_commit_id,SR_file_content,SR_diff,label
0,76e569992b259b9e636ee68dcc7719539f4b9bb8,"Malformed data types (`commitDetails`, `intera...","Cleanup profile export/import data types, add ...",packages/react-reconciler/src/ReactFiberSchedu...,cc24d0ea56b0538d1ac61dc09faedd70ced5bb47,"/**\n * Copyright (c) Facebook, Inc. and its a...","@@ -569,8 +569,6 @@ function resolveLocksOnRoo...",0
1,76e569992b259b9e636ee68dcc7719539f4b9bb8,"Malformed data types (`commitDetails`, `intera...","Cleanup profile export/import data types, add ...",src/renderers/shared/fiber/ReactChildFiber.js,c22b94f14a809abb376f07a53f36860a7c6a342e,"/**\n * Copyright 2013-present, Facebook, Inc....","@@ -13,8 +13,7 @@\n 'use strict';\n \n import ...",0


In [31]:
params = {
        'model_name': args.model_path,
        'psg_cnt': args.psg_cnt,
        'aggregation_strategy': args.aggregation_strategy,
        'batch_size': args.batch_size,
        'use_gpu': args.use_gpu,
        'rerank_depth': args.rerank_depth,
        'num_epochs': args.num_epochs,
        'lr': args.learning_rate,
        'num_positives': args.num_positives,
        'num_negatives': args.num_negatives,
        'train_depth': args.train_depth,
        'num_workers': args.num_workers,
        'train_commits': args.train_commits,
        'bm25_aggr_strategy': 'sump',
        'psg_len': args.psg_len,
        'psg_stride': args.psg_stride
    }
code_reranker = BERTCodeReranker(params)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at microsoft/codebert-base and are newly initialized: ['classifier.out_proj.weight', 'classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Using device: cuda
Using GPU: Quadro RTX 6000
GPU Device Count: 1
GPU Memory Usage: 953.46 MB
Initialized Code File BERT reranker with parameters: {'model_name': 'microsoft/codebert-base', 'psg_cnt': 25, 'aggregation_strategy': 'sump', 'batch_size': 32, 'use_gpu': True, 'rerank_depth': 100, 'num_epochs': 10, 'lr': 5e-05, 'num_positives': 10, 'num_negatives': 10, 'train_depth': 1000, 'num_workers': 8, 'train_commits': 1000, 'bm25_aggr_strategy': 'sump', 'psg_len': 350, 'psg_stride': 250}


In [37]:
def prepare_code_triplets(code_df, code_reranker, mode, cache_file, overwrite=False):
    print(f"Preparing code triplets with mode {mode} for {len(code_df)} rows.")
    if cache_file and os.path.exists(cache_file) and not overwrite:
        print(f"Loading data from cache file: {cache_file}")
        return pd.read_parquet(cache_file)

    if mode == 'sliding_window':
        triplets = prepare_sliding_window_triplets(code_df, code_reranker)
    elif mode == 'parse_functions':
        triplets = prepare_function_triplets(code_df, code_reranker)
    elif mode == 'diff_content':
        triplets = prepare_diff_content_triplets(code_df, code_reranker)
    else:
        raise ValueError(f"Unsupported mode: {mode}")

    triplets_df = pd.DataFrame(triplets, columns=['query', 'file_path', 'passage', 'label'])
    if cache_file:
        print(f"Saving data to cache file: {cache_file}")
        triplets_df.to_parquet(cache_file)

    return triplets_df

In [111]:
def prepare_sliding_window_triplets(code_df, code_reranker):

    #### Helper functions ####
    def full_tokenize(s):
        return code_reranker.tokenizer.encode_plus(s, max_length=None, truncation=False, return_tensors='pt', add_special_tokens=True, return_attention_mask=False, return_token_type_ids=False)['input_ids'].squeeze().tolist()
        
    def prep_line(line):
        return line.rstrip().lstrip()

    def parse_diff(diff):
        return [
            line[1:] if line.startswith('+') else line
            for line in diff.split('\n')
            if not (line.startswith('-') or len(line) == 0 or (line.startswith('@@') and line.count('@@') > 1))
            and len(prep_line(line)) > 2
        ]
        
    def count_matching_lines(passage_lines, diff_lines):
        # Create a 2D array to store the lengths of the longest common subsequences
        dp = [[0] * (len(diff_lines) + 1) for _ in range(len(passage_lines) + 1)]

        # Fill the dp array
        for i in range(1, len(passage_lines) + 1):
            for j in range(1, len(diff_lines) + 1):
                if prep_line(passage_lines[i - 1]) == prep_line(diff_lines[j - 1]):
                    dp[i][j] = dp[i - 1][j - 1] + 1
                else:
                    dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

        return dp[-1][-1]

    #### End of helper functions ####

    triplets = []
    for _, row in tqdm(code_df.iterrows(), total=len(code_df)):
        file_tokens = full_tokenize(row['SR_file_content'])
        total_tokens = len(file_tokens)
        cur_diff = row['SR_diff']
        cur_triplets = []
        if cur_diff is None:
            # NOTE: for cases where status is added probably or if diff was not able to be stored (encoding issue, etc)
            # THIS WILL LEAD TO A FEW POSITIVES MISSING - don't freak out, it's normal, I checked ;)
            continue

        # Process the diffs to removee @@ stuff 
        cur_diff_lines = parse_diff(cur_diff) # split into lines and remove deletions, only additions remaining

        # get the diff tokens
        total_tokens = len(file_tokens)

        for cur_start in range(0, total_tokens, code_reranker.psg_stride):
            cur_passage = []

            # get tokens for current passage
            cur_passage.extend(file_tokens[cur_start:cur_start+code_reranker.psg_len])

            # now convert current passage tokens back into a string
            cur_passage_decoded = code_reranker.tokenizer.decode(cur_passage)

            # for ranking acc. to number of common lines with diff, split on \n
            cur_passage_lines = cur_passage_decoded.split('\n')

            # remove lines with less than 2 characters since we do the same for diff preprocessing
            cur_passage_lines = [line for line in cur_passage_lines if len(prep_line(line)) > 2] # otherwise empty characters and brackets match

            # get number of common lines b/w diff
            common_line_count = count_matching_lines(cur_passage_lines, cur_diff_lines)

            # add the cur_passage to cur_result_passages
            cur_triplets.append((common_line_count, (row['train_query'], row['SR_file_path'], cur_passage_decoded, row['label'])))

        # # sort the cur_triplets by the number of common lines
        cur_triplets.sort(key=lambda x: x[0], reverse=True)

        # now add the top code_reranker.psg_cnt to triplets
        for triplet in cur_triplets[:code_reranker.psg_cnt]:
            # print(f"Found {triplet[0]} matching lines for diff in cur_passage at index")
            triplets.append(triplet[1])
    print(len(triplets))
    return triplets

In [120]:
tmp = prepare_code_triplets(code_df, code_reranker, mode='sliding_window', cache_file=None, overwrite=False)

Preparing code triplets with mode sliding_window for 126 rows.


100%|██████████| 126/126 [00:26<00:00,  4.73it/s]

2696





In [127]:
def prepare_diff_content_triplets(code_df, code_reranker):

    #### Helper functions #### 

    def full_tokenize(s):
        return code_reranker.tokenizer.encode_plus(s, max_length=None, truncation=False, return_tensors='pt', add_special_tokens=True, return_attention_mask=False, return_token_type_ids=False)['input_ids'].squeeze().tolist()

    def prep_line(line):
        return line.rstrip().lstrip()
        
    def full_parse_diffs(diff):
       # keep both insertions and deletions to be passed to the model
        return [
            line[1:] if (line.startswith('+') or line.startswith('-')) else line
            for line in diff.split('\n')
            if not (len(line) == 0 or (line.startswith('@@') and line.count('@@') > 1))
        ]
    #### end of helper functions ####

    triplets = []

    for _, row in tqdm(code_df.iterrows(), total=len(code_df)):
        cur_diff = row['SR_diff']
        if cur_diff is None:
            # NOTE: for cases where status is added probably or if diff was not able to be stored (encoding issue, etc)
            # THIS WILL LEAD TO A FEW POSITIVES MISSING - don't freak out, it's normal, I checked ;)
            continue
        cur_diff_lines = full_parse_diffs(cur_diff) # keep both insertions and deletions
        diff_tokens = full_tokenize(''.join(cur_diff_lines))
        total_tokens = len(diff_tokens)
        for cur_start in range(0, total_tokens, code_reranker.psg_stride):
            cur_passage = []
    
            cur_passage.extend(diff_tokens[cur_start:cur_start+code_reranker.psg_len])
    
            # now convert cur_passage into a string
            cur_passage_decoded = code_reranker.tokenizer.decode(cur_passage)

            # add the cur_passage to cur_result_passages
            triplets.append((row['train_query'], row['SR_file_path'], cur_passage_decoded, row['label']))

        # now add the top code_reranker.psg_cnt to triplets
        return triplets


In [128]:
tmp2 = prepare_code_triplets(code_df, code_reranker, mode='diff_content', cache_file=None, overwrite=False)

Preparing code triplets with mode diff_content for 126 rows.


  0%|          | 0/126 [00:00<?, ?it/s]


In [167]:
tmp2.head()

Unnamed: 0,query,file_path,passage,label
0,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,<s> firstBatch._defer && firstBatch._e...,0
1,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,"null, root, expirationTime); return commitR...",0
2,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,"// immediately, wait for more dat...",0
3,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,"return commitRoot.bind(null, root); } ...",0
4,"Malformed data types (`commitDetails`, `intera...",packages/react-reconciler/src/ReactFiberSchedu...,"!== CommitPhase, 'Should not already be wo...",0


In [145]:
# print(code_df.query('train_commit_id == "76e569992b259b9e636ee68dcc7719539f4b9bb8" & SR_file_path == "packages/react-reconciler/src/ReactFiberScheduler.js"').SR_diff.values[0])

In [149]:
from tree_sitter import Language, Parser

In [172]:
recent_df.info()

NameError: name 'recent_df' is not defined

In [171]:
combined_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 73765 entries, 0 to 73764
Data columns (total 13 columns):
 #   Column                 Non-Null Count  Dtype 
---  ------                 --------------  ----- 
 0   owner                  73765 non-null  string
 1   repo_name              73765 non-null  string
 2   commit_date            73765 non-null  int64 
 3   commit_id              73765 non-null  string
 4   commit_message         73765 non-null  string
 5   file_path              73765 non-null  string
 6   previous_commit_id     73765 non-null  string
 7   previous_file_content  73765 non-null  string
 8   cur_file_content       73765 non-null  string
 9   diff                   58037 non-null  string
 10  status                 73765 non-null  object
 11  is_merge_request       73765 non-null  bool  
 12  file_extension         73765 non-null  object
dtypes: bool(1), int64(1), object(2), string(9)
memory usage: 6.8+ MB


In [169]:
def prepare_function_triplets(code_df, code_reranker):    
    #### Helper Functions ####
    def prep_line(line):
        return line.rstrip().lstrip()

    def parse_diff(diff):
        return [
            line[1:] if line.startswith('+') else line
            for line in diff.split('\n')
            if not (line.startswith('-') or len(line) == 0 or (line.startswith('@@') and line.count('@@') > 1))
            and len(prep_line(line)) > 2
        ]

    def full_tokenize(s):
        return code_reranker.tokenizer.encode_plus(s, max_length=None, truncation=False, return_tensors='pt', add_special_tokens=True, return_attention_mask=False, return_token_type_ids=False)['input_ids'].squeeze().tolist()

    def extract_function_texts(node, source_code):
        function_texts = []
        # Check if the node represents a function declaration
        if node.type == 'function_declaration':
            start_byte = node.start_byte
            end_byte = node.end_byte
            function_texts.append(source_code[start_byte:end_byte].decode('utf8'))
        # Check for variable declarations that might include function expressions or arrow functions
        elif node.type == 'variable_declaration':
            for child in node.children:
                if child.type == 'variable_declarator':
                    init_node = child.child_by_field_name('init')
                    if init_node and (init_node.type in ['function', 'arrow_function', 'function_expression']):
                        start_byte = node.start_byte
                        end_byte = node.end_byte
                        function_texts.append(source_code[start_byte:end_byte].decode('utf8'))
                        break  # Assuming one function per variable declaration for simplicity
        # Recursively process all child nodes
        else:
            for child in node.children:
                function_texts.extend(extract_function_texts(child, source_code))
        return function_texts

    def count_matching_lines(passage_lines, diff_lines):
        # Create a 2D array to store the lengths of the longest common subsequences
        dp = [[0] * (len(diff_lines) + 1) for _ in range(len(passage_lines) + 1)]

        # Fill the dp array
        for i in range(1, len(passage_lines) + 1):
            for j in range(1, len(diff_lines) + 1):
                if prep_line(passage_lines[i - 1]) == prep_line(diff_lines[j - 1]):
                    dp[i][j] = dp[i - 1][j - 1] + 1
                else:
                    dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

        return dp[-1][-1]

    #### end of helper functions #### 

    JS_LANGUAGE = Language('../src/parser/my-languages.so', 'javascript')
    parser = Parser()
    parser.set_language(JS_LANGUAGE)

    triplets = []

    for _, row in tqdm(code_df.iterrows(), total=len(code_df)):
        file_content = row['SR_file_content']
        cur_diff = row['SR_diff']

        if cur_diff is None:
            continue

        # Convert the source code to bytes for tree-sitter
        source_code_bytes = bytes(file_content, "utf8")

        # Parse the code
        tree = parser.parse(source_code_bytes)

        # Extract function texts
        root_node = tree.root_node
        function_texts = extract_function_texts(root_node, source_code_bytes)

        cur_diff_lines = parse_diff(cur_diff)
        cur_triplets = []

        for func in function_texts:
            cur_func_lines = func.split('\n')

            # remove lines with less than 2 characters
            cur_func_lines = [line for line in cur_func_lines if len(prep_line(line)) > 2]
            # common_lines = set(cur_func_lines).intersection(set(cur_diff_lines))
            common_line_count = count_matching_lines(cur_func_lines, cur_diff_lines)
            cur_triplets.append((common_line_count, (row['train_query'], row['SR_file_path'], func, row['label'])))

        # # sort the cur_triplets by the number of common lines
        cur_triplets.sort(key=lambda x: x[0], reverse=True)

        # # now we want to filter cur_triplets to have all tuplets with x[0] > 3 to be in order and shuffle the rest

        # # now add the top code_reranker.psg_cnt to triplets
        for triplet in cur_triplets[:code_reranker.psg_cnt]:
            query, file_path, function, label = triplet[1]
            function_tokenized = full_tokenize(function)
            total_tokens = len(function_tokenized)

            for cur_start in range(0, total_tokens, code_reranker.psg_stride):
                cur_passage = []
                cur_passage.extend(function_tokenized[cur_start:cur_start+code_reranker.psg_len])
                cur_passage_decoded = code_reranker.tokenizer.decode(cur_passage)
                triplets.append((query, file_path, cur_passage_decoded, label))
            
    return triplets

In [None]:
# # get the diff tokens
# total_tokens = len(file_tokens)

# for cur_start in range(0, total_tokens, code_reranker.psg_stride):
#     cur_passage = []

#     # get tokens for current passage
#     cur_passage.extend(file_tokens[cur_start:cur_start+code_reranker.psg_len])

#     # now convert current passage tokens back into a string
#     cur_passage_decoded = code_reranker.tokenizer.decode(cur_passage)

#     # for ranking acc. to number of common lines with diff, split on \n
#     cur_passage_lines = cur_passage_decoded.split('\n')

#     # remove lines with less than 2 characters since we do the same for diff preprocessing
#     cur_passage_lines = [line for line in cur_passage_lines if len(prep_line(line)) > 2] # otherwise empty characters and brackets match

#     # get number of common lines b/w diff
#     common_line_count = count_matching_lines(cur_passage_lines, cur_diff_lines)

#     # add the cur_passage to cur_result_passages
#     cur_triplets.append((common_line_count, (row['train_query'], row['SR_file_path'], cur_passage_decoded, row['label'])))

In [153]:
tmp3 = prepare_code_triplets(code_df, code_reranker, mode='parse_functions', cache_file=None, overwrite=False)

Preparing code triplets with mode parse_functions for 6631 rows.


  0%|          | 0/6631 [00:00<?, ?it/s]


In [168]:
tmp3.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 12 entries, 0 to 11
Data columns (total 4 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   query      12 non-null     object
 1   file_path  12 non-null     object
 2   passage    12 non-null     object
 3   label      12 non-null     int64 
dtypes: int64(1), object(3)
memory usage: 512.0+ bytes


In [170]:
tmp4 = prepare_code_triplets(code_df, code_reranker, mode='parse_functions', cache_file=None, overwrite=False)

Preparing code triplets with mode parse_functions for 6631 rows.


  5%|▍         | 324/6631 [00:26<08:27, 12.42it/s]


KeyboardInterrupt: 

In [166]:
tmp4.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 22 entries, 0 to 21
Data columns (total 4 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   query      22 non-null     object
 1   file_path  22 non-null     object
 2   passage    22 non-null     object
 3   label      22 non-null     int64 
dtypes: int64(1), object(3)
memory usage: 832.0+ bytes


In [156]:
code_df.head(1)

Unnamed: 0,train_commit_id,train_query,train_original_message,SR_file_path,SR_commit_id,SR_file_content,SR_diff,label
0,76e569992b259b9e636ee68dcc7719539f4b9bb8,"Malformed data types (`commitDetails`, `intera...","Cleanup profile export/import data types, add ...",packages/react-reconciler/src/ReactFiberSchedu...,cc24d0ea56b0538d1ac61dc09faedd70ced5bb47,"/**\n * Copyright (c) Facebook, Inc. and its a...","@@ -569,8 +569,6 @@ function resolveLocksOnRoo...",0


In [144]:
combined_df.commit_id.nunique()

11609

In [157]:
code_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6631 entries, 0 to 6630
Data columns (total 8 columns):
 #   Column                  Non-Null Count  Dtype 
---  ------                  --------------  ----- 
 0   train_commit_id         6631 non-null   object
 1   train_query             6631 non-null   object
 2   train_original_message  6631 non-null   object
 3   SR_file_path            6631 non-null   object
 4   SR_commit_id            6631 non-null   object
 5   SR_file_content         6631 non-null   object
 6   SR_diff                 6446 non-null   object
 7   label                   6631 non-null   int64 
dtypes: int64(1), object(7)
memory usage: 414.6+ KB
