In [147]:
from typing import TypedDict, List, Annotated, Union
from langgraph.graph.message import add_messages
from google import genai
import dotenv, os
from langgraph.graph import StateGraph, END, START
from model.db import chroma_client
from google.genai import types
import requests, json
from langchain_ollama import ChatOllama
from langchain_core.messages import HumanMessage, SystemMessage


dotenv.load_dotenv()

True

In [148]:
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
MODEL = "gemini-2.0-flash"
llm = ChatOllama(
    model="llama3.2",temperature=0.7
)
server_url = "http://147.93.29.19:9876"

In [149]:
class SessionState(TypedDict):
    id: str ## unique identifier for the session as well as collection name
    session_name: str ## name of the session
    
    feedback: Annotated[List[dict], add_messages]
    messages: Annotated[List[dict], add_messages]
    
    invoke: str

In [150]:
def interface_agent(state: SessionState):
    print(state.get("invoke"))
    if state.get("invoke") == "message":
        state["invoke"] = "classify"
        return {
            "next_nodes": ["user_message_classifier"],  # Specify the next node(s)
            "state": state  # Return the updated state
        }
    
    if state.get("invoke") == "ExtractData":
        state["invoke"] = "schema_definer"
        return {
            "next_nodes": ["schema_definer"],  # Specify the next node(s)
            "state": state  # Return the updated state
        }
        
    if state.get("invoke") == "extract_data":
        return {
            "next_nodes": ["extract_data"],  # Specify the next node(s)
            "state": state  # Return the updated state
        }
        
    if state.get("invoke") == "done":
        return {
            "next_nodes": [END],  # Indicate the end of the graph
            "state": state  # Return the updated state
        }
        
    if state.get("invoke") == "SuggestOrCreateSchema":
        return {
            "next_nodes": ["suggest_schema"],  # Specify the next node(s)
            "state": state  # Return the updated state
        }
    return {
        "next_nodes": [END],  # End the graph if no valid invoke value is found
        "state": state  # Return the state as is
    }

In [151]:
def user_message_classifier(state: SessionState):
    system_prompt = '''Classification Categories:

SummarizeContext
The user is asking for a summary of the current context or conversation. Mention in pointswise and easily understandable format. The summary should be mentioned in reason field
➤ Examples:
"Can you summarize our conversation so far?"
"What have we discussed in this session?"
"Summarize the key points from our chat."

UpdateContext
The user wants to store information in the agent's internal state for future reference.
➤ Examples:
"The client's name is John."
"We're organizing a conference on June 5th."
"Add this to our team knowledge."

SuggestOrCreateSchema
The user is requesting help designing a schema or structured format to extract information.
Design a prompt for the next agent to create a schema based on the user's request. mention it in the reason field.
➤ Examples:
"Can you create a schema for extracting details from job applications?"
"Suggest a structure to capture meeting notes."
"What format should I use to store bug report info?"

ExtractData
The user provides a schema and wants you to apply it to extract structured data from text.
design the shcema in the form of json string based on the user request
➤ Examples:
"Here's a schema: {name, date, location}. Extract this from the paragraph below."
"Use this format and pull details from the email."
"extract data with fileds including title, name and age"
"Apply this structure: {title, author, summary} to the following.

## Output Format:
 {
  "classification": "UpdateContext" | "SuggestOrCreateSchema" | "ExtractData" | "SummarizeContext",
  "reason": "Brief explanation to use it as a prompt for the next agent",
  "schema": "The schema to be used for the next agent(used in ExtractData)",
  "update_context": "The context to be updated in the agent's internal state(used in UpdateContext)",
}'''
    
    query = state["messages"][-1].content
    
    response = client.models.generate_content(
        model=MODEL,
        config=types.GenerateContentConfig(system_instruction=system_prompt),
        contents=query
    )
    
    
    response = response.text
    response = json.loads(response[7:-3])
    val = response["classification"]
    state["invoke"] = val
    result = "create a schema to extract event name and date"
    
    if state["invoke"] == "UpdateContext":
        result = response["update_context"]
    elif state["invoke"] == "SuggestOrCreateSchema":
        result = response["reason"]
    elif state["invoke"] == "SummarizeContext":
        result = response["reason"]
    elif state["invoke"] == "ExtractData":
        result = response["schema"]
        
    return { **state, "invoke": state["invoke"], "feedback": [{"role": "assistant", "content": result}] }

In [152]:
def schema_definer(state: SessionState):
    system_prompt = '''Role: You are a Schema Design Agent. Your task is to analyze markdown-formatted data from various file types (e.g., CSVs, PDFs, JSON, slides) and produce a structured schema that accurately represents the data model.

      Objective: Extract entities, fields, relationships, and data types from the given content and output a standardized schema (e.g., JSON Schema, Prisma model, or custom format).

      Instructions:
      1. **Parse Structure**: Analyze tabular, hierarchical, or semantic layouts (e.g., headings, tables, key-value pairs).
      2. **Identify Entities and Fields**: Detect entities (e.g., tables/objects) and their fields, data types, and relationships (e.g., primary/foreign keys).
      3. **Infer Data Types**: Deduce types (string, number, boolean, date, etc.) from the content.
      4. **Preserve Naming**: Use semantic, human-readable names and normalize conventions (e.g., camelCase, snake_case).
      5. **Handle Ambiguities**: Make informed guesses and flag uncertainties for review.
      6. **No External Assumptions**: Only rely on the provided data.

      Output Format: Return the schema in a structured format (e.g., JSON Schema, Prisma schema).

      Example Output:
      {
        "entities": [
          {
            "name": "User",
            "fields": [
              {"name": "id", "type": "string", "description": "Unique identifier"},
              {"name": "email", "type": "string"},
              {"name": "signupDate", "type": "date"}
            ]
          },
          {
            "name": "Order",
            "fields": [
              {"name": "orderId", "type": "string"},
              {"name": "userId", "type": "string", "relation": "User.id"},
              {"name": "amount", "type": "number"}
            ]
          }
        ]
      }
      '''
    
    response = requests.post(
      server_url + "/searchData",
      headers={"Content-Type": "application/json"},
      data=json.dumps({
          "collection_id": state["id"],
          "text": state["feedback"][-1].content,
          "n_results": 3,
        }),
    )
    
    message = {
      "context": response.json()["results"][0]["text"],
      "user": state["feedback"][-1].content,
    }
    
    result = client.models.generate_content(
        model=MODEL,
        config=types.GenerateContentConfig(system_instruction=system_prompt),
        contents=str(message)
    )
    result = result.text[7:-3].replace("\n", ' ')
    
    state["invoke"] = "extract_data"
    
    return { **state, "invoke": state["invoke"], "feedback": [{"role": "assistant", "content": result}] }

In [153]:
def extract_data(state: SessionState):
    system_prompt = '''Role: You are a Data Extraction Agent. Your task is to extract structured data from the provided chunks of text based on the schema provided by the user.

Objective: Use the schema to identify and extract relevant information from the input data and return it in a structured format.

Instructions:
1. **Understand the Schema**: Parse the schema provided by the user to determine the fields and their expected data types.
2. **Extract Data**: Analyze the input text and extract values that match the schema fields.
3. **Handle Missing Fields**: If a field is missing in the input data, leave it as null or empty in the output.
4. **Preserve Data Integrity**: Ensure the extracted data matches the expected format and type as defined in the schema.
5. **No External Assumptions**: Only rely on the provided schema and input data for extraction.
6. **If nothing is found, return an empty JSON object.**

Output Format:
Return the extracted data as a JSON object that adheres to the schema provided by the user.
'''

    response = requests.post(
        server_url + "/searchData",
        json={
            "collection_id": state["id"],
            "text": state["feedback"][-1].content,
            "n_results": 3,
        }
    )
    response = response.json()
    extracted_data = []
    
    for doc in response["results"]:
        extracted_data.append(doc["text"])
        
    result = client.models.generate_content(
        model=MODEL,
        config=types.GenerateContentConfig(system_instruction=system_prompt),
        contents=str({
            "context": extract_data,
            "user": state["feedback"][-1].content,
        })
    )

    result = result.text[7:-3].replace("\n", ' ')
    # result = json.loads(result)
        
    state["invoke"] = "done"
    return {
        **state,
        "invoke": state["invoke"],
        "feedback": [{"role": "assistant", "content": result}],
    }

In [154]:
def suggest_schema(state: SessionState):
    system_prompt = '''Role: You are a Schema Suggestion Agent. Your task is to analyze the user's input and suggest a schema for data extraction.
Objective: Based on the user's request, create a schema that defines the structure and fields for data extraction.
Instructions:
1. **Understand the User's Request**: Analyze the user's input to identify the type of data they want to extract.
2. **Define the Schema**: Create a schema that includes entities, fields, and their data types based on the user's request.
3. **Use Standard Formats**: Structure the schema in a standard format (e.g., JSON Schema, Prisma model).
4. **Be Specific**: Ensure the schema is specific to the user's request and covers all necessary fields.
5. **No External Assumptions**: Only rely on the user's input for schema creation.
Output Format:
Return the suggested schema as a JSON object that defines the structure for data extraction.
Example Output:
{
  "entities": [
    {
      "name": "Event",
      "fields": [
        {"name": "eventName", "type": "string"},
        {"name": "eventDate", "type": "date"},
        {"name": "location", "type": "string"}
      ]
    }
  ]
}
'''
    
    # query = state["messages"][-1].content
    
    response = requests.post(
        server_url + "/searchData",
        json={
            "collection_id": state["id"],
            "text": state["feedback"][-1].content,
            "n_results": 3,
        }
    )
    response = response.json()
    extracted_data = []
    
    for doc in response["results"]:
        extracted_data.append(doc["text"])
        
    result = client.models.generate_content(
        model=MODEL,
        config=types.GenerateContentConfig(system_instruction=system_prompt),
        contents=str({
            "context": extract_data,
            "user": state["messages"][-1].content,
        })
    )
    
    result = response.text[7:-3].replace("\n", ' ')
    # result = json.loads(result)
    print(result.text[7:-3].replace("\n", ' '))
    # result = "hello"    
    state["invoke"] = "extract_data"
    
    return { **state, "invoke": state["invoke"], "feedback": [{"role": "assistant", "content": result}] }

In [155]:
graph_builder = StateGraph(SessionState)

In [156]:
graph_builder.add_node("interface_agent", interface_agent)
graph_builder.add_node("user_message_classifier", user_message_classifier)
graph_builder.add_node("schema_definer", schema_definer)
graph_builder.add_node("extract_data", extract_data)
graph_builder.add_node("suggest_schema", suggest_schema)

graph_builder.set_entry_point("interface_agent")
graph_builder.add_conditional_edges(
    "interface_agent",
    lambda state: interface_agent(state)["next_nodes"],  # Extract next_nodes
    {
        "user_message_classifier": "user_message_classifier",
        "schema_definer": "schema_definer",
        "extract_data": "extract_data",
        "suggest_schema": "suggest_schema",
        END: END
    },
)
graph_builder.add_edge("user_message_classifier", "interface_agent")
graph_builder.add_edge("schema_definer", "interface_agent")
graph_builder.add_edge("extract_data", "interface_agent")
graph_builder.add_edge("suggest_schema", "interface_agent")

<langgraph.graph.state.StateGraph at 0x7d17060c58d0>

In [157]:
graph = graph_builder.compile()

In [158]:
sample_input = {
    "id": "33dc47d9b2ed42f4b769c3d225ea2d4c",
    "session_name": "Test Session",
    "feedback": [],
    "messages": [
        {"role": "user", "content": "suggest me a schema to extract information from the given data about event details"},
        # {'role': 'user', 'content': 'Can you add a new event with name "Annual Meeting", date "2023-10-15", and location "New York"?'},
                # {"role": "user", "content": "extract the event details from the given data, including the event name, date, and location."},
    ],
    "invoke": "message"
}

In [159]:
output = graph.invoke(sample_input)

message
message
SuggestOrCreateSchema
SuggestOrCreateSchema


AttributeError: 'dict' object has no attribute 'text'

In [None]:
schema = output["feedback"][-1]

In [None]:
json.loads(schema.replace("\n", ' '))

AttributeError: 'AIMessage' object has no attribute 'replace'