In [1]:
import json
import pygraphviz as pgv
import os
import dotenv

In [2]:
dotenv.load_dotenv()

True

In [3]:
class ConversationState:
    def __init__(self, name=None, parent=None, system="", system_frmtstrs={}, messages=[], messages_frmtstrs={}):
        self.name = name

        self.system_frmtstrs = system_frmtstrs
        self.system = system

        self.messages = messages
        self.messages_frmtstrs = messages_frmtstrs

        self.parent = parent

        self.transitions = {}
        self.children = []

    def add_message(self, message):
        self.messages.append(message)

    def add_transition(self, trigger, next_state):
        self.transitions[trigger] = next_state

    def add_child(self, child_state):
        self.children.append(child_state)
        child_state.parent = self

    def get_next_state(self, response):
        if response in self.transitions:
            return self.transitions[response]
        elif self.parent:
            return self.parent.get_next_state(response)
        else:
            return None

    def get_root(self):
        if self.parent:
            return self.parent.get_root()
        else:
            return self

    def get_hpath(self):
        if self.parent and self.parent.name != "root":
            return self.parent.get_hpath() + "_" + self.name
        else:
            return self.name

In [4]:
class ConversationStateMachine:
    PRINT_PREFIX = "[CSM]"

    def __init__(self, state_data=None, transition_data=None, init_state_path=None):
        self.initialize_conversation_states(state_data)
        self.initialize_transitions(transition_data)
        self.current_state = self.state_map[init_state_path]

    def transition(self, trigger):
        if trigger in self.current_state.transitions:
            self.current_state = self.current_state.transitions[trigger]
            return self.current_state
        else:
            print(f"{self.PRINT_PREFIX} invalid trigger '{trigger}' for state {self.current_state.get_hpath()}")
            return None

    def initialize_conversation_states(self, state_data):
        def create_state(state_data, parent=None):
            state = ConversationState(name=state_data["name"],
                                      parent=parent,
                                      system=state_data.get("system", ""),
                                      system_frmtstrs=state_data.get("system_frmtstrs", {}),
                                      messages=state_data.get("messages", []),
                                      messages_frmtstrs = state_data.get("messages_frmtstrs", {}))

            for child_data in state_data.get("children", []):
                child_state = create_state(child_data, parent=state)
                state.add_child(child_state)

            return state

        self.root_state = create_state(state_data)

    def find_state_by_path(self, path):
            return self.state_map.get(path)
    
    def initialize_transitions(self, transition_data=None):
        self.transition_data = transition_data
        self.state_map = {}

        def traverse_and_map_states(state):
            self.state_map[state.get_hpath()] = state
            for child in state.children:
                traverse_and_map_states(child)

        traverse_and_map_states(self.root_state)

        for transition in transition_data:
            trigger = transition["trigger"]
            source_paths = transition["source"]
            dest_path = transition["dest"]

            if not isinstance(source_paths, list):
                source_paths = [source_paths]

            for source_path in source_paths:
                source_path = source_path
                source_state = self.find_state_by_path(source_path)
                dest_state = self.find_state_by_path(dest_path)

                if source_state and dest_state:
                    source_state.add_transition(trigger, dest_state)
                else:
                    print(f"{self.PRINT_PREFIX} Warning: Invalid transition - Source: {source_path}, Destination: {dest_path}")

    def visualize(self):
        graph = pgv.AGraph(directed=True)

        graph.graph_attr['fontname'] = 'Consolas'
        graph.node_attr['fontname'] = 'Consolas'
        graph.node_attr['shape'] = 'box'
        graph.node_attr['style'] = 'rounded'
        graph.edge_attr['fontname'] = 'Consolas'

        def add_state_to_graph(state, parent_subgraph=None):
            if parent_subgraph is None:
                subgraph = graph
            else:
                subgraph = parent_subgraph.add_subgraph(name=f"cluster_{state.get_hpath()}")
                subgraph.graph_attr['style'] = 'rounded'

            if not (parent_subgraph is None):
                subgraph.add_node(state.get_hpath(), label=state.name)

            for child in state.children:
                add_state_to_graph(child, subgraph)

        add_state_to_graph(self.root_state)

        for transition in self.transition_data:
            trigger = transition["trigger"]
            source_paths = transition["source"]
            dest_path = transition["dest"]

            if not isinstance(source_paths, list):
                source_paths = [source_paths]

            for source_path in source_paths:
                source_state = graph.get_node(source_path)
                dest_state = graph.get_node(dest_path)

                if source_state and dest_state:
                    graph.add_edge(source_state, dest_state, label=trigger)

        graph.layout(prog='dot')
        
        if not os.path.exists(os.environ.get("OUTPUT_DIR")):
            os.makedirs(os.environ.get("OUTPUT_DIR"))
            
        graph.draw(os.path.join(os.environ.get("OUTPUT_DIR"), 'state_diagram.png'))
     
    def print_current_state(self):
        print(f"{self.PRINT_PREFIX} self.current_state: {self.current_state}")
        print(f"{self.PRINT_PREFIX} self.current_state.get_hpath(): {self.current_state.get_hpath()}")

    def print_state_hierarchy(self, state=None, level=0):
        if state == None:
            state = self.root_state

        print(self.PRINT_PREFIX + "  " * (level+1) + state.get_hpath())
        for child in state.children:
            self.print_state_hierarchy(child, level + 1)

In [5]:
with open(os.path.join(os.environ.get("INPUT_DIR"), "states.json")) as file:
    state_data = json.load(file)

with open(os.path.join(os.environ.get("INPUT_DIR"), "transitions.json")) as file:
    transition_data = json.load(file)

In [6]:
csm = ConversationStateMachine(state_data=state_data, transition_data=transition_data, init_state_path='start')

In [7]:
csm.print_state_hierarchy()

[CSM]  root
[CSM]    start
[CSM]      start_select-ready
[CSM]    notready
[CSM]      notready_select-get-set-state
[CSM]        notready_select-get-set-state_get-state
[CSM]          notready_select-get-set-state_get-state_select-tool
[CSM]            notready_select-get-set-state_get-state_select-tool_compose-question
[CSM]            notready_select-get-set-state_get-state_select-tool_compose-python
[CSM]            notready_select-get-set-state_get-state_select-tool_compose-powershell
[CSM]            notready_select-get-set-state_get-state_select-tool_compose-screenshot
[CSM]        notready_select-get-set-state_set-state
[CSM]          notready_select-get-set-state_set-state_select-tool
[CSM]            notready_select-get-set-state_set-state_select-tool_compose-message
[CSM]            notready_select-get-set-state_set-state_select-tool_compose-python
[CSM]            notready_select-get-set-state_set-state_select-tool_compose-powershell
[CSM]            notready_select-get-set-

In [8]:
csm.visualize()

In [9]:
csm.print_current_state()

[CSM] self.current_state: <__main__.ConversationState object at 0x000001C9DE1C2440>
[CSM] self.current_state.get_hpath(): start


In [10]:
csm.current_state.transitions

{'selectready': <__main__.ConversationState at 0x1c9de1c3970>}

In [11]:
csm.transition("selectready")

<__main__.ConversationState at 0x1c9de1c3970>

In [12]:
csm.print_current_state()

[CSM] self.current_state: <__main__.ConversationState object at 0x000001C9DE1C3970>
[CSM] self.current_state.get_hpath(): start_select-ready


In [13]:
csm.current_state.system

'system_start_select-ready.md'

In [14]:
csm.current_state.messages

[{'role': 'user',
  'content': [{'type': 'text',
    'text': 'Do you need to get any additional information or context from the user or system or do you need to configure the system into a specific state to perform the task "{task}"? Why or why not?'}]}]

In [15]:
csm.current_state.messages_frmtstrs["dynamic_cow"] = "dynamic moo"

In [16]:
csm.current_state.messages[0]["content"][0]["text"].format(**csm.current_state.messages_frmtstrs)

KeyError: 'task'