# HuggingGPT Task Planning Implementation - Stage 1

This notebook implements a simplified version of the task planning stage from the HuggingGPT paper. The code is divided into the following sections:

1.  **Setup and Imports**: Install necessary packages and import libraries.
2.  **Task Definitions**: Define available task types and the TaskNode data structure.
3.  **HuggingGPTTaskPlanner Class**: Implement the core task planning logic, including LLM integrations.
4.  **Example Usage and Testing**: Demonstrate how to use the planner with test cases and different LLM providers.
5.  **Interactive Mode**: Provide an interactive loop for testing the planner.

In [None]:
# 1. Setup and Imports
# =====================

# Install required packages (run this cell first)
# !pip install openai>=1.0.0 transformers torch requests json5 langchain-openai

import json
import re
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from enum import Enum
import requests
import os

# Optional: For LangChain + OpenRouter integration
try:
    from langchain_openai import ChatOpenAI
    HAS_LANGCHAIN = True
except ImportError:
    HAS_LANGCHAIN = False
    print("LangChain not installed. Install with: pip install langchain-openai")

# Optional: For OpenAI integration
try:
    from openai import OpenAI
    HAS_OPENAI = True
except ImportError:
    HAS_OPENAI = False
    print("OpenAI not installed. Install with: pip install openai>=1.0.0")

# Optional: For Hugging Face integration
try:
    from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
    import torch
    HAS_TRANSFORMERS = True
except ImportError:
    HAS_TRANSFORMERS = False
    print("Transformers not installed. Install with: pip install transformers torch")

# Optional: For free API alternatives
try:
    import requests
    HAS_REQUESTS = True
except ImportError:
    HAS_REQUESTS = False

In [None]:
pip install langchain-openai

In [None]:
# 2. Task Definitions
# ===================

# Define available tasks
class TaskType(Enum):
    OBJECT_DETECTION = "object-detection"
    IMAGE_TO_TEXT = "image-to-text"
    IMAGE_CLASSIFICATION = "image-cls"
    VISUAL_QUESTION_ANSWERING = "visual-question-answering"
    POSE_DETECTION = "pose-detection"
    POSE_TEXT_TO_IMAGE = "pose-text-to-image"
    TEXT_TO_IMAGE = "text-to-image"
    IMAGE_SEGMENTATION = "image-segmentation"
    DEPTH_ESTIMATION = "depth-estimation"
    TEXT_CLASSIFICATION = "text-classification"
    TEXT_GENERATION = "text-generation"
    SPEECH_TO_TEXT = "speech-to-text"
    TEXT_TO_SPEECH = "text-to-speech"

@dataclass
class TaskNode:
    task: str
    id: int
    dep: List[int]
    args: Dict[str, Any]

    def to_dict(self) -> Dict[str, Any]:
        return {
            "task": self.task,
            "id": self.id,
            "dep": self.dep,
            "args": self.args
        }

In [None]:
# 3. HuggingGPTTaskPlanner Class
# ===============================

class HuggingGPTTaskPlanner:
    def __init__(self, available_tasks: List[str] = None,
                 openai_api_key: str = None,
                 openrouter_api_key: str = None,
                 llm_provider: str = "rule_based"):
        """
        Initialize the Task Planner

        Args:
            available_tasks: List of available task types
            openai_api_key: OpenAI API key for direct OpenAI integration
            openrouter_api_key: OpenRouter API key for accessing multiple models
            llm_provider: LLM provider ("openai", "openrouter", "huggingface_local", "huggingface_api", "ollama", "rule_based")
        """
        if available_tasks is None:
            self.available_tasks = [task.value for task in TaskType]
        else:
            self.available_tasks = available_tasks

        self.llm_provider = llm_provider

        # Initialize OpenAI client (direct)
        self.openai_client = None
        if llm_provider == "openai":
            if openai_api_key and HAS_OPENAI:
                self.openai_client = OpenAI(api_key=openai_api_key)
            elif HAS_OPENAI and os.getenv("OPENAI_API_KEY"):
                self.openai_client = OpenAI()
            else:
                print("⚠️  OpenAI API key not provided. Set OPENAI_API_KEY environment variable or pass api_key parameter.")

        # Initialize LangChain + OpenRouter client
        self.langchain_client = None
        if llm_provider == "openrouter":
            if openrouter_api_key and HAS_LANGCHAIN:
                self.langchain_client = ChatOpenAI(
                    model="openai/gpt-4o-mini",  # Default model
                    api_key=openrouter_api_key,
                    base_url="https://openrouter.ai/api/v1",
                    temperature=0.1,
                    max_tokens=1000
                )
                print("✅ OpenRouter + LangChain initialized successfully!")
            elif HAS_LANGCHAIN and os.getenv("OPENROUTER_API_KEY"):
                self.langchain_client = ChatOpenAI(
                    model="openai/gpt-4o-mini",
                    api_key=os.getenv("OPENROUTER_API_KEY"),
                    base_url="https://openrouter.ai/api/v1",
                    temperature=0.1,
                    max_tokens=1000
                )
                print("✅ OpenRouter + LangChain initialized from environment!")
            else:
                print("⚠️  OpenRouter API key not provided or LangChain not available.")

        # Initialize Hugging Face model (for local inference)
        self.hf_pipeline = None

        if llm_provider == "huggingface_local" and HAS_TRANSFORMERS:
            self.initialize_huggingface_model()
        elif llm_provider == "huggingface_api":
            # Uses Hugging Face Inference API (free tier available)
            self.hf_api_token = os.getenv("HUGGINGFACE_API_TOKEN")
            if not self.hf_api_token:
                print("⚠️  Set HUGGINGFACE_API_TOKEN environment variable for HF API access")

        # Ollama support (free local LLM)
        self.ollama_url = "http://localhost:11434"  # Default Ollama URL

        # Define demonstrations (examples from the paper)
        self.demonstrations = [
            {
                "input": "Can you tell me how many objects in e1.jpg?",
                "output": [{"task": "object-detection", "id": 0, "dep": [-1], "args": {"image": "e1.jpg"}}]
            },
            {
                "input": "In e2.jpg, what's the animal and what's it doing?",
                "output": [
                    {"task": "image-to-text", "id": 0, "dep": [-1], "args": {"image": "e2.jpg"}},
                    {"task": "image-cls", "id": 1, "dep": [-1], "args": {"image": "e2.jpg"}},
                    {"task": "object-detection", "id": 2, "dep": [-1], "args": {"image": "e2.jpg"}},
                    {"task": "visual-question-answering", "id": 3, "dep": [-1], "args": {"text": "what's the animal doing?", "image": "e2.jpg"}}
                ]
            },
            {
                "input": "First generate a HED image of e3.jpg, then based on the HED image and a text 'a girl reading a book', create a new image as a response.",
                "output": [
                    {"task": "pose-detection", "id": 0, "dep": [-1], "args": {"image": "e3.jpg"}},
                    {"task": "pose-text-to-image", "id": 1, "dep": [0], "args": {"text": "a girl reading a book", "image": "<resource>-0"}}
                ]
            }
        ]

        self.chat_logs = []  # Store conversation history

    def initialize_huggingface_model(self, model_name: str = "microsoft/DialoGPT-small"):
        """Initialize Hugging Face model for local inference"""
        try:
            print(f"📥 Loading Hugging Face model: {model_name}")
            print("⏳ This may take a few minutes for first-time download...")

            # Use a smaller, faster model for task planning
            self.hf_pipeline = pipeline(
                "text-generation",
                model=model_name,
                tokenizer=model_name,
                max_length=512,
                do_sample=True,
                temperature=0.1
            )
            print("✅ Hugging Face model loaded successfully!")
        except Exception as e:
            print(f"❌ Failed to load Hugging Face model: {e}")
            print("💡 Trying with a simpler model...")
            try:
                # Fallback to a very small model
                self.hf_pipeline = pipeline(
                    "text-generation",
                    model="distilgpt2",
                    max_length=256,
                    do_sample=True,
                    temperature=0.1
                )
                print("✅ Fallback model loaded successfully!")
            except Exception as e2:
                print(f"❌ Fallback model also failed: {e2}")
                self.hf_pipeline = None

    def create_task_planning_prompt(self, user_input: str, chat_logs: List[str] = None) -> str:
        """Create the task planning prompt based on HuggingGPT format"""
        available_tasks_str = ", ".join(self.available_tasks)

        # Format demonstrations
        demo_str = ""
        for demo in self.demonstrations:
            demo_str += f"{demo['input']}\n{json.dumps(demo['output'])}\n\n"

        # Format chat logs if available
        chat_logs_str = ""
        if chat_logs:
            chat_logs_str = "\n".join(chat_logs[-5:])  # Last 5 entries

        prompt = f"""#1 Task Planning Stage - Parse user input into structured tasks.

Format: [{{"task": task_name, "id": task_id, "dep": [dependency_ids], "args": {{"text": text, "image": URL, "audio": URL, "video": URL}}}}]

Available tasks: {available_tasks_str}

Dependencies:
- "dep": [-1] for no dependencies
- "dep": [0,1] for dependencies on tasks 0 and 1
- Use "<resource>-task_id" to reference outputs from previous tasks

Examples:

{demo_str}

Chat history: {chat_logs_str}

Parse this user input into JSON task array:
"{user_input}"

Response (JSON array only):"""

        return prompt

    def call_huggingface_local(self, prompt: str) -> str:
        """Call local Hugging Face model for task planning"""
        if not self.hf_pipeline:
            raise Exception("Hugging Face model not initialized")

        try:
            # Simplified prompt for smaller models
            simple_prompt = f"Task planning for: {prompt}\nJSON output:"

            outputs = self.hf_pipeline(
                simple_prompt,
                max_length=len(simple_prompt.split()) + 100,
                num_return_sequences=1,
                pad_token_id=50256  # GPT-2 EOS token
            )

            response = outputs[0]['generated_text']
            # Remove the input prompt from response
            response = response[len(simple_prompt):].strip()
            return response
        except Exception as e:
            print(f"Hugging Face model error: {e}")
            return "[]"

    def call_huggingface_api(self, prompt: str, model: str = "gpt2") -> str:
        """Call Hugging Face Inference API (free tier available)"""
        if not HAS_REQUESTS:
            raise Exception("Requests library not available")

        url = f"https://api-inference.huggingface.co/models/{model}"
        headers = {"Authorization": f"Bearer {self.hf_api_token}"}

        payload = {
            "inputs": prompt,
            "parameters": {
                "max_new_tokens": 200,
                "temperature": 0.1,
                "do_sample": True
            }
        }

        try:
            response = requests.post(url, headers=headers, json=payload, timeout=30)
            if response.status_code == 200:
                result = response.json()
                if isinstance(result, list) and len(result) > 0:
                    return result[0].get('generated_text', '[]')
            return "[]"
        except Exception as e:
            print(f"Hugging Face API error: {e}")
            return "[]"

    def call_ollama(self, prompt: str, model: str = "llama2") -> str:
        """Call Ollama API for local LLM inference"""
        if not HAS_REQUESTS:
            raise Exception("Requests library not available")

        url = f"{self.ollama_url}/api/generate"
        payload = {
            "model": model,
            "prompt": prompt,
            "stream": False,
            "options": {
                "temperature": 0.1,
                "num_predict": 200
            }
        }

        try:
            response = requests.post(url, json=payload, timeout=30)
            if response.status_code == 200:
                result = response.json()
                return result.get('response', '[]')
            return "[]"
        except Exception as e:
            print(f"Ollama API error: {e}")
            return "[]"

    def call_llm(self, prompt: str) -> str:
        """Call the configured LLM provider"""
        if self.llm_provider == "openai" and self.openai_client:
            return self.call_openai_llm(prompt)
        elif self.llm_provider == "openrouter" and self.langchain_client:
            return self.call_openrouter_llm(prompt)
        elif self.llm_provider == "huggingface_local" and self.hf_pipeline:
            return self.call_huggingface_local(prompt)
        elif self.llm_provider == "huggingface_api" and self.hf_api_token:
            return self.call_huggingface_api(prompt)
        elif self.llm_provider == "ollama":
            return self.call_ollama(prompt)
        else:
            return "[]" # Fallback if no LLM is configured/available

    def call_openrouter_llm(self, prompt: str) -> str:
        """Call OpenRouter via LangChain for task planning"""
        if not self.langchain_client:
            raise Exception("LangChain + OpenRouter client not initialized")

        try:
            response = self.langchain_client.invoke(prompt)
            return response.content.strip()
        except Exception as e:
            print(f"OpenRouter API error: {e}")
            return "[]"

    def call_openai_llm(self, prompt: str, model: str = "gpt-3.5-turbo") -> str:
        """Call OpenAI API for task planning"""
        if not self.openai_client:
            raise Exception("OpenAI client not initialized. Provide API key or set OPENAI_API_KEY environment variable.")

        try:
            response = self.openai_client.chat.completions.create(
                model=model,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.1,
                max_tokens=1000
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            print(f"OpenAI API error: {e}")
            return "[]"

    def parse_llm_response(self, response: str) -> List[Dict[str, Any]]:
        """Parse LLM response to extract JSON tasks"""
        try:
            # Extract JSON array from response
            json_match = re.search(r'\[.*\]', response, re.DOTALL)
            if json_match:
                json_str = json_match.group(0)
                tasks = json.loads(json_str)
                return tasks
            else:
                return []
        except Exception as e:
            print(f"Error parsing LLM response: {e}")
            return []

    def fallback_parse_user_request(self, user_input: str) -> List[TaskNode]:
        """Fallback rule-based parsing when LLM is not available"""
        tasks = []
        task_id = 0

        user_input_lower = user_input.lower()
        image_files = re.findall(r'\b\w+\.(jpg|jpeg|png|gif|bmp)\b', user_input, re.IGNORECASE)

        if "how many objects" in user_input_lower and image_files:
            tasks.append(TaskNode(
                task="object-detection",
                id=task_id,
                dep=[-1],
                args={"image": image_files[0]}
            ))

        elif "what" in user_input_lower and ("animal" in user_input_lower or "doing" in user_input_lower) and image_files:
            tasks.extend([
                TaskNode(task="image-to-text", id=task_id, dep=[-1], args={"image": image_files[0]}),
                TaskNode(task="image-cls", id=task_id+1, dep=[-1], args={"image": image_files[0]}),
                TaskNode(task="object-detection", id=task_id+2, dep=[-1], args={"image": image_files[0]}),
                TaskNode(task="visual-question-answering", id=task_id+3, dep=[-1],
                        args={"text": user_input, "image": image_files[0]})
            ])

        elif "generate" in user_input_lower and "image" in user_input_lower:
            if "pose" in user_input_lower or "hed" in user_input_lower:
                if image_files:
                    tasks.append(TaskNode(task="pose-detection", id=task_id, dep=[-1], args={"image": image_files[0]}))
                    task_id += 1

                text_match = re.search(r'"([^"]*)"', user_input)
                if text_match:
                    text_content = text_match.group(1)
                    tasks.append(TaskNode(
                        task="pose-text-to-image",
                        id=task_id,
                        dep=[task_id-1],
                        args={"text": text_content, "image": f"<resource>-{task_id-1}"}
                    ))

        return tasks

    def plan_tasks(self, user_input: str, use_llm: bool = True) -> List[Dict[str, Any]]:
        """Main task planning function"""

        if use_llm:
            try:
                provider_available = False

                if self.llm_provider == "openai" and self.openai_client:
                    print("🤖 Using OpenAI for task planning...")
                    provider_available = True
                elif self.llm_provider == "openrouter" and self.langchain_client:
                    print("🌐 Using OpenRouter (GPT-4o-mini) for task planning...")
                    provider_available = True
                elif self.llm_provider == "huggingface_local" and self.hf_pipeline:
                    print("🤗 Using local Hugging Face model for task planning...")
                    provider_available = True
                elif self.llm_provider == "huggingface_api" and self.hf_api_token:
                    print("🌐 Using Hugging Face API for task planning...")
                    provider_available = True
                elif self.llm_provider == "ollama":
                    print("🦙 Using Ollama for task planning...")
                    provider_available = True

                if provider_available:
                    prompt = self.create_task_planning_prompt(user_input, self.chat_logs)
                    print(f"📝 Sending prompt to {self.llm_provider}...")
                    llm_response = self.call_llm(prompt)
                    print(f"📨 Received response from {self.llm_provider}")
                    tasks = self.parse_llm_response(llm_response)

                    if tasks:
                        print(f"✅ Successfully generated {len(tasks)} tasks using LLM")
                        # Print parsed tasks for debugging
                        for i, task in enumerate(tasks):
                            print(f"   Task {i}: {task.get('task', 'unknown')} -> {task.get('args', {})}")
                    else:
                        print("⚠️  LLM returned no valid tasks, using fallback...")
                        task_nodes = self.fallback_parse_user_request(user_input)
                        tasks = [task.to_dict() for task in task_nodes]
                else:
                    print(f"⚠️  LLM provider '{self.llm_provider}' not available, using rule-based parsing...")
                    task_nodes = self.fallback_parse_user_request(user_input)
                    tasks = [task.to_dict() for task in task_nodes]

            except Exception as e:
                print(f"❌ LLM error: {e}")
                print("🔄 Using fallback parsing...")
                task_nodes = self.fallback_parse_user_request(user_input)
                tasks = [task.to_dict() for task in task_nodes]
        else:
            print("📝 Using rule-based parsing...")
            task_nodes = self.fallback_parse_user_request(user_input)
            tasks = [task.to_dict() for task in task_nodes]

        # Add to chat logs
        self.chat_logs.append(f"User: {user_input}")
        self.chat_logs.append(f"Tasks: {json.dumps(tasks, indent=2)}")

        return tasks

    def visualize_task_graph(self, tasks: List[Dict[str, Any]]) -> str:
        """Create a text-based visualization of task dependencies"""
        if not tasks:
            return "No tasks to visualize"

        graph = "Task Dependency Graph:\n"
        graph += "=" * 30 + "\n"

        for task in tasks:
            deps = task['dep']
            dep_str = "None" if deps == [-1] else f"Tasks {deps}"
            graph += f"Task {task['id']}: {task['task']}\n"
            graph += f"  Dependencies: {dep_str}\n"
            graph += f"  Args: {task['args']}\n"
            graph += "-" * 20 + "\n"

        return graph

    def get_execution_order(self, tasks: List[Dict[str, Any]]) -> List[List[int]]:
        """Determine execution order based on dependencies"""
        if not tasks:
            return []

        # Create dependency graph
        task_deps = {task['id']: task['dep'] for task in tasks}
        executed = set()
        execution_order = []

        while len(executed) < len(tasks):
            current_batch = []
            for task_id, deps in task_deps.items():
                if task_id in executed:
                    continue
                # Check if all dependencies are satisfied
                if all(dep == -1 or dep in executed for dep in deps):
                    current_batch.append(task_id)

            if not current_batch:
                break  # Circular dependency or error

            execution_order.append(current_batch)
            executed.update(current_batch)

        return execution_order

In [None]:
# 4. Example Usage and Testing
# ============================

# Example usage and configuration
print("🚀 HuggingGPT Task Planner - WITH OPENROUTER SUPPORT")
print("=" * 60)

print("💰 COST-EFFECTIVE LLM Options:")
print("1. 🌐 OpenRouter + LangChain (RECOMMENDED - cheap GPT-4o-mini)")
print("2. 📝 Rule-based (free, immediate)")
print("3. 🤗 Hugging Face Local (free, requires download)")
print("4. 🦙 Ollama (free, requires local installation)")
print("5. 🤖 Direct OpenAI (expensive)")
print()

print("🎯 RECOMMENDED: OpenRouter with GPT-4o-mini")
print("   💲 ~$0.15 per 1M tokens (very cheap!)")
print("   🚀 High quality task planning")
print("   🔑 Just need OpenRouter API key")
print()

# Configuration Examples:

print("🔧 Quick Setup with OpenRouter:")
print("   1. Get API key from: https://openrouter.ai/")
print("   2. Uncomment the OpenRouter setup below")
print()

# RECOMMENDED: OpenRouter + LangChain Setup
print("🎯 Setting up OpenRouter + LangChain (RECOMMENDED)")

# Your OpenRouter API key
openrouter_key = os.getenv("OPENROUTER_API_KEY")

try:
    planner = HuggingGPTTaskPlanner(
        openrouter_api_key=openrouter_key,
        llm_provider="openrouter"
    )
    print("✅ OpenRouter + LangChain initialized successfully!")
    print("🤖 Using GPT-4o-mini for high-quality task planning")
    print("💰 Cost: ~$0.15 per 1M tokens")
    USE_LLM = True
except Exception as e:
    print(f"❌ OpenRouter setup failed: {e}")
    print("🔄 Falling back to rule-based parsing...")
    planner = HuggingGPTTaskPlanner(llm_provider="rule_based")
    USE_LLM = False

# Alternative configurations (comment out the above and uncomment one of these):

# Rule-based only
# planner = HuggingGPTTaskPlanner(llm_provider="rule_based")
# USE_LLM = False
# print("📝 Using rule-based parsing (safe fallback)")

# Environment variable approach for OpenRouter
# import os
# os.environ["OPENROUTER_API_KEY"] = "sk-or-v1-your-full-key-here"
# planner = HuggingGPTTaskPlanner(llm_provider="openrouter")
# USE_LLM = True

# Hugging Face Local (FREE)
# planner = HuggingGPTTaskPlanner(llm_provider="huggingface_local")
# USE_LLM = True

print("✅ Planner ready!")
print("-" * 40)

# Test cases
test_cases = [
    "Can you tell me how many objects in e1.jpg?",
    "In e2.jpg, what's the animal and what's it doing?",
    "First generate a HED image of e3.jpg, then based on the HED image and a text 'a girl reading a book', create a new image as a response.",
    "Create an image of a sunset over mountains",
    "What objects are in photo.png and describe the scene?"
]

print("Running test cases:")
print("-" * 30)

for i, test_case in enumerate(test_cases, 1):
    print(f"\n🔍 Test Case {i}:")
    print(f"Input: {test_case}")
    print("-" * 30)

    # Plan tasks using the configured provider
    tasks = planner.plan_tasks(test_case, use_llm=USE_LLM)

    if tasks:
        print("\n📋 Generated Tasks:")
        print(json.dumps(tasks, indent=2))

        print("\n🔗 Task Graph:")
        print(planner.visualize_task_graph(tasks))

        print("⚡ Execution Order:")
        execution_order = planner.get_execution_order(tasks)
        for batch_idx, batch in enumerate(execution_order):
            print(f"  Batch {batch_idx + 1}: Tasks {batch}")
    else:
        print("❌ No tasks generated")

    print("=" * 50)

In [None]:
# 5. Interactive Mode
# ===================

print("\n🎯 Interactive Task Planning")
print("Enter 'quit' to exit")
print("-" * 30)

while True:
    user_input = input("\nEnter your request: ").strip()
    if user_input.lower() in ['quit', 'exit', 'q']:
        break

    if user_input:
        print(f"\n🔍 Processing: {user_input}")
        tasks = planner.plan_tasks(user_input, use_llm=USE_LLM)

        if tasks:
            print("\n📋 Generated Tasks:")
            print(json.dumps(tasks, indent=2))
            print("\n🔗 Task Dependencies:")
            print(planner.visualize_task_graph(tasks))

            # Show execution order
            execution_order = planner.get_execution_order(tasks)
            if execution_order:
                print("\n⚡ Execution Order:")
                for batch_idx, batch in enumerate(execution_order):
                    print(f"  Batch {batch_idx + 1}: Tasks {batch}")
        else:
            print("❌ Could not parse the request into tasks")

print("\n✅ Task Planning Demo Complete!")