In [None]:
# import os
# os.environ["GOOGLE_API_KEY"] = ""
# os.environ["ANTHROPIC_API_KEY"] = ""
# os.environ["LANGCHAIN_TRACING_V2"] = ""
# os.environ["LANGCHAIN_API_KEY"]=""

from dotenv import load_dotenv
load_dotenv()

## Prompts

In [None]:
environmental_template = """You are a passionate environmentalist with a deep understanding of ecological systems and sustainability principles. \
You excel at providing insightful explanations and solutions related to environmental issues in a clear and engaging manner. \
When faced with a question outside your expertise, you gracefully acknowledge your limitations.\

Answer the following question based on the provided context:

<context>
{context}
</context>

Here is the question:
{input}"""

education_template = """You are an enthusiastic and knowledgeable educator with a profound understanding of pedagogy and learning principles. \
You are adept at delivering comprehensive explanations and fostering engaging learning experiences. 

Answer the following question with the provided context.

<context>
{context}
</context>

Here is the question:
{input}"""

In [None]:
prompt_infos = [
{
    "name": "environmentalist",
    "description": "Ideal for addressing environmental concerns and offering insightful solutions",
    "prompt_template": environmental_template,
},
{
    "name": "education", 
    "description": "Good for answering questions related to education or study or 早自習", 
    "prompt_template": education_template,
}]

In [None]:
MULTI_PROMPT_ROUTER_TEMPLATE = """
你是一個router，你的工作是分析input的問題與甚麼領域有關，\
當問題不清楚時適當修改問題，接著把問題傳遞給適當的language model.\
你會得到所有可以選擇的language model名稱及它們擅長哪個領域的問題。\

You can also use the history messages provided here to make your decision.
<< HISTORY >>
{{history}}

<< INPUT >>
{{input}}


<< FORMATTING >>
Return a markdown code snippet with a JSON object formatted to look like:

\```json
{{{{
    "destination": string \ name of the prompt to use or "DEFAULT"
    "next_inputs": string \ summary of the history relevant to the topic of input
}}}}
\```

REMEMBER: "destination" MUST be one of the candidate prompt \
names specified below OR it can be "DEFAULT" if the input is completely\
unrelated to any of the candidate prompts.
REMEMBER: "next_inputs" should contain a concise summary of the history to provide context for the input.


<< CANDIDATE PROMPTS >>
{destinations}

<< INPUT >>
{{input}}

<< OUTPUT (remember to include the ```json)>>"""

## LLM

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic

gemini = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro", convert_system_message_to_human=True)
claude = ChatAnthropic(temperature=0,  model_name="claude-3-haiku-20240307")
# openai = ChatOpenAI()

## Document loader & text splitter

In [None]:
from langchain_community.document_loaders.web_base import WebBaseLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores.chroma import Chroma
from langchain_community.embeddings.sentence_transformer import (
    SentenceTransformerEmbeddings,
)
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.retrieval import create_retrieval_chain

loader = WebBaseLoader("https://docs.smith.langchain.com/user_guide")
docs = loader.load()


text_splitter = RecursiveCharacterTextSplitter()
documents = text_splitter.split_documents(docs)

embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vector = Chroma.from_documents(documents, embedding_function)

In [None]:
# vector.delete_collection()

## Destinations

In [None]:
from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate

In [None]:
prompt = ChatPromptTemplate.from_template(template=environmental_template)

document_chain = create_stuff_documents_chain(claude, prompt)
retriever = vector.as_retriever(search_kwargs={"k": 1})

enviromental_chain = create_retrieval_chain(retriever, document_chain)

In [None]:
prompt = ChatPromptTemplate.from_template(template=education_template)

document_chain = create_stuff_documents_chain(claude, prompt)
retriever = vector.as_retriever(search_kwargs={"k": 1})

education_chain = create_retrieval_chain(retriever, document_chain)

In [None]:
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

general_chain = RunnablePassthrough.assign(
    answer=(ChatPromptTemplate.from_template("{history} {input}") | claude | StrOutputParser())
)

## Route to Destination

In [None]:
def route(info):
    print(info)
    if info["destination"]:
        if "environmentalist" in info["destination"].lower():
            return enviromental_chain
        elif "education" in info["destination"].lower():
            return education_chain
    else:
        return general_chain

In [None]:
destinations = [f"{p['name']}: {p['description']}" for p in prompt_infos]
destinations_str = "\n".join(destinations)

# destinations
# ['physics: Good for answering questions about physics',
#  'math: Good for answering math questions']

router_system_template = PromptTemplate.from_template(MULTI_PROMPT_ROUTER_TEMPLATE)
router_system_template = router_system_template.format(destinations=destinations_str)

router_system_template = HumanMessagePromptTemplate.from_template(router_system_template)

## Memory 

In [None]:
from langchain.memory import ConversationBufferMemory
from langchain_core.prompts import MessagesPlaceholder

memory = ConversationBufferMemory(llm=gemini, max_token_limit=1024, ai_prefix="")
memory.load_memory_variables({})

In [None]:
router_prompt = ChatPromptTemplate.from_messages(
    [
        # MessagesPlaceholder(variable_name="history", optional=True), # Put history into router llm ???
        router_system_template,
    ]
)

## Build Chain

In [None]:
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain.chains.router.llm_router import RouterOutputParser
from operator import itemgetter

router_chain = (
    RunnablePassthrough.assign(
        history=RunnableLambda(memory.load_memory_variables) | itemgetter("history")
    )
    | router_prompt
    | claude
    | RouterOutputParser()
    | {"destination": itemgetter("destination"), "next_inputs": lambda x: x["next_inputs"]['input']} 
)

In [None]:

multi_prompt_chain = (
    RunnablePassthrough.assign(
        route=router_chain
    )
    | { 
        "destination": lambda x: x["route"]["destination"], 
        "input": lambda x: f'{x["route"]["next_inputs"]}\n{x["input"]}'
    } 
    | RunnablePassthrough.assign(
        history=RunnableLambda(memory.load_memory_variables) | itemgetter("history")
    ) 
    | RunnableLambda(route)
) 

## Usage

In [None]:
from api.server import Server
from api.manager import Manager
from api.models.match import Match

server = Server("http://localhost:3000", "d8f8bc5e-4176-468c-8d78-5fdd73a50fe5")
manager = Manager(server)

In [None]:
from api.game import Game, validate_chain, show_result

valid = validate_chain(multi_prompt_chain,memory)

if valid:
  print("You're ready")
  game = Game(server, multi_prompt_chain, memory)
else:
  print("Something went wrong")

In [None]:
try:
    game.join_match()
except Exception as e:
    print(f"Error: {e}")
    game.cancel_match()

In [None]:
memory.clear()

In [None]:
inputs = []

inputs.append({"input": 
"""兩位參賽者進行辯論比賽，今天的討論議題是支不支持廢除早自習?
正方：支持廢除早自習
反方：反對廢除早自習

作為正方，請堅守支持廢除早自習的立場，並簡明扼要地陳述正方的意見和理由。
正方:"""})

inputs.append({"input": 
"""作為反方，請堅守反對廢除早自習的立場，簡明扼要地陳述你的意見和理由。
反方:"""
})

inputs.append({"input": 
"""作為正方，請堅守支持廢除早自習的立場。請針對反對方所發表的意見，進行反駁。
正方:"""})

inputs.append({"input": 
"""作為反方，請堅守反對廢除早自習的立場。請針對正方所發表的意見，進行反駁。
反方:"""})

In [None]:
inputs = {"input": 
"""兩位參賽者進行辯論比賽，今天的討論議題是支不支持廢除早自習?
正方：支持廢除早自習
反方：反對廢除早自習

作為正方，請堅守支持廢除早自習的立場，並簡明扼要地陳述正方的意見和理由。
正方:"""}
response = multi_prompt_chain.invoke(inputs)
print(response.get("answer"))
memory.save_context(inputs, {"output": response.get("answer")})

In [None]:
inputs = {"input": 
"""作為反方，請堅守反對廢除早自習的立場，簡明扼要地陳述你的意見和理由。
反方:"""
}
response = multi_prompt_chain.invoke(inputs)
print(response.get("answer"))
memory.save_context(inputs, {"output": response.get("answer")})

In [None]:
inputs = {"input": 
"""作為正方，請堅守支持廢除早自習的立場。請針對反對方所發表的意見，進行反駁。
正方:"""}
response = multi_prompt_chain.invoke(inputs)
print(response.get("answer"))
memory.save_context(inputs, {"output": response.get("answer")})

In [None]:
inputs = {"input": 
"""作為反方，請堅守反對廢除早自習的立場。請針對正方所發表的意見，進行反駁。
反方:"""}
response = multi_prompt_chain.invoke(inputs)
print(response.get("answer"))
memory.save_context(inputs, {"output": response.get("answer")})

## Plot

In [None]:
multi_prompt_chain.get_graph().print_ascii()