In [1]:
import os 
from dotenv import load_dotenv
from langchain_groq.chat_models import ChatGroq

load_dotenv()

True

In [2]:
groq_api_key = os.getenv('GROQ_API_KEY')

llm = ChatGroq(groq_api_key=groq_api_key, model="gemma2-9b-it")
llm.invoke("hi").content

'Hi there! 👋  How can I help you today? 😊\n'

In [3]:
# while True:
#     question = input("Enter you question: ")
#     if question in ["quit", "esc"]:
#         break 
#     result = llm.invoke(question).content
#     print(result)

#### Memory

In [4]:
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory

In [5]:
store = {}

In [6]:
def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = InMemoryChatMessageHistory()
    return store[session_id]

In [7]:
config = {"configurable": {"session_id": "1"}}

In [8]:
model_with_memory = RunnableWithMessageHistory(llm, get_session_history)

In [9]:
model_with_memory.invoke("Hi, my name is sam", config=config).content

'Hi Sam, nice to meet you!\n\nHow can I help you today? 😊\n'

In [10]:
model_with_memory.invoke("Tell me my name", config=config).content

'Your name is Sam!  😄  \n\nIs there anything else I can do for you?  \n'

#### RAG

In [11]:
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain.vectorstores import FAISS

In [13]:
os.environ['GOOGLE_API_KEY'] = os.getenv('GOOGLE_API_KEY')

embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")

In [14]:
loader = DirectoryLoader("../data", glob="./*.txt", loader_cls=TextLoader)
docs = loader.load()

In [15]:
len(docs)

3

In [16]:
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
documents = splitter.split_documents(docs)
len(documents)

21

In [18]:
db = FAISS.from_documents(documents, embeddings)

In [19]:
retriever = db.as_retriever()

In [21]:
from langchain_core.runnables import RunnablePassthrough, RunnableMap
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser

In [22]:
parser = JsonOutputParser()

In [23]:
template = """You are an AI Assistant tasked with answer the following question based on 
provided content {context}\n\n{format_instructions}
question: {query}
"""

prompt = PromptTemplate(
    template=template,
    input_variables=["query"],
    partial_variables={"format_instructions": parser.get_format_instructions()}
)

In [24]:
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

In [25]:
chain = (
    RunnableMap({"context": retriever | format_docs, "query": RunnablePassthrough()})
    | prompt 
    | llm 
    | parser
)

In [26]:
chain.invoke("In 2022 what's the GDP of Japan?")

{'answer': '$4,256.41B'}

#### Structured output with Pydantic

In [27]:
from pydantic import BaseModel, Field

class Result(BaseModel):
    country: str = Field(description="country name")
    answer: str = Field(description="provide appropriate answer")

In [28]:
llm_with_structured = llm.with_structured_output(Result)

In [31]:
chain = (
    RunnableMap({"context": retriever | format_docs, "query": RunnablePassthrough()})
    | prompt 
    | llm_with_structured 
)

In [32]:
chain.invoke("In 2022 what's the GDP of Japan?")

Result(country='Japan', answer='4,256.41B')

In [34]:
response = chain.invoke("In 2022 what's the GDP of Japan?")

In [35]:
response.answer

'$4,256.41B'

In [36]:
response.country

'Japan'

#### Adding cast

In [37]:
from typing import cast 

chain = cast ( Result, 
    RunnableMap({"context": retriever | format_docs, "query": RunnablePassthrough()})
    | prompt 
    | llm_with_structured 
)

In [38]:
chain.invoke("In 2022 what's the GDP of Japan?")

Result(country='Japan', answer='4,256.41B')

In [39]:
response = chain.invoke("In 2022 what's the GDP of Japan?")

In [40]:
response.country

'Japan'