In [604]:
import json
import hashlib
import re
import zlib
import base64
import requests
import operator
import os
import logging
import time
import pandas as pd
import sys

# sqlite-vec expects the standard library sqlite3 module.
# Older versions of this notebook replaced sqlite3 with sqlean via sys.modules; undo that if present.
if sys.modules.get("sqlite3") is sys.modules.get("sqlean"):
    del sys.modules["sqlite3"]
import sqlite3

from datetime import datetime
from enum import Enum
from typing import Annotated, List, TypedDict, Optional, Dict, Any, Tuple, Literal
from pydantic import BaseModel, Field, computed_field, field_validator, model_validator
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.documents import Document
from langchain_community.vectorstores import SQLiteVec
from langchain_community.embeddings import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer
from langgraph.graph import StateGraph, START, END
from prompts import DECOMPOSER_SYSTEM, GENERATOR_SYSTEM, CRITIC_SYSTEM, REFLECTOR_SYSTEM, PLAN_AUDITOR_SYSTEM, STRUCTURE_REFINER_SYSTEM


logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

In [605]:
class NodeNames(str, Enum):
    """Enum for node names to avoid string literals."""
    RETRIEVE = "retrieve"
    DECOMPOSE = "decompose"
    GENERATE = "generate"
    SYNTAX_CHECK = "syntax_check"
    CRITIC = "critic"
    REFLECTOR = "reflector"
    PLAN_AUDIT = "plan_audit"
    STRUCTURE_REFINER = "structure_refiner"


class Attribute(BaseModel):
    """Model for a class attribute."""
    name: str = Field(description="Attribute name")
    type: str = Field(description="Attribute type")


class Class(BaseModel):
    """Model for a UML class."""
    name: str = Field(description="Class name")
    attributes: List[Attribute] = Field(default_factory=list, description="List of class attributes")


class Relationship(BaseModel):
    """Model for a relationship between classes."""
    source: str = Field(description="Source class name")
    target: str = Field(description="Target class name")
    type: str = Field(description="Relationship type (e.g., association, composition, inheritance)")
    source_multiplicity: str = Field(description="Multiplicity at source end")
    target_multiplicity: str = Field(description="Multiplicity at target end")


class DecompositionResult(BaseModel):
    """Structured output from the DECOMPOSE node."""
    classes: List[Class] = Field(default_factory=list, description="List of identified classes")
    relationships: List[Relationship] = Field(default_factory=list, description="List of relationships between classes")


class PlanAudit(BaseModel):
    critique: List[Literal[
        "Missing class",
        "Missing relationship",
        "Disconnected class"
    ]] = Field(default_factory=list, description="List of critique points found in the plan.")
    suggestions: List[str] = Field(default_factory=list, description="Actionable steps to fix the plan.")
    
    @computed_field
    @property
    def is_valid(self) -> bool:
        return len(self.critique) == 0


class SystemConfig(BaseModel):
    """System configuration for UML generation."""
    lmstudio_base_url: str = Field(default="http://localhost:1234/v1", description="LMStudio API endpoint")
    model_name: str = Field(default="mistralai/devstral-small-2-2512", description="Model to use")
    embedder_model: str = Field(default="BAAI/bge-large-en-v1.5", description="Embedder model for semantic search")
    db_path: str = Field(default="./../data/uml_knowledge.db", description="Path to SQLite database")
    shots_json_path: str = Field(default="./../data/complete_shots.json", description="Path to few-shot examples")
    plantuml_host: str = Field(default="http://localhost:8080", description="PlantUML server host")
    max_iterations: int = Field(default=6, ge=1, description="Maximum workflow iterations")
    max_tokens_decompose: int = Field(default=2048, description="Max tokens for decompose step")
    max_tokens_generate: int = Field(default=2048, description="Max tokens for generate step")
    max_tokens_critique: int = Field(default=2048, description="Max tokens for critique step")
    max_tokens_reflect: int = Field(default=2048, description="Max tokens for reflect step")
    max_tokens_refine: int = Field(default=2048, description="Max tokens for structure refine step")
    temperature: float = Field(default=0.15, ge=0.0, le=2.0, description="Base temperature for LLM")
    num_few_shots: int = Field(default=3, ge=0, description="Number of few-shot examples")
    request_timeout: int = Field(default=5, ge=1, description="Timeout for PlantUML server requests")
    llm_timeout: int = Field(default=600, ge=1, description="Timeout for LLM operations")
    # New fields for iteration metrics
    plateau_window: int = Field(default=3, ge=2, description="Number of iterations to consider for plateau detection")
    plateau_threshold: float = Field(default=0.1, ge=0.0, description="Score delta threshold for plateau detection")
    max_stagnant_iterations: int = Field(default=2, ge=1, description="Max consecutive stagnant iterations before stopping")


class PlantUMLResult(BaseModel):
    """Result from PlantUML validation."""
    is_valid: bool = Field(description="Whether the PlantUML syntax is valid")
    error: Optional[str] = Field(default=None, description="Error message if validation failed")
    url: Optional[str] = Field(default=None, description="URL to view the diagram")
    svg_url: Optional[str] = Field(default=None, description="URL to view the diagram as SVG")

In [606]:
def create_llm(config: Optional[SystemConfig] = None) -> ChatOpenAI:
    """
    Create a ChatOpenAI instance configured for LMStudio.
    
    Args:
        config: Optional system configuration
        
    Returns:
        Configured ChatOpenAI instance
    """
    cfg = config or SystemConfig()
    logger.info(f"Connecting to LMStudio at {cfg.lmstudio_base_url}")
    logger.info(f"Using model: {cfg.model_name} (temp={cfg.temperature})")
    
    return ChatOpenAI(
        base_url=cfg.lmstudio_base_url,
        api_key="lm-studio",  
        model=cfg.model_name,
        temperature=cfg.temperature,
        timeout=cfg.llm_timeout 
    )

In [607]:
class Fixability(str, Enum):
    render_only = "render_only"
    structure_change = "structure_change"
    unfixable = "unfixable"


class Severity(str, Enum):
    error = "error"
    warning = "warning"


class FindingCategory(str, Enum):
    coverage = "coverage"         
    structure = "structure"        
    render = "render"              
    syntax = "syntax"            


class CritiqueFinding(BaseModel):
    id: str = Field(default=None, description="Stable identifier for this finding")
    category: FindingCategory
    severity: Severity = Field(default=Severity.error, description="Severity level of the finding")
    fixability: Fixability
    affected_elements: List[str]
    description: str
    expected_correction: Optional[str] = None

    @staticmethod
    def _slugify(text: str) -> str:
        text = re.sub(r"[^a-z0-9]+", "_", (text or "").lower())
        return re.sub(r"_+", "_", text).strip("_") or "issue"

    @staticmethod
    def _entities_from_affected(affected: List[str]) -> Tuple[str, str]:
        entity, related = "none", "none"
        for a in affected or []:
            if a.startswith("class:"):
                entity = a.split(":", 1)[1].strip() or entity
            elif a.startswith("attr:"):
                entity = (a.split(":", 1)[1].split(".", 1)[0].strip() or entity)
            elif a.startswith("rel:"):
                rel = a.split(":", 1)[1]
                parts = re.split(r"<\|--|--\|>|<--|-->|--|\*--|--\*", rel)
                if len(parts) >= 2:
                    entity = parts[0].strip() or entity
                    related = parts[1].strip() or related
        return entity or "none", related or "none"

    @model_validator(mode="after")
    def ensure_canonical_id(self):
        if self.id and self.id.count("::") == 3 and all(p.strip() for p in self.id.split("::")):
            return self

        entity, related = self._entities_from_affected(self.affected_elements)

        raw = (self.id or "").strip()
        tail = raw.split(":", 1)[1].strip() if ":" in raw else raw

        tail_slug = self._slugify(tail)
        if "missing_attribute" in tail_slug:
            issue_slug = "missing_attribute"
        elif "missing_class" in tail_slug:
            issue_slug = "missing_class"
        elif "duplicate_relationship" in tail_slug:
            issue_slug = "duplicate_relationship"
        else:
            seed = f"{self.category.value}|{self.fixability.value}|{','.join(sorted(self.affected_elements or []))}|{self.description[:80]}"
            issue_slug = f"auto_{hashlib.md5(seed.encode()).hexdigest()[:8]}"

        self.id = f"{self.category.value}::{entity}::{related}::{issue_slug}"
        return self


class CritiqueSummary(BaseModel):
    total_findings: int = 0
    render_only: int = 0
    structure_change: int = 0
    unfixable: int = 0
    new_findings: int = 0
    resolved_findings: int = 0


class CritiqueReport(BaseModel):
    findings: List[CritiqueFinding]

    @computed_field
    @property
    def is_valid(self) -> bool:
        return not any(f.severity == Severity.error for f in self.findings)


In [608]:
class AgentState(TypedDict):
    """
    Shared state for the LangGraph workflow.
    """
    requirements: str
    plan: Optional[str]
    examples: List[Dict[str, str]]
    current_diagram: Optional[str]
    best_diagram: Optional[str]  
    summary: Optional[str]
    syntax_valid: bool
    logic_valid: bool
    error_message: Optional[str]
    plan_valid: bool 
    audit_feedback: Optional[List[str]]  # Feedback from auditor
    plan_audit_attempts: int  # audit loop iterations
    best_score: float
    best_code: str
    current_validation: Optional[CritiqueReport]
    failed_attempts: Annotated[List[Dict[str, Any]], operator.add]
    iterations: int
    stagnant_count: int  # Track consecutive iterations without changes
    critique_cache: Dict[str, Dict[str, Any]]  # Cache critiques by diagram hash
    score_history: List[float]  # History of weighted scores for plateau detection
    delta_score: float  # Score change from previous iteration
    audit_suggestions: Optional[List[str]]  # Suggestions from plan auditor

In [609]:
class PlantUMLTool:
    """
    Tool for validating and rendering PlantUML diagrams.
    
    This class interfaces with a PlantUML server to check syntax
    and generate diagram URLs.
    """
    
    def __init__(self, host: str = "http://localhost:8080"):
        """
        Initialize PlantUML tool.
        
        Args:
            host: PlantUML server host URL
        """
        self.host = host
        logger.info(f"PlantUML tool initialized with host: {host}")

    def extract_plantuml(self, text: str) -> str:
        """
        Extract PlantUML code from markdown blocks or raw text.
        
        Args:
            text: Text containing PlantUML code
            
        Returns:
            Extracted PlantUML code or empty string
        """
        if not text:
            return ""
        
        # Try to extract from ```plantuml ... ```
        fence_match = re.search(r"```\s*plantuml\s*(.*?)```", text, re.DOTALL | re.IGNORECASE)
        if fence_match:
            return fence_match.group(1).strip()
        
        # Try to extract from @startuml ... @enduml
        tag_match = re.search(r"@startuml.*?@enduml", text, re.DOTALL | re.IGNORECASE)
        if tag_match:
            return tag_match.group(0).strip()
        
        return text.strip()

    def _encode_plantuml(self, plantuml_code: str) -> str:
        """
        Encode PlantUML code for URL.
        
        Args:
            plantuml_code: Raw PlantUML code
            
        Returns:
            URL-safe encoded string
        """
        code = plantuml_code.strip()
        
        if not code.startswith("@startuml"): 
            code = f"@startuml\n{code}"
        if not code.endswith("@enduml"): 
            code = f"{code}\n@enduml"
        
        compressed = zlib.compress(code.encode('utf-8'))[2:-4]
        encoded = base64.b64encode(compressed).translate(
            bytes.maketrans(
                b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
                b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-_"
            )
        ).decode('utf-8')
        
        return encoded

    def get_diagram_url(self, plantuml_code: str, format: str = "png") -> str:
        """
        Generate a viewable URL for the PlantUML diagram.
        
        Args:
            plantuml_code: PlantUML diagram code
            format: Output format (png, svg, etc.)
            
        Returns:
            URL to view the diagram
        """
        diagram_code = self.extract_plantuml(plantuml_code)
        encoded = self._encode_plantuml(diagram_code)
        return f"{self.host}/{format}/{encoded}"
        
    def check_syntax(self, plantuml_code: str, timeout: int = 5) -> PlantUMLResult:
        """
        Validate PlantUML syntax with detailed error extraction.
        
        Args:
            plantuml_code: PlantUML code to validate
            timeout: Request timeout in seconds
            
        Returns:
            PlantUMLResult with validation status and detailed error if applicable.
        """
        logger.info("Validating PlantUML syntax")
        
        try:
            diagram_code = self.extract_plantuml(plantuml_code)
            encoded = self._encode_plantuml(diagram_code)
            
            url_png = f"{self.host}/png/{encoded}"
            response = requests.get(url_png, timeout=timeout)
            
            if response.status_code == 200 and response.content[:4] == b'\x89PNG':
                logger.info("Syntax validation passed (PNG rendered)")
                return PlantUMLResult(
                    is_valid=True,
                    url=url_png,
                    svg_url=f"{self.host}/svg/{encoded}"
                )
            
            logger.warning("PNG rendering failed. Fetching detailed syntax error...")
            url_txt = f"{self.host}/txt/{encoded}"
            error_response = requests.get(url_txt, timeout=timeout)
            
            detailed_error = error_response.text.strip() if error_response.status_code == 200 else "Unknown server error"
            
            error_msg = f"PlantUML Syntax Error:\n{detailed_error[:1000]}"
            logger.error(f"Syntax error detected: {error_msg}")
            
            return PlantUMLResult(
                is_valid=False,
                error=error_msg
            )
            
        except requests.exceptions.RequestException as e:
            error_msg = f"PlantUML Server Connection Error: {str(e)}"
            logger.error(error_msg)
            return PlantUMLResult(is_valid=False, error=error_msg)
            
        except Exception as e:
            error_msg = f"Unexpected error during syntax check: {str(e)}"
            logger.error(error_msg)
            return PlantUMLResult(is_valid=False, error=error_msg)

In [610]:
class MemoryManager:
    """
    Manages long-term memory for UML diagram generation using LangChain's SQLiteVec.
    
    Supports semantic search to find similar past solutions.
    """
    
    def __init__(
        self,
        embedder: SentenceTransformer,
        db_path: str = "./../data/uml_knowledge.db",
        embedding_dims: int = 1024
    ):
        """
        Initialize memory manager with LangChain SQLiteVec.
        
        Args:
            embedder: SentenceTransformer model for semantic search
            db_path: Path to the SQLite database file
            embedding_dims: Dimensions of the embeddings 
        """
        self.embedder = embedder
        self.db_path = db_path
        self.embedding_dims = embedding_dims
        

        self.embedding_function = HuggingFaceEmbeddings(
            model_name=embedder.model_name if hasattr(embedder, 'model_name') else "BAAI/bge-large-en-v1.5",
            model_kwargs={'device': 'cpu'},
            encode_kwargs={'normalize_embeddings': True}
        )
        
        # Create directory and connection
        os.makedirs(os.path.dirname(self.db_path) if os.path.dirname(self.db_path) else ".", exist_ok=True)
        
        # Create connection using sqlean instead of default sqlite3 (which lacks extension support on macOS)
        try:
            import sqlean
            import sqlite_vec
            
            self.connection = sqlean.connect(self.db_path)
            self.connection.row_factory = sqlean.Row
            self.connection.enable_load_extension(True)
            sqlite_vec.load(self.connection)
            self.connection.enable_load_extension(False)
            logger.info("Used sqlean for SQLite connection (extension support enabled)")
        except ImportError:
            logger.warning("sqlean not found, falling back to SQLiteVec.create_connection (may fail on macOS)")
            self.connection = SQLiteVec.create_connection(db_file=self.db_path)
        
        # Initialize vector store with connection
        self.vector_store = SQLiteVec(
            table="uml_memories",
            connection=self.connection,
            embedding=self.embedding_function
        )
        
        logger.info(f"MemoryManager initialized with LangChain SQLiteVec at {db_path} (dims={embedding_dims})")

    def save_diagram(
        self,
        requirements: str,
        diagram: str,
        metadata: Optional[Dict[str, Any]] = None
    ) -> int:
        """
        Save a validated diagram to SQLite long-term memory.
                
        Args:
            requirements: Original requirements text
            diagram: PlantUML diagram code
            metadata: Optional metadata
            
        Returns:
            ID of the stored record
        """
        timestamp = datetime.now().isoformat()
        
        full_metadata = metadata or {}
        full_metadata.update({
            "diagram": diagram,
            "timestamp": timestamp
        })
        

        doc = Document(
            page_content=requirements,
            metadata=full_metadata
        )
        
        ids = self.vector_store.add_documents([doc])
        
        logger.info("Diagram saved to SQLite memory using LangChain SQLiteVec")
        return ids[0] if ids else 0
    
    def retrieve_similar_diagrams(
        self,
        requirements: str,
        limit: int = 2
    ) -> List[Dict[str, Any]]:
        """
        Retrieve similar diagrams from SQLite memory using vector search.
        
        Args:
            requirements: Requirements text to search for
            limit: Maximum number of results
            
        Returns:
            List of similar diagram records
        """
        try:
            results = self.vector_store.similarity_search(requirements, k=limit)
            
            diagrams = []
            for doc in results:
                diagrams.append({
                    "requirements": doc.page_content,
                    "diagram": doc.metadata.get("diagram", ""),
                    "timestamp": doc.metadata.get("timestamp", ""),
                    "metadata": {k: v for k, v in doc.metadata.items() 
                                if k not in ["diagram", "timestamp"]}
                })
                
            logger.info(f"Retrieved {len(diagrams)} similar diagrams from SQLite")
            return diagrams
            
        except Exception as e:
            logger.warning(f"Memory retrieval failed: {e}")
            return []
    
    def clear_memory(self) -> None:
        try:
            # Close existing connection
            if hasattr(self, 'connection'):
                self.connection.close()
            
            # Remove database file
            if os.path.exists(self.db_path):
                os.remove(self.db_path)
            
            # Recreate connection and vector store
            import sqlean
            import sqlite_vec
            
            self.connection = sqlean.connect(self.db_path)
            self.connection.row_factory = sqlean.Row
            self.connection.enable_load_extension(True)
            sqlite_vec.load(self.connection)
            self.connection.enable_load_extension(False)
            
            self.vector_store = SQLiteVec(
                table="uml_memories",
                connection=self.connection,
                embedding=self.embedding_function
            )

            logger.info("Memory cleared and reinitialized")
        except Exception as e:
            logger.error(f"Failed to clear memory: {e}")

In [611]:
def seed_memory_from_shots(
    memory_manager: MemoryManager,
    shots_json_path: str = "./../data/complete_shots.json",
    force_reseed: bool = False
) -> int:
    """
    Seed the memory database with few-shot examples from JSON file.
    Skips seeding if database already contains data (unless force_reseed=True).
    
    Args:
        memory_manager: MemoryManager instance to seed
        shots_json_path: Path to the complete_shots.json file
        force_reseed: If True, clears existing data and reseeds
        
    Returns:
        Number of shots seeded (0 if skipped)
    """
    logger.info("="*60)
    logger.info("CHECKING MEMORY SEEDING STATUS")
    logger.info("="*60)
    
    # Check if database already has data
    try:
        existing_docs = memory_manager.vector_store.similarity_search("test", k=1)
        if existing_docs and not force_reseed:
            logger.info(f"Database already contains data ({len(existing_docs)} docs found)")
            logger.info("Skipping seeding operation. Set force_reseed=True to override.")
            return 0
    except Exception as e:
        logger.info(f"Database appears empty or uninitialized: {e}")
    
    if force_reseed:
        logger.warning("Force reseed enabled - clearing existing memory")
        memory_manager.clear_memory()
    

    if not os.path.exists(shots_json_path):
        logger.error(f"Shots file not found at {shots_json_path}")
        return 0
    
    logger.info(f"Loading shots from {shots_json_path}")
    with open(shots_json_path, 'r', encoding='utf-8') as f:
        shots = json.load(f)
    
    logger.info(f"Found {len(shots)} shots to seed")
    
    # Prepare documents
    documents = []
    for shot in shots:
        requirements = shot["requirements"]
        diagram = shot["diagram"]
        
        metadata = {
            "diagram": diagram,
            "timestamp": datetime.now().isoformat(),
            "plan": shot.get("plan"),
            "is_static": True,
            "title": shot.get("title", "Untitled")
        }
        
        logger.info(f"  Processing: {metadata['title']}")
        
        doc = Document(
            page_content=requirements,
            metadata=metadata
        )
        documents.append(doc)
    
    if documents:
        memory_manager.vector_store.add_documents(documents)
        logger.info("="*60)
        logger.info(f"✓ Successfully seeded {len(documents)} shots to memory")
        logger.info("="*60)
        return len(documents)
    
    return 0

In [None]:
class UMLNodes:
    """
    Collection of agent nodes for the UML generation workflow.
    
    Each method represents a node in the LangGraph workflow and
    follows the pattern of taking AgentState and returning a dict
    with state updates.
    """
    
    def __init__(
        self,
        llm: ChatOpenAI,
        plantuml_tool: PlantUMLTool,
        memory_manager: Optional['MemoryManager'] = None,
        config: Optional[SystemConfig] = None
    ):
        """
        Initialize UML nodes with required dependencies.
        
        Args:
            llm: LangChain ChatOpenAI instance
            plantuml_tool: Tool for PlantUML validation
            memory_manager: long-term memory manager
            config: Optional system configuration
        """
        self.llm = llm
        self.plantuml_tool = plantuml_tool
        self.memory_manager = memory_manager
        self.config = config or SystemConfig()
        logger.info("UMLNodes initialized")

    @staticmethod
    def _normalize_diagram(diagram: str) -> str:
        """Normalize diagram for consistent hashing (remove whitespace variations)."""
        lines = [line.strip() for line in diagram.strip().split('\n') if line.strip()]
        return '\n'.join(sorted(lines))  # Sort for order-independent comparison
    
    @staticmethod
    def _hash_diagram(diagram: str) -> str:
        """Create a hash of the diagram content for caching."""
        normalized = UMLNodes._normalize_diagram(diagram)
        return hashlib.md5(normalized.encode()).hexdigest()

    def _safe_invoke(self, runnable: Any, input_data: Any, **kwargs) -> Any:
        """
        Invoke a runnable (LLM or chain) with retry logic.
        """
        max_retries = 3
        last_exception = None
        
        for attempt in range(max_retries):
            try:
                return runnable.invoke(input_data, **kwargs)
            except Exception as e:
                last_exception = e
                logger.warning(f"LLM call failed (attempt {attempt+1}/{max_retries}): {e}")
                if attempt < max_retries - 1:
                    time.sleep(2 * (attempt + 1))
        
        logger.error(f"Max retries reached for LLM call: {last_exception}")
        raise last_exception

    @staticmethod
    def _format_decomposition_plan(decomposition: DecompositionResult) -> str:
        """
        Format a DecompositionResult into a readable plan string.
        
        Args:
            decomposition: The decomposition result with classes and relationships
            
        Returns:
            Formatted plan as a string
        """
        lines = ["## STRUCTURAL DECOMPOSITION\n"]
        
        # Format classes
        if decomposition.classes:
            lines.append("### Classes:")
            for cls in decomposition.classes:
                attrs_str = ", ".join([f"{attr.name}: {attr.type}" for attr in cls.attributes])
                lines.append(f"- {cls.name}" + (f" ({attrs_str})" if attrs_str else ""))
            lines.append("")
        
        # Format relationships
        if decomposition.relationships:
            lines.append("### Relationships:")
            for rel in decomposition.relationships:
                src_multiplicity_str = f" [{rel.source_multiplicity}]" if rel.source_multiplicity else ""
                tgt_multiplicity_str = f" [{rel.target_multiplicity}]" if rel.target_multiplicity else ""
                lines.append(f"- {rel.source}{src_multiplicity_str} --{rel.type}--> {rel.target}{tgt_multiplicity_str}")
        
        return "\n".join(lines)

    def retrieve(self, state: AgentState) -> Dict[str, Any]:
        """
        Retrieve relevant few-shot examples based on requirements.
        
        Args:
            state: Current workflow state
            
        Returns:
            Dict with 'examples' key containing formatted shots
        """
        logger.info(f"--- NODE: {NodeNames.RETRIEVE.upper()} ---")
        
        try:
            memories = self.memory_manager.retrieve_similar_diagrams(
                state["requirements"],
                limit=self.config.num_few_shots
            )
            
            formatted_shots = []
            for mem in memories:
                formatted_shots.append(
                    HumanMessage(content=f"Requirements:\n{mem['requirements']}")
                )
                
                meta = mem.get("metadata", {})
                plan = meta.get("plan", "No plan available.")
                
                assistant_content = (
                    f"1. DESIGN PLAN:\n{plan}\n\n"
                    f"2. PLANTUML DIAGRAM:\n```plantuml\n{mem['diagram']}\n```"
                )
                
                formatted_shots.append(
                    AIMessage(content=assistant_content)
                )
            
            logger.info(f"Retrieved {len(memories)} relevant examples from unified memory")
            return {"examples": formatted_shots}
        except Exception as e:
            logger.error(f"Retrieval failed: {e}")
            return {"examples": []}

    def decompose(self, state: AgentState) -> Dict[str, Any]:
        """
        Decompose requirements into structural building blocks.
        
        Args:
            state: Current workflow state
            
        Returns:
            Dict with 'plan' update containing formatted decomposition
        """
        logger.info(f"--- NODE: {NodeNames.DECOMPOSE.upper()} ---")

        critique = state.get("audit_feedback", []) or []
        suggestions = state.get("audit_suggestions", []) or []

        critique_str = "\n".join([f"- {c}" for c in critique]) if critique else "None"
        suggestions_str = "\n".join([f"- {s}" for s in suggestions]) if suggestions else "None"

        system_prompt = DECOMPOSER_SYSTEM
        if critique or suggestions:
            system_prompt += (
                "\n\nIMPORTANT: You must REVISE the previous plan to address the audit."
                "\n- Make the smallest changes needed."
                "\n- Do not drop correct elements."
                "\n\nAUDIT CRITIQUE:\n"
                f"{critique_str}"
                "\n\nAUDIT SUGGESTIONS:\n"
                f"{suggestions_str}"
            )
        
        messages = [SystemMessage(content=system_prompt)]
        messages.append(HumanMessage(content=f"REQUIREMENTS:\n{state['requirements']}"))

        previous_plan = state.get("plan")
        if previous_plan:
            messages.append(HumanMessage(content=f"PREVIOUS PLAN (revise this):\n{previous_plan}"))
        
        try:
            structured_llm = self.llm.bind(max_tokens=self.config.max_tokens_decompose).with_structured_output(DecompositionResult)
            decomposition: DecompositionResult = self._safe_invoke(
                structured_llm,
                messages
            )
            logger.info("Decomposition completed")
            formatted_plan = self._format_decomposition_plan(decomposition)
            
            return {"plan": formatted_plan}
            
        except Exception as e:
            logger.error(f"Decomposition failed: {e}")
            empty_decomposition = DecompositionResult(classes=[], relationships=[])
            formatted_plan = self._format_decomposition_plan(empty_decomposition)
            return {"plan": formatted_plan}
        
    def plan_auditor(self, state: AgentState) -> Dict[str, Any]:
        """
        Audits the structural plan for logical consistency and requirement coverage.
        
        Uses PlanAudit Pydantic model for structured output validation.
        
        Args:
            state: Current workflow state
            
        Returns:
            Dict with 'plan_valid' and 'audit_feedback' updates
        """
        logger.info(f"--- NODE: {NodeNames.PLAN_AUDIT.upper()} ---")
        
        # Check if we've exceeded max plan audit attempts
        max_plan_audit_attempts = 3
        current_attempts = state.get("plan_audit_attempts", 0) + 1
        
        if current_attempts > max_plan_audit_attempts:
            logger.warning(f"Max plan audit attempts ({max_plan_audit_attempts}) reached. Forcing plan validation.")
            return {
                "plan_valid": True,
                "audit_feedback": [],
                "audit_suggestions": [],
                "plan_audit_attempts": current_attempts
            }
        
        plan = state.get("plan")
        requirements = state.get("requirements")
        req_len = len(requirements) if isinstance(requirements, str) else 0
        plan_len = len(plan) if isinstance(plan, str) else 0
        logger.debug(f"Plan audit inputs: requirements_len={req_len}, plan_len={plan_len}")
        logger.debug(f"Plan audit prompt template chars={len(PLAN_AUDITOR_SYSTEM)}")
        if not requirements:
            logger.warning("Plan audit received empty requirements")
        if not plan:
            logger.warning("Plan audit received empty plan")
        
        prompt = f"""
        REQUIREMENTS:
        {requirements}
        
        PROPOSED PLAN:
        {plan}
        """

        try:
            t0 = time.perf_counter()
            messages = [
                SystemMessage(content=PLAN_AUDITOR_SYSTEM),
                HumanMessage(content=prompt)
            ]
            # Questa sezione di codice si può probabilmente semplificare
            try:
                structured_llm = self.llm.with_structured_output(PlanAudit, include_raw=True)
                audit_payload = structured_llm.invoke(messages)
                audit_result: Optional[PlanAudit] = audit_payload.get("parsed")
                raw_msg = audit_payload.get("raw")
                parsing_error = audit_payload.get("parsing_error")
                if raw_msg is not None and getattr(raw_msg, 'content', None) is not None:
                    raw_text = str(raw_msg.content)
                    raw_preview = raw_text[:2000] + ("..." if len(raw_text) > 2000 else "")
                    logger.debug(f"Plan audit raw response (preview): {raw_preview}")
                if parsing_error is not None:
                    logger.warning(f"Plan audit parsing_error: {parsing_error}")
                if audit_result is None:
                    raise ValueError("Plan audit produced no parsed result")
            except TypeError:
                audit_result = self.llm.with_structured_output(PlanAudit).invoke(messages)
            elapsed = time.perf_counter() - t0
            logger.debug(f"Plan audit invoke took {elapsed:.2f}s")
            
            logger.info(f"Plan audit completed: valid={audit_result.is_valid}")
            logger.debug(f"Plan audit parsed: critique_count={len(audit_result.critique)}, suggestions_count={len(audit_result.suggestions)}")
            
            logger.info(f"Output from plan auditor: {audit_result.model_dump()}")
            
            #if not audit_result.is_valid:
            if audit_result.critique:
                logger.info(f"Audit issues (first 5): {', '.join(audit_result.critique[:5])}")
            if audit_result.suggestions:
                logger.info(f"Audit suggestions (first 5): {', '.join(audit_result.suggestions[:5])}")
            logger.debug(f"Plan audit full result: {audit_result.model_dump()}")
            
            return {
                "plan_valid": audit_result.is_valid,
                "audit_feedback": audit_result.critique,
                "audit_suggestions": audit_result.suggestions,
                "plan_audit_attempts": current_attempts
            }
            
        except Exception as e:
            logger.exception(f"Plan audit failed: {e}")
            return {
                "plan_valid": False,
                "audit_feedback": [f"Audit error: {str(e)}"],
                "audit_suggestions": [],
                "plan_audit_attempts": current_attempts  
            }

    def generate(self, state: AgentState) -> Dict[str, Any]:
        """
        Generate PlantUML diagram using chain-of-thought reasoning.
        
        Args:
            state: Current workflow state
            
        Returns:
            Dict with 'current_diagram' and 'iterations' updates
        """
        logger.info(f"--- NODE: {NodeNames.GENERATE.upper()} ---")
        
        messages = [SystemMessage(content=GENERATOR_SYSTEM)]
        
        # Add few-shot examples if available
        if state.get("examples"):
            messages.extend(state["examples"])
            logger.debug(f"Added {len(state['examples'])} example messages")
            
        user_content = f"""
        # VALIDATED REQUIREMENTS
        {state['requirements']}

        # VALIDATED DESIGN PLAN
        {state['plan']}

        # TASK
        Render the PlantUML class diagram exactly from the design plan.
        """
        
        # Add syntax error feedback if we're retrying after syntax check failure
        if not state.get("syntax_valid", True) and state.get("error_message"):
            user_content += f"""
        
        # PREVIOUS ATTEMPT HAD SYNTAX ERROR
        {state['error_message']}
        
        Fix the syntax error and regenerate the diagram.
        """
            logger.info(f"Added syntax error feedback to generation prompt")
        
        messages.append(HumanMessage(content=user_content))

        logger.debug(f"Generation prompt messages: {messages}")
        
        try:
            response = self._safe_invoke(
                self.llm,
                messages,
                max_tokens=self.config.max_tokens_generate
            )
            diagram = self.plantuml_tool.extract_plantuml(response.content)
            
            logger.info(f"Generation completed (iteration {state['iterations'] + 1})")
            return {
                "current_diagram": diagram,
                "iterations": state["iterations"] + 1,
                "stagnant_count": 0  # Reset stagnation counter on new generation
            }
            
        except Exception as e:
            logger.error(f"Generation failed: {e}")
            return {
                "current_diagram": f"Error: {str(e)}",
                "iterations": state["iterations"] + 1
            }

    def syntax_check(self, state: AgentState) -> Dict[str, Any]:
        """
        Validate PlantUML syntax through server.
        
        Args:
            state: Current workflow state
            
        Returns:
            Dict with 'syntax_valid' and optional 'error_message'
        """
        logger.info(f"--- NODE: {NodeNames.SYNTAX_CHECK.upper()} ---")
        
        try:
            result = self.plantuml_tool.check_syntax(
                state["current_diagram"],
                timeout=self.config.request_timeout
            )
            
            if result.is_valid:
                logger.info(f"Syntax valid. View at: {result.url}")
            else:
                logger.warning(f"Syntax error: {result.error}")
            
            return {
                "syntax_valid": result.is_valid,
                "error_message": result.error if not result.is_valid else None,
            }
            
        except Exception as e:
            logger.error(f"Syntax check failed: {e}")
            return {
                "syntax_valid": False,
                "error_message": f"Syntax check error: {str(e)}"
            }

    def critic(self, state: AgentState) -> Dict[str, Any]:
        """
        Critic node: produces a formal CritiqueReport, tracks semantic progress,
        and updates stagnation metrics.
        """
        logger.info(f"--- NODE: {NodeNames.CRITIC.upper()} ---")

        try:
            requirements = state["requirements"]
            diagram = self.plantuml_tool.extract_plantuml(state["current_diagram"])

            messages = [
                SystemMessage(content=CRITIC_SYSTEM),
                HumanMessage(content=json.dumps({
                    "requirements": requirements,
                    "diagram": diagram
                }))
            ]

            structured_llm = self.llm.bind(max_tokens=self.config.max_tokens_critique).with_structured_output(CritiqueReport)
            report: CritiqueReport = self._safe_invoke(structured_llm, messages)

            # ---- STAGNATION TRACKING ----
            prev_ids = set(state.get("previous_finding_ids", []))
            curr_ids = {f.id for f in report.findings}

            resolved = prev_ids - curr_ids
            new = curr_ids - prev_ids

            if not resolved and not new:
                stagnant_count = state.get("stagnant_count", 0) + 1
            else:
                stagnant_count = 0

            # ---- UPDATE SUMMARY ----
            summary: CritiqueSummary = CritiqueSummary(
                total_findings=len(report.findings),
                render_only=sum(1 for f in report.findings if f.fixability == Fixability.render_only),
                structure_change=sum(1 for f in report.findings if f.fixability == Fixability.structure_change),
                unfixable=sum(1 for f in report.findings if f.fixability == Fixability.unfixable),
                new_findings=len(new),
                resolved_findings=len(resolved)
            )

            logger.info(
                f"Critic findings: total={report.summary.total_findings}, "
                f"render_only={summary.render_only}, "
                f"structural={summary.structure_change}, "
                f"unfixable={summary.unfixable}, "
                f"new={summary.new_findings}, resolved={summary.resolved_findings}"
            )

            return {
                "current_validation": report,
                "previous_finding_ids": list(curr_ids),
                "stagnant_count": stagnant_count,
            }

        except Exception as e:
            logger.error(f"Critic node failed: {e}")
            return {
                "logic_valid": False,
                "current_validation": None,
                "error_message": str(e)
            }

    def reflector(self, state: AgentState) -> Dict[str, Any]:
        """
        Reflector node: fixes render-only issues in the current PlantUML diagram
        based on structured critique findings.

        Args:
            state: Current workflow state

        Returns:
            Dict with updated current_diagram and iterations
        """
        logger.info(f"--- NODE: {NodeNames.REFLECTOR.upper()} ---")

        try:
            # Extract findings flagged as render-only from current_validation
            current_validation = state.get("current_validation")
            findings = current_validation.findings if current_validation else []
            renderable_findings = [
                f for f in findings if f.fixability == Fixability.render_only
            ]

            # If no render-only findings, also check for errors to provide context
            if not renderable_findings:
                logger.info("No render-only issues to fix. Returning current diagram unchanged.")
                return {
                    "current_diagram": state["current_diagram"],
                    "iterations": state["iterations"] + 1,
                }

            reflector_input = {
                "diagram": state["current_diagram"],
                "findings": [f.model_dump() for f in renderable_findings],
            }

            system_msg = SystemMessage(content=REFLECTOR_SYSTEM)
            user_msg = HumanMessage(content=json.dumps(reflector_input))

            reflected_diagram_response = self._safe_invoke(
                self.llm,
                [system_msg, user_msg],
                max_tokens=self.config.max_tokens_reflect
            )

            reflected_diagram = self.plantuml_tool.extract_plantuml(reflected_diagram_response.content)

            logger.info(f"Reflector applied fixes for {len(reflector_input['findings'])} issues.")

            return {
                "current_diagram": reflected_diagram,
                "iterations": state["iterations"] + 1
            }

        except Exception as e:
            logger.error(f"Reflector node failed: {e}")
            return {
                "current_diagram": state.get("current_diagram"),
                "error_message": str(e),
                "iterations": state["iterations"] + 1
            }
    
    def structure_refiner(self, state: AgentState) -> Dict[str, Any]:
        """
        Structure Refiner node: applies guided structural fixes using
        expected_correction from critique findings.
        """
        logger.info(f"--- NODE: {NodeNames.STRUCTURE_REFINER.upper()} ---")

        try:
            validation = state.get("current_validation")
            if not validation:
                logger.warning("No validation report. Structure refiner is a no-op.")
                return {"current_diagram": state["current_diagram"]}

            findings = [
                f for f in validation.findings
                if f.fixability == Fixability.structure_change
            ]

            # Safety: every structural finding must have an expected correction
            for f in findings:
                if not f.expected_correction:
                    logger.error(
                        f"Structural finding missing expected_correction: {f.id}"
                    )
                    return {
                        "current_diagram": state["current_diagram"],
                        "error_message": f"Unfixable structural finding: {f.id}"
                    }

            if not findings:
                logger.info("No structural findings to apply.")
                return {"current_diagram": state["current_diagram"]}

            refiner_input = {
                "requirements": state["requirements"],
                "diagram": state["current_diagram"],
                "findings": [f.model_dump() for f in findings]
            }

            messages = [
                SystemMessage(content=STRUCTURE_REFINER_SYSTEM),
                HumanMessage(content=json.dumps(refiner_input))
            ]

            response = self._safe_invoke(
                self.llm,
                messages,
                max_tokens=self.config.max_tokens_refine
            )

            refined_diagram = self.plantuml_tool.extract_plantuml(response.content)

            logger.info(
                f"Structure refiner applied {len(findings)} guided corrections."
            )

            return {
                "current_diagram": refined_diagram,
                "iterations": state["iterations"] + 1
            }

        except Exception as e:
            logger.error(f"Structure refiner failed: {e}")
            return {
                "current_diagram": state.get("current_diagram"),
                "error_message": str(e)
            }


In [613]:
def update_iteration_metrics(state: AgentState, cfg: SystemConfig) -> None:
    """
    Update iteration-based metrics for stopping conditions:
    - Weighted score delta
    - Plateau detection
    - Stagnant iteration count
    
    Args:
        state: Current workflow state
        cfg: System configuration (plateau_window, plateau_threshold)
    """
    current_validation = state.get("current_validation")
    if not current_validation:
        state["delta_score"] = 0.0
        state["stagnant_count"] = state.get("stagnant_count", 0)
        return

    current_score = getattr(current_validation, "weighted_score", 0.0)
    state.setdefault("score_history", []).insert(0, current_score)  # newest first

    # Limit history to plateau window
    state["score_history"] = state["score_history"][:cfg.plateau_window]

    # Compute delta from previous iteration
    if len(state["score_history"]) > 1:
        delta = current_score - state["score_history"][1]
    else:
        delta = current_score  # first iteration

    state["delta_score"] = delta

    # Check for plateau
    plateau_detected = False
    if len(state["score_history"]) >= cfg.plateau_window:
        deltas = [
            abs(state["score_history"][i] - state["score_history"][i+1])
            for i in range(len(state["score_history"]) - 1)
        ]
        if all(d < cfg.plateau_threshold for d in deltas):
            plateau_detected = True

    # Update stagnant count
    if plateau_detected:
        state["stagnant_count"] = state.get("stagnant_count", 0) + 1
    else:
        state["stagnant_count"] = 0

    logger.debug(
        f"Iteration metrics updated: current_score={current_score:.2f}, "
        f"delta_score={delta:.2f}, stagnant_count={state['stagnant_count']}"
    )

In [None]:
def create_uml_graph(
    nodes: UMLNodes, 
    config: Optional[SystemConfig] = None
) -> Any:
    """
    Create the LangGraph workflow for UML diagram generation.
    
    Args:
        nodes: UMLNodes instance with all agent methods
        config: Optional system configuration
        
    Returns:
        Compiled LangGraph workflow
    """
    cfg = config or SystemConfig()
    logger.info("Creating UML generation workflow")
    
    workflow = StateGraph(AgentState)

    # Add all nodes
    workflow.add_node(NodeNames.RETRIEVE, nodes.retrieve)
    workflow.add_node(NodeNames.PLAN_AUDIT, nodes.plan_auditor)
    workflow.add_node(NodeNames.DECOMPOSE, nodes.decompose)
    workflow.add_node(NodeNames.GENERATE, nodes.generate)
    workflow.add_node(NodeNames.SYNTAX_CHECK, nodes.syntax_check)
    workflow.add_node(NodeNames.CRITIC, nodes.critic)
    workflow.add_node(NodeNames.REFLECTOR, nodes.reflector)
    workflow.add_node(NodeNames.STRUCTURE_REFINER, nodes.structure_refiner)
    
    logger.debug("Added 8 nodes to workflow")

    # Define edges
    workflow.add_edge(START, NodeNames.RETRIEVE)
    workflow.add_edge(NodeNames.RETRIEVE, NodeNames.DECOMPOSE)
    workflow.add_edge(NodeNames.DECOMPOSE, NodeNames.PLAN_AUDIT)

    def route_after_plan_audit(state: AgentState) -> str:
        """
        Route based on plan audit results.
        
        Args:
            state: Current workflow state
            
        Returns:
            Next node name
        """
        if state.get("plan_valid", False):
            logger.debug("Routing: plan_audit -> generate")
            return NodeNames.GENERATE
            
        logger.debug("Routing: plan_audit -> decompose")
        return NodeNames.DECOMPOSE

    workflow.add_conditional_edges(
        NodeNames.PLAN_AUDIT, 
        route_after_plan_audit,
        {
            NodeNames.DECOMPOSE: NodeNames.DECOMPOSE,
            NodeNames.GENERATE: NodeNames.GENERATE
        }
    )

    workflow.add_edge(NodeNames.GENERATE, NodeNames.SYNTAX_CHECK)

    def route_after_syntax_check(state: AgentState) -> str:
        """
        Route based on syntax validation results and iteration limits.
        
        Args:
            state: Current workflow state
            
        Returns:
            Next node name
        """
        if state.get("syntax_valid", False):
            logger.debug("Routing: syntax_check -> critic")
            return NodeNames.CRITIC
            
        if state["iterations"] >= cfg.max_iterations:
            logger.warning(f"Max iterations ({cfg.max_iterations}) reached during syntax check")
            return END
        
        logger.debug("Routing: syntax_check -> generate (syntax error)")
        return NodeNames.GENERATE

    workflow.add_conditional_edges(
        NodeNames.SYNTAX_CHECK, 
        route_after_syntax_check,
        {
            NodeNames.CRITIC: NodeNames.CRITIC,
            NodeNames.GENERATE: NodeNames.GENERATE,
            END: END
        }
    )

    def route_after_critic(state: AgentState) -> str:
        update_iteration_metrics(state, cfg)  # optional metric tracking

        validation = state.get("current_validation")
        if not validation:
            logger.error("No critic report; stopping workflow.")
            return END

        summary = validation.summary

        # Step 1: Fully valid diagram
        if validation.is_valid:
            logger.info("CRITIC passed: diagram is fully valid → END")
            return END

        # Step 2: Check for unfixable structural issues
        if summary.unfixable > 0:
            logger.warning("CRITIC detected unfixable issues → END")
            return END

        # Step 3: Check for stagnation
        if state.get("stagnant_count", 0) >= cfg.max_stagnant_iterations:
            logger.warning(f"Stagnation detected ({cfg.max_stagnant_iterations}) → END")
            return END

        # Step 4: Structural corrections required
        if summary.structure_change > 0:
            logger.debug("Routing: CRITIC → STRUCTURE_REFINER (structural issues exist)")
            return NodeNames.STRUCTURE_REFINER

        # Step 5: Render-only issues can be fixed
        if summary.render_only > 0:
            logger.debug("Routing: CRITIC → REFLECTOR (render-only issues exist)")
            return NodeNames.REFLECTOR

        # Safety net
        logger.warning("CRITIC findings not actionable; ending workflow")
        return END

    workflow.add_conditional_edges(
        NodeNames.CRITIC,
        route_after_critic,
        {
            NodeNames.REFLECTOR: NodeNames.REFLECTOR,
            NodeNames.STRUCTURE_REFINER: NodeNames.STRUCTURE_REFINER,
            END: END
        }
    )


    def route_after_reflector(state: AgentState) -> str:
        """
        Route after reflection based on iteration limits and stagnation.
        
        Args:
            state: Current workflow state
            
        Returns:
            Next node name
        """
        if state["iterations"] >= cfg.max_iterations:
            logger.warning(f"Max iterations ({cfg.max_iterations}) reached after reflection")
            return END
        logger.debug("Routing: reflect -> syntax_check")
        return NodeNames.SYNTAX_CHECK

    workflow.add_conditional_edges(
        NodeNames.REFLECTOR, 
        route_after_reflector,
        {
            NodeNames.SYNTAX_CHECK: NodeNames.SYNTAX_CHECK,
            END: END
        }
    )

    def route_after_structure_refiner(state: AgentState) -> str:
        if state["iterations"] >= cfg.max_iterations:
            logger.warning(f"Max iterations ({cfg.max_iterations}) reached after STRUCTURE_REFINER")
            return END

        # After structural corrections, syntax check first
        logger.debug("Routing: STRUCTURE_REFINER → SYNTAX_CHECK")
        return NodeNames.SYNTAX_CHECK

    workflow.add_conditional_edges(
        NodeNames.STRUCTURE_REFINER,
        route_after_structure_refiner,
        {
            NodeNames.SYNTAX_CHECK: NodeNames.SYNTAX_CHECK, 
            END: END
        }
    )
        
    logger.info("Workflow graph created successfully")
    return workflow.compile()


def create_initial_state(requirements: str) -> AgentState:
    """
    Create an initial state for the workflow.
    
    Args:
        requirements: Software requirements text
        
    Returns:
        Initial AgentState dictionary
    """
    return {
        "requirements": requirements,
        "plan": None,
        "examples": [],
        "current_diagram": None,
        "best_diagram": None,
        "summary": None,
        "syntax_valid": False,
        "logic_valid": False,
        "iterations": 0,
        "error_message": None,
        "failed_attempts": [],
        "stagnant_count": 0,
        "critique_cache": {},
        # Additional required fields
        "plan_valid": False,
        "audit_feedback": None,
        "plan_audit_attempts": 0,
        "best_score": 0.0,
        "best_code": "",
        "current_validation": None,
        "score_history": [],
        "delta_score": 0.0
    }


In [615]:
def load_test_exercises(json_path: str = "./../data/test_exercises.json") -> List[Dict[str, Any]]:
    """
    Load test exercises from JSON file.
    
    Args:
        json_path: Path to test exercises JSON
        
    Returns:
        List of exercise dictionaries
        
    Raises:
        FileNotFoundError: If file doesn't exist
        json.JSONDecodeError: If JSON is invalid
    """
    logger.info(f"Loading test exercises from {json_path}")
    
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"Test exercises file not found: {json_path}")
    
    with open(json_path, "r", encoding="utf-8") as f:
        exercises = json.load(f)
    
    logger.info(f"Loaded {len(exercises)} test exercises")
    return exercises

try:
    test_exercises = load_test_exercises()
    print(f"Loaded {len(test_exercises)} test exercises")
except Exception as e:
    logger.error(f"Failed to load test exercises: {e}")
    test_exercises = []

2026-01-16 01:53:34,539 - __main__ - INFO - Loading test exercises from ./../data/test_exercises.json
2026-01-16 01:53:34,541 - __main__ - INFO - Loaded 8 test exercises


Loaded 8 test exercises


In [616]:
def initialize_system(
    config: Optional[SystemConfig] = None,
    enable_long_term_memory: bool = True
) -> Tuple[UMLNodes, Any, SystemConfig, Optional[MemoryManager]]:
    """
    Initialize all system components.
    
    Args:
        config: Optional system configuration
        enable_long_term_memory: Whether to enable long-term memory
        
    Returns:
        Tuple of (nodes, compiled_workflow, config, memory_manager)
    """
    cfg = config or SystemConfig()
    logger.info("="*60)
    logger.info("INITIALIZING UML GENERATION SYSTEM")
    logger.info("="*60)
    
    try:
        logger.info("Creating LLM connection...")
        llm = create_llm(cfg)
        
        logger.info("Initializing PlantUML tool...")
        puml_tool = PlantUMLTool(cfg.plantuml_host)
        
        memory_mgr = None
        if enable_long_term_memory:
            logger.info(f"Initializing long-term memory with {cfg.embedder_model}...")

            dims = 1024 if "large" in cfg.embedder_model.lower() else 384
            
            memory_mgr = MemoryManager(
                embedder=SentenceTransformer(cfg.embedder_model),
                db_path=cfg.db_path,
                embedding_dims=dims
            )

            seeded_count = seed_memory_from_shots(
                memory_manager=memory_mgr,
                shots_json_path=cfg.shots_json_path,
                force_reseed=True 
            )

            logger.info("Long-term memory (SQLite + sqlite-vec) enabled")

            if seeded_count > 0:
                logger.info(f"Seeded {seeded_count} few-shot examples into memory")
        else:
            logger.info("Long-term memory disabled")
        

        logger.info("Building LangGraph workflow...")
        nodes = UMLNodes(llm, puml_tool, memory_mgr, cfg)
        app = create_uml_graph(nodes, cfg)
        
        logger.info("="*60)
        logger.info("SYSTEM INITIALIZED SUCCESSFULLY")
        logger.info("="*60)
        
        return nodes, app, cfg, memory_mgr
        
    except Exception as e:
        logger.error(f"System initialization failed: {e}")
        raise


nodes, app, config, memory_manager = initialize_system(enable_long_term_memory=True)
print("\nSystem ready for diagram generation")
print(f"Long-term memory: {'ENABLED' if memory_manager else 'DISABLED'}")


2026-01-16 01:53:34,551 - __main__ - INFO - INITIALIZING UML GENERATION SYSTEM
2026-01-16 01:53:34,552 - __main__ - INFO - Creating LLM connection...
2026-01-16 01:53:34,552 - __main__ - INFO - Connecting to LMStudio at http://localhost:1234/v1
2026-01-16 01:53:34,552 - __main__ - INFO - Using model: mistralai/devstral-small-2-2512 (temp=0.15)
2026-01-16 01:53:34,556 - __main__ - INFO - Initializing PlantUML tool...
2026-01-16 01:53:34,556 - __main__ - INFO - PlantUML tool initialized with host: http://localhost:8080
2026-01-16 01:53:34,556 - __main__ - INFO - Initializing long-term memory with BAAI/bge-large-en-v1.5...
2026-01-16 01:53:34,560 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device_name: mps
2026-01-16 01:53:34,561 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: BAAI/bge-large-en-v1.5
2026-01-16 01:53:34,923 - httpx - INFO - HTTP Request: HEAD https://huggingface.co/BAAI/bge-large-en-v1.5/resolve/main/modules.js

Loading weights:   0%|          | 0/391 [00:00<?, ?it/s]

BertModel LOAD REPORT from: BAAI/bge-large-en-v1.5
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
2026-01-16 01:53:37,288 - httpx - INFO - HTTP Request: HEAD https://huggingface.co/BAAI/bge-large-en-v1.5/resolve/main/tokenizer_config.json "HTTP/1.1 307 Temporary Redirect"
2026-01-16 01:53:37,364 - httpx - INFO - HTTP Request: HEAD https://huggingface.co/api/resolve-cache/models/BAAI/bge-large-en-v1.5/d4aa6901d3a41ba39fb536a557fa166f842b0e09/tokenizer_config.json "HTTP/1.1 200 OK"
2026-01-16 01:53:37,563 - httpx - INFO - HTTP Request: GET https://huggingface.co/api/models/BAAI/bge-large-en-v1.5/tree/main/additional_chat_templates?recursive=false&expand=false "HTTP/1.1 404 Not Found"
2026-01-16 01:53:37,766 - httpx - INFO - HTTP Request: GET https://huggingface.co/api/models/BAAI/bge-larg

Loading weights:   0%|          | 0/391 [00:00<?, ?it/s]

BertModel LOAD REPORT from: BAAI/bge-large-en-v1.5
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
2026-01-16 01:54:57,010 - httpx - INFO - HTTP Request: HEAD https://huggingface.co/BAAI/bge-large-en-v1.5/resolve/main/tokenizer_config.json "HTTP/1.1 307 Temporary Redirect"
2026-01-16 01:54:57,084 - httpx - INFO - HTTP Request: HEAD https://huggingface.co/api/resolve-cache/models/BAAI/bge-large-en-v1.5/d4aa6901d3a41ba39fb536a557fa166f842b0e09/tokenizer_config.json "HTTP/1.1 200 OK"
2026-01-16 01:54:57,282 - httpx - INFO - HTTP Request: GET https://huggingface.co/api/models/BAAI/bge-large-en-v1.5/tree/main/additional_chat_templates?recursive=false&expand=false "HTTP/1.1 404 Not Found"
2026-01-16 01:54:57,510 - httpx - INFO - HTTP Request: GET https://huggingface.co/api/models/BAAI/bge-larg


System ready for diagram generation
Long-term memory: ENABLED


In [617]:
def run_single_test(
    app: Any,
    requirements: str,
    exercise_name: str = "Test Exercise"
) -> AgentState:
    """
    Run the workflow on a single exercise.
    
    Args:
        app: Compiled LangGraph workflow
        requirements: Software requirements text
        exercise_name: Name for logging purposes
        
    Returns:
        Final workflow state
    """
    logger.info("="*60)
    logger.info(f"RUNNING: {exercise_name}")
    logger.info("="*60)
    logger.info(f"Requirements preview: {requirements[:150]}...")
    
    initial_state = create_initial_state(requirements)
    
    try:
        final_output = app.invoke(initial_state, config={"recursion_limit": 50})
        
        logger.info("="*60)
        logger.info("WORKFLOW COMPLETED")
        logger.info("="*60)
        logger.info(f"Iterations: {final_output['iterations']}")
        logger.info(f"Syntax Valid: {final_output['syntax_valid']}")
        logger.info(f"Logic Valid: {final_output['logic_valid']}")
        
        if final_output.get('best_diagram') and not final_output['logic_valid']:
            if final_output['best_diagram'] != final_output['current_diagram']:
                logger.info("Using BEST diagram instead of final (prevented regression)")
                final_output['current_diagram'] = final_output['best_diagram']
        
        return final_output
        
    except Exception as e:
        logger.error(f"Workflow execution failed: {e}")
        raise


# Select and run a test exercise
test_idx = 2
requirements = test_exercises[test_idx]["requirements"]

final_output = run_single_test(
    app, 
    requirements, 
    f"Exercise {test_idx + 1}"
)


print("\n" + "="*60)
print("FINAL RESULTS")
print("="*60)
print(f"Iterations: {final_output['iterations']}")
print(f"Syntax Valid: {final_output['syntax_valid']}")
print(f"Logic Valid: {final_output['logic_valid']}")

if final_output['current_diagram']:
    puml_tool = PlantUMLTool(config.plantuml_host)
    diagram_url = puml_tool.get_diagram_url(final_output['current_diagram'])
    print(f"\nDiagram URL: {diagram_url}")
    
    print("\nGenerated Diagram:")
    print(final_output['current_diagram']) 


2026-01-16 01:55:09,785 - __main__ - INFO - RUNNING: Exercise 3
2026-01-16 01:55:09,786 - __main__ - INFO - Requirements preview: A library system manages books, members, and loans.
Books have an ISBN, title, author, publisher, publication year, and availability status.
Members h...
2026-01-16 01:55:09,800 - __main__ - INFO - --- NODE: RETRIEVE ---
2026-01-16 01:55:13,315 - __main__ - INFO - Retrieved 3 similar diagrams from SQLite
2026-01-16 01:55:13,321 - __main__ - INFO - Retrieved 3 relevant examples from unified memory
2026-01-16 01:55:13,324 - __main__ - INFO - --- NODE: DECOMPOSE ---
2026-01-16 01:56:26,819 - httpx - INFO - HTTP Request: POST http://localhost:1234/v1/chat/completions "HTTP/1.1 200 OK"
2026-01-16 01:56:26,862 - __main__ - INFO - Decomposition completed
2026-01-16 01:56:26,864 - __main__ - INFO - --- NODE: PLAN_AUDIT ---
2026-01-16 01:56:40,244 - httpx - INFO - HTTP Request: POST http://localhost:1234/v1/chat/completions "HTTP/1.1 200 OK"
2026-01-16 01:56:40,272 -


FINAL RESULTS
Iterations: 6
Syntax Valid: False
Logic Valid: False

Diagram URL: http://localhost:8080/png/RP71QiCm38RlUGhJuw27lS0ePHGAsovz0YLM6rFR2hOqCDlUVRBJP5lPoRzawTDlUPSP4almtU0XPxCdkfgFpfZZQV-c1plsg2S8ZvHKJD9xbqTSzG3iA9g2K5Fm3iv3xxpZOfJDahkl6_iLGu-fqKEJUNIJEJvh726qATOcpBcYHufeejGo3HDUNEOqZxB0k49FB1OZDsg-wOQS4bqI14FdbYzhnm46yLn-kS4mUIg8SwF5ILSI5BIsl134LZcCE5n9UFQWjOABJCYIximTd3wYzfshW42YKBK6fU9_zhCyOlEuCTljvzEoAesxltR_grZaGxsQlQhksbVU9tSOhVxsDm==

Generated Diagram:
@startuml
class Book {
  ISBN
  title
  author
  publisher
  publication year
  availability status
}
class Member {
  membership ID
  name
  address
  phone number
  registration date
}
class Student {
  student ID
  program of study
}
class FacultyMember {
  employee ID
  department
}
class Loan {
  checkout date
  due date
  return date
}
class Fine {
  fine amount
  payment status
}
Student "1" --|> Member "1"
FacultyMember "1" --|> Member "1"
Member "*" -- Loan "*"
Book "*" -- Loan "*"
@enduml


In [618]:
print(final_output['current_diagram'])

@startuml
class Book {
  ISBN
  title
  author
  publisher
  publication year
  availability status
}
class Member {
  membership ID
  name
  address
  phone number
  registration date
}
class Student {
  student ID
  program of study
}
class FacultyMember {
  employee ID
  department
}
class Loan {
  checkout date
  due date
  return date
}
class Fine {
  fine amount
  payment status
}
Student "1" --|> Member "1"
FacultyMember "1" --|> Member "1"
Member "*" -- Loan "*"
Book "*" -- Loan "*"
@enduml


In [619]:
class EvaluationMetrics(BaseModel):
    """Container for evaluation metrics."""
    precision: float = Field(ge=0.0, le=1.0, description="Precision score")
    recall: float = Field(ge=0.0, le=1.0, description="Recall score")
    f1: float = Field(ge=0.0, le=1.0, description="F1 score")
    
    def __str__(self) -> str:
        return f"P={self.precision:.2f}, R={self.recall:.2f}, F1={self.f1:.2f}"


class PlantUMLParser:
    """
    Parser for extracting structured information from PlantUML diagrams.
    
    Extracts classes, attributes, and relationships from PlantUML code
    for evaluation purposes.
    """
    
    def __init__(self, plantuml_code: str):
        """
        Initialize parser with PlantUML code.
        
        Args:
            plantuml_code: PlantUML diagram code
        """
        self.plantuml_code = plantuml_code
        self.classes: Dict[str, Dict[str, List[str]]] = {}
        self.relationships: List[Dict[str, Any]] = []
        self.parse()
    
    def parse(self) -> None:
        """Parse the PlantUML code."""
        try:
            self._extract_classes()
            self._extract_relationships()
            logger.debug(f"Parsed {len(self.classes)} classes and {len(self.relationships)} relationships")
        except Exception as e:
            logger.error(f"Parsing failed: {e}")
    
    def _extract_classes(self) -> None:
        """Extract class definitions and their attributes."""
        class_pattern = r'class\s+(\w+)\s*\{([^}]*)\}'
        matches = re.finditer(class_pattern, self.plantuml_code, re.MULTILINE | re.DOTALL)
        
        for match in matches:
            class_name = match.group(1)
            class_body = match.group(2)
            
            attributes = []
            for line in class_body.strip().split('\n'):
                line = line.strip()
                if line and not line.startswith('--'):
                    attributes.append(line)
            
            self.classes[class_name] = {'attributes': attributes}
    
    def _extract_relationships(self) -> None:
        """Extract relationships between classes with cardinalities."""
        patterns = [
            # Generalization (either direction)
            (r'(\w+)\s*(?:"([^"]*)")?\s*<\|--\s*(?:"([^"]*)")?\s*(\w+)', 'generalization'),
            (r'(\w+)\s*(?:"([^"]*)")?\s*--|>\s*(?:"([^"]*)")?\s*(\w+)', 'generalization'),
            
            # Composition (either direction)
            (r'(\w+)\s*(?:"([^"]*)")?\s*\*--\s*(?:"([^"]*)")?\s*(\w+)', 'composition'),
            (r'(\w+)\s*(?:"([^"]*)")?\s*--\*\s*(?:"([^"]*)")?\s*(\w+)', 'composition'),
            
            # Aggregation (either direction)
            (r'(\w+)\s*(?:"([^"]*)")?\s*o--\s*(?:"([^"]*)")?\s*(\w+)', 'aggregation'),
            (r'(\w+)\s*(?:"([^"]*)")?\s*--o\s*(?:"([^"]*)")?\s*(\w+)', 'aggregation'),
            
            # Directed Association (either direction)
            (r'(\w+)\s*(?:"([^"]*)")?\s*-->\s*(?:"([^"]*)")?\s*(\w+)', 'association'),
            (r'(\w+)\s*(?:"([^"]*)")?\s*<--\s*(?:"([^"]*)")?\s*(\w+)', 'association'),
            
            # Simple Association (no arrow)
            (r'(\w+)\s*(?:"([^"]*)")?\s*--\s*(?:"([^"]*)")?\s*(\w+)', 'association'),
        ]
        
        for pattern, rel_type in patterns:
            for match in re.finditer(pattern, self.plantuml_code):
                source = match.group(1)
                target = match.group(4)
                
                # Skip if source or target is None or empty
                if not source or not target:
                    continue
                
                self.relationships.append({
                    'type': rel_type,
                    'source': source,
                    'target': target,
                    'cardinality_source': match.group(2) if match.lastindex >= 2 else None,
                    'cardinality_target': match.group(3) if match.lastindex >= 3 else None
                })


class DiagramEvaluator:
    """
    Evaluator for comparing generated diagrams against gold standards.
    
    Computes precision, recall, and F1 scores for classes, attributes,
    and relationships.
    """
    
    def __init__(self, gold_plantuml: str, pred_plantuml: str):
        """
        Initialize evaluator with gold and predicted diagrams.
        
        Args:
            gold_plantuml: Gold standard PlantUML code
            pred_plantuml: Predicted PlantUML code
        """
        self.gold_parser = PlantUMLParser(gold_plantuml)
        self.pred_parser = PlantUMLParser(pred_plantuml)
    
    def _normalize_attr(self, attr_str: str) -> str:
        """Normalize attribute strings for comparison."""
        return attr_str.split(':')[0].strip().lower()
    
    def _normalize_rel_type(self, rel_type: str) -> str:
        """Normalize relationship types."""
        mapping = {
            '<|--': 'INHERITANCE',
            '--|>': 'INHERITANCE',
            '*--': 'COMPOSITION',
            '--*': 'COMPOSITION',
            'o--': 'AGGREGATION',
            '--o': 'AGGREGATION',
            '--': 'ASSOCIATION',
            '<--': 'ASSOCIATION',
            '-->': 'ASSOCIATION'
        }
        return mapping.get(rel_type, 'ASSOCIATION')
    
    def _calculate_metrics(
        self, 
        gold_set: set, 
        pred_set: set
    ) -> EvaluationMetrics:
        """
        Calculate precision, recall, and F1 scores.
        
        Args:
            gold_set: Set of gold standard elements
            pred_set: Set of predicted elements
            
        Returns:
            EvaluationMetrics object
        """
        tp = len(gold_set.intersection(pred_set))
        fp = len(pred_set - gold_set)
        fn = len(gold_set - pred_set)
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
        
        return EvaluationMetrics(
            precision=round(precision, 2),
            recall=round(recall, 2),
            f1=round(f1, 2)
        )
    
    def get_metrics(self) -> Dict[str, EvaluationMetrics]:
        """
        Get all evaluation metrics.
        
        Returns:
            Dictionary with metrics for classes, attributes, and relationships
        """
        # Classes 
        gold_classes = {c.lower() for c in self.gold_parser.classes.keys()}
        pred_classes = {c.lower() for c in self.pred_parser.classes.keys()}
        
        # Attributes 
        gold_attrs = set()
        for cls, info in self.gold_parser.classes.items():
            for attr in info['attributes']:
                gold_attrs.add((cls.lower(), self._normalize_attr(attr)))
        
        pred_attrs = set()
        for cls, info in self.pred_parser.classes.items():
            for attr in info['attributes']:
                pred_attrs.add((cls.lower(), self._normalize_attr(attr)))
        
        # Relationships (type + direction only)
        gold_rels = {
            (r['source'].lower(), r['target'].lower(), self._normalize_rel_type(r['type']))
            for r in self.gold_parser.relationships
            if r.get('source') and r.get('target')
        }
        pred_rels = {
            (r['source'].lower(), r['target'].lower(), self._normalize_rel_type(r['type']))
            for r in self.pred_parser.relationships
            if r.get('source') and r.get('target')
        }
        
        # Cardinalities 
        gold_rels_card = {
            (r['source'].lower(), r['target'].lower(), self._normalize_rel_type(r['type']),
             (r['cardinality_source'] or '').strip(), (r['cardinality_target'] or '').strip())
            for r in self.gold_parser.relationships
            if r.get('source') and r.get('target')
        }
        pred_rels_card = {
            (r['source'].lower(), r['target'].lower(), self._normalize_rel_type(r['type']),
             (r['cardinality_source'] or '').strip(), (r['cardinality_target'] or '').strip())
            for r in self.pred_parser.relationships
            if r.get('source') and r.get('target')
        }
        
        result = {
            "classes": self._calculate_metrics(gold_classes, pred_classes),
            "attributes": self._calculate_metrics(gold_attrs, pred_attrs),
            "relationships": self._calculate_metrics(gold_rels, pred_rels),
            "cardinalities": self._calculate_metrics(gold_rels_card, pred_rels_card)
        }
        
        return result


def evaluate_diagram(
    gold_standard: str,
    generated_diagram: str
) -> Dict[str, EvaluationMetrics]:
    """
    Evaluate a generated diagram against gold standard.
    
    Args:
        gold_standard: Gold standard PlantUML code
        generated_diagram: Generated PlantUML code
        
    Returns:
        Dictionary of evaluation metrics
    """
    evaluator = DiagramEvaluator(gold_standard, generated_diagram)
    return evaluator.get_metrics()

In [620]:
gold_standard = test_exercises[test_idx]["solution_plantuml"]
generated_diagram = final_output["current_diagram"]

metrics = evaluate_diagram(gold_standard, generated_diagram)

print("="*60)
print("EVALUATION METRICS")
print("="*60)
print(f"\nClasses:       {metrics['classes']}")
print(f"Attributes:    {metrics['attributes']}")
print(f"Relationships: {metrics['relationships']}")
print(f"Cardinalities: {metrics['cardinalities']}")

weighted_avg_f1 = (
    metrics['classes'].f1 * 0.4 + 
    metrics['attributes'].f1 * 0.3 + 
    metrics['relationships'].f1 * 0.3
    # metrics['cardinalities'].f1 * 0.2
)

print(f"\n{'='*60}")
print(f"OVERALL F1 SCORE: {weighted_avg_f1:.2f}")
print(f"{'='*60}")


EVALUATION METRICS

Classes:       P=0.83, R=0.71, F1=0.77
Attributes:    P=0.30, R=0.27, F1=0.29
Relationships: P=1.00, R=0.29, F1=0.44
Cardinalities: P=0.00, R=0.00, F1=0.00

OVERALL F1 SCORE: 0.53
