In [None]:
!pip install -q langgraph langchain huggingface_hub

import os, re, ast
from typing import TypedDict
from huggingface_hub import InferenceClient
from langgraph.graph import StateGraph, END

HF_TOKEN = os.getenv("HF_TOKEN") or "YOUR_HF_TOKEN_HERE"
MODEL_ID = "TeichAI/Qwen3-1.7B-Gemini-2.5-Flash-Lite-Preview-Distill"
client = InferenceClient(MODEL_ID, token=HF_TOKEN)

class GraphState(TypedDict):
    input: str
    output: str

ALLOWED_NODES = (
    ast.Expression, ast.BinOp, ast.UnaryOp, ast.Constant,
    ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Pow, ast.Mod,
    ast.FloorDiv, ast.UAdd, ast.USub, ast.Load, ast.Tuple,
    ast.List, ast.BinOp, ast.UnaryOp, ast.Expr
)

def safe_eval(expr: str):
    node = ast.parse(expr, mode="eval")
    for nd in ast.walk(node):
        if isinstance(nd, ast.Call):
            raise ValueError("Function calls not allowed")
        if not isinstance(nd, ALLOWED_NODES):
            raise ValueError(f"Unsupported element: {type(nd).__name__}")
    compiled = compile(node, "<string>", "eval")
    return eval(compiled, {"__builtins__": {}}, {})

def router_node(state: GraphState):
    text = state["input"].strip()
    math_pattern = re.search(r"(?:^|\b)(calculate|what is|evaluate)\b", text, flags=re.I) \
                   or re.search(r"\d+\s*[\+\-\*\/\^%]\s*\d+", text)
    if math_pattern:
        return {"route": "calculator"}
    return {"route": "llm"}

def calculator_node(state: GraphState):
    text = state["input"].strip()
    expr = re.sub(r"(calculate|what is|evaluate|please|=|\?)", "", text, flags=re.I).strip()
    expr = expr.replace("^", "**")
    try:
        result = safe_eval(expr)
        return {"output": f"Result = {result}"}
    except Exception:
        m = re.search(r"[\d\.\s\(\)\+\-\*\/\*\*%]+", text.replace("^", "**"))
        if m:
            candidate = m.group(0).replace("^", "**")
            try:
                result = safe_eval(candidate)
                return {"output": f"Result = {result}"}
            except Exception as e:
                return {"output": f"Invalid or unsupported math expression. ({e})"}
        return {"output": "No valid math expression found."}

def llm_node(state: GraphState):
    prompt = state["input"].strip()
    try:
        resp = client.text_generation({"inputs": prompt, "parameters": {"max_new_tokens": 150}})
    except Exception as e:
        return {"output": f"LLM error: {repr(e)}"}
    generated = None
    try:
        if isinstance(resp, dict):
            if "generated_text" in resp:
                generated = resp["generated_text"]
            elif "outputs" in resp and isinstance(resp["outputs"], list) and resp["outputs"]:
                first = resp["outputs"][0]
                generated = first.get("generated_text") or first.get("text") or str(first)
            elif "data" in resp and isinstance(resp["data"], list) and resp["data"]:
                first = resp["data"][0]
                generated = first.get("generated_text") or first.get("text") or str(first)
            else:
                generated = str(resp)
        elif isinstance(resp, list) and resp:
            first = resp[0]
            if isinstance(first, dict):
                generated = first.get("generated_text") or first.get("text") or str(first)
            else:
                generated = str(first)
        else:
            generated = str(resp)
    except Exception as e:
        generated = f"Failed to parse LLM response: {repr(e)}"
    if not generated:
        generated = str(resp)
    return {"output": generated}

graph = StateGraph(GraphState)
graph.add_node("router", router_node)
graph.add_node("calculator", calculator_node)
graph.add_node("llm", llm_node)
graph.set_entry_point("router")
graph.add_conditional_edges("router", lambda state: state["route"], {"calculator": "calculator", "llm": "llm"})
graph.add_edge("calculator", END)
graph.add_edge("llm", END)
app = graph.compile()

tests = [
    "12 + 8",
    "calculate 15 * (2 + 3)",
    "What is 2^10?",
    "What is 7+10-19"
]

for t in tests:
    out = app.invoke({"input": t})
    print("INPUT:", t)
    print("OUTPUT:", out.get("output"))
    print("-" * 50)


INPUT: 12 + 8
OUTPUT: Result = 20
--------------------------------------------------
INPUT: calculate 15 * (2 + 3)
OUTPUT: Result = 75
--------------------------------------------------
INPUT: What is 2^10?
OUTPUT: Result = 1024
--------------------------------------------------
INPUT: What is 7+10-19
OUTPUT: Result = -2
--------------------------------------------------
