In [None]:
# =========================================================
# RAG System Performance Evaluation (Recall@K)
# =========================================================

import os
import sys
import copy
import unicodedata
import pandas as pd
import torch
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Libraries
from langchain_community.document_loaders import CSVLoader
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.retrievers import EnsembleRetriever
from sentence_transformers import CrossEncoder

class NaturalQueryEvaluator:
    def __init__(self, data_path, query_path):
        self.data_path = data_path
        self.query_path = query_path
        
        # Check device
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Device: {self.device.upper()}")

        # 1. Load Embedding Model
        print("Loading embedding model (all-MiniLM-L6-v2)...")
        self.embeddings = HuggingFaceEmbeddings(
            model_name="sentence-transformers/all-MiniLM-L6-v2",
            model_kwargs={'device': self.device}
        )

        # 2. Load Re-ranker Model
        print("Loading re-ranker model (ms-marco-MiniLM-L-6-v2)...")
        self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=self.device)

        self.doc_map = {}
        self.docs_full = []
        self.docs_meta = []

    def _normalize(self, text):
        return unicodedata.normalize('NFKD', str(text)).encode('ascii', 'ignore').decode('utf-8').lower().strip()

    def load_and_preprocess(self):
        print("Loading data...")
        if not os.path.exists(self.data_path):
            raise FileNotFoundError(f"Data file not found: {self.data_path}")

        loader = CSVLoader(self.data_path, encoding='utf-8', content_columns=['track_name', 'artist_name', 'lyrics', 'genre'])
        self.docs_full = loader.load()

        self.docs_meta = []
        for i, doc in enumerate(self.docs_full):
            doc.metadata['doc_id'] = i
            lines = doc.page_content.split('\n')
            
            try:
                track_line = next((line for line in lines if "track_name:" in line), "")
                artist_line = next((line for line in lines if "artist_name:" in line), "")

                track = track_line.split(':', 1)[1].strip() if track_line else ""
                artist = artist_line.split(':', 1)[1].strip() if artist_line else ""

                key = (self._normalize(track), self._normalize(artist))
                if track and artist:
                    self.doc_map[key] = i
            except: pass

            # Create metadata-focused document for BM25
            new_doc = copy.deepcopy(doc)
            new_doc.page_content = self._normalize(" ".join(lines[:3]))
            new_doc.metadata['doc_id'] = i
            self.docs_meta.append(new_doc)

        print(f"Data loaded: {len(self.docs_full)} documents")

    def build_engine(self):
        print("Building hybrid search engine...")

        # 1. BM25
        bm25 = BM25Retriever.from_documents(self.docs_meta)
        bm25.k = 50

        # 2. FAISS
        vectorstore = FAISS.from_documents(self.docs_full, self.embeddings)
        faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 50})

        # 3. Ensemble
        self.ensemble_retriever = EnsembleRetriever(
            retrievers=[bm25, faiss_retriever],
            weights=[0.5, 0.5]
        )

    def run_evaluation(self, sample_size=1000):
        print(f"Starting evaluation (Target N={sample_size})...")

        if not os.path.exists(self.query_path):
             raise FileNotFoundError(f"Query file not found: {self.query_path}")

        df_queries = pd.read_csv(self.query_path)

        if len(df_queries) > sample_size:
            df_queries = df_queries.sample(sample_size, random_state=42)
        else:
            sample_size = len(df_queries)

        k_values = [1, 3, 5, 10]
        hits = {k: 0 for k in k_values}
        total = 0

        for _, row in tqdm(df_queries.iterrows(), total=len(df_queries), desc="Processing"):
            query = self._normalize(row['question'])

            try:
                true_track = self._normalize(row['ground_truth_track'])
                true_artist = self._normalize(row['ground_truth_artist'])
                ground_truth_id = self.doc_map.get((true_track, true_artist))
                if ground_truth_id is None: continue 
            except: continue

            total += 1

            try:
                # 1. Retrieval (Ensemble)
                results = self.ensemble_retriever.invoke(query)

                # 2. Re-ranking (Cross-Encoder)
                pairs = [[query, doc.page_content] for doc in results]
                
                if pairs:
                    scores = self.reranker.predict(pairs)
                    scored_results = sorted(zip(results, scores), key=lambda x: x[1], reverse=True)
                    reranked_docs = [doc for doc, score in scored_results]
                else:
                    reranked_docs = results

                # Check hits
                found_ids = [doc.metadata.get('doc_id') for doc in reranked_docs]

                for k in k_values:
                    if ground_truth_id in found_ids[:k]:
                        hits[k] += 1
            except: continue

        # Report
        print("\n" + "="*50)
        print(f"Final Report: {total} samples evaluated")
        print("="*50)

        recall_scores = []
        for k in k_values:
            score = (hits[k] / total) * 100 if total > 0 else 0
            recall_scores.append(score)
            print(f"Recall@{k:<2} : {score:.2f}%")

        # Plot
        plt.figure(figsize=(10, 6))
        plt.plot(k_values, recall_scores, 'b*-', linewidth=2, markersize=10, label='Re-ranker (Final)')
        plt.title(f'RAG Performance on {total} Queries', fontsize=16)
        plt.xlabel('Top-K', fontsize=12)
        plt.ylabel('Recall (%)', fontsize=12)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.ylim(0, 105)
        plt.xticks(k_values)
        for x, y in zip(k_values, recall_scores):
            plt.text(x, y+3, f"{y:.1f}%", ha='center', fontweight='bold')
        plt.legend()
        plt.show()

if __name__ == "__main__":
    DATA_FILE = "final_preprocessed_music_data.csv"
    QUERY_FILE = "generated_query_set_1000_llm.csv"

    print(f"Data path: {os.path.abspath(DATA_FILE)}")
    print(f"Query path: {os.path.abspath(QUERY_FILE)}")

    if os.path.exists(DATA_FILE) and os.path.exists(QUERY_FILE):
        evaluator = NaturalQueryEvaluator(DATA_FILE, QUERY_FILE)
        evaluator.load_and_preprocess()
        evaluator.build_engine()
        evaluator.run_evaluation(sample_size=1000)
    else:
        print("Error: Input files not found. Please check the file paths.")