In [None]:
from typing import Annotated,Literal
from langgraph.graph import StateGraph,START,END
from langgraph.graph.message import add_messages
from pydantic import BaseModel,Field
from typing_extensions import TypedDict

# Basic Agents function

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification , BitsAndBytesConfig
# Load tokenizer and model
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_name = "/home/siamai/data/Focus/agentic/notebooks/model/xlm_routing"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_classifier = AutoModelForSequenceClassification.from_pretrained(model_name, 
                                                           num_labels=2,
                                                           quantization_config = quantization_config,
                                                           device_map="auto",
                                                           )

In [None]:
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch

# Alternative memory-efficient loading options without bitsandbytes

model_id = "/home/siamai/data/huggingface/hub/models--tarun7r--Finance-Llama-8B/snapshots/7934db35d2374c1321b90a9deb0d84b97525b025"

quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_multiple = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map="auto",         
    low_cpu_mem_usage=True,     
    trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Create pipeline
generator = pipeline(
    "text-generation",
    model=model_multiple,
    tokenizer=tokenizer,
    return_full_text=False
)

In [None]:
def classify_xlm(text:str):
    dict = {0:"multiple",1:"prediction"}
    inputs = tokenizer(text, 
                       padding=True, 
                       truncation=True, 
                       max_length=512,
                       return_tensors="pt").to("cuda")
    outputs = model_classifier(**inputs)
    logits = outputs.logits.argmax(dim=1)
    return dict[logits.item()]

def prediction_answer(text:str):
    pass

def multiple_answer(text:str,system_prompt:str):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": text}]
    prompt = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in messages])
    outputs = generator(
                prompt,
                max_new_tokens=32,         # Reduced for memory efficiency
                do_sample=True,
                temperature=0.3,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id,
                # Memory efficient generation settings
                num_beams=3,                # No beam search to save memory
                early_stopping=True,
                use_cache=True
                )
    # Extract response
    response = outputs[0]['generated_text'].strip()
    return response

In [None]:
system_prompt = "You are a highly knowledgeable finance chatbot. Your purpose is to select answer choice from user query you can select only character that represent the answer follwing by A, B, C, D"
text = """Answer the question with the appropriate options A, B, C and D. Please respond with the exact answer A, B, C or D only. Do not be verbose or provide extra information. 
Question: According to COSO, which of the following is the most effective method to transmit a message of ethical behavior throughout an organization?
Answer Choices: A: Demonstrating appropriate behavior by example., B: Strengthening internal audit’s ability to deter and report improper behavior., C: Removing pressures to meet unrealistic targets, particularly for short-term results., D: Specifying the competence levels for every job in an organization and translating those levels to requisite knowledge and skills. 
Answer:"""
multiple_answer(text = text,
               system_prompt=system_prompt)

# Function with Langgraph

In [None]:
PROMPT_MULTIPLE ="""You are a highly knowledgeable finance chatbot. Your purpose is to select answer choice from user query you can select only character that represent the answer follwing by A, B, C, D"""

In [None]:
class MessageClassifier(BaseModel):
    message_type: Literal["multiple","prediction"] = Field(
        ...,
        description="Classify if the message is multiple or prediction",
    )

class State(TypedDict):
    message: Annotated[list,add_messages]
    message_type: str | None
    next: str

def classify_message(state: State) -> State:
    messsage = state["message"][-1].content
    message_type = classify_xlm(messsage)
    validated_type = MessageClassifier(message_type=message_type) 
    return {"message_type":validated_type.message_type}

def router(state: State) -> State:
    message_type = state.get("message_type")
    return {"next":message_type}

def prediction_agent(state: State) -> State:
    message = state["message"][-1].content
    message = f"Hello User I'm a prediction_agent agent! JUST PLACE HOLDER"
    return {"message":message}

def multiple_agent(state: State) -> State:
    message = state["message"][-1].content
    respond = multiple_answer(text = message,system_prompt=PROMPT_MULTIPLE)
    return {"message":respond}

graph_builder = StateGraph(State)

graph_builder.add_node("classifier",classify_message)
graph_builder.add_node("router",router)
graph_builder.add_node("prediction_agent",prediction_agent)
graph_builder.add_node("multiple_agent",multiple_agent)

graph_builder.add_edge(START,"classifier")
graph_builder.add_edge("classifier","router")

graph_builder.add_conditional_edges(
    "router",
    lambda state: state.get("next"),
    {
        "prediction": "prediction_agent",
        "multiple": "multiple_agent"
    }
)
graph_builder.add_edge("prediction_agent",END)
graph_builder.add_edge("multiple_agent",END)
graph = graph_builder.compile()

In [None]:
#randomly select row from dataframe as input
import pandas as pd

df = pd.read_csv("/home/siamai/data/Focus/agentic/data/test.csv")
user_input = df.sample(n=1).iloc[0]["query"]
print(f"User input: {user_input}")
state = graph.invoke({"message":[user_input]})     
state["message"][-1].content

In [None]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png(max_retries=10)))