In [None]:
# --- Imports ---
from typing import TypedDict
from PIL import Image
import requests, torch
from transformers import AutoProcessor, AutoModelForVision2Seq

# --- Model & Processor ---
from transformers import AutoProcessor, AutoModelForImageTextToText
import torch

model_id = "unsloth/llava-1.5-7b-hf"
processor = AutoProcessor.from_pretrained(model_id)

# requires accelerate installed
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    dtype=torch.float32  # CPU-friendly
)
model.to("cpu")




# --- State ---
class State(TypedDict, total=False):
    query: str
    image: str          # URL or base64
    category: str
    sentiment: str
    response: str

# --- Simple text-only functions ---
def categorize(state: State) -> State:
    text = (
        "Categorize the following Kenya School of Government (KSG) customer query "
        "into one of these categories: Admissions, Training, Certificates, General. "
        f"Query: {state['query']}"
    )
    inputs = processor(text=text, return_tensors="pt").to(model.device)
    with torch.inference_mode():
        output = model.generate(**inputs, max_new_tokens=50)
    return {"category": processor.decode(output[0], skip_special_tokens=True).strip()}

def analyze_sentiment(state: State) -> State:
    text = (
        "Analyze the sentiment of the following KSG customer query. "
        "Respond with either 'Positive', 'Neutral', or 'Negative'. "
        f"Query: {state['query']}"
    )
    inputs = processor(text=text, return_tensors="pt").to(model.device)
    with torch.inference_mode():
        output = model.generate(**inputs, max_new_tokens=50)
    return {"sentiment": processor.decode(output[0], skip_special_tokens=True).strip()}

# --- Multimodal handler ---
def multimodal_handler(state: State, instruction: str) -> State:
    text = (
        f"You are a Kenya School of Government support assistant. "
        f"{instruction} Query: {state['query']}"
    )
    if state.get("image"):
        image = Image.open(requests.get(state["image"], stream=True).raw)
    else:
        image = None

    inputs = processor(images=image, text=text, return_tensors="pt").to(model.device)
    with torch.inference_mode():
        output = model.generate(**inputs, max_new_tokens=200)
    answer = processor.decode(output[0], skip_special_tokens=True)
    return {"response": answer.strip()}

# --- Handlers ---
def handle_admissions(state: State) -> State:
    return multimodal_handler(state, "Provide admissions/application support.")

def handle_training(state: State) -> State:
    return multimodal_handler(state, "Provide training/program information.")

def handle_certificates(state: State) -> State:
    return multimodal_handler(state, "Provide certificates/verification support.")

def handle_general(state: State) -> State:
    return multimodal_handler(state, "Provide general support.")

def escalate(state: State) -> State:
    return {"response": "This query has been escalated to a human KSG agent due to negative sentiment."}

# --- Simple router ---
def route_query(state: State) -> str:
    if state.get('sentiment', '').strip().lower() == 'negative':
        return "escalate"
    category = state.get('category', '').strip().lower()
    if "admission" in category:
        return "admissions"
    elif "training" in category or "program" in category:
        return "training"
    elif "certificate" in category or "verification" in category:
        return "certificates"
    else:
        return "general"

# --- Orchestrator ---
def run_customer_support(query: str, image: str = None) -> dict:
    """Passes text and optional image URL through the workflow."""

    # step1 categorize
    state: State = {"query": query, "image": image}
    state.update(categorize(state))

    # step2 sentiment
    state.update(analyze_sentiment(state))

    # step3 route
    route = route_query(state)
    if route == "admissions":
        state.update(handle_admissions(state))
    elif route == "training":
        state.update(handle_training(state))
    elif route == "certificates":
        state.update(handle_certificates(state))
    elif route == "general":
        state.update(handle_general(state))
    else:
        state.update(escalate(state))

    return {
        "category": state.get("category"),
        "sentiment": state.get("sentiment"),
        "response": state.get("response"),
    }

# --- Example usage ---
output = run_customer_support(
    "How do I apply for the next leadership training at KSG?",
    image="https://example.com/sample_certificate.jpg"
)
print("Category:", output['category'])
print("Sentiment:", output['sentiment'])
print("Response:", output['response'])


model.safetensors.index.json: 0.00B [00:00, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.18G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]