In [1]:
from pydantic import BaseModel
from pydantic import BaseModel, Field

from typing import List, Union, Any



In [2]:
class Attribute(BaseModel):
    name: str = Field(
        ...,
        description="Name of the attribute, representing a specific characteristic or property of an entity."
    )
    value: Any = Field(
        ...,
        description="Value of the attribute, which can be of any data type, representing the state or property value of the entity."
    )

class Entity(BaseModel):
    id: int = Field(
        ...,
        description="Unique identifier for the entity, used for tracking and referencing."
    )
    name: str = Field(
        ...,
        description="Name of the entity, representing an object or concept within the GOAP framework."
    )
    attributes: List[Attribute] = Field(
        ...,
        description="List of attributes of the entity, detailing the characteristics and properties associated with the entity."
    )
    time_step: int = Field(
        ...,
        description="Time step at which the entity is relevant, allowing for modeling of entities that vary over time."
    )

class Action(BaseModel):
    name: str = Field(
        ...,
        description="Name of the action, describing the interaction or event."
    )
    time_step: int = Field(
        ...,
        description="Time step at which the action takes place."
    )
    source_entity_ids: List[int] = Field(
        ...,
        description="List of IDs of the source entities initiating the action."
    )
    target_entity_ids: List[int] = Field(
        ...,
        description="List of IDs of the target entities affected by the action."
    )
    description: str = Field(
        ...,
        description="Description of the action, providing context and details about the interaction."
    )

class TimeStepEntity(BaseModel):
    time_step: int = Field(
        ...,
        description="Time step at which the entities' states and actions are relevant."
    )
    entities: List[Entity] = Field(
        ...,
        description="List of entities present at this time step, each with their respective attributes."
    )
    actions: List[Action] = Field(
        ...,
        description="List of actions occurring at this time step, involving the entities."
    )

class EntitiesExtraction(BaseModel):
    time_step_entities: List[TimeStepEntity] = Field(
        ...,
        description="A list capturing entities, their attributes, and actions across different time steps."
    )



system_prompt = """
Analyze the given text and extract entities, their attributes, and actions occurring at each time step. Populate the TimeStepEntity class with this information. The text to analyze is: "<INSERT TEXT HERE>". Follow these steps:

1. **Identify Entities and Attributes**: For each time step in the text, identify all the entities present and their attributes. Assign each entity a unique ID for reference.

2. **Extract Actions**: Identify any actions that occur at each time step. These actions should be associated with the entities involved, either as initiators (source) or receivers (target) of the action.

3. **Populate TimeStepEntity**: For each time step identified, create a TimeStepEntity object that includes both the list of entities (with their attributes) and the list of actions occurring at that time.

4. **Sequential Organization**: Organize the TimeStepEntity objects sequentially, starting from the earliest time step in the narrative.

Your goal is to create a structured representation of the narrative, capturing the dynamic interplay of entities and actions over time.

Example of a correctly formatted response for a given time step:

```python
TimeStepEntity(
    time_step=0,
    entities=[
        Entity(id=1, name="Entity1", attributes=[Attribute(name="Attribute1", value="Value1")]),
        Entity(id=2, name="Entity2", attributes=[Attribute(name="Attribute2", value="Value2")])
    ],
    actions=[
        Action(name="Action1", time_step=0, source_entity_ids=[1], target_entity_ids=[2], description="Description of Action1")
    ]
)
"""


In [9]:
from graphviz import Digraph
from typing import List, Any

# Assuming the classes Entity, Attribute, Action, and TimeStepEntity are defined as provided

def generate_html_label(entity: Entity) -> str:
    rows = [f"<tr><td>{attr.name}</td><td>{attr.value}</td></tr>" for attr in entity.attributes]
    table_rows = "".join(rows)
    return f"<<table border='0' cellborder='1' cellspacing='0'><tr><td colspan='2'><b>{entity.name}</b></td></tr>{table_rows}</table>>"

def generate_static_graph(data: EntitiesExtraction):
    dot = Digraph(comment="Entity Dynamics Graph", node_attr={"shape": "plaintext"})

    # Create nodes for each entity with their attributes
    for time_step_entity in data.time_step_entities:
        for entity in time_step_entity.entities:
            label = generate_html_label(entity)
            dot.node(str(entity.id), label)

    # Create edges for actions
    for time_step_entity in data.time_step_entities:
        for action in time_step_entity.actions:
            for source_id in action.source_entity_ids:
                for target_id in action.target_entity_ids:
                    dot.edge(str(source_id), str(target_id), label=f"{action.name} (Time Step {action.time_step})\n{action.description}")

    dot.render("entity_dynamics.gv", view=True)


def generate_graph(data: EntitiesExtraction, prompt: str):
    dot = Digraph(comment="State Action Graph", node_attr={"shape": "plaintext"})
    dot.attr(labelloc='t')
    dot.attr(label='State Action Graph')

    # Creating clusters for each time step
    for time_step_entity in data.time_step_entities:
        with dot.subgraph(name=f'cluster_{time_step_entity.time_step}') as cluster:
            cluster.attr(label=f'Time Step {time_step_entity.time_step}')
            cluster.attr(style='filled')
            cluster.attr(color='lightgrey')

            # Create nodes for each entity within the time step
            for entity in time_step_entity.entities:
                label = generate_html_label(entity)
                cluster.node(f'entity_{entity.id}_{time_step_entity.time_step}', label)

            # Create edges for actions within the time step
            for action in time_step_entity.actions:
                for source_id in action.source_entity_ids:
                    for target_id in action.target_entity_ids:
                        cluster.edge(f'entity_{source_id}_{time_step_entity.time_step}', 
                                     f'entity_{target_id}_{time_step_entity.time_step}', 
                                     label=f"{action.name}\n{action.description}")

    # Adding the prompt at the bottom of the graph
    dot.attr(labeljust='b', label=prompt)

    dot.render("state_action_graph.gv", view=True)





In [4]:
import instructor
from openai import OpenAI

# Apply the patch to the OpenAI client
# enables response_model keyword
client = instructor.patch(OpenAI(api_key = "sk-"))

def ask_ai(content,base_instructions,extra_instructions,response_model):
    return client.chat.completions.create(
        model="gpt-4",
        response_model=response_model,
        messages=[
            {
                "role": "system",
                "content": "Use this instructions" + base_instructions+   extra_instructions,
            },
            {
                "role": "user",
                "content":  content,
            },
        ],
    )  # type: ignore


In [5]:
content = "Margaret decided to cook a chocolate cake. She prepared the ingredients and mixed them together. She then put the cake in the oven and waited for it to bake. After 30 minutes, she took the cake out of the oven and let it cool. She then decorated the cake with icing and sprinkles. Margaret was very happy with how the cake turned out."
response = ask_ai(content,system_prompt,"remember to characterize each entity by their attributes, and be sure that causally relevant actions always intervene over some measured attribute of the target entity. Remember that each entity can take as single action for each time step",EntitiesExtraction)

Retrying, exception: 10 validation errors for EntitiesExtraction
time_step_entities.0.entities.0.attributes.0
  Input should be an object [type=model_type, input_value='None', input_type=str]
    For further information visit https://errors.pydantic.dev/2.5/v/model_type
time_step_entities.0.entities.1.attributes.0
  Input should be an object [type=model_type, input_value='None', input_type=str]
    For further information visit https://errors.pydantic.dev/2.5/v/model_type
time_step_entities.1.entities.0.attributes.0
  Input should be an object [type=model_type, input_value='None', input_type=str]
    For further information visit https://errors.pydantic.dev/2.5/v/model_type
time_step_entities.1.entities.1.attributes.0
  Input should be an object [type=model_type, input_value='Ingredients Prepared', input_type=str]
    For further information visit https://errors.pydantic.dev/2.5/v/model_type
time_step_entities.2.entities.0.attributes.0
  Input should be an object [type=model_type, inpu

In [12]:
response

EntitiesExtraction(time_step_entities=[TimeStepEntity(time_step=0, entities=[Entity(id=1, name='Margaret', attributes=[Attribute(name='state', value='None')], time_step=0), Entity(id=2, name='Chocolate Cake', attributes=[Attribute(name='state', value='None')], time_step=0)], actions=[Action(name='Decide to cook', time_step=0, source_entity_ids=[1], target_entity_ids=[2], description='Margaret decided to cook a chocolate cake.')]), TimeStepEntity(time_step=1, entities=[Entity(id=1, name='Margaret', attributes=[Attribute(name='state', value='None')], time_step=1), Entity(id=2, name='Chocolate Cake', attributes=[Attribute(name='state', value='Ingredients Prepared')], time_step=1)], actions=[Action(name='Prepare and mix ingredients', time_step=1, source_entity_ids=[1], target_entity_ids=[2], description='Margaret prepared the ingredients and mixed them together.')]), TimeStepEntity(time_step=2, entities=[Entity(id=1, name='Margaret', attributes=[Attribute(name='state', value='None')], time

In [26]:
from graphviz import Digraph

def generate_html_label(entity: Entity) -> str:
    rows = [f"<tr><td>{attr.name}</td><td>{attr.value}</td></tr>" for attr in entity.attributes]
    table_rows = "".join(rows)
    return f"<<table border='0' cellborder='1' cellspacing='0'><tr><td colspan='2'><b>{entity.name}</b></td></tr>{table_rows}</table>>"

def generate_graph(data: EntitiesExtraction, prompt: str):
    dot = Digraph(comment="State Action Graph", node_attr={"shape": "plaintext"})
    dot.attr(rankdir='TB')  # Top to Bottom layout

    # Invisible node for alignment
    dot.node("start", style="invisible")

    # Creating clusters for each time step
    for idx, time_step_entity in enumerate(data.time_step_entities, start=1):
        with dot.subgraph(name=f'cluster_{idx}') as cluster:
            cluster.attr(label=f'Time Step {time_step_entity.time_step}')
            cluster.attr(style='filled', color='lightgrey')

            # Invisible node for internal cluster alignment
            cluster.node(f'invisible_{idx}', style="invisible")

            # Create nodes for each entity within the time step
            for entity in time_step_entity.entities:
                label = generate_html_label(entity)
                cluster.node(f'entity_{entity.id}_{idx}', label)

            # Connect the invisible nodes to center-align the clusters
            if idx > 1:  # Skip the first cluster
                dot.edge(f'invisible_{idx-1}', f'invisible_{idx}', style='invis')

            # Create edges for actions within the time step
            for action in time_step_entity.actions:
                for source_id in action.source_entity_ids:
                    for target_id in action.target_entity_ids:
                        source_node_name = f'entity_{source_id}_{idx}'
                        target_node_name = f'entity_{target_id}_{idx}'
                        cluster.edge(source_node_name, target_node_name, label=f"{action.name}\n{action.description}")

    # Edge from start node to first cluster's invisible node
    dot.edge("start", "invisible_1", style="invis")

    # Adding the prompt at the bottom of the graph
    dot.attr(labelloc='t', label=prompt)

    dot.render("state_action_graph.gv", view=True)


In [27]:
generate_graph(response,content)
# generate_static_graph(response)

In [8]:
generate_graph(response,content)