Task 4

In [None]:
import json
import ast
import math
import operator as op
import traceback
import inspect
from typing import Any

from langchain_openai import ChatOpenAI
from langchain.tools import tool
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage

_ALLOWED_OPERATORS = {
    ast.Add: op.add,
    ast.Sub: op.sub,
    ast.Mult: op.mul,
    ast.Div: op.truediv,
    ast.Pow: op.pow,
    ast.USub: op.neg,
    ast.UAdd: op.pos,
    ast.Mod: op.mod,
    ast.FloorDiv: op.floordiv,
}

_ALLOWED_NAMES = {k: getattr(math, k) for k in dir(math) if not k.startswith("_")}
_ALLOWED_NAMES.update({"pi": math.pi, "e": math.e})

def safe_eval_expr(expr: str):
    """
    Safely evaluate simple numeric math expressions using ast.
    Supports numeric literals, + - * / ** unary +/-, mod, floor div, math.* functions.
    """
    def _eval(node):
        if isinstance(node, ast.Constant):
            if isinstance(node.value, (int, float)):
                return node.value
            raise ValueError("Unsupported constant type")
        if node.__class__.__name__ == "Num":
            val = getattr(node, "n", None)
            if isinstance(val, (int, float)):
                return val
            raise ValueError("Unsupported legacy numeric literal")

        if isinstance(node, ast.BinOp):
            op_type = type(node.op)
            if op_type not in _ALLOWED_OPERATORS:
                raise ValueError(f"Operator {op_type} not allowed")
            left = _eval(node.left)
            right = _eval(node.right)
            return _ALLOWED_OPERATORS[op_type](left, right)


        if isinstance(node, ast.UnaryOp):
            op_type = type(node.op)
            if op_type not in _ALLOWED_OPERATORS:
                raise ValueError(f"Unary operator {op_type} not allowed")
            return _ALLOWED_OPERATORS[op_type](_eval(node.operand))


        if isinstance(node, ast.Call):
            if not isinstance(node.func, ast.Name):
                raise ValueError("Only simple function calls allowed")
            fname = node.func.id
            if fname not in _ALLOWED_NAMES:
                raise ValueError(f"Function {fname} not allowed")
            f = _ALLOWED_NAMES[fname]
            args = [_eval(a) for a in node.args]
            return f(*args)

        if isinstance(node, ast.Name):
            name = node.id
            if name in _ALLOWED_NAMES:
                return _ALLOWED_NAMES[name]
            raise ValueError(f"Name {name} not allowed")

        raise ValueError(f"Unsupported AST node: {type(node)}")

    parsed = ast.parse(expr, mode="eval")
    return _eval(parsed.body)


@tool
def get_weather(location: str) -> str:
    """Simulated weather lookup (quick demo)."""
    weather_data = {
        "San Francisco": "Sunny, 72°F",
        "New York": "Cloudy, 55°F",
        "London": "Rainy, 48°F",
        "Tokyo": "Clear, 65°F",
    }
    return weather_data.get(location, f"Weather data not available for {location}")

@tool
def calculator(raw: str) -> str:
    """
    Calculator tool expecting a JSON string (raw).
    Accepts either:
      - '{"expr":"2+2"}'
      - '{"operation":"area_circle","radius":3}'
      - '{"operation":"sin","x":"pi/2"}'
    Returns JSON string.
    """
    try:
        try:
            data = json.loads(raw)
        except Exception:
            data = {"expr": raw}

        if isinstance(data, dict) and "raw" in data and isinstance(data["raw"], str):
            try:
                nested = json.loads(data["raw"])
                data = nested
            except Exception:
                data = {"expr": data["raw"]}


        if not isinstance(data, dict):
            return json.dumps({"error":"unexpected_payload","payload":str(data)})

        if "expr" in data:
            expr = str(data["expr"])
            val = safe_eval_expr(expr)
            return json.dumps({"result": val, "expression": expr})

        op_name = str(data.get("operation","")).lower()

        if op_name in ("add","sum"):
            a = float(data["a"]); b = float(data["b"])
            return json.dumps({"result": a + b})
        if op_name in ("sub","subtract","minus"):
            a = float(data["a"]); b = float(data["b"])
            return json.dumps({"result": a - b})
        if op_name in ("mul","multiply"):
            a = float(data["a"]); b = float(data["b"])
            return json.dumps({"result": a * b})
        if op_name in ("div","divide"):
            a = float(data["a"]); b = float(data["b"])
            if b == 0:
                return json.dumps({"error":"division_by_zero"})
            return json.dumps({"result": a / b})
        if op_name in ("pow","power"):
            a = float(data["a"]); b = float(data["b"])
            return json.dumps({"result": a ** b})
        if op_name in ("sin","cos","tan","sqrt","log","ln","exp"):
            x_expr = data.get("x")
            if x_expr is None:
                return json.dumps({"error":"missing_argument","message": "x required"})
            x_val = safe_eval_expr(str(x_expr))
            if op_name == "sin":
                return json.dumps({"result": math.sin(x_val)})
            if op_name == "cos":
                return json.dumps({"result": math.cos(x_val)})
            if op_name == "tan":
                return json.dumps({"result": math.tan(x_val)})
            if op_name == "sqrt":
                return json.dumps({"result": math.sqrt(x_val)})
            if op_name in ("log","ln"):
                return json.dumps({"result": math.log(x_val)})
            if op_name == "exp":
                return json.dumps({"result": math.exp(x_val)})
        if op_name == "area_circle":
            r = float(data["radius"])
            return json.dumps({"result": math.pi * r * r})
        if op_name == "circumference":
            r = float(data["radius"])
            return json.dumps({"result": 2 * math.pi * r})
        if op_name == "area_rectangle":
            w = float(data["width"]); h = float(data["height"])
            return json.dumps({"result": w * h})
        if op_name == "perimeter_rectangle":
            w = float(data["width"]); h = float(data["height"])
            return json.dumps({"result": 2 * (w + h)})
        if op_name == "area_triangle":
            b = float(data["base"]); h = float(data["height"])
            return json.dumps({"result": 0.5 * b * h})

        return json.dumps({"error":"unknown_operation","input":data})
    except Exception as e:
        return json.dumps({"error":"execution_error","message": str(e)})

@tool
def count_letter(text: str, letter: str) -> str:
    """Count occurrences of letter in text (case-insensitive). Returns simple integer string."""
    if not isinstance(text, str) or not isinstance(letter, str) or len(letter) == 0:
        return "0"
    return str(text.lower().count(letter.lower()))

@tool
def text_stats(text: str) -> str:
    """
    A third tool: returns simple statistics as JSON string:
      - length (chars)
      - words_count
      - vowels_count
      - unique_words (count)
    """
    s = text or ""
    words = [w for w in (s.strip().split()) if w]
    vowels = sum(1 for ch in s.lower() if ch in "aeiou")
    unique = len(set(w.strip(".,!?;:\"'()[]{}").lower() for w in words if w))
    out = {
        "length": len(s),
        "words_count": len(words),
        "vowels_count": vowels,
        "unique_words": unique
    }
    return json.dumps(out)


def _try_parse_possible_json(x: Any):
    if isinstance(x, str):
        s = x.strip()
        if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
            try:
                return json.loads(s)
            except Exception:
                return x
    return x

def normalize_args_obj(function_args: Any, tool_name: str = None):
    """
    Convert LLM-provided function_args into a plain dict or primitive as expected by tools.
    Handles: dicts, JSON strings, and parses nested JSON strings in dict fields.
    Crucially, it avoids premature unwrapping of 'raw' argument for tools like 'calculator'
    that expect 'raw' as a named string parameter.
    """
    if isinstance(function_args, str):
        try:
            return json.loads(function_args)
        except json.JSONDecodeError:
            return function_args
    elif isinstance(function_args, dict):
        if tool_name == "calculator" and "raw" in function_args and isinstance(function_args["raw"], str):
            return function_args
        parsed = {}
        for k, v in function_args.items():
            if isinstance(v, str):
                try:
                    parsed[k] = json.loads(v)
                except json.JSONDecodeError:
                    parsed[k] = v
            else:
                parsed[k] = v
        if "text" in parsed and isinstance(parsed["text"], dict) and "letter" in parsed["text"] and "letter" not in parsed:
            inner_args = parsed.pop("text")
            return {**inner_args, **parsed}

        return parsed
    return function_args

def invoke_tool(tool_obj, function_args, tool_name=None):
    """
    Try several sensible ways to call a tool object (StructuredTool/BaseTool/callable).
    Returns the tool result or an informative error string.
    """
    args_obj = normalize_args_obj(function_args, tool_name=tool_name)
    errors = []
    def rec(e):
        errors.append(traceback.format_exc())
    try:
        if hasattr(tool_obj, "run"):
            try:
                return tool_obj.run(args_obj)
            except TypeError:
                try:
                    return tool_obj.run(json.dumps(args_obj))
                except Exception as e:
                    rec(e)
            except Exception as e:
                rec(e)
    except Exception as e:
        rec(e)
    try:
        if hasattr(tool_obj, "invoke"):
            try:
                return tool_obj.invoke(args_obj)
            except TypeError:
                try:
                    return tool_obj.invoke(json.dumps(args_obj))
                except Exception as e:
                    rec(e)
            except Exception as e:
                rec(e)
    except Exception as e:
        rec(e)
    try:
        if callable(tool_obj):
            try:
                if isinstance(args_obj, dict):
                    return tool_obj(**args_obj)
                return tool_obj(args_obj)
            except TypeError:
                try:
                    return tool_obj(json.dumps(args_obj))
                except Exception as e:
                    rec(e)
            except Exception as e:
                rec(e)
    except Exception as e:
        rec(e)
    try:
        sig = None
        try:
            sig = inspect.signature(tool_obj)
        except Exception:
            pass
        if sig is not None and isinstance(args_obj, dict):
            params = sig.parameters
            call_kwargs = {k: v for k, v in args_obj.items() if k in params}
            try:
                return tool_obj(**call_kwargs)
            except Exception as e:
                rec(e)
    except Exception as e:
        rec(e)

    return "[invoke_tool_failed] Could not call tool. Attempts:\n" + "\n\n".join(errors or ["no details"])

llm = ChatOpenAI(model="gpt-4o-mini")
TOOLS = [get_weather, calculator, count_letter, text_stats]
tool_map = {}
for t in TOOLS:
    name = getattr(t, "name", None) or getattr(t, "__name__", None) or getattr(t, "__qualname__", None)
    tool_map[name] = t
llm_with_tools = llm.bind_tools(TOOLS)


SYSTEM_PROMPT = (
    "You are a helpful assistant with tools such asj: get_weather(location), calculator(raw), "
    "count_letter(text,letter), and text_stats(text).\n\n"
    "CRITICAL: For any numeric computation, call the 'calculator' tool. For any letter counting, call 'count_letter'.\n"
    "When calling calculator, pass a JSON string (e.g. '{\"expr\":\"2+2\"}' or '{\"operation\":\"area_circle\",\"radius\":3}').\n"
    "When calling count_letter, pass two fields: text and letter.\n"
)


def run_agent(user_query: str, max_iters: int = 5):
    """
    Agent loop that prints trace and attempts to dispatch tools using invoke_tool().
    """
    messages = [
        SystemMessage(content=SYSTEM_PROMPT),
        HumanMessage(content=user_query)
    ]

    print("======================================================================")
    print("Test for:", user_query)
    print("User:", user_query)
    print()

    for iteration in range(max_iters):
        print(f"--- Iteration {iteration + 1} ---")
        response = llm_with_tools.invoke(messages)
        if getattr(response, "tool_calls", None):
            n = len(response.tool_calls)
            print(f"LLM wants to call {n} tool(s)")
            messages.append(response)

            for tool_call in response.tool_calls:
                try:
                    function_name = tool_call.get("name") if isinstance(tool_call, dict) else getattr(tool_call, "name", None)
                    function_args = tool_call.get("args") if isinstance(tool_call, dict) else getattr(tool_call, "args", None)
                except Exception:
                    function_name = getattr(tool_call, "name", None)
                    function_args = getattr(tool_call, "args", None)

                print(f"  Tool: {function_name}")
                print(f"  Args: {function_args}")
                tool_obj = tool_map.get(function_name)
                if tool_obj is None:
                    result = f"Error: Unknown function {function_name}"
                else:
                    result = invoke_tool(tool_obj, function_args, tool_name=function_name)

                print(f"  Result: {result}")

                messages.append(ToolMessage(content=result, tool_call_id=(tool_call.get("id") if isinstance(tool_call, dict) else getattr(tool_call, "id", None))))

            print()
        else:
            final_text = getattr(response, "content", None)
            print(f"Assistant: {final_text}\n")
            return final_text

    print("Max iterations reached\n")
    return "Max iterations reached"


if __name__ == "__main__":
    tests = [
        "What's the weather like in San Francisco?",
        "Say hello!",
        "How many s are in Mississippi riverboats?",
        "Are there more i's than s's in Mississippi riverboats?",
        "What is the sin of the difference between the number of i's and the number of s's in Mississippi riverboats?",
        "Compute the area of a circle with radius 3",
    ]

    for q in tests:
        run_agent(q)
        print("\n")

Test for: What's the weather like in San Francisco?
User: What's the weather like in San Francisco?

--- Iteration 1 ---
LLM wants to call 1 tool(s)
  Tool: get_weather
  Args: {'location': 'San Francisco'}
  Result: Sunny, 72°F

--- Iteration 2 ---
Assistant: The weather in San Francisco is sunny with a temperature of 72°F.



Test for: Say hello!
User: Say hello!

--- Iteration 1 ---
Assistant: Hello! How can I assist you today?



Test for: How many s are in Mississippi riverboats?
User: How many s are in Mississippi riverboats?

--- Iteration 1 ---
LLM wants to call 1 tool(s)
  Tool: count_letter
  Args: {'text': 'Mississippi riverboats', 'letter': 's'}
  Result: 5

--- Iteration 2 ---
Assistant: There are 5 occurrences of the letter "s" in "Mississippi riverboats."



Test for: Are there more i's than s's in Mississippi riverboats?
User: Are there more i's than s's in Mississippi riverboats?

--- Iteration 1 ---
LLM wants to call 2 tool(s)
  Tool: count_letter
  Args: {'text': 'Mi