In [1]:
# Install required dependencies (run this first in Colab)
!pip install opentelemetry-api opentelemetry-sdk

import asyncio, warnings, copy, time, json, uuid
from datetime import datetime
from typing import Dict, Any, List, Optional
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor
from opentelemetry.trace import Status, StatusCode

# Set up OpenTelemetry
trace.set_tracer_provider(TracerProvider())
tracer = trace.get_tracer(__name__)
span_processor = SimpleSpanProcessor(ConsoleSpanExporter())
trace.get_tracer_provider().add_span_processor(span_processor)

# ==============================================================================
# ORIGINAL BASE FRAMEWORK (UNTOUCHED FOR REFERENCE)
# ==============================================================================

class BaseNode:
    def __init__(self): self.params,self.successors={},{}
    def set_params(self,params): self.params=params
    def next(self,node,action="default"):
        if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'")
        self.successors[action]=node; return node
    def prep(self,shared): pass
    def exec(self,prep_res): pass
    def post(self,shared,prep_res,exec_res): pass
    def _exec(self,prep_res): return self.exec(prep_res)
    def _run(self,shared): p=self.prep(shared); e=self._exec(p); return self.post(shared,p,e)
    def run(self,shared):
        if self.successors: warnings.warn("Node won't run successors. Use Flow.")
        return self._run(shared)
    def __rshift__(self,other): return self.next(other)
    def __sub__(self,action):
        if isinstance(action,str): return _ConditionalTransition(self,action)
        raise TypeError("Action must be a string")

class _ConditionalTransition:
    def __init__(self,src,action): self.src,self.action=src,action
    def __rshift__(self,tgt): return self.src.next(tgt,self.action)

class Node(BaseNode):
    def __init__(self,max_retries=1,wait=0): super().__init__(); self.max_retries,self.wait=max_retries,wait
    def exec_fallback(self,prep_res,exc): raise exc
    def _exec(self,prep_res):
        for self.cur_retry in range(self.max_retries):
            try: return self.exec(prep_res)
            except Exception as e:
                if self.cur_retry==self.max_retries-1: return self.exec_fallback(prep_res,e)
                if self.wait>0: time.sleep(self.wait)

class BatchNode(Node):
    def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])]

class Flow(BaseNode):
    def __init__(self,start=None): super().__init__(); self.start_node=start
    def start(self,start): self.start_node=start; return start
    def get_next_node(self,curr,action):
        nxt=curr.successors.get(action or "default")
        if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
        return nxt
    def _orch(self,shared,params=None):
        curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
        while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
        return last_action
    def _run(self,shared): p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o)
    def post(self,shared,prep_res,exec_res): return exec_res

# ==============================================================================
# AUDIT LOGGER
# ==============================================================================

class AuditLogger:
    def __init__(self):
        self.events = []
        self.edges = []
        self.session_id = str(uuid.uuid4())
        self.start_time = datetime.now()

    def log_event(self, node_name: str, event_type: str, **kwargs):
        """Log a node execution event"""
        event = {
            "session_id": self.session_id,
            "timestamp": datetime.now().isoformat(),
            "node_name": node_name,
            "event_type": event_type,
            "event_id": str(uuid.uuid4()),
            **kwargs
        }
        self.events.append(event)

    def log_edge(self, from_node: str, to_node: str, action: str = "default"):
        """Log a flow transition"""
        edge = {
            "session_id": self.session_id,
            "timestamp": datetime.now().isoformat(),
            "from_node": from_node,
            "to_node": to_node,
            "action": action,
            "edge_id": str(uuid.uuid4())
        }
        self.edges.append(edge)

    def export_logs(self, filename: Optional[str] = None) -> Dict:
        """Export logs to JSON format"""
        if filename is None:
            filename = f"audit_log_{self.session_id[:8]}.json"

        log_data = {
            "session_id": self.session_id,
            "start_time": self.start_time.isoformat(),
            "end_time": datetime.now().isoformat(),
            "events": self.events,
            "edges": self.edges,
            "summary": {
                "total_events": len(self.events),
                "total_edges": len(self.edges),
                "unique_nodes": len(set(e["node_name"] for e in self.events))
            }
        }

        with open(filename, 'w') as f:
            json.dump(log_data, f, indent=2)

        return log_data

    def get_summary(self) -> Dict:
        """Get a summary of the audit log"""
        return {
            "session_id": self.session_id,
            "total_events": len(self.events),
            "total_edges": len(self.edges),
            "unique_nodes": len(set(e["node_name"] for e in self.events)),
            "duration_seconds": (datetime.now() - self.start_time).total_seconds()
        }

# ==============================================================================
# AUDITED NODE CLASSES
# ==============================================================================

class AuditedBaseNode(BaseNode):
    def __init__(self, name: str = None, logger: AuditLogger = None):
        super().__init__()
        self.name = name or f"{self.__class__.__name__}_{id(self)}"
        self.logger = logger

    def _run_with_audit(self, shared):
        """Run node with full auditing and tracing"""
        start_time = time.time()

        with tracer.start_as_current_span(f"node.{self.name}") as span:
            span.set_attribute("node.name", self.name)
            span.set_attribute("node.type", self.__class__.__name__)

            try:
                # Log start
                if self.logger:
                    self.logger.log_event(
                        self.name, "node_start",
                        node_type=self.__class__.__name__,
                        shared_keys=list(shared.keys()) if shared else []
                    )

                # Prep phase
                with tracer.start_as_current_span("prep") as prep_span:
                    prep_start = time.time()
                    prep_res = self.prep(shared)
                    prep_latency = (time.time() - prep_start) * 1000
                    prep_span.set_attribute("latency_ms", prep_latency)

                # Exec phase
                with tracer.start_as_current_span("exec") as exec_span:
                    exec_start = time.time()
                    exec_res = self._exec(prep_res)
                    exec_latency = (time.time() - exec_start) * 1000
                    exec_span.set_attribute("latency_ms", exec_latency)
                    if hasattr(self, 'cur_retry'):
                        exec_span.set_attribute("retries", self.cur_retry + 1)

                # Post phase
                with tracer.start_as_current_span("post") as post_span:
                    post_start = time.time()
                    result = self.post(shared, prep_res, exec_res)
                    post_latency = (time.time() - post_start) * 1000
                    post_span.set_attribute("latency_ms", post_latency)

                total_latency = (time.time() - start_time) * 1000
                span.set_attribute("total_latency_ms", total_latency)
                span.set_status(Status(StatusCode.OK))

                # Log success
                if self.logger:
                    self.logger.log_event(
                        self.name, "node_success",
                        prep_latency_ms=prep_latency,
                        exec_latency_ms=exec_latency,
                        post_latency_ms=post_latency,
                        total_latency_ms=total_latency,
                        result_type=type(result).__name__,
                        retries=getattr(self, 'cur_retry', 0) + 1 if hasattr(self, 'cur_retry') else 1
                    )

                return result

            except Exception as e:
                span.set_status(Status(StatusCode.ERROR, str(e)))
                span.set_attribute("error.type", type(e).__name__)
                span.set_attribute("error.message", str(e))

                if self.logger:
                    self.logger.log_event(
                        self.name, "node_error",
                        error_type=type(e).__name__,
                        error_message=str(e),
                        total_latency_ms=(time.time() - start_time) * 1000
                    )
                raise

    def _orch_with_audit(self, shared, params=None):
        """Orchestrate node execution with proper edge logging and auditing"""
        curr = self.start_node
        p = params or {**self.params}
        last_action = None

        print(f"🔍 FLOW DEBUG: Starting orchestration with node: {curr.name if hasattr(curr, 'name') else str(curr)}")

        while curr:
            print(f"🔍 FLOW DEBUG: Current node: {curr.name if hasattr(curr, 'name') else str(curr)}")

            # Ensure logger is attached to current node
            if hasattr(curr, 'logger'):
                curr.logger = self.logger

            curr.set_params(p)

            # Run the node with proper audit logging
            if hasattr(curr, '_run_with_audit'):
                print(f"🔍 FLOW DEBUG: Running {curr.name} with _run_with_audit")
                last_action = curr._run_with_audit(shared)
            else:
                print(f"🔍 FLOW DEBUG: Running {curr.name} with _run (fallback)")
                last_action = curr._run(shared)

            print(f"🔍 FLOW DEBUG: Node {curr.name} returned action: {last_action}")

            # Get next node
            next_node = curr.successors.get(last_action or "default")
            print(f"🔍 FLOW DEBUG: Next node for action '{last_action}': {next_node.name if next_node and hasattr(next_node, 'name') else str(next_node)}")

            # Log the edge BEFORE moving to next node
            if next_node and self.logger:
                curr_name = curr.name if hasattr(curr, 'name') else str(curr)
                next_name = next_node.name if hasattr(next_node, 'name') else str(next_node)

                print(f"🔍 FLOW DEBUG: Logging edge: {curr_name} -> {next_name}")
                self.logger.log_edge(
                    from_node=curr_name,
                    to_node=next_name,
                    action=last_action or "default"
                )

            # Check for flow end
            if not next_node and curr.successors:
                warnings.warn(f"Flow ends: '{last_action}' not found in {list(curr.successors)}")

            # Move to next node
            curr = next_node

        print(f"🔍 FLOW DEBUG: Orchestration complete, final action: {last_action}")
        return last_action

    def run(self, shared):
        if self.successors:
            warnings.warn("Node won't run successors. Use Flow.")
        return self._run_with_audit(shared)

class AuditedNode(AuditedBaseNode, Node):
    def __init__(self, name: str = None, logger: AuditLogger = None, max_retries=1, wait=0):
        AuditedBaseNode.__init__(self, name, logger)
        Node.__init__(self, max_retries, wait)

    def _run(self, shared):
        return self._run_with_audit(shared)

class AuditedBatchNode(AuditedNode, BatchNode):
    def __init__(self, name: str = None, logger: AuditLogger = None, max_retries=1, wait=0):
        super().__init__(name, logger, max_retries, wait)

class AuditedFlow(AuditedBaseNode, Flow):
    def __init__(self, name: str = None, logger: AuditLogger = None, start=None):
        AuditedBaseNode.__init__(self, name, logger)
        Flow.__init__(self, start)

    def run(self, shared):
        """Override to use orchestration instead of treating flow as single node"""
        if self.successors:
            warnings.warn("Node won't run successors. Use Flow.")

        # For flows, we want orchestration, not single-node execution
        return self._run_flow_with_audit(shared)

    def _run_flow_with_audit(self, shared):
        """Run flow with orchestration and auditing"""
        start_time = time.time()

        print(f"🔍 FLOW DEBUG: Starting flow {self.name}")

        with tracer.start_as_current_span(f"flow.{self.name}") as span:
            span.set_attribute("flow.name", self.name)
            span.set_attribute("flow.type", self.__class__.__name__)

            try:
                if self.logger:
                    self.logger.log_event(
                        self.name, "flow_start",
                        node_type=self.__class__.__name__,
                        shared_keys=list(shared.keys()) if shared else []
                    )

                # Use the Flow's orchestration logic directly
                result = self._orch_with_audit(shared)

                total_latency = (time.time() - start_time) * 1000
                span.set_attribute("total_latency_ms", total_latency)
                span.set_status(Status(StatusCode.OK))

                if self.logger:
                    self.logger.log_event(
                        self.name, "flow_success",
                        total_latency_ms=total_latency,
                        result_type=type(result).__name__
                    )

                return result

            except Exception as e:
                span.set_status(Status(StatusCode.ERROR, str(e)))
                if self.logger:
                    self.logger.log_event(
                        self.name, "flow_error",
                        error_type=type(e).__name__,
                        error_message=str(e)
                    )
                raise

# ==============================================================================
# EXAMPLE DEMO NODES
# ==============================================================================

class LoadTextNode(AuditedNode):
    """Node 1: Load text into shared storage"""
    def exec(self, prep_res):
        # Simulate loading text (in real scenario, might load from file/API)
        return "loaded"

    def post(self, shared, prep_res, exec_res):
        if "text" not in shared:
            shared["text"] = "this is a simple demo"
        print(f"✓ LoadTextNode: Loaded text: '{shared['text']}'")
        return "continue"

class SplitTextNode(AuditedBatchNode):
    """Node 2: Split text into chunks"""
    def prep(self, shared):
        text = shared.get("text", "")
        chunks = text.split()  # Split by words
        print(f"✓ SplitTextNode: Splitting '{text}' into {len(chunks)} chunks")
        return chunks

    def exec(self, chunk):
        # Process each chunk individually
        return {"original": chunk, "length": len(chunk)}

    def post(self, shared, prep_res, exec_res):
        shared["chunks"] = exec_res
        print(f"✓ SplitTextNode: Created {len(exec_res)} chunk objects")
        return "continue"

class TransformChunksNode(AuditedNode):
    """Node 3: Transform chunks (uppercase them)"""
    def prep(self, shared):
        return shared.get("chunks", [])

    def exec(self, chunks):
        transformed = []
        for chunk in chunks:
            transformed.append({
                "original": chunk["original"],
                "transformed": chunk["original"].upper(),
                "length": chunk["length"]
            })
        return transformed

    def post(self, shared, prep_res, exec_res):
        shared["transformed_chunks"] = exec_res
        print(f"✓ TransformChunksNode: Transformed {len(exec_res)} chunks to uppercase")
        return "continue"

class CombineChunksNode(AuditedNode):
    """Node 4: Combine chunks back into one string"""
    def prep(self, shared):
        return shared.get("transformed_chunks", [])

    def exec(self, chunks):
        combined = " ".join(chunk["transformed"] for chunk in chunks)
        return combined

    def post(self, shared, prep_res, exec_res):
        shared["final_result"] = exec_res
        print(f"✓ CombineChunksNode: Final result: '{exec_res}'")
        return "complete"

# ==============================================================================
# DEMO EXECUTION
# ==============================================================================

def run_demo():
    """Run the complete PocketFlow-Audit demo"""
    print("🚀 Starting PocketFlow-Audit Demo")
    print("=" * 50)

    # Initialize audit logger
    logger = AuditLogger()

    # Create nodes with auditing
    load_node = LoadTextNode("LoadText", logger)
    split_node = SplitTextNode("SplitText", logger)
    transform_node = TransformChunksNode("TransformChunks", logger)
    combine_node = CombineChunksNode("CombineChunks", logger)

    # Build the flow with proper action transitions
    flow = AuditedFlow("DemoFlow", logger)
    flow.start(load_node)

    # Use explicit action matching - nodes return "continue"
    load_node - "continue" >> split_node
    split_node - "continue" >> transform_node
    transform_node - "continue" >> combine_node

    # Initialize shared state
    shared = {"text": "this is a simple demo"}

    print("\n📋 Executing Flow...")
    print("-" * 30)

    # Run the flow
    final_result = flow.run(shared)

    print("\n✅ Flow Execution Complete!")
    print("=" * 50)

    # Display results
    print(f"\n📊 Final Shared State:")
    for key, value in shared.items():
        if isinstance(value, str):
            print(f"  {key}: '{value}'")
        else:
            print(f"  {key}: {type(value).__name__} with {len(value) if hasattr(value, '__len__') else '?'} items")

    print(f"\n🎯 Flow Result: {final_result}")

    # Export and display audit logs
    print(f"\n📝 Audit Log Summary:")
    summary = logger.get_summary()
    for key, value in summary.items():
        print(f"  {key}: {value}")

    # Export full logs
    log_data = logger.export_logs()
    print(f"\n💾 Full audit log exported to: audit_log_{logger.session_id[:8]}.json")

    # Show sample events
    print(f"\n🔍 Sample Audit Events (first 10):")
    for i, event in enumerate(log_data["events"][:10]):
        print(f"  Event {i+1}: {event['node_name']} -> {event['event_type']} "
              f"({event['timestamp'][:19]})")

    if len(log_data["events"]) > 10:
        print(f"  ... and {len(log_data['events']) - 10} more events")

    print(f"\n🔗 Flow Transitions:")
    if log_data["edges"]:
        for edge in log_data["edges"]:
            print(f"  {edge['from_node']} --[{edge['action']}]--> {edge['to_node']}")
    else:
        print("  No edges logged - check flow execution")

    return shared, logger, log_data

# ==============================================================================
# RUN THE DEMO
# ==============================================================================

if __name__ == "__main__":
    # Execute the demo
    shared_result, audit_logger, full_logs = run_demo()

    print("\n" + "="*50)
    print("🎉 PocketFlow-Audit Demo Complete!")
    print("✓ Auditing: All node executions logged")
    print("✓ Tracing: OpenTelemetry spans printed above")
    print("✓ Export: JSON audit log saved to file")
    print("✓ Ready for extension to RAG/chatbots!")
    print("="*50)

🚀 Starting PocketFlow-Audit Demo

📋 Executing Flow...
------------------------------
🔍 FLOW DEBUG: Starting flow DemoFlow
🔍 FLOW DEBUG: Starting orchestration with node: LoadText
🔍 FLOW DEBUG: Current node: LoadText
🔍 FLOW DEBUG: Running LoadText with _run_with_audit
{
    "name": "prep",
    "context": {
        "trace_id": "0x0d4f5f3209526736b1101e0055c72b1f",
        "span_id": "0xd88207cdbe3368c2",
        "trace_state": "[]"
    },
    "kind": "SpanKind.INTERNAL",
    "parent_id": "0x07315684c16e637e",
    "start_time": "2025-09-15T22:57:24.664544Z",
    "end_time": "2025-09-15T22:57:24.664570Z",
    "status": {
        "status_code": "UNSET"
    },
    "attributes": {
        "latency_ms": 0.007867813110351562
    },
    "events": [],
    "links": [],
    "resource": {
        "attributes": {
            "telemetry.sdk.language": "python",
            "telemetry.sdk.name": "opentelemetry",
            "telemetry.sdk.version": "1.36.0",
            "service.name": "unknown_service