In [None]:
from langchain.llms import OpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain, ConversationChain
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
from langchain.chains.router import MultiPromptChain, MultiRetrievalQAChain
from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser
from langchain.vectorstores import Chroma, FAISS
from langchain.document_loaders import TextLoader
from langchain.chains.router.embedding_router import EmbeddingRouterChain

## Setting the LLM

In [None]:
with open("openai_api.txt", "r") as f:
    OPENAI_API = f.read()

llm = OpenAI(
    model_name = "gpt-3.5-turbo-instruct",
    openai_api_key = OPENAI_API
)

embedding_llm = OpenAIEmbeddings(
    model = "text-embedding-ada-002",
    openai_api_key = OPENAI_API
)

## Router Chains

This type of chain allows the model to select the next chain to use for a given input.

The `Router Chains` are made up of two compenents:
* The **RouterChain** itselt - responsible for selecting the next chain to call
* **desination_chains** - the chains that it can call.

There are three different types of Router Chains, depending on the task that the Chain is being design to execute:
* Text Router Chains
* Embedding Router Chains
* Retrieval Router Chains

### Text Queries

In [None]:
## Creating the Prompt Templates

physics_template = """You are a very smart physics professor.
You are great at answering questions about physics in a concise and easy to understand manner.
When you don't know the answer to a question you admit that you don't know.

Here is a question:
{input}"""

math_template = """You are a very good mathematician. You are great at answering math questions.
You are so good because you are able to break down hard problems into their component parts,
answer the component parts, and then put them together to answer the broader question.

Here is a question:
{input}"""

In [None]:
prompt_infos = [
    {
        "name": "physics",
        "description": "Good for answering questions about physics",
        "prompt_template": physics_template
    },
    {
        "name": "math",
        "description": "Good for answering math questions",
        "prompt_template": math_template
    }
]

In [None]:
## Creating the Destination Chains

destination_chains = {}
for p_info in prompt_infos:
    name = p_info["name"]
    prompt_template = p_info["prompt_template"]

    prompt = PromptTemplate(template=prompt_template, input_variables=["input"])
    chain = LLMChain(llm=llm, prompt=prompt)

    destination_chains[name] = chain

default_chain = ConversationChain(llm=llm, output_key="text")

In [None]:
## Setting the Router Chain

destinations = [f"{p['name']}: {p['description']}" for p in prompt_infos]
destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
    destinations=destinations_str
)
router_prompt = PromptTemplate(
    template=router_template,
    input_variables=["input"],
    output_parser=RouterOutputParser(),
)
router_chain = LLMRouterChain.from_llm(llm, router_prompt)

In [None]:
## Connecting the RouterChain with the Destination Chains

chain = MultiPromptChain(
    router_chain = router_chain,
    destination_chains = destination_chains,
    default_chain = default_chain,
    verbose = True
)

In [None]:
print(chain.run("What is black body radiation?"))

In [None]:
print(chain.run("What is the first prime number greater than 40 such that one plus the prime number is divisible by 3"))

### Embedding Queries

In [None]:
## Setting the Information of the Destination Chains

names_and_descriptions = [
    ("physics", ["for questions about physics"]),
    ("math", ["for questions about math"])
]

In [None]:
## Setting RouterChain

router_chain = EmbeddingRouterChain.from_names_and_descriptions(
    names_and_descriptions = names_and_descriptions,
    vectorstore_cls = Chroma,
    embeddings = embedding_llm,
    routing_keys=["input"]
)

In [None]:
## Connection the RouterChain with the Destination Chains

chain = MultiPromptChain(
    router_chain = router_chain,
    destination_chains = destination_chains,
    default_chain = default_chain,
    verbose = True
)

In [None]:
print(chain.run("What is black body radiation?"))

### Retrieval Queries

In [None]:
## Creating the Destination Retrievers

sou_docs = TextLoader('../../state_of_the_union.txt').load_and_split()
sou_retriever = FAISS.from_texts(
    texts = sou_docs,
    embedding = embedding_llm
).as_retriever()

pg_docs = TextLoader('../../paul_graham_essay.txt').load_and_split()
pg_retriever = FAISS.from_documents(
    text = pg_docs,
    embedding = embedding_llm
).as_retriever()

personal_texts = [
    "I love apple pie",
    "My favorite color is fuchsia",
    "My dream is to become a professional dancer",
    "I broke my arm when I was 12",
    "My parents are from Peru",
]
personal_retriever = FAISS.from_texts(
    text = personal_texts,
    embedding = embedding_llm
).as_retriever()

In [None]:
## Setting the Information of each Retriever

retriever_infos = [
    {
        "name": "state of the union",
        "description": "Good for answering questions about the 2023 State of the Union address",
        "retriever": sou_retriever
    },
    {
        "name": "pg essay",
        "description": "Good for answer quesitons about Paul Graham's essay on his career",
        "retriever": pg_retriever
    },
    {
        "name": "personal",
        "description": "Good for answering questions about me",
        "retriever": personal_retriever
    }
]

In [None]:
## Connection the Router Chain with the Destination Chains

chain = MultiRetrievalQAChain.from_retrievers(
    llm = llm,
    retriever_infos = retriever_infos,
    verbose=True
)

In [None]:
print(chain.run("What did the president say about the economy?"))