In [2]:
%pip install websockets


Collecting websockets
  Downloading websockets-14.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.7 kB)
Downloading websockets-14.1-cp312-cp312-macosx_11_0_arm64.whl (159 kB)
Installing collected packages: websockets
Successfully installed websockets-14.1


In [4]:
import os
from typing import Dict, Any, List
from langgraph.graph import StateGraph, END
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from pydantic import BaseModel

# Agent State
class AgentState(BaseModel):
    tweet: str = ""
    tweet_category: str = ""
    selected_style: str = ""
    draft_blog: str = ""
    status: str = "pending"
    human_feedback: str = ""

# LangGraph Agent Implementation
class TweetToBlogAgent:
    def __init__(self):
        self.model = ChatOpenAI(model="gpt-4-turbo")
        
    def classify_tweet(self, state: AgentState) -> AgentState:
        """Classify the tweet category"""
        response = self.model.invoke([
            HumanMessage(content=f"Classify the following tweet into one of these categories: News, Technology, Personal, Entertainment, Sports, Politics\n\nTweet: {state.tweet}")
        ])
        state.tweet_category = response.content
        return state
    
    def request_style_selection(self, state: AgentState) -> AgentState:
        """Prepare for human style selection"""
        state.status = "style_selection_required"
        return state
    
    def generate_blog(self, state: AgentState) -> AgentState:
        """Generate blog draft based on tweet and selected style"""
        style_prompt = f"Write a blog post in {state.selected_style} style about the following tweet: {state.tweet}"
        response = self.model.invoke([
            HumanMessage(content=style_prompt)
        ])
        state.draft_blog = response.content
        state.status = "draft_generated"
        return state
    
    def request_human_review(self, state: AgentState) -> AgentState:
        """Prepare draft for human review"""
        state.status = "review_required"
        return state
    
    def publish_blog(self, state: AgentState) -> AgentState:
        """Publish blog after human approval"""
        if state.human_feedback.lower() == "approve":
            # In a real scenario, you'd integrate with a blog publishing platform
            print(f"Blog Published: {state.draft_blog}")
            state.status = "published"
        else:
            state.status = "rejected"
        return state

    def build_graph(self):
        workflow = StateGraph(AgentState)
        
        # Add nodes
        workflow.add_node("classify_tweet", self.classify_tweet)
        workflow.add_node("request_style_selection", self.request_style_selection)
        workflow.add_node("generate_blog", self.generate_blog)
        workflow.add_node("request_human_review", self.request_human_review)
        workflow.add_node("publish_blog", self.publish_blog)
        
        # Add edges
        workflow.add_edge("classify_tweet", "request_style_selection")
        workflow.add_edge("request_style_selection", "generate_blog")
        workflow.add_edge("generate_blog", "request_human_review")
        workflow.add_edge("request_human_review", "publish_blog")
        
        workflow.set_entry_point("classify_tweet")
        workflow.add_conditional_edges(
            "publish_blog",
            lambda state: state.status,
            {
                "published": END,
                "rejected": "request_human_review"
            }
        )
        
        return workflow.compile()

# FastAPI WebSocket Server
class WebSocketServer:
    def __init__(self):
        self.app = FastAPI()
        self.agent = TweetToBlogAgent()
        self.graph = self.agent.build_graph()
        
        @self.app.websocket("/process")
        async def websocket_endpoint(websocket: WebSocket):
            await websocket.accept()
            try:
                while True:
                    # Receive initial tweet
                    data = await websocket.receive_json()
                    
                    # Initialize agent state
                    state = AgentState(tweet=data['tweet'])
                    
                    # Run through graph
                    current_state = state
                    for node in self.graph.iterate(current_state):
                        current_state = node['state']
                        
                        # Send status updates via WebSocket
                        await websocket.send_json({
                            "status": current_state.status,
                            "tweet_category": current_state.tweet_category,
                            "draft_blog": current_state.draft_blog
                        })
                        
                        # Handle human interaction points
                        if current_state.status == "style_selection_required":
                            await websocket.send_json({
                                "status": "awaiting_style_selection",
                                "tweet_category": current_state.tweet_category
                            })
                            # Wait for human style selection
                            style_data = await websocket.receive_json()
                            current_state.selected_style = style_data['selected_style']
                        
                        if current_state.status == "review_required":
                            await websocket.send_json({
                                "status": "awaiting_review",
                                "draft_blog": current_state.draft_blog
                            })
                            # Wait for human review
                            review_data = await websocket.receive_json()
                            current_state.human_feedback = review_data['feedback']
                    
                    # Final publication status
                    await websocket.send_json({
                        "status": current_state.status,
                        "message": "Blog process completed"
                    })
                    
            except WebSocketDisconnect:
                print("WebSocket connection closed")

    def run(self, host="localhost", port=8000):
        uvicorn.run(self.app, host=host, port=port)

ImportError: cannot import name 'CheckpointAt' from 'langgraph.checkpoint.base' (/opt/homebrew/Caskroom/miniconda/base/envs/postbot/lib/python3.12/site-packages/langgraph/checkpoint/base/__init__.py)

In [3]:
from typing import Annotated

from langchain_anthropic import ChatAnthropic
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableConfig
from typing_extensions import TypedDict

from langgraph.graph.message import AnyMessage, add_messages


class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    user_info: str


class Assistant:
    def __init__(self, runnable: Runnable):
        self.runnable = runnable

    def __call__(self, state: State, config: RunnableConfig):
        while True:
            result = self.runnable.invoke(state)
            # If the LLM happens to return an empty response, we will re-prompt it
            # for an actual response.
            if not result.tool_calls and (
                not result.content
                or isinstance(result.content, list)
                and not result.content[0].get("text")
            ):
                messages = state["messages"] + [("user", "Respond with a real output.")]
                state = {**state, "messages": messages}
            else:
                break
        return {"messages": result}


# Haiku is faster and cheaper, but less accurate
# llm = ChatAnthropic(model="claude-3-haiku-20240307")
llm = ChatOpenAI(
            model="llama-3.3-70b-versatile", 
            temperature=0.5,
            api_key=os.environ["GROQ_API_KEY"],
            base_url="https://api.groq.com/openai/v1/"
        )
llm = ChatAnthropic(model="claude-3-sonnet-20240229", temperature=1)
# You can update the LLMs, though you may need to update the prompts
# from langchain_openai import ChatOpenAI

# llm = ChatOpenAI(model="gpt-4-turbo-preview")

assistant_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a helpful customer support assistant for Swiss Airlines. "
            " Use the provided tools to search for flights, company policies, and other information to assist the user's queries. "
            " When searching, be persistent. Expand your query bounds if the first search returns no results. "
            " If a search comes up empty, expand your search before giving up."
            "\n\nCurrent user:\n<User>\n{user_info}\n</User>"
            "\nCurrent time: {time}.",
        ),
        ("placeholder", "{messages}"),
    ]
).partial(time=datetime.now)


# "Read"-only tools (such as retrievers) don't need a user confirmation to use
part_3_safe_tools = [
    TavilySearchResults(max_results=1),
    fetch_user_flight_information,
    search_flights,
    lookup_policy,
    search_car_rentals,
    search_hotels,
    search_trip_recommendations,
]

# These tools all change the user's reservations.
# The user has the right to control what decisions are made
part_3_sensitive_tools = [
    update_ticket_to_new_flight,
    cancel_ticket,
    book_car_rental,
    update_car_rental,
    cancel_car_rental,
    book_hotel,
    update_hotel,
    cancel_hotel,
    book_excursion,
    update_excursion,
    cancel_excursion,
]
sensitive_tool_names = {t.name for t in part_3_sensitive_tools}
# Our LLM doesn't have to know which nodes it has to route to. In its 'mind', it's just invoking functions.
part_3_assistant_runnable = assistant_prompt | llm.bind_tools(
    part_3_safe_tools + part_3_sensitive_tools
)

RuntimeError: asyncio.run() cannot be called from a running event loop