In [None]:
import os
import streamlit as st
from dotenv import load_dotenv
from langchain_community.graphs import Neo4jGraph
from langchain.chains import GraphCypherQAChain
from langchain_google_genai import ChatGoogleGenerativeAI
import pandas as pd
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score

# Evaluation metrics libraries
from nltk.translate.bleu_score import sentence_bleu
from rouge import Rouge

# NLTK resources
import nltk
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab')

# Load environment variables
load_dotenv()

# Retrieve environment variables
NEO4J_URL = os.getenv("NEO4J_URL", "bolt://localhost:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PW", "admin123")
NEO4J_DATABASE = os.getenv("NEO4J_DB", "spotify1")

GEMINI_API = os.getenv("GEMINI_API", "")

@st.cache_resource
def graph_chain():
    graph = Neo4jGraph(NEO4J_URL, NEO4J_USER, NEO4J_PASSWORD, NEO4J_DATABASE)
    llm = ChatGoogleGenerativeAI(
        model="gemini-2.0-flash-lite", google_api_key=GEMINI_API, temperature=0
    )
    chain = GraphCypherQAChain.from_llm(allow_dangerous_requests=True,
        graph=graph, llm=llm, return_intermediate_steps=True, verbose=True
    )
    return chain

def infer(chain, prompt):
    """
    Generate response using the RAG system
    
    Args:
        chain: GraphCypherQAChain
        prompt: Input query
    
    Returns:
        tuple: (generated query, retrieved context, result)
    """
    response = chain.invoke(prompt)
    query = response["intermediate_steps"][0]["query"]
    context = response["intermediate_steps"][1]["context"]
    result = response["result"]
    return query, context, result

class RAGEvaluator:
    def __init__(self, ground_truth_data):
        """
        Initialize RAG Evaluator with ground truth data
        
        Args:
            ground_truth_data (list): List of dictionaries with 'query' and 'expected_answer' keys
        """
        self.ground_truth_data = ground_truth_data
        self.evaluation_results = []
        self.rouge_scorer = Rouge()
    
    def calculate_bleu_score(self, predicted, expected):
        """
        Calculate BLEU score
        
        Args:
            predicted (str): Predicted answer
            expected (str): Expected answer
        
        Returns:
            float: BLEU score
        """
        # Tokenize the sentences
        predicted_tokens = nltk.word_tokenize(predicted.lower())
        expected_tokens = [nltk.word_tokenize(expected.lower())]
        
        # Calculate BLEU score
        try:
            bleu_score = sentence_bleu(expected_tokens, predicted_tokens)
        except ZeroDivisionError:
            bleu_score = 0
        
        return bleu_score
    
    def calculate_rouge_scores(self, predicted, expected):
        """
        Calculate ROUGE scores
        
        Args:
            predicted (str): Predicted answer
            expected (str): Expected answer
        
        Returns:
            dict: ROUGE-1, ROUGE-2, and ROUGE-L scores
        """
        try:
            rouge_scores = self.rouge_scorer.get_scores(predicted, expected)[0]
            return {
                'rouge-1': rouge_scores['rouge-1']['f'],
                'rouge-2': rouge_scores['rouge-2']['f'],
                'rouge-l': rouge_scores['rouge-l']['f']
            }
        except Exception as e:
            print(f"Error calculating ROUGE scores: {e}")
            return {
                'rouge-1': 0,
                'rouge-2': 0,
                'rouge-l': 0
            }
    
    def calculate_exact_match(self, predicted, expected):
        """
        Calculate exact match score
        
        Args:
            predicted (str): Predicted answer
            expected (str): Expected answer
        
        Returns:
            float: Exact match score (0 or 1)
        """
        return 1 if predicted.strip().lower() == expected.strip().lower() else 0
    
    def calculate_partial_match(self, predicted, expected, threshold=0.7):
        """
        Calculate partial match score using word overlap
        
        Args:
            predicted (str): Predicted answer
            expected (str): Expected answer
            threshold (float): Matching threshold
        
        Returns:
            float: Partial match score
        """
        predicted_words = set(predicted.lower().split())
        expected_words = set(expected.lower().split())
        
        overlap = len(predicted_words.intersection(expected_words))
        max_words = max(len(predicted_words), len(expected_words))
        
        return 1 if overlap / max_words >= threshold else 0
    
    def evaluate_relevance(self, context, expected_keywords):
        """
        Evaluate context relevance
        
        Args:
            context (str): Retrieved context
            expected_keywords (list): List of keywords expected in context
        
        Returns:
            float: Relevance score
        """
        context_lower = context.lower()
        keyword_matches = [
            1 if keyword.lower() in context_lower else 0 
            for keyword in expected_keywords
        ]
        return np.mean(keyword_matches) if keyword_matches else 0
    
    def run_evaluation(self, chain):
        """
        Run comprehensive evaluation of RAG system
        
        Args:
            chain: GraphCypherQAChain
        
        Returns:
            pandas.DataFrame: Evaluation metrics
        """
        for item in self.ground_truth_data:
            query = item['query']
            expected_answer = item['expected_answer']
            expected_keywords = item.get('expected_keywords', [])
            
            # Generate RAG response
            generated_query, retrieved_context, predicted_answer = infer(chain, query)
            
            # Calculate evaluation metrics
            #exact_match = self.calculate_exact_match(predicted_answer, expected_answer)
            #partial_match = self.calculate_partial_match(predicted_answer, expected_answer)
            #context_relevance = self.calculate_relevance(retrieved_context, expected_keywords)
            
            # Calculate BLEU and ROUGE scores
            #bleu_score = self.calculate_bleu_score(predicted_answer, expected_answer)
            rouge_scores = self.calculate_rouge_scores(predicted_answer, expected_answer)
            
            # Collect evaluation results
            self.evaluation_results.append({
                'Query': query,
                'Generated Query': generated_query,
                'Retrieved Context': retrieved_context,
                'Predicted Answer': predicted_answer,
                'Expected Answer': expected_answer,
                #'Exact Match': exact_match,
                #'Partial Match': partial_match,
                #'Context Relevance': context_relevance,
                #'BLEU Score': bleu_score,
                'ROUGE-1 Score': rouge_scores['rouge-1'],
                'ROUGE-2 Score': rouge_scores['rouge-2'],
                'ROUGE-L Score': rouge_scores['rouge-l']
            })
        
        # Convert results to DataFrame
        df_results = pd.DataFrame(self.evaluation_results)
        
        # Calculate overall metrics
        metrics = {
            #'Average Exact Match': df_results['Exact Match'].mean(),
            #'Average Partial Match': df_results['Partial Match'].mean(),
            #'Average Context Relevance': df_results['Context Relevance'].mean(),
            #'Average BLEU Score': df_results['BLEU Score'].mean(),
            'Average ROUGE-1 Score': df_results['ROUGE-1 Score'].mean(),
            'Average ROUGE-2 Score': df_results['ROUGE-2 Score'].mean(),
            'Average ROUGE-L Score': df_results['ROUGE-L Score'].mean()
        }
        
        return df_results, metrics

def main():
    # Example ground truth data with multiple queries
   
    ground_truth_data = [
        {
            'query': 'What is the highest number of streams?',
            'expected_answer': 'The highest number of streams is 3703895074',
            'expected_keywords': ['highest', 'streams', '3703895074']
        },
        # {
        #     'query': 'Which track has the lowest danceability?',
        #     'expected_answer': 'The track with the lowest danceability is "vampire" with 51%',
        #     'expected_keywords': ['lowest', 'danceability', 'vampire', '51']
        # },
        # {
        #     'query': 'How many songs were released in 2023?',
        #     'expected_answer': 'There are X songs released in 2023',
        #     'expected_keywords': ['songs', 'released', '2023']
        # },
        {
            'query': 'What is the average energy of all tracks?',
            'expected_answer': 'The average energy of all tracks is 513597931.3137464',
            'expected_keywords': ['average', 'energy']
        },
        {
            'query': 'Which artist appears most frequently?',
            'expected_answer': 'The artist that appears most frequently is "Bad Bunny"',
            'expected_keywords': ['most', 'frequently', 'artist', 'Bad Bunny']
        },
        {
            'query': 'Which song has the highest valence?',
            'expected_answer': 'The song with the highest valence is "Seven" ',
            'expected_keywords': ['highest', 'valence', 'Seven']
        },
        {
            'query': 'Which month had the most song releases?',
            'expected_answer': 'The month with the most releases is January',
            'expected_keywords': ['most', 'releases', 'August']
        },
        # {
        #     'query': 'What is the average speechiness of tracks?',
        #     'expected_answer': 'The average speechiness is Z%',
        #     'expected_keywords': ['average', 'speechiness']
        # },
        # {
        #     'query': 'Which track has the highest acousticness?',
        #     'expected_answer': 'The track with the highest acousticness is "vampire" with 17%',
        #     'expected_keywords': ['highest', 'acousticness', 'vampire']
        # },
        # {
        #     'query': 'How many tracks are in minor key?',
        #     'expected_answer': 'There are N tracks in minor key',
        #     'expected_keywords': ['tracks', 'minor key']
        # },
        {
            'query': 'What is the average number of Spotify playlists a track appears in?',
            'expected_answer': 'The 5200.124868835249 is M playlists',
            'expected_keywords': ['average', 'Spotify', 'playlists']
        },
        # {
        #     'query': 'Which song has the highest BPM?',
        #     'expected_answer': 'The song with the highest BPM is "Cruel Summer" with 170',
        #     'expected_keywords': ['highest', 'BPM', 'Cruel Summer']
        # },
        # {
        #     'query': 'Which song has the most Apple Music chart appearances?',
        #     'expected_answer': 'The song with most Apple chart appearances is "Cruel Summer"',
        #     'expected_keywords': ['most', 'Apple', 'charts', 'Cruel Summer']
        # },
        # {
        #     'query': 'What is the average number of Deezer playlists per track?',
        #     'expected_answer': 'The average is D playlists',
        #     'expected_keywords': ['average', 'Deezer', 'playlists']
        # },
        {
            'query': 'Which song had the earliest release date?',
            'expected_answer': 'The song with the earliest release is "Agudo Mï¿½ï¿½gi"',
            'expected_keywords': ['earliest', 'release', 'Agudo Mï¿½ï¿½gi']
        },
        # {
        #     'query': 'Which song has the most instrumentalness?',
        #     'expected_answer': 'The song with the most instrumentalness is "WHERE SHE GOES" with 63%',
        #     'expected_keywords': ['most', 'instrumentalness', 'WHERE SHE GOES']
        # },
        {
            'query': 'What is the total number of streams for all songs?',
            'expected_answer': 'The total number of streams is 489458828542',
            'expected_keywords': ['total', 'streams']
        },
        # {
        #     'query': 'Which song has the lowest valence?',
        #     'expected_answer': 'The song with the lowest valence is "WHERE SHE GOES" with 23%',
        #     'expected_keywords': ['lowest', 'valence', 'WHERE SHE GOES']
        # },
        {
            'query': 'Which artist has the highest average danceability across their songs?',
            'expected_answer': 'The artist with highest average danceability is "Latto"',
            'expected_keywords': ['highest', 'average', 'danceability', 'Latto']
        },
        {
            'query': 'Which song appears in the most Spotify charts?',
            'expected_answer': 'The song with most Spotify chart appearances is "Seven"',
            'expected_keywords': ['most', 'Spotify', 'charts', 'Seven']
        }
    ]

        


    
    # Initialize graph chain and evaluator
    chain = graph_chain()
    evaluator = RAGEvaluator(ground_truth_data)
    
    # Run evaluation
    detailed_results, overall_metrics = evaluator.run_evaluation(chain)
    
    # Print results
    print("Detailed Evaluation Results:")
    print(detailed_results)
    print("\nOverall Metrics:")
    print(overall_metrics)
    
    # Optional: Save results to CSV
    detailed_results.to_csv('rag_evaluation_results.csv', index=False)
    
    return detailed_results, overall_metrics

if __name__ == "__main__":
    main()

[nltk_data] Downloading package punkt_tab to /Users/admin/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!




[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (t:Track)
RETURN max(t.streams)
[0m
Full Context:
[32;1m[1;3m[{'max(t.streams)': 3703895074}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (t:Track)
RETURN avg(t.streams)
[0m
Full Context:
[32;1m[1;3m[{'avg(t.streams)': 513597931.3137464}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (a:Artist)-[:PERFORMED]->(t:Track)
RETURN a.name, count(t) AS trackCount
ORDER BY trackCount DESC
LIMIT 1
[0m
Full Context:
[32;1m[1;3m[{'a.name': 'Bad Bunny', 'trackCount': 40}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m




Generated Cypher:
[32;1m[1;3mcypher
MATCH (t:Track)
RETURN t.name, t.valence
ORDER BY t.valence DESC
LIMIT 1
[0m
Full Context:
[32;1m[1;3m[{'t.name': 'Seven (feat. Latto) (Explicit Ver.)', 't.valence': None}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (t:Track)
RETURN t.released_month, count(t) AS num_releases
ORDER BY num_releases DESC
LIMIT 1
[0m
Full Context:
[32;1m[1;3m[{'t.released_month': 1, 'num_releases': 134}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (t:Track)
RETURN avg(t.spotify_playlists)
[0m
Full Context:
[32;1m[1;3m[{'avg(t.spotify_playlists)': 5200.124868835249}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (t:Track)
RETURN t.name, t.released_day, t.released_month, t.released_year
ORDER BY t.released_year, t.r