In [15]:
import os
from dotenv import load_dotenv
load_dotenv()

os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")

In [16]:
from typing_extensions import TypedDict, List

class StateFlow(TypedDict):
    text : str
    classification: str
    entities: List[str]
    summary : str

In [17]:
from langchain_groq import ChatGroq

class GroqLLM:
    def __init__(self, model_name: str = "deepseek-r1-distill-llama-70b"):
        self.model_name = model_name
        self.llm = None
    
    def load_groq_llm(self, temperature: float = 0, max_tokens: int =1000):
        """loads required groq model"""
        try:
            if "llama" in self.model_name:
                self.llm = ChatGroq(model="llama3-70b-8192")
                #return llm
            elif "deepseek" in self.model_name:
                self.llm = ChatGroq(model="deepseek-r1-distill-llama-70b")
                #return llm
            elif "gemma" in self.model_name:
                self.llm = ChatGroq(model="gemma2-9b-it")
                #return llm
            elif "qwen" in self.model_name:
                self.llm = ChatGroq(model="qwen-qwq-32b")
            else:
                raise f"model name was not given"
            return  self.llm
        except Exception as e:
            raise f"Error occured as: {e}"

In [18]:
from langchain_core.prompts import PromptTemplate
from langchain_core.messages import HumanMessage

class AgentNodes:
    def __init__(self):
        self.llm = GroqLLM().load_groq_llm()
        #self.state = StateFlow


    def classification_node(self, state: StateFlow):
        """Classifies input text as News, Blog, Research Paper, or Other."""
        template = """
            Classify the following text into one of the categories: News, Blog, Research, or Other.
            \n\nText:{text}
            \n\nCategory:
        """
        prompt = PromptTemplate(
        input_variables=["text"],
        template= template
        )
        message = HumanMessage(content=prompt.format(text=state["text"]))
        classification = self.llm.invoke([message]).content.strip()
        return {"classification": classification}

    
    def entity_extraction_node(self, state: StateFlow):
        """Extract all the entities (Person, Organization, Location) from the text"""
        template = """
            Extract all the entities (Person, Organiztion, Location) from the following text, 
            provide result in the form of comma-separated values in a list.
            \n\nText: {text}
            \n\nEntities:  
        """
        prompt = PromptTemplate(
        input_variables=["text"],
        template= template
        )
        message = HumanMessage(content=prompt.format(text=state["text"]))
        entities = self.llm.invoke([message]).content.strip().split(", ")
        return {"entities": entities}
    
    def summary_node(self, state: StateFlow):
        """Summarize the text in on short sentence"""
        template = """
            Summarize the following text into one short sentence
            \n\nText: {text}
            \n\nSummary:
        """
        prompt = PromptTemplate(
        input_variables=["text"],
        template= template
        )
        message = HumanMessage(content=prompt.format(text=state["text"]))
        summary = self.llm.invoke([message]).content.strip()
        return {"summary": summary}

In [19]:
from langgraph.graph import StateGraph, START, END
from IPython.display import display, Image

class GraphBuilder:
    def __init__(self):
        self.nodes = AgentNodes()
        self.graph_builder = StateGraph(StateFlow)
        self.text_analysis_graph()  # Ensure edges and nodes are added

    def text_analysis_graph(self):
        """Creates the flow for the graph"""
        self.graph_builder.add_node("Text Classification", self.nodes.classification_node)
        self.graph_builder.add_node("Entities Extraction", self.nodes.entity_extraction_node)
        self.graph_builder.add_node("Text Summary", self.nodes.summary_node)

        self.graph_builder.add_edge(START, "Text Classification")
        self.graph_builder.add_edge("Text Classification", "Entities Extraction")
        self.graph_builder.add_edge("Entities Extraction", "Text Summary")
        self.graph_builder.add_edge("Text Summary", END)

    def setup_graph(self):
        """Compiles the graph"""
        return self.graph_builder.compile()

    def get_graph_image(self):
        """Visualizes the workflow as an image"""
        graph = self.setup_graph()
        display(Image(graph.get_graph().draw_mermaid_png()))


In [22]:
import nest_asyncio
nest_asyncio.apply()

async def get_response(text:str):
    """Provides response using the graph"""
    graph = GraphBuilder().setup_graph()
    response = await graph.ainvoke({"text": text})
    return response

async def get_graph():
    """creates workflow image"""
    obj = GraphBuilder()
    graph = await  obj.get_graph_image()
    return graph

In [23]:
if __name__ == "__main__":
    import asyncio
    text = input("Enter your text: ")
    result = asyncio.run(get_response(text))
    print("Classification:", result["classification"])
    print("\nEntities:", result["entities"])
    print("\nSummary:", result["summary"])

    graph = get_graph()
    print(graph)

Classification: Category: News

Entities: ['Here are the extracted entities in the form of comma-separated values:\n\nOpenAI', 'GPT-4', 'GPT-3']

Summary: OpenAI has announced GPT-4, a large multimodal model that achieves human-level performance on various benchmarks and is designed to be more efficient and safer than its predecessor.
<coroutine object get_graph at 0x000001E6BEB59F20>
