# Airline Flight Insights - Full Pipeline

This notebook provides a complete Graph-RAG pipeline including:
1. **Neo4j Database Connection**
2. **LLM Setup** (Gemini)
3. **Embeddings** - Vector embeddings for semantic search
4. **Hybrid Retrieval** - Cypher + Semantic search
5. **Question Answering**

## 1. Imports and Setup

In [None]:
from neo4j import GraphDatabase, Driver
from dotenv import load_dotenv, find_dotenv
from langchain_groq import ChatGroq
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional, Callable, Tuple
import numpy as np
import os
import json
import re

In [None]:
# Load environment variables
load_dotenv(find_dotenv())

NEO4J_URI = os.getenv('NEO4J_URI') or os.getenv('URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME') or os.getenv('USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD') or os.getenv('PASSWORD')
groq_api_key = os.getenv('GROQ_API_KEY') or os.getenv('GROQ')

print(f"URI: {NEO4J_URI}")
print(f"Groq API key loaded: {'Yes' if groq_api_key else 'No'}")

In [127]:
# Create Neo4j driver
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
driver.verify_connectivity()
print("Connected to Neo4j!")

Connected to Neo4j!


In [None]:
# Groq: FREE, 30+ requests/minute, very fast!
llm = ChatGroq(
    model="llama-3.3-70b-versatile",
    api_key=groq_api_key,
    temperature=0
)
print("Groq LLM loaded! (llama-3.3-70b-versatile)")

## 2. Cypher Queries

In [129]:
queries = [
    # Intent 1: Operational Delay Diagnostics
    "MATCH (j:Journey)-[:ON]->(f:Flight)-[:ARRIVES_AT]->(a:Airport) RETURN a.station_code AS destination, SUM(j.arrival_delay_minutes) AS total_delay ORDER BY total_delay DESC LIMIT $x",
    "MATCH (j:Journey)-[:ON]->(f:Flight)-[:ARRIVES_AT]->(a:Airport) RETURN a.station_code AS destination, SUM(j.arrival_delay_minutes) AS total_delay ORDER BY total_delay ASC LIMIT $x",
    "MATCH (j:Journey)-[:ON]->(f:Flight)-[:DEPARTS_FROM]->(a:Airport) RETURN a.station_code AS origin, SUM(j.arrival_delay_minutes) AS total_delay ORDER BY total_delay DESC LIMIT $x",
    "MATCH (j:Journey)-[:ON]->(f:Flight)-[:DEPARTS_FROM]->(a:Airport) RETURN a.station_code AS origin, SUM(j.arrival_delay_minutes) AS total_delay ORDER BY total_delay ASC LIMIT $x",
    "MATCH (o:Airport {station_code: $origin_station_code})<-[:DEPARTS_FROM]-(f:Flight)-[:ARRIVES_AT]->(d:Airport), (j:Journey)-[:ON]->(f) WITH o, d, AVG(j.arrival_delay_minutes) AS avg_delay WHERE avg_delay > $x RETURN o.station_code AS origin, d.station_code AS destination, avg_delay",
    "MATCH (j:Journey {number_of_legs: $x}) RETURN AVG(j.arrival_delay_minutes) AS avg_delay",
    # Intent 2: Service Quality
    "MATCH (o:Airport)<-[:DEPARTS_FROM]-(f:Flight)-[:ARRIVES_AT]->(d:Airport), (j:Journey {passenger_class: $class_name})-[:ON]->(f) WITH o, d, AVG(j.food_satisfaction_score) AS avg_food_score WHERE avg_food_score < $threshold RETURN o.station_code AS origin, d.station_code AS destination, avg_food_score",
    "MATCH (j:Journey {food_satisfaction_score: 1})-[:ON]->(f:Flight) WHERE j.actual_flown_miles > $x RETURN DISTINCT f.flight_number",
    # Intent 3: Fleet Performance
    "MATCH (j:Journey)-[:ON]->(f:Flight) WHERE j.arrival_delay_minutes > $x RETURN f.fleet_type_description AS aircraft_type, COUNT(j) AS delay_frequency ORDER BY delay_frequency DESC LIMIT 1",
    "MATCH (j:Journey)-[:ON]->(f:Flight {fleet_type_description: $x}) RETURN AVG(j.food_satisfaction_score) AS avg_food_score",
    "MATCH (j:Journey)-[:ON]->(f:Flight {fleet_type_description: $x}) RETURN AVG(j.actual_flown_miles) AS avg_miles",
    "MATCH (j:Journey)-[:ON]->(f:Flight {fleet_type_description: $x}) WITH COUNT(j) AS total_flights, COUNT(CASE WHEN j.arrival_delay_minutes < 0 THEN 1 END) AS early_flights RETURN (TOFLOAT(early_flights) / total_flights) * 100 AS early_arrival_percentage",
    # Intent 3b: Aircraft Performance Aggregation (NEW)
    "MATCH (j:Journey)-[:ON]->(f:Flight) RETURN f.fleet_type_description AS aircraft_type, AVG(j.arrival_delay_minutes) AS avg_delay, COUNT(j) AS flight_count ORDER BY avg_delay ASC LIMIT $x",
    "MATCH (j:Journey)-[:ON]->(f:Flight) RETURN f.fleet_type_description AS aircraft_type, AVG(j.arrival_delay_minutes) AS avg_delay, COUNT(j) AS flight_count ORDER BY avg_delay DESC LIMIT $x",
    "MATCH (j:Journey)-[:ON]->(f:Flight) WITH f.fleet_type_description AS aircraft_type, COUNT(j) AS total, COUNT(CASE WHEN j.arrival_delay_minutes <= 0 THEN 1 END) AS on_time RETURN aircraft_type, (toFloat(on_time) / total) * 100 AS on_time_pct, total AS flight_count ORDER BY on_time_pct DESC LIMIT $x",
    # Intent 4: Loyalty
    "MATCH (p:Passenger {loyalty_program_level: $loyalty_program_level})-[:TOOK]->(j:Journey) RETURN AVG(j.arrival_delay_minutes) AS avg_delay",
    "MATCH (p:Passenger {loyalty_program_level: $loyalty_program_level})-[:TOOK]->(j:Journey) WHERE j.arrival_delay_minutes > $x RETURN p.record_locator AS passenger_id, j.arrival_delay_minutes AS delay",
    # Intent 5: Demographics
    "MATCH (p:Passenger {generation: $generation})-[:TOOK]->(j:Journey)-[:ON]->(f:Flight) WHERE j.actual_flown_miles > $threshold RETURN f.fleet_type_description AS aircraft_type, COUNT(f) AS usage_count ORDER BY usage_count DESC LIMIT 1",
    "MATCH (p:Passenger {generation: $generation})-[:TOOK]->(j:Journey)-[:ON]->(f:Flight) RETURN f.fleet_type_description AS fleet_type, COUNT(f) AS usage_count ORDER BY usage_count DESC LIMIT 1",
    "MATCH (p:Passenger {generation: $generation})-[:TOOK]->(j:Journey)-[:ON]->(f:Flight)-[:ARRIVES_AT]->(a:Airport) RETURN a.station_code AS destination, COUNT(p) AS passenger_volume ORDER BY passenger_volume DESC LIMIT $x"
]

query_descriptions = [
    "Identify the top ${x} destination stations with the highest accumulated arrival delay minutes.",
    "Identify the top ${x} destination stations with the lowest accumulated arrival delay minutes.",
    "Identify the top ${x} origin stations with the highest accumulated arrival delay minutes.",
    "Identify the top ${x} origin stations with the lowest accumulated arrival delay minutes.",
    "Find routes from the origin station ${origin_station_code} where the average arrival delay exceeds ${x} minutes.",
    "Calculate the average arrival delay for flights consisting of exactly ${x} legs.",
    "Identify routes for the passenger class ${class_name} where the average food satisfaction score is below ${threshold}.",
    "List the flight numbers for journeys longer than ${x} miles where the food satisfaction score was 1.",
    "Identify the aircraft type that has the highest frequency of arrival delays greater than ${x} minutes.",
    "Calculate the average food satisfaction score for passengers flying on the ${x} fleet.",
    "Calculate the average actual flown miles for the ${x} fleet.",
    "Calculate the percentage of early arrivals for the ${x} fleet.",
    # NEW: Aircraft performance aggregation
    "List the top ${x} aircraft types with the LOWEST average arrival delay (best on-time performance).",
    "List the top ${x} aircraft types with the HIGHEST average arrival delay (worst on-time performance).",
    "List the top ${x} aircraft types by on-time arrival percentage (arrivals with delay <= 0 minutes).",
    # Loyalty
    "Calculate the average arrival delay experienced by passengers with the loyalty level ${loyalty_program_level}.",
    "Find the record locators for passengers with loyalty level ${loyalty_program_level} who experienced a delay greater than ${x} minutes.",
    # Demographics
    "Identify the most common aircraft type used by the ${generation} generation for journeys exceeding ${threshold} miles.",
    "Identify the most frequently used fleet type for the ${generation} generation.",
    "Identify the top ${x} destination stations for the ${generation} generation based on passenger volume."
]

print(f"Loaded {len(queries)} queries")

Loaded 20 queries


## 3. Query Execution Functions

In [130]:
def run_query(query_index: int, **params) -> list:
    """Run a query by index with parameters."""
    if query_index < 0 or query_index >= len(queries):
        raise ValueError(f"Query index {query_index} out of range (0-{len(queries)-1})")
    with driver.session() as session:
        result = session.run(queries[query_index], **params)
        return [record.data() for record in result]

In [131]:
# Load KG schema values for better parameter matching
def load_kg_schema(driver) -> Dict[str, Any]:
    """Query the KG to get valid values for each parameter field."""
    schema = {}
    
    with driver.session() as session:
        # Airport codes
        result = session.run('MATCH (a:Airport) RETURN DISTINCT a.station_code AS code ORDER BY code')
        schema['airport_codes'] = [r['code'] for r in result]
        
        # Passenger classes
        result = session.run('MATCH (j:Journey) RETURN DISTINCT j.passenger_class AS class ORDER BY class')
        schema['passenger_classes'] = [r['class'] for r in result if r['class']]
        
        # Generations
        result = session.run('MATCH (p:Passenger) RETURN DISTINCT p.generation AS gen ORDER BY gen')
        schema['generations'] = [r['gen'] for r in result if r['gen']]
        
        # Loyalty levels
        result = session.run('MATCH (p:Passenger) RETURN DISTINCT p.loyalty_program_level AS level ORDER BY level')
        schema['loyalty_levels'] = [r['level'] for r in result if r['level']]
        
        # Fleet types
        result = session.run('MATCH (f:Flight) RETURN DISTINCT f.fleet_type_description AS fleet ORDER BY fleet')
        schema['fleet_types'] = [r['fleet'] for r in result if r['fleet']]
        
        # Number of legs
        result = session.run('MATCH (j:Journey) RETURN DISTINCT j.number_of_legs AS legs ORDER BY legs')
        schema['number_of_legs'] = [r['legs'] for r in result if r['legs']]
    
    return schema

# Load schema on startup
kg_schema = load_kg_schema(driver)
print(f"Loaded KG schema:")
print(f"  - {len(kg_schema['airport_codes'])} airport codes")
print(f"  - Generations: {kg_schema['generations']}")
print(f"  - Loyalty levels: {kg_schema['loyalty_levels']}")
print(f"  - Fleet types: {len(kg_schema['fleet_types'])} types")
print(f"  - Passenger classes: {kg_schema['passenger_classes']}")

Loaded KG schema:
  - 158 airport codes
  - Generations: ['Boomer', 'Gen X', 'Gen Z', 'Millennial', 'NBK', 'Silent']
  - Loyalty levels: ['NBK', 'global services', 'non-elite', 'premier 1k', 'premier gold', 'premier platinum', 'premier silver']
  - Fleet types: 20 types
  - Passenger classes: ['Economy']


In [132]:
def get_context(prompt: str) -> list:
    """Use Groq LLM to identify relevant queries and extract parameters with KG schema awareness."""
    safe_descriptions = [desc.replace('${', '<').replace('}', '>') for desc in query_descriptions]
    query_list = "\n".join([f"{i}: {desc}" for i, desc in enumerate(safe_descriptions)])
    
    # Build schema reference for the LLM
    schema_info = (
        "VALID VALUES FROM THE DATABASE:\n"
        f"- origin_station_code/destination: {kg_schema['airport_codes'][:20]}... ({len(kg_schema['airport_codes'])} total)\n"
        f"- generation: {kg_schema['generations']}\n"
        f"- loyalty_program_level: {kg_schema['loyalty_levels']}\n"
        f"- class_name: {kg_schema['passenger_classes']}\n"
        f"- fleet types (for x when fleet-related): {kg_schema['fleet_types']}\n"
        f"- number_of_legs: {kg_schema['number_of_legs']}\n"
    )
    
    full_prompt = (
        "You are an expert at analyzing user questions about airline flight data.\n\n"
        "Available queries:\n" + query_list + "\n\n"
        + schema_info + "\n"
        "PARAMETER RULES:\n"
        "- x: a number for counts/limits (default 5), delay thresholds (default 30), or miles threshold\n"
        "- For fleet queries (indices 9-11), x must be an EXACT fleet type from the list above\n"
        "- Match user terms to closest valid values (e.g., 'Baby Boomer' -> 'Boomer', 'gold member' -> 'premier gold')\n"
        "- Use exact values from the database schema above\n\n"
        'CRITICAL: Return ONLY ONE valid JSON array on a single line. No explanations.\n'
        'Format: [{"query_index": 0, "params": {"x": 5}}]\n\n'
        "User question: " + prompt + "\n\nJSON:"
    )
    
    response = llm.invoke(full_prompt)
    response_text = response.content.strip()
    
    # Clean up response - remove markdown and extra whitespace
    response_text = response_text.replace('```json', '').replace('```', '').strip()
    
    # Find the FIRST valid JSON array (ignore any additional lines)
    # Look for the first line that starts with '['
    for line in response_text.split('\n'):
        line = line.strip()
        if line.startswith('[') and line.endswith(']'):
            try:
                return json.loads(line)
            except json.JSONDecodeError:
                continue
    
    # Fallback: try to find any JSON array in the response
    json_match = re.search(r'\[\s*\{[^\[]*\}\s*\]', response_text, re.DOTALL)
    if json_match:
        try:
            return json.loads(json_match.group())
        except json.JSONDecodeError as e:
            print(f"JSON parse error: {e}")
            print(f"Response was: {response_text[:200]}")
    
    return []

In [133]:
def format_query_result(query_index: int, **params) -> str:
    """Run query and format result as context."""
    if query_index < 0 or query_index >= len(queries):
        return f"Error: Query index {query_index} out of range."
    
    description = query_descriptions[query_index]
    for name, value in params.items():
        description = description.replace(f"${{{name}}}", str(value))
    
    try:
        results = run_query(query_index, **params)
    except Exception as e:
        return f'Error for "{description}": {e}'
    
    if not results:
        return f'"{description}": No data found.'
    
    formatted = []
    for r in results:
        parts = [f"{k}: {v:.2f}" if isinstance(v, float) else f"{k}: {v}" for k, v in r.items()]
        formatted.append("  - " + ", ".join(parts))
    
    return f'"{description}":\n' + "\n".join(formatted)

## 4. Embeddings Module

Vector embeddings for semantic search. Models:
- `minilm`: all-MiniLM-L6-v2 (384 dims, fast)
- `mpnet`: all-mpnet-base-v2 (768 dims, higher quality)

In [134]:
# Embedding model configurations
EMBEDDING_MODELS = {
    "minilm": {"name": "all-MiniLM-L6-v2", "dimensions": 384, "property_name": "embedding_minilm"},
    "mpnet": {"name": "all-mpnet-base-v2", "dimensions": 768, "property_name": "embedding_mpnet"}
}

_model_cache: Dict[str, SentenceTransformer] = {}

def get_model(model_key: str) -> SentenceTransformer:
    """Load and cache an embedding model."""
    if model_key not in EMBEDDING_MODELS:
        raise ValueError(f"Unknown model: {model_key}. Use 'minilm' or 'mpnet'")
    if model_key not in _model_cache:
        print(f"Loading {EMBEDDING_MODELS[model_key]['name']}...")
        _model_cache[model_key] = SentenceTransformer(EMBEDDING_MODELS[model_key]["name"])
        print("Model loaded!")
    return _model_cache[model_key]

In [None]:
def create_journey_sentences(props: Dict[str, Any]) -> List[str]:
    """Create MULTIPLE focused sentences for a Journey (better for retrieval)."""
    sentences = []
    
    # Sentence 1: Route & Flight info
    flight_number = props.get('flight_number', '')
    fleet_type = props.get('fleet_type', '')
    origin = props.get('origin', '')
    destination = props.get('destination', '')
    if flight_number and origin and destination:
        route_text = f"Flight {flight_number} from {origin} to {destination}"
        if fleet_type:
            route_text += f" operated by {fleet_type} aircraft"
        sentences.append(route_text + ".")
    
    # Sentence 2: Passenger demographics & loyalty
    generation = props.get('generation', '')
    loyalty = props.get('loyalty_program_level', '')
    if generation or loyalty:
        passenger_text = f"Passenger is a {generation}" if generation else "Passenger"
        if loyalty:
            passenger_text += f" with {loyalty} loyalty program level"
        sentences.append(passenger_text + ".")
    
    # Sentence 3: Journey experience
    passenger_class = props.get('passenger_class', 'Economy')
    miles = props.get('actual_flown_miles', 0)
    delay = props.get('arrival_delay_minutes', 0)
    legs = props.get('number_of_legs', 1)
    food_score = props.get('food_satisfaction_score', 3)
    
    delay_text = f"arrived {abs(delay)} minutes early" if delay < 0 else "on time" if delay == 0 else f"delayed {delay} minutes"
    food_labels = {1: "very poor", 2: "poor", 3: "average", 4: "good", 5: "excellent"}
    satisfaction = food_labels.get(food_score, "average")
    
    exp_text = f"{passenger_class} class journey covering {miles:.0f} miles over {legs} segment{'s' if legs > 1 else ''}, {delay_text}, {satisfaction} food satisfaction."
    sentences.append(exp_text)
    
    return sentences


def create_journey_text(props: Dict[str, Any]) -> str:
    """Create combined text from all journey sentences."""
    return " ".join(create_journey_sentences(props))


def create_flight_text(props: Dict[str, Any], origin: str = None, destination: str = None) -> str:
    """Create text representation of a Flight node."""
    flight_num = props.get('flight_number', 'Unknown')
    fleet = props.get('fleet_type_description', 'Unknown aircraft')
    route = f" from {origin} to {destination}" if origin and destination else ""
    return f"Flight {flight_num} operated by {fleet}{route}."


def create_passenger_text(props: Dict[str, Any]) -> str:
    """Create text representation of a Passenger node."""
    return f"A {props.get('generation', 'unknown')} passenger with {props.get('loyalty_program_level', 'unknown')} loyalty status."

In [136]:
def generate_embeddings(texts: List[str], model_key: str = "minilm") -> np.ndarray:
    """Generate embeddings for a list of texts."""
    model = get_model(model_key)
    return model.encode(texts, show_progress_bar=True, convert_to_numpy=True)


def generate_single_embedding(text: str, model_key: str = "minilm") -> List[float]:
    """Generate embedding for a single text."""
    model = get_model(model_key)
    return model.encode(text, convert_to_numpy=True).tolist()

In [137]:
def fetch_journey_nodes(driver: Driver) -> List[Dict[str, Any]]:
    """Fetch Journey nodes with ENRICHED data including passenger and flight info."""
    query = """
    MATCH (p:Passenger)-[:TOOK]->(j:Journey)-[:ON]->(f:Flight)
    OPTIONAL MATCH (f)-[:DEPARTS_FROM]->(o:Airport)
    OPTIONAL MATCH (f)-[:ARRIVES_AT]->(d:Airport)
    RETURN j.feedback_ID AS feedback_ID,
           j.passenger_class AS passenger_class,
           j.food_satisfaction_score AS food_satisfaction_score,
           j.arrival_delay_minutes AS arrival_delay_minutes,
           j.actual_flown_miles AS actual_flown_miles,
           j.number_of_legs AS number_of_legs,
           p.generation AS generation,
           p.loyalty_program_level AS loyalty_program_level,
           f.flight_number AS flight_number,
           f.fleet_type_description AS fleet_type,
           o.station_code AS origin,
           d.station_code AS destination
    """
    with driver.session() as session:
        result = session.run(query)
        return [{"feedback_ID": r["feedback_ID"], "properties": dict(r)} for r in result]


def fetch_flight_nodes(driver: Driver) -> List[Dict[str, Any]]:
    """Fetch all Flight nodes from Neo4j with route info."""
    query = """
    MATCH (f:Flight)
    OPTIONAL MATCH (f)-[:DEPARTS_FROM]->(origin:Airport)
    OPTIONAL MATCH (f)-[:ARRIVES_AT]->(dest:Airport)
    RETURN f.flight_number AS flight_number, f.fleet_type_description AS fleet_type_description,
           origin.station_code AS origin, dest.station_code AS destination
    """
    with driver.session() as session:
        result = session.run(query)
        return [{
            "flight_number": r["flight_number"],
            "fleet_type_description": r["fleet_type_description"],
            "properties": {"flight_number": r["flight_number"], "fleet_type_description": r["fleet_type_description"]},
            "origin": r["origin"], "destination": r["destination"]
        } for r in result]

In [138]:
def create_vector_index(driver: Driver, model_key: str, node_label: str = "Journey"):
    """Create a vector index in Neo4j."""
    config = EMBEDDING_MODELS[model_key]
    index_name = f"{node_label.lower()}_{config['property_name']}"
    
    create_query = f"""
    CREATE VECTOR INDEX {index_name} IF NOT EXISTS
    FOR (n:{node_label}) ON n.{config['property_name']}
    OPTIONS {{indexConfig: {{
        `vector.dimensions`: {config['dimensions']},
        `vector.similarity_function`: 'cosine'
    }}}}
    """
    with driver.session() as session:
        try:
            session.run(f"DROP INDEX {index_name} IF EXISTS")
        except: pass
        session.run(create_query)
        print(f"Created index: {index_name}")

In [None]:
def store_journey_embeddings(driver: Driver, feedback_ids: List[str], embeddings: np.ndarray, 
                              model_key: str = "minilm", batch_size: int = 100):
    """Store embeddings for Journey nodes."""
    prop = EMBEDDING_MODELS[model_key]["property_name"]
    query = f"UNWIND $batch AS item MATCH (j:Journey {{feedback_ID: item.feedback_ID}}) SET j.{prop} = item.embedding"
    
    with driver.session() as session:
        for i in range(0, len(feedback_ids), batch_size):
            batch = [{"feedback_ID": feedback_ids[j], "embedding": embeddings[j].tolist()} 
                     for j in range(i, min(i + batch_size, len(feedback_ids)))]
            session.run(query, batch=batch)
            print(f"Stored {min(i + batch_size, len(feedback_ids))}/{len(feedback_ids)}...")


def store_journey_multi_embeddings(driver: Driver, journeys: List[Dict], all_embeddings: Dict[str, np.ndarray],
                                    model_key: str = "minilm", batch_size: int = 100):
    """Store multiple embeddings per Journey (route, passenger, experience)."""
    base_prop = EMBEDDING_MODELS[model_key]["property_name"]
    
    for emb_type in ['route', 'passenger', 'experience']:
        if emb_type not in all_embeddings:
            continue
        prop = f"{base_prop}_{emb_type}"
        query = f"UNWIND $batch AS item MATCH (j:Journey {{feedback_ID: item.feedback_ID}}) SET j.{prop} = item.embedding"
        embeddings = all_embeddings[emb_type]
        
        with driver.session() as session:
            for i in range(0, len(journeys), batch_size):
                batch = [{"feedback_ID": journeys[j]["feedback_ID"], "embedding": embeddings[j].tolist()} 
                         for j in range(i, min(i + batch_size, len(journeys)))]
                session.run(query, batch=batch)
        print(f"Stored {emb_type} embeddings")


def store_flight_embeddings(driver: Driver, flights: List[Dict], embeddings: np.ndarray,
                            model_key: str = "minilm", batch_size: int = 100):
    """Store embeddings for Flight nodes."""
    prop = EMBEDDING_MODELS[model_key]["property_name"]
    query = f"UNWIND $batch AS item MATCH (f:Flight {{flight_number: item.flight_number, fleet_type_description: item.fleet_type_description}}) SET f.{prop} = item.embedding"
    
    with driver.session() as session:
        for i in range(0, len(flights), batch_size):
            batch = [{"flight_number": flights[j]["flight_number"], 
                      "fleet_type_description": flights[j]["fleet_type_description"],
                      "embedding": embeddings[j].tolist()} 
                     for j in range(i, min(i + batch_size, len(flights)))]
            session.run(query, batch=batch)
            print(f"Stored {min(i + batch_size, len(flights))}/{len(flights)}...")

In [140]:
def semantic_search_journeys(driver: Driver, query_text: str, model_key: str = "minilm", top_k: int = 5) -> List[Dict]:
    """Semantic search on Journey nodes - returns enriched data."""
    query_embedding = generate_single_embedding(query_text, model_key)
    index_name = f"journey_{EMBEDDING_MODELS[model_key]['property_name']}"
    
    # Updated query to fetch connected entities
    search_query = f"""
    CALL db.index.vector.queryNodes('{index_name}', $top_k, $query_embedding)
    YIELD node, score
    MATCH (p:Passenger)-[:TOOK]->(node)-[:ON]->(f:Flight)
    OPTIONAL MATCH (f)-[:DEPARTS_FROM]->(o:Airport)
    OPTIONAL MATCH (f)-[:ARRIVES_AT]->(d:Airport)
    RETURN node.feedback_ID AS feedback_ID, 
           node.passenger_class AS passenger_class,
           node.food_satisfaction_score AS food_satisfaction_score,
           node.arrival_delay_minutes AS arrival_delay_minutes,
           node.actual_flown_miles AS actual_flown_miles,
           node.number_of_legs AS number_of_legs,
           p.generation AS generation,
           p.loyalty_program_level AS loyalty_program_level,
           f.flight_number AS flight_number,
           f.fleet_type_description AS fleet_type,
           o.station_code AS origin,
           d.station_code AS destination,
           score
    ORDER BY score DESC
    """
    with driver.session() as session:
        result = session.run(search_query, top_k=top_k, query_embedding=query_embedding)
        return [{**dict(r), "similarity_score": r["score"]} for r in result]


def semantic_search_flights(driver: Driver, query_text: str, model_key: str = "minilm", top_k: int = 5) -> List[Dict]:
    """Semantic search on Flight nodes."""
    query_embedding = generate_single_embedding(query_text, model_key)
    index_name = f"flight_{EMBEDDING_MODELS[model_key]['property_name']}"
    
    search_query = f"""
    CALL db.index.vector.queryNodes('{index_name}', $top_k, $query_embedding)
    YIELD node, score
    MATCH (node)-[:DEPARTS_FROM]->(origin:Airport)
    MATCH (node)-[:ARRIVES_AT]->(dest:Airport)
    RETURN node.flight_number AS flight_number, node.fleet_type_description AS fleet_type_description,
           origin.station_code AS origin, dest.station_code AS destination, score
    ORDER BY score DESC
    """
    with driver.session() as session:
        result = session.run(search_query, top_k=top_k, query_embedding=query_embedding)
        return [{**dict(r), "similarity_score": r["score"]} for r in result]

In [141]:
def format_embedding_results(results: List[Dict], node_type: str = "Journey") -> str:
    """Format embedding search results as context with enriched info."""
    if not results:
        return f"No similar {node_type} nodes found."
    
    lines = [f"Found {len(results)} relevant {node_type} records:"]
    for i, r in enumerate(results, 1):
        if node_type == "Journey":
            # Rich journey text with all connected info
            text = create_journey_text(r)
        else:
            text = create_flight_text({"flight_number": r.get("flight_number"), 
                                       "fleet_type_description": r.get("fleet_type_description")},
                                      r.get("origin"), r.get("destination"))
        lines.append(f"  {i}. (score: {r['similarity_score']:.3f}) {text}")
    return "\n".join(lines)


def get_embedding_context(driver: Driver, query: str, model_key: str = "minilm", top_k: int = 5) -> str:
    """Get context from embedding-based semantic search."""
    contexts = []
    try:
        contexts.append(format_embedding_results(semantic_search_journeys(driver, query, model_key, top_k), "Journey"))
    except Exception as e:
        contexts.append(f"Journey search error: {e}")
    try:
        contexts.append(format_embedding_results(semantic_search_flights(driver, query, model_key, top_k), "Flight"))
    except Exception as e:
        contexts.append(f"Flight search error: {e}")
    return "\n\n".join(contexts)

In [142]:
def generate_and_store_all_embeddings(driver: Driver, model_key: str = "minilm"):
    """Generate and store embeddings for all Journey and Flight nodes."""
    print(f"\n{'='*60}")
    print(f"Generating embeddings with {EMBEDDING_MODELS[model_key]['name']}")
    print(f"{'='*60}\n")
    
    # Journeys
    print("Fetching Journey nodes...")
    journeys = fetch_journey_nodes(driver)
    print(f"Found {len(journeys)} journeys")
    
    if journeys:
        texts = [create_journey_text(j["properties"]) for j in journeys]
        embeddings = generate_embeddings(texts, model_key)
        create_vector_index(driver, model_key, "Journey")
        store_journey_embeddings(driver, [j["feedback_ID"] for j in journeys], embeddings, model_key)
    
    # Flights
    print("\nFetching Flight nodes...")
    flights = fetch_flight_nodes(driver)
    print(f"Found {len(flights)} flights")
    
    if flights:
        texts = [create_flight_text(f["properties"], f["origin"], f["destination"]) for f in flights]
        embeddings = generate_embeddings(texts, model_key)
        create_vector_index(driver, model_key, "Flight")
        store_flight_embeddings(driver, flights, embeddings, model_key)
    
    print(f"\nEmbedding generation complete!")

## 5. Hybrid Retrieval

Combines Cypher queries with embedding-based semantic search.

In [143]:
def get_hybrid_context(driver: Driver, prompt: str, model_key: str = "minilm", top_k: int = 3) -> Dict[str, Any]:
    """Get context from both Cypher queries and embedding search."""
    results = {'cypher_context': [], 'embedding_context': '', 'combined_context': ''}
    
    # Cypher context
    try:
        for cq in get_context(prompt):
            results['cypher_context'].append(format_query_result(cq['query_index'], **cq['params']))
    except Exception as e:
        results['cypher_context'].append(f"Cypher error: {e}")
    
    # Embedding context
    try:
        results['embedding_context'] = get_embedding_context(driver, prompt, model_key, top_k)
    except Exception as e:
        results['embedding_context'] = f"Embedding error: {e}"
    
    # Combine
    cypher_text = '\n\n'.join(results['cypher_context'])
    results['combined_context'] = f"=== STRUCTURED QUERY RESULTS ===\n{cypher_text}\n\n=== SEMANTIC SEARCH RESULTS ===\n{results['embedding_context']}"
    
    return results

In [144]:
def answer_with_hybrid_context(driver: Driver, question: str, model_key: str = "minilm") -> str:
    """Answer a question using hybrid retrieval."""
    context = get_hybrid_context(driver, question, model_key, top_k=5)['combined_context']
    
    prompt = f"""You are an AI assistant for an airline company analyzing flight data.

Based on this context from our knowledge graph, answer the user's question.
Only use information from the context. If insufficient, say so.

CONTEXT:
{context}

USER QUESTION: {question}

ANSWER:"""
    
    return llm.invoke(prompt).content

In [145]:
def compare_retrieval_methods(driver: Driver, question: str) -> Dict[str, Any]:
    """Compare results from different retrieval methods."""
    results = {'cypher_only': [], 'embedding_minilm': '', 'embedding_mpnet': '', 
               'hybrid_minilm': None, 'hybrid_mpnet': None}
    
    # Cypher only
    try:
        for cq in get_context(question):
            results['cypher_only'].append(format_query_result(cq['query_index'], **cq['params']))
    except Exception as e:
        results['cypher_only'] = [f"Error: {e}"]
    
    # Embeddings
    for key in ['minilm', 'mpnet']:
        try:
            results[f'embedding_{key}'] = get_embedding_context(driver, question, key, 5)
        except Exception as e:
            results[f'embedding_{key}'] = f"Error: {e}"
    
    # Hybrid
    results['hybrid_minilm'] = get_hybrid_context(driver, question, "minilm", 5)
    results['hybrid_mpnet'] = get_hybrid_context(driver, question, "mpnet", 5)
    
    return results


def print_comparison(results: Dict[str, Any]):
    """Print comparison results."""
    print("=" * 80 + "\nRETRIEVAL METHOD COMPARISON\n" + "=" * 80)
    print("\n--- CYPHER ONLY ---")
    for ctx in results['cypher_only']: print(ctx + "\n")
    print("\n--- EMBEDDING (MiniLM) ---\n" + results['embedding_minilm'])
    print("\n--- EMBEDDING (MPNet) ---\n" + results['embedding_mpnet'])
    if results['hybrid_minilm']:
        print("\n--- HYBRID (MiniLM) ---\n" + results['hybrid_minilm']['combined_context'])
    if results['hybrid_mpnet']:
        print("\n--- HYBRID (MPNet) ---\n" + results['hybrid_mpnet']['combined_context'])

## 6. Interactive Q&A

In [146]:
def ask(question: str, use_hybrid: bool = True, model_key: str = "minilm") -> str:
    """Ask a question using the full pipeline."""
    print(f"\n{'='*60}\nQ: {question}\nMode: {'Hybrid' if use_hybrid else 'Cypher Only'}\n{'='*60}\n")
    
    if use_hybrid:
        answer = answer_with_hybrid_context(driver, question, model_key)
    else:
        context_parts = [format_query_result(cq['query_index'], **cq['params']) for cq in get_context(question)]
        print({chr(10).join(context_parts)})
        prompt = f"""You are an AI assistant for an airline analyzing flight data.
Answer using only this context:

{chr(10).join(context_parts)}

Question: {question}

Answer:"""
        answer = llm.invoke(prompt).content
    
    print(f"ANSWER:\n{'-'*40}\n{answer}\n")
    return answer

## 7. Usage Examples

In [147]:
# Generate embeddings (run once)
# generate_and_store_all_embeddings(driver, "minilm")

print("Uncomment the line above to generate embeddings.")

Uncomment the line above to generate embeddings.


In [150]:
# Example questions:
ask("What are the top 5 airports with the most delays?")
ask("How do Millennials travel compared to Baby Boomers?")
ask("Which aircraft type has the best on-time performance?")
ask("What is the flight number of the journey that departs from LAX and arrives at IAX and has generation 'Millennials'?")
ask("What are the different loyalty program levels for a journey that has flight number 2, mention all of them")

print("Uncomment an example to try the pipeline!")


Q: What are the top 5 airports with the most delays?
Mode: Hybrid

ANSWER:
----------------------------------------
Based on the provided data, the top 5 airports with the most delays are:

1. CDX (991 delay minutes)
2. JAX (810 delay minutes)
3. SIX (661 delay minutes)
4. FRX (462 delay minutes)
5. MUX (326 delay minutes) 



Q: How do Millennials travel compared to Baby Boomers?
Mode: Hybrid

ANSWER:
----------------------------------------
The provided text does not contain information about how Millennials travel compared to Baby Boomers.  It does, however, provide examples of journeys with specific travel class, aircraft type, and traveler age group. 



Q: Which aircraft type has the best on-time performance?
Mode: Hybrid

ANSWER:
----------------------------------------
The answer cannot be determined from the provided data. The query asks for the aircraft type with the LOWEST average arrival delay, but the data only shows the average delay by aircraft type. There is no informa

In [149]:
# Close driver when done
# driver.close()
# print("Closed.")