[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/langchain-ai/langchain-academy/blob/main/module-4/map-reduce.ipynb) [![Open in LangChain Academy](https://cdn.prod.website-files.com/65b8cd72835ceeacd4449a53/66e9eba12c7b7688aa3dbb5e_LCA-badge-green.svg)](https://academy.langchain.com/courses/take/intro-to-langgraph/lessons/58239947-lesson-3-map-reduce)

# Map-reduce

Now, we're going to take a look into [map reduce](https://langchain-ai.github.io/langgraph/how-tos/map-reduce/) langgraph based implementation.

In [11]:
%%capture --no-stderr
%pip install -U langchain_openai langgraph

In [12]:
import os, getpass

def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")

_set_env("OPENAI_API_KEY")

## MapReduce

Map-reduce operations are essential for efficient task decomposition and parallel processing. 

It has two phases:\
(1) `Map` - Break a task into smaller sub-tasks, processing each sub-task in parallel.\
(2) `Reduce` - Aggregate the results across all of the completed, parallelized sub-tasks.

Let's design a system that will do two things:\
(1) `Map` - Create a set of jokes about a topic.\
(2) `Reduce` - Pick the best joke from the list.

We'll use an LLM to do the job generation and selection.

In [14]:
from langchain_openai import ChatOpenAI

# Prompts we will use
subjects_prompt = """Generate a list of 3 sub-topics that are all related to this overall topic: {topic}."""
joke_prompt = """Generate a joke about {subject}"""
best_joke_prompt = """Below are a bunch of jokes about {topic}. Select the best one! Return the ID of the best one, starting 0 as the ID for the first joke. Jokes: \n\n  {jokes}"""

# LLM
model = ChatOpenAI(model="gpt-4o", temperature=0) 

## State

### Parallelizing joke generation

First, let's define the entry point of the graph that will:

* Take a user input topic
* Produce a list of joke topics from it
* Send each joke topic to our above joke generation node

Our state has a `jokes` key, which will accumulate jokes from parallelized joke generation

In [32]:
import operator
from typing import Annotated
from typing_extensions import TypedDict
from pydantic import BaseModel

# General state of the graph
class OverallState(TypedDict):
    # animals
    topic: str
    # elefant, tiger, lion
    subjects: list
    # jokes about elefant, tiger, lion
    jokes: Annotated[list, operator.add]
    # best joke from the list above
    best_selected_joke: str

## Input
Generate subjects for jokes.

In [33]:
# For subject generation using structured output
class Subjects(BaseModel):
    subjects: list[str]

def generate_subjects(state: OverallState):
    # Get topic from state
    prompt = subjects_prompt.format(topic=state["topic"])
    # Generate subjects to joke about for the topic
    response = model.with_structured_output(Subjects).invoke(prompt)
    return {"subjects": response.subjects}

### Joke generation (map)

Now, we just define a node that will create our jokes, `generate_joke`!\
We write them back out to `jokes` in `OverallState`!\
This key has a reducer that will combine lists.

In [34]:
# Input state for generate joke
class JokeState(TypedDict):
    subject: str

# Model for structured output
class Joke(BaseModel):
    joke: str

def generate_joke(state: JokeState):
    # Get subject from state
    prompt = joke_prompt.format(subject=state["subject"])
    # Generate joke
    response = model.with_structured_output(Joke).invoke(prompt)
    # Append joke to overall state
    return {"jokes": [response.joke]}

## Send subject
Here is the magic: we use the [Send](https://langchain-ai.github.io/langgraph/concepts/low_level/#send) to create a joke for each subject.

This is very useful! It can automatically parallelize joke generation for any number of subjects.

* `generate_joke`: the name of the node in the graph
* `{"subject": s`}: the state to send

`Send` allow you to pass any state that you want to `generate_joke`! It does not have to align with `OverallState`.

In this case, `generate_joke` is using its own internal state, and we can populate this via `Send`.

In [27]:
from langgraph.constants import Send

# Conditional edge
def continue_to_jokes(state: OverallState):
    return [Send("generate_joke", JokeState(subject=s)) for s in state["subjects"]] # When the Send command is processes?

### Best joke selection (reduce)

Now, we add logic to pick the best joke.

In [35]:
# Model for structured output
class BestJoke(BaseModel):
    id: int

def best_joke(state: OverallState):
    # Get jokes from state
    jokes = "\n\n".join(state["jokes"])
    prompt = best_joke_prompt.format(topic=state["topic"], jokes=jokes)
    # Select best joke
    response = model.with_structured_output(BestJoke).invoke(prompt)
    return {"best_selected_joke": state["jokes"][response.id]}

## Compile

In [42]:
from IPython.display import Image
from langgraph.graph import END, StateGraph, START

# Construct the graph: here we put everything together to construct our graph
graph = StateGraph(OverallState)
graph.add_node("generate_subjects", generate_subjects)
graph.add_node("generate_joke", generate_joke)
graph.add_node("best_joke", best_joke)

graph.add_edge(START, "generate_subjects")
graph.add_conditional_edges("generate_subjects", continue_to_jokes, ["generate_joke"])
graph.add_edge("generate_joke", "best_joke")
graph.add_edge("best_joke", END)

# Compile the graph
app = graph.compile()
Image(app.get_graph().draw_mermaid_png())

ReadTimeout: HTTPSConnectionPool(host='mermaid.ink', port=443): Read timed out. (read timeout=10)

In [41]:
# Call the graph: here we call it to generate a list of jokes
for s in app.stream({"topic": "students"}):
    print(s)

{'generate_subjects': {'subjects': ['Study Techniques and Learning Strategies', 'Mental Health and Well-being in Students', 'The Impact of Technology on Student Learning']}}
{'generate_joke': {'jokes': ['Why did the student bring a ladder to the library?\n\nBecause they heard the books on mental health were on a higher level!']}}
{'generate_joke': {'jokes': ['Why did the student bring a ladder to the computer lab?\n\nBecause they heard the cloud was where all the answers were stored!']}}
{'generate_joke': {'jokes': ['Why did the textbook break up with the highlighter?\n\nBecause it couldn\'t handle the pressure of being in a "highlighted" relationship! 📚✨']}}
{'best_joke': {'best_selected_joke': 'Why did the student bring a ladder to the library?\n\nBecause they heard the books on mental health were on a higher level!'}}
