In [1]:
import math
from collections import Counter
from typing import List, Dict, Union, Tuple
import numpy as np
from tqdm import tqdm
import json
import os
import asyncio

In [6]:


class BM25Retriever:
    def __init__(self, k1: float = 1.5, b: float = 0.75):
        self.k1 = k1
        self.b = b
        self.corpus = []
        self.corpus_size = 0
        self.avgdl = 0
        self.doc_freqs = []
        self.idf = {}
        self.doc_len = []

    def load_documents(self, file_path: str):
        with open(file_path, 'r', encoding='utf-8') as file:
            self.corpus = [doc.strip() for doc in file.readlines() if doc.strip()]
        self.corpus_size = len(self.corpus)
        print(f"Loaded {self.corpus_size} documents.")
        self.initialize()

    def initialize(self):
        print("Initializing BM25 model...")
        self.avgdl = sum(len(doc.split()) for doc in self.corpus) / self.corpus_size

        for document in tqdm(self.corpus, desc="Processing documents"):
            words = document.split()
            self.doc_len.append(len(words))
            word_freq = Counter(words)
            self.doc_freqs.append(word_freq)

            for word, freq in word_freq.items():
                if word not in self.idf:
                    self.idf[word] = 0
                self.idf[word] += 1

        for word, freq in self.idf.items():
            self.idf[word] = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5)

        print("BM25 model initialized successfully.")

    def get_score(self, query: str, index: int) -> float:
        score = 0.0
        doc_freq = self.doc_freqs[index]
        doc_len = self.doc_len[index]

        for word in query.split():
            if word not in doc_freq:
                continue
            freq = doc_freq[word]
            numerator = self.idf.get(word, 0) * freq * (self.k1 + 1)
            denominator = freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)
            score += numerator / denominator

        return score

    def retrieve_documents(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
        scores = [(i, self.get_score(query, i)) for i in range(self.corpus_size)]
        top_scores = sorted(scores, key=lambda x: x[1], reverse=True)[:top_k]
        return [(self.corpus[i], score) for i, score in top_scores]

    def evaluate_retrieval(self, queries: List[str], relevant_docs: List[List[int]], 
                           k: int = 5) -> Tuple[float, float]:
        precisions, recalls = [], []
        for query, relevant in tqdm(zip(queries, relevant_docs), desc="Evaluating queries", total=len(queries)):
            retrieved = self.retrieve_documents(query, top_k=k)
            retrieved_indices = [self.corpus.index(doc) for doc, _ in retrieved]
            
            true_positives = len(set(retrieved_indices) & set(relevant))
            precision = true_positives / k
            recall = true_positives / len(relevant)
            
            precisions.append(precision)
            recalls.append(recall)
        
        avg_precision = sum(precisions) / len(precisions)
        avg_recall = sum(recalls) / len(recalls)
        return avg_precision, avg_recall
    
    async def aretrieve_documents(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
        scores = [(i, self.get_score(query, i)) for i in range(self.corpus_size)]
        top_scores = sorted(scores, key=lambda x: x[1], reverse=True)[:top_k]
        return [(self.corpus[i], score) for i, score in top_scores]

    async def process_requirements(self, input_folder: str, output_folder: str):
        os.makedirs(output_folder, exist_ok=True)
        
        for filename in os.listdir(input_folder):
            if filename.endswith('.txt'):
                input_file_path = os.path.join(input_folder, filename)
                with open(input_file_path, 'r', encoding='utf-8') as file:
                    req_content = file.read().strip()
                
                query = f"Identify content that are relevant to ensuring compliance or completeness of the following requirement: {req_content}"
                results = await self.aretrieve_documents(query)
                
                output_file_path = os.path.join(output_folder, filename)
                with open(output_file_path, 'w', encoding='utf-8') as output_file:
                    output_file.write(f"Query: {query}\n\nRelevant Documents:\n")
                    for doc, score in results:
                        output_file.write(f"Score: {score:.4f}\nDocument: {doc}\n\n")
                
                print(f"Written content to {output_file_path}")

def load_json_file(file_path: str) -> dict:
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)

async def amain():
    retriever = BM25Retriever()
    retriever.load_documents('input/articles.txt')

    input_folder = "documents/requirements/FR/modified"
    output_folder = "output/search_results/BM25"
    await retriever.process_requirements(input_folder, output_folder)

def main():
    try:
        asyncio.run(amain())
    except RuntimeError:
        loop = asyncio.get_event_loop()
        if loop.is_running():
            loop.create_task(amain())
        else:
            loop.run_until_complete(amain())

if __name__ == "__main__":
    main()


  loop.create_task(amain())


Loaded 2318 documents.
Initializing BM25 model...


Processing documents: 100%|██████████| 2318/2318 [00:00<00:00, 272572.73it/s]


BM25 model initialized successfully.
Written content to output/search_results/BM25\req_16.txt
Written content to output/search_results/BM25\req_17.txt
Written content to output/search_results/BM25\req_18.txt
Written content to output/search_results/BM25\req_22.txt
Written content to output/search_results/BM25\req_25.txt
Written content to output/search_results/BM25\req_3.txt
Written content to output/search_results/BM25\req_40.txt
Written content to output/search_results/BM25\req_5.txt
Written content to output/search_results/BM25\req_7.txt
Written content to output/search_results/BM25\req_8.txt
