In [1]:
from dotenv import load_dotenv
load_dotenv()

import os, json, hashlib
import fitz
from groq import Groq
from neo4j import GraphDatabase

from langgraph.graph import StateGraph, END
from typing import TypedDict, List, Dict, Any

from sentence_transformers import SentenceTransformer
import numpy as np

  from pydantic.v1.fields import FieldInfo as FieldInfoV1


In [2]:
# ===== IMPROVED: Configuration and Constants =====
import os
from typing import Optional

# Embedding configuration
EMBEDDING_DIM = 384  # Standardized dimension for all-MiniLM-L6-v2
HF_EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
MAX_TEXT_CHUNK = 8000  # Max characters per chunk
CHUNK_OVERLAP = 500    # Overlap between chunks

# API Configuration
GROQ_MODEL = "llama-3.1-8b-instant"
MAX_TOKENS = 8000
TEMPERATURE = 0.1

# Retry configuration
MAX_RETRIES = 3
RETRY_DELAY = 1  # seconds

def validate_env_vars():
    """Validate required environment variables"""
    required = ['GROQ_API_KEY', 'HF_TOKEN', 'NEO4J_URI', 'NEO4J_USERNAME', 'NEO4J_PASSWORD']
    missing = [var for var in required if not os.getenv(var)]
    if missing:
        raise ValueError(f"Missing required environment variables: {', '.join(missing)}")
    return True

# Validate on import
try:
    validate_env_vars()
    print("‚úÖ Environment variables validated")
except ValueError as e:
    print(f"‚ö†Ô∏è {e}")



‚úÖ Environment variables validated


In [3]:
# ===== IMPROVED: Retry Logic with Exponential Backoff =====
import time
from functools import wraps

def retry_with_backoff(max_retries=MAX_RETRIES, delay=RETRY_DELAY, backoff=2):
    """Decorator for retrying functions with exponential backoff"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            retries = 0
            current_delay = delay
            while retries < max_retries:
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    retries += 1
                    if retries >= max_retries:
                        print(f"‚ùå {func.__name__} failed after {max_retries} retries: {str(e)}")
                        raise
                    print(f"‚ö†Ô∏è {func.__name__} failed (attempt {retries}/{max_retries}), retrying in {current_delay}s...")
                    time.sleep(current_delay)
                    current_delay *= backoff
            return None
        return wrapper
    return decorator

print("‚úÖ Retry utility loaded")



‚úÖ Retry utility loaded


In [4]:
# ===== IMPROVED: Text Chunking for Large Contracts =====
def chunk_text(text: str, max_size: int = MAX_TEXT_CHUNK, overlap: int = CHUNK_OVERLAP) -> list:
    """
    Split large text into overlapping chunks for processing.
    
    Args:
        text: Input text to chunk
        max_size: Maximum characters per chunk
        overlap: Number of characters to overlap between chunks
    
    Returns:
        List of text chunks
    """
    if len(text) <= max_size:
        return [text]
    
    chunks = []
    start = 0
    
    while start < len(text):
        end = start + max_size
        chunk = text[start:end]
        chunks.append(chunk)
        
        # Move start position with overlap
        start = end - overlap
        if start >= len(text):
            break
    
    return chunks

def merge_chunk_analyses(chunk_results: list) -> dict:
    """
    Merge analysis results from multiple chunks into a single result.
    
    Args:
        chunk_results: List of analysis dictionaries from chunks
    
    Returns:
        Merged analysis dictionary
    """
    if not chunk_results:
        return {}
    
    merged = {
        "title": chunk_results[0].get("title", ""),
        "parties": [],
        "dates": [],
        "governing_law": chunk_results[0].get("governing_law", ""),
        "clauses": []
    }
    
    # Collect unique parties
    seen_parties = set()
    for result in chunk_results:
        for party in result.get("parties", []):
            party_key = str(party.get("name", "")) + str(party.get("role", ""))
            if party_key not in seen_parties:
                merged["parties"].append(party)
                seen_parties.add(party_key)
    
    # Collect unique dates
    seen_dates = set()
    for result in chunk_results:
        for date in result.get("dates", []):
            date_key = str(date.get("value", "")) + str(date.get("type", ""))
            if date_key not in seen_dates:
                merged["dates"].append(date)
                seen_dates.add(date_key)
    
    # Collect all clauses
    seen_clauses = set()
    for result in chunk_results:
        for clause in result.get("clauses", []):
            clause_key = clause.get("clause_name", "")
            if clause_key and clause_key not in seen_clauses:
                merged["clauses"].append(clause)
                seen_clauses.add(clause_key)
    
    return merged

print("‚úÖ Chunking utilities loaded")



‚úÖ Chunking utilities loaded


In [5]:
# ===== IMPROVED: Enhanced Embedding Function with Retry =====
@retry_with_backoff(max_retries=MAX_RETRIES)
def get_embeddings_api_improved(text: str) -> Optional[list]:
    """
    Get embeddings using HuggingFace Inference API with improved error handling.
    
    Args:
        text: Text to generate embeddings for
    
    Returns:
        List of floats (384 dimensions) or None if failed
    """
    if not text or not text.strip():
        return None
    
    headers = {
        "Authorization": f"Bearer {os.getenv('HF_TOKEN')}",
        "Content-Type": "application/json"
    }
    
    payload = {
        "inputs": text[:MAX_TEXT_CHUNK],  # Limit input size
        "options": {"wait_for_model": True}
    }
    
    try:
        response = requests.post(
            f"https://router.huggingface.co/hf-inference/models/{HF_EMBED_MODEL}",
            headers=headers,
            json=payload,
            timeout=30
        )
        
        if response.status_code == 200:
            result = response.json()
            
            # Handle nested list response (token embeddings)
            if isinstance(result, list) and len(result) > 0:
                if isinstance(result[0], list):
                    # Mean pool token embeddings
                    import numpy as np
                    emb = np.mean(result, axis=0).tolist()
                else:
                    emb = result
            else:
                emb = result
            
            # Validate dimension
            if isinstance(emb, list) and len(emb) == EMBEDDING_DIM:
                return emb
            elif isinstance(emb, list) and len(emb) > 0:
                # Truncate or pad to correct dimension
                if len(emb) > EMBEDDING_DIM:
                    return emb[:EMBEDDING_DIM]
                else:
                    return emb + [0.0] * (EMBEDDING_DIM - len(emb))
            else:
                return None
        else:
            print(f"‚ö†Ô∏è HF API returned status {response.status_code}")
            return None
            
    except requests.exceptions.Timeout:
        print("‚ö†Ô∏è HF API request timed out")
        return None
    except Exception as e:
        print(f"‚ö†Ô∏è HF API error: {str(e)}")
        return None

def validate_embedding(emb: Optional[list], expected_dim: int = EMBEDDING_DIM) -> list:
    """
    Validate and fix embedding dimensions.
    
    Args:
        emb: Embedding vector or None
        expected_dim: Expected dimension (default 384)
    
    Returns:
        Valid embedding vector of correct dimension
    """
    if emb is None or not isinstance(emb, list):
        return [0.0] * expected_dim
    
    if len(emb) == expected_dim:
        return emb
    elif len(emb) > expected_dim:
        return emb[:expected_dim]
    else:
        return emb + [0.0] * (expected_dim - len(emb))

print("‚úÖ Improved embedding functions loaded")



‚úÖ Improved embedding functions loaded


In [6]:
# ===== IMPROVED: Enhanced JSON Parsing with Better Error Recovery =====
import re
import json

def parse_llm_json(content: str, max_attempts: int = 3) -> Optional[dict]:
    """
    Parse JSON from LLM output with multiple fallback strategies.
    
    Args:
        content: Raw LLM output string
        max_attempts: Maximum parsing attempts
    
    Returns:
        Parsed dictionary or None
    """
    if not content:
        return None
    
    # Strategy 1: Direct JSON parse
    try:
        return json.loads(content)
    except json.JSONDecodeError:
        pass
    
    # Strategy 2: Remove markdown code blocks
    cleaned = content
    if "```json" in cleaned:
        cleaned = cleaned.split("```json")[1].split("```")[0].strip()
    elif "```" in cleaned:
        cleaned = cleaned.split("```")[1].split("```")[0].strip()
    
    try:
        return json.loads(cleaned)
    except json.JSONDecodeError:
        pass
    
    # Strategy 3: Extract JSON object boundaries
    if "{" in cleaned and "}" in cleaned:
        start = cleaned.find("{")
        end = cleaned.rfind("}") + 1
        json_str = cleaned[start:end]
        
        try:
            return json.loads(json_str)
        except json.JSONDecodeError:
            pass
    
    # Strategy 4: Fix common JSON issues
    try:
        # Fix unescaped quotes in strings
        fixed = re.sub(r'(?<!\\)"(?=\w)', r'\\"', cleaned)
        # Remove trailing commas
        fixed = re.sub(r',\s*}', '}', fixed)
        fixed = re.sub(r',\s*]', ']', fixed)
        
        return json.loads(fixed)
    except json.JSONDecodeError as e:
        print(f"‚ö†Ô∏è JSON parsing failed after all attempts: {e}")
        return None

def validate_analysis_data(data: dict) -> dict:
    """
    Validate and fix analysis data structure.
    
    Args:
        data: Analysis dictionary from LLM
    
    Returns:
        Validated and fixed analysis dictionary
    """
    if not isinstance(data, dict):
        return {
            "title": "Unknown Contract",
            "parties": [],
            "dates": [],
            "governing_law": "Not Specified",
            "clauses": []
        }
    
    # Ensure required fields exist
    validated = {
        "title": data.get("title", "Unknown Contract"),
        "parties": data.get("parties", []),
        "dates": data.get("dates", []),
        "governing_law": data.get("governing_law", "Not Specified"),
        "clauses": data.get("clauses", [])
    }
    
    # Validate clauses
    validated_clauses = []
    for clause in validated["clauses"]:
        if isinstance(clause, dict):
            validated_clause = {
                "clause_name": clause.get("clause_name", "Unnamed Clause"),
                "summary": clause.get("summary", ""),
                "risk_level": clause.get("risk_level", "MEDIUM"),
                "risk_reason": clause.get("risk_reason", ""),
                "obligation": clause.get("obligation", ""),
                "liability": clause.get("liability", ""),
                "ai_summary": clause.get("ai_summary", "")
            }
            # Ensure risk_level is valid
            if validated_clause["risk_level"] not in ["LOW", "MEDIUM", "HIGH"]:
                validated_clause["risk_level"] = "MEDIUM"
            validated_clauses.append(validated_clause)
    
    validated["clauses"] = validated_clauses
    return validated

print("‚úÖ Improved JSON parsing utilities loaded")



‚úÖ Improved JSON parsing utilities loaded


In [7]:
# ===== HF EMBEDDING VIA API (NO MODEL DOWNLOAD) =====
import requests

HF_TOKEN = os.environ.get("HF_TOKEN")
HF_EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"  # Reliable model
HF_API_URL = (
    f"https://router.huggingface.co/hf-inference/models/{HF_EMBED_MODEL}"
)
# Groq
groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])

# Neo4j - Auto-fix Aura URI for SSL compatibility
uri = os.environ["NEO4J_URI"]

# Fix for Neo4j Aura: convert neo4j+s:// to neo4j+ssc:// (uses system cert store)
if "neo4j+s://" in uri and "neo4j+ssc://" not in uri:
    uri = uri.replace("neo4j+s://", "neo4j+ssc://")
    print(f"üîÑ Updated URI for Aura compatibility: {uri[:50]}...")

neo4j_driver = GraphDatabase.driver(
    uri,
    auth=(os.environ["NEO4J_USERNAME"], os.environ["NEO4J_PASSWORD"])
)

# Test connection
try:
    neo4j_driver.verify_connectivity()
    print("‚úÖ Neo4j Aura connection successful!")
except Exception as e:
    print(f"‚ùå Neo4j connection failed: {e}")
    print("üí° Make sure:")
    print("   1. Your NEO4J_URI uses neo4j+ssc:// or neo4j+s://")
    print("   2. Username and password are correct")
    print("   3. Your Aura database is running")
    raise

def get_embeddings_api(text):
    """Get embeddings using HuggingFace Inference API"""
    headers = {"Authorization": f"Bearer {HF_TOKEN}"}
    payload = {"inputs": text, "options": {"wait_for_model": True}}
    
    try:
        response = requests.post(HF_API_URL, headers=headers, json=payload, timeout=2)
        
        if response.status_code == 200:
            result = response.json()
            # Handle different response formats
            if isinstance(result, list):
                if isinstance(result[0], list):
                    return result[0]  # Nested list
                return result
            return result
        else:
            # print(f"‚ùå API Error: {response.status_code} - {response.text}")
            return None
    except Exception as e:
        print(f"‚ùå Exception: {str(e)}")
        return None

print("‚úÖ Clients initialized (API mode - no model download)")

üîÑ Updated URI for Aura compatibility: neo4j+ssc://8c5998b9.databases.neo4j.io...
‚úÖ Neo4j Aura connection successful!
‚úÖ Clients initialized (API mode - no model download)


In [8]:
def pdf_hash(path):
    with open(path, "rb") as f:
        return hashlib.sha256(f.read()).hexdigest()

def extract_text(pdf):
    doc = fitz.open(pdf)
    return "".join(p.get_text() for p in doc)

def safe_str(v):
    return v.strip() if isinstance(v, str) else None

def normalize_list(items, key=None):
    out = []
    for i in items:
        if isinstance(i, dict):
            val = i.get(key) if key else None
            if val:
                out.append(val)
        elif isinstance(i, str):
            out.append(i)
    return out

In [9]:
def print_contract_summary(data):
    """Enhanced summary printer with all details"""
    print("\n" + "="*80)
    print("üìÑ CONTRACT SUMMARY")
    print("="*80)

    print(f"\nüìå BASIC INFORMATION")
    print("-"*80)
    print(f"Title          : {data.get('title', 'N/A')}")
    print(f"File Name      : {data.get('file_name', 'N/A')}")
    print(f"Contract ID    : {data.get('contract_id', 'N/A')}")
    print(f"Governing Law  : {data.get('governing_law', 'N/A')}")

    # Parties
    print(f"\nüë• PARTIES ({len(data.get('parties', []))})")
    print("-"*80)
    for i, p in enumerate(data.get('parties', []), 1):
        print(f"  [{i}] {p}")

    # Important Dates
    print(f"\nüìÖ IMPORTANT DATES ({len(data.get('dates', []))})")
    print("-"*80)
    for i, d in enumerate(data.get('dates', []), 1):
        print(f"  [{i}] {d}")

    # Clauses with Risk Analysis
    print(f"\n‚öñÔ∏è CLAUSE RISK ANALYSIS ({len(data.get('clauses', []))})")
    print("="*80)
    for i, c in enumerate(data.get("clauses", []), 1):
        print(f"\n[Clause {i}] {c.get('clause_name', 'Unnamed')}")
        print("-"*80)
        print(f"Summary      : {c.get('summary', 'N/A')}")
        print(f"\nüö® Risk Level : {c.get('risk_level', 'N/A')}")
        print(f"Risk Reason  : {c.get('risk_reason', 'N/A')}")
        print(f"\nüìã Obligation : {c.get('obligation', 'N/A')}")
        print(f"üíº Liability  : {c.get('liability', 'N/A')}")
        print(f"\nü§ñ AI Summary : {c.get('ai_summary', 'N/A')}")
        print("-"*80)
    
    print("\n" + "="*80)

In [10]:
class ContractState(TypedDict):
    pdf_path: str
    cid: str
    text: str
    embeddings: List[float]
    analysis: Dict[str, Any]

In [11]:
def pdf_extraction_agent(state: ContractState):
    print(f"\nüìÑ Extracting PDF: {state['pdf_path']}")
    text = extract_text(state["pdf_path"])
    return {
        **state,
        "text": text
    }

In [12]:
def embedding_agent(state: ContractState):
    print("üî¢ Generating embeddings via HuggingFace API")

    # Limit chunk size for HF inference safety
    text_chunk = state["text"][:8000]

    emb = get_embeddings_api(text_chunk)

    # Fallback if API fails or returns bad data
    if (
        emb is None
        or not isinstance(emb, list)
        or len(emb) == 0
    ):
        print("‚ö†Ô∏è Using fallback embeddings")
        emb = [0.0] * 384  # all-MiniLM-L6-v2 ‚Üí 384 dims

    # Handle nested response ([[...]])
    if isinstance(emb[0], list):
        emb = emb[0]

    # Final dimension guard (VERY important for Neo4j vector index)
    if len(emb) != 384:
        print(f"‚ö†Ô∏è Invalid embedding size {len(emb)}, forcing fallback")
        emb = [0.0] * 384

    print(f"   Embedding dimension: {len(emb)}")

    return {
        **state,
        "embeddings": emb
    }
def get_embeddings_api(text):
    headers = {
        "Authorization": f"Bearer {HF_TOKEN}",
        "Content-Type": "application/json"
    }

    payload = {
        "inputs": text,
        "options": {
            "wait_for_model": True
        }
    }

    try:
        response = requests.post(
            HF_API_URL,
            headers=headers,
            json=payload,
            timeout=30
        )

        if response.status_code == 200:
            result = response.json()

            # HF returns: [ [token_embeddings...] ]
            # We must MEAN POOL
            if isinstance(result, list) and isinstance(result[0], list):
                import numpy as np
                return np.mean(result, axis=0).tolist()

            print("‚ö†Ô∏è Unexpected HF response:", result)
            return None

        else:
            # print(f"‚ùå API Error: {response.status_code} - {response.text}")
            return None

    except Exception as e:
        print(f"‚ùå Exception: {e}")
        return None
    


In [13]:
# def get_embeddings_api(text):
#     headers = {
#         "Authorization": f"Bearer {HF_TOKEN}",
#         "Content-Type": "application/json"
#     }

#     payload = {
#         "inputs": text,
#         "options": {"wait_for_model": True}
#     }

#     try:
#         response = requests.post(
#             HF_API_URL,
#             headers=headers,
#             json=payload,
#             timeout=30
#         )

#         if response.status_code == 200:
#             result = response.json()

#             # token-level embeddings ‚Üí mean pooling
#             if isinstance(result, list) and isinstance(result[0], list):
#                 import numpy as np
#                 return np.mean(result, axis=0).tolist()

#             print("‚ö†Ô∏è Unexpected HF response:", result)
#             return None

#         print(f"‚ùå HF API Error {response.status_code}: {response.text}")
#         return None

#     except Exception as e:
#         print(f"‚ùå HF Exception: {e}")
#         return None


In [14]:
def analysis_agent(state: ContractState):
    print("üß† Analyzing contract via Groq LLM")

    # Enhanced prompt for better extraction
    prompt = f"""
You are a legal contract analyzer. Analyze the following contract and extract detailed information.

CRITICAL: Return ONLY valid JSON. No markdown, no explanations, just the JSON object.

{{
  "title": "Contract title",
  "parties": [
    {{"name": "Party 1 name", "role": "Role (e.g., Service Provider, Client)"}},
    {{"name": "Party 2 name", "role": "Role"}}
  ],
  "dates": [
    {{"type": "Effective Date", "value": "YYYY-MM-DD or as mentioned"}},
    {{"type": "Expiration Date", "value": "YYYY-MM-DD or as mentioned"}}
  ],
  "governing_law": "Jurisdiction and governing law",
  "clauses": [
    {{
      "clause_name": "Name of the clause",
      "summary": "Brief summary of what this clause says",
      "risk_level": "Low/Medium/High",
      "risk_reason": "Detailed explanation of why this risk level was assigned. Mention specific concerns, potential liabilities, or unfavorable terms.",
      "obligation": "Specific obligations this clause imposes on parties. Be detailed.",
      "liability": "What liabilities or penalties are mentioned in this clause. Include financial limits if any.",
      "ai_summary": "A comprehensive AI analysis of this clause including: 1) What it means in plain language, 2) Key takeaways, 3) Red flags or concerns, 4) Recommendations"
    }}
  ]
}}

IMPORTANT INSTRUCTIONS:
1. Extract ALL major clauses from the contract (aim for 5-10 clauses)
2. For risk_reason: Explain WHY you assigned that risk level with specific concerns
3. For ai_summary: Provide detailed analysis (at least 2-3 sentences)
4. Be thorough - don't leave fields empty
5. Focus on: payment terms, termination, liability, intellectual property, confidentiality, warranties, indemnification
6. ESCAPE all quotes inside strings properly using backslash
7. Do NOT include any text before or after the JSON object

CONTRACT TEXT:
{state["text"][:10000]}
"""

    try:
        res = groq_client.chat.completions.create(
            model="llama-3.1-8b-instant",
            messages=[
                {"role": "system", "content": "You are a JSON-only API. Return only valid JSON, no markdown or explanations."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.1,
            max_tokens=8000
        )

        content = res.choices[0].message.content.strip()
        
        # Clean the response
        # Remove markdown code blocks
        if "```json" in content:
            content = content.split("```json")[1].split("```")[0].strip()
        elif "```" in content:
            content = content.split("```")[1].split("```")[0].strip()
        
        # Remove any leading/trailing whitespace
        content = content.strip()
        
        # Try to find JSON object boundaries
        if not content.startswith("{"):
            start = content.find("{")
            if start != -1:
                content = content[start:]
        
        if not content.endswith("}"):
            end = content.rfind("}")
            if end != -1:
                content = content[:end+1]
        
        # Attempt to parse
        try:
            analysis = json.loads(content)
        except json.JSONDecodeError as e:
            print(f"‚ö†Ô∏è JSON parsing failed at position {e.pos}: {e.msg}")
            print(f"Problematic content around error: ...{content[max(0,e.pos-50):e.pos+50]}...")
            
            # Try to fix common issues
            import re
            
            # Fix unescaped quotes in strings
            content = re.sub(r'(?<!\\)"(?=\w)', r'\\"', content)
            
            # Try parsing again
            try:
                analysis = json.loads(content)
                print("‚úÖ JSON fixed and parsed successfully")
            except:
                print("‚ùå Could not fix JSON, using fallback structure")
                # Fallback structure
                analysis = {
                    "title": "Contract Analysis",
                    "parties": [{"name": "Party A", "role": "Unknown"}, {"name": "Party B", "role": "Unknown"}],
                    "dates": [{"type": "Effective Date", "value": "Not specified"}],
                    "governing_law": "Not specified",
                    "clauses": [
                        {
                            "clause_name": "General Terms",
                            "summary": "Contract terms extracted from document",
                            "risk_level": "Medium",
                            "risk_reason": "Unable to fully analyze due to parsing error. Manual review recommended.",
                            "obligation": "Review document manually for obligations",
                            "liability": "Review document manually for liabilities",
                            "ai_summary": "Automated analysis encountered an error. This contract requires manual legal review to identify all terms, conditions, and potential risks."
                        }
                    ]
                }
        
        print(f"‚úÖ Analysis complete - Found {len(analysis.get('clauses', []))} clauses")

        return {
            **state,
            "analysis": analysis
        }
        
    except Exception as e:
        print(f"‚ùå Analysis error: {str(e)}")
        # Return minimal fallback
        return {
            **state,
            "analysis": {
                "title": "Error in Analysis",
                "parties": [],
                "dates": [],
                "governing_law": "Unknown",
                "clauses": []
            }
        }

In [15]:
def store_graph_agent(state: ContractState):
    data = state["analysis"]
    cid = state["cid"]
    filename = os.path.basename(state["pdf_path"])
    embeddings = state["embeddings"]

    print("üóÑÔ∏è Storing into Neo4j with vector embeddings")

    with neo4j_driver.session() as s:

        # Contract with embeddings
        s.run("""
        MERGE (c:Contract {id:$id})
        SET c.title=$title,
            c.file_name=$file,
            c.governing_law=$law,
            c.embedding=$emb
        """, id=cid, title=data.get("title", "Unknown Contract"), file=filename, 
             law = data.get("governing_law", "Not Specified"), emb=embeddings)

        # Parties with roles
        for p in data.get("parties", []):
            if isinstance(p, dict):
                s.run("""
                MERGE (o:Organization {name:$name})
                SET o.role=$role
                WITH o
                MATCH (c:Contract {id:$id})
                MERGE (o)-[:IS_PARTY_TO]->(c)
                """, name=p.get("name"), role=p.get("role"), id=cid)

        # Dates with types
        for d in data.get("dates", []):
            if isinstance(d, dict):
                s.run("""
                MERGE (dt:ImportantDate {value:$v})
                SET dt.type=$type
                WITH dt
                MATCH (c:Contract {id:$id})
                MERGE (c)-[:HAS_DATE]->(dt)
                """, v=d.get("value"), type=d.get("type"), id=cid)

        # Clauses with all details
        for cl in data.get("clauses", []):
            # Generate embedding for each clause via API
            clause_text = f"{cl.get('clause_name', '')} {cl.get('summary', '')}"
            clause_emb = get_embeddings_api(clause_text)
            
            if clause_emb is None:
                clause_emb = [0.0] * 384  # Fixed: Use 384 to match all-MiniLM-L6-v2 model
            
            # Normalize if nested
            if isinstance(clause_emb[0], list):
                clause_emb = clause_emb[0]
            
            s.run("""
            MATCH (c:Contract {id:$id})
            CREATE (cl:Clause {
                name:$n, 
                summary:$s,
                embedding:$emb
            })

            MERGE (r:Risk {level:$rl})
            MERGE (rr:RiskReason {text:$rr})
            MERGE (o:Obligation {text:$ob})
            MERGE (l:Liability {text:$li})
            MERGE (ai:AISummary {text:$ai})

            CREATE (c)-[:HAS_CLAUSE]->(cl)
            CREATE (cl)-[:HAS_RISK]->(r)
            CREATE (cl)-[:HAS_REASON]->(rr)
            CREATE (cl)-[:HAS_OBLIGATION]->(o)
            CREATE (cl)-[:HAS_LIABILITY]->(l)
            CREATE (cl)-[:HAS_AI_SUMMARY]->(ai)
            """,
            id=cid,
            n=safe_str(cl.get("clause_name")),
            s=safe_str(cl.get("summary")),
            rl=safe_str(cl.get("risk_level")),
            rr=safe_str(cl.get("risk_reason")),
            ob=safe_str(cl.get("obligation")),
            li=safe_str(cl.get("liability")),
            ai=safe_str(cl.get("ai_summary")),
            emb=clause_emb
            )

    print("‚úÖ Stored successfully with vector embeddings")
    return state

In [16]:
graph = StateGraph(ContractState)

graph.add_node("extract", pdf_extraction_agent)
graph.add_node("embed", embedding_agent)
graph.add_node("analyze", analysis_agent)
graph.add_node("store", store_graph_agent)

graph.set_entry_point("extract")

graph.add_edge("extract", "embed")
graph.add_edge("embed", "analyze")
graph.add_edge("analyze", "store")
graph.add_edge("store", END)

workflow = graph.compile()
print("‚úÖ LangGraph workflow ready")

‚úÖ LangGraph workflow ready


In [17]:
def retrieve_contract_from_db(contract_id):
    """
    Retrieve complete contract details from Neo4j database
    """
    print(f"\nüîç Retrieving contract: {contract_id}")
    
    with neo4j_driver.session() as s:
        # Get contract with all related data
        result = s.run("""
        MATCH (c:Contract {id:$id})
        OPTIONAL MATCH (c)<-[:IS_PARTY_TO]-(org:Organization)
        OPTIONAL MATCH (c)-[:HAS_DATE]->(dt:ImportantDate)
        OPTIONAL MATCH (c)-[:HAS_CLAUSE]->(cl:Clause)
        OPTIONAL MATCH (cl)-[:HAS_RISK]->(r:Risk)
        OPTIONAL MATCH (cl)-[:HAS_REASON]->(rr:RiskReason)
        OPTIONAL MATCH (cl)-[:HAS_OBLIGATION]->(o:Obligation)
        OPTIONAL MATCH (cl)-[:HAS_LIABILITY]->(l:Liability)
        OPTIONAL MATCH (cl)-[:HAS_AI_SUMMARY]->(ai:AISummary)
        
        RETURN c.title as title,
               c.file_name as file_name,
               c.id as contract_id,
               c.governing_law as governing_law,
               collect(DISTINCT org.name) as parties,
               collect(DISTINCT dt.value) as dates,
               collect(DISTINCT {
                   clause_name: cl.name,
                   summary: cl.summary,
                   risk_level: r.level,
                   risk_reason: rr.text,
                   obligation: o.text,
                   liability: l.text,
                   ai_summary: ai.text
               }) as clauses
        """, id=contract_id)
        
        record = result.single()
        if record:
            data = {
                "title": record["title"],
                "file_name": record["file_name"],
                "contract_id": record["contract_id"],
                "governing_law": record["governing_law"],
                "parties": [p for p in record["parties"] if p],
                "dates": [d for d in record["dates"] if d],
                "clauses": [c for c in record["clauses"] if c.get("clause_name")]
            }
            print("‚úÖ Contract retrieved successfully")
            return data
        else:
            print("‚ùå Contract not found")
            return None

def retrieve_all_contracts():
    """
    Retrieve all contracts from database
    """
    print("\nüìö Retrieving all contracts...")
    
    with neo4j_driver.session() as s:
        result = s.run("""
        MATCH (c:Contract)
        RETURN c.id as id, c.title as title, c.file_name as file_name
        """)
        
        contracts = []
        for record in result:
            contracts.append({
                "id": record["id"],
                "title": record["title"],
                "file_name": record["file_name"]
            })
        
        print(f"‚úÖ Found {len(contracts)} contracts")
        return contracts

def search_similar_clauses(query_text, top_k=5):
    """
    Search for similar clauses using vector embeddings
    """
    print(f"\nüîç Searching for clauses similar to: '{query_text}'")
    
    # Generate embedding for query via API
    query_emb = get_embeddings_api(query_text)
    
    if query_emb is None:
        print("‚ùå Could not generate query embedding")
        return []
    
    # Normalize if nested
    if isinstance(query_emb[0], list):
        query_emb = query_emb[0]
    
    with neo4j_driver.session() as s:
        # Get all clauses with embeddings
        result = s.run("""
        MATCH (c:Contract)-[:HAS_CLAUSE]->(cl:Clause)
        WHERE cl.embedding IS NOT NULL
        RETURN c.title as contract_title,
               cl.name as clause_name,
               cl.summary as summary,
               cl.embedding as embedding
        """)
        
        clauses = []
        for record in result:
            # Calculate cosine similarity
            emb = np.array(record["embedding"])
            query = np.array(query_emb)
            
            # Handle different embedding dimensions
            if len(emb) != len(query):
                continue
            
            similarity = np.dot(emb, query) / (np.linalg.norm(emb) * np.linalg.norm(query))
            
            clauses.append({
                "contract": record["contract_title"],
                "clause": record["clause_name"],
                "summary": record["summary"],
                "similarity": float(similarity)
            })
        
        # Sort by similarity
        clauses.sort(key=lambda x: x["similarity"], reverse=True)
        
        print(f"\nüìä Top {top_k} similar clauses:")
        for i, clause in enumerate(clauses[:top_k], 1):
            print(f"\n[{i}] Similarity: {clause['similarity']:.4f}")
            print(f"    Contract: {clause['contract']}")
            print(f"    Clause: {clause['clause']}")
            print(f"    Summary: {clause['summary']}")
        
        return clauses[:top_k]

In [18]:
# ===== PROCESS CONTRACTS =====
pdfs = [
    "Legal-Services-Agreement.pdf",
    "Employment_contract.pdf",
    "sample_contract.pdf"
]

contract_ids = []

for pdf in pdfs:
    cid = pdf_hash(pdf)
    contract_ids.append(cid)

    print("\n" + "="*80)
    print(f"üöÄ Processing: {pdf}")
    print("="*80)

    workflow.invoke({
        "pdf_path": pdf,
        "cid": cid,
        "text": "",
        "embeddings": [],
        "analysis": {},
    })


üöÄ Processing: Legal-Services-Agreement.pdf

üìÑ Extracting PDF: Legal-Services-Agreement.pdf
üî¢ Generating embeddings via HuggingFace API


‚ö†Ô∏è Using fallback embeddings
   Embedding dimension: 384
üß† Analyzing contract via Groq LLM
‚úÖ Analysis complete - Found 12 clauses
üóÑÔ∏è Storing into Neo4j with vector embeddings
‚úÖ Stored successfully with vector embeddings

üöÄ Processing: Employment_contract.pdf

üìÑ Extracting PDF: Employment_contract.pdf
üî¢ Generating embeddings via HuggingFace API
‚ö†Ô∏è Using fallback embeddings
   Embedding dimension: 384
üß† Analyzing contract via Groq LLM
‚úÖ Analysis complete - Found 12 clauses
üóÑÔ∏è Storing into Neo4j with vector embeddings
‚úÖ Stored successfully with vector embeddings

üöÄ Processing: sample_contract.pdf

üìÑ Extracting PDF: sample_contract.pdf
üî¢ Generating embeddings via HuggingFace API
‚ö†Ô∏è Using fallback embeddings
   Embedding dimension: 384
üß† Analyzing contract via Groq LLM
‚úÖ Analysis complete - Found 6 clauses
üóÑÔ∏è Storing into Neo4j with vector embeddings
‚úÖ Stored successfully with vector embeddings


In [19]:
# ===== RETRIEVE AND DISPLAY RESULTS =====

print("\n\n" + "#"*80)
print("# RETRIEVING STORED CONTRACTS FROM DATABASE")
print("#"*80)

# Show all contracts
all_contracts = retrieve_all_contracts()

# Retrieve and display each processed contract
for cid in contract_ids:
    contract_data = retrieve_contract_from_db(cid)
    if contract_data:
        print_contract_summary(contract_data)



################################################################################
# RETRIEVING STORED CONTRACTS FROM DATABASE
################################################################################

üìö Retrieving all contracts...
‚úÖ Found 3 contracts

üîç Retrieving contract: 1de79b4ffc94b68989ab72c79d929fc5169e8c48b1ea78fc7ee7c8bcd0b1c7d2
‚úÖ Contract retrieved successfully

üìÑ CONTRACT SUMMARY

üìå BASIC INFORMATION
--------------------------------------------------------------------------------
Title          : Legal Services Agreement
File Name      : Legal-Services-Agreement.pdf
Contract ID    : 1de79b4ffc94b68989ab72c79d929fc5169e8c48b1ea78fc7ee7c8bcd0b1c7d2
Governing Law  : California law

üë• PARTIES (2)
--------------------------------------------------------------------------------
  [1] Law Firm
  [2] Client

üìÖ IMPORTANT DATES (1)
--------------------------------------------------------------------------------
  [1] Not specified

‚öñÔ∏è CLAUSE RISK ANAL

In [None]:
# ===== VIEW INDIVIDUAL CONTRACT GRAPHS =====

def reconnect_neo4j():
    """
    Reconnect to Neo4j if connection is lost
    """
    global neo4j_driver
    
    try:
        # Close existing driver if it exists
        if 'neo4j_driver' in globals():
            try:
                neo4j_driver.close()
            except:
                pass
    except:
        pass
    
    # Re-initialize driver with URI fix
    uri = os.environ["NEO4J_URI"]
    if "neo4j+s://" in uri and "neo4j+ssc://" not in uri:
        uri = uri.replace("neo4j+s://", "neo4j+ssc://")
        print(f"üîÑ Updated URI for Aura compatibility: {uri[:50]}...")
    
    neo4j_driver = GraphDatabase.driver(
        uri,
        auth=(os.environ["NEO4J_USERNAME"], os.environ["NEO4J_PASSWORD"])
    )
    
    # Test connection
    try:
        neo4j_driver.verify_connectivity()
        print("‚úÖ Neo4j Aura reconnected successfully!")
        return True
    except Exception as e:
        print(f"‚ùå Failed to reconnect: {e}")
        return False

def get_contract_cypher_query(contract_id=None, contract_title=None):
    """
    Generate Cypher query to view a single contract graph in Neo4j Browser
    Returns query string you can copy-paste into Neo4j Browser
    """
    if contract_id:
        query = f"""// View Individual Contract Graph by ID
MATCH (c:Contract {{id: "{contract_id}"}})
OPTIONAL MATCH (c)<-[:IS_PARTY_TO]-(o:Organization)
OPTIONAL MATCH (c)-[:HAS_DATE]->(d:ImportantDate)
OPTIONAL MATCH (c)-[:HAS_CLAUSE]->(cl:Clause)
OPTIONAL MATCH (cl)-[:HAS_RISK]->(r:Risk)
OPTIONAL MATCH (cl)-[:HAS_REASON]->(rr:RiskReason)
OPTIONAL MATCH (cl)-[:HAS_OBLIGATION]->(ob:Obligation)
OPTIONAL MATCH (cl)-[:HAS_LIABILITY]->(li:Liability)
OPTIONAL MATCH (cl)-[:HAS_AI_SUMMARY]->(ai:AISummary)
RETURN c, o, d, cl, r, rr, ob, li, ai"""
    elif contract_title:
        query = f"""// View Individual Contract Graph by Title
MATCH (c:Contract {{title: "{contract_title}"}})
OPTIONAL MATCH (c)<-[:IS_PARTY_TO]-(o:Organization)
OPTIONAL MATCH (c)-[:HAS_DATE]->(d:ImportantDate)
OPTIONAL MATCH (c)-[:HAS_CLAUSE]->(cl:Clause)
OPTIONAL MATCH (cl)-[:HAS_RISK]->(r:Risk)
OPTIONAL MATCH (cl)-[:HAS_REASON]->(rr:RiskReason)
OPTIONAL MATCH (cl)-[:HAS_OBLIGATION]->(ob:Obligation)
OPTIONAL MATCH (cl)-[:HAS_LIABILITY]->(li:Liability)
OPTIONAL MATCH (cl)-[:HAS_AI_SUMMARY]->(ai:AISummary)
RETURN c, o, d, cl, r, rr, ob, li, ai"""
    else:
        query = """// View ALL Contracts Together (Explore View)
MATCH (c:Contract)
OPTIONAL MATCH (c)<-[:IS_PARTY_TO]-(o:Organization)
OPTIONAL MATCH (c)-[:HAS_DATE]->(d:ImportantDate)
OPTIONAL MATCH (c)-[:HAS_CLAUSE]->(cl:Clause)
OPTIONAL MATCH (cl)-[:HAS_RISK]->(r:Risk)
RETURN c, o, d, cl, r
LIMIT 100"""
    return query

def view_individual_contract_graph(contract_id=None, contract_title=None):
    """
    View a single contract's graph structure
    Shows what nodes and relationships belong to this contract only
    """
    if not contract_id and not contract_title:
        print("‚ùå Please provide either contract_id or contract_title")
        return
    
    print("\n" + "="*80)
    print("üìä INDIVIDUAL CONTRACT GRAPH VIEW")
    print("="*80)
    
    # Try to reconnect if connection fails
    try:
        with neo4j_driver.session() as s:
            if contract_id:
            result = s.run("""
                MATCH (c:Contract {id: $id})
                OPTIONAL MATCH (c)<-[:IS_PARTY_TO]-(o:Organization)
                OPTIONAL MATCH (c)-[:HAS_DATE]->(d:ImportantDate)
                OPTIONAL MATCH (c)-[:HAS_CLAUSE]->(cl:Clause)
                OPTIONAL MATCH (cl)-[:HAS_RISK]->(r:Risk)
                OPTIONAL MATCH (cl)-[:HAS_REASON]->(rr:RiskReason)
                OPTIONAL MATCH (cl)-[:HAS_OBLIGATION]->(ob:Obligation)
                OPTIONAL MATCH (cl)-[:HAS_LIABILITY]->(li:Liability)
                OPTIONAL MATCH (cl)-[:HAS_AI_SUMMARY]->(ai:AISummary)
                RETURN c, 
                       collect(DISTINCT o) as parties,
                       collect(DISTINCT d) as dates,
                       collect(DISTINCT cl) as clauses,
                       collect(DISTINCT r) as risks
            """, id=contract_id)
        else:
            result = s.run("""
                MATCH (c:Contract {title: $title})
                OPTIONAL MATCH (c)<-[:IS_PARTY_TO]-(o:Organization)
                OPTIONAL MATCH (c)-[:HAS_DATE]->(d:ImportantDate)
                OPTIONAL MATCH (c)-[:HAS_CLAUSE]->(cl:Clause)
                OPTIONAL MATCH (cl)-[:HAS_RISK]->(r:Risk)
                OPTIONAL MATCH (cl)-[:HAS_REASON]->(rr:RiskReason)
                OPTIONAL MATCH (cl)-[:HAS_OBLIGATION]->(ob:Obligation)
                OPTIONAL MATCH (cl)-[:HAS_LIABILITY]->(li:Liability)
                OPTIONAL MATCH (cl)-[:HAS_AI_SUMMARY]->(ai:AISummary)
                RETURN c, 
                       collect(DISTINCT o) as parties,
                       collect(DISTINCT d) as dates,
                       collect(DISTINCT cl) as clauses,
                       collect(DISTINCT r) as risks
            """, title=contract_title)
        
        record = result.single()
        if record:
            c = record["c"]
            parties = [p for p in record["parties"] if p]
            dates = [d for d in record["dates"] if d]
            clauses = [cl for cl in record["clauses"] if cl]
            risks = [r for r in record["risks"] if r]
            
            print(f"\nüìÑ Contract: {c.get('title', 'Unknown')}")
            print(f"   ID: {c.get('id', 'N/A')[:30]}...")
            print(f"\nüìä Graph Statistics:")
            print(f"   Parties: {len(parties)}")
            print(f"   Dates: {len(dates)}")
            print(f"   Clauses: {len(clauses)}")
            print(f"   Risk Levels: {len(risks)}")
            
            print(f"\nüîó Copy this query to Neo4j Browser to visualize:")
            print("-" * 80)
            query = get_contract_cypher_query(contract_id=contract_id, contract_title=contract_title)
            print(query)
            print("-" * 80)
            
            return {
                "contract": c,
                "parties": parties,
                "dates": dates,
                "clauses": clauses,
                "risks": risks,
                "cypher_query": query
            }
        else:
            print("‚ùå Contract not found")
            return None

def list_contracts_for_viewing():
    """
    List all contracts with their IDs and titles for easy selection
    """
    print("\n" + "="*80)
    print("üìã AVAILABLE CONTRACTS FOR VIEWING")
    print("="*80)
    
    # Try to reconnect if connection fails
    try:
        contracts = retrieve_all_contracts()
    except Exception as e:
        print(f"‚ö†Ô∏è Connection error: {e}")
        print("üîÑ Attempting to reconnect...")
        if reconnect_neo4j():
            contracts = retrieve_all_contracts()
        else:
            print("‚ùå Could not connect to Neo4j. Please check your connection.")
            return []
    
    if not contracts:
        print("No contracts found in database")
        return []
    
    print(f"\nFound {len(contracts)} contract(s):\n")
    for i, contract in enumerate(contracts, 1):
        print(f"[{i}] {contract['title']}")
        print(f"    ID: {contract['id']}")
        print(f"    File: {contract['file_name']}\n")
    
    return contracts

print("‚úÖ Individual contract graph viewing functions loaded!")
print("\nUsage:")
print("  1. reconnect_neo4j() - Reconnect if you get connection errors")
print("  2. list_contracts_for_viewing() - See all contracts")
print("  3. view_individual_contract_graph(contract_id='...') - View one contract")
print("  4. get_contract_cypher_query(contract_id='...') - Get Cypher query for Browser")
print("\nüí° If you get connection errors, run: reconnect_neo4j()")


‚úÖ Individual contract graph viewing functions loaded!

Usage:
  1. list_contracts_for_viewing() - See all contracts
  2. view_individual_contract_graph(contract_id='...') - View one contract
  3. get_contract_cypher_query(contract_id='...') - Get Cypher query for Browser


In [None]:
# ===== QUICK FIX: Reconnect Neo4j =====
# Run this cell if you get connection errors

# Close existing driver
try:
    neo4j_driver.close()
except:
    pass

# Re-initialize with URI fix
uri = os.environ["NEO4J_URI"]
if "neo4j+s://" in uri and "neo4j+ssc://" not in uri:
    uri = uri.replace("neo4j+s://", "neo4j+ssc://")
    print(f"üîÑ Updated URI: {uri[:50]}...")

neo4j_driver = GraphDatabase.driver(
    uri,
    auth=(os.environ["NEO4J_USERNAME"], os.environ["NEO4J_PASSWORD"])
)

# Test connection
try:
    neo4j_driver.verify_connectivity()
    print("‚úÖ Neo4j reconnected! Now try your query again.")
except Exception as e:
    print(f"‚ùå Connection failed: {e}")
    print("üí° Check your .env file and Neo4j Aura status")


In [20]:
# ===== VECTOR SIMILARITY SEARCH EXAMPLE =====

print("\n\n" + "#"*80)
print("# VECTOR SIMILARITY SEARCH")
print("#"*80)

# Example searches
search_queries = [
    "payment terms and conditions",
    "liability and indemnification",
    "termination clause"
]

for query in search_queries:
    search_similar_clauses(query, top_k=3)
    print("\n" + "-"*80)



################################################################################
# VECTOR SIMILARITY SEARCH
################################################################################

üîç Searching for clauses similar to: 'payment terms and conditions'
‚ùå Could not generate query embedding

--------------------------------------------------------------------------------

üîç Searching for clauses similar to: 'liability and indemnification'
‚ùå Could not generate query embedding

--------------------------------------------------------------------------------

üîç Searching for clauses similar to: 'termination clause'
‚ùå Could not generate query embedding

--------------------------------------------------------------------------------


In [21]:
# ===== CLEANUP =====
import atexit
atexit.register(lambda: neo4j_driver.close())
print("üîí Neo4j connection will close on exit")


üîí Neo4j connection will close on exit


In [23]:
# List all contracts to get IDs
list_contracts_for_viewing()


üìã AVAILABLE CONTRACTS FOR VIEWING

üìö Retrieving all contracts...
‚úÖ Found 3 contracts

Found 3 contract(s):

[1] Legal Services Agreement
    ID: 1de79b4ffc94b68989ab72c79d929fc5169e8c48b1ea78fc7ee7c8bcd0b1c7d2
    File: Legal-Services-Agreement.pdf

[2] Employment Agreement
    ID: 36b93039beb3de2504d3c21533f28516466e6f6ab2529f4bd31d275b66d27dd3
    File: Employment_contract.pdf

[3] Service Agreement
    ID: cc92ad20222720eef7875a359d5d0c52166963bd678f7263c92a8cbd60726f80
    File: sample_contract.pdf



[{'id': '1de79b4ffc94b68989ab72c79d929fc5169e8c48b1ea78fc7ee7c8bcd0b1c7d2',
  'title': 'Legal Services Agreement',
  'file_name': 'Legal-Services-Agreement.pdf'},
 {'id': '36b93039beb3de2504d3c21533f28516466e6f6ab2529f4bd31d275b66d27dd3',
  'title': 'Employment Agreement',
  'file_name': 'Employment_contract.pdf'},
 {'id': 'cc92ad20222720eef7875a359d5d0c52166963bd678f7263c92a8cbd60726f80',
  'title': 'Service Agreement',
  'file_name': 'sample_contract.pdf'}]

In [27]:
# View individual contract and get the query
# Step 1: Get all contracts
contracts = list_contracts_for_viewing()

# Step 2: View the first contract individually
if contracts:
    print("\n" + "="*80)
    print("VIEWING FIRST CONTRACT")
    print("="*80)
    
    # Use the first contract's ID
    result = view_individual_contract_graph(contract_id=contracts[0]['id'])
    
    # The function will print a Cypher query - copy that to Neo4j Browser!


üìã AVAILABLE CONTRACTS FOR VIEWING

üìö Retrieving all contracts...


[#EA9A]  _: <CONNECTION> error: Failed to write data to connection IPv4Address(('si-8c5998b9-4a7b.production-orch-0703.neo4j.io', 7687)) (ResolvedIPv4Address(('34.124.169.171', 7687))): ConnectionResetError(10054, 'An existing connection was forcibly closed by the remote host', None, 10054, None)


SessionExpired: Failed to write data to connection IPv4Address(('si-8c5998b9-4a7b.production-orch-0703.neo4j.io', 7687)) (ResolvedIPv4Address(('34.124.169.171', 7687)))