In [1]:


import torch
from transformers import BertTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from tqdm import tqdm
import os
import json
from typing import List, Tuple
import asyncio


In [None]:

class BERTRetriever:
    def __init__(self, model_name: str = 'bert-base-uncased', max_length: int = 512):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.model = BertModel.from_pretrained(model_name).to(self.device)
        self.max_length = max_length
        
        print(f"BERT model '{model_name}' and tokenizer loaded successfully.")
        
        self.documents = []
        self.document_embeddings = None

    def load_documents(self, file_path: str):
        with open(file_path, 'r', encoding='utf-8') as file:
            self.documents = [doc.strip() for doc in file.readlines() if doc.strip()]
        print(f"Loaded {len(self.documents)} documents.")

    def encode_text(self, text: str) -> np.ndarray:
        inputs = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, 
                                padding='max_length', truncation=True)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.last_hidden_state[:, 0, :].cpu().numpy()

    def encode_documents(self, batch_size: int = 32):
        embeddings = []
        for i in tqdm(range(0, len(self.documents), batch_size), desc="Encoding documents"):
            batch = self.documents[i:i+batch_size]
            batch_embeddings = self.encode_text(" ".join(batch))
            embeddings.extend(batch_embeddings)
        self.document_embeddings = np.vstack(embeddings)
        print(f"Encoded {len(self.documents)} documents.")

    def retrieve_documents(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
        query_embedding = self.encode_text(query)
        similarities = cosine_similarity(query_embedding, self.document_embeddings)[0]
        top_indices = similarities.argsort()[-top_k:][::-1]
        return [(self.documents[i], similarities[i]) for i in top_indices]

    def evaluate_retrieval(self, queries: List[str], relevant_docs: List[List[int]], 
                           k: int = 5) -> Tuple[float, float]:
        precisions, recalls = [], []
        for query, relevant in zip(queries, relevant_docs):
            retrieved = self.retrieve_documents(query, top_k=k)
            retrieved_indices = [self.documents.index(doc) for doc, _ in retrieved]
            
            true_positives = len(set(retrieved_indices) & set(relevant))
            precision = true_positives / k
            recall = true_positives / len(relevant)
            
            precisions.append(precision)
            recalls.append(recall)
        
        avg_precision = sum(precisions) / len(precisions)
        avg_recall = sum(recalls) / len(recalls)
        return avg_precision, avg_recall
    
    async def aretrieve_documents(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
        query_embedding = self.encode_text(query)
        similarities = cosine_similarity(query_embedding, self.document_embeddings)[0]
        top_indices = similarities.argsort()[-top_k:][::-1]
        return [(self.documents[i], similarities[i]) for i in top_indices]

    async def process_requirements(self, input_folder: str, output_folder: str):
        os.makedirs(output_folder, exist_ok=True)
        
        for filename in os.listdir(input_folder):
            if filename.endswith('.txt'):
                input_file_path = os.path.join(input_folder, filename)
                with open(input_file_path, 'r', encoding='utf-8') as file:
                    req_content = file.read().strip()
                
                query = f"Identify content that are relevant to ensuring compliance or completeness of the following requirement: {req_content}"
                results = await self.aretrieve_documents(query)
                
                output_file_path = os.path.join(output_folder, filename)
                with open(output_file_path, 'w', encoding='utf-8') as output_file:
                    output_file.write(f"Query: {query}\n\nRelevant Documents:\n")
                    for doc, score in results:
                        output_file.write(f"Score: {score:.4f}\nDocument: {doc}\n\n")
                
                print(f"Written content to {output_file_path}")


def load_json_file(file_path: str) -> dict:
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)

async def amain():
    retriever = BERTRetriever()
    retriever.load_documents('input/articles.txt')
    retriever.encode_documents()

    input_folder = "documents/requirements/FR/modified"
    output_folder = "output/search_results/BERT"
    await retriever.process_requirements(input_folder, output_folder)

def main():
    try:
        asyncio.run(amain())
    except RuntimeError:
        # If we're in an environment with an existing event loop
        loop = asyncio.get_event_loop()
        if loop.is_running():
            loop.create_task(amain())
        else:
            loop.run_until_complete(amain())

if __name__ == "__main__":
    main()

  loop.create_task(amain())


Using device: cpu
BERT model 'bert-base-uncased' and tokenizer loaded successfully.
Loaded 2318 documents.


Encoding documents: 100%|██████████| 73/73 [00:23<00:00,  3.09it/s]


Encoded 2318 documents.
Written content to output/search_results/BERT\req_16.txt
Written content to output/search_results/BERT\req_17.txt
Written content to output/search_results/BERT\req_18.txt
Written content to output/search_results/BERT\req_22.txt
Written content to output/search_results/BERT\req_25.txt
Written content to output/search_results/BERT\req_3.txt
Written content to output/search_results/BERT\req_40.txt
Written content to output/search_results/BERT\req_5.txt
Written content to output/search_results/BERT\req_7.txt
Written content to output/search_results/BERT\req_8.txt
