In [4]:
#!/usr/bin/env python3
import yaml
import json
import asyncio
import websockets
import time


import sys
from typing import Dict, Any, List, Optional
from datetime import datetime



import nest_asyncio
nest_asyncio.apply()  # This allows asyncio.run() inside Jupyter


In [2]:
# Load the YAML flow file
yaml_file_path = "./execution_flows/simple-ai-flow.yaml"  # Update this to your file location
print(f"Loading flow from {yaml_file_path}...")

with open(yaml_file_path, 'r') as file:
    yaml_data = yaml.safe_load(file)

flow_id = yaml_data.get("flow_id", "flow-" + yaml_file_path.split("/")[-1].split(".")[0])
print(f"Flow ID: {flow_id}")


Loading flow from ./execution_flows/simple-ai-flow.yaml...
Flow ID: simple_ai_flow


In [3]:
flow_id

'simple_ai_flow'

In [5]:
websocket_url = "ws://localhost:8000/ws/flow/" + flow_id
print(f"WebSocket URL: {websocket_url}")

WebSocket URL: ws://localhost:8000/ws/flow/simple_ai_flow


In [None]:

class FlowStreamer:
    def __init__(self, flow_data, websocket_url=None):
        """
        Initialize the Flow Streamer.
        
        Args:
            flow_data: Dictionary containing flow definition and initial inputs
            websocket_url: WebSocket URL to connect to (if None, will be derived from flow_id)
        """
        self.flow_definition = flow_data.get("flow_definition", {})
        self.initial_inputs = flow_data.get("initial_inputs", {})
        self.flow_id = flow_data.get("flow_id", "unknown-flow")
        
        # Use provided URL or build from flow_id
        self.websocket_url = websocket_url or f"ws://localhost:8000/ws/execute/{self.flow_id}"
        
        # Store of all element data
        self.elements = []
        self.element_dict = {}  # For quicker lookup by ID
        self.execution_order = []
        self.llm_chunks = {}  # element_id -> accumulated chunks
        
    def build_element_dictionary(self):
        """Build a dictionary of elements from the flow definition"""
        elements = self.flow_definition.get('elements', {})
        for element_id, element in elements.items():
            element_data = {
                'element_id': element_id,
                'name': element.get('name', element_id),
                'type': element.get('type', 'unknown'),
                'description': element.get('description', ''),
                'input_schema': element.get('input_schema', {}),
                'output_schema': element.get('output_schema', {}),
                'inputs': {},
                'outputs': {},
                'status': 'waiting',
                'streamed_data': '',
                'start_time': None,
                'end_time': None,
                'execution_time': None,
                'error': None
            }
            self.elements.append(element_data)
            self.element_dict[element_id] = element_data
    
    def create_payload(self, element_id, name, load, is_end_element=False):
        """
        Create a structured payload for streaming.
        
        Args:
            element_id: ID of the element
            name: Name of the element
            load: Content to stream
            is_end_element: Whether this is an end element payload
        
        Returns:
            Dictionary payload
        """
        if is_end_element:
            return {
                "element_id": element_id,
                "is_end_element": True,
                "load": load
            }
        else:
            return {
                "element_id": element_id,
                "name": name,
                "load": load
            }
    
    def stream_payload(self, payload):
        """
        Stream a payload to the frontend (just prints for now).
        
        Args:
            payload: The payload to stream
        """
        # For now, just print the payload in a structured format
        element_id = payload.get("element_id", "unknown")
        load = payload.get("load", "")
        
        if payload.get("is_end_element"):
            print(f"\n{element_id} [END]\n<{load}>\n{'_' * 40}")
        else:
            name = payload.get("name", "")
            print(f"\n{element_id} - {name}\n<{load}>\n{'_' * 40}")
        
        # In a real implementation, this would send to a frontend:
        # await websocket.send(json.dumps(payload))
        
    def handle_llm_chunk(self, data):
        """
        Handle LLM chunk data.
        
        Args:
            data: Event data containing element_id and content
        """
        element_id = data.get('element_id')
        content = data.get('content', '')
        
        # Print original content to console (as is)
        print(content, end='', flush=True)
        
        # Accumulate chunks for the element
        if element_id not in self.llm_chunks:
            self.llm_chunks[element_id] = ''
            
            # Get element name
            element_name = "LLM"
            if element_id in self.element_dict:
                element_name = self.element_dict[element_id].get('name', 'LLM')
            
            # Create and stream initial payload for this LLM element
            initial_payload = self.create_payload(
                element_id=element_id,
                name=element_name,
                load="[LLM output starting...]"
            )
            self.stream_payload(initial_payload)
        
        self.llm_chunks[element_id] += content
        
        # Update the element's streamed data
        if element_id in self.element_dict:
            self.element_dict[element_id]['streamed_data'] = self.llm_chunks[element_id]
    
    def handle_element_event(self, event_type, data):
        """
        Handle element-related events.
        
        Args:
            event_type: Type of event (element_started, element_completed, etc.)
            data: Event data
        """
        element_id = data.get('element_id')
        
        if event_type == 'element_started':
            if element_id in self.element_dict:
                element = self.element_dict[element_id]
                element['status'] = 'running'
                element['start_time'] = datetime.now()
                self.execution_order.append(element_id)
                
                # Create payload with element start info
                description = element.get('description', '')
                inputs = element.get('inputs', {})
                
                load = f"Started: {element.get('name')}\n"
                if description:
                    load += f"Description: {description}\n"
                if inputs:
                    load += f"Inputs: {json.dumps(inputs, indent=2)}"
                
                payload = self.create_payload(
                    element_id=element_id,
                    name=element.get('name', ''),
                    load=load
                )
                self.stream_payload(payload)
                
        elif event_type == 'element_completed':
            if element_id in self.element_dict:
                element = self.element_dict[element_id]
                element['status'] = 'completed'
                element['end_time'] = datetime.now()
                element['outputs'] = data.get('outputs', {})
                
                # Calculate execution time
                if element['start_time']:
                    start = element['start_time']
                    end = element['end_time']
                    element['execution_time'] = (end - start).total_seconds()
                
                # Create payload with completion info
                outputs = element['outputs']
                execution_time = element.get('execution_time', 0)
                
                load = f"Completed in {execution_time:.2f} seconds\n"
                load += f"Outputs: {json.dumps(outputs, indent=2)}"
                
                # If it's an LLM element, include complete streamed data
                if element_id in self.llm_chunks:
                    payload = self.create_payload(
                        element_id=element_id,
                        is_end_element=True,
                        load=self.llm_chunks[element_id]
                    )
                else:
                    payload = self.create_payload(
                        element_id=element_id,
                        is_end_element=True,
                        load=load
                    )
                    
                self.stream_payload(payload)
                
        elif event_type == 'element_error':
            if element_id in self.element_dict:
                element = self.element_dict[element_id]
                element['status'] = 'error'
                element['end_time'] = datetime.now()
                element['error'] = data.get('error', 'Unknown error')
                
                # Calculate execution time
                if element['start_time']:
                    start = element['start_time']
                    end = element['end_time']
                    element['execution_time'] = (end - start).total_seconds()
                
                # Create payload with error info
                error = element['error']
                execution_time = element.get('execution_time', 0)
                
                load = f"ERROR after {execution_time:.2f} seconds\n"
                load += f"Error: {error}"
                
                payload = self.create_payload(
                    element_id=element_id,
                    is_end_element=True,
                    load=load
                )
                self.stream_payload(payload)
                
    def process_event(self, event):
        """
        Process an event from the websocket.
        
        Args:
            event: Event object from WebSocket
        """
        event_type = event.get('type', '')
        data = event.get('data', {})
        
        # Handle different event types
        if event_type == 'llm_chunk':
            self.handle_llm_chunk(data)
                
        elif event_type in ['element_started', 'element_completed', 'element_error']:
            self.handle_element_event(event_type, data)
                
        elif event_type == 'flow_started':
            print(f"\nFlow {self.flow_id} started at {datetime.now()}")
                
        elif event_type == 'flow_completed':
            flow_id = data.get('flow_id', self.flow_id)
            print(f"\nFlow {flow_id} completed at {datetime.now()}")
            
        elif event_type == 'flow_error':
            error = data.get('error', 'Unknown error')
            print(f"\nFlow error: {error}")
    
    async def stream_flow(self):
        """
        Connect to WebSocket and stream flow execution.
        
        Returns:
            List of element data objects
        """
        # Build element dictionary
        self.build_element_dictionary()
        
        print(f"Connecting to WebSocket at {self.websocket_url}")
        
        try:
            async with websockets.connect(self.websocket_url) as websocket:
                # Receive ready message
                ready_msg = await websocket.recv()
                print(f"Server: {ready_msg}")
                
                # Send flow definition
                flow_definition_str = json.dumps(self.flow_definition)
                await websocket.send(flow_definition_str)
                print("Sent flow definition")
                
                # Receive acknowledgment
                ack1 = await websocket.recv()
                print(f"Server: {ack1}")
                
                # Send initial inputs
                initial_inputs_str = json.dumps(self.initial_inputs)
                await websocket.send(initial_inputs_str)
                print("Sent initial inputs")
                
                # Receive acknowledgment
                ack2 = await websocket.recv()
                print(f"Server: {ack2}")
                
                # Send config (null in this case)
                await websocket.send("null")
                print("Sent null config")
                
                # Receive final acknowledgment
                ack3 = await websocket.recv()
                print(f"Server: {ack3}")
                
                print("Flow execution starting. Streaming events...")
                
                # Now receive streaming events
                try:
                    while True:
                        message = await websocket.recv()
                        event = json.loads(message)
                        
                        # Process the event
                        self.process_event(event)
                        
                        # Set inputs for elements (simplified approach)
                        if event['type'] == 'element_started':
                            element_id = event['data'].get('element_id')
                            # Try to find inputs from dependencies that just completed
                            for exec_element_id in reversed(self.execution_order):
                                if exec_element_id != element_id and self.element_dict[exec_element_id]['status'] == 'completed':
                                    # This is a simplification - in a real system, you'd track 
                                    # the specific input mappings between elements
                                    if element_id in self.element_dict:
                                        self.element_dict[element_id]['inputs'] = self.element_dict[exec_element_id]['outputs']
                                        break
                                    
                except websockets.exceptions.ConnectionClosed:
                    print("\nWebSocket connection closed")
                
                # Return the complete elements list
                return self.elements
                
        except Exception as e:
            print(f"Error: {e}")
            return self.elements




async def stream_from_dict(flow_data, websocket_url=None):
    """
    Helper function to stream flow execution from a dictionary.
    
    Args:
        flow_data: Dictionary with flow_definition and initial_inputs
        websocket_url: Optional WebSocket URL
        
    Returns:
        List of element data
    """
    try:
        streamer = FlowStreamer(flow_data, websocket_url)
        return await streamer.stream_flow()
        
    except Exception as e:
        print(f"Error streaming flow: {e}")
        return []


if __name__ == "__main__":
    # Default YAML file path
    yaml_file_path = "./execution_flows/simple-ai-flow.yaml"
    
    # Check if path is provided as argument
    if len(sys.argv) > 1:
        yaml_file_path = sys.argv[1]
    
    # Run the event loop with YAML file
    elements = asyncio.run(stream_from_yaml(yaml_file_path))
    
    # Print number of elements processed
    print(f"\nProcessed {len(elements)} elements")
    
    # Print first element as example of the data structure
    if elements:
        print("\nExample of element data structure:")
        example = elements[0]
        print(json.dumps(example, indent=2, default=str))