In [1]:
from typing import Annotated, Sequence

from typing_extensions import TypedDict

from langchain import chat_models
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory


from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph, END
from langgraph.graph.message import add_messages
from langgraph.graph.state import CompiledStateGraph

from IPython.display import Image, display

import json
import os

In [2]:
secrets = json.load(open("secrets.json"))
for key, value in secrets.items():
    os.environ[key] = value

In [3]:
class Message:
    role_id: int
    content: str

    def __init__(self, role_id: int, content: str):
        self.role_id = role_id
        self.content = content

    def __repr__(self) -> str:
        return f"Message(role_id={self.role_id}, content={self.content})"

    def to_human_message(self) -> HumanMessage:
        return HumanMessage(content=self.content)

    def to_message(self, cur_role: int) -> BaseMessage:
        if cur_role == self.role_id:
            return self.to_human_message()
        else:
            return self.to_human_message()

In [4]:
class Role:
    id: int
    system_message: str
    model: BaseChatModel

    def __init__(self, id: int, system_message: str, model: BaseChatModel):
        self.id = id
        self.system_message = system_message
        self.model = model

    def create_history(self, messages: Sequence[Message]) -> Sequence[BaseMessage]:
        return [SystemMessage(content=self.system_message)] + [
            message.to_message(self.id) for message in messages
        ]

    def advance(self, messages: Sequence[Message]) -> Message:
        history = self.create_history(messages)
        response = self.model.invoke(history)
        if isinstance(response.content, str):
            return Message(self.id, response.content)
        else:
            raise Exception("Multiple responses received")

In [5]:
role0_system_prompt = (
    "You are in a conversation with a friend you have a slight crush on"
)
role1_system_prompt = (
    "You are in a conversation with a friend you need to ask for money"
)
role0_start_message = "How are you doing?"

In [6]:
model = chat_models.init_chat_model("gpt-4o-mini", model_provider="openai")

role0: Role
role1: Role
messages: list[Message]


def reset():
    global role0, role1, messages
    role0 = Role(0, role0_system_prompt, model)
    role1 = Role(1, role1_system_prompt, model)
    messages = [Message(0, "How are you doing?")]


def advance():
    global messages, role0, role1
    prev_role = messages[-1].role_id
    if prev_role == 0:
        response = role1.advance(messages)
        messages += [response]
    else:
        response = role0.advance(messages)
        messages += [response]
    for message in messages:
        print(f"{message.role_id}: {message.content}")


In [7]:
reset()

In [None]:
advance()