In [2]:
import sys
sys.path.append('../src')

In [3]:
import json
import os
import numpy as np
import pandas as pd
from pyserini.search.lucene import LuceneSearcher
from pyserini.index.lucene import IndexReader
from sklearn.metrics import average_precision_score, ndcg_score

from utils import count_commits, get_combined_df, tokenize, reverse_tokenize
from collections import defaultdict

In [28]:
class SearchResult:
    def __init__(self, commit_id, file_path, score, commit_date, commit_msg):
        self.commit_id = commit_id
        self.file_path = file_path
        self.score = score
        self.commit_date = commit_date
        self.commit_msg = commit_msg


    def __repr__(self):
        class_name = self.__class__.__name__
        # return f"{self.file_path} {self.score:.5f} {self.commit_date}"
        # return f"{class_name}(score: {self.score:.5f}, file_path: {self.file_path!r}, commit_id: {self.commit_id!r}, commit_date: {self.commit_date})"
        return f"{class_name}(score={self.score:.5f}, file_path={self.file_path!r}, commit_id={self.commit_id!r}, commit_date={self.commit_date})"

    def is_actual_modified(self, actual_modified_files):
        return self.file_path in actual_modified_files

    @staticmethod
    def print_results(query, search_results, show_only_actual_modified=False):
        actual_modified_files = query['actual_files_modified']
        for i, result in enumerate(search_results):
            if show_only_actual_modified and not result.is_actual_modified(actual_modified_files):
                continue
            print(f"{i+1:2} {result}")

In [27]:
class AggregatedSearchResult:
    def __init__(self, file_path, aggregated_score, contributing_results):
        self.file_path = file_path
        self.score = aggregated_score
        self.contributing_results = contributing_results

    def __repr__(self):
        class_name = self.__class__.__name__
        return f"{class_name}(file_path={self.file_path!r}, score={self.score}, " \
               f"contributing_results={self.contributing_results})"

In [97]:
class BM25Search:
    def __init__(self, index_path):
        if not os.path.exists(index_path):
            raise FileNotFoundError(f"Index at {index_path} does not exist!")
        self.searcher = LuceneSearcher(index_path)
        print(f"Loaded index at {index_path}")
        print(f'Index Stats: {IndexReader(index_path).stats()}')
        # self.ranking_depth = ranking_depth

    def search(self, query, query_date, ranking_depth):
        # TODO maybe change this to mean returning reranking_depths total results instead of being pruned by the query date
        hits = self.searcher.search(tokenize(query), ranking_depth)
        unix_date = query_date
        filtered_hits = [
            SearchResult(hit.docid, json.loads(hit.raw)['file_path'], hit.score, int(json.loads(hit.raw)["commit_date"]), reverse_tokenize(json.loads(hit.raw)['contents']))
            for hit in hits if int(json.loads(hit.raw)["commit_date"]) < unix_date
        ]
        return filtered_hits

    def search_full(self, query, query_date, ranking_depth):
        filtered_hits = []
        step_size = ranking_depth  # Initial search window
        total_hits_retrieved = 0

        while len(filtered_hits) < ranking_depth and step_size > 0:
            current_hits = self.searcher.search(tokenize(query), total_hits_retrieved + step_size)
            if not current_hits:
                break  # No more results to retrieve

            # Filter hits by query date
            for hit in current_hits:
                if int(json.loads(hit.raw)["commit_date"]) < query_date:
                    filtered_hits.append(
                        SearchResult(hit.docid, json.loads(hit.raw)['file_path'], hit.score,
                                     int(json.loads(hit.raw)["commit_date"]),
                                     reverse_tokenize(json.loads(hit.raw)['contents']))
                    )
                if len(filtered_hits) == ranking_depth:
                    break  # We have enough results

            total_hits_retrieved += step_size
            step_size = ranking_depth - len(filtered_hits)  # Decrease step size to only get as many as needed

        return filtered_hits[:ranking_depth]  # Return up to ranking_depth results

    def aggregate_file_scores(self, search_results, aggregation_method='sump'):
        # TODO maybe have different aggregation methods
        file_to_results = defaultdict(list)
        for result in search_results:
            file_to_results[result.file_path].append(result)

        aggregated_results = []
        for file_path, results in file_to_results.items():
            # aggregated_score = sum(result.score for result in results)
            if aggregation_method == 'sump':
                aggregated_score = sum(result.score for result in results)
            elif aggregation_method == 'maxp':
                aggregated_score = max(result.score for result in results)
            # elif aggregation_method == 'firstp':
            #     aggregated_score = results[0].score
            elif aggregation_method == 'avgp':
                aggregated_score = np.mean([result.score for result in results])
            else:
                raise ValueError(f"Unknown aggregation method {aggregation_method}")

            aggregated_results.append(AggregatedSearchResult(file_path, aggregated_score, results))

        aggregated_results.sort(key=lambda result: result.score, reverse=True)
        return aggregated_results

    def pipeline(self, query, query_date, ranking_depth, aggregation_method='sump'):
        search_results = self.search(query, query_date, ranking_depth)
        aggregated_results = self.aggregate_file_scores(search_results, aggregation_method)
        return aggregated_results

In [201]:
class SearchEvaluator:
    def __init__(self, metrics):
        self.metrics = metrics

    @staticmethod
    def precision_at_k(relevant, k):
        return sum(relevant[:k]) / k

    @staticmethod
    def mean_reciprocal_rank(relevant):
        for idx, value in enumerate(relevant):
            if value == 1:
                return 1 / (idx + 1)
        return 0

    @staticmethod
    def calculate_average_precision(relevant):
        pred_rel = [1] * len(relevant)
        relevant_documents_count = 0
        cumulative_precision = 0.0

        # We iterate through the predicted relevance scores
        for i in range(len(pred_rel)):
            # Check if the prediction at this rank is correct (i.e., if it is a relevant document)
            if pred_rel[i] == 1 and relevant[i] == 1:
                relevant_documents_count += 1
                precision_at_i = relevant_documents_count / (i + 1)
                cumulative_precision += precision_at_i

        # The average precision is the cumulative precision divided by the number of relevant documents
        average_precision = cumulative_precision / sum(relevant) if sum(relevant) > 0 else 0
        return average_precision

    # @staticmethod
    # def calculate_recall(relevant, total_modified_files, k):
    #   # Does not work for commit based approach as it can have multiple mentions of the same file across commits leading to a higher than 1 recall
    #     print(total_modified_files)
    #     print(relevant)
    #     return sum(relevant[:k]) / total_modified_files

    @staticmethod
    def calculate_recall(retrieved_files, actual_modified_files, relevant, k):

        return len(list({
                        file
                        for idx, file in enumerate(retrieved_files)
                        if relevant[idx] == 1
                    })[:k]) / len(actual_modified_files) if len(actual_modified_files) > 0 else 0


    def evaluate(self, search_results, actual_modified_files):
        retrieved_files = [result.file_path for result in search_results]
        relevant = [1 if file in actual_modified_files else 0 for file in retrieved_files]


        evaluations = {}
        for metric in self.metrics:
            if metric == 'MAP':
                evaluations[metric] = self.calculate_average_precision(relevant)
            elif metric == 'MRR':
                evaluations[metric] = self.mean_reciprocal_rank(relevant)
            elif metric.startswith('P@'):
                k = int(metric.split('@')[1])
                evaluations[metric] = self.precision_at_k(relevant, k)
            elif metric.startswith('Recall@'):
                k = int(metric.split('@')[1])

                # evaluations[metric] = len(list({
                #         file
                #         for idx, file in enumerate(retrieved_files)
                #         if relevant[idx] == 1
                #     })[:k]) / len(actual_modified_files)
                # evaluations[metric] = self.calculate_recall(relevant, len(actual_modified_files), k) # DOES NOT WORK FOR COMMIT-BASED APPROACH

                evaluations[metric] = self.calculate_recall(retrieved_files, actual_modified_files, relevant, k)


        return {k: round(v, 4) for k, v in evaluations.items()}

    # def evaluate_file_based(self, search_results, actual_modified_files, aggregation_strategy='sump'):
    #     file_relevance = defaultdict(list)

    #     # Aggregate relevance scores for each file across all commits
    #     for result in search_results:
    #         if result.file_path in actual_modified_files:
    #             # file_relevance[result.file_path] += 1
    #             file_relevance[result.file_path].append(result.score)

    #     # Normalize relevance scores based on occurrences in actual modified files
    #     # max_relevance = max(file_relevance.values(), default=1)
    #     # normalized_relevance = {file: relevance / max_relevance for file, relevance in file_relevance.items()}
    #     # sorted_normalized_relevance = sorted(normalized_relevance.items(), key=lambda item: item[1], reverse=True)
    #     print(file_relevance)
    #     if aggregation_strategy == 'sump':
    #         aggregated_scores = {file: sum(relevance) for file, relevance in file_relevance.items()}
    #     elif aggregation_strategy == 'firstp':
    #         aggregated_scores = {file: relevance[0] for file, relevance in file_relevance.items()}
    #     elif aggregation_strategy == 'avgp':
    #         aggregated_scores = {file: np.mean(relevance) for file, relevance in file_relevance.items()}
    #     else:
    #         raise ValueError(f"Unknown aggregation strategy {aggregation_strategy}")

    #     sorted_aggregated_scores = sorted(aggregated_scores.items(), key=lambda item: item[1], reverse=True)

    #     print(sorted_aggregated_scores)

    #     evaluations = {}
    #     for metric in self.metrics:
    #         if metric.startswith('P@'):
    #             # Compute precision at k for files, not individual commit mentions
    #             k = int(metric.split('@')[1])
    #             # top_k_files = sorted(normalized_relevance.items(), key=lambda item: item[1], reverse=True)[:k]
    #             top_k_files = sorted_aggregated_scores[:k]
    #             precision_at_k = sum(1 for file, relevance in top_k_files if file in actual_modified_files) / k
    #             evaluations[metric] = precision_at_k
    #         elif metric.startswith('Recall@'):
    #             k = int(metric.split('@')[1])
    #             # top_k_files = sorted(normalized_relevance.items(), key=lambda item: item[1], reverse=True)[:k]
    #             top_k_files = sorted_aggregated_scores[:k]
    #             recall_at_k = sum(1 for file, relevance in top_k_files if file in actual_modified_files) / len(actual_modified_files)
    #             evaluations[metric] = recall_at_k
    #         elif metric == 'MAP':
    #             # Compute average precision for files, not individual commit mentions
    #             average_precision = 0
    #             num_relevant_files = 0
    #             for idx, (file, relevance) in enumerate(sorted_aggregated_scores):
    #                 if file in actual_modified_files:
    #                     num_relevant_files += 1
    #                     average_precision += num_relevant_files / (idx + 1)
    #             average_precision /= len(actual_modified_files)
    #             evaluations[metric] = average_precision
    #         elif metric == 'MRR':
    #             # Compute mean reciprocal rank for files, not individual commit mentions
    #             reciprocal_rank = 0
    #             for idx, (file, relevance) in enumerate(sorted_aggregated_scores):
    #                 if file in actual_modified_files:
    #                     reciprocal_rank = 1 / (idx + 1)
    #                     break
    #             evaluations[metric] = reciprocal_rank

    #     return {k: round(v, 4) for k, v in evaluations.items()}

In [185]:
class ModelEvaluator:
    def __init__(self, model, eval_model, combined_df, seed=42):
        self.model = model
        self.eval_model = eval_model
        self.combined_df = combined_df
        self.seed = seed

    def sample_commits(self, n):
        if self.combined_df.commit_id.nunique() < n:
            raise ValueError(f'Not enough commits to sample. Required: {n}, available: {self.combined_df.commit_id.nunique()}')
        return self.combined_df.drop_duplicates(subset='commit_id').sample(n=n, replace=False, random_state=self.seed)

    def evaluate_sampling(self, n=100, k=1000, output_dir='.', skip_existing=False, evaluation_strategy='commit', aggregation_strategy='sump'):
        model_name = self.model.__class__.__name__
        output_file = f"{output_dir}/{model_name}_metrics.txt"

        if skip_existing and os.path.exists(output_file):
            print(f'Output file {output_file} already exists, skipping...')
            return

        sampled_commits = self.sample_commits(n)

        results = []
        for _, row in sampled_commits.iterrows():
            # search_results = self.model.search(row['commit_message'], row['commit_date'], ranking_depth=k)
            search_results = self.model.pipeline(row['commit_message'], row['commit_date'], ranking_depth=k, aggregation_method=aggregation_strategy)
            # if evaluation_strategy == 'commit':
            evaluation = self.eval_model.evaluate(search_results,
                                                       self.combined_df[self.combined_df['commit_id'] == row['commit_id']]['file_path'].tolist())
            # elif evaluation_strategy == 'file':
            #     evaluation = self.eval_model.evaluate_file_based(search_results,
            #                                                       self.combined_df[self.combined_df['commit_id'] == row['commit_id']]['file_path'].tolist(), aggregation_strategy=aggregation_strategy)
            # else:
            #     raise ValueError(f'Invalid evaluation strategy: {evaluation_strategy}')
            results.append(evaluation)

        avg_scores = {metric: round(np.mean([result[metric] for result in results]), 4) for metric in results[0]}

        os.makedirs(output_dir, exist_ok=True)  # Create the output directory if it doesn't exist
        with open(output_file, "w") as file:
            file.write(f"Model Name: {model_name}\n")
            file.write(f"Sample Size: {n}\n")
            file.write("Evaluation Metrics:\n")
            for key, value in avg_scores.items():
                file.write(f"{key}: {value}\n")

        return avg_scores

In [186]:
def tmp():
    def calculate_recall(relevant, total_modified_files, k):
        return sum(relevant[:k]) / total_modified_files

    def calculate_average_precision(relevant):
        pred_rel = [1] * len(relevant)
        relevant_documents_count = 0
        cumulative_precision = 0.0

        # We iterate through the predicted relevance scores
        for i in range(len(pred_rel)):
            # Check if the prediction at this rank is correct (i.e., if it is a relevant document)
            if pred_rel[i] == 1 and relevant[i] == 1:
                relevant_documents_count += 1
                precision_at_i = relevant_documents_count / (i + 1)
                cumulative_precision += precision_at_i

        # The average precision is the cumulative precision divided by the number of relevant documents
        average_precision = cumulative_precision / sum(relevant)
        return average_precision

    def mean_reciprocal_rank(relevant):
        for idx, value in enumerate(relevant):
            if value == 1:
                return 1 / (idx + 1)
        return 0

    def precision_at_k(relevant, k):
        return sum(relevant[:k]) / k

    rel = [0,1,0,1,0,1,0,0,1]
    k = 3
    # print(calculate_recall(rel, 4, k))
    print(f'MAP: {calculate_average_precision(rel)}')
    print(f'MRR: {mean_reciprocal_rank(rel)}')

    for k in range(1, 10):
        print(f'P@{k}: {precision_at_k(rel, k)}')
        print(f'Recall@{k}: {calculate_recall(rel, sum(rel), k)}')

tmp()

MAP: 0.4861111111111111
MRR: 0.5
P@1: 0.0
Recall@1: 0.0
P@2: 0.5
Recall@2: 0.25
P@3: 0.3333333333333333
Recall@3: 0.25
P@4: 0.5
Recall@4: 0.5
P@5: 0.4
Recall@5: 0.5
P@6: 0.5
Recall@6: 0.75
P@7: 0.42857142857142855
Recall@7: 0.75
P@8: 0.375
Recall@8: 0.75
P@9: 0.4444444444444444
Recall@9: 1.0


In [11]:
index_path = '../smalldata/fbr/index_commit_tokenized'
repo_path = '../smalldata/fbr/'
K=1000

Loaded index at ../smalldata/fbr/index_commit_tokenized
Index Stats: {'total_terms': 8061856, 'documents': 69835, 'non_empty_documents': 69835, 'unique_terms': 14589}


In [13]:
metrics = ['MAP', 'P@10', 'P@100', 'P@1000', 'MRR', 'Recall@1000']

In [15]:
combined_df = get_combined_df(repo_path)

In [202]:
bm25_searcher = BM25Search(index_path)
evaluator = SearchEvaluator(metrics)
bm25_evaluator = ModelEvaluator(bm25_searcher, evaluator, combined_df)

Loaded index at ../smalldata/fbr/index_commit_tokenized
Index Stats: {'total_terms': 8061856, 'documents': 69835, 'non_empty_documents': 69835, 'unique_terms': 14589}


In [56]:
# bm25_evaluator.evaluate_sampling(n=100, k=K, output_dir='../tmp/', evaluation_strategy='file')

In [18]:
# randomly sample a commit from the combined_df
random_commit = combined_df.sample(1).iloc[0]
random_commit

owner                                                             facebook
repo_name                                                            react
commit_date                                                     1541787469
commit_id                         1034e26fe5e42ba07492a736da7bdf5bf2108bc6
commit_message                                      Fix typos (#14124)\n\n
file_path                packages/react-dom/src/server/escapeTextForBro...
cur_file_content         /**\n * Copyright (c) Facebook, Inc. and its a...
previous_commit_id                5618da49d8cb9cdb6c623446bbd2c504ee6c0422
previous_file_path                                                    <NA>
previous_file_content    /**\n * Copyright (c) Facebook, Inc. and its a...
diff                     @@ -56,7 +56,7 @@ function escapeHtml(string) ...
status                                                            modified
is_merge_request                                                     False
file_extension           

In [203]:
# get search results for the random commit
search_results = bm25_searcher.search(random_commit['commit_message'], random_commit['commit_date'], ranking_depth=K)
search_results[:20]

[SearchResult(score=10.71162, file_path='packages/react-reconciler/src/ReactChildFiber.js', commit_id='73527237e6eb311de102791bb3fd2922cc28368a', commit_date=1508494517),
 SearchResult(score=10.71162, file_path='fixtures/dom/src/components/fixtures/text-inputs/index.js', commit_id='ccb2f82a833710030136909f4601dec00c2e2ddc', commit_date=1506507201),
 SearchResult(score=10.71162, file_path='scripts/bench/server.js', commit_id='ccb2f82a833710030136909f4601dec00c2e2ddc', commit_date=1506507201),
 SearchResult(score=10.71161, file_path='scripts/rollup/build.js', commit_id='ccb2f82a833710030136909f4601dec00c2e2ddc', commit_date=1506507201),
 SearchResult(score=10.71161, file_path='scripts/rollup/modules.js', commit_id='ccb2f82a833710030136909f4601dec00c2e2ddc', commit_date=1506507201),
 SearchResult(score=10.71161, file_path='src/renderers/shared/fiber/ReactFiberCompleteWork.js', commit_id='ccb2f82a833710030136909f4601dec00c2e2ddc', commit_date=1506507201),
 SearchResult(score=10.71161, file

In [204]:
aggregated_results = bm25_searcher.aggregate_file_scores(search_results)
aggregated_results[:20]

[AggregatedSearchResult(file_path='src/browser/ui/ReactDOMComponent.js', score=41.675281047821045, contributing_results=[SearchResult(score=9.41890, file_path='src/browser/ui/ReactDOMComponent.js', commit_id='78ec2501cf9d57fa38d3166ceae0e9b7811071c2', commit_date=1415408329), SearchResult(score=8.74300, file_path='src/browser/ui/ReactDOMComponent.js', commit_id='042c6c794c12b9f8d7cc7470736db896a5fc12a3', commit_date=1415380027), SearchResult(score=6.80090, file_path='src/browser/ui/ReactDOMComponent.js', commit_id='0c1eca7dfcf0e4375f7cfe1164823547fee1ad01', commit_date=1426868489), SearchResult(score=6.80090, file_path='src/browser/ui/ReactDOMComponent.js', commit_id='70f16cc936c8e41881b65d58b165fbbf93b5618f', commit_date=1426870969), SearchResult(score=5.56560, file_path='src/browser/ui/ReactDOMComponent.js', commit_id='de1dacdb2887349d09b839bb79d008ae0bf057b0', commit_date=1414025998), SearchResult(score=4.34599, file_path='src/browser/ui/ReactDOMComponent.js', commit_id='af819d122ef

In [205]:
len(search_results), len(aggregated_results)

(384, 279)

In [207]:
# evaluate the search results
evaluation = evaluator.evaluate(search_results, combined_df[combined_df['commit_id'] == random_commit['commit_id']]['file_path'].tolist())
evaluation

{'MAP': 0.0178,
 'P@10': 0.0,
 'P@100': 0.0,
 'P@1000': 0.006,
 'MRR': 0.0087,
 'Recall@1000': 0.3333}

In [206]:
file_based_evaluation = evaluator.evaluate(aggregated_results, combined_df[combined_df['commit_id'] == random_commit['commit_id']]['file_path'].tolist())
file_based_evaluation

{'MAP': 0.1667,
 'P@10': 0.1,
 'P@100': 0.01,
 'P@1000': 0.001,
 'MRR': 0.1667,
 'Recall@1000': 0.3333}

Full Search Experiments

In [124]:
# search_full = bm25_searcher.search_full(random_commit['commit_message'], random_commit['commit_date'], ranking_depth=K)
# search_full[:20]

In [123]:
# aggregated_results_full = bm25_searcher.aggregate_file_scores(search_full)
# aggregated_results_full[:20]

In [122]:
# len(search_full), len(aggregated_results_full)

In [127]:
# evaluation_full = evaluator.evaluate(search_full, combined_df[combined_df['commit_id'] == random_commit['commit_id']]['file_path'].tolist())
# evaluation_full

In [126]:
# file_based_evaluation_full = evaluator.evaluate(aggregated_results_full, combined_df[combined_df['commit_id'] == random_commit['commit_id']]['file_path'].tolist())
# file_based_evaluation_full

In [160]:
len(combined_df[combined_df['commit_id'] == random_commit['commit_id']]['file_path'].tolist())

3

In [208]:
# test different aggregation strategies in bm25 search
agggreagtion_strategies = ['sump', 'maxp', 'avgp']
for strategy in agggreagtion_strategies:
    print(f"Aggregation Strategy: {strategy}")
    aggregated_results = bm25_searcher.aggregate_file_scores(search_results, aggregation_method=strategy)
    evaluation = evaluator.evaluate(aggregated_results, combined_df[combined_df['commit_id'] == random_commit['commit_id']]['file_path'].tolist())
    print(evaluation)

Aggregation Strategy: sump
{'MAP': 0.1667, 'P@10': 0.1, 'P@100': 0.01, 'P@1000': 0.001, 'MRR': 0.1667, 'Recall@1000': 0.3333}
Aggregation Strategy: maxp
{'MAP': 0.0123, 'P@10': 0.0, 'P@100': 0.01, 'P@1000': 0.001, 'MRR': 0.0123, 'Recall@1000': 0.3333}
Aggregation Strategy: avgp
{'MAP': 0.0112, 'P@10': 0.0, 'P@100': 0.01, 'P@1000': 0.001, 'MRR': 0.0112, 'Recall@1000': 0.3333}


In [209]:
bm25_evaluator.evaluate_sampling(n=100, k=K, output_dir='../tmp/')

{'MAP': 0.2037,
 'P@10': 0.074,
 'P@100': 0.0202,
 'P@1000': 0.0027,
 'MRR': 0.2746,
 'Recall@1000': 0.6351}

In [64]:
for strategy in agggreagtion_strategies:
    print(f"Aggregation Strategy: {strategy}")
    results = bm25_evaluator.evaluate_sampling(n=100, k=K, output_dir='../tmp/', evaluation_strategy='file', aggregation_strategy=strategy)
    print(results)

Aggregation Strategy: sump
{'MAP': 0.0272, 'P@10': 0.074, 'P@100': 0.0202, 'P@1000': 0.0027, 'MRR': 0.2746, 'Recall@1000': 0.6351}
Aggregation Strategy: maxp
{'MAP': 0.0272, 'P@10': 0.06, 'P@100': 0.0181, 'P@1000': 0.0027, 'MRR': 0.2688, 'Recall@1000': 0.6351}
Aggregation Strategy: avgp
{'MAP': 0.0272, 'P@10': 0.048, 'P@100': 0.0149, 'P@1000': 0.0027, 'MRR': 0.1968, 'Recall@1000': 0.6351}


SUMP seems to be the best aggregation strategy.

In [50]:
# file_based_evaluation = evaluator.evaluate_file_based(search_results, combined_df[combined_df['commit_id'] == random_commit['commit_id']]['file_path'].tolist())
# file_based_evaluation

In [51]:
# # iterate over the list of aggregation strategies and evaluate each one

# aggregation_strategies = ['sump', 'maxp', 'firstp', 'avgp']
# for strategy in aggregation_strategies:
#     file_based_evaluation = evaluator.evaluate_file_based(search_results, combined_df[combined_df['commit_id'] == random_commit['commit_id']]['file_path'].tolist(), aggregation_strategy=strategy)
#     print(f"{strategy}: {file_based_evaluation}")

In [52]:
# # do the same for the model evaluator
# for strategy in aggregation_strategies:
#     model_evaluation = bm25_evaluator.evaluate_sampling(n=1000, k=K, output_dir='../tmp/', evaluation_strategy='file', aggregation_strategy=strategy)
#     print(f"{strategy}: {model_evaluation}")

In [127]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from collections import defaultdict

class BERTReRanker:
    """
    A class for performing reranking with a BERT-based model.
    """
    def __init__(self, model_name, psgLen=128, psgStride=64, psgCnt=None, scoreAggregation='maxp', batchSize=8):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1)
        self.model.eval()  # Set model to evaluation mode

        # Passage handling parameters
        self.psgLen = psgLen
        self.psgStride = psgStride
        self.psgCnt = psgCnt
        self.scoreAggregation = scoreAggregation

        if self.scoreAggregation == 'firstp':
            self.psgCnt = 1

        self.batchSize = batchSize

    def rerank(self, query, search_results):
        """
        Rerank the search results using the BERT model.

        query: The query string.
        search_results: A list of SearchResult objects.
        """
        reranked_results = []
        # todo - add batching

        # Process each SearchResult to create input for BERT
        for result in search_results:
            passages = self._split_into_passages(query, result.commit_msg)

            # Score each passage with BERT
            passage_scores = [self._score_passage(query, passage) for passage in passages]

            # Aggregate passage scores to get a single document score
            doc_score = self._aggregate_scores(passage_scores)

            # Create a new SearchResult with the updated score
            reranked_results.append((doc_score, result))

        # Sort reranked results by the new score
        reranked_results.sort(key=lambda x: x[0], reverse=True)
        return [result for _, result in reranked_results]

    def _split_into_passages(self, query, commit_msg):
        # Tokenize the query and commit message
        tokens = self.tokenizer.tokenize(query) + self.tokenizer.tokenize(commit_msg)

        # Split the tokens into passages
        passages = []
        for i in range(0, len(tokens), self.psgStride):
            passage = tokens[i:i+self.psgLen]
            passages.append(self.tokenizer.convert_tokens_to_string(passage))
            if self.psgCnt and len(passages) >= self.psgCnt:
                break
        return passages

    def _score_passage(self, query, passage):
        # Encode query and passage for BERT
        inputs = self.tokenizer.encode_plus(
            query,
            passage,
            add_special_tokens=True,
            max_length=self.psgLen,
            truncation="only_second",
            return_tensors="pt"
        )

        # Score the (query, passage) pair with BERT
        with torch.no_grad():
            outputs = self.model(**inputs)
            score = outputs.logits.squeeze().item()
        return score

    def _aggregate_scores(self, passage_scores):
        # Aggregate passage scores based on the specified strategy
        if self.scoreAggregation == 'firstp':
            return passage_scores[0]
        elif self.scoreAggregation == 'maxp':
            return max(passage_scores)
        elif self.scoreAggregation == 'avgp':
            return sum(passage_scores) / len(passage_scores)
        elif self.scoreAggregation == 'sump':
            return sum(passage_scores)
        else:
            raise ValueError(f"Invalid score aggregation method: {self.scoreAggregation}")

In [128]:
# Assuming bm25_searcher is an instance of BM25Search
query = "Fix login page error"
query_date = 1699261235
bm25_results = bm25_search.search(query, query_date, K)

# Now rerank those results with BERT
bert_reranker = BERTReRanker(model_name="bert-base-uncased")
reranked_results = bert_reranker.rerank(query, bm25_results)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Token indices sequence length is longer than the specified maximum sequence length for this model (791 > 512). Running this sequence through the model will result in indexing errors


In [130]:
# Print the top 10 results
reranked_results[:10]

[SearchResult(score: 2.69137, file_path: 'packages/react-reconciler/src/ReactFiberBeginWork.new.js', commit_id: '1faf9e3dd5d6492f3607d5c721055819e4106bc6', commit_date: 1601495853),
 SearchResult(score: 2.69136, file_path: 'packages/react-reconciler/src/ReactFiberBeginWork.old.js', commit_id: '1faf9e3dd5d6492f3607d5c721055819e4106bc6', commit_date: 1601495853),
 SearchResult(score: 2.69136, file_path: 'packages/react-reconciler/src/ReactFiberSuspenseComponent.new.js', commit_id: '1faf9e3dd5d6492f3607d5c721055819e4106bc6', commit_date: 1601495853),
 SearchResult(score: 2.69136, file_path: 'packages/react-reconciler/src/ReactFiberSuspenseComponent.old.js', commit_id: '1faf9e3dd5d6492f3607d5c721055819e4106bc6', commit_date: 1601495853),
 SearchResult(score: 2.69136, file_path: 'packages/react-reconciler/src/__tests__/ReactCPUSuspense-test.js', commit_id: '1faf9e3dd5d6492f3607d5c721055819e4106bc6', commit_date: 1601495853),
 SearchResult(score: 3.26617, file_path: 'packages/react-dom-bindi

In [131]:
# evaluate the reranked results
evaluation = evaluator.evaluate(reranked_results, combined_df[combined_df['commit_id'] == random_commit['commit_id']]['file_path'].tolist())
evaluation

{'MAP': 0.016,
 'P@10': 0.0,
 'P@100': 0.0,
 'P@1000': 0.016,
 'MRR': 0.0031,
 'Recall@1000': 0.3913}

In [137]:
evaluation = evaluator.evaluate_file_based(reranked_results, combined_df[combined_df['commit_id'] == random_commit['commit_id']]['file_path'].tolist())
evaluation

defaultdict(<class 'list'>, {'packages/react-reconciler/src/ReactFiberCompleteWork.js': [2.718986988067627, 3.8634390830993652, 3.863529920578003, 4.273975372314453], 'packages/shared/ReactSymbols.js': [3.5164780616760254], 'packages/shared/ReactTypes.js': [3.6565749645233154, 3.8633739948272705, 3.863521099090576], 'scripts/rollup/bundles.js': [4.612883567810059, 5.564578056335449, 2.7962939739227295], 'packages/react-dom/src/shared/assertValidProps.js': [5.564676284790039], 'packages/react-dom/src/events/DOMEventResponderSystem.js': [3.8635449409484863], 'packages/react-dom/src/events/__tests__/DOMEventResponderSystem-test.internal.js': [3.86354398727417], 'packages/react-events/src/FocusScope.js': [3.863539934158325], 'packages/shared/getComponentName.js': [3.8635189533233643]})
[('packages/react-reconciler/src/ReactFiberCompleteWork.js', 14.719931364059448), ('scripts/rollup/bundles.js', 12.973755598068237), ('packages/shared/ReactTypes.js', 11.383470058441162), ('packages/react-do

{'MAP': 0.3913,
 'P@10': 0.9,
 'P@100': 0.09,
 'P@1000': 0.009,
 'MRR': 1.0,
 'Recall@1000': 0.3913}

In [None]:
{'MAP': 0.0546,
 'P@10': 0.0,
 'P@100': 0.05,
 'P@1000': 0.019,
 'MRR': 0.0119,
 'Recall@1000': 0.3478}

{'MAP': 0.3478,
 'P@10': 0.8,
 'P@100': 0.08,
 'P@1000': 0.008,
 'MRR': 1.0,
 'Recall@1000': 0.3478}