<a href="https://colab.research.google.com/github/Fahrudeen/GraphRAG/blob/master/notebooks/graphrag_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GraphRAG Supply Chain Management System

## Project Overview

This notebook implements a **Graph-based Retrieval Augmented Generation (GraphRAG)** system for supply chain management using Neo4j graph database and Groq's LLM API.

### Key Features
- **Knowledge Graph Construction**: Build comprehensive supply chain relationships
- **Intelligent Query Processing**: Natural language to Cypher query translation
- **Context-Aware Responses**: LLM-powered answers using graph context
- **Multi-Entity Relationships**: Cars, Features, Parts, and Suppliers

### Architecture
```
User Query → LLM Query Generator → Neo4j Graph → Context Retrieval → LLM Response
```

### Prerequisites
- Neo4j AuraDB instance or local Neo4j installation
- Groq API access
- Supply chain CSV data files

---

##  Setup and Configuration

### Environment Variables

**Important**: Before running this notebook, set up your environment variables:

```
# Create a .env file with:
NEO4J_URI=your_neo4j_uri_here
NEO4J_PASSWORD=****
GROQ_API_KEY=****
```


In [None]:
# Install required dependencies
!pip install neo4j groq pandas python-dotenv -q

print(" Dependencies installed successfully!")

In [None]:
import os
from neo4j import GraphDatabase
from groq import Groq
import pandas as pd
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Configuration with environment variables
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME', 'neo4j')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
GROQ_API_KEY = os.getenv('GROQ_API_KEY')

# Validate configuration
if not all([NEO4J_PASSWORD, GROQ_API_KEY]):
    raise ValueError("Missing required environment variables. Please check your .env file.")

print("Configuration loaded successfully!")

## Database Connection Setup

Establishing secure connection to Neo4j database with proper error handling.

In [None]:
def create_database_connection():
    """Create and test Neo4j database connection."""
    try:
        driver = GraphDatabase.driver(
            NEO4J_URI,
            auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
        )
        return driver
    except Exception as e:
        print(f"Database connection failed: {e}")
        raise

def test_connection(driver):
    """Test database connectivity."""
    try:
        with driver.session() as session:
            result = session.run("RETURN 'Neo4j Connected!' AS message")
            message = result.single()["message"]
            print(f"{message}")
            return True
    except Exception as e:
        print(f"Connection test failed: {e}")
        return False

# Initialize database connection
driver = create_database_connection()
test_connection(driver)

## Data Loading and Preparation

Loading supply chain data from CSV files. Update the file paths according to your data location.

In [None]:
def load_supply_chain_data():
    """Load all supply chain CSV files."""

    # Update these paths according to your data location
    data_files = {
        'car_models': 'data/nodes_car_model.csv',
        'features': 'data/nodes_feature.csv',
        'parts': 'data/nodes_part.csv',
        'suppliers': 'data/nodes_supplier.csv',
        'with_feature': 'data/with_feature.csv',
        'is_composed_of': 'data/is_composed_of.csv',
        'is_supplied_by': 'data/is_supplied_by.csv'
    }

    dataframes = {}

    for name, filepath in data_files.items():
        try:
            df = pd.read_csv(filepath)
            dataframes[name] = df
            print(f"Loaded {name}: {len(df)} records")
        except FileNotFoundError:
            print(f"File not found: {filepath}")
        except Exception as e:
            print(f"Error loading {name}: {e}")

    return dataframes

# Load data
data = load_supply_chain_data()

# Display data structure
print("\nData Structure Overview:")
for name, df in data.items():
    print(f"{name}: {list(df.columns)}")

## Knowledge Graph Construction

Building the supply chain knowledge graph with nodes and relationships.

In [None]:
class GraphBuilder:
    """Handles knowledge graph construction operations."""

    def __init__(self, driver):
        self.driver = driver

    def create_car_models(self, df_car):
        """Insert car model nodes into the graph."""
        def insert_car_model(tx, row):
            query = """
            MERGE (c:CarModel {vertex_id: $vertex_id})
            SET c.name = $name,
                c.number = $number,
                c.year = toInteger($year),
                c.type = $type,
                c.engine_type = $engine_type,
                c.size = $size,
                c.seats = toInteger($seats)
            """
            tx.run(query, **row)

        with self.driver.session() as session:
            for _, row in df_car.iterrows():
                session.execute_write(insert_car_model, row.to_dict())

        print(f"Created {len(df_car)} car model nodes")

    def create_features(self, df_feature):
        """Insert feature nodes into the graph."""
        def insert_feature(tx, row):
            query = """
            MERGE (f:Feature {vertex_id: $vertex_id})
            SET f.name = $name,
                f.number = $number,
                f.type = $type,
                f.state = $state
            """
            tx.run(query, **row)

        with self.driver.session() as session:
            for _, row in df_feature.iterrows():
                session.execute_write(insert_feature, row.to_dict())

        print(f"Created {len(df_feature)} feature nodes")

    def create_parts(self, df_part):
        """Insert part nodes into the graph."""
        def insert_part(tx, row):
            query = """
            MERGE (p:Part {vertex_id: $vertex_id})
            SET p.name = $name,
                p.number = $number,
                p.price = toFloat($price),
                p.date = date($date)
            """
            tx.run(query, **row)

        with self.driver.session() as session:
            for _, row in df_part.iterrows():
                session.execute_write(insert_part, row.to_dict())

        print(f"Created {len(df_part)} part nodes")

    def create_suppliers(self, df_supplier):
        """Insert supplier nodes into the graph."""
        def insert_supplier(tx, row):
            query = """
            MERGE (s:Supplier {vertex_id: $vertex_id})
            SET s.name = $name,
                s.address = $address,
                s.contact = $contact,
                s.phone_number = $phone_number
            """
            tx.run(query, **row)

        with self.driver.session() as session:
            for _, row in df_supplier.iterrows():
                session.execute_write(insert_supplier, row.to_dict())

        print(f"Created {len(df_supplier)} supplier nodes")

# Initialize graph builder
if 'car_models' in data:
    graph_builder = GraphBuilder(driver)

    # Create nodes
    graph_builder.create_car_models(data['car_models'])
    graph_builder.create_features(data['features'])
    graph_builder.create_parts(data['parts'])
    graph_builder.create_suppliers(data['suppliers'])
else:
    print("Skipping graph construction - data not loaded")

## Relationship Creation

Creating relationships between nodes to form the complete knowledge graph.

In [None]:
class RelationshipBuilder:
    """Handles relationship creation in the knowledge graph."""

    def __init__(self, driver):
        self.driver = driver

    def create_with_feature_relationships(self, df_with_feature):
        """Create WITH_FEATURE relationships between cars and features."""
        def insert_relationship(tx, row):
            query = """
            MATCH (c:CarModel {vertex_id: $src_node_id})
            MATCH (f:Feature {vertex_id: $dst_node_id})
            MERGE (c)-[:WITH_FEATURE {version: $version}]->(f)
            """
            tx.run(query, **row)

        with self.driver.session() as session:
            for _, row in df_with_feature.iterrows():
                session.execute_write(insert_relationship, row.to_dict())

        print(f"Created {len(df_with_feature)} WITH_FEATURE relationships")

    def create_composed_of_relationships(self, df_is_composed_of):
        """Create IS_COMPOSED_OF relationships between features and parts."""
        def insert_relationship(tx, row):
            query = """
            MATCH (f:Feature {vertex_id: $src_id})
            MATCH (p:Part {vertex_id: $dst_id})
            MERGE (f)-[:IS_COMPOSED_OF {version: $version}]->(p)
            """
            tx.run(query, **row)

        with self.driver.session() as session:
            for _, row in df_is_composed_of.iterrows():
                session.execute_write(insert_relationship, row.to_dict())

        print(f"Created {len(df_is_composed_of)} IS_COMPOSED_OF relationships")

    def create_supplied_by_relationships(self, df_is_supplied_by):
        """Create IS_SUPPLIED_BY relationships between parts and suppliers."""
        def insert_relationship(tx, row):
            query = """
            MATCH (p:Part {vertex_id: $src_id})
            MATCH (s:Supplier {vertex_id: $dst_id})
            MERGE (p)-[:IS_SUPPLIED_BY {version: $version}]->(s)
            """
            tx.run(query, **row)

        with self.driver.session() as session:
            for _, row in df_is_supplied_by.iterrows():
                session.execute_write(insert_relationship, row.to_dict())

        print(f"Created {len(df_is_supplied_by)} IS_SUPPLIED_BY relationships")

# Create relationships
if all(key in data for key in ['with_feature', 'is_composed_of', 'is_supplied_by']):
    rel_builder = RelationshipBuilder(driver)

    rel_builder.create_with_feature_relationships(data['with_feature'])
    rel_builder.create_composed_of_relationships(data['is_composed_of'])
    rel_builder.create_supplied_by_relationships(data['is_supplied_by'])

    print("\nKnowledge graph construction completed!")
else:
    print("Skipping relationship creation - data not loaded")

## LLM Integration and Query Processing

Setting up Groq LLM client and implementing intelligent query processing.

In [None]:
class QueryProcessor:
    """Handles LLM-powered query processing and Cypher generation."""

    def __init__(self, groq_api_key, driver):
        self.client = Groq(api_key=groq_api_key)
        self.driver = driver
        self.system_prompt = self._create_system_prompt()

    def _create_system_prompt(self):
        """Create system prompt for Cypher query generation."""
        return """
        You are a Cypher query generator for a supply chain graph in Neo4j.

        The graph includes:

        Nodes:
        - CarModel: vertex_id, name, number, year, type, engine_type, size, seats
        - Feature: vertex_id, name, number, type, state
        - Part: vertex_id, name, number, price, date
        - Supplier: vertex_id, name, address, contact, phone_number

        Relationships:
        - (CarModel)-[:WITH_FEATURE {version}]->(Feature)
        - (Feature)-[:IS_COMPOSED_OF {version}]->(Part)
        - (Part)-[:IS_SUPPLIED_BY {version}]->(Supplier)

        Rules:
        - Use MATCH and RETURN only
        - Filter by number, name, or vertex_id
        - Return just the Cypher query without additional context

        Examples:
        Q: Features of C1000
        A: MATCH (c:CarModel {number: "C1000"})-[:WITH_FEATURE]->(f:Feature) RETURN f.name;

        Q: Tell me about Car A
        A: MATCH (c:CarModel {name: "Model A"}) RETURN c;

        Q: Parts used in feature f_13
        A: MATCH (:Feature {vertex_id: "f_13"})-[:IS_COMPOSED_OF]->(p:Part) RETURN p.name;
        """

    def generate_cypher(self, user_query):
        """Generate Cypher query from natural language."""
        try:
            prompt = f"{self.system_prompt}\n\nThe query from user is: {user_query}"

            response = self.client.chat.completions.create(
                model="llama-3.3-70b-versatile",
                messages=[{"role": "user", "content": prompt}]
            )

            return response.choices[0].message.content.strip()

        except Exception as e:
            print(f"Error generating Cypher: {e}")
            return None

    def execute_cypher(self, cypher_query):
        """Execute Cypher query against Neo4j database."""
        try:
            with self.driver.session() as session:
                result = session.run(cypher_query)
                records = [record.data() for record in result]
                return records
        except Exception as e:
            print(f"Error executing query: {e}")
            return None

    def generate_response(self, user_query, retrieved_data):
        """Generate natural language response using retrieved data."""
        try:
            system_message = (
                "You are a GraphRAG Client communication AI for XYZ Automobile services. "
                "You will be provided with retrieved data from the graph database and the user's query. "
                "Provide helpful, accurate responses based on the data."
            )

            user_message = f"Retrieved context: {retrieved_data}\nQuery: {user_query}"

            response = self.client.chat.completions.create(
                messages=[
                    {"role": "system", "content": system_message},
                    {"role": "user", "content": user_message}
                ],
                model="llama-3.3-70b-versatile"
            )

            return response.choices[0].message.content

        except Exception as e:
            print(f"Error generating response: {e}")
            return "Sorry, I encountered an error processing your request."

# Initialize query processor
query_processor = QueryProcessor(GROQ_API_KEY, driver)
print("Query processor initialized successfully!")

## Interactive Query Interface

Interactive system for querying the supply chain knowledge graph.

In [None]:
def run_graphrag_query(user_query):
    """Execute complete GraphRAG pipeline for a user query."""
    print(f"Processing query: {user_query}")
    print("="*50)

    # Step 1: Generate Cypher query
    cypher_query = query_processor.generate_cypher(user_query)
    if not cypher_query:
        return "Failed to generate query"

    print(f"Generated Cypher:\n{cypher_query}\n")

    # Step 2: Execute query
    retrieved_data = query_processor.execute_cypher(cypher_query)
    if retrieved_data is None:
        return "Failed to execute query"

    print(f"Retrieved {len(retrieved_data)} records\n")

    # Step 3: Generate natural language response
    response = query_processor.generate_response(user_query, retrieved_data)

    print("AI Response:")
    print("-" * 30)
    print(response)

    return response

# Example usage
print("GraphRAG system ready for queries!")
print("\nExample queries:")
print("- 'Features of Model A'")
print("- 'Compare Model A and Model C'")
print("- 'Parts used in sunroof feature'")
print("- 'Suppliers for brake parts'")

In [None]:
# Interactive query cell - Run this to ask questions
user_query = input("🎤 Ask your supply chain question: ")
run_graphrag_query(user_query)

## Utility Functions

Additional utility functions for database management.

In [None]:
def get_graph_statistics():
    """Get overview statistics of the knowledge graph."""
    queries = {
        'Car Models': 'MATCH (c:CarModel) RETURN count(c) as count',
        'Features': 'MATCH (f:Feature) RETURN count(f) as count',
        'Parts': 'MATCH (p:Part) RETURN count(p) as count',
        'Suppliers': 'MATCH (s:Supplier) RETURN count(s) as count',
        'Total Relationships': 'MATCH ()-[r]-() RETURN count(r) as count'
    }

    print("Knowledge Graph Statistics:")
    print("=" * 35)

    with driver.session() as session:
        for name, query in queries.items():
            result = session.run(query)
            count = result.single()['count']
            print(f"{name:<20}: {count:>10,}")

def cleanup_database():
    """Clean up database (use with caution!)."""
    confirmation = input("⚠️  Are you sure you want to delete all data? (type 'YES' to confirm): ")

    if confirmation == 'YES':
        with driver.session() as session:
            session.run("MATCH (n) DETACH DELETE n")
        print("Database cleaned successfully!")
    else:
        print("Operation cancelled.")

def close_connection():
    """Close database connection."""
    if driver:
        driver.close()
        print("Database connection closed.")

# Display current graph statistics
get_graph_statistics()

## Conclusion

This GraphRAG implementation demonstrates:

### Achievements
- **Knowledge Graph**: Successfully built comprehensive supply chain relationships
- **Natural Language Processing**: Intelligent query translation to Cypher
- **Context-Aware Responses**: LLM-powered answers using graph context
- **Scalable Architecture**: Modular design for easy extension

### Next Steps
- Add graph visualization capabilities
- Implement query optimization
- Add more complex relationship types
- Integrate with web interface

### Resources
- [Neo4j Documentation](https://neo4j.com/docs/)
- [Groq API Documentation](https://groq.com/)
- [Graph Database Best Practices](https://neo4j.com/developer/)

---

**Remember**: Always close the database connection when finished:
```
close_connection()
```