In [1]:
%pip install langchain langchain-community langchain-huggingface neo4j pandas spacy trans
!python -m spacy download en_core_web_sm

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.0.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip

[notice] A new release of pip is available: 25.0.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
     ---------------------------------------- 0.0/12.8 MB ? eta -:--:--
      --------------------------------------- 0.3/12.8 MB ? eta -:--:--
     - -------------------------------------- 0.5/12.8 MB 1.3 MB/s eta 0:00:10
     -- ------------------------------------- 0.8/12.8 MB 1.4 MB/s eta 0:00:09
     --- ------------------------------------ 1.0/12.8 MB 1.5 MB/s eta 0:00:09
     ---- ----------------------------------- 1.3/12.8 MB 1.3 MB/s eta 0:00:10
     ---- ----------------------------------- 1.3/12.8 MB 1.3 MB/s eta 0:00:10
     ---- ----------------------------------- 1.6/12.8 MB 1.2 MB/s eta 0:00:10
     ----- ---------------------------------- 1.8/12.8 MB 1.2 MB/s eta 0:00:10
     ------ --------------------------------- 2.1/12.8 MB 1.2 MB/s eta 0:00:09
     ------- -------------------------------- 2.

In [18]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")

False
No GPU


In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
from langchain_community.graphs import Neo4jGraph
from neo4j import GraphDatabase

In [14]:
config = {}

with open('config.txt', 'r') as file:
    for line in file:
        if "=" in line:
            key, value = line.split('=', 1)
            config[key.strip()] = value.strip()

uri = config.get('URI')
username = config.get('USERNAME')
password = config.get('PASSWORD')
driver = GraphDatabase.driver(uri, auth=(username, password))
print("Connected to Neo4j database")

Connected to Neo4j database


In [15]:
# Connect using the LangChain wrapper
graph = Neo4jGraph(
    url=uri,
    username=username,
    password=password,
    refresh_schema= False
)
# Ensure the connection is working by running a quick query (optional)
print(graph.query("MATCH (s:Season) RETURN s"))

  graph = Neo4jGraph(


[{'s': {'season_name': '2021-22'}}, {'s': {'season_name': '2022-23'}}]


In [16]:
import spacy
import re

nlp = spacy.load("en_core_web_sm")
# Load the kb from the graph database (optional enhancement)
def load_fpl_kb(graph: Neo4jGraph) -> dict:
    kb = {
        "players": [],
        "teams": [],
        "positions": ["gk","gkp", "def", "mid", "fwd", "goalkeeper","goalkeepers" "defender", "midfielder", "forward","defenders","midfielders","forwards"],
        "stats": {}
    }

    # Load players
    player_results = graph.query("MATCH (p:Player) RETURN p.player_name AS name")
    kb["players"] = [record["name"] for record in player_results]

    # Load teams
    team_results = graph.query("MATCH (t:Team) RETURN t.name AS name")
    kb["teams"] = [record["name"] for record in team_results]

    # Load stats mapping
    kb["stats"] = {
        "points": "total_points",
        "goals": "goals_scored",
        "assists": "assists",
        "minutes": "minutes",
        "bonus": "bonus",
        "influence": "influence",
        "creativity": "creativity",
        "threat": "threat",
        "ict": "ict_index",
        "clean sheets": "clean_sheets",
        "form": "form"
    }

    return kb

In [17]:
FPL_KB = load_fpl_kb(graph)

In [18]:
import re

def normalize(text: str) -> str:
    return re.sub(r"\s+", " ", text.lower().strip())


In [19]:
INTENTS = {
    "fixture_details": [
        "fixture", "fixtures", "when do", "when does", "when is", "play next", 
        "next match", "kickoff", "schedule", "upcoming match", "future match"
    ],

    "best_players_by_metric": [
        "top", "best", "highest", "leader", "rank", "ranking", 
        "top scorer", "top assist", "highest points", "most points", "stat leaders", "top players","best forward",
        "best midfielder", "best defender", "best goalkeeper","top number","best number"
    ],
    "Worst_players_by_metric": [
        "worst", "lowest", "bottom", "least", "bottom scorer", "least assists", "lowest points", "fewest points", "stat laggards", "worst players","worst forward",
        "worst midfielder", "worst defender", "worst goalkeeper","bottom number","worst number"
    ],
    "player_or_team_performance": [
        "how did", "performance", "stats", "statistics", "record", "scored", 
        "assists", "goals", "points", "clean sheets", "how many", 
        "results","compare", "vs", "versus", "better than", "head to head", "compare stats", "comparison","more than","compare player1 and player2"
    ],
    "player_information": [
        "who is", "tell me about", "information on", "details about", 
        "bio", "biography", "what is position of", "which team does", "age of", "nationality of", "height of", "weight of"
    ]
}

In [20]:
import re

def normalize(text: str) -> str:
    return re.sub(r"\s+", " ", text.lower().strip())

In [21]:
def classify_fpl_intents(prompt: str) -> str:
    q = normalize(prompt)

    scores = {
        "fixture_details": 0,
        "player_or_team_performance": 0,
        "best_players_by_metric": 0,
        "Worst_players_by_metric": 0,
        "player_information": 0
    }

    # ---------------------------
    # FIXTURE DETAILS (HIGHEST PRIORITY)
    # ---------------------------
    if any(x in q for x in [
        "when do", "when does", "when is",
        "next fixture", "play next", "next match",
        "kickoff", "schedule"
    ]):
        scores["fixture_details"] += 6

    # Player plays against team
    if " play against " in q or " plays against " in q:
        scores["fixture_details"] += 6

    # Team vs team fixture
    if (" play each other" in q or " vs " in q or " versus " in q) and "when" in q:
        scores["fixture_details"] += 6

    # Gameweek reference boosts fixture intent
    if re.search(r"\b(gameweek|gw)\s*\d+", q):
        scores["fixture_details"] += 2

    # ---------------------------
    # PERFORMANCE / COMPARISON
    # ---------------------------
    if any(x in q for x in [
        "show me", "give me", "how did", "performed",
        "total", "goals", "assists", "points",
        "bonus", "clean sheets", "stats"
    ]):
        scores["player_or_team_performance"] += 3

    if any(x in q for x in [
        "compare", "vs", "versus", "better than",
        "head to head", "comparison"
    ]):
        scores["player_or_team_performance"] += 5

    # Season reference
    if re.search(r"\b20\d{2}-\d{2}\b", q):
        scores["player_or_team_performance"] += 2

    # ---------------------------
    # BEST PLAYERS
    # ---------------------------
    if any(x in q for x in [
        "top", "best", "highest", "most",
        "top scorer", "top assist"
    ]):
        scores["best_players_by_metric"] += 6

    # ---------------------------
    # WORST PLAYERS
    # ---------------------------
    if any(x in q for x in [
        "worst", "bottom", "lowest", "least",
        "fewest"
    ]):
        scores["Worst_players_by_metric"] += 6

    # ---------------------------
    # PLAYER INFORMATION
    # ---------------------------
    if any(x in q for x in [
        "who is", "what team does", "what position does",
        "which position", "tell me about"
    ]):
        scores["player_information"] += 6

    # ---------------------------
    # FINAL DECISION
    # ---------------------------
    best_intent = max(scores, key=scores.get)

    if scores[best_intent] == 0:
        return "player_or_team_performance"

    return best_intent


In [80]:
tests = [
# Single Player Performance
    "Show me how Mohamed Salah performed in gameweek 5, including total_points.",
    "Show me the total goals scored by Kevin De Bruyne for the 2022-23 season.",
    
    # # Single Team Performance
    "Give me Arsenal's total goals in gameweek 10.",
    "What is Chelsea's total bonus points for the 2022-23 season?",
    
    # # Compare Two Players
    "Compare Mohamed Salah and Erling Haaland in gameweek 8 season 2021-22 for total points.",
    "Compare Erling Haaland and Mohamed Salah for the 2022-23 season by goals.",
    
    # # Compare Two Teams
    "Compare Liverpool and Chelsea in gameweek 12 for total points.",
    "Compare Liverpool and Chelsea in gameweek 12 for goals scored.",
    
    # # Fixtures
    "Show me Man City's next fixture in gameweek 15.",
    "When do Arsenal and Man city play each other?",
    "When does Mohamed Salah play next?",
    "When does Harry Kane play against Liverpool?",
    
    # # Best Players
    "Who are the top players by total points in the 2021-22 season?",
    "Who are the top 5 forwards by total points in the 2022-23 season?",
    "Who are the top 3 midfielders by assists?.",
    "Who are the top 2 players with goals above 10 in the 2022-23 season?",
    "Who is the best goalkeeper in the 2022-23 season?",
    "Who scored the most goals in the 2021-22 season?",

    # # Worst Players
    "Who are the worst players by total points in the 2021-22 season?",
    "Who are the bottom 5 defenders by total points in the 2022-23 season?",
    "Who are the bottom 3 midfielders by assists?.",
    "Who are the bottom 2 players with goals below 2 in the 2022-23 season?",

    # Player Information
    "What team does Harry Kane play for?",
    "What position does Mohamed Salah play?",
    "What team does Kevin De Bruyne play for?",
    "Which position does Virgil van Dijk play as.",
]

for t in tests:
    print(t, "→", classify_fpl_intents(t))


Show me how Mohamed Salah performed in gameweek 5, including total_points. → player_or_team_performance
Show me the total goals scored by Kevin De Bruyne for the 2022-23 season. → player_or_team_performance
Give me Arsenal's total goals in gameweek 10. → player_or_team_performance
What is Chelsea's total bonus points for the 2022-23 season? → player_or_team_performance
Compare Mohamed Salah and Erling Haaland in gameweek 8 season 2021-22 for total points. → player_or_team_performance
Compare Erling Haaland and Mohamed Salah for the 2022-23 season by goals. → player_or_team_performance
Compare Liverpool and Chelsea in gameweek 12 for total points. → player_or_team_performance
Compare Liverpool and Chelsea in gameweek 12 for goals scored. → player_or_team_performance
Show me Man City's next fixture in gameweek 15. → fixture_details
When do Arsenal and Man city play each other? → fixture_details
When does Mohamed Salah play next? → fixture_details
When does Harry Kane play against Liverpo

In [23]:
ENTITY_LOOKUP = {}

def add_to_lookup(terms, category):
    for item in terms:
        # If it's a dict (like stats), the item is the key, canonical is the value
        if isinstance(terms, dict):
            value = terms[item]
            key = item
        else:
            value = item.title()
            key = item
        
        ENTITY_LOOKUP[key.lower()] = (category, value)
add_to_lookup(FPL_KB["players"], "player")
add_to_lookup(FPL_KB["teams"], "team")
add_to_lookup(FPL_KB["positions"], "position")
add_to_lookup(FPL_KB["stats"], "stat")

In [84]:
def extract_fpl_entities(query: str) -> dict:
    """
    Extract entities from FPL query with improved accuracy and validation
    """
    doc = nlp(query)
    entities = {
        "stat_type": "total_points",  # Default fallback
        "season": "2022-23",  # Default season
        "limit": 10  # Default limit
    }
    
    query_lower = query.lower()
    
    # ============================================================================
    # STEP 1: Extract using spaCy + lookup
    # ============================================================================
    
    # Track which players/teams we've seen to handle comparisons
    seen_players = []
    seen_teams = []
    
    for token in doc:
        text = token.text.lower()
        lemma = token.lemma_.lower()
    

        match = ENTITY_LOOKUP.get(text) or ENTITY_LOOKUP.get(lemma)
        
        if match:
            category, value = match
            
            # Handle Players
            if category == "player":
                if value not in seen_players:
                    seen_players.append(value)
                    if "player1" not in entities:
                        entities["player1"] = value
                        entities["player_name"] = value
                    elif "player2" not in entities:
                        entities["player2"] = value
            
            # Handle Teams
            elif category == "team":
                if value not in seen_teams:
                    seen_teams.append(value)
                    if "team1" not in entities:
                        entities["team1"] = value
                        entities["team_name"] = value
                    elif "team2" not in entities:
                        entities["team2"] = value
            
            # Handle Positions
            elif category == "position":
                    # Normalize Aliases
                    norm = value.upper()
                    if "MID" in norm: norm = "MID"
                    elif "FWD" in norm or "FORWARD" in norm: norm = "FWD"
                    elif "DEF" in norm: norm = "DEF"
                    elif "GK" in norm or "GKP" in norm or "GOALKEEPER" in norm: norm = "GK"
                    entities["position"] = norm
            # Handle Stats
            elif category == "stat":
                if entities.get("stat_type") != "bonus":
                    entities["stat_type"] = value
    
    # ============================================================================
    # STEP 2: Multi-word entity extraction (e.g., "Mohamed Salah")
    # ============================================================================
    
    # Extract multi-word player names
    for player in FPL_KB["players"]:
        if player.lower() in query_lower:
            if player not in seen_players:
                seen_players.append(player)
                if "player1" not in entities:
                    entities["player1"] = player
                    entities["player_name"] = player
                elif "player2" not in entities:
                    entities["player2"] = player
    
    # Extract multi-word team names (e.g., "Manchester City")
    for team in FPL_KB["teams"]:
        if team.lower() in query_lower:
            if team not in seen_teams:
                seen_teams.append(team)
                if "team1" not in entities:
                    entities["team1"] = team
                    entities["team_name"] = team
                elif "team2" not in entities:
                    entities["team2"] = team
    
    # Extract multi-word stats (e.g., "clean sheets")
    for stat_key, stat_value in FPL_KB["stats"].items():
        if stat_key in query_lower:
            if entities.get("stat_type") != "bonus":
                entities["stat_type"] = stat_value
                break
    
    # ============================================================================
    # STEP 3: Regex extraction for structured patterns
    # ============================================================================
    
    # Extract Gameweek
    gw_match = re.search(r"(?:gw|gameweek|game week)\s*(\d+)", query_lower)
    if gw_match:
        entities["gw_number"] = int(gw_match.group(1))
    
    # Extract Season (2022-23 format)
    season_match = re.search(r"(20\d{2}-\d{2})", query_lower)
    if season_match:
        entities["season"] = season_match.group(1)
    else:
        # Try alternative format: "season 2022" or "in 2022"
        year_match = re.search(r"(?:season|year|in)\s*(20\d{2})", query_lower)
        if year_match:
            year = year_match.group(1)
            next_year = str(int(year) + 1)[-2:]
            entities["season"] = f"{year}-{next_year}"
    
    # Extract "top N" or "best N"
    limit_match = re.search(r"(?:top|best|first)\s*(\d+)", query_lower)
    if limit_match:
        entities["limit"] = int(limit_match.group(1))

    # Extract "bottom N" or "worst N"
    bottom_limit_match = re.search(r"(?:bottom|worst|last)\s*(\d+)", query_lower)
    if bottom_limit_match:
        entities["limit"] = int(bottom_limit_match.group(1))
    
    # Extract current gameweek for recommendations
    if any(word in query_lower for word in ["recommend", "suggest", "current", "now", "right now"]):
        # If no specific GW mentioned, assume current GW context
        if "current_gw" not in entities and "gw_number" in entities:
            entities["current_gw"] = entities["gw_number"]
        elif "current_gw" not in entities:
            entities["current_gw"] = 20  # Default mid-season
    
    # Extract minimum value filters (e.g., "more than 10 points")
    filter_match = re.search(r"(?:more than|over|above|at least|minimum)\s*(\d+)", query_lower)
    if filter_match:
        entities["filter_value"] = int(filter_match.group(1))
    
    return entities

In [72]:
def get_fpl_cypher_query(intent: str, entities: dict) -> str:
    """
    Generate Cypher query based on FPL intents + extracted entities
    """

    player1 = entities.get("player1")
    player2 = entities.get("player2")
    team1 = entities.get("team1")
    team2 = entities.get("team2")
    stat = entities.get("stat_type", "total_points")
    gw = entities.get("gw_number")
    limit = entities.get("limit", 10)
    season = entities.get("season", "2022-23")
    

    # ----------------------------------------------------------------------
    # 1) PERFORMANCE: Single Player
    # ----------------------------------------------------------------------
    if intent == "player_or_team_performance" and player1 and not player2:
        if gw:
            return f"""
                MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g:Gameweek {{GW_number:{gw}}})-[:HAS_FIXTURE]->(f:Fixture)
                MATCH (p:Player {{player_name:'{player1}'}})-[pi:PLAYED_IN]->(f)
                RETURN p.player_name AS player, pi.{stat} AS {stat}, g.GW_number AS gameweek
        
            """
        
        else:
            return f"""
                MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g)-[:HAS_FIXTURE]->(f)
                MATCH (p:Player {{player_name:'{player1}'}})-[pi:PLAYED_IN]->(f)
                RETURN p.player_name AS player, SUM(pi.{stat}) AS total_{stat}, '{season}' AS season
            """

    # ----------------------------------------------------------------------
    # 2) PERFORMANCE: Single Team Performance summary
    # ----------------------------------------------------------------------
    if intent == "player_or_team_performance" and team1 and not team2:
            if gw:
                # Single Gameweek performance for a team
                return f"""
                    MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g:Gameweek {{GW_number:{gw}}})-[:HAS_FIXTURE]->(f)
                    MATCH (p:Player)-[:PLAYS_FOR]->(t:Team{{name:'{team1}'}})
                    MATCH (p)-[pi:PLAYED_IN]->(f)
                    RETURN t.name AS team, SUM(pi.{stat}) AS {stat}, g.GW_number AS gameweek
                """
            else:
                # Full season performance for a team
                return f"""
                    MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(:Gameweek)-[:HAS_FIXTURE]->(f)
                    MATCH (p:Player)-[:PLAYS_FOR]->(t:Team{{name:'{team1}'}})
                    MATCH (p)-[pi:PLAYED_IN]->(f)
                    
                    RETURN t.name AS team, SUM(pi.{stat}) AS {stat}
                """



    # ----------------------------------------------------------------------
    # 3) PERFORMANCE: Compare Two Players
    # ----------------------------------------------------------------------
    if intent == "player_or_team_performance" and player1 and player2:
        if gw:  
            return f"""
           UNWIND ['{player1}', '{player2}'] AS pname
            MATCH (p:Player {{player_name: pname}})
            OPTIONAL MATCH (p)-[pi:PLAYED_IN]->(f:Fixture)<-[:HAS_FIXTURE]-(g:Gameweek {{GW_number:{gw}}})<-[:HAS_GW]-(:Season {{season_name:'{season}'}})
            RETURN 
                p.player_name AS player, 
                COALESCE(SUM(pi.{stat}), 0) AS total_{stat},
                {gw} AS gameweek
            ORDER BY player


            """
        else:
            return f"""
                UNWIND ['{player1}', '{player2}'] AS pname
                MATCH (p:Player {{player_name: pname}})
                OPTIONAL MATCH (p)-[pi:PLAYED_IN]->(:Fixture)<-[:HAS_FIXTURE]-(g:Gameweek)<-[:HAS_GW]-(:Season {{season_name:'{season}'}})
                RETURN 
                    p.player_name AS player, 
                    COALESCE(SUM(pi.{stat}), 0) AS total_{stat}
                ORDER BY player

            """

    # ----------------------------------------------------------------------
    # 4) PERFORMANCE: Compare Two Teams
    # ----------------------------------------------------------------------
    if intent == "player_or_team_performance" and team1 and team2:
        if gw:
            # Single Gameweek comparison
            return f"""
                MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g:Gameweek {{GW_number:{gw}}})-[:HAS_FIXTURE]->(f)
                MATCH (p:Player)-[pi:PLAYED_IN]->(f)
                MATCH (p)-[:PLAYS_FOR]->(t:Team)
                WHERE t.name = '{team1}' OR t.name = '{team2}'
                RETURN 
                    CASE 
                        WHEN t.name = '{team1}' THEN '{team1}' 
                        ELSE '{team2}' 
                    END AS team,
                    SUM(pi.{stat}) AS total_{stat},
                    g.GW_number AS gameweek
            """
        else:
            # Full season comparison
            return f"""
                MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(:Gameweek)-[:HAS_FIXTURE]->(f)
                MATCH (p:Player)-[pi:PLAYED_IN]->(f)
                MATCH (p)-[:PLAYS_FOR]->(t:Team)
                WHERE t.name = '{team1}' OR t.name = '{team2}'
                RETURN 
                    CASE 
                        WHEN t.name = '{team1}' THEN '{team1}' 
                        ELSE '{team2}' 
                    END AS team,
                    SUM(pi.{stat}) AS total_{stat}
            """

    # ----------------------------------------------------------------------
    # 5) FIXTURE: Next Fixture by Team
    # ----------------------------------------------------------------------
    if intent == "fixture_details" and team1 and gw and not team2:
        return f"""
             MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g:Gameweek {{GW_number:{gw}}})-[:HAS_FIXTURE]->(f:Fixture)-[:HAS_HOME_TEAM]->(h:Team),
            (f)-[:HAS_AWAY_TEAM]->(a:Team)
            WHERE h.name = '{team1}' OR a.name = '{team1}'
            RETURN f.fixture_number AS fixture, f.kickoff_time, h, a
            ORDER BY f.kickoff_time ASC 
        """
    
    

    # ----------------------------------------------------------------------
    # 6) FIXTURE: Next Fixture involving Two Teams (Head-to-head)
    # ----------------------------------------------------------------------
    if intent == "fixture_details" and team1 and team2 and gw:
        return f"""
          MATCH (:Season {{season_name:'{season}'}})
        -[:HAS_GW]->(g:Gameweek {{GW_number:{gw}}})
        -[:HAS_FIXTURE]->(f:Fixture)
        MATCH (f)-[:HAS_HOME_TEAM]->(home:Team)
        MATCH (f)-[:HAS_AWAY_TEAM]->(away:Team)
        WHERE (home.name = '{team1}' AND away.name = '{team2}')
        OR (home.name = '{team2}' AND away.name = '{team1}')
        RETURN 
            f.fixture_number AS fixture,
            f.kickoff_time AS kickoff_time,
            home.name AS home_team,
            away.name AS away_team
        ORDER BY f.kickoff_time ASC
   

        """
    
    if intent == "fixture_details" and team1 and team2 and not gw:
        return f"""
            MATCH (:Season {{season_name:'{season}'}})
        -[:HAS_GW]->(g)
        -[:HAS_FIXTURE]->(f:Fixture)
        MATCH (f)-[:HAS_HOME_TEAM]->(home:Team)
        MATCH (f)-[:HAS_AWAY_TEAM]->(away:Team)
        WHERE (home.name = '{team1}' AND away.name = '{team2}')
        OR (home.name = '{team2}' AND away.name = '{team1}')
        RETURN 
            f.fixture_number AS fixture,
            f.kickoff_time AS kickoff_time,
            g.GW_number AS gameweek,
            home.name AS home_team,
            away.name AS away_team
        ORDER BY f.kickoff_time ASC


            """

    # ----------------------------------------------------------------------
    # 7) BEST PLAYERS BY METRIC: Overall
    # ----------------------------------------------------------------------
    if intent == "best_players_by_metric" and not entities.get("position"):
        if entities.get("filter_value")!= None:
            val = entities["filter_value"]
            return f"""
                MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g)-[:HAS_FIXTURE]->(f)
                MATCH (p:Player)-[pi:PLAYED_IN]->(f)
                WITH p, SUM(pi.{stat}) AS total_stat
                WHERE total_stat > {val}
                RETURN p.player_name AS player, total_stat
                ORDER BY total_stat DESC LIMIT {limit}
            """
        else:
            return f"""
                MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g)-[:HAS_FIXTURE]->(f)
                MATCH (p:Player)-[pi:PLAYED_IN]->(f)
                RETURN p.player_name AS player, SUM(pi.{stat}) AS total_{stat}
                ORDER BY total_{stat} DESC LIMIT {limit}
            """

    # ----------------------------------------------------------------------
    # 8) BEST PLAYERS BY METRIC AND POSITION
    # ----------------------------------------------------------------------
    if intent == "best_players_by_metric" and entities.get("position"):
        position = entities["position"]
        if entities.get("filter_value")!= None:
            val = entities["filter_value"]
            return f"""
                MATCH (p:Player)-[:PLAYS_AS]->(pos:Position {{name:'{position}'}})
                MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g)-[:HAS_FIXTURE]->(f)
                MATCH (p)-[pi:PLAYED_IN]->(f)
                WITH p, pos, SUM(pi.{stat}) AS total_stat
                WHERE total_stat > {val}
                RETURN p.player_name AS player, total_stat, pos.name AS position
                ORDER BY total_stat DESC LIMIT {limit}
            """
        else:
            return f"""
                MATCH (p:Player)-[:PLAYS_AS]->(pos:Position {{name:'{position}'}})
                MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g)-[:HAS_FIXTURE]->(f)
                MATCH (p)-[pi:PLAYED_IN]->(f)
                RETURN p.player_name AS player, SUM(pi.{stat}) AS total_{stat}, pos.name AS position
                ORDER BY total_{stat} DESC LIMIT {limit}
            """

    # ----------------------------------------------------------------------
    # 9) BEST PLAYERS — FILTER WHERE STAT ABOVE VALUE
    # ----------------------------------------------------------------------
    if intent == "best_players_by_metric" and entities.get("filter_value")!= None:
        val = entities["filter_value"]
        return f"""
            MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g)-[:HAS_FIXTURE]->(f)
            MATCH (p:Player)-[pi:PLAYED_IN]->(f)
            WITH p, SUM(pi.{stat}) AS total_stat
            WHERE total_stat > {val}
            RETURN p.player_name AS player, total_stat
            ORDER BY total_stat DESC LIMIT {limit}
        """

    #10)when will player play against team
    if intent == "fixture_details" and player1 and team1:
        return f"""
            MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g)-[:HAS_FIXTURE]->(f:Fixture)
            MATCH (p:Player {{player_name:'{player1}'}})-[pi:PLAYED_IN]->(f)
            MATCH (f)-[:HAS_HOME_TEAM]->(h:Team),
                  (f)-[:HAS_AWAY_TEAM]->(a:Team)
            MATCH (p)-[:PLAYS_FOR]->(t:Team)
            WHERE h.name = '{team1}' OR a.name = '{team1}'
            RETURN p.player_name AS player, t.name AS team, f.fixture_number AS fixture, f.kickoff_time, g.GW_number AS gameweek, h, a
            ORDER BY f.kickoff_time ASC

        """
    
    #11)Worst players by metric
    if intent == "Worst_players_by_metric" and not entities.get("position"):
        if entities.get("filter_value")!= None:
            val = entities["filter_value"]
            return f"""
                MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g)-[:HAS_FIXTURE]->(f)
                MATCH (p:Player)-[pi:PLAYED_IN]->(f)
                WITH p, SUM(pi.{stat}) AS total_stat
                WHERE total_stat > {val}
                RETURN p.player_name AS player, total_stat
                ORDER BY total_stat ASC LIMIT {limit}
            """
        else:
            return f"""
                MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g)-[:HAS_FIXTURE]->(f)
                MATCH (p:Player)-[pi:PLAYED_IN]->(f)
                RETURN p.player_name AS player, SUM(pi.{stat}) AS total_{stat}
                ORDER BY total_{stat} ASC LIMIT {limit}
            """
    if intent == "Worst_players_by_metric" and entities.get("position"):
        position = entities["position"]
        if entities.get("filter_value")!= None:
            val = entities["filter_value"]
            return f"""
                MATCH (p:Player)-[:PLAYS_AS]->(pos:Position {{name:'{position}'}})
                MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g)-[:HAS_FIXTURE]->(f)
                MATCH (p)-[pi:PLAYED_IN]->(f)
                WITH p, pos, SUM(pi.{stat}) AS total_stat
                WHERE total_stat > {val}
                RETURN p.player_name AS player, total_stat, pos.name AS position
                ORDER BY total_stat ASC LIMIT {limit}
            """
        else:
            return f"""
                MATCH (p:Player)-[:PLAYS_AS]->(pos:Position {{name:'{position}'}})
                MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g)-[:HAS_FIXTURE]->(f)
                MATCH (p)-[pi:PLAYED_IN]->(f)
                RETURN p.player_name AS player, SUM(pi.{stat}) AS total_{stat}, pos.name AS position
                ORDER BY total_{stat} ASC LIMIT {limit}
            """
    if intent == "Worst_players_by_metric" and entities.get("filter_value")!= None:
        val = entities["filter_value"]
        return f"""
            MATCH (:Season {{season_name:'{season}'}})-[:HAS_GW]->(g)-[:HAS_FIXTURE]->(f)
            MATCH (p:Player)-[pi:PLAYED_IN]->(f)
            WITH p, SUM(pi.{stat}) AS total_stat
            WHERE total_stat > {val}
            RETURN p.player_name AS player, total_stat
            ORDER BY total_stat ASC LIMIT {limit}
        """
    
    #12)PLAYER INFORMATION
    if intent == "player_information" and player1:
        return f"""
            MATCH (p:Player {{player_name:'{player1}'}})-[:PLAYS_FOR]->(t:Team),
                  (p)-[:PLAYS_AS]->(pos:Position)
            RETURN p.player_name AS player,
                   t.name AS team,
                   pos.name AS position
            """
                   

    # ----------------------------------------------------------------------
    # 10) FALLBACK
    # ----------------------------------------------------------------------
    return f"""MATCH (n:NonExistentLabel)
            RETURN n
            """


In [50]:
def format_query_result(intent: str, result: list, entities: dict = None) -> str:

    # Handle empty results
    if result is None or len(result) == 0:
        if intent == "fixture_details":
            team = entities.get("team1", "the team") if entities else "the team"
            return f"No fixtures found for {team}."
        elif intent == "best_players_by_metric":
            return "No players found matching the criteria."
        elif intent == "player_or_team_performance":
            entity_name = entities.get("player1") or entities.get("team1") if entities else None
            return f"No performance data found{' for ' + entity_name if entity_name else ''}."
        else:
            return "No results found."
    
    entities = entities or {}
    
    # ========================================================================
    # INTENT 1: FIXTURE_DETAILS
    # ========================================================================

    
    if intent == "fixture_details":
        if len(result) == 1:
            rec = result[0]
            # Check if player info exists
            player = rec.get('player')
            if player:
                home = rec.get('h', {}).get('name', rec.get('home_team', 'Unknown'))
                away = rec.get('a', {}).get('name', rec.get('away_team', 'Unknown'))
                kickoff = rec.get('f.kickoff_time', 'TBD')
                team = rec.get('team', 'Unknown Team')
                fixture_num = rec.get('fixture', rec.get('fixture_number', ''))
                if team == home:
                    opposing = away
                else:
                    opposing = home
                return f"{player} in team {team} will play against {opposing} at {kickoff if kickoff != 'TBD' else 'TBD'} (Fixture #{fixture_num})"
            else:
                home = rec.get('h', {}).get('name', rec.get('home_team', 'Unknown'))
                away = rec.get('a', {}).get('name', rec.get('away_team', 'Unknown'))
                kickoff = rec.get('f.kickoff_time', 'TBD')
                fixture_num = rec.get('fixture', '')
                return f"{home} will play against {away} at {kickoff if kickoff != 'TBD' else 'TBD'} (Fixture #{fixture_num})"

        else:
            response = "Upcoming Fixtures:\n\n"
            for i, rec in enumerate(result, 1):
                player = rec.get('player')
                team = rec.get('team', 'Unknown Team')
                home = rec.get('h', {}).get('name', rec.get('home_team', 'Unknown'))
                away = rec.get('a', {}).get('name', rec.get('away_team', 'Unknown'))
                kickoff = rec.get('f.kickoff_time', rec.get('kickoff_time', 'TBD'))
                gameweek = rec.get('gameweek', 'gw')
                fixture_num = rec.get('fixture', rec.get('fixture_number', ''))
                if player and team:
                    if team == home:
                        opposing = away
                    else:
                        opposing = home
                    response += f"{i}. {player} in team {team} will play against {opposing} in Gameweek {gameweek}"

                else:
                    response += f"{i}. {home} will play against {away} in Gameweek {gameweek}"

                if kickoff and kickoff != 'TBD':
                    response += f" at {kickoff} (Fixture #{fixture_num})"
                response += "\n"
            return response.strip()

    
    # ========================================================================
    # INTENT 2: BEST_PLAYERS_BY_METRIC
    # ========================================================================
    elif intent == "best_players_by_metric" or intent == "Worst_players_by_metric":
        stat_type = entities.get("stat_type", "total_points")
        position = entities.get("position")
        limit = entities.get("limit", len(result))
        
        # Convert stat field name to display name
        stat_names = {
            "goals_scored": "Goals", "assists": "Assists", "total_points": "Points",
            "bonus": "Bonus Points", "clean_sheets": "Clean Sheets", 
            "saves": "Saves", "minutes": "Minutes", "form": "Form"
        }
        stat_display = stat_names.get(stat_type, stat_type.replace('_', ' ').title())
        
        position_text = f" {position}" if position else ""
        response = f"**Top{position_text} Players by {stat_display}:**\n\n"
        
        for i, rec in enumerate(result[:limit], 1):
            player = rec.get('player', 'Unknown Player')
            
            # Try different possible field names for the stat value
            stat_value = (rec.get(f'total_{stat_type}') or 
                         rec.get('total_stat') or 
                         rec.get(stat_type) or 
                         rec.get('Total') or 0)
            
            pos = rec.get('position', '')
            pos_text = f" ({pos})" if pos else ""
            
            response += f"{i}. {player}{pos_text}: **{stat_value}** {stat_display}\n"
        
        return response.strip()
    
    # ========================================================================
    # INTENT 3: PLAYER_OR_TEAM_PERFORMANCE
    # ========================================================================
    elif intent == "player_or_team_performance":
        has_player1 = bool(entities.get("player1"))
        has_player2 = bool(entities.get("player2"))
        has_team1 = bool(entities.get("team1"))
        has_team2 = bool(entities.get("team2"))
        has_gw = bool(entities.get("gw_number"))
        
        stat_type = entities.get("stat_type", "total_points")
        stat_names = {
            "goals_scored": "Goals", "assists": "Assists", "total_points": "Points",
            "bonus": "Bonus Points", "clean_sheets": "Clean Sheets", 
            "saves": "Saves", "minutes": "Minutes", "form": "Form"
        }
        stat_display = stat_names.get(stat_type, stat_type.replace('_', ' ').title())
        
        # --------------------------------------------------------------------
        # SUB-CASE 1: Compare Two Players
        # --------------------------------------------------------------------
        if has_player1 and has_player2 and len(result) >= 2:
            player1 = entities["player1"]
            player2 = entities["player2"]
            
            p1_stat = 0
            p2_stat = 0
            
            for rec in result:
                player_name = rec.get('player', '')
                stat_value = rec.get(f'total_{stat_type}', rec.get(stat_type, 0))
                
                if player_name == player1:
                    p1_stat = stat_value
                elif player_name == player2:
                    p2_stat = stat_value
            
            response = f"**Player Comparison - {stat_display}:**\n\n"
            response += f"{player1}: **{p1_stat}** {stat_display}\n"
            response += f"{player2}: **{p2_stat}** {stat_display}\n\n"
            
            if p1_stat > p2_stat:
                diff = p1_stat - p2_stat
                response += f"{player1} has {diff} more {stat_display} than {player2}"
            elif p2_stat > p1_stat:
                diff = p2_stat - p1_stat
                response += f"{player2} has {diff} more {stat_display} than {player1}"
            else:
                response += f"Both players are equal in {stat_display}"
            
            return response
        
        # --------------------------------------------------------------------
        # SUB-CASE 2: Compare Two Teams
        # --------------------------------------------------------------------
        elif has_team1 and has_team2 and len(result) >= 2:
            team1 = entities["team1"]
            team2 = entities["team2"]
            
            t1_stat = 0
            t2_stat = 0
            
            for rec in result:
                team_name = rec.get('team', '')
                stat_value = rec.get(f'total_{stat_type}', rec.get(stat_type, 0))
                
                if team_name == team1:
                    t1_stat = stat_value
                elif team_name == team2:
                    t2_stat = stat_value
            
            response = f"**Team Comparison - {stat_display}:**\n\n"
            response += f"{team1}: **{t1_stat}** {stat_display}\n"
            response += f"{team2}: **{t2_stat}** {stat_display}\n\n"
            
            if t1_stat > t2_stat:
                diff = t1_stat - t2_stat
                response += f"{team1} has {diff} more {stat_display} than {team2}"
            elif t2_stat > t1_stat:
                diff = t2_stat - t1_stat
                response += f"{team2} has {diff} more {stat_display} than {team1}"
            else:
                response += f"Both teams are equal in {stat_display}"
            
            return response
        
        # --------------------------------------------------------------------
        # SUB-CASE 3: Single Player in Specific Gameweek
        # --------------------------------------------------------------------
        elif has_player1 and has_gw and len(result) >= 1:
            rec = result[0]
            player = rec.get('player', entities.get('player1', 'Unknown Player'))
            gw = rec.get('gameweek', entities.get('gw_number', '?'))
            
            stat_value = rec.get(stat_type, rec.get(f'total_{stat_type}', 0))
            
            response = (
                f"{player} played in Gameweek {gw}"
                + (f", scoring {rec.get('goals_scored', rec.get('goals', 0))} goals" if 'goals_scored' in rec or 'goals' in rec else "")
                + (f", providing {rec['assists']} assists" if 'assists' in rec else "")
                + (f", playing {rec['minutes']} minutes" if 'minutes' in rec else "")
                + f", with {stat_value} {stat_display}."
            )

            return response.strip()
        
        # --------------------------------------------------------------------
        # SUB-CASE 4: Single Team in Specific Gameweek
        # --------------------------------------------------------------------
        elif has_team1 and has_gw and len(result) >= 1:
            rec = result[0]
            team = rec.get('team', entities.get('team1', 'Unknown Team'))
            gw = rec.get('gameweek', entities.get('gw_number', '?'))
            
            stat_value = rec.get(f'total_{stat_type}', rec.get(stat_type, 0))
            
            response = f"{team} in Gameweek {gw} has a total of {stat_value} {stat_display}."

            
            return response.strip()
        
        # --------------------------------------------------------------------
        # SUB-CASE 5: Single Player Full Season
        # --------------------------------------------------------------------
        elif has_player1 and not has_gw:
            rec = result[0]
            player = rec.get('player', entities.get('player1', 'Unknown Player'))
            season = entities.get('season', '2022-23')
            
            stat_value = rec.get(f'total_{stat_type}', rec.get(stat_type, 0))
            
            response = (
            f"{player} in the {season} season has a total of {stat_value} {stat_display}"
            + (f", scoring {rec.get('total_goals_scored', rec.get('goals', 0))} goals" 
            if 'total_goals_scored' in rec or 'goals' in rec else "")
            + (f", providing {rec.get('total_assists', rec.get('assists', 0))} assists" 
            if 'total_assists' in rec or 'assists' in rec else "")
            + (f", playing {rec.get('total_minutes', rec.get('minutes', 0))} minutes" 
            if 'total_minutes' in rec or 'minutes' in rec else "")
            + "."
            )

            
            return response.strip()
        
        # --------------------------------------------------------------------
        # SUB-CASE 6: Single Team Full Season
        # --------------------------------------------------------------------
        elif has_team1 and not has_gw:
            rec = result[0]
            team = rec.get('team', entities.get('team1', 'Unknown Team'))
            season = entities.get('season', '2022-23')
            
            stat_value = rec.get(f'total_{stat_type}', rec.get(stat_type, 0))
            
            response = f"{team} in the {season} season has a total of {stat_value} {stat_display}."
            
            return response.strip()
        
        # --------------------------------------------------------------------
        # FALLBACK: Generic Performance
        # --------------------------------------------------------------------
        else:
            response = "**Performance Results:**\n\n"
            
            for i, rec in enumerate(result, 1):
                name = rec.get('player', rec.get('team', f'Entity {i}'))
                
                stat_value = None
                for key in [f'total_{stat_type}', stat_type, 'total_points', 'points']:
                    if key in rec:
                        stat_value = rec[key]
                        break
                
                if stat_value is not None:
                    response += f"{i}. {name}: {stat_value} {stat_display}\n"
                else:
                    response += f"{i}. {name}\n"
            
            return response.strip()
        
    #intent 4: PLAYER_INFORMATION
    if intent == "player_information":
        rec = result[0]
        player = rec.get('player', entities.get('player1', 'Unknown Player'))
        team = rec.get('team', 'Unknown Team')
        position = rec.get('position', 'Unknown Position')
        
        return f"{player} plays for {team} as a {position}."
    
    # ========================================================================
    # FALLBACK for unknown intents
    # ========================================================================
    else:
        return "Query executed. Results retrieved."

In [86]:
# Array of test prompts
test_prompts = [
 
    # Single Player Performance
    "Show me how Mohamed Salah performed in gameweek 5, including total_points.",
    "Show me the total goals scored by Kevin De Bruyne for the 2022-23 season.",
    
    # # Single Team Performance
    "Give me Arsenal's total goals in gameweek 10.",
    "What is Liverpool's total bonus points for the 2022-23 season?",
    
    # # Compare Two Players
    "Compare Mohamed Salah and Erling Haaland in gameweek 8 season 2021-22 for total points.",
    "Compare Erling Haaland and Mohamed Salah for the 2022-23 season by goals.",
    
    # # Compare Two Teams
    "Compare Liverpool and Chelsea in gameweek 12 for total points.",
    "Compare Liverpool and Chelsea in gameweek 12 for goals scored.",
    
    # # # Fixtures
    "Show me Man City's next fixture in gameweek 15.",
    "When do Arsenal and Man city play each other?",
    "When does Harry Kane play against Liverpool?",
    
    # # # Best Players
    "Who are the top players by total points in the 2021-22 season?",
    "Who are the top 5 forwards by total points in the 2022-23 season?",
    "Who are the top 3 midfielders by assists?.",
    "Who are the top 2 players with goals above 33 in the 2022-23 season?",
    "Who is the best goalkeeper in the 2022-23 season?",
    "Who scored the most goals in the 2021-22 season?",

    # # # Worst Players
    "Who are the worst players by total points in the 2021-22 season?",
    "Who are the bottom 5 defenders by total points in the 2022-23 season?",
    "Who are the bottom 3 midfielders by assists above 0?.",
    "Who are the bottom 2 players with goals below 2 in the 2022-23 season?",

    # # Player Information
    "What team does Harry Kane play for?",
    "What position does Mohamed Salah play?",
    "What team does Kevin De Bruyne play for?",
    "Which position does Virgil van Dijk play as.",
]



# Iterate over prompts to test the pipeline
for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    
    intent = classify_fpl_intents(prompt)
    print("Classified Intent:", intent)
    
    entities = extract_fpl_entities(prompt)
    # print("Extracted Entities:", entities)
    
    cypher_query = get_fpl_cypher_query(intent, entities)
    # print("Cypher Query:\n", cypher_query)
    
    # Execute query
    query_result = graph.query(cypher_query)
    # print("Query Result:\n", query_result)
    
    # Format response
    formatted_response = format_query_result(intent, query_result, entities)
    print("Formatted Response:\n", formatted_response)



Prompt: Show me how Mohamed Salah performed in gameweek 5, including total_points.
Classified Intent: player_or_team_performance
Formatted Response:
 Mohamed Salah played in Gameweek 5, with 10 Points.

Prompt: Show me the total goals scored by Kevin De Bruyne for the 2022-23 season.
Classified Intent: player_or_team_performance


Formatted Response:
 Kevin De Bruyne in the 2022-23 season has a total of 7 Goals, scoring 7 goals.

Prompt: Give me Arsenal's total goals in gameweek 10.
Classified Intent: player_or_team_performance
Formatted Response:
 Arsenal in Gameweek 10 has a total of 3 Goals.

Prompt: What is Liverpool's total bonus points for the 2022-23 season?
Classified Intent: player_or_team_performance
Formatted Response:
 Liverpool in the 2022-23 season has a total of 144 Bonus Points.

Prompt: Compare Mohamed Salah and Erling Haaland in gameweek 8 season 2021-22 for total points.
Classified Intent: player_or_team_performance
Formatted Response:
 **Player Comparison - Points:**

Erling Haaland: **0** Points
Mohamed Salah: **13** Points

Mohamed Salah has 13 more Points than Erling Haaland

Prompt: Compare Erling Haaland and Mohamed Salah for the 2022-23 season by goals.
Classified Intent: player_or_team_performance
Formatted Response:
 **Player Comparison - Goals:**

Erling Haaland: **36** Goals
Mohamed

## Embeddings

For creation of feature vector embeddings, text descriptions where constructed from the numerical features and where saved to the KG on the PLAYED_IN relation.

Then embeddings where created for these texts

In [7]:
def reset_vector_index(index_name, label, property_name, dimension):
    try:
        graph.query(f"DROP INDEX {index_name} IF EXISTS")
    except Exception as e:
        print(f"   - Warning dropping index: {e}")

    # 2. Create the new index
    create_query = f"""
    CREATE VECTOR INDEX {index_name}
    FOR (n:{label})
    ON (n.{property_name})
    OPTIONS {{indexConfig: {{
      `vector.dimensions`: {dimension},
      `vector.similarity_function`: 'cosine'
    }}}}
    """
    
    try:
        graph.query(create_query)
    except Exception as e:
        print(f"   - ❌ Error creating index: {e}")

In [8]:
def reset_relationship_vector_index(graph, index_name, rel_type, property_name, dimension):
    # 1. Drop existing index
    try:
        graph.query(f"DROP INDEX {index_name} IF EXISTS")
    except Exception as e:
        print(f"   - Warning dropping index: {e}")

    # 2. Create the new RELATIONSHIP index
    # Note the syntax change: FOR ()-[r:{rel_type}]-()
    create_query = f"""
    CREATE VECTOR INDEX {index_name}
    FOR ()-[r:{rel_type}]-()
    ON (r.{property_name})
    OPTIONS {{indexConfig: {{
      `vector.dimensions`: {dimension},
      `vector.similarity_function`: 'cosine'
    }}}}
    """
    
    try:
        graph.query(create_query)
        print(f"   - ✅ Created relationship vector index: {index_name}")
    except Exception as e:
        print(f"   - ❌ Error creating index: {e}")

In [9]:
def generate_player_feature_vector_embeddings(graph: Neo4jGraph, embedding_model, model_name: str):

    fetch_query = """
    MATCH (p:Player)
    WHERE p.fpl_features IS NOT NULL
    RETURN p.player_name AS name, p.fpl_features AS text
    """
    data = graph.query(fetch_query)

    update_query = f"""
    MATCH (p:Player {{player_name: $name}})
    SET p.feature_vector_embedding_{model_name} = $embedding
    """

    for row in data:
        vector = embedding_model.embed_query(row['text'])
        graph.query(update_query, {'name': row['name'], 'embedding': vector})

    reset_vector_index(
        index_name="player_feature_index_" + model_name, 
        label="Player", 
        property_name="feature_vector_embedding_" + model_name, 
        dimension=len(vector)
    )

    print("Feature Vector Embeddings generation complete!")

In [10]:
def generate_team_feature_vector_embeddings(graph: Neo4jGraph, embedding_model, model_name: str):

    fetch_query = """
    MATCH (t:Team)
    WHERE t.team_description IS NOT NULL
    RETURN t.name AS name, t.team_description AS text
    """
    data = graph.query(fetch_query)

    update_query = f"""
    MATCH (t:Team {{name: $name}})
    SET t.feature_vector_embedding_{model_name} = $embedding
    """

    for row in data:
        vector = embedding_model.embed_query(row['text'])
        graph.query(update_query, {'name': row['name'], 'embedding': vector})

    reset_vector_index(
        index_name="team_feature_index_" + model_name, 
        label="Team", 
        property_name="feature_vector_embedding_" + model_name, 
        dimension=len(vector)
    )

    print("Team Feature Vector Embeddings generation complete!")

In [11]:
def generate_feature_vector_embeddings(graph, embedding_model, model_name: str, batch_size=500):
    print(f"Fetching data for model: {model_name}...")
    
    fetch_query = """
    MATCH (p:Player)-[pf:PLAYED_IN]->(f:Fixture)
    WHERE pf.feature_text IS NOT NULL
    RETURN p.player_element AS element, 
           p.player_name AS name, 
           f.season AS season, 
           f.fixture_number AS fixture_id, 
           pf.feature_text AS text
    """
    data = graph.query(fetch_query)
    
    print(f"Found {len(data)} records. Generating embeddings...")

    # Prepare batch list
    batch = []
    total_processed = 0

    # The update query expects a list of objects called $batch_data
    update_query = f"""
    UNWIND $batch_data AS row
    MATCH (p:Player {{player_name: row.name, player_element: row.element}})-[pf:PLAYED_IN]->(f:Fixture {{season: row.season, fixture_number: row.fixture_id}})
    SET pf.feature_vector_embedding_{model_name} = row.embedding
    """

    for row in data:
        # Generate embedding
        vector = embedding_model.embed_query(row['text'])
        
        # Add to batch
        batch.append({
            'name': row['name'],
            'element': row['element'],
            'season': row['season'],
            'fixture_id': row['fixture_id'],
            'embedding': vector
        })

        # Execute if batch is full
        if len(batch) >= batch_size:
            graph.query(update_query, {'batch_data': batch})
            total_processed += len(batch)
            print(f"   - Processed {total_processed}/{len(data)}...")
            batch = [] # Reset batch

    # Process remaining records
    if batch:
        graph.query(update_query, {'batch_data': batch})
        total_processed += len(batch)

    if total_processed > 0:
        sample_dim = len(batch[0]['embedding']) if batch else len(vector)
        reset_relationship_vector_index(
            graph=graph,
            index_name="pf_feature_index_" + model_name, 
            rel_type="PLAYED_IN", 
            property_name="feature_vector_embedding_" + model_name, 
            dimension=sample_dim
        )

    print("Feature Vector Embeddings generation complete!")

In [12]:
from langchain_huggingface import HuggingFaceEmbeddings

def generate_all_embeddings(graph: Neo4jGraph):
    embedding_model = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2",
        model_kwargs={"device": "cuda"}
    )
    generate_feature_vector_embeddings(graph, embedding_model, "MiniLM")

    embedding_model2 = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-mpnet-base-v2",
        model_kwargs={"device": "cuda"}
    )
    generate_feature_vector_embeddings(graph, embedding_model2, "MPNet")

    embedding_model3 = HuggingFaceEmbeddings(
        model_name="BAAI/bge-base-en-v1.5",
        model_kwargs={"device": "cuda"}
    )
    generate_feature_vector_embeddings(graph, embedding_model3, "BGE")


  from .autonotebook import tqdm as notebook_tqdm


In [52]:
mini_lm = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2",
        model_kwargs={"device": "cuda"}
    )
mpnet = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-mpnet-base-v2",
        model_kwargs={"device": "cuda"}
    )
bge = HuggingFaceEmbeddings(
        model_name="BAAI/bge-base-en-v1.5",
        model_kwargs={"device": "cuda"}
    )

In [15]:
generate_all_embeddings(graph)

Fetching data for model: MiniLM...
Found 51952 records. Generating embeddings...
   - Processed 500/51952...
   - Processed 1000/51952...
   - Processed 1500/51952...
   - Processed 2000/51952...
   - Processed 2500/51952...
   - Processed 3000/51952...
   - Processed 3500/51952...
   - Processed 4000/51952...
   - Processed 4500/51952...
   - Processed 5000/51952...
   - Processed 5500/51952...
   - Processed 6000/51952...
   - Processed 6500/51952...
   - Processed 7000/51952...
   - Processed 7500/51952...
   - Processed 8000/51952...
   - Processed 8500/51952...
   - Processed 9000/51952...
   - Processed 9500/51952...
   - Processed 10000/51952...
   - Processed 10500/51952...
   - Processed 11000/51952...
   - Processed 11500/51952...
   - Processed 12000/51952...
   - Processed 12500/51952...
   - Processed 13000/51952...
   - Processed 13500/51952...
   - Processed 14000/51952...
   - Processed 14500/51952...
   - Processed 15000/51952...
   - Processed 15500/51952...
   - Proc

In [42]:
from typing import List, Tuple

def retrieve_embedding_search(query: str, embeddings_model, model_name: str, graph):
    query_vector = embeddings_model.embed_query(query)
    index_name = "pf_feature_index_" + model_name

    search_query = """
    CALL db.index.vector.queryRelationships($index_name, 3, $embedding)
    YIELD relationship, score
    RETURN relationship.feature_text AS text, score
    """
    
    try:
        results = graph.query(search_query, {
            "index_name": index_name, 
            "embedding": query_vector
        })
        
        text_list = [row['text'] for row in results]
        if text_list:
            context_string = "\n\n---\n\n".join(text_list)
        else:
            context_string = "No relevant context found."
            
        return context_string

    except Exception as e:
        print(f"Error: {e}")
        return "Error retrieving context."

### Embedding models comparison

In [53]:
MODELS = {
    "MiniLM": {
        "model": mini_lm,
        "index_name": "pf_feature_index_MiniLM"
    },
    "MPNet": {
        "model": mpnet,
        "index_name": "pf_feature_index_MPNet"
    },
    "BGE": {
        "model": bge,
        "index_name": "pf_feature_index_BGE"
    }
}

In [64]:
sample_queries = [
    "What is Mohamed Salah total points in gameweek 38 2022-23 season?",
    "What position does Kevin De Bruyne play?",
    "How many assists did Harry Kane have in gameweek 10?",
    "How did Salah perform against Spurs?",
    "Which defender had the most clean sheets in Gameweek 12?",
    "Who are the best cheap value midfielders?"
]

In [59]:
import time
def run_speed_test():
    test_queries = sample_queries * 10

    print("\n=== 1. Embedding Speed Test (60 queries) ===")
    results = {}
    
    for name, config in MODELS.items():
        start = time.time()
        _ = config['model'].embed_documents(test_queries)
        duration = time.time() - start
        results[name] = duration
        print(f"{name:<20}: {duration:.4f} seconds")

    # Calculate Ratio
    names = list(MODELS.keys())
    t1 = results[names[0]]
    t2 = results[names[1]]
    t3 = results[names[2]]
    print(f"Ratio: {names[0]} is {t1/t3:.2f}x slower than {names[2]} and {names[1]} is {t2/t3:.2f}x slower than {names[2]}")

run_speed_test()


=== 1. Embedding Speed Test (60 queries) ===
MiniLM              : 0.1443 seconds
MPNet               : 0.2810 seconds
BGE                 : 0.0452 seconds
Ratio: MiniLM is 3.19x slower than BGE and MPNet is 6.22x slower than BGE


In [60]:
def get_top_relationship_match(graph, query, model, index_name):
    try:
        query_vector = model.embed_query(query)
        
        cypher = """
        CALL db.index.vector.queryRelationships($index_name, 1, $embedding)
        YIELD relationship, score
        RETURN relationship.feature_text AS text, score
        """
        
        result = graph.query(cypher, {
            "index_name": index_name,
            "embedding": query_vector
        })
        
        if result:
            return result[0]['text'], result[0]['score']
        return "No match found", 0.0
        
    except Exception as e:
        return f"Error: {str(e)}", 0.0

In [65]:
def run_quality_test(graph):
    print("\n=== 2. Quality Comparison (Human Eye Test) ===")

    for q in sample_queries:
        print(f"QUERY: {q}")
        print("-" * 100)
        print(f"{'Model':<15} | {'Score':<8} | {'Retrieved Snippet (First 120 chars)':<60}")
        print("-" * 100)
        
        for name, config in MODELS.items():
            text, score = get_top_relationship_match(
                graph, 
                q, 
                config['model'], 
                config['index_name']
            )
            
            clean_text = text.replace('\n', ' ')
            # snippet = (clean_text[:115] + '...') if len(clean_text) > 115 else clean_text
            
            print(f"{name:<15} | {score:.4f}   | {clean_text:<60}")
        
        print("\n")

run_quality_test(graph)


=== 2. Quality Comparison (Human Eye Test) ===
QUERY: What is Mohamed Salah total points in gameweek 38 2022-23 season?
----------------------------------------------------------------------------------------------------
Model           | Score    | Retrieved Snippet (First 120 chars)                         
----------------------------------------------------------------------------------------------------
MiniLM          | 0.8082   | Player ID: 233, Player Name: Mohamed Salah, Season: 2021-22, Team: Liverpool, Gameweek: 9, Position: MID, Fixture: 88, Home Team: Man Utd, Away Team: Liverpool, Kickoff Time: 2021-10-24 15:30:00+00:00, Total Points This Gameweek: 24, Goals Scored This Gameweek: 3, Assists This Gameweek: 1, Minutes Played This Gameweek: 90, Bonus Points This Gameweek: 3, Clean Sheets This Gameweek: 1, Yellow Cards This Gameweek: 0, Red Cards This Gameweek: 0, Own Goals This Gameweek: 0, Penalties Saved This Gameweek: 0, Penalties Missed This Gameweek: 0, Saves: 0, Form 

In [49]:
from langchain_community.vectorstores import Neo4jVector

def get_top_result(query, model, index_name, embedding_prop):
    store = Neo4jVector.from_existing_index(
        embedding=model,
        url= config.get('URI'), username=config.get('USERNAME'), password=config.get('PASSWORD'),
        index_name=index_name,
        node_label="Player",
        embedding_node_property=embedding_prop,
        text_node_property= "fpl_features"
    )
    
    result = store.similarity_search_with_score(query, k=1)
    return result[0]

test_qs = [
    "What is Mohamed Salah total points in 2022-23 season?",
    "What position does Kevin De Bruyne play?",
    "How many assists did Harry Kane have in gameweek 10?",
]

print("\nQuality Comparison:")
print(f"{'Query':<50} | {'Model':<10} | {'Top Match':<20} | {'Score':<5}")
print("-" * 95)

for q in test_qs:
    # 1. MiniLM
    doc_a, score_a = get_top_result(q, model_minilm, "player_feature_index_MiniLM", "feature_vector_embedding_MiniLM")
    name_a = doc_a.metadata['player_name']
    
    # 2. MPNet
    doc_b, score_b = get_top_result(q, model_mpnet, "player_feature_index_MPNet", "feature_vector_embedding_MPNet")
    name_b = doc_b.metadata['player_name']
    
    print(f"{q[:47]+'...':<50} | MiniLM     | {name_a:<20} | {score_a:.4f}")
    print(f"{'':<50} | MPNet      | {name_b:<20} | {score_b:.4f}")
    print("-" * 95)


Quality Comparison:
Query                                              | Model      | Top Match            | Score
-----------------------------------------------------------------------------------------------
What is Mohamed Salah total points in 2022-23 s... | MiniLM     | Mohamed Salah        | 0.8681
                                                   | MPNet      | Mohamed Naser El Sayed Elneny | 0.7695
-----------------------------------------------------------------------------------------------
What position does Kevin De Bruyne play?...        | MiniLM     | Kevin De Bruyne      | 0.8575
                                                   | MPNet      | Cheikhou Kouyaté     | 0.7303
-----------------------------------------------------------------------------------------------
How many assists did Harry Kane have in gamewee... | MiniLM     | Daniel Castelo Podence | 0.7347
                                                   | MPNet      | Harry Kane           | 0.8144
---------

### LLM and comparisons

In [37]:
%pip install -U langchain langchain-community langchain-core pydantic typing-extensions

Collecting langchain
  Downloading langchain-1.1.3-py3-none-any.whl (102 kB)
     ---------------------------------------- 0.0/102.2 kB ? eta -:--:--
     ----------- --------------------------- 30.7/102.2 kB 1.3 MB/s eta 0:00:01
     ----------------------------------- --- 92.2/102.2 kB 1.3 MB/s eta 0:00:01
     -------------------------------------- 102.2/102.2 kB 1.5 MB/s eta 0:00:00
Collecting langchain-core
  Downloading langchain_core-1.1.3-py3-none-any.whl (475 kB)
     ---------------------------------------- 0.0/475.3 kB ? eta -:--:--
     ------- ------------------------------- 92.2/475.3 kB 2.6 MB/s eta 0:00:01
     ------------- ------------------------ 174.1/475.3 kB 2.1 MB/s eta 0:00:01
     ---------------------- --------------- 276.5/475.3 kB 2.4 MB/s eta 0:00:01
     ----------------------------- -------- 368.6/475.3 kB 2.6 MB/s eta 0:00:01
     ---------------------------------- --- 430.1/475.3 kB 2.7 MB/s eta 0:00:01
     -------------------------------------  471.0/


[notice] A new release of pip is available: 23.0.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [38]:
%pip install langchain langchain-community langchain-core langchain-huggingface




[notice] A new release of pip is available: 23.0.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip





In [39]:
%pip install langchain-classic

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.0.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [54]:
from langchain_core.language_models import LLM
from typing import Optional, List, Any
from pydantic import Field

class GemmaLangChainWrapper(LLM):
    client: Any = Field(...)
    max_tokens: int = 500
    
    @property
    def _llm_type(self) -> str:
        return "gemma_hf_api"
    
    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        response = self.client.chat_completion(
            messages=[{"role": "user", "content": prompt}],
            max_tokens=self.max_tokens,
            temperature=0.2 
        )
        return response.choices[0].message["content"]


In [55]:
class LlamaLangChainWrapper(LLM):
    client: Any = Field(...)
    max_tokens: int = 500
    
    @property
    def _llm_type(self) -> str:
        return "llama_hf_api"
    
    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        response = self.client.chat_completion(
            messages=[{"role": "user", "content": prompt}],
            max_tokens=self.max_tokens,
            temperature=0.2 
        )
        return response.choices[0].message["content"]

In [56]:
class MistralLangChainWrapper(LLM):
    client: Any = Field(...)
    max_tokens: int = 500
    
    @property
    def _llm_type(self) -> str:
        return "mistral_hf_api"
    
    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        response = self.client.chat_completion(
            messages=[{"role": "user", "content": prompt}],
            max_tokens=self.max_tokens,
            temperature=0.2 
        )
        return response.choices[0].message["content"]

In [81]:
from langchain_huggingface import HuggingFaceEndpoint
from huggingface_hub import InferenceClient

HF_TOKEN = config.get('HF_TOKEN')
print("Initializing models...")
# Gemma
gemma_client = InferenceClient(model="google/gemma-2-2b-it", token=HF_TOKEN)
gemma_llm = GemmaLangChainWrapper(client=gemma_client, max_tokens=500)

# Llama
llama_client = InferenceClient(model="meta-llama/Llama-3.2-3B-Instruct", token=HF_TOKEN)
llama_llm = LlamaLangChainWrapper(client=llama_client, max_tokens=500)

# Mistral
mistral_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.2", token=HF_TOKEN)
mistral_llm = MistralLangChainWrapper(client=mistral_client, max_tokens=500)

models = {
        "Gemma-2-2B": gemma_llm,
        "Llama-3.2-3B": llama_llm,
        "Mistral-7B": mistral_llm
    }

Initializing models...


In [92]:
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from langchain_classic.chains import create_retrieval_chain
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate

def rag_pipline(llm, embedding_model, embedding_model_name, query):
    # 1. Retrieve from KG via Cypher
    intent = classify_fpl_intents(query)
    entities = extract_fpl_entities(query)
    cypher_query = get_fpl_cypher_query(intent, entities)
    cypher_result = graph.query(cypher_query)
    formatted_cypher = format_query_result(intent, cypher_result, entities)

    embedding_context = retrieve_embedding_search(query, embedding_model, embedding_model_name)

    # 3. Combine Contexts
    combined_context = f"Cypher Results:\n{formatted_cypher}\n\nEmbedding Results:\n{embedding_context}"
    print("Combined Context:\n", combined_context)

    # 4. Create Prompt
    prompt = ChatPromptTemplate.from_template("""
    You are an expert Fantasy Premier League assistant.

    Use the context below to answer the user's question.

    <context>
    {context}
    </context>

    Question: {input}
    """)

    document_chain = create_stuff_documents_chain(llm, prompt)

    class DummyRetriever(BaseRetriever):
        def _get_relevant_documents(self, query: str):
            return [Document(page_content=combined_context)]

    dummy_retriever = DummyRetriever()
    qa_chain = create_retrieval_chain(dummy_retriever, document_chain)

    return qa_chain

In [46]:
query = "When does Arsenal play against Liverpool in season 2022-23?"
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
embedding_model_name = "MPNet"
rag_chain = rag_pipline(
    llm=gemma_llm,
    embedding_model=embedding_model,
    embedding_model_name=embedding_model_name,
    query=query
)
response = rag_chain.invoke({"input": query})
print(response["answer"])

Combined Context:
 Cypher Results:
Upcoming Fixtures:

1. Arsenal will play against Liverpool in Gameweek 10 at 2022-10-09 15:30:00+00:00 (Fixture #91)
2. Liverpool will play against Arsenal in Gameweek 30 at 2023-04-09 15:30:00+00:00 (Fixture #296)

Embedding Results:
[Team] Team: Arsenal. 2021-22 - Fixtures: 38, Goals: 107, Assists: 89. F1 GW1: vs Brentford (Away) G:2 A:1. F18 GW2: vs Chelsea (Home) G:2 A:2. F24 GW3: vs Man City (Away) G:5 A:5. F31 GW4: vs Norwich (Home) G:1 A:1. F43 GW5: vs Burnley (Away) G:1 A:1. F51 GW6: vs Spurs (Home) G:4 A:3. F61 GW7: vs Brighton (Away) G:0 A:0. F71 GW8: vs Crystal Palace (Home) G:4 A:4. F81 GW9: vs Aston Villa (Home) G:4 A:3. F93 GW10: vs Leicester (Away) G:2 A:1. F101 GW11: vs Watford (Home) G:1 A:0. F114 GW12: vs Liverpool (Away) G:4 A:3. F121 GW13: vs Newcastle (Home) G:2 A:2. F134 GW14: vs Man Utd (Away) G:5 A:5. F142 GW15: vs Everton (Away) G:3 A:3. F161 GW17: vs West Ham (Home) G:2 A:2. F187 GW19: vs Norwich (Away) G:5 A:4. F201 GW21: vs

In [73]:
%pip install protobuf sentencepiece

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.0.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [82]:
from transformers import AutoTokenizer

print("Loading model tokenizers...")
model_tokenizers = {
    "Gemma-2-2B": AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token=HF_TOKEN),
    "Llama-3.2-3B": AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", token=HF_TOKEN),
    "Mistral-7B": AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", token=HF_TOKEN)
}

Loading model tokenizers...


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [85]:
def get_token_count(text: str, model_name: str) -> int:
    tokenizer = model_tokenizers.get(model_name)
    
    if tokenizer:
        return len(tokenizer.encode(text))
    
    return len(text) // 4

In [83]:
def get_manual_evaluation(query: str, response: str, expected: str, context: str) -> dict:
    print("\n" + "="*80)
    print("MANUAL EVALUATION REQUIRED")
    print("="*80)
    print(f"\nQuery: {query}")
    print(f"\nExpected: {expected}")
    print(f"\nResponse: {response}")
    print(f"\nContext Preview: {str(context)[:200]}...")
    print("\nRate 1-5:")
    
    try:
        return {
            "accuracy": int(input("Accuracy (correct answer): ")),
            "relevance": int(input("Relevance (addresses query): ")),
            "completeness": int(input("Completeness (all info): ")),
            "naturalness": int(input("Naturalness (fluent): ")),
            "correctness": int(input("Correctness (factual): "))
        }
    except:
        print("Invalid input, defaulting to 3")
        return {"accuracy": 3, "relevance": 3, "completeness": 3, "naturalness": 3, "correctness": 3}

In [88]:
test_cases = [
    {
        "query": "Show me the total goals scored by Kevin De Bruyne for the 2022-23 season.",
        "expected": "Kevin De Bruyne total goals scored in 2022-23 season is 7.",
        "category": "player_stats",
    },
    {
        "query": "Compare Mohamed Salah and Harry Kane in gameweek 8 season 2021-22 for total points.",
        "expected": "Mohamed Salah had 13 points while Harry Kane had 12 points in gameweek 8 of the 2021-22 season.",
        "category": "player_comparison",
    },
    {
        "query": "Who are the top 5 forwards by total points in the 2022-23 season?",
        "expected": "Top 5 forwards by total points in 2022-23 season are Erling Haaland (272), Harry Kane (263), Ivan Toney (182), Ollie Watkins (175), Callum Wilson (157).",
        "category": "best_players",
    },
    {
        "query": "When do Arsenal and Man city play each other in season 2022-23?",
        "expected": "1. Arsenal will play against Man City in Gameweek 23 at 2023-02-15 19:30:00+00:00 (Fixture #111)\n2. Man City will play against Arsenal in Gameweek 33 at 2023-04-26 19:00:00+00:00 (Fixture #329)",
        "category": "fixture_details",
    }
]

In [98]:
import pandas as pd
from datetime import datetime

def evaluate_models_on_tests(models: dict, test_cases: list):
    all_results = []
    AUTO_EVALUATE = False # Set to False for manual input
    
    print(f"\nStarting evaluation of {len(models)} models on {len(test_cases)} test cases...")

    for model_name, llm_instance in models.items():
        print(f"\n\n{'#'*60}")
        print(f"EVALUATING MODEL: {model_name}")
        print(f"{'#'*60}")
        
        for test in test_cases:
            print(f"\n>>> Test: {test['query']}")
            
            try:
                start_time = time.time()
                
                rag_chain = rag_pipline(
                    llm=llm_instance,
                    embedding_model=embedding_model,
                    embedding_model_name=embedding_model_name,
                    query=test['query']
                )
                
                result = rag_chain.invoke({"input": test['query']})
                response_text = result["answer"]
                context_text = result.get("context", "")
                
                end_time = time.time()
                duration = end_time - start_time
                
                output_tokens = get_token_count(response_text, model_name)
                input_tokens = get_token_count(test['query'] + str(context_text), model_name)
                total_tokens = input_tokens + output_tokens

                # Calculate Tokens Per Second (TPS) - A great metric for comparison!
                tps = output_tokens / duration if duration > 0 else 0
                
                # D. Qualitative Scoring
                if AUTO_EVALUATE:
                    scores = {"accuracy": 3, "relevance": 3, "completeness": 3, "naturalness": 3, "correctness": 3}
                else:
                    scores = get_manual_evaluation(test['query'], response_text, test['expected'], context_text)
                
                # E. Store Result
                record = {
                    "Model": model_name,
                    "Query": test['query'],
                    "Category": test['category'],
                    "Tokens (Out)": output_tokens,
                    "Tokens (Total)": total_tokens,
                    "Speed (TPS)": round(tps, 2),
                    "Response Time (s)": round(duration, 2),
                    "Response": response_text,
                    **scores, # Unpack score dict
                    "Avg Quality": sum(scores.values()) / 5,
                    "Timestamp": datetime.now().isoformat()
                }
                
                all_results.append(record)
                print(f"Done in {duration:.2f}s | Quality: {record['Avg Quality']}")
                
            except Exception as e:
                print(f"ERROR: {str(e)}")

    if not all_results:
        print("No results generated.")
        return

    df = pd.DataFrame(all_results)
    
    df.to_csv("evaluation_results.csv", index=False)
    print(f"\nSaved {len(df)} results to evaluation_results.csv")

In [97]:
evaluate_models_on_tests(models, test_cases)


Starting evaluation of 3 models on 4 test cases...


############################################################
EVALUATING MODEL: Gemma-2-2B
############################################################

>>> Test: Show me the total goals scored by Kevin De Bruyne for the 2022-23 season.
Combined Context:
 Cypher Results:
Kevin De Bruyne in the 2022-23 season has a total of 7 Goals, scoring 7 goals.

Embedding Results:
[Player] Player Profile: Kevin De Bruyne is a MID. Performance: They accumulated 196 total points, scoring 15 goals and providing 8 assists. Advanced Metrics: They had an average ICT Index of 7.8 (Influence: 25.6, Creativity: 30.2, Threat: 22.2). Defensive & Discipline: They kept 13 clean sheets, made 0 saves, and received 2 yellow cards. Form & Impact: Their average form was 0.5 and they earned 33 bonus points.
---
[Player] Player Profile: Kevin De Bruyne is a MID. Performance: They accumulated 183 total points, scoring 7 goals and providing 18 assists. Advanced Metric

KeyError: "Column(s) ['Accuracy', 'Tokens'] do not exist"

In [102]:
df = pd.read_csv("evaluation_results.csv")
print("\n" + "="*70)
print("SUMMARY REPORT")
print("="*70)

summary = df.groupby("Model").agg({
    "Response Time (s)": "mean",
    "Tokens (Total)": "mean",
    "accuracy": "mean",
    "relevance": "mean",
    "completeness": "mean",
    "naturalness": "mean",
    "correctness": "mean",
    "Avg Quality": "mean"
}).round(2)

print("\nPer Model Performance:")
print(summary)

best_model = summary["Avg Quality"].idxmax()
fastest_model = summary["Response Time (s)"].idxmin()

print("-" * 70)
print(f"Best Quality Model: {best_model}")
print(f"Fastest Model:      {fastest_model}")
print("-" * 70)


SUMMARY REPORT

Per Model Performance:
              Response Time (s)  Tokens (Total)  accuracy  relevance  \
Model                                                                  
Gemma-2-2B                10.20         2141.75      5.00       5.00   
Llama-3.2-3B              10.42         1929.25      4.25       4.50   
Mistral-7B                12.83         2427.75      4.00       4.75   

              completeness  naturalness  correctness  Avg Quality  
Model                                                              
Gemma-2-2B            4.50          5.0          5.0          4.9  
Llama-3.2-3B          4.25          5.0          4.0          4.4  
Mistral-7B            4.25          5.0          3.5          4.3  
----------------------------------------------------------------------
Best Quality Model: Gemma-2-2B
Fastest Model:      Gemma-2-2B
----------------------------------------------------------------------
