In [7]:
import base64
import boto3
import logging
import nest_asyncio
import operator
import os
import re
import time
import uuid

from botocore.config import Config
from collections import defaultdict
from dotenv import load_dotenv
from IPython.display import display, HTML
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.anthropic import Anthropic
from llama_index.llms.openai import OpenAI
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod
from langgraph.graph import END, StateGraph
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain_core.tools import StructuredTool
from typing import Annotated, Dict, List, Sequence, TypedDict, DefaultDict, Any, Optional
from typing_extensions import TypedDict

In [2]:
logger = logging.getLogger(__name__)

if not logger.hasHandlers():
    logger.setLevel(logging.INFO)
    handler = logging.StreamHandler()
    formatter = logging.Formatter('%(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)

### Loading of LLMs and Embedding Models

In [3]:
load_dotenv()

CLAUDE_API_KEY = os.getenv('CLAUDE_API_KEY')
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')

In [4]:
openai = OpenAI(model="gpt-3.5-turbo-0125", openai_api_key=OPENAI_API_KEY, temperature=0.0, streaming=True)

In [5]:
# bge_embed_model = TextEmbedding(model_name="BAAI/bge-large-en-v1.5")
openai_embed_model = OpenAIEmbedding(model_name="text-embedding-3-small")

### Define GraphState's Storage Parameters

In [6]:
class GraphState(TypedDict):
    """
    Represents the state of a graph.

    Attributes:
        query (str): The user query
        expanded_queries (List[str]): The expanded queries generated by the agent
        agent (str): The agent responsible for decision making/answer generating
        context (str): The context retrieved from the DB
        answer (str): The answer generated by the agent
    """
    query: str
    expanded_queries: List[str]
    agent: str
    context: str
    answer: str

### Define Agents
1. Query Expansion Agent
2. Retrieval Agent
3. Grading Agent
4. Answer Generation

In [None]:
def query_expansion(query: str) -> List[str]:
    """
    Expands a query using LLM

    Args:
        query (str): The user query

    Returns:
        List[str]: The expanded queries
    """
    prompt = """You are a creative AI assistant specializing in expanding user queries to make them more comprehensive and diverse. Your goal is to generate multiple variant queries based on the initial user query, capturing different aspects, synonyms, related terms, and broader or narrower contexts. Ensure that the expanded queries are relevant, diverse, and avoid repetition.

    ### Instructions:
    1. Take the initial query provided by the user.
    2. Generate 3 variant queries that explore different interpretations, related topics, or alternative phrasings.
    3. Ensure the variants cover a range of specific to broad scopes and use synonyms or related terms.
    4. Avoid repeating the same information or using overly similar phrasing.
    5. Output the expanded queries in a JSON format, following the examples provided.
    6. Do not include any preamble, explanation, or additional information beyond the expanded queries in the given JSON format.

    ### Examples:
    Query: "machine learning algorithms"  
    Response:  
    {{
        "expanded_queries": [
            "types of machine learning algorithms",
            "applications of supervised learning techniques",
            "deep learning vs traditional machine learning approaches"
        ]
    }}

    ### Your Task:
    - Query: "{{query}}"
    - Response:
    """
    
    prompt_template = PromptTemplate(
        template=prompt,
        input_variables={"query": query}
    )
    
    chain = prompt_template | openai | JsonOutputParser()
    query_list = chain.invoke({"query": query})
    return query_list["expanded_queries"]