In [8]:
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field
from typing import Literal
from langchain_groq import ChatGroq
import os
from dotenv import load_dotenv
load_dotenv()

True

In [2]:
groq_api_key=os.getenv("GROQ_API_KEY")
os.environ["GROQ_API_KEY"]=groq_api_key
model=ChatGroq(model="llama-3.1-8b-instant")

In [9]:
class Feedback(BaseModel):
    sentiment: Literal["positive","negative"] = Field(description="Give the sentiment of the feedback.")

In [10]:
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser

parser=StrOutputParser()
pydatnic_parser=PydanticOutputParser(pydantic_object=Feedback)

In [11]:
from langchain_core.prompts import PromptTemplate

# 1. Sentiment Analysis
prompt1=PromptTemplate(
    template="""Classify the sentiment of the following feedback text into positive or negative.
    Feedback: {feedback}
    {format_instructions}""",
    input_variables=["feedback"],
    partial_variables={"format_instructions":pydatnic_parser.get_format_instructions()}
)


In [14]:
classifier_chain = prompt1 | model | pydatnic_parser

In [19]:
result = classifier_chain.invoke({"feedback":"This is a terrible smartphone"})
result

Feedback(sentiment='negative')

In [23]:
from langchain.schema.runnable import RunnableBranch

In [24]:
"""
RunnableBranch(
    (condition1,chain1),
    (condition2,chain2),
    default chain
)
"""

'\nRunnableBranch(\n    (condition1,chain1),\n    (condition2,chain2),\n    default chain\n)\n'

In [25]:
prompt2=PromptTemplate(
    template="Write an response to this positive feedback: {feedback}",
    input_variables=["feedback"]
)

prompt3=PromptTemplate(
    template="Write an response to this negative feedback: {feedback}",
    input_variables=["feedback"]
)

In [28]:
branch_chain = RunnableBranch(
    (lambda x: x.sentiment == "positive", prompt2 | model | parser),
    (lambda x: x.sentiment == "negative", prompt3 | model | parser),
    lambda x: "Failed!"  # By default changes to RunnableLambda
)

In [29]:
branch_chain

RunnableBranch(branches=[(RunnableLambda(lambda x: x.sentiment == 'positive'), PromptTemplate(input_variables=['feedback'], input_types={}, partial_variables={}, template='Write an response to this positive feedback: {feedback}')
| ChatGroq(client=<groq.resources.chat.completions.Completions object at 0x000001F9BD0D2750>, async_client=<groq.resources.chat.completions.AsyncCompletions object at 0x000001F9BE28CD10>, model_name='llama-3.1-8b-instant', model_kwargs={}, groq_api_key=SecretStr('**********'))
| StrOutputParser()), (RunnableLambda(lambda x: x.sentiment == 'negative'), PromptTemplate(input_variables=['feedback'], input_types={}, partial_variables={}, template='Write an response to this negative feedback: {feedback}')
| ChatGroq(client=<groq.resources.chat.completions.Completions object at 0x000001F9BD0D2750>, async_client=<groq.resources.chat.completions.AsyncCompletions object at 0x000001F9BE28CD10>, model_name='llama-3.1-8b-instant', model_kwargs={}, groq_api_key=SecretStr('*

In [32]:
chain = classifier_chain | branch_chain

In [37]:
chain.invoke({"feedback" : "This is a terrible Smartphone."})

"I apologize that our product/service did not meet your expectations. Can you please provide more details about your experience so we can better understand what went wrong and how we can improve? Your feedback is invaluable to us, and we appreciate your honesty. \n\nIf there's anything we can do to rectify the situation, please let us know and we'll do our best to resolve the issue. We're committed to providing the best possible experience for our customers, and we'll use your feedback to make positive changes in the future."

In [38]:
chain.get_graph().print_ascii()

    +-------------+      
    | PromptInput |      
    +-------------+      
            *            
            *            
            *            
   +----------------+    
   | PromptTemplate |    
   +----------------+    
            *            
            *            
            *            
      +----------+       
      | ChatGroq |       
      +----------+       
            *            
            *            
            *            
+----------------------+ 
| PydanticOutputParser | 
+----------------------+ 
            *            
            *            
            *            
       +--------+        
       | Branch |        
       +--------+        
            *            
            *            
            *            
    +--------------+     
    | BranchOutput |     
    +--------------+     
