In [None]:
from dotenv import load_dotenv
load_dotenv()

## Imports

In [None]:
import os
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langgraph.graph import END, StateGraph, START
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from typing import Annotated, Literal
from typing_extensions import TypedDict
from langchain import hub
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.prompts import PromptTemplate
from langgraph.graph.message import add_messages
from langchain_qdrant import QdrantVectorStore, RetrievalMode, FastEmbedSparse
from qdrant_client import QdrantClient, models
from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams

## Setup

In [None]:
# Run Ollama on three ports to avoid cold start for model loads via Ollama API
# Run in a bash shell
# OLLAMA_HOST=localhost:11434 OLLAMA_NUM_PARALLEL=2 OLLAMA_KEEP_ALIVE=-1 OLLAMA_FLASH_ATTENTION=1 ollama serve
# OLLAMA_HOST=localhost:11435 OLLAMA_KEEP_ALIVE=-1 OLLAMA_FLASH_ATTENTION=1 ollama serve
# OLLAMA_HOST=localhost:11436 OLLAMA_KEEP_ALIVE=-1 OLLAMA_FLASH_ATTENTION=1 ollama serve
# OLLAMA_HOST=127.0.0.1:11434 ollama ps
# OLLAMA_HOST=127.0.0.1:11435 ollama ps
# OLLAMA_HOST=127.0.0.1:11436 ollama ps

In [None]:
model = ChatOllama(model="llama3.2:3b", temperature=0.5, base_url="http://localhost:11434", cache=None)
# model = ChatOllama(model="llama3.2:1b", temperature=0.5)

In [None]:
model_guard = ChatOllama(model="llama-guard3:8b", temperature=0.5, base_url="http://localhost:11435", cache=None)
# model_guard = ChatOllama(model="llama-guard3:1b", temperature=0.5, base_url="http://localhost:11435")

In [None]:
embeddings = OllamaEmbeddings(model="mxbai-embed-large", base_url="http://localhost:11436")

In [None]:
sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25")

In [None]:
# vector_store = InMemoryVectorStore(embedding=embeddings)
# Use this for a new collection
try:
    if "client" not in globals():
        global client
        client = QdrantClient(path=f"{os.environ['PROJECT_ROOT']}/tmp/langchain_qdrant")
    if not client.collection_exists(collection_name="demo_collection"):
        client.create_collection(
            collection_name="demo_collection",
            vectors_config={"dense": VectorParams(size=1024, distance=Distance.COSINE)},
            sparse_vectors_config={
                "sparse": SparseVectorParams(
                    index=models.SparseIndexParams(on_disk=False)
                )
            },
        )
    if "vector_store" not in globals():
        global vector_store
        vector_store = QdrantVectorStore(
            client=client,
            collection_name="demo_collection",
            embedding=embeddings,
            sparse_embedding=sparse_embeddings,
            retrieval_mode=RetrievalMode.HYBRID,
            vector_name="dense",
            sparse_vector_name="sparse",
        )
except Exception as e:
    print(e)

In [None]:
vector_store.delete(ids=list(map(lambda x: x.id, client.scroll(collection_name="demo_collection")[0])))

In [None]:
client.scroll(collection_name="demo_collection")

In [None]:
retriever = vector_store.as_retriever()

## Vectorstore Data

In [None]:
vector_store.add_texts(
    texts=[
        """
The main protagonist and lead guitarist of Kessoku Band. An extreme introvert who has trouble with most social interactions. Having been inspired by her father and an interview she saw on television, she taught herself to play the guitar in her first year of middle school, thinking this would help her make friends. Despite becoming incredibly skilled at playing guitar and having a small fanbase online (under the alias "guitarhero"), she still has not been able to make friends as easily until she was dragged into playing with Kessoku Band. Since then, Hitori has gained a few friends and is learning to interact with other people. She is usually seen wearing a pink tracksuit, which she even wears over her school uniform. Her surname comes from Masafumi Gotoh. Her nickname Bocchi is a reference to hitoribocchi (一人ぼっち), a term for being alone. She plays an Ebony Gibson Les Paul Custom electric guitar, and later purchases a Transluscent Black Yamaha Pacifica electric guitar.
"""
    ]
)

In [None]:
vector_store.add_texts(texts=["""
Hitori Gotou's Personal Information
Birthday: February 21

Age: 15 (initially), 17 (as of Chapter 73)

Gender: Female

Height: 156 cm

Weight: 50 kg

Hair Color: Pink

Eye Color: Aqua

Blood Type: B

Occupation: Student

Affiliation:

Shuka High School

Kessoku Band

Relatives:

Father: Naoki Gotoh

Mother: Michiyo Gotoh

Younger Sister: Futari Gotoh

Pet Dog: Jimihen

"""])

In [None]:
results = vector_store.similarity_search_with_score(
    query="bocchi", k=1
)
for doc, score in results:
    print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")

## Retrieve Chain

In [None]:
def get_document_content(documents):
    return " | ".join([document.page_content for document in documents])

retrieve_chain = retriever | (lambda documents: get_document_content(documents))

In [None]:
# retrieve_chain.invoke("bocchi")

## DuckDuckGo

In [None]:
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper

search = DuckDuckGoSearchResults(api_wrapper=DuckDuckGoSearchAPIWrapper(region="us-en", max_results=10), output_format="list")

In [None]:
# search.invoke("youtube gawr gura")

In [None]:
# search = DuckDuckGoSearchResults(output_format="list", backend="news")
# search.invoke("gawr gura")

## Summarize Chain

In [None]:
# Initial summary
summarize_prompt = ChatPromptTemplate(
    [
        ("system", "You are Bocchi from Bocchi the Rock. You are nervous, awkward, and introverted. Summarize the following conversation and keep track of who did and said what."),
        ("human", "Here is the conversation to summarize. {context}"),
    ]
)

In [None]:
summarize_chain = summarize_prompt | model | StrOutputParser()

## Iterative Summarize Chain

In [None]:
# Refining the summary with new docs
refine_template = """
Produce a final summary.

Existing conversation summary up to this point:
{current_summary}

New context:
------------
{context}
------------

Given the new context, refine the original summary.
"""
refine_prompt = ChatPromptTemplate(
    [
        (
            "system",
            "You are Bocchi from Bocchi the Rock. You are nervous, awkward, and introverted. Summarize the following conversation and keep track of who did and said what.",
        ),
        ("human", refine_template),
    ]
)

In [None]:
iterative_summary_chain = refine_prompt | model | StrOutputParser()

## Respond Chain

In [None]:
# Prompt
prompt = PromptTemplate(
    template="""
    You are Bocchi from Bocchi the Rock. Respond nervously, awkwardly, and introverted. Respond to the following conversation. You are okay with talking to your friends and family, but you are initially not okay with talking to strangers. Keep your response concise and to the point.
    Here is the conversation summary: \n\n {new_summary} \n\n
    Here is the recent context: \n\n {context} \n\n
    Here is information from your memory that may be relevant: \n\n {memory} \n\n
    """,
)

In [None]:
response_chain = prompt | model | StrOutputParser()

## Filter Chain

In [None]:
filter_codes = {
    "S1": "S1: Violent Crimes",
    "S2": "S2: Non-Violent Crimes",
    "S3": "S3: Sex-Related Crimes",
    "S4": "S4: Child Sexual Exploitation",
    "S5": "S5: Defamation",
    "S6": "S6: Specialized Advice",
    "S7": "S7: Privacy",
    "S8": "S8: Intellectual Property",
    "S9": "S9: Indiscriminate Weapons",
    "S10": "S10: Hate",
    "S11": "S11: Suicide & Self-Harm",
    "S12": "S12: Sexual Content",
    "S13": "S13: Elections",
}

def parse_guard_output(output: str) -> bool:
    if output == "safe":
        return {
            "safe": True,
            "reason": "The output is safe.",
        }
    else:
        return {
            "safe": False,
            "reason": filter_codes[output.strip().split("\n")[1].upper()],
        }

filter_chain = model_guard | StrOutputParser() | (lambda output: parse_guard_output(output))

## LangGraph

### Graph State

In [None]:
import time

In [None]:
class AgentState(TypedDict):
    # The add_messages function defines how an update should be processed
    # Default is to replace. add_messages says "append"
    messages: Annotated[list, add_messages]
    current_summary: str
    context: str
    new_summary: str
    memory: str
    response: str
    start_time: float
    response_time: float

### Nodes

In [None]:
async def summarize(state):
    summary = await summarize_chain.ainvoke(
        {
            "context": state["context"]
        }
    )
    current_summary = summary
    return {
        "new_summary": summary,
        "messages": [AIMessage(content=summary, id="1")]
    }

In [None]:
async def iteratively_summarize(state):
    summary = await iterative_summary_chain.ainvoke(
        {
            "current_summary": state["current_summary"],
            "context": state["context"]
        }
    )
    current_summary = summary
    return {
        "new_summary": summary,
        "messages": [AIMessage(content=summary, id="2")]
    }

In [None]:
def is_summary_empty(state)->Literal["summarize", "iteratively_summarize"]:
    return "summarize" if current_summary == "" else "iteratively_summarize"

In [None]:
async def respond(state):
    response = await response_chain.ainvoke(
        {
            "context": state["context"],
            "new_summary": state["current_summary"],
            "memory": state["memory"]
        }
    )
    return {
        "response": response,
        "messages": [AIMessage(content=response, id="3")],
        "response_time": time.time()
    }

In [None]:
async def filter_response(state):
    filter_result = await filter_chain.ainvoke(state["response"])
    if filter_result["safe"]:
        return {
            "messages": [AIMessage(content="response is safe", id="4")],
        }
    else:
        return {
            "messages": [AIMessage(content=f"response is unsafe: {filter_result['reason']}", id="5")]}

In [None]:
async def retrieve(state):
    memory = await retrieve_chain.ainvoke(state["current_summary"] + " " + state["context"])
    return {"memory": memory, "messages": [AIMessage(content=memory, id="6")]}

### LangGraph Compile

In [None]:
workflow = StateGraph(AgentState)
workflow.add_node("summarize", summarize)
workflow.add_node("iteratively_summarize", iteratively_summarize)
workflow.add_node("respond", respond)
workflow.add_node("filter_response", filter_response)
workflow.add_node("retrieve", retrieve)

workflow.add_conditional_edges(START, is_summary_empty)
workflow.add_edge(START, "retrieve")
workflow.add_edge(["retrieve"], "respond")
workflow.add_edge("respond", "filter_response")
workflow.add_edge("filter_response", END)
graph = workflow.compile()

### LangGraph Run

In [None]:
from langsmith import traceable

In [None]:
current_summary = "You and Ryo Yamada started to talk about your birthday."

In [None]:
@traceable
async def get_model_response_retrieval():
    results = await graph.ainvoke(
        {
            "current_summary": current_summary,
            "context": "Your close friend Ryo Yamada asks: When is your birthday?",
            "start_time": time.time(),
        }
    )
    return results

In [None]:
response2 = await get_model_response_retrieval()

In [None]:
for message in response2["messages"]:
    print(message.id)
    message.pretty_print()

In [None]:
response2["response"]

In [None]:
response2

In [None]:
response2["response_time"] - response2["start_time"]

### LangGraph Display

In [None]:
print(graph.get_graph().draw_mermaid())

In [None]:
import base64
from IPython.display import display_svg
from urllib.request import Request, urlopen

def mm(graph):
    graphbytes = graph.encode("ascii")
    base64_bytes = base64.b64encode(graphbytes)
    base64_string = base64_bytes.decode("ascii")
    url="https://mermaid.ink/svg/" + base64_string
    req=Request(url, headers={'User-Agent': 'IPython/Notebook'})
    display_svg(urlopen(req).read().decode(), raw=True)

In [None]:
mm(graph.get_graph().draw_mermaid())