In [None]:
import json
import hashlib
import re
import zlib
import base64
import requests
import operator
import os
import logging
import time
import pandas as pd
import sys

# sqlite-vec expects the standard library sqlite3 module.
# Older versions of this notebook replaced sqlite3 with sqlean via sys.modules; undo that if present.
if sys.modules.get("sqlite3") is sys.modules.get("sqlean"):
    del sys.modules["sqlite3"]
import sqlite3

from datetime import datetime
from enum import Enum
from typing import Annotated, List, TypedDict, Optional, Dict, Any, Tuple
from pydantic import BaseModel, Field, model_validator
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.documents import Document
from langchain_community.vectorstores import SQLiteVec
from langchain_community.embeddings import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer
from langgraph.graph import StateGraph, START, END
from prompts import DECOMPOSER_SYSTEM, GENERATOR_SYSTEM, CRITIC_SYSTEM, SUMMARIZER_SYSTEM, REFLECTOR_SYSTEM


logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

In [48]:
class NodeNames(str, Enum):
    """Enum for node names to avoid string literals."""
    RETRIEVE = "retrieve"
    DECOMPOSE = "decompose"
    GENERATE = "generate"
    SYNTAX_CHECK = "syntax_check"
    CRITIC = "critic"
    SUMMARIZE = "summarize"
    REFLECT = "reflect"
    PLAN_AUDIT = "plan_audit"


class Scores(int, Enum):
    AVERAGE_SCORE_THRESHOLD = 8.5
    REQUIREMENT_COVERAGE_THRESHOLD = 9.0


class CritiqueError(BaseModel):
    """Model for a single critique error."""
    type: str = Field(description="Type of error")
    description: str = Field(description="Detailed description of the error")


class CritiqueResponse(BaseModel):
    """Structured output from the CRITIC node."""
    requirement_coverage: float = Field(ge=0, le=10, description="Does it capture all classes and relationships from the text?")
    design_best_practices: float = Field(ge=0, le=10, description="Are relationships correct? (e.g., composition vs association)")
    structural_integrity: float = Field(ge=0, le=10, description="Are there redundant classes or missing attributes?")
    is_valid: bool = Field(description=f"True only if total average score is > {Scores.AVERAGE_SCORE_THRESHOLD} AND 'requirement_coverage' is >= {Scores.REQUIREMENT_COVERAGE_THRESHOLD}")
    errors: List[CritiqueError] = Field(default_factory=list, description="List of errors found")
    warnings: List[str] = Field(default_factory=list, description="List of warnings")
    missing_concepts: List[str] = Field(default_factory=list, description="Concepts from requirements not in diagram")
    reasoning: str = Field(description="Brief explanation for the scores provided.")

    @property
    def weighted_score(self) -> float:
        return (self.requirement_coverage * 0.5) + \
               (self.design_best_practices * 0.3) + \
               (self.structural_integrity * 0.2)
    
    @model_validator(mode='after')
    def compute_validity(self) -> 'CritiqueResponse':
        self.is_valid = (self.weighted_score > Scores.AVERAGE_SCORE_THRESHOLD) and (self.requirement_coverage >= Scores.REQUIREMENT_COVERAGE_THRESHOLD)
        return self


class SummaryResponse(BaseModel):
    """Structured output from the SUMMARIZER node."""
    is_complete: bool = Field(description="Whether all issues are resolved")
    fixed: List[str] = Field(default_factory=list, description="Issues that were fixed")
    unresolved: List[str] = Field(default_factory=list, description="Issues still present")
    message: str = Field(description="Brief status summary")


class PlanAudit(BaseModel):
    is_valid: bool = Field(description="True if the plan is logically sound and covers all requirements.")
    critique: List[str] = Field(default_factory=list, description="List of specific logical flaws (e.g., 'Missing relationship between User and Account').")
    suggestions: List[str] = Field(default_factory=list, description="Actionable steps to fix the plan.")


class SystemConfig(BaseModel):
    """System configuration for UML generation."""
    lmstudio_base_url: str = Field(default="http://localhost:1234/v1", description="LMStudio API endpoint")
    model_name: str = Field(default="mistralai/devstral-small-2-2512", description="Model to use")
    embedder_model: str = Field(default="BAAI/bge-large-en-v1.5", description="Embedder model for semantic search")
    db_path: str = Field(default="./../data/uml_knowledge.db", description="Path to SQLite database")
    shots_json_path: str = Field(default="./../data/complete_shots.json", description="Path to few-shot examples")
    plantuml_host: str = Field(default="http://localhost:8080", description="PlantUML server host")
    max_iterations: int = Field(default=6, ge=1, description="Maximum workflow iterations")
    max_tokens_decompose: int = Field(default=1024, description="Max tokens for decompose step")
    max_tokens_generate: int = Field(default=2048, description="Max tokens for generate step")
    max_tokens_critique: int = Field(default=2048, description="Max tokens for critique step")
    max_tokens_summarize: int = Field(default=1024, description="Max tokens for summarize step")
    max_tokens_reflect: int = Field(default=2048, description="Max tokens for reflect step")
    max_tokens_compare: int = Field(default=1024, description="Max tokens for compare step")
    temperature: float = Field(default=0.15, ge=0.0, le=2.0, description="Base temperature for LLM")
    num_few_shots: int = Field(default=3, ge=0, description="Number of few-shot examples")
    request_timeout: int = Field(default=5, ge=1, description="Timeout for PlantUML server requests")
    llm_timeout: int = Field(default=120, ge=1, description="Timeout for LLM operations")


class PlantUMLResult(BaseModel):
    """Result from PlantUML validation."""
    is_valid: bool = Field(description="Whether the PlantUML syntax is valid")
    error: Optional[str] = Field(default=None, description="Error message if validation failed")
    url: Optional[str] = Field(default=None, description="URL to view the diagram")
    svg_url: Optional[str] = Field(default=None, description="URL to view the diagram as SVG")


class AgentState(TypedDict):
    """
    Shared state for the LangGraph workflow.
    """
    requirements: str
    plan: Optional[str]
    examples: List[Dict[str, str]]
    current_diagram: Optional[str]
    best_diagram: Optional[str]  
    history: Annotated[List[Dict[str, Any]], operator.add]
    summary: Optional[str]
    syntax_valid: bool
    logic_valid: bool
    error_message: Optional[str]
    plan_valid: bool 
    best_score: float
    best_code: str
    current_validation: Optional[CritiqueResponse]
    iterations: int

In [49]:
def create_llm(config: Optional[SystemConfig] = None) -> ChatOpenAI:
    """
    Create a ChatOpenAI instance configured for LMStudio.
    
    Args:
        config: Optional system configuration
        
    Returns:
        Configured ChatOpenAI instance
    """
    cfg = config or SystemConfig()
    logger.info(f"Connecting to LMStudio at {cfg.lmstudio_base_url}")
    logger.info(f"Using model: {cfg.model_name} (temp={cfg.temperature})")
    
    return ChatOpenAI(
        base_url=cfg.lmstudio_base_url,
        api_key="lm-studio",  
        model=cfg.model_name,
        temperature=cfg.temperature,
        timeout=cfg.llm_timeout 
    )

In [50]:
class PlantUMLTool:
    """
    Tool for validating and rendering PlantUML diagrams.
    
    This class interfaces with a PlantUML server to check syntax
    and generate diagram URLs.
    """
    
    def __init__(self, host: str = "http://localhost:8080"):
        """
        Initialize PlantUML tool.
        
        Args:
            host: PlantUML server host URL
        """
        self.host = host
        logger.info(f"PlantUML tool initialized with host: {host}")

    def extract_plantuml(self, text: str) -> str:
        """
        Extract PlantUML code from markdown blocks or raw text.
        
        Args:
            text: Text containing PlantUML code
            
        Returns:
            Extracted PlantUML code or empty string
        """
        if not text:
            return ""
        
        # Try to extract from ```plantuml ... ```
        fence_match = re.search(r"```\s*plantuml\s*(.*?)```", text, re.DOTALL | re.IGNORECASE)
        if fence_match:
            return fence_match.group(1).strip()
        
        # Try to extract from @startuml ... @enduml
        tag_match = re.search(r"@startuml.*?@enduml", text, re.DOTALL | re.IGNORECASE)
        if tag_match:
            return tag_match.group(0).strip()
        
        return text.strip()

    def _encode_plantuml(self, plantuml_code: str) -> str:
        """
        Encode PlantUML code for URL.
        
        Args:
            plantuml_code: Raw PlantUML code
            
        Returns:
            URL-safe encoded string
        """
        code = plantuml_code.strip()
        
        if not code.startswith("@startuml"): 
            code = f"@startuml\n{code}"
        if not code.endswith("@enduml"): 
            code = f"{code}\n@enduml"
        
        compressed = zlib.compress(code.encode('utf-8'))[2:-4]
        encoded = base64.b64encode(compressed).translate(
            bytes.maketrans(
                b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
                b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-_"
            )
        ).decode('utf-8')
        
        return encoded

    def get_diagram_url(self, plantuml_code: str, format: str = "png") -> str:
        """
        Generate a viewable URL for the PlantUML diagram.
        
        Args:
            plantuml_code: PlantUML diagram code
            format: Output format (png, svg, etc.)
            
        Returns:
            URL to view the diagram
        """
        diagram_code = self.extract_plantuml(plantuml_code)
        encoded = self._encode_plantuml(diagram_code)
        return f"{self.host}/{format}/{encoded}"
        
    def check_syntax(self, plantuml_code: str, timeout: int = 5) -> PlantUMLResult:
        """
        Validate PlantUML syntax with detailed error extraction.
        
        Args:
            plantuml_code: PlantUML code to validate
            timeout: Request timeout in seconds
            
        Returns:
            PlantUMLResult with validation status and detailed error if applicable.
        """
        logger.info("Validating PlantUML syntax")
        
        try:
            diagram_code = self.extract_plantuml(plantuml_code)
            encoded = self._encode_plantuml(diagram_code)
            
            url_png = f"{self.host}/png/{encoded}"
            response = requests.get(url_png, timeout=timeout)
            
            if response.status_code == 200 and response.content[:4] == b'\x89PNG':
                logger.info("Syntax validation passed (PNG rendered)")
                return PlantUMLResult(
                    is_valid=True,
                    url=url_png,
                    svg_url=f"{self.host}/svg/{encoded}"
                )
            
            logger.warning("PNG rendering failed. Fetching detailed syntax error...")
            url_txt = f"{self.host}/txt/{encoded}"
            error_response = requests.get(url_txt, timeout=timeout)
            
            detailed_error = error_response.text.strip() if error_response.status_code == 200 else "Unknown server error"
            
            error_msg = f"PlantUML Syntax Error:\n{detailed_error[:1000]}"
            logger.error(f"Syntax error detected: {error_msg}")
            
            return PlantUMLResult(
                is_valid=False,
                error=error_msg
            )
            
        except requests.exceptions.RequestException as e:
            error_msg = f"PlantUML Server Connection Error: {str(e)}"
            logger.error(error_msg)
            return PlantUMLResult(is_valid=False, error=error_msg)
            
        except Exception as e:
            error_msg = f"Unexpected error during syntax check: {str(e)}"
            logger.error(error_msg)
            return PlantUMLResult(is_valid=False, error=error_msg)


In [51]:
class MemoryManager:
    """
    Manages long-term memory for UML diagram generation using LangChain's SQLiteVec.
    
    Supports semantic search to find similar past solutions.
    """
    
    def __init__(
        self,
        embedder: SentenceTransformer,
        db_path: str = "./../data/uml_knowledge.db",
        embedding_dims: int = 1024
    ):
        """
        Initialize memory manager with LangChain SQLiteVec.
        
        Args:
            embedder: SentenceTransformer model for semantic search
            db_path: Path to the SQLite database file
            embedding_dims: Dimensions of the embeddings 
        """
        self.embedder = embedder
        self.db_path = db_path
        self.embedding_dims = embedding_dims
        

        self.embedding_function = HuggingFaceEmbeddings(
            model_name=embedder.model_name if hasattr(embedder, 'model_name') else "BAAI/bge-large-en-v1.5",
            model_kwargs={'device': 'cpu'},
            encode_kwargs={'normalize_embeddings': True}
        )
        
        # Create directory and connection
        os.makedirs(os.path.dirname(self.db_path) if os.path.dirname(self.db_path) else ".", exist_ok=True)
        
        # Create connection using SQLiteVec's method
        self.connection = SQLiteVec.create_connection(db_file=self.db_path)
        
        # Initialize vector store with connection
        self.vector_store = SQLiteVec(
            table="uml_memories",
            connection=self.connection,
            embedding=self.embedding_function
        )
        
        logger.info(f"MemoryManager initialized with LangChain SQLiteVec at {db_path} (dims={embedding_dims})")

    def save_diagram(
        self,
        requirements: str,
        diagram: str,
        metadata: Optional[Dict[str, Any]] = None
    ) -> int:
        """
        Save a validated diagram to SQLite long-term memory.
                
        Args:
            requirements: Original requirements text
            diagram: PlantUML diagram code
            metadata: Optional metadata
            
        Returns:
            ID of the stored record
        """
        timestamp = datetime.now().isoformat()
        
        full_metadata = metadata or {}
        full_metadata.update({
            "diagram": diagram,
            "timestamp": timestamp
        })
        

        doc = Document(
            page_content=requirements,
            metadata=full_metadata
        )
        
        ids = self.vector_store.add_documents([doc])
        
        logger.info("Diagram saved to SQLite memory using LangChain SQLiteVec")
        return ids[0] if ids else 0
    
    def retrieve_similar_diagrams(
        self,
        requirements: str,
        limit: int = 2
    ) -> List[Dict[str, Any]]:
        """
        Retrieve similar diagrams from SQLite memory using vector search.
        
        Args:
            requirements: Requirements text to search for
            limit: Maximum number of results
            
        Returns:
            List of similar diagram records
        """
        try:
            results = self.vector_store.similarity_search(requirements, k=limit)
            
            diagrams = []
            for doc in results:
                diagrams.append({
                    "requirements": doc.page_content,
                    "diagram": doc.metadata.get("diagram", ""),
                    "timestamp": doc.metadata.get("timestamp", ""),
                    "metadata": {k: v for k, v in doc.metadata.items() 
                                if k not in ["diagram", "timestamp"]}
                })
                
            logger.info(f"Retrieved {len(diagrams)} similar diagrams from SQLite")
            return diagrams
            
        except Exception as e:
            logger.warning(f"Memory retrieval failed: {e}")
            return []
    
    def clear_memory(self) -> None:
        try:
            # Close existing connection
            if hasattr(self, 'connection'):
                self.connection.close()
            
            # Remove database file
            if os.path.exists(self.db_path):
                os.remove(self.db_path)
            
            # Recreate connection and vector store
            self.connection = SQLiteVec.create_connection(db_file=self.db_path)
            self.vector_store = SQLiteVec(
                table="uml_memories",
                connection=self.connection,
                embedding=self.embedding_function
            )

            logger.info("Memory cleared and reinitialized")
        except Exception as e:
            logger.error(f"Failed to clear memory: {e}")


In [52]:
def seed_memory_from_shots(
    memory_manager: MemoryManager,
    shots_json_path: str = "./../data/complete_shots.json",
    force_reseed: bool = False
) -> int:
    """
    Seed the memory database with few-shot examples from JSON file.
    Skips seeding if database already contains data (unless force_reseed=True).
    
    Args:
        memory_manager: MemoryManager instance to seed
        shots_json_path: Path to the complete_shots.json file
        force_reseed: If True, clears existing data and reseeds
        
    Returns:
        Number of shots seeded (0 if skipped)
    """
    logger.info("="*60)
    logger.info("CHECKING MEMORY SEEDING STATUS")
    logger.info("="*60)
    
    # Check if database already has data
    try:
        existing_docs = memory_manager.vector_store.similarity_search("test", k=1)
        if existing_docs and not force_reseed:
            logger.info(f"Database already contains data ({len(existing_docs)} docs found)")
            logger.info("Skipping seeding operation. Set force_reseed=True to override.")
            return 0
    except Exception as e:
        logger.info(f"Database appears empty or uninitialized: {e}")
    
    if force_reseed:
        logger.warning("Force reseed enabled - clearing existing memory")
        memory_manager.clear_memory()
    

    if not os.path.exists(shots_json_path):
        logger.error(f"Shots file not found at {shots_json_path}")
        return 0
    
    logger.info(f"Loading shots from {shots_json_path}")
    with open(shots_json_path, 'r', encoding='utf-8') as f:
        shots = json.load(f)
    
    logger.info(f"Found {len(shots)} shots to seed")
    
    # Prepare documents
    documents = []
    for shot in shots:
        requirements = shot["requirements"]
        diagram = shot["solution_plantuml"]
        
        metadata = {
            "diagram": diagram,
            "timestamp": datetime.now().isoformat(),
            "plan": shot.get("subgoal_decomposition"),
            "reasoning": shot.get("chain_of_thought"),
            "is_static": True,
            "title": shot.get("title", "Untitled")
        }
        
        logger.info(f"  Processing: {metadata['title']}")
        
        doc = Document(
            page_content=requirements,
            metadata=metadata
        )
        documents.append(doc)
    
    if documents:
        memory_manager.vector_store.add_documents(documents)
        logger.info("="*60)
        logger.info(f"✓ Successfully seeded {len(documents)} shots to memory")
        logger.info("="*60)
        return len(documents)
    
    return 0


In [None]:
class UMLNodes:
    """
    Collection of agent nodes for the UML generation workflow.
    
    Each method represents a node in the LangGraph workflow and
    follows the pattern of taking AgentState and returning a dict
    with state updates.
    """
    
    def __init__(
        self,
        llm: ChatOpenAI,
        plantuml_tool: PlantUMLTool,
        memory_manager: Optional['MemoryManager'] = None,
        config: Optional[SystemConfig] = None
    ):
        """
        Initialize UML nodes with required dependencies.
        
        Args:
            llm: LangChain ChatOpenAI instance
            plantuml_tool: Tool for PlantUML validation
            memory_manager: long-term memory manager
            config: Optional system configuration
        """
        self.llm = llm
        self.plantuml_tool = plantuml_tool
        self.memory_manager = memory_manager
        self.config = config or SystemConfig()
        logger.info("UMLNodes initialized")

    def _safe_invoke(self, runnable: Any, input_data: Any, **kwargs) -> Any:
        """
        Invoke a runnable (LLM or chain) with retry logic.
        """
        max_retries = 3
        last_exception = None
        
        for attempt in range(max_retries):
            try:
                return runnable.invoke(input_data, **kwargs)
            except Exception as e:
                last_exception = e
                logger.warning(f"LLM call failed (attempt {attempt+1}/{max_retries}): {e}")
                if attempt < max_retries - 1:
                    time.sleep(2 * (attempt + 1))
        
        logger.error(f"Max retries reached for LLM call: {last_exception}")
        raise last_exception

    def retrieve(self, state: AgentState) -> Dict[str, Any]:
        """
        Retrieve relevant few-shot examples based on requirements.
        
        Args:
            state: Current workflow state
            
        Returns:
            Dict with 'examples' key containing formatted shots
        """
        logger.info(f"--- NODE: {NodeNames.RETRIEVE.upper()} ---")
        
        try:
            memories = self.memory_manager.retrieve_similar_diagrams(
                state["requirements"],
                limit=self.config.num_few_shots
            )
            
            formatted_shots = []
            for mem in memories:
                formatted_shots.append(
                    HumanMessage(content=f"Requirements:\n{mem['requirements']}")
                )
                
                meta = mem.get("metadata", {})
                plan = meta.get("plan", "No plan available.")
                reasoning = meta.get("reasoning", "No reasoning available.")
                
                assistant_content = (
                    f"1. DESIGN PLAN:\n{plan}\n\n"
                    f"2. DESIGN REASONING:\n{reasoning}\n\n"
                    f"3. PLANTUML DIAGRAM:\n```plantuml\n{mem['diagram']}\n```"
                )
                
                formatted_shots.append(
                    AIMessage(content=assistant_content)
                )
            
            logger.info(f"Retrieved {len(memories)} relevant examples from unified memory")
            return {"examples": formatted_shots}
        except Exception as e:
            logger.error(f"Retrieval failed: {e}")
            return {"examples": []}

    def decompose(self, state: AgentState) -> Dict[str, Any]:
        """
        Decompose requirements into structural building blocks.
        
        Args:
            state: Current workflow state
            
        Returns:
            Dict with 'plan' key containing decomposition
        """
        logger.info(f"--- NODE: {NodeNames.DECOMPOSE.upper()} ---")

        feedback = state.get("audit_feedback", [])
        feedback_str = "\n".join([f"- {f}" for f in feedback]) if feedback else "None"

        system_prompt = DECOMPOSER_SYSTEM
        if feedback:
            system_prompt += f"\n\nIMPORTANT: Your previous plan was rejected. Fix these issues:\n{feedback_str}"
        
        messages = [
            SystemMessage(content=system_prompt),
            HumanMessage(content=f"REQUIREMENTS:\n{state['requirements']}")
        ]
        
        try:
            response = self._safe_invoke(
                self.llm,
                messages,
                max_tokens=self.config.max_tokens_decompose
            )
            logger.info("Decomposition completed")
            return {"plan": response.content}
            
        except Exception as e:
            logger.error(f"Decomposition failed: {e}")
            return {"plan": f"Error: {str(e)}"}

    def logic_auditor(self, state: AgentState) -> Dict[str, Any]:
        """
        Audits the structural plan for logical consistency and requirement coverage.
        """
        logger.info(f"--- NODE: {NodeNames.PLAN_AUDIT.upper()} ---")
        
        plan = state.get("plan")
        requirements = state.get("requirements")
        
        prompt = f"""
        You are a Senior Software Architect auditing a UML Class Diagram plan.
        
        REQUIREMENTS:
        {requirements}
        
        PROPOSED PLAN (JSON):
        {plan}
        
        YOUR TASK:
        1. Check for 'Island Classes' (classes with no relationships).
        2. Ensure all entities mentioned in the requirements exist in the plan.
        3. Check for relationship directionality (e.g., should 'User' own 'Order'?).
        4. Verify that attributes have appropriate types.
        
        If the plan is flawed, be specific about what is missing.
        """
        
        audit_result = self.llm.with_structured_output(PlanAudit).invoke([
            SystemMessage(content="You are a Senior Software Architect auditing a UML Class Diagram plan."),
            HumanMessage(content=prompt)
        ])
        
        return {
            "plan_valid": audit_result.is_valid,
            "audit_feedback": audit_result.critique + audit_result.suggestions,
            "iterations": state["iterations"] + (0 if audit_result.is_valid else 1)
        }

    def generate(self, state: AgentState) -> Dict[str, Any]:
        """
        Generate PlantUML diagram using chain-of-thought reasoning.
        
        Args:
            state: Current workflow state
            
        Returns:
            Dict with 'current_diagram' and 'iterations' updates
        """
        logger.info(f"--- NODE: {NodeNames.GENERATE.upper()} ---")
        
        messages = [SystemMessage(content=GENERATOR_SYSTEM)]
        
        # Add few-shot examples if available
        if state.get("examples"):
            messages.extend(state["examples"])
            logger.debug(f"Added {len(state['examples'])} example messages")
            
        user_content = f"""
        # ORIGINAL REQUIREMENTS
        {state['requirements']}

        # DESIGN PLAN
        {state['plan']}

        # TASK
        Follow the examples above exactly. Output your response in three parts:
        1. DESIGN PLAN: (Briefly refine the plan for implementation)
        2. DESIGN REASONING: (Explain your choice of relationships and cardinality)
        3. PLANTUML DIAGRAM: (The code block)
        """
        
        messages.append(HumanMessage(content=user_content))
        
        try:
            response = self._safe_invoke(
                self.llm,
                messages,
                max_tokens=self.config.max_tokens_generate
            )
            diagram = self.plantuml_tool.extract_plantuml(response.content)
            
            logger.info(f"Generation completed (iteration {state['iterations'] + 1})")
            return {
                "current_diagram": diagram,
                "iterations": state["iterations"] + 1
            }
            
        except Exception as e:
            logger.error(f"Generation failed: {e}")
            return {
                "current_diagram": f"Error: {str(e)}",
                "iterations": state["iterations"] + 1
            }

    def syntax_check(self, state: AgentState) -> Dict[str, Any]:
        """
        Validate PlantUML syntax through server.
        
        Args:
            state: Current workflow state
            
        Returns:
            Dict with 'syntax_valid' and optional 'error_message'
        """
        logger.info(f"--- NODE: {NodeNames.SYNTAX_CHECK.upper()} ---")
        
        try:
            result = self.plantuml_tool.check_syntax(
                state["current_diagram"],
                timeout=self.config.request_timeout
            )
            
            if result.is_valid:
                logger.info(f"Syntax valid. View at: {result.url}")
            else:
                logger.warning(f"Syntax error: {result.error}")
            
            return {
                "syntax_valid": result.is_valid,
                "error_message": result.error if not result.is_valid else None,
                "iterations": state["iterations"] + (0 if result.is_valid else 1)
            }
            
        except Exception as e:
            logger.error(f"Syntax check failed: {e}")
            return {
                "syntax_valid": False,
                "error_message": f"Syntax check error: {str(e)}"
            }

    def _validate_diagram(self, requirements: str, diagram: str) -> CritiqueResponse:
        """
        Helper method to validate a diagram and return the structured response.
        Used by both critic node and reflect node (for rollback decision).
        """
        plantuml_only = self.plantuml_tool.extract_plantuml(diagram)
        user_msg = f"""
        # REQUIREMENTS
        {requirements}

        # DIAGRAM
        {plantuml_only}

        Audit the diagram thoroughly and provide the scoring report.
        """
        
        messages = [
            SystemMessage(content=CRITIC_SYSTEM),
            HumanMessage(content=user_msg)
        ]
        
        structured_llm = self.llm.with_structured_output(CritiqueResponse)
        return self._safe_invoke(structured_llm, messages)

    def critic(self, state: AgentState) -> Dict[str, Any]:
        """
        Perform logical validation of the UML diagram.
                
        Args:
            state: Current workflow state
            
        Returns:
            Dict with 'logic_valid' and 'history' updates
        """
        logger.info(f"--- NODE: {NodeNames.CRITIC.upper()} ---")
        
        try:
            critique_response = self._validate_diagram(state['requirements'], state["current_diagram"])
            
            weighted = critique_response.weighted_score
            
            critique = {
                "is_valid": critique_response.is_valid,
                "requirement_coverage": critique_response.requirement_coverage,
                "design_best_practices": critique_response.design_best_practices,
                "structural_integrity": critique_response.structural_integrity,
                "weighted_score": weighted,
                "errors": [{"type": err.type, "description": err.description} 
                          for err in critique_response.errors],
                "warnings": critique_response.warnings,
                "missing_concepts": critique_response.missing_concepts,
                "reasoning": critique_response.reasoning
            }
            
            is_valid = critique_response.is_valid
            logger.info(f"Logic validation: {'PASSED' if is_valid else 'FAILED'} (Weighted Score: {weighted:.2f})")
            logger.info(f"  Requirements Coverage: {critique_response.requirement_coverage:.2f}/10")
            logger.info(f"  Design Best Practices: {critique_response.design_best_practices:.2f}/10")
            logger.info(f"  Structural Integrity: {critique_response.structural_integrity:.2f}/10")
            
            if not is_valid and critique_response.errors:
                logger.info(f"Found {len(critique_response.errors)} errors")
            
            updates = {
                "logic_valid": is_valid,
                "history": [critique]  + state.get("history", []),
                "current_validation": critique_response
            }
            
            if is_valid and not state.get("best_diagram"):
                logger.info(f"Storing first valid diagram as best")
                updates["best_diagram"] = state["current_diagram"]
            
            return updates
            
        except Exception as e:
            logger.error(f"Critic failed: {e}")
            return {
                "logic_valid": False,
                "history": [{
                    "is_valid": False,
                    "requirement_coverage": 0.0,
                    "design_best_practices": 0.0,
                    "structural_integrity": 0.0,
                    "weighted_score": 0.0,
                    "errors": [{"type": "system", "description": str(e)}],
                    "warnings": [],
                    "missing_concepts": [],
                    "reasoning": "System error during critique."
                }]
            }

    def summarize_memory(self, state: AgentState) -> Dict[str, Any]:
        """
        Summarize progress by comparing current and previous critiques.
        
        Args:
            state: Current workflow state
            
        Returns:
            Dict with 'summary' key containing JSON string
        """
        logger.info(f"--- NODE: {NodeNames.SUMMARIZE.upper()} ---")
        
        if not state.get("history"):
            logger.info("No history to summarize")
            return {"summary": json.dumps({"is_complete": False, "message": "No history"})}
        
        current_critique = state["history"][-1]
        
        # Only look at the last 2 previous critiques to save tokens
        # sending the full history causes context overflow in later iterations
        previous_critiques = state["history"][-3:-1] if len(state["history"]) > 1 else []
        
        user_prompt = f"""
        CURRENT CRITIQUE (Issues in the latest diagram):
        {json.dumps(current_critique)}
        
        PREVIOUS CRITIQUES (Recent history):
        {json.dumps(previous_critiques)}
        """
        
        messages = [
            SystemMessage(content=SUMMARIZER_SYSTEM),
            HumanMessage(content=user_prompt)
        ]
        
        try:
            structured_llm = self.llm.with_structured_output(SummaryResponse)
            summary_response: SummaryResponse = self._safe_invoke(structured_llm, messages)
            
            summary = {
                "is_complete": summary_response.is_complete,
                "fixed": summary_response.fixed,
                "unresolved": summary_response.unresolved,
                "message": summary_response.message
            }
            
            logger.info(f"Summary: {summary_response.message}")
            return {"summary": json.dumps(summary)}
            
        except Exception as e:
            logger.error(f"Summarization failed: {e}")
            return {"summary": json.dumps({
                "is_complete": False,
                "fixed": [],
                "unresolved": [],
                "message": f"Error: {str(e)}"
            })}

    def reflect(self, state: AgentState) -> Dict[str, Any]:
        """
        Fix diagram based on memory summary and error history.
        Implements internal retry loop with dynamic temperature to escape local optima.
        
        Args:
            state: Current workflow state
            
        Returns:
            Dict with 'current_diagram' and 'iterations' updates
        """
        logger.info(f"--- NODE: {NodeNames.REFLECT.upper()} ---")

        last_critique = state["history"][-1]
        prev_score = last_critique.get("weighted_score", 0)
        old_diagram = state['current_diagram']
        best_score = state.get("best_score", prev_score)
        best_diagram = state.get("best_diagram") or old_diagram
        validation_cache = state.get("validation_cache", {})
        tone_instruction = ""
        summary_json: Dict[str, Any] = {}
        unresolved_issues: List[str] = []
        raw_summary = state.get("summary")
        if raw_summary:
            try:
                summary_json = json.loads(raw_summary) if isinstance(raw_summary, str) else (raw_summary if isinstance(raw_summary, dict) else {})
            except Exception as e:
                logger.warning(f"Failed to parse state['summary'] as JSON; using empty summary. Error: {e}")
                summary_json = {}
        unresolved_raw = summary_json.get("unresolved") or []
        if isinstance(unresolved_raw, list):
            unresolved_issues = [str(x) for x in unresolved_raw]
        else:
            unresolved_issues = [str(unresolved_raw)]
        
        error_items = last_critique.get("errors", []) or []
        error_descriptions: List[str] = []
        if isinstance(error_items, list):
            for err in error_items:
                if isinstance(err, dict):
                    error_descriptions.append(str(err.get("description") or err.get("type") or ""))
                else:
                    error_descriptions.append(str(err))
        
        focus_issues = [x for x in (error_descriptions or unresolved_issues) if x]
        focus_issues = focus_issues[:3]
        
        if prev_score < 7.0:
            tone_instruction = (
                "The design is fundamentally wrong. You MUST:\n"
                f"1. Re-examine these missing concepts: {last_critique.get('missing_concepts')}\n"
                "2. Question your class boundaries\n"
                "3. Verify every relationship direction\n"
            )
        elif prev_score < 8.0:
            tone_instruction = (
                "You're close but need targeted fixes:\n"
                f"1. Focus ONLY on: {focus_issues}\n"
                "2. Don't touch working parts\n"
            )

        
        summary_text = (
            f"Message: {summary_json.get('message', '')}\n"
            f"Focus Issues: {', '.join(focus_issues)}"
        )

        base_user_msg = f"""
        {tone_instruction}
        
        [QUALITY SCORE]: {prev_score}/10
        [ERRORS TO FIX]:
        {summary_text}
        
        [MISSING CONCEPTS]: {last_critique.get('missing_concepts', [])}
        
        [CURRENT CODE]:
        {old_diagram}
        """
        
        messages = [
            SystemMessage(content=REFLECTOR_SYSTEM),
            HumanMessage(content=base_user_msg)
        ]

        
        max_retries = 2 
        current_temp = self.config.temperature
        
        def _diagram_key(diagram: str) -> str:
            return hashlib.sha256(diagram.encode("utf-8")).hexdigest()
        
        for attempt in range(max_retries + 1):
            try:
                logger.info(f"Reflection attempt {attempt+1}/{max_retries+1} (temp={current_temp:.2f})")
                
                response = self._safe_invoke(
                    self.llm,
                    messages,
                    max_tokens=self.config.max_tokens_reflect
                )
                new_diagram = self.plantuml_tool.extract_plantuml(response.content)


                if new_diagram.strip() == old_diagram.strip():
                    logger.warning("Generated identical diagram.")
                    if attempt < max_retries:
                        messages.append(HumanMessage(content="You returned the exact same diagram. You MUST change it to fix the errors. Try again."))
                        current_temp = min(1.0, current_temp + 0.2)
                        continue
                    else:
                        break 

                new_key = _diagram_key(new_diagram)
                if new_key in validation_cache:
                    new_validation = validation_cache[new_key]
                else:
                    new_validation = self._validate_diagram(state['requirements'], new_diagram)
                    validation_cache[new_key] = new_validation
                new_score = new_validation.weighted_score
                
                logger.info(f"Attempt {attempt+1} Score: {new_score:.2f} (Previous: {prev_score:.2f})")

                if new_score >= prev_score:
                    if new_score > best_score:
                        best_score = new_score
                        best_diagram = new_diagram
                    logger.info(f"Improvement found! ({new_score:.2f} >= {prev_score:.2f})")
                    return {
                        "current_diagram": new_diagram,
                        "best_diagram": best_diagram,
                        "best_score": best_score,
                        "validation_cache": validation_cache,
                        "iterations": state["iterations"] + 1,
                    }
                

                logger.warning(f"Score dropped to {new_score:.2f}. Retrying...")
                
                if attempt < max_retries:
                    messages.append(HumanMessage(content=f"""
                    Your previous attempt resulted in a LOWER score ({new_score:.2f} < {prev_score:.2f}).
                    The changes you made introduced new issues.
                    Undo those bad changes and try a different approach to fix the original errors.
                    """))
                    current_temp = min(1.0, current_temp + 0.1)
            
            except Exception as e:
                logger.error(f"Reflection attempt {attempt+1} failed: {e}")
        
        logger.warning("All reflection attempts failed to improve score. Rolling back to best diagram.")
        

        recent_history = state.get("history", [])[-4:]
        if len(recent_history) >= 3:
             last_scores = [h.get("weighted_score", 0) for h in recent_history[-2:]]
             if len(set(last_scores)) == 1:
                 logger.warning("Score plateau detected. Stopping.")
                 return {
                     "current_diagram": best_diagram,
                     "best_diagram": best_diagram,
                     "best_score": best_score,
                     "validation_cache": validation_cache,
                     "iterations": self.config.max_iterations 
                 }

        return {
            "current_diagram": best_diagram,
            "best_diagram": best_diagram,
            "best_score": best_score,
            "validation_cache": validation_cache,
            "iterations": state["iterations"] + 1
        }

In [54]:
def create_uml_graph(
    nodes: UMLNodes, 
    config: Optional[SystemConfig] = None
) -> Any:
    """
    Create the LangGraph workflow for UML diagram generation.
    
    Args:
        nodes: UMLNodes instance with all agent methods
        config: Optional system configuration
        
    Returns:
        Compiled LangGraph workflow
    """
    cfg = config or SystemConfig()
    logger.info("Creating UML generation workflow")
    
    workflow = StateGraph(AgentState)

    # Add all nodes
    workflow.add_node(NodeNames.RETRIEVE, nodes.retrieve)
    workflow.add_node(NodeNames.PLAN_AUDIT, nodes.logic_auditor)
    workflow.add_node(NodeNames.DECOMPOSE, nodes.decompose)
    workflow.add_node(NodeNames.GENERATE, nodes.generate)
    workflow.add_node(NodeNames.SYNTAX_CHECK, nodes.syntax_check)
    workflow.add_node(NodeNames.CRITIC, nodes.critic)
    workflow.add_node(NodeNames.SUMMARIZE, nodes.summarize_memory)
    workflow.add_node(NodeNames.REFLECT, nodes.reflect)
    
    logger.debug("Added 7 nodes to workflow")

    # Define edges
    workflow.add_edge(START, NodeNames.RETRIEVE)
    workflow.add_edge(NodeNames.RETRIEVE, NodeNames.DECOMPOSE)
    workflow.add_edge(NodeNames.DECOMPOSE, NodeNames.PLAN_AUDIT)

    def route_after_plan_audit(state: AgentState) -> str:
        """
        Route based on plan audit results.
        
        Args:
            state: Current workflow state
            
        Returns:
            Next node name
        """
        if state["plan_valid"]:
            logger.debug("Routing: plan_audit -> generate")
            return NodeNames.GENERATE
            
        logger.debug("Routing: plan_audit -> decompose")
        return NodeNames.DECOMPOSE

    workflow.add_conditional_edges(
        NodeNames.PLAN_AUDIT, 
        route_after_plan_audit,
        {
            NodeNames.DECOMPOSE: NodeNames.DECOMPOSE,
            NodeNames.GENERATE: NodeNames.GENERATE
        }
    )

    workflow.add_edge(NodeNames.GENERATE, NodeNames.SYNTAX_CHECK)

    def route_after_syntax_check(state: AgentState) -> str:
        """
        Route based on syntax validation results and iteration limits.
        
        Args:
            state: Current workflow state
            
        Returns:
            Next node name
        """
        if state["syntax_valid"]:
            logger.debug("Routing: syntax_check -> critic")
            return NodeNames.CRITIC
            
        if state["iterations"] >= cfg.max_iterations:
            logger.warning(f"Max iterations ({cfg.max_iterations}) reached during syntax check")
            return END
            
        logger.debug("Routing: syntax_check -> reflect")
        return NodeNames.REFLECT

    workflow.add_conditional_edges(
        NodeNames.SYNTAX_CHECK, 
        route_after_syntax_check,
        {
            NodeNames.CRITIC: NodeNames.CRITIC,
            NodeNames.REFLECT: NodeNames.REFLECT,
            END: END
        }
    )

    def is_logic_valid(state: AgentState) -> str:
        """
        Route based on logic validation and iteration limits.
        
        Args:
            state: Current workflow state
            
        Returns:
            Next node name or END
        """
        if state["logic_valid"]:
            logger.info("Diagram validated successfully")
            return END
            
        if state["iterations"] >= cfg.max_iterations:
            logger.warning(f"Max iterations ({cfg.max_iterations}) reached")
            return END
            
        logger.debug("Routing: critic -> summarize")
        return NodeNames.SUMMARIZE

    workflow.add_conditional_edges(
        NodeNames.CRITIC, 
        is_logic_valid,
        {
            END: END,
            NodeNames.SUMMARIZE: NodeNames.SUMMARIZE
        }
    )
    
    workflow.add_edge(NodeNames.SUMMARIZE, NodeNames.REFLECT)

    def route_after_reflect(state: AgentState) -> str:
        """
        Route after reflection based on iteration limits.
        
        Args:
            state: Current workflow state
            
        Returns:
            Next node name
        """
        if state["iterations"] >= cfg.max_iterations:
            logger.warning(f"Max iterations ({cfg.max_iterations}) reached after reflection")
            return END
            
        logger.debug("Routing: reflect -> syntax_check")
        return NodeNames.SYNTAX_CHECK

    workflow.add_conditional_edges(
        NodeNames.REFLECT, 
        route_after_reflect,
        {
            NodeNames.SYNTAX_CHECK: NodeNames.SYNTAX_CHECK,
            END: END
        }
    )
    
    logger.info("Workflow graph created successfully")
    return workflow.compile()


def create_initial_state(requirements: str) -> AgentState:
    """
    Create an initial state for the workflow.
    
    Args:
        requirements: Software requirements text
        
    Returns:
        Initial AgentState dictionary
    """
    return {
        "requirements": requirements,
        "plan": None,
        "examples": [],
        "current_diagram": None,
        "best_diagram": None,
        "history": [],
        "summary": None,
        "syntax_valid": False,
        "logic_valid": False,
        "iterations": 0,
        "error_message": None
    }

In [55]:
def load_test_exercises(json_path: str = "./../data/test_exercises.json") -> List[Dict[str, Any]]:
    """
    Load test exercises from JSON file.
    
    Args:
        json_path: Path to test exercises JSON
        
    Returns:
        List of exercise dictionaries
        
    Raises:
        FileNotFoundError: If file doesn't exist
        json.JSONDecodeError: If JSON is invalid
    """
    logger.info(f"Loading test exercises from {json_path}")
    
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"Test exercises file not found: {json_path}")
    
    with open(json_path, "r", encoding="utf-8") as f:
        exercises = json.load(f)
    
    logger.info(f"Loaded {len(exercises)} test exercises")
    return exercises

try:
    test_exercises = load_test_exercises()
    print(f"Loaded {len(test_exercises)} test exercises")
except Exception as e:
    logger.error(f"Failed to load test exercises: {e}")
    test_exercises = []


2026-01-14 19:04:35,930 - __main__ - INFO - Loading test exercises from ./../data/test_exercises.json
2026-01-14 19:04:35,931 - __main__ - INFO - Loaded 8 test exercises


Loaded 8 test exercises


In [56]:
def initialize_system(
    config: Optional[SystemConfig] = None,
    enable_long_term_memory: bool = True
) -> Tuple[UMLNodes, Any, SystemConfig, Optional[MemoryManager]]:
    """
    Initialize all system components.
    
    Args:
        config: Optional system configuration
        enable_long_term_memory: Whether to enable long-term memory
        
    Returns:
        Tuple of (nodes, compiled_workflow, config, memory_manager)
    """
    cfg = config or SystemConfig()
    logger.info("="*60)
    logger.info("INITIALIZING UML GENERATION SYSTEM")
    logger.info("="*60)
    
    try:
        logger.info("Creating LLM connection...")
        llm = create_llm(cfg)
        
        logger.info("Initializing PlantUML tool...")
        puml_tool = PlantUMLTool(cfg.plantuml_host)
        
        memory_mgr = None
        if enable_long_term_memory:
            logger.info(f"Initializing long-term memory with {cfg.embedder_model}...")

            dims = 1024 if "large" in cfg.embedder_model.lower() else 384
            
            memory_mgr = MemoryManager(
                embedder=SentenceTransformer(cfg.embedder_model),
                db_path=cfg.db_path,
                embedding_dims=dims
            )

            seeded_count = seed_memory_from_shots(
                memory_manager=memory_mgr,
                shots_json_path=cfg.shots_json_path,
                force_reseed=True 
            )

            logger.info("Long-term memory (SQLite + sqlite-vec) enabled")

            if seeded_count > 0:
                logger.info(f"Seeded {seeded_count} few-shot examples into memory")
        else:
            logger.info("Long-term memory disabled")
        

        logger.info("Building LangGraph workflow...")
        nodes = UMLNodes(llm, puml_tool, memory_mgr, cfg)
        app = create_uml_graph(nodes, cfg)
        
        logger.info("="*60)
        logger.info("SYSTEM INITIALIZED SUCCESSFULLY")
        logger.info("="*60)
        
        return nodes, app, cfg, memory_mgr
        
    except Exception as e:
        logger.error(f"System initialization failed: {e}")
        raise


nodes, app, config, memory_manager = initialize_system(enable_long_term_memory=True)
print("\nSystem ready for diagram generation")
print(f"Long-term memory: {'ENABLED' if memory_manager else 'DISABLED'}")


2026-01-14 19:04:35,940 - __main__ - INFO - INITIALIZING UML GENERATION SYSTEM
2026-01-14 19:04:35,941 - __main__ - INFO - Creating LLM connection...
2026-01-14 19:04:35,942 - __main__ - INFO - Connecting to LMStudio at http://localhost:1234/v1
2026-01-14 19:04:35,942 - __main__ - INFO - Using model: mistralai/devstral-small-2-2512 (temp=0.15)
2026-01-14 19:04:35,948 - __main__ - INFO - Initializing PlantUML tool...
2026-01-14 19:04:35,949 - __main__ - INFO - PlantUML tool initialized with host: http://localhost:8080
2026-01-14 19:04:35,949 - __main__ - INFO - Initializing long-term memory with BAAI/bge-large-en-v1.5...
2026-01-14 19:04:35,964 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device_name: mps
2026-01-14 19:04:35,964 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: BAAI/bge-large-en-v1.5
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 8969


System ready for diagram generation
Long-term memory: ENABLED


In [57]:
def run_single_test(
    app: Any,
    requirements: str,
    exercise_name: str = "Test Exercise"
) -> AgentState:
    """
    Run the workflow on a single exercise.
    
    Args:
        app: Compiled LangGraph workflow
        requirements: Software requirements text
        exercise_name: Name for logging purposes
        
    Returns:
        Final workflow state
    """
    logger.info("="*60)
    logger.info(f"RUNNING: {exercise_name}")
    logger.info("="*60)
    logger.info(f"Requirements preview: {requirements[:150]}...")
    
    initial_state = create_initial_state(requirements)
    
    try:
        final_output = app.invoke(initial_state, config={"recursion_limit": 50})
        
        logger.info("="*60)
        logger.info("WORKFLOW COMPLETED")
        logger.info("="*60)
        logger.info(f"Iterations: {final_output['iterations']}")
        logger.info(f"Syntax Valid: {final_output['syntax_valid']}")
        logger.info(f"Logic Valid: {final_output['logic_valid']}")
        
        if final_output.get('best_diagram') and not final_output['logic_valid']:
            if final_output['best_diagram'] != final_output['current_diagram']:
                logger.info("Using BEST diagram instead of final (prevented regression)")
                final_output['current_diagram'] = final_output['best_diagram']
        
        return final_output
        
    except Exception as e:
        logger.error(f"Workflow execution failed: {e}")
        raise


# Select and run a test exercise
test_idx = 0
requirements = test_exercises[test_idx]["requirements"]

final_output = run_single_test(
    app, 
    requirements, 
    f"Exercise {test_idx + 1}"
)


print("\n" + "="*60)
print("FINAL RESULTS")
print("="*60)
print(f"Iterations: {final_output['iterations']}")
print(f"Syntax Valid: {final_output['syntax_valid']}")
print(f"Logic Valid: {final_output['logic_valid']}")

if final_output['current_diagram']:
    puml_tool = PlantUMLTool(config.plantuml_host)
    diagram_url = puml_tool.get_diagram_url(final_output['current_diagram'])
    print(f"\nDiagram URL: {diagram_url}")
    
    print("\nGenerated Diagram:")
    print(final_output['current_diagram'])

2026-01-14 19:05:02,982 - __main__ - INFO - RUNNING: Exercise 1
2026-01-14 19:05:02,983 - __main__ - INFO - Requirements preview: An e-commerce platform manages products, customers, and orders.
Each product has a unique SKU, name, description, price, and stock quantity.
Products ...
2026-01-14 19:05:02,985 - __main__ - INFO - --- NODE: RETRIEVE ---
2026-01-14 19:05:03,142 - __main__ - INFO - Retrieved 3 similar diagrams from SQLite
2026-01-14 19:05:03,142 - __main__ - INFO - Retrieved 3 relevant examples from unified memory
2026-01-14 19:05:03,143 - __main__ - INFO - --- NODE: DECOMPOSE ---
2026-01-14 19:06:25,359 - httpx - INFO - HTTP Request: POST http://localhost:1234/v1/chat/completions "HTTP/1.1 200 OK"
2026-01-14 19:06:25,377 - __main__ - INFO - Decomposition completed
2026-01-14 19:06:25,382 - __main__ - INFO - --- NODE: PLAN_AUDIT ---
2026-01-14 19:06:36,028 - httpx - INFO - HTTP Request: POST http://localhost:1234/v1/chat/completions "HTTP/1.1 200 OK"
2026-01-14 19:06:36,077 -


FINAL RESULTS
Iterations: 6
Syntax Valid: True
Logic Valid: False

Diagram URL: http://localhost:8080/png/TLEzJWCn3Dxp554NVgGEh3U12aC7HA7io2HM6v6V8t643iJA4_24F0c-QvjBIrdEFx4_svzzbuaWKdPM26KX9Rc8GMT5yaD8cLvo8vSKZL-nvS5XPMfCAfgMJF2Ljur6STGrAkF0zXWKrCjz1a-6kaREFU4Ae_ZSrPi1EqBiXe96Zn471SU4p90Euv203eojotF4MuYwSgMrQLjMhxIEc58Zjme_FHjhZxsKMdTHRmk5NhlcNdas_ZuQXyDmG9aJRtnsR58Wi5SkP4yZ1VjWSq8t0Bx-eKJAx0qjzWaisSjr6CQOyReOeFWNq3dqctsr6_qMQHrg2YY2Jw3wLStr7bj3_BAUeW1loTWS0aL4gZBLqFd8CGtC5FlTjBEWQu_gIObTlSuJ-O0M-gRw5IMnkxp9nKHEft9oFgaeQnZ14nYVn6Of_qGPhspcqnJbJ0u3EQYreF0GYCMEoqQkuGW2YxqjD3AySJyyo2tuRafXsajvkXlav-kxdFVPEDvvwQVSVYEVqGQ_IayKXBX4h_cF_0K=

Generated Diagram:
@startuml

class Product {
  sku: String
  name: String
  description: String
  price: Decimal
  stockQuantity: Integer
}

class Category {
  id: Integer
  name: String
}

class Customer {
  email: String
  password: String
  shippingAddress: Address
  billingAddress: Address
}

class Order {
  orderDate: DateTime
  status: Enum
  tot

In [34]:
print(final_output['current_diagram'])

@startuml

class Book {
  isbn: String
  title: String
  author: String
  publisher: String
  publicationYear: Integer
  availabilityStatus: Boolean
}

class Member {
  membershipId: String
  name: String
  address: String
  phoneNumber: String
  registrationDate: Date
}

class Student {
  studentId: String
  programOfStudy: String
}

class FacultyMember {
  employeeId: String
  department: String
}

class Loan {
  checkoutDate: Date
  dueDate: Date
  returnDate: Date
}

class Fine {
  fineAmount: Decimal
  paymentStatus: Boolean
}

class Reservation {
  reservationDate: Date
}

Member <|-- Student
Member <|-- FacultyMember

Member "1" -- "*" Loan : borrows
Book "1" -- "*" Loan : is borrowed via

Member "1" -- "*" Reservation : makes
Book "1" -- "*" Reservation : is reserved by

Loan "1" -- "0..1" Fine : has

@enduml


In [35]:
class EvaluationMetrics(BaseModel):
    """Container for evaluation metrics."""
    precision: float = Field(ge=0.0, le=1.0, description="Precision score")
    recall: float = Field(ge=0.0, le=1.0, description="Recall score")
    f1: float = Field(ge=0.0, le=1.0, description="F1 score")
    
    def __str__(self) -> str:
        return f"P={self.precision:.2f}, R={self.recall:.2f}, F1={self.f1:.2f}"


class PlantUMLParser:
    """
    Parser for extracting structured information from PlantUML diagrams.
    
    Extracts classes, attributes, and relationships from PlantUML code
    for evaluation purposes.
    """
    
    def __init__(self, plantuml_code: str):
        """
        Initialize parser with PlantUML code.
        
        Args:
            plantuml_code: PlantUML diagram code
        """
        self.plantuml_code = plantuml_code
        self.classes: Dict[str, Dict[str, List[str]]] = {}
        self.relationships: List[Dict[str, Any]] = []
        self.parse()
    
    def parse(self) -> None:
        """Parse the PlantUML code."""
        try:
            self._extract_classes()
            self._extract_relationships()
            logger.debug(f"Parsed {len(self.classes)} classes and {len(self.relationships)} relationships")
        except Exception as e:
            logger.error(f"Parsing failed: {e}")
    
    def _extract_classes(self) -> None:
        """Extract class definitions and their attributes."""
        class_pattern = r'class\s+(\w+)\s*\{([^}]*)\}'
        matches = re.finditer(class_pattern, self.plantuml_code, re.MULTILINE | re.DOTALL)
        
        for match in matches:
            class_name = match.group(1)
            class_body = match.group(2)
            
            attributes = []
            for line in class_body.strip().split('\n'):
                line = line.strip()
                if line and not line.startswith('--'):
                    attributes.append(line)
            
            self.classes[class_name] = {'attributes': attributes}
    
    def _extract_relationships(self) -> None:
        """Extract relationships between classes with cardinalities."""
        patterns = [
            # Generalization (either direction)
            (r'(\w+)\s*(?:"([^"]*)")?\s*<\|--\s*(?:"([^"]*)")?\s*(\w+)', 'generalization'),
            (r'(\w+)\s*(?:"([^"]*)")?\s*--|>\s*(?:"([^"]*)")?\s*(\w+)', 'generalization'),
            
            # Composition (either direction)
            (r'(\w+)\s*(?:"([^"]*)")?\s*\*--\s*(?:"([^"]*)")?\s*(\w+)', 'composition'),
            (r'(\w+)\s*(?:"([^"]*)")?\s*--\*\s*(?:"([^"]*)")?\s*(\w+)', 'composition'),
            
            # Aggregation (either direction)
            (r'(\w+)\s*(?:"([^"]*)")?\s*o--\s*(?:"([^"]*)")?\s*(\w+)', 'aggregation'),
            (r'(\w+)\s*(?:"([^"]*)")?\s*--o\s*(?:"([^"]*)")?\s*(\w+)', 'aggregation'),
            
            # Directed Association (either direction)
            (r'(\w+)\s*(?:"([^"]*)")?\s*-->\s*(?:"([^"]*)")?\s*(\w+)', 'association'),
            (r'(\w+)\s*(?:"([^"]*)")?\s*<--\s*(?:"([^"]*)")?\s*(\w+)', 'association'),
            
            # Simple Association (no arrow)
            (r'(\w+)\s*(?:"([^"]*)")?\s*--\s*(?:"([^"]*)")?\s*(\w+)', 'association'),
        ]
        
        for pattern, rel_type in patterns:
            for match in re.finditer(pattern, self.plantuml_code):
                source = match.group(1)
                target = match.group(4)
                
                # Skip if source or target is None or empty
                if not source or not target:
                    continue
                
                self.relationships.append({
                    'type': rel_type,
                    'source': source,
                    'target': target,
                    'cardinality_source': match.group(2) if match.lastindex >= 2 else None,
                    'cardinality_target': match.group(3) if match.lastindex >= 3 else None
                })


class DiagramEvaluator:
    """
    Evaluator for comparing generated diagrams against gold standards.
    
    Computes precision, recall, and F1 scores for classes, attributes,
    and relationships.
    """
    
    def __init__(self, gold_plantuml: str, pred_plantuml: str):
        """
        Initialize evaluator with gold and predicted diagrams.
        
        Args:
            gold_plantuml: Gold standard PlantUML code
            pred_plantuml: Predicted PlantUML code
        """
        self.gold_parser = PlantUMLParser(gold_plantuml)
        self.pred_parser = PlantUMLParser(pred_plantuml)
    
    def _normalize_attr(self, attr_str: str) -> str:
        """Normalize attribute strings for comparison."""
        return attr_str.split(':')[0].strip().lower()
    
    def _normalize_rel_type(self, rel_type: str) -> str:
        """Normalize relationship types."""
        mapping = {
            '<|--': 'INHERITANCE',
            '--|>': 'INHERITANCE',
            '*--': 'COMPOSITION',
            '--*': 'COMPOSITION',
            'o--': 'AGGREGATION',
            '--o': 'AGGREGATION',
            '--': 'ASSOCIATION',
            '<--': 'ASSOCIATION',
            '-->': 'ASSOCIATION'
        }
        return mapping.get(rel_type, 'ASSOCIATION')
    
    def _calculate_metrics(
        self, 
        gold_set: set, 
        pred_set: set
    ) -> EvaluationMetrics:
        """
        Calculate precision, recall, and F1 scores.
        
        Args:
            gold_set: Set of gold standard elements
            pred_set: Set of predicted elements
            
        Returns:
            EvaluationMetrics object
        """
        tp = len(gold_set.intersection(pred_set))
        fp = len(pred_set - gold_set)
        fn = len(gold_set - pred_set)
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
        
        return EvaluationMetrics(
            precision=round(precision, 2),
            recall=round(recall, 2),
            f1=round(f1, 2)
        )
    
    def get_metrics(self) -> Dict[str, EvaluationMetrics]:
        """
        Get all evaluation metrics.
        
        Returns:
            Dictionary with metrics for classes, attributes, and relationships
        """
        # Classes 
        gold_classes = {c.lower() for c in self.gold_parser.classes.keys()}
        pred_classes = {c.lower() for c in self.pred_parser.classes.keys()}
        
        # Attributes 
        gold_attrs = set()
        for cls, info in self.gold_parser.classes.items():
            for attr in info['attributes']:
                gold_attrs.add((cls.lower(), self._normalize_attr(attr)))
        
        pred_attrs = set()
        for cls, info in self.pred_parser.classes.items():
            for attr in info['attributes']:
                pred_attrs.add((cls.lower(), self._normalize_attr(attr)))
        
        # Relationships (type + direction only)
        gold_rels = {
            (r['source'].lower(), r['target'].lower(), self._normalize_rel_type(r['type']))
            for r in self.gold_parser.relationships
            if r.get('source') and r.get('target')
        }
        pred_rels = {
            (r['source'].lower(), r['target'].lower(), self._normalize_rel_type(r['type']))
            for r in self.pred_parser.relationships
            if r.get('source') and r.get('target')
        }
        
        # Cardinalities 
        gold_rels_card = {
            (r['source'].lower(), r['target'].lower(), self._normalize_rel_type(r['type']),
             (r['cardinality_source'] or '').strip(), (r['cardinality_target'] or '').strip())
            for r in self.gold_parser.relationships
            if r.get('source') and r.get('target')
        }
        pred_rels_card = {
            (r['source'].lower(), r['target'].lower(), self._normalize_rel_type(r['type']),
             (r['cardinality_source'] or '').strip(), (r['cardinality_target'] or '').strip())
            for r in self.pred_parser.relationships
            if r.get('source') and r.get('target')
        }
        
        result = {
            "classes": self._calculate_metrics(gold_classes, pred_classes),
            "attributes": self._calculate_metrics(gold_attrs, pred_attrs),
            "relationships": self._calculate_metrics(gold_rels, pred_rels),
            "cardinalities": self._calculate_metrics(gold_rels_card, pred_rels_card)
        }
        
        return result


def evaluate_diagram(
    gold_standard: str,
    generated_diagram: str
) -> Dict[str, EvaluationMetrics]:
    """
    Evaluate a generated diagram against gold standard.
    
    Args:
        gold_standard: Gold standard PlantUML code
        generated_diagram: Generated PlantUML code
        
    Returns:
        Dictionary of evaluation metrics
    """
    evaluator = DiagramEvaluator(gold_standard, generated_diagram)
    return evaluator.get_metrics()


In [36]:
gold_standard = test_exercises[test_idx]["solution_plantuml"]
generated_diagram = final_output["current_diagram"]

metrics = evaluate_diagram(gold_standard, generated_diagram)

print("="*60)
print("EVALUATION METRICS")
print("="*60)
print(f"\nClasses:       {metrics['classes']}")
print(f"Attributes:    {metrics['attributes']}")
print(f"Relationships: {metrics['relationships']}")
print(f"Cardinalities: {metrics['cardinalities']}")

weighted_avg_f1 = (
    metrics['classes'].f1 * 0.3 + 
    metrics['attributes'].f1 * 0.2 + 
    metrics['relationships'].f1 * 0.3 +
    metrics['cardinalities'].f1 * 0.2
)

print(f"\n{'='*60}")
print(f"OVERALL F1 SCORE: {weighted_avg_f1:.2f}")
print(f"{'='*60}")


EVALUATION METRICS

Classes:       P=0.86, R=0.86, F1=0.86
Attributes:    P=0.90, R=0.86, F1=0.88
Relationships: P=0.86, R=0.86, F1=0.86
Cardinalities: P=0.29, R=0.29, F1=0.29

OVERALL F1 SCORE: 0.75


In [46]:
class BatchResult(BaseModel):
    """Result from a single exercise in batch evaluation."""
    exercise_num: int = Field(description="Exercise number")
    success: bool = Field(description="Whether the exercise was successful")
    iterations: int = Field(default=0, ge=0, description="Number of iterations used")
    syntax_valid: bool = Field(default=False, description="Whether syntax validation passed")
    logic_valid: bool = Field(default=False, description="Whether logic validation passed")
    metrics: Optional[Dict[str, EvaluationMetrics]] = Field(default=None, description="Evaluation metrics")
    diagram_url: Optional[str] = Field(default=None, description="URL to view the diagram")
    error: Optional[str] = Field(default=None, description="Error message if failed")
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for DataFrame creation."""
        if self.metrics:
            return {
                "exercise": self.exercise_num,
                "success": self.success,
                "iterations": self.iterations,
                "syntax_valid": self.syntax_valid,
                "logic_valid": self.logic_valid,
                "class_f1": self.metrics['classes'].f1,
                "attr_f1": self.metrics['attributes'].f1,
                "rel_f1": self.metrics['relationships'].f1,
                "card_f1": self.metrics['cardinalities'].f1,
                "diagram_url": self.diagram_url
            }
        return {
            "exercise": self.exercise_num,
            "success": self.success,
            "error": self.error
        }


def evaluate_batch(
    app: Any,
    test_exercises: List[Dict[str, Any]],
    puml_tool: PlantUMLTool,
    max_exercises: Optional[int] = None
) -> pd.DataFrame:
    """
    Run batch evaluation on multiple exercises.
    
    Args:
        app: Compiled LangGraph workflow
        test_exercises: List of exercise dictionaries
        puml_tool: PlantUML tool for URL generation
        max_exercises: Optional limit on number of exercises
        
    Returns:
        DataFrame with evaluation results
    """
    exercises_to_test = test_exercises[:max_exercises] if max_exercises else test_exercises
    results = []
    
    logger.info("="*60)
    logger.info(f"BATCH EVALUATION: {len(exercises_to_test)} exercises")
    logger.info("="*60)
    
    for i, exercise in enumerate(exercises_to_test):
        logger.info(f"\n--- Exercise {i+1}/{len(exercises_to_test)} ---")
        
        try:
            # Run workflow
            requirements = exercise["requirements"]
            final_output = run_single_test(app, requirements, f"Exercise {i+1}")
            
            # Evaluate
            gold_standard = exercise["solution_plantuml"]
            generated_diagram = final_output["current_diagram"]
            metrics = evaluate_diagram(gold_standard, generated_diagram)
            
            result = BatchResult(
                exercise_num=i + 1,
                success=True,
                iterations=final_output["iterations"],
                syntax_valid=final_output["syntax_valid"],
                logic_valid=final_output["logic_valid"],
                metrics=metrics,
                diagram_url=puml_tool.get_diagram_url(generated_diagram)
            )
            
            logger.info(f"Exercise {i+1}: F1 = Classes:{metrics['classes'].f1:.2f} | "
                       f"Attrs:{metrics['attributes'].f1:.2f} | Rels:{metrics['relationships'].f1:.2f} | "
                       f"Cards:{metrics['cardinalities'].f1:.2f}")
            
        except Exception as e:
            logger.error(f"✗ Exercise {i+1} failed: {e}")
            result = BatchResult(
                exercise_num=i + 1,
                success=False,
                error=str(e)
            )
        
        results.append(result.to_dict())
    
    df = pd.DataFrame(results)
    logger.info("\n" + "="*60)
    logger.info("BATCH EVALUATION COMPLETE")
    logger.info("="*60)
    
    return df


# Example: Run on first 3 exercises (uncomment to execute)
df_results = evaluate_batch(app, test_exercises, puml_tool, max_exercises=3)
# 
print("\n" + "="*60)
print("BATCH EVALUATION SUMMARY")
print("="*60)
successful = df_results[df_results['success'] == True]
if not successful.empty:
    print(successful[['exercise', 'class_f1', 'attr_f1', 'rel_f1']].to_string(index=False))
    avg_f1 = successful[['class_f1', 'attr_f1', 'rel_f1']].mean().mean()
    print(f"\nAverage F1: {avg_f1:.2f}")
else:
    print("No successful evaluations")


2026-01-14 18:59:52,726 - __main__ - INFO - BATCH EVALUATION: 3 exercises
2026-01-14 18:59:52,726 - __main__ - INFO - 
--- Exercise 1/3 ---
2026-01-14 18:59:52,727 - __main__ - INFO - RUNNING: Exercise 1
2026-01-14 18:59:52,727 - __main__ - INFO - Requirements preview: An e-commerce platform manages products, customers, and orders.
Each product has a unique SKU, name, description, price, and stock quantity.
Products ...
2026-01-14 18:59:52,731 - __main__ - INFO - --- NODE: RETRIEVE ---
2026-01-14 18:59:55,601 - __main__ - INFO - Retrieved 3 similar diagrams from SQLite
2026-01-14 18:59:55,604 - __main__ - INFO - Retrieved 3 relevant examples from unified memory
2026-01-14 18:59:55,606 - __main__ - INFO - --- NODE: DECOMPOSE ---
2026-01-14 19:00:29,071 - httpx - INFO - HTTP Request: POST http://localhost:1234/v1/chat/completions "HTTP/1.1 200 OK"
2026-01-14 19:00:29,081 - __main__ - INFO - Decomposition completed
2026-01-14 19:00:29,083 - __main__ - INFO - --- NODE: PLAN_AUDIT ---
2026-

KeyboardInterrupt: 