In [35]:
import json
import pygraphviz as pgv
import os
import dotenv

In [36]:
dotenv.load_dotenv()

True

In [37]:
class ConversationState:
    def __init__(self, name=None, parent=None):
        self.name = name

        self.system = ""
        self.messages = []

        self.transitions = {}
        self.parent = parent
        self.children = []

    def add_message(self, message):
        self.messages.append(message)

    def add_transition(self, trigger, next_state):
        self.transitions[trigger] = next_state

    def add_child(self, child_state):
        self.children.append(child_state)
        child_state.parent = self

    def get_next_state(self, response):
        if response in self.transitions:
            return self.transitions[response]
        elif self.parent:
            return self.parent.get_next_state(response)
        else:
            return None

    def get_root(self):
        if self.parent:
            return self.parent.get_root()
        else:
            return self

    def get_hierarchy_path(self):
        if self.parent and self.parent.name != "root":
            return self.parent.get_hierarchy_path() + "_" + self.name
        else:
            return self.name

In [38]:
class ConversationStateMachine:
    def __init__(self, root_state=None, init_state=None):
        self.root_state = root_state
        self.current_state = init_state

    def transition(self, trigger):
        if trigger in self.current_state.transitions:
            self.current_state = self.current_state.transitions[trigger]
            return self.current_state
        else:
            print(f"invalid trigger '{trigger}' for state {self.current_state.get_hierarchy_path()}")
            return None
        
    def print_state(self):
        print(f"[CSM] self.current_state: {self.current_state}")
        print(f"[CSM] self.current_state.get_hierarchy_path(): {self.current_state.get_hierarchy_path()}")

In [39]:
def initialize_conversation_states(json_data):
    def create_state(state_data, parent=None):
        state = ConversationState(name=state_data["name"], parent=parent)
        state.system = state_data.get("system", "")
        state.messages = state_data.get("messages", [])

        for child_data in state_data.get("children", []):
            child_state = create_state(child_data, parent=state)
            state.add_child(child_state)

        return state

    root_state = create_state(json_data)
    return root_state

In [40]:
def apply_transitions(root_state, transitions_data):
    state_map = {}

    def traverse_and_map_states(state):
        state_map[state.get_hierarchy_path()] = state
        for child in state.children:
            traverse_and_map_states(child)

    def find_state_by_path(path):
        return state_map.get(path)

    traverse_and_map_states(root_state)

    for transition in transitions_data:
        trigger = transition["trigger"]
        source_paths = transition["source"]
        dest_path = transition["dest"]

        if not isinstance(source_paths, list):
            source_paths = [source_paths]

        for source_path in source_paths:
            source_path = source_path
            source_state = find_state_by_path(source_path)
            dest_state = find_state_by_path(dest_path)

            if source_state and dest_state:
                source_state.add_transition(trigger, dest_state)
            else:
                print(f"Warning: Invalid transition - Source: {source_path}, Destination: {dest_path}")

In [41]:
def visualize_conversation_states(root_state, transitions_data):
    graph = pgv.AGraph(directed=True)

    graph.graph_attr['fontname'] = 'Consolas'
    graph.node_attr['fontname'] = 'Consolas'
    graph.node_attr['shape'] = 'box'
    graph.node_attr['style'] = 'rounded'
    graph.edge_attr['fontname'] = 'Consolas'

    def add_state_to_graph(state, parent_subgraph=None):
        if parent_subgraph is None:
            subgraph = graph
        else:
            subgraph = parent_subgraph.add_subgraph(name=f"cluster_{state.get_hierarchy_path()}")
            subgraph.graph_attr['style'] = 'rounded'

        if not (parent_subgraph is None):
            subgraph.add_node(state.get_hierarchy_path(), label=state.name)

        for child in state.children:
            add_state_to_graph(child, subgraph)

    add_state_to_graph(root_state)

    for transition in transitions_data:
        trigger = transition["trigger"]
        source_paths = transition["source"]
        dest_path = transition["dest"]

        if not isinstance(source_paths, list):
            source_paths = [source_paths]

        for source_path in source_paths:
            source_state = graph.get_node(source_path)
            dest_state = graph.get_node(dest_path)

            if source_state and dest_state:
                graph.add_edge(source_state, dest_state, label=trigger)

    graph.layout(prog='dot')
    
    if not os.path.exists(os.environ.get("OUTPUT_DIR")):
        os.makedirs(os.environ.get("OUTPUT_DIR"))
        
    graph.draw(os.path.join(os.environ.get("OUTPUT_DIR"), 'state_diagram.png'))

In [42]:
def print_state_hierarchy(state, level=0):
    print("  " * level + state.get_hierarchy_path())
    for child in state.children:
        print_state_hierarchy(child, level + 1)

In [43]:
with open(os.path.join(os.environ.get("INPUT_DIR"), "states.json")) as file:
    json_data = json.load(file)

with open(os.path.join(os.environ.get("INPUT_DIR"), "transitions.json")) as file:
    transitions_data = json.load(file)

In [44]:
root_state = initialize_conversation_states(json_data)


print_state_hierarchy(root_state)

apply_transitions(root_state, transitions_data)

visualize_conversation_states(root_state, transitions_data)

root
  start
    start_select-ready
  notready
    notready_select-get-set-state
      notready_select-get-set-state_get-state
        notready_select-get-set-state_get-state_select-tool
          notready_select-get-set-state_get-state_select-tool_compose-question
          notready_select-get-set-state_get-state_select-tool_compose-python
          notready_select-get-set-state_get-state_select-tool_compose-powershell
          notready_select-get-set-state_get-state_select-tool_compose-screenshot
      notready_select-get-set-state_set-state
        notready_select-get-set-state_set-state_select-tool
          notready_select-get-set-state_set-state_select-tool_compose-message
          notready_select-get-set-state_set-state_select-tool_compose-python
          notready_select-get-set-state_set-state_select-tool_compose-powershell
          notready_select-get-set-state_set-state_select-tool_compose-screenshot
    notready_select-ready
  ready
    ready_select-tool
      ready_sele

In [45]:
init_state = None
for state in root_state.children:
    if state.name == "start":
        init_state = state
        break
print(init_state)
init_state.name

<__main__.ConversationState object at 0x00000229B017E770>


'start'

In [46]:
csm = ConversationStateMachine(root_state=root_state, init_state=init_state)

In [47]:
csm.print_state()

[CSM] self.current_state: <__main__.ConversationState object at 0x00000229B017E770>
[CSM] self.current_state.get_hierarchy_path(): start


In [48]:
csm.current_state.transitions

{'selectready': <__main__.ConversationState at 0x229b017d6c0>}

In [49]:
csm.transition("selectready")

<__main__.ConversationState at 0x229b017d6c0>

In [50]:
csm.print_state()

[CSM] self.current_state: <__main__.ConversationState object at 0x00000229B017D6C0>
[CSM] self.current_state.get_hierarchy_path(): start_select-ready


In [51]:
csm.current_state.system

'test'

In [52]:
csm.current_state.messages

[{'role': 'user', 'content': [{'type': 'text', 'text': 'userprompt'}]}]