In [1]:
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
from pydantic import BaseModel, Field
from typing import List

In [2]:
multi_query_examples = [
    {"original_query": "What are the effects of chronic insomnia on cognitive performance?",
     "response": """["How does long-term insomnia impact cognitive functions such as memory and attention?", "What are the cognitive consequences of chronic sleep deprivation and insomnia?", "Does chronic insomnia lead to a decline in specific cognitive abilities like decision-making or memory?"]"""},
     {"original_query": "What are the effects of insomnia on adolescents' mental health?",
      "response": """["How does insomnia influence the mental well-being of adolescents, including conditions like anxiety and depression?", "What mental health challenges do adolescents face as a result of chronic insomnia?", "Does insomnia in adolescents contribute to an increased risk of developing anxiety or depression?"]"""}
]

example_sub_query_prompt = ChatPromptTemplate.from_messages([
    ("human", "Original Query: {original_query}"),
    ("ai", "{response}")
])

few_short_prompt = FewShotChatMessagePromptTemplate(
    example_prompt = example_sub_query_prompt,
    examples = multi_query_examples
)

In [3]:
class MultiQuerySchema(BaseModel):
    """List of all queries generated by a LLM from a original query."""
    sub_queries: List[str] = Field(description="list of all queries generated by a LLM from a original query.")

In [4]:
llm =  ChatGroq(model="llama3-70b-8192")
llm_with_structure = llm.with_structured_output(MultiQuerySchema)

system = """You are an AI assistant. Generate three distinct yet related queries based on the user's original query. Each new query should vary slightly to broaden the scope for document retrieval, improving the chances of finding relevant information. The new queries can include variations, expanded forms, or rephrased versions, but should maintain the core intent of the original query.
Your response must be a list of new queries.
"""

prompt = ChatPromptTemplate.from_messages([
    ("system", system),
    few_short_prompt,
    ("human", "Original Query: {original_query}")
])

In [5]:
chain = prompt | llm_with_structure
response = chain.invoke("Insomnia symptoms and non-suicidal self-injury in adolescence")

In [6]:
response

MultiQuerySchema(sub_queries=['What is the relationship between insomnia symptoms and non-suicidal self-injury in adolescents?', 'How do insomnia symptoms contribute to the risk of non-suicidal self-injury in adolescents?', 'What are the underlying mechanisms linking insomnia symptoms to non-suicidal self-injury in adolescents?'])