# Retrieval & Generation

#### Imports

In [None]:
import os
import uuid
import requests

from dotenv import load_dotenv

from IPython.display import Image, display

from langchain_chroma import Chroma
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage

from langgraph.graph import MessagesState, StateGraph, END
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import MemorySaver

#### Environment Variables

In [None]:
load_dotenv('../.env')

os.environ['LANGSMITH_TRACING'] = 'true'
os.environ['LANGSMITH_API_KEY'] = os.getenv('LANGSMITH')
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY')

#### Components

In [None]:
# Chat Model
llm = ChatOpenAI(model = 'gpt-4o-mini')

# Embeddings Model
embeddings = OpenAIEmbeddings(model = 'text-embedding-3-large')

# Chroma DB Vector Store
vector_store = Chroma(persist_directory = '../chroma_db', embedding_function = embeddings)

#### Prompts

In [None]:
# LLM short-circuit prompt

prompt_short_circuit = (
    "If the query contains factual questions, retrieve documents."
    "You **must always call the `nasa` tool** to fetch relevant NASA images."
    "If you call the `nasa` tool, you must also call the `podcast` tool to initialize."
    "If the query is conversation, respond immediately."
    "You love astronomy and to engage with curious kids!"
    "Keep your response short, and fun (five sentences max)."
    "Add single relevant emojies within the text."
    "Always make sure to end the response with an emoji."
    "\n\n"
    "When using the 'nasa' tool, provide only a single word as input."
    "This word should represent the main object of the query."
    "Such as the singular name of a celestial body or astronomical object."
    )

In [None]:
# Retrieval step prompt

prompt_retrieval = (
    "You're a friendly and enthusiastic astronomy teacher who loves explaining space facts to curious kids!"
    "Use the following pieces of context to answer the question at the end in a fun, simple, and engaging way."
    "Keep your explanation short, fun, and easy to understand (five sentences max)."
    "Use playful language, examples, or comparisons to make the answer exciting for kids."
    "Always end with an encouraging phrase like ´Keep looking up!´ or ´Space is awesome, isn't it?´ to keep them excited about learning."
    "Add single relevant emojies within the text. Always make sure to end the response with a single emoji"
    )

#### Tools

In [None]:
# Vector Store retrieval step tool

@tool(response_format = 'content')
def vector_db(query: str):
    """Retrieve astronomical information chunks from chromaDB"""
    retrieved_docs = vector_store.similarity_search(query, k = 2)
    serialized = '\n\n'.join(
        (f"Source: {doc.metadata}\nContent: {doc.page_content}")
        for doc in retrieved_docs
    )
    return serialized

In [None]:
# NASA Images retrieval step tool

last_search = None

@tool(response_format = 'content')
def nasa_images(search_word: str):
    """Fetch relevant space images from NASA Images Api"""
    global last_search
    if search_word != last_search:
        BASE_URL = 'https://images-api.nasa.gov/search'
        params = {
            'q': search_word,
            'media_type': 'image',
        }
        response = requests.get(BASE_URL, params = params)
        data = response.json()
        items = data.get('collection').get('items')
        # Get the first 8 image objects
        images = [item.get('links') for item in items[:8]]
        # Get the image links
        image_links = [image[0]['href'] for image in images]
        # Update last search string
        last_search = search_word

        return image_links

In [None]:
# Podcast initialization step tool

podcast_setup = {
    'topic': None,
    'llm': 'gpt-4o-mini',
    'queries': []
}

@tool(response_format = 'content')
def podcast(search_word:str, query: str):
    """Initialize podcast"""
    # Update podcast (new topic)
    if search_word != podcast_setup['topic']:
        podcast_setup['topic'] = search_word
        podcast_setup['queries'] = [query]
    # Update podcast (add message)
    else:
        podcast_setup['queries'].append(query)

    return podcast_setup

#### Nodes

In [None]:
# NODE 1: LLM decides to retrieve documents or respond immediately

def query_or_respond(state: MessagesState):
    """Generate tools call for retrieval or respond"""
    system_message_content = prompt_short_circuit
    system_message = SystemMessage(system_message_content)
    llm_with_tools = llm.bind_tools([vector_db, nasa_images, podcast])
    # Appends messages to MessagesState
    response = llm_with_tools.invoke([system_message] + state['messages'])
    # Return updated MessagesState
    return {'messages': [response]}


# NODE 2: Registers and executes retrieval if needed

tools = ToolNode([vector_db, nasa_images, podcast])


# NODE 3: Generate retrieval response

def generate(state: MessagesState):
    """Generate answer based on retrieved data"""
    # Get generated ToolMessages
    recent_tool_messages = []
    for message in reversed(state['messages']):
        if message.type == 'tool':
            recent_tool_messages.append(message)
        else:
            break
    # Restore chronological order
    tool_messages = recent_tool_messages[::-1]

    # Format into retrieval prompt
    docs_content = '\n\n'.join(doc.content for doc in tool_messages)
    system_message_content = f"{prompt_retrieval}\n\n{docs_content}"

    conversation_history = [
        message for message in state['messages']
        if message.type in ('human', 'system')
        # Exclude AI tool call messages
        or (message.type == 'ai' and not message.tool_calls)
    ]

    prompt = [SystemMessage(system_message_content)] + conversation_history

    response = llm.invoke(prompt)
    return {'messages': [response]}

#### Build Graph

In [None]:
# Initialize the Graph
graph_builder = StateGraph(MessagesState)

# Add Nodes
graph_builder.add_node(query_or_respond)
graph_builder.add_node(tools)
graph_builder.add_node(generate)

# Define entry point
graph_builder.set_entry_point('query_or_respond')

# Define dynamic flow
graph_builder.add_conditional_edges(
    'query_or_respond',
    tools_condition,
    {END: END, 'tools': 'tools'}
)

# Define fixed transitions
graph_builder.add_edge('tools', 'generate')
graph_builder.add_edge('generate', END)

# Simple in-memory checkpointer
memory = MemorySaver()

# Compile the Graph
graph = graph_builder.compile(checkpointer = memory)

#### Control flow visualization

In [None]:
display(Image(graph.get_graph().draw_mermaid_png()))

#### Input Testing

In [None]:
# Set session thread_id
config = {'configurable': {'thread_id': str(uuid.uuid4())}}

def message_test(input_message):
    """DRY message test function"""
    for step in graph.stream(
        {'messages': [{'role': 'user', 'content': input_message}]},
        stream_mode = 'values',
        config = config
    ):
        step['messages'][-1].pretty_print()

In [None]:
# Conversational message - LLM short-circuit - NO Images (conversation)

message_test(input_message = 'Hello. I do have a question about the universe.')

In [None]:
# Astronomy question - RETRIEVAL step - FETCH Images

message_test(input_message = 'How hot is it on the sun?')

In [None]:
# Astronomy question (memory usage) - RETRIEVAL step - NO Images (same topic)

message_test(input_message = 'Is it big?')

In [None]:
# Astronomy question (memory usage) - LLM short-circuit - NO Images (shortcut)

message_test(input_message = 'Would I be able to live on it?')

In [None]:
# Astronomy question - RETRIEVAL step (back from short-circuit) - FETCH Images

message_test(input_message = 'How about brown dwars? Are they real?')

In [None]:
# Out-of-Context question - LLM short-circuit - NO Images (no context)

message_test(input_message = 'Do you know OpenAI?')

In [None]:
# Conversational message - LLM short-circuit - NO Images (conversation)

message_test(input_message = 'Thank you. Have a nice day!')