In [31]:
from langchain_ollama import ChatOllama

model = ChatOllama(model="llama3.1:latest", temperature=0)

# Prompts we will use
subjects_prompt = """Generate a list of 3 sub-topics that are all related to this overall topic: {topic}."""
joke_prompt = """Generate a joke about {subject}"""
best_joke_prompt = """Below are a bunch of jokes about {topic}. Select the best one! Return the ID of the best one, starting 0 as the ID for the first joke. Jokes: \n\n  {jokes}"""


In [32]:
import operator
from typing_extensions import TypedDict
from typing import Annotated
from pydantic import BaseModel

class Subjects(BaseModel):
    subjects:list[str]

class BestJoke(BaseModel):
    id:int

class OverallState(TypedDict):
    subjects:list
    topic:str
    jokes: Annotated[list, operator.add]
    best_selected_joke: str


In [33]:
def generate_topics(state:OverallState):
    prompt  = subjects_prompt.format(topic=state["topic"])
    response = model.with_structured_output(Subjects).invoke(prompt)
    return {"subjects": response.subjects}

In [34]:
from langgraph.constants import Send

def continue_to_jokes(state:OverallState):
    return [Send("generate_joke", {"subject":s}) for s in state["subjects"]]

In [35]:
class JokeState(TypedDict):
    subject:str

class Joke(BaseModel):
    joke:str

def generate_joke(state:JokeState):
    prompt = joke_prompt.format(subject=state["subject"])
    response = model.with_structured_output(Joke).invoke(prompt)
    return {"jokes": [response.joke]}

In [36]:
def best_joke(state: OverallState):
    jokes = "\n\n".join(state["jokes"])
    prompt = best_joke_prompt.format(topic=state["topic"], jokes=jokes)
    response = model.with_structured_output(BestJoke).invoke(prompt)
    return {"best_selected_joke": state["jokes"][response.id]}

In [37]:
from langgraph.graph import StateGraph, START, END

graph  = StateGraph(OverallState)


graph.add_node("generate_topics", generate_topics)
graph.add_node("generate_joke", generate_joke)
graph.add_node("best_joke", best_joke)

graph.add_edge(START, "generate_topics")
graph.add_conditional_edges("generate_topics", continue_to_jokes, ["generate_joke"])
graph.add_edge("generate_joke", "best_joke")
graph.add_edge("best_joke", END)

app = graph.compile()


In [40]:
# Call the graph: here we call it to generate a list of jokes
for s in app.stream({"topic": "india"}):
    print(s)

{'generate_topics': {'subjects': ['Indian Culture', 'History of India', 'Geography of India']}}
{'generate_joke': {'jokes': ['Why did the Ganges River go to therapy? Because it was feeling drained!']}}
{'generate_joke': {'jokes': ['Why did the Taj Mahal go to therapy? Because it had a lot of history!']}}
{'generate_joke': {'jokes': ['Why did the elephant quit the circus? Because it was tired of working for peanuts. But in India, we say that elephants are sacred and will never work for peanuts, they get paid in naan bread instead!']}}
{'best_joke': {'best_selected_joke': 'Why did the elephant quit the circus? Because it was tired of working for peanuts. But in India, we say that elephants are sacred and will never work for peanuts, they get paid in naan bread instead!'}}
