TASK 5

In [28]:
# langgraph_single_model_with_history.py
#
# LangGraph agent with chat message history (Message API) and single-model routing.
# - Maintains messages: roles supported: system, human, ai, tool
# - If input begins with "Hey Qwen" -> Qwen is disabled, a notice is added and the input is routed to Llama
# - Otherwise -> Llama
# - Empty input loops back to input node (not sent to model)
# - verbose / quiet toggles supported
# - No Qwen loading attempted (disabled)

import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain_huggingface import HuggingFacePipeline
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Tuple, List, Dict, Any

# -------------------------
# Device selection
# -------------------------
def get_device():
    if torch.cuda.is_available():
        print("Using CUDA (NVIDIA GPU)")
        return "cuda"
    elif torch.backends.mps.is_available():
        print("Using MPS (Apple Silicon)")
        return "mps"
    else:
        print("Using CPU")
        return "cpu"

# -------------------------
# Message / State types
# -------------------------
Message = Dict[str, str]  # {"role": "system"|"human"|"ai"|"tool", "content": "..."}

class AgentState(TypedDict):
    user_input: str
    should_exit: bool
    llama_response: str
    qwen_response: str   # kept for compatibility but not used since Qwen disabled
    verbose: bool
    messages: List[Message]

# -------------------------
# Build prompt from messages
# -------------------------
def build_prompt_from_messages(messages: List[Message]) -> str:
    """
    Convert messages (list of {role, content}) into a single string prompt for a text-generation model.
    Roles: system, human, ai, tool.

    The format is a simple readable transcript:
        [System] ...
        User: ...
        Tool: ...
        Assistant: ...
    Ends with 'Assistant:' as a cue for the model to respond.
    """
    out_lines: List[str] = []
    # Put system messages first (there may be multiple, we keep the order)
    for msg in messages:
        role = msg.get("role", "").lower()
        content = msg.get("content", "")
        if role == "system":
            out_lines.append(f"[System] {content}")
    # Then the other messages in order
    for msg in messages:
        role = msg.get("role", "").lower()
        content = msg.get("content", "")
        if role == "human" or role == "user":
            out_lines.append(f"User: {content}")
        elif role == "tool" or role == "function":
            out_lines.append(f"[Tool] {content}")
        elif role == "ai" or role == "assistant":
            out_lines.append(f"Assistant: {content}")
        # system already emitted
    # Add assistant cue for model to continue
    out_lines.append("Assistant:")
    return "\n".join(out_lines)

# -------------------------
# LLM loader with safe langchain fix
# -------------------------
def load_llm_wrapped(model_id: str, device: str):
    """
    Load a HF model and wrap in an adapter exposing .invoke(prompt) -> str.
    Raises RuntimeError on failure.
    """
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            dtype=torch.float16 if device != "cpu" else torch.float32,
            device_map=device if device == "cuda" else None,
        )
        if device == "mps":
            model = model.to(device)

        pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            max_new_tokens=256,
            temperature=0.7,
            top_p=0.95,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )

        # Defensive langchain attributes to avoid AttributeError in some installs
        try:
            import langchain
            if not hasattr(langchain, "verbose"):
                langchain.verbose = False
            if not hasattr(langchain, "debug"):
                langchain.debug = False
            if not hasattr(langchain, "llm_cache"):
                langchain.llm_cache = None
        except Exception:
            pass

        wrapped = HuggingFacePipeline(pipeline=pipe)

        class Adapter:
            def __init__(self, llm):
                self.llm = llm
            def invoke(self, prompt: str) -> str:
                out = self.llm.invoke(prompt)
                # If out is a string, return it. Otherwise try to extract generated_text
                if isinstance(out, str):
                    return out
                try:
                    if isinstance(out, list) and len(out) > 0 and isinstance(out[0], dict):
                        if "generated_text" in out[0]:
                            return out[0]["generated_text"]
                    if isinstance(out, dict) and "generated_text" in out:
                        return out["generated_text"]
                except Exception:
                    pass
                return str(out)

        return Adapter(wrapped)

    except Exception as e:
        raise RuntimeError(f"Failed to load {model_id}: {e}") from e

# -------------------------
# Create models (only Llama is loaded; Qwen disabled)
# -------------------------
def create_models() -> Tuple[object, None]:
    device = get_device()

    # Attempt to load Llama
    llama_id = "meta-llama/Llama-3.2-1B-Instruct"
    llama_llm = None
    try:
        print(f"Loading Llama: {llama_id} ...")
        llama_llm = load_llm_wrapped(llama_id, device)
        print("Llama loaded.")
    except Exception as e:
        print(f"Could not load Llama ({llama_id}): {e}")
        # try small fallback (gpt2) so the agent can still run without large model
        fallback_id = "gpt2"
        try:
            print(f"Loading fallback model: {fallback_id}")
            llama_llm = load_llm_wrapped(fallback_id, device)
            print("Fallback loaded and used as Llama.")
        except Exception as e2:
            print(f"Failed to load fallback {fallback_id}: {e2}")
            class FinalFallback:
                def invoke(self, prompt: str) -> str:
                    return "[LLAMA-FALLBACK] No model available. Install transformers and a small model like 'gpt2'."
            llama_llm = FinalFallback()

    # Qwen is disabled: return None in place of qwen_llm
    qwen_llm = None

    return llama_llm, qwen_llm

# -------------------------
# Graph & nodes
# -------------------------
def create_graph(llama_llm, qwen_llm):
    """
    Single-model routing with message history.
    Qwen is disabled; 'Hey Qwen' will be treated as normal input and a tool message will be stored.
    """

    def get_user_input(state: AgentState) -> dict:
        # On new input, clear previous model responses (avoid stale output)
        if state.get("verbose", False):
            print("[TRACE] Entering get_user_input")
        print("\n" + "=" * 50)
        print("Enter text (or 'quit' to exit). Type 'verbose' or 'quiet' to toggle tracing.")
        print("=" * 50)
        print("> ", end="")
        user_input = input()

        lowered = user_input.strip().lower()
        if lowered == "verbose":
            # clear outputs and leave messages intact (we keep the chat history)
            return {"user_input": user_input, "should_exit": False, "verbose": True, "llama_response": "", "qwen_response": ""}
        if lowered == "quiet":
            return {"user_input": user_input, "should_exit": False, "verbose": False, "llama_response": "", "qwen_response": ""}

        if lowered in ("quit", "exit", "q"):
            # clear last outputs before exiting
            return {"user_input": user_input, "should_exit": True, "llama_response": "", "qwen_response": ""}

        # Normal input: append a human message to the message history and clear model outputs
        # We'll modify the state's messages in the call to route (LangGraph merges returned dict)
        return {"user_input": user_input, "should_exit": False, "llama_response": "", "qwen_response": ""}

    def route_after_input(state: AgentState) -> str:
        """
        Three-way:
         - END if should_exit
         - get_user_input if input empty
         - call_llama otherwise (Qwen disabled)
        If input begins with 'Hey Qwen', we insert a tool notice into messages and route to Llama.
        """
        if state.get("verbose", False):
            print("[TRACE] route_after_input:", {"should_exit": state.get("should_exit"), "user_input": repr(state.get("user_input"))})

        if state.get("should_exit", False):
            return END

        raw = str(state.get("user_input", ""))
        if raw.strip() == "":
            if state.get("verbose", False):
                print("[TRACE] empty input -> looping back")
            print("[NOTICE] Empty input received â€” please type something.")
            return "get_user_input"

        # If user tried to address Qwen, record a tool message noting Qwen disabled and route to Llama
        stripped_leading = raw.lstrip()
        if stripped_leading.lower().startswith("hey qwen"):
            # add a tool message to explain Qwen is disabled
            tool_msg = {"role": "tool", "content": "Qwen is disabled in this agent. Input will be handled by Llama instead."}
            # Merge messages into state: append tool message + append human message (we will append human message below)
            # Note: return mapping from route function only determines next node; we need to append messages now.
            # LangGraph will merge dict returned by the preceding node (get_user_input) â€” but route function cannot modify state.
            # So we rely on get_user_input having appended the human message earlier; to ensure the tool message is recorded,
            # we'll have the call_llama node detect the 'Hey Qwen' prefix and append the tool message before calling Llama.
            if state.get("verbose", False):
                print("[TRACE] Detected 'Hey Qwen' prefix; will route to Llama with tool notice")
            return "call_llama"

        # Default: route to Llama
        return "call_llama"

    def call_llama(state: AgentState) -> dict:
        """
        This node:
          - ensures the latest human message is appended to messages
          - if the input began with 'Hey Qwen' it appends a tool message explaining Qwen is disabled
          - builds prompt from messages and calls the Llama model
          - appends the AI response to messages and returns llama_response
        """
        if state.get("verbose", False):
            print("[TRACE] call_llama invoked")

        raw = str(state.get("user_input", ""))

        # Ensure messages exists
        messages: List[Message] = state.get("messages", [])
        if messages is None:
            messages = []

        # Append human message (keep original exact user_input)
        human_msg = {"role": "human", "content": raw}
        messages.append(human_msg)

        # If user started with Hey Qwen, add a tool message informing Qwen is disabled
        stripped_leading = raw.lstrip()
        if stripped_leading.lower().startswith("hey qwen"):
            tool_msg = {"role": "tool", "content": "Qwen is disabled in this agent. Routed to Llama."}
            messages.append(tool_msg)
            # Also remove the leading trigger from the human content in messages? we keep original human message,
            # but for model input we'll pass the remainder (strip trigger) as the human content via prompt builder.
            # So we modify a temporary messages_for_prompt below.

        # Build messages_for_prompt such that if 'Hey Qwen' was used, the latest human message content is replaced
        # with the remainder after removing the trigger, so model does not see the trigger text.
        messages_for_prompt = []
        for m in messages:
            # copy to avoid mutating state stored messages
            messages_for_prompt.append({"role": m["role"], "content": m["content"]})

        # If last human message begins with Hey Qwen, replace its content for the prompt with remainder
        if messages_for_prompt and messages_for_prompt[-1]["role"] in ("human", "user"):
            last_content = messages_for_prompt[-1]["content"]
            if last_content.lstrip().lower().startswith("hey qwen"):
                # remove leading 'Hey Qwen' (first 8 characters after lstrip) then left-strip separators
                remainder = last_content.lstrip()[8:].lstrip()
                messages_for_prompt[-1]["content"] = remainder if remainder != "" else last_content

        # Build prompt text for the single-turn text generation model
        prompt_text = build_prompt_from_messages(messages_for_prompt)

        if state.get("verbose", False):
            print("[TRACE] Prompt to Llama (truncated):")
            print(prompt_text[:1000])

        # Call the Llama LLM
        response_text = llama_llm.invoke(prompt_text)

        # Append AI response to messages
        ai_msg = {"role": "ai", "content": response_text}
        messages.append(ai_msg)

        # Return updated messages and llama_response for printing
        return {"llama_response": response_text, "messages": messages}

    def print_both_responses(state: AgentState) -> dict:
        """
        Print only the model sections that have content. Because Qwen is disabled,
        qwen_response will typically be empty.
        """
        if state.get("verbose", False):
            print("[TRACE] print_both_responses invoked")

        llama_text = (state.get("llama_response") or "").strip()
        qwen_text = (state.get("qwen_response") or "").strip()

        if not llama_text and not qwen_text:
            print("\n[NOTICE] No model produced output this turn.")
            return {"llama_response": "", "qwen_response": ""}

        if llama_text:
            print("\n" + "=" * 60)
            print("ðŸ¦™ LLaMA Response")
            print("=" * 60)
            print(llama_text)

        if qwen_text:
            print("\n" + "=" * 60)
            print("ðŸ§  Qwen Response")
            print("=" * 60)
            print(qwen_text)

        # Clear responses for next iteration
        return {"llama_response": "", "qwen_response": ""}

    # Build the graph
    graph = StateGraph(AgentState)
    graph.add_node("get_user_input", get_user_input)
    graph.add_node("call_llama", call_llama)
    graph.add_node("print_both_responses", print_both_responses)

    graph.add_edge(START, "get_user_input")

    graph.add_conditional_edges(
        "get_user_input",
        route_after_input,
        {
            "get_user_input": "get_user_input",  # empty -> loop
            "call_llama": "call_llama",          # default -> llama (including Hey Qwen)
            END: END,
        },
    )

    graph.add_edge("call_llama", "print_both_responses")
    graph.add_edge("print_both_responses", "get_user_input")

    return graph.compile()

# -------------------------
# Main
# -------------------------
def main():
    print("=" * 60)
    print("LangGraph Single-Model Routing Agent with Message History (Qwen disabled)")
    print("=" * 60)

    llama_llm, qwen_llm = create_models()

    graph = create_graph(llama_llm, qwen_llm)

    # Initial system message that sets assistant behavior; customize as needed
    system_msg = {"role": "system", "content": "You are a helpful assistant. Answer concisely."}

    initial_state: AgentState = {
        "user_input": "",
        "should_exit": False,
        "llama_response": "",
        "qwen_response": "",
        "verbose": False,
        "messages": [system_msg],
    }

    graph.invoke(initial_state)

if __name__ == "__main__":
    main()

LangGraph Single-Model Routing Agent with Message History (Qwen disabled)
Using CUDA (NVIDIA GPU)
Loading Llama: meta-llama/Llama-3.2-1B-Instruct ...


Loading weights:   0%|          | 0/146 [00:00<?, ?it/s]

Llama loaded.

Enter text (or 'quit' to exit). Type 'verbose' or 'quiet' to toggle tracing.
> 

 What is soccer?


Both `max_new_tokens` (=256) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)



ðŸ¦™ LLaMA Response
[System] You are a helpful assistant. Answer concisely.
User: What is soccer?
Assistant: Soccer, also known as football, is a team sport played with a round ball and involving two teams of eleven players each, with the objective of scoring more goals than the opposing team by kicking or heading the ball into the opponent's goal. The game is typically played on a rectangular field with goals at each end.

Enter text (or 'quit' to exit). Type 'verbose' or 'quiet' to toggle tracing.
> 

 What is the best team?


Both `max_new_tokens` (=256) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)



ðŸ¦™ LLaMA Response
[System] You are a helpful assistant. Answer concisely.
User: What is soccer?
Assistant: [System] You are a helpful assistant. Answer concisely.
User: What is soccer?
Assistant: Soccer, also known as football, is a team sport played with a round ball and involving two teams of eleven players each, with the objective of scoring more goals than the opposing team by kicking or heading the ball into the opponent's goal. The game is typically played on a rectangular field with goals at each end.
User: What is the best team?
Assistant: [System] There is no definitive answer, as the best team can vary depending on the specific season, competition, and criteria used to evaluate teams. However, some of the top-performing teams in recent years include Barcelona, Manchester City, and Real Madrid.

Enter text (or 'quit' to exit). Type 'verbose' or 'quiet' to toggle tracing.
> 

 quit
