In [None]:
from typing import Annotated

from langchain_core.documents import Document
from langchain_core.messages import SystemMessage,HumanMessage
from langchain_core.tools import tool
from typing_extensions import List, TypedDict

from langchain_core.prompts import ChatPromptTemplate, PromptTemplate

from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition

from IPython.display import Image, display

from tools import llm

from VectorDB import VectorDB

In [None]:
PERSIST_DIR = "./chroma_langchain_db"
COLLECTION_NAME = "movies_collection"

# Initialize vector database
vector_db = VectorDB(model_name="BAAI/bge-base-en-v1.5", batch_size=32)
init_result = vector_db.initialize_vector_store(PERSIST_DIR, COLLECTION_NAME)

vector_store = vector_db.vector_store

In [None]:
generator_prompt = (
    "You are a movie recommendation / finding assistant, your job is to help users find movies based on their preferences."
    "You will start by reading the user's query and then searching for relevant movies in the database."
    "Then you will use the available context to provide a personalized recommendation of a movie title:" \
    "- Identify the main 3 genres of the movie from the context"
    "- Identify the the release year, director, and main actors of the movie from the context"
    "- Identify the main themes and plotline of the movie from the context"
    "When generating the final response, make sure to include all relevant information in the following order:"
    "1 - Title:"
    "2 - Genres:"
    "3 - Release Year, Director, and Main Actors:"
    "4 - Themes and Plotline:"
    "Query:\n{question}"
    "Context:\n{context}\n"
)

GENERATOR_PROMPT = PromptTemplate.from_template(generator_prompt)

class State(MessagesState):
    context: List[Document]


class MovieRecommendationWithSources(TypedDict):
    """A movie recommendation with detailed information and sources."""
    
    movie_title: str
    genres: Annotated[
        List[str], 
        "Main 3 genres of the recommended movie"
    ]
    release_year: Annotated[
        int,
        "Year the movie was released"
    ]
    director: Annotated[
        str,
        "Director of the movie"
    ]
    main_actors: Annotated[
        List[str],
        "List of main actors in the movie"
    ]
    themes_and_plot: Annotated[
        str,
        "Brief description of main themes and plotline"
    ]
    recommendation_reason: Annotated[
        str,
        "Detailed explanation of why this movie matches the user's preferences"
    ]
    sources: Annotated[
        List[str],
        "List of source documents/databases used to gather this movie information"
    ]

def query_or_respond(state: State):
    """Generate tool call for movie retrieval or respond directly."""
    
    # Add system message to encourage tool usage for movie queries
    system_message = SystemMessage(content=(
        "You are a movie recommendation assistant. When users ask about movies, "
        "you should use the retrieve tool to search the movie database "
        "before providing recommendations. Only recommend movies found in the database."
    ))
    
    # Combine system message with conversation history
    messages_with_system = [system_message] + state["messages"]
    
    # Bind tools to the LLM
    llm_with_tools = llm.bind_tools([retrieve])
    response = llm_with_tools.invoke(messages_with_system)
    
    return {"messages": [response]}

@tool(response_format="content_and_artifact")
def retrieve(query: str):
    """Retrieve information related to a query."""
    # Increase k to get more movies when user asks for multiple recommendations
    retrieved_docs = vector_store.similarity_search(query, k=6)  # Increased from 2 to 6
    serialized = "\n\n".join(
        (f"Source: {doc.metadata}\nContent: {doc.page_content}")
        for doc in retrieved_docs
    )
    return serialized, retrieved_docs

tools = ToolNode([retrieve])

def generate(state: MessagesState):
    """Generate movie recommendation response."""
    
    # Get the most recent tool messages (movie retrieval results)
    recent_tool_messages = []
    for message in reversed(state["messages"]):
        if message.type == "tool":
            recent_tool_messages.append(message)
        else:
            break
    tool_messages = recent_tool_messages[::-1]  # Reverse to get correct order
    
    # Format movie information for the context
    movies_content = "\n\n".join(msg.content for msg in tool_messages)
    
    # Get the most recent user question
    user_question = ""
    for message in reversed(state["messages"]):
        if message.type == "human":
            user_question = message.content
            break
    
    # Simplified generation prompt
    generation_prompt = f"""You are a movie recommendation assistant. Based on the user's query and the provided movie database context, provide detailed movie recommendations.

User Query: {user_question}

Available Movies from Database:
{movies_content}

Instructions:
- Recommend movies that best match the user's preferences
- If they ask for multiple movies, recommend multiple movies from the database
- Include relevant details like title, year, director, cast, genres, and plot
- Explain why each movie matches their request
- Only recommend movies that appear in the database context above
- Format your response clearly and engagingly

Provide your recommendations now:"""

    # Generate response using the LLM directly
    response = llm.invoke([HumanMessage(content=generation_prompt)])
    
    # Extract context from tool message artifacts for state tracking
    context = []
    for tool_message in tool_messages:
        if hasattr(tool_message, 'artifact') and tool_message.artifact:
            context.extend(tool_message.artifact)
    
    return {
        "messages": [response], 
        "context": context
    }




#def chatbot(state: State):
#    return {"messages": [llm.invoke(state["messages"])]}

In [None]:
#graph_builder = StateGraph(State)

#graph_builder.add_node("chatbot", chatbot)

#graph_builder.add_edge(START, "chatbot")
#graph_builder.add_edge("chatbot", END)

graph_builder = StateGraph(State)
    
# Add nodes
graph_builder.add_node("query_or_respond", query_or_respond)
graph_builder.add_node("tools", tools)
graph_builder.add_node("generate", generate)

# Set entry point
graph_builder.set_entry_point("query_or_respond")

# Add conditional edges
graph_builder.add_conditional_edges(
    "query_or_respond",
    tools_condition,
    {END: END, "tools": "tools"},
)

# Add regular edges
graph_builder.add_edge("tools", "generate")
graph_builder.add_edge("generate", END)

graph = graph_builder.compile()

try:
    display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

In [None]:
result = graph.invoke({
    "messages": [{"role": "user", "content": "Recommend 3 movies about space exploration, containing aliens. Idealy they should be horror sci-fi movies."}]
})

# Access the final response
final_message = result["messages"][-1]
print(final_message.content)

# Access the retrieved context (movies that were found)
retrieved_movies = result["context"]
print(f"Found {len(retrieved_movies)} movies in the database")