In [45]:
import openai
import os
import random
from dotenv import load_dotenv
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")

In [46]:
from llama_index.core.workflow import (
    Context,
    StartEvent,
    StopEvent,
    Workflow,
    step,
    Event,
)
from llama_index.utils.workflow import draw_all_possible_flows

In [47]:
class FirstEvent(Event):
    first_output: str

class SecondEvent(Event):
    second_output: str

class LoopEvent(Event):
    loop_output: str

In [48]:
class MyWorkflow(Workflow):
    @step(pass_context=True)
    async def step_one(self, ctx: Context, ev: StartEvent | LoopEvent) -> FirstEvent | LoopEvent:
        if random.randint(0, 1) == 0:
            print("Bad thing happened")
            return LoopEvent(loop_output="Back to step one.")
        else:
            print("Good thing happened")
            return FirstEvent(first_output="First step complete.")

    @step(pass_context=True)
    async def step_two(self, ctx: Context, ev: FirstEvent) -> SecondEvent:
        print(ev.first_output)
        return SecondEvent(second_output="Second step complete.")

    @step(pass_context=True)
    async def step_three(self, ctx: Context, ev: SecondEvent) -> StopEvent:
        print(ev.second_output)
        return StopEvent(result="Workflow complete.")

In [49]:
w = MyWorkflow(timeout=10, verbose=False)
result = await w.run(first_input="Start the workflow.")
print(result)

Bad thing happened
Bad thing happened
Good thing happened
First step complete.
Second step complete.
Workflow complete.


In [50]:
draw_all_possible_flows(MyWorkflow, filename="multi_step_workflow.html")

multi_step_workflow.html


In [51]:
class BranchA1Event(Event):
    payload: str


class BranchA2Event(Event):
    payload: str


class BranchB1Event(Event):
    payload: str


class BranchB2Event(Event):
    payload: str


class BranchWorkflow(Workflow):
    @step(pass_context=True)
    async def start(self, ctx: Context, ev: StartEvent) -> BranchA1Event | BranchB1Event:
        if random.randint(0, 1) == 0:
            print("Go to branch A")
            return BranchA1Event(payload="Branch A")
        else:
            print("Go to branch B")
            return BranchB1Event(payload="Branch B")

    @step(pass_context=True)
    async def step_a1(self, ctx: Context, ev: BranchA1Event) -> BranchA2Event:
        print(ev.payload)
        return BranchA2Event(payload=ev.payload)

    @step(pass_context=True)
    async def step_b1(self, ctx: Context, ev: BranchB1Event) -> BranchB2Event:
        print(ev.payload)
        return BranchB2Event(payload=ev.payload)

    @step(pass_context=True)
    async def step_a2(self, ctx: Context, ev: BranchA2Event) -> StopEvent:
        print(ev.payload)
        return StopEvent(result="Branch A complete.")

    @step(pass_context=True)
    async def step_b2(self, ctx: Context, ev: BranchB2Event) -> StopEvent:
        print(ev.payload)
        return StopEvent(result="Branch B complete.")

In [52]:
class SetupEvent(Event):
    query: str


class StepTwoEvent(Event):
    query: str


class StatefulFlow(Workflow):
    @step(pass_context=True)
    async def start(
        self, ctx: Context, ev: StartEvent
    ) -> SetupEvent | StepTwoEvent:
        if "some_database" not in ctx.data:
            print("Need to load data")
            return SetupEvent(query=ev.query)

        # do something with the query
        return StepTwoEvent(query=ev.query)

    @step(pass_context=True)
    async def setup(self, ctx: Context, ev: SetupEvent) -> StartEvent:
        # load data
        ctx.data["some_database"] = [1, 2, 3]
        return StartEvent(query=ev.query)
    
    @step(pass_context=True)
    async def step_two(self, ctx: Context, ev: StepTwoEvent) -> StopEvent:
        # do something with the data
        print("Data is ", ctx.data["some_database"])

        return StopEvent(result=ctx.data["some_database"][1])

In [53]:
w = StatefulFlow(timeout=10, verbose=False)
result = await w.run(query="Some query")
print(result)

Need to load data
Data is  [1, 2, 3]
2


In [54]:
import asyncio
class ParallelFlow(Workflow):
    @step(pass_context=True)
    async def start(self, ctx: Context, ev: StartEvent) -> StepTwoEvent:
        self.send_event(StepTwoEvent(query="Query 1"))
        self.send_event(StepTwoEvent(query="Query 2"))
        self.send_event(StepTwoEvent(query="Query 3"))

    @step(pass_context=True, num_workers=4)
    async def step_two(self, ctx: Context, ev: StepTwoEvent) -> StopEvent:
        print("Running slow query ", ev.query)
        await asyncio.sleep(random.randint(1, 5))

        return StopEvent(result=ev.query)

In [55]:
w = ParallelFlow(timeout=10, verbose=False)
result = await w.run()
print(result)

Running slow query  Query 1
Running slow query  Query 2
Running slow query  Query 3
Query 2


In [56]:
class SetupEvent(Event):
    query: str

class StepTwoEvent(Event):
    query: str

class StepThreeEvent(Event):
    result: str

class ConcurrentFlow(Workflow):
    @step(pass_context=True)
    async def start(self, ctx: Context, ev: StartEvent) -> StepTwoEvent:
        self.send_event(StepTwoEvent(query="Query 1"))
        self.send_event(StepTwoEvent(query="Query 2"))
        self.send_event(StepTwoEvent(query="Query 3"))

    @step(pass_context=True, num_workers=4)
    async def step_two(self, ctx: Context, ev: StepTwoEvent) -> StepThreeEvent:
        print("Running query ", ev.query)
        await asyncio.sleep(random.randint(1, 5))
        return StepThreeEvent(result=ev.query)

    @step(pass_context=True)
    async def step_three(self, ctx: Context, ev: StepThreeEvent) -> StopEvent:
        # wait until we receive 3 events
        result = ctx.collect_events(ev, [StepThreeEvent] * 3)
        if result is None:
            return None

        print(result)
        return StopEvent(result="Done")

In [57]:
w = ConcurrentFlow(timeout=10, verbose=False)
result = await w.run()
print(result)

Running query  Query 1
Running query  Query 2
Running query  Query 3
[StepThreeEvent(result='Query 3'), StepThreeEvent(result='Query 1'), StepThreeEvent(result='Query 2')]
Done


In [58]:
class StepAEvent(Event):
    query: str

class StepBEvent(Event):
    query: str

class StepCEvent(Event):
    query: str

class StepACompleteEvent(Event):
    result: str

class StepBCompleteEvent(Event):
    result: str

class StepCCompleteEvent(Event):
    result: str

class ConcurrentFlow(Workflow):
    @step(pass_context=True)
    async def start(
        self, ctx: Context, ev: StartEvent
    ) -> StepAEvent | StepBEvent | StepCEvent:
        self.send_event(StepAEvent(query="Query 1"))
        self.send_event(StepBEvent(query="Query 2"))
        self.send_event(StepCEvent(query="Query 3"))

    @step(pass_context=True)
    async def step_a(self, ctx: Context, ev: StepAEvent) -> StepACompleteEvent:
        print("Doing something A-ish")
        return StepACompleteEvent(result=ev.query)

    @step(pass_context=True)
    async def step_b(self, ctx: Context, ev: StepBEvent) -> StepBCompleteEvent:
        print("Doing something B-ish")
        return StepBCompleteEvent(result=ev.query)

    @step(pass_context=True)
    async def step_c(self, ctx: Context, ev: StepCEvent) -> StepCCompleteEvent:
        print("Doing something C-ish")
        return StepCCompleteEvent(result=ev.query)

    @step(pass_context=True)
    async def step_three(
        self,
        ctx: Context,
        ev: StepACompleteEvent | StepBCompleteEvent | StepCCompleteEvent,
    ) -> StopEvent:
        print("Received event ", ev.result)

        # wait until we receive 3 events
        if (
            ctx.collect_events(
                ev,
                [StepCCompleteEvent, StepACompleteEvent, StepBCompleteEvent],
            )
            is None
        ):
            return None

        # do something with all 3 results together
        return StopEvent(result="Done")

In [59]:
w = ConcurrentFlow(timeout=10, verbose=False)
result = await w.run()
print(result)

Doing something A-ish
Doing something B-ish
Doing something C-ish
Received event  Query 1
Received event  Query 2
Received event  Query 3
Done


In [61]:
draw_all_possible_flows(ConcurrentFlow, filename="multi_step_workflow.html")

multi_step_workflow.html
