Task 4

In [27]:
# langgraph_single_model_routing_fixed.py
#
# LangGraph agent that routes each user input to exactly one LLM:
#  - If input starts with "Hey Qwen" (case-insensitive) -> route to Qwen
#  - Otherwise -> route to Llama
# Fixes:
#  - Clears previous responses on new input to avoid stale output
#  - Uses robust startswith matching for trigger phrase
#  - Strips the trigger phrase before sending remainder to Qwen

import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain_huggingface import HuggingFacePipeline
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Tuple

# -------------------------
# Device selection
# -------------------------
def get_device():
    if torch.cuda.is_available():
        print("Using CUDA (NVIDIA GPU)")
        return "cuda"
    elif torch.backends.mps.is_available():
        print("Using MPS (Apple Silicon)")
        return "mps"
    else:
        print("Using CPU")
        return "cpu"

# -------------------------
# State
# -------------------------
class AgentState(TypedDict):
    user_input: str
    should_exit: bool
    llama_response: str
    qwen_response: str
    verbose: bool

# -------------------------
# LLM loader with safe langchain fix
# -------------------------
def load_llm_wrapped(model_id: str, device: str):
    """
    Load a HF model and wrap in an adapter exposing .invoke(prompt) -> str.
    Raises RuntimeError on failure.
    """
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            dtype=torch.float16 if device != "cpu" else torch.float32,
            device_map=device if device == "cuda" else None,
        )
        if device == "mps":
            model = model.to(device)

        pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            max_new_tokens=256,
            temperature=0.7,
            top_p=0.95,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )

        # Defensive langchain attributes to avoid AttributeError
        try:
            import langchain
            if not hasattr(langchain, "verbose"):
                langchain.verbose = False
            if not hasattr(langchain, "debug"):
                langchain.debug = False
            if not hasattr(langchain, "llm_cache"):
                langchain.llm_cache = None
        except Exception:
            pass

        wrapped = HuggingFacePipeline(pipeline=pipe)

        class Adapter:
            def __init__(self, llm):
                self.llm = llm
            def invoke(self, prompt: str) -> str:
                out = self.llm.invoke(prompt)
                if isinstance(out, str):
                    return out
                try:
                    if isinstance(out, list) and len(out) > 0 and isinstance(out[0], dict):
                        if "generated_text" in out[0]:
                            return out[0]["generated_text"]
                    if isinstance(out, dict) and "generated_text" in out:
                        return out["generated_text"]
                except Exception:
                    pass
                return str(out)

        return Adapter(wrapped)

    except Exception as e:
        raise RuntimeError(f"Failed to load {model_id}: {e}") from e

# -------------------------
# Create models (Llama + Qwen/ fallback)
# -------------------------
def create_models() -> Tuple[object, object]:
    device = get_device()

    # Attempt to load Llama
    llama_id = "meta-llama/Llama-3.2-1B-Instruct"
    llama_llm = None
    try:
        print(f"Loading Llama: {llama_id} ...")
        llama_llm = load_llm_wrapped(llama_id, device)
        print("Llama loaded.")
    except Exception as e:
        print(f"Could not load Llama ({llama_id}): {e}")
        llama_llm = None

    # Attempt to load Qwen; fall back to small local model if needed
    qwen_id = "Qwen/Qwen2.5-1.5B-Instruct"
    qwen_llm = None
    try:
        print(f"Loading Qwen: {qwen_id} ...")
        qwen_llm = load_llm_wrapped(qwen_id, device)
        print("Qwen loaded.")
    except Exception as e:
        print(f"Could not load Qwen ({qwen_id}): {e}")
        print("Falling back to lightweight local model (gpt2) for Qwen branch (no tokens required).")
        fallback_id = "gpt2"
        try:
            qwen_llm = load_llm_wrapped(fallback_id, device)
            print("Fallback (gpt2) loaded for Qwen branch.")
        except Exception as e2:
            print(f"Failed to load fallback model {fallback_id}: {e2}")
            class QwenFallback:
                def invoke(self, prompt: str) -> str:
                    return "[QWEN-FALLBACK] No local model available. Install transformers and a small model like 'gpt2'."
            qwen_llm = QwenFallback()

    # If Llama failed to load, use qwen/fallback as a substitute so graph still runs
    if llama_llm is None:
        print("Using the Qwen/fallback model for the Llama branch as well (Llama unavailable).")
        llama_llm = qwen_llm

    return llama_llm, qwen_llm

# -------------------------
# Create graph with single-model routing (fixed)
# -------------------------
def create_graph(llama_llm, qwen_llm):
    def get_user_input(state: AgentState) -> dict:
        if state.get("verbose", False):
            print("[TRACE] Entering get_user_input")
        print("\n" + "=" * 50)
        print("Enter text (or 'quit' to exit). Type 'verbose' or 'quiet' to toggle tracing.")
        print("=" * 50)
        print("> ", end="")
        user_input = input()

        lowered = user_input.strip().lower()
        if lowered == "verbose":
            print("Verbose tracing enabled.")
            # clear previous responses on a new input
            return {"user_input": user_input, "should_exit": False, "verbose": True, "llama_response": "", "qwen_response": ""}
        if lowered == "quiet":
            print("Verbose tracing disabled.")
            return {"user_input": user_input, "should_exit": False, "verbose": False, "llama_response": "", "qwen_response": ""}

        if lowered in ("quit", "exit", "q"):
            print("Goodbye!")
            return {"user_input": user_input, "should_exit": True, "llama_response": "", "qwen_response": ""}

        # For any other input, clear previous model outputs to avoid stale prints
        return {"user_input": user_input, "should_exit": False, "llama_response": "", "qwen_response": ""}

    def route_after_input(state: AgentState) -> str:
        # 3-way branch: END, back-to-input (empty), or call appropriate model
        if state.get("verbose", False):
            print("[TRACE] route_after_input:", {"should_exit": state.get("should_exit"), "user_input": repr(state.get("user_input"))})
        if state.get("should_exit", False):
            return END

        raw = str(state.get("user_input", ""))
        if raw.strip() == "":
            if state.get("verbose", False):
                print("[TRACE] empty input -> looping back to get_user_input")
            print("[NOTICE] Empty input received â€” please type something.")
            return "get_user_input"

        # Robust trigger detection: startswith after lstrip
        stripped_leading = raw.lstrip()
        if stripped_leading.lower().startswith("hey qwen"):
            if state.get("verbose", False):
                print("[TRACE] route_after_input -> call_qwen (matched 'Hey Qwen')")
            return "call_qwen"
        else:
            if state.get("verbose", False):
                print("[TRACE] route_after_input -> call_llama (default)")
            return "call_llama"

    def call_llama(state: AgentState) -> dict:
        if state.get("verbose", False):
            print("[TRACE] call_llama invoked")
        prompt = f"User: {state['user_input']}\nAssistant:"
        response = llama_llm.invoke(prompt)
        return {"llama_response": response}

    def call_qwen(state: AgentState) -> dict:
        if state.get("verbose", False):
            print("[TRACE] call_qwen invoked")
        # Strip the trigger phrase "Hey Qwen" (case-insensitive) before passing to model
        raw = str(state.get("user_input", ""))
        stripped_leading = raw.lstrip()
        lower_leading = stripped_leading.lower()
        if lower_leading.startswith("hey qwen"):
            # remove only the leading phrase 'Hey Qwen' (length 8), then strip leading separators
            remainder = stripped_leading[8:].lstrip()
        else:
            remainder = stripped_leading

        # If remainder is empty (user only typed the trigger), still call model with original raw input
        model_input = remainder if remainder.strip() != "" else raw

        prompt = f"User: {model_input}\nAssistant:"
        response = qwen_llm.invoke(prompt)
        return {"qwen_response": response}

    def print_both_responses(state: AgentState) -> dict:
        """
        Print only the model sections that actually have text.
        If both are empty, print a short notice instead.
        """
        if state.get("verbose", False):
            print("[TRACE] print_both_responses invoked")
    
        llama_text = (state.get("llama_response") or "").strip()
        qwen_text  = (state.get("qwen_response")  or "").strip()
    
        if not llama_text and not qwen_text:
            print("\n[NOTICE] No model produced output this turn.")
            return {"llama_response": "", "qwen_response": ""}
    
        if llama_text:
            print("\n" + "=" * 60)
            print("ðŸ¦™ LLaMA Response")
            print("=" * 60)
            print(llama_text)
    
        if qwen_text:
            print("\n" + "=" * 60)
            print("ðŸ§  Qwen Response")
            print("=" * 60)
            print(qwen_text)

        # Clear both responses so the next turn starts fresh
        return {"llama_response": "", "qwen_response": ""}

    # Build graph
    graph = StateGraph(AgentState)
    graph.add_node("get_user_input", get_user_input)
    graph.add_node("call_llama", call_llama)
    graph.add_node("call_qwen", call_qwen)
    graph.add_node("print_both_responses", print_both_responses)

    graph.add_edge(START, "get_user_input")

    graph.add_conditional_edges(
        "get_user_input",
        route_after_input,
        {
            "get_user_input": "get_user_input",  # empty input -> loop
            "call_qwen": "call_qwen",            # Hey Qwen -> qwen
            "call_llama": "call_llama",          # default -> llama
            END: END,
        },
    )

    # Both model nodes converge to the print node
    graph.add_edge("call_llama", "print_both_responses")
    graph.add_edge("call_qwen", "print_both_responses")

    # Loop back to input
    graph.add_edge("print_both_responses", "get_user_input")

    return graph.compile()

# -------------------------
# Main
# -------------------------
def main():
    print("=" * 60)
    print("LangGraph Single-Model Routing Agent (LLaMA or Qwen based on input)")
    print("=" * 60)

    llama_llm, qwen_llm = create_models()

    graph = create_graph(llama_llm, qwen_llm)

    initial_state: AgentState = {
        "user_input": "",
        "should_exit": False,
        "llama_response": "",
        "qwen_response": "",
        "verbose": False,
    }

    graph.invoke(initial_state)

# Entry point
if __name__ == "__main__":
    main()

LangGraph Single-Model Routing Agent (LLaMA or Qwen based on input)
Using CUDA (NVIDIA GPU)
Loading Llama: meta-llama/Llama-3.2-1B-Instruct ...


Loading weights:   0%|          | 0/146 [00:00<?, ?it/s]

Llama loaded.
Loading Qwen: Qwen/Qwen2.5-1.5B-Instruct ...


Loading weights:   0%|          | 0/338 [00:00<?, ?it/s]

Qwen loaded.

Enter text (or 'quit' to exit). Type 'verbose' or 'quiet' to toggle tracing.
> 

 Hey Qwen, how are you?


Both `max_new_tokens` (=256) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)



ðŸ§  Qwen Response
User: , how are you?
Assistant: As an AI, I don't have feelings or emotions like humans do. However, I'm here to help you with any questions or information you need! How can I assist you today? Let me know if there's anything specific you'd like to discuss or learn about. I'll do my best to provide accurate and helpful responses based on the data available to me. Is there something particular you're curious about? I may not be able to predict what you might ask, but I'll try to give a thoughtful response when possible. What would you like to find out?

Enter text (or 'quit' to exit). Type 'verbose' or 'quiet' to toggle tracing.
> 

 What is soccer?


Both `max_new_tokens` (=256) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)



ðŸ¦™ LLaMA Response
User: What is soccer?
Assistant: Soccer, also known as association football, is a team sport played between two teams of 11 players using a round ball with the objective of scoring more goals than the opposing team by kicking or heading the ball into the opponent's goal.

Would you like me to explain any specific aspect of soccer?

Enter text (or 'quit' to exit). Type 'verbose' or 'quiet' to toggle tracing.
> 

 quit


Goodbye!
