In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:

%pip install torch tensorboard
%pip install transformers datasets accelerate evaluate trl protobuf sentencepiece
%pip install flash-attn


# Hugging Face configuration for pushing final model

In [None]:
import os


hf_token = "xxxx"
hf_repo_name = "bhaiyahnsingh45/functiongemma-multiagent-router"



In [None]:
# Login into Hugging Face Hub
from huggingface_hub import login
login()


# **Multi-Agent System: Customer Support Platform**
#
# We'll create a routing system for a customer support platform with 3 specialized agents:
#
1. **Technical Support Agent**: Handles technical issues, bugs, and troubleshooting
2. **Billing Agent**: Manages invoices, payments, subscriptions, and refunds
3. **Product Information Agent**: Provides product details, features, and recommendations

In [None]:
# -*- coding: utf-8 -*-
"""Fine-tuning with FunctionGemma for Multi-Agent Routing System

This script fine-tunes FunctionGemma to route customer support queries to specialized agents.
"""

import re
import json
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTConfig, SFTTrainer
import matplotlib.pyplot as plt

# --- Configuration ---
base_model = "google/functiongemma-270m-it"
checkpoint_dir = "functiongemma-270m-it-multiagent-router"
learning_rate = 5e-5


Multi-Agent Routing Dataset

In [None]:
multi_agent_dataset = [
    # Technical Support Agent queries - MORE VARIED AND REALISTIC
    {"user_content": "My app keeps crashing when I try to upload photos larger than 5MB", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "crash", "priority": "high"}'},
    {"user_content": "I can't log in, it says my password is incorrect but I'm sure it's right", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "authentication", "priority": "high"}'},
    {"user_content": "The dashboard takes forever to load, sometimes over 30 seconds", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "performance", "priority": "medium"}'},
    {"user_content": "How do I connect your REST API to my Node.js backend?", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "integration", "priority": "medium"}'},
    {"user_content": "Getting error 500 when calling the /api/users endpoint", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "api_error", "priority": "high"}'},
    {"user_content": "My mobile app data isn't syncing with the cloud storage", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "sync", "priority": "medium"}'},
    {"user_content": "Need help configuring 2FA with Google Authenticator", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "security_setup", "priority": "low"}'},
    {"user_content": "My CSV export keeps failing after processing 50% of the data", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "export", "priority": "medium"}'},
    {"user_content": "Search results are showing outdated information from last month", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "functionality", "priority": "medium"}'},
    {"user_content": "Can't connect to VPN from my home WiFi network", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "connectivity", "priority": "high"}'},
    {"user_content": "Getting 403 forbidden when trying to access the admin dashboard", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "permissions", "priority": "high"}'},
    {"user_content": "Desktop app installation fails on Windows 11 with error code 0x80070057", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "installation", "priority": "medium"}'},
    {"user_content": "My webhook isn't receiving POST requests from your system", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "webhook", "priority": "medium"}'},
    {"user_content": "Database queries are timing out after 30 seconds", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "database", "priority": "high"}'},
    {"user_content": "SSL certificate error when accessing the API over HTTPS", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "ssl", "priority": "medium"}'},
    {"user_content": "The app freezes when I try to edit large documents", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "performance", "priority": "high"}'},
    {"user_content": "Getting 'session expired' error every 5 minutes", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "authentication", "priority": "medium"}'},
    {"user_content": "How do I set up OAuth2 authentication for my integration?", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "integration", "priority": "low"}'},
    {"user_content": "Push notifications aren't working on my Android device", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "functionality", "priority": "medium"}'},
    {"user_content": "Unable to restore backup from last week", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "database", "priority": "high"}'},

    # Billing Agent queries - MORE VARIED AND REALISTIC
    {"user_content": "I see two charges of $99 on my credit card for this month", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "dispute", "urgency": "high"}'},
    {"user_content": "I'd like a refund for the annual plan I purchased yesterday", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "refund", "urgency": "medium"}'},
    {"user_content": "My credit card expires next month, how do I update it?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "payment_method", "urgency": "low"}'},
    {"user_content": "Where can I find my September invoice for accounting purposes?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "invoice", "urgency": "low"}'},
    {"user_content": "I want to upgrade from Basic to Premium, what's the price difference?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "upgrade", "urgency": "medium"}'},
    {"user_content": "How do I cancel my subscription before the renewal on March 15th?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "cancellation", "urgency": "medium"}'},
    {"user_content": "My payment was declined but I have sufficient funds in my account", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "payment_failure", "urgency": "high"}'},
    {"user_content": "Can I switch from monthly billing to annual to save money?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "billing_cycle", "urgency": "low"}'},
    {"user_content": "I need to change the billing email from old@company.com to new@company.com", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "account_update", "urgency": "low"}'},
    {"user_content": "What features are included in the Enterprise plan pricing?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "pricing_inquiry", "urgency": "low"}'},
    {"user_content": "I was charged $149 but my account still shows as unpaid", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "payment_discrepancy", "urgency": "high"}'},
    {"user_content": "Does your company offer educational discounts for universities?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "discount", "urgency": "low"}'},
    {"user_content": "I need to add 5 more user licenses to my current team plan", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "add_seats", "urgency": "medium"}'},
    {"user_content": "Can you provide a tax exemption form for our non-profit organization?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "tax", "urgency": "low"}'},
    {"user_content": "Is it possible to pause my subscription for 3 months while I'm on sabbatical?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "pause_subscription", "urgency": "medium"}'},
    {"user_content": "I was charged after canceling, can I get that refunded?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "refund", "urgency": "high"}'},
    {"user_content": "How do I apply the promotional code SAVE20 to my account?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "discount", "urgency": "medium"}'},
    {"user_content": "My company needs a quote for 50 enterprise licenses", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "pricing_inquiry", "urgency": "medium"}'},
    {"user_content": "Can I downgrade from Premium to Basic and get a prorated refund?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "upgrade", "urgency": "low"}'},
    {"user_content": "Need to update our billing address for tax purposes", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "account_update", "urgency": "low"}'},

    # Product Information Agent queries - MORE VARIED AND REALISTIC
    {"user_content": "What's the difference between Pro and Enterprise in terms of storage limits?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "comparison", "category": "plans"}'},
    {"user_content": "Does your platform support real-time collaborative editing like Google Docs?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "features", "category": "capabilities"}'},
    {"user_content": "Which project management tools can I integrate with your platform?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "integrations", "category": "project_management"}'},
    {"user_content": "What's the maximum file size I can upload on the Basic plan?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "limits", "category": "storage"}'},
    {"user_content": "Is there a native mobile app for iOS and Android?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "platform", "category": "mobile"}'},
    {"user_content": "Are you HIPAA compliant for healthcare data?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "compliance", "category": "healthcare"}'},
    {"user_content": "What kind of analytics and reporting dashboards do you provide?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "features", "category": "analytics"}'},
    {"user_content": "Is there an API rate limit on the Standard plan?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "limits", "category": "api"}'},
    {"user_content": "Do you have SOC 2 Type II certification?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "compliance", "category": "security"}'},
    {"user_content": "Can I customize the interface with my company's branding and colors?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "customization", "category": "branding"}'},
    {"user_content": "What features were added in the latest version 3.5 release?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "updates", "category": "releases"}'},
    {"user_content": "Does your product support SSO with Microsoft Azure AD?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "integrations", "category": "authentication"}'},
    {"user_content": "How many team members can I have on the Standard plan?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "limits", "category": "users"}'},
    {"user_content": "What languages is your user interface available in?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "localization", "category": "languages"}'},
    {"user_content": "Can I export reports to PDF and Excel formats?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "features", "category": "export"}'},
    {"user_content": "Does the Enterprise plan include dedicated customer support?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "comparison", "category": "plans"}'},
    {"user_content": "What's your uptime SLA for production environments?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "compliance", "category": "security"}'},
    {"user_content": "Can I integrate with Salesforce CRM?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "integrations", "category": "project_management"}'},
    {"user_content": "Is there a limit on how many API calls I can make per day?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "limits", "category": "api"}'},
    {"user_content": "Do you support custom workflow automation?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "features", "category": "capabilities"}'},

    # EDGE CASES - Ambiguous queries that could confuse the model
    {"user_content": "I need help with my account", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "account_update", "urgency": "medium"}'},
    {"user_content": "Something is broken", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "functionality", "priority": "medium"}'},
    {"user_content": "Tell me about your pricing", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "pricing_inquiry", "urgency": "low"}'},
    {"user_content": "What can your product do?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "features", "category": "capabilities"}'},
    {"user_content": "I have a problem with my subscription", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "account_update", "urgency": "medium"}'},

    # ADDITIONAL DIVERSE SAMPLES FOR SLOWER CONVERGENCE
    # More Technical Support variations
    {"user_content": "The application freezes randomly during video calls", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "performance", "priority": "high"}'},
    {"user_content": "I'm unable to authenticate using my Google account", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "authentication", "priority": "high"}'},
    {"user_content": "My data export is corrupted and missing half the records", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "export", "priority": "high"}'},
    {"user_content": "The mobile app crashes immediately after opening on iOS 17", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "crash", "priority": "high"}'},
    {"user_content": "API rate limiting is too aggressive, hitting limits constantly", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "api_error", "priority": "medium"}'},
    {"user_content": "Webhook notifications are delayed by 5-10 minutes", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "webhook", "priority": "medium"}'},
    {"user_content": "Can't establish secure connection, SSL handshake fails", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "ssl", "priority": "high"}'},
    {"user_content": "User permissions reset after every system update", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "permissions", "priority": "high"}'},
    {"user_content": "Backup restoration process hangs at 80% completion", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "database", "priority": "high"}'},
    {"user_content": "Integration with Slack stopped working after their API update", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "integration", "priority": "medium"}'},

    # More Billing variations
    {"user_content": "I was double-charged for last month's subscription", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "dispute", "urgency": "high"}'},
    {"user_content": "Need to change payment method from credit card to bank transfer", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "payment_method", "urgency": "medium"}'},
    {"user_content": "Can I get a pro-rated refund for unused portion of my plan?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "refund", "urgency": "low"}'},
    {"user_content": "My invoice shows wrong billing address, need correction", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "invoice", "urgency": "medium"}'},
    {"user_content": "Want to switch from monthly to quarterly billing cycle", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "billing_cycle", "urgency": "low"}'},
    {"user_content": "Payment processing failed with error code PAYMENT_001", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "payment_failure", "urgency": "high"}'},
    {"user_content": "Need to add 10 more user licenses to our team account", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "add_seats", "urgency": "medium"}'},
    {"user_content": "Is there a student discount available for annual plans?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "discount", "urgency": "low"}'},
    {"user_content": "Charged for Enterprise plan but account shows Standard tier", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "payment_discrepancy", "urgency": "high"}'},
    {"user_content": "Need to cancel auto-renewal before next billing date", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "cancellation", "urgency": "medium"}'},

    # More Product Info variations
    {"user_content": "What's the storage capacity difference between Basic and Pro?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "comparison", "category": "plans"}'},
    {"user_content": "Does your platform support real-time collaboration features?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "features", "category": "capabilities"}'},
    {"user_content": "Which CRM systems can I integrate with your product?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "integrations", "category": "project_management"}'},
    {"user_content": "What's the maximum number of API requests per hour?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "limits", "category": "api"}'},
    {"user_content": "Is there a desktop application for Windows and Mac?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "platform", "category": "mobile"}'},
    {"user_content": "Are you GDPR compliant for European customers?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "compliance", "category": "security"}'},
    {"user_content": "What reporting and analytics tools are included?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "features", "category": "analytics"}'},
    {"user_content": "Can I customize the UI theme and branding?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "customization", "category": "branding"}'},
    {"user_content": "What new features were released in version 4.0?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "updates", "category": "releases"}'},
    {"user_content": "Does Enterprise plan support multi-factor authentication?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "comparison", "category": "plans"}'},

    # More challenging edge cases
    {"user_content": "Help!", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "functionality", "priority": "medium"}'},
    {"user_content": "I'm having issues", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "functionality", "priority": "medium"}'},
    {"user_content": "My account", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "account_update", "urgency": "medium"}'},
    {"user_content": "Tell me more", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "features", "category": "capabilities"}'},
    {"user_content": "I need assistance", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "functionality", "priority": "medium"}'},

    # Additional diverse samples to reach exactly 100 with balanced distribution
    # More Technical Support diversity
    {"user_content": "Email notifications stopped arriving in my inbox", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "notification", "priority": "medium"}'},
    {"user_content": "The search function returns no results even when I know data exists", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "search", "priority": "high"}'},
    {"user_content": "Can't download files larger than 100MB through the web interface", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "download", "priority": "medium"}'},
    {"user_content": "My session expires after 5 minutes of inactivity", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "session", "priority": "low"}'},
    {"user_content": "The calendar widget doesn't sync with Google Calendar", "agent_name": "technical_support_agent", "agent_arguments": '{"issue_type": "integration", "priority": "medium"}'},

    # More Billing diversity
    {"user_content": "I want to upgrade from Basic to Pro plan mid-cycle", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "upgrade", "urgency": "medium"}'},
    {"user_content": "Received invoice but payment already processed", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "invoice", "urgency": "high"}'},
    {"user_content": "Need to update tax information for international billing", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "tax_info", "urgency": "medium"}'},
    {"user_content": "My credit card expired, how do I update it?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "payment_method", "urgency": "high"}'},
    {"user_content": "Can I get a receipt for last month's payment?", "agent_name": "billing_agent", "agent_arguments": '{"request_type": "receipt", "urgency": "low"}'},

    # More Product Info diversity
    {"user_content": "What programming languages can I use with your API?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "features", "category": "api"}'},
    {"user_content": "Do you offer white-label solutions for enterprise clients?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "features", "category": "enterprise"}'},
    {"user_content": "What's the difference between your cloud and on-premise versions?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "comparison", "category": "deployment"}'},
    {"user_content": "Can I export my data in JSON format?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "features", "category": "export"}'},
    {"user_content": "What mobile operating systems are supported?", "agent_name": "product_info_agent", "agent_arguments": '{"query_type": "platform", "category": "mobile"}'},
]

# Count distribution for verification
tech_count = sum(1 for item in multi_agent_dataset if item["agent_name"] == "technical_support_agent")
billing_count = sum(1 for item in multi_agent_dataset if item["agent_name"] == "billing_agent")
product_count = sum(1 for item in multi_agent_dataset if item["agent_name"] == "product_info_agent")
total_count = len(multi_agent_dataset)

print(f"‚úÖ Created {total_count} training samples with improved variety and realism")
print(f"üìä Distribution: {tech_count} Technical Support, {billing_count} Billing, {product_count} Product Info")
print(f"üí° Expanded dataset for slower convergence - model should take 8-10 epochs to converge")


Agent Tool Definitions

In [None]:
import json
from datasets import Dataset
from transformers.utils import get_json_schema

# --- Agent Tool Definitions ---
def technical_support_agent(issue_type: str, priority: str) -> str:
    """
    Routes technical issues to the specialized technical support team.

    Args:
        issue_type: Type of technical issue (e.g., 'crash', 'authentication', 'performance', 'api_error', 'integration')
        priority: Priority level of the issue ('low', 'medium', 'high')
    """
    return f"Routing to Technical Support: {issue_type} with {priority} priority"

def billing_agent(request_type: str, urgency: str) -> str:
    """
    Routes billing, payment, subscription, and invoicing queries to the billing department.

    Args:
        request_type: Type of billing request (e.g., 'refund', 'invoice', 'upgrade', 'cancellation', 'dispute')
        urgency: How urgent the request is ('low', 'medium', 'high')
    """
    return f"Routing to Billing: {request_type} with {urgency} urgency"

def product_info_agent(query_type: str, category: str) -> str:
    """
    Routes product information queries including features, plans, integrations, and capabilities.

    Args:
        query_type: Type of product query (e.g., 'features', 'comparison', 'integrations', 'limits', 'compliance')
        category: Specific category of the query (e.g., 'plans', 'storage', 'mobile', 'security')
    """
    return f"Routing to Product Info: {query_type} about {category}"

In [None]:
# Create tools list
AGENT_TOOLS = [
    get_json_schema(technical_support_agent),
    get_json_schema(billing_agent),
    get_json_schema(product_info_agent)
]

In [None]:
DEFAULT_SYSTEM_MSG = "You are an intelligent routing agent that directs customer queries to the appropriate specialized agent based on the nature of their request."

def create_conversation(sample):
    # Parse the arguments to ensure we only keep the actual parameters
    arguments = json.loads(sample["agent_arguments"])

    return {
        "messages": [
            {"role": "developer", "content": DEFAULT_SYSTEM_MSG},
            {"role": "user", "content": sample["user_content"]},
            {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": sample["agent_name"], "arguments": arguments}}]},
        ],
        "tools": AGENT_TOOLS
    }

In [None]:
# Create dataset
dataset = Dataset.from_list(multi_agent_dataset)

# Convert dataset to conversational format
dataset = dataset.map(create_conversation, remove_columns=dataset.features, batched=False)

# Split dataset: 80% training, 20% test
# With ~100 samples: ~80 train, ~20 test
# Expanded dataset provides more diversity and slower convergence
dataset = dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)

print(f"Training samples: {len(dataset['train'])}")
print(f"Test samples: {len(dataset['test'])}")

Load Model and Tokenizer

In [None]:
# %%
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    dtype="auto",
    device_map="auto",
    attn_implementation="eager"
)
tokenizer = AutoTokenizer.from_pretrained(base_model)

print(f"Device: {model.device}")
print(f"DType: {model.dtype}")

In [None]:
# Print formatted user prompt
print("\n--- Sample Dataset Input ---")
sample_data = dataset["train"][0]
print(f"User Query: {sample_data['messages'][1]['content']}")
print(f"Expected Agent: {sample_data['messages'][2]['tool_calls'][0]['function']['name']}")
print(f"Expected Arguments: {sample_data['messages'][2]['tool_calls'][0]['function']['arguments']}")

debug_msg = tokenizer.apply_chat_template(sample_data["messages"], tools=sample_data["tools"], add_generation_prompt=False, tokenize=False)
print("\n--- Formatted Prompt (first 500 chars) ---")
print(debug_msg[:500] + "...")

Before Fine-tuning Evaluation

In [None]:
# Inspect dataset structure to debug None values
print("üîç Inspecting dataset structure...")
print("\n--- First Training Sample ---")
first_sample = dataset['train'][0]
print(f"Messages: {first_sample['messages']}")
print(f"\nExpected tool call:")
print(f"  Name: {first_sample['messages'][2]['tool_calls'][0]['function']['name']}")
print(f"  Arguments: {first_sample['messages'][2]['tool_calls'][0]['function']['arguments']}")
print(f"  Argument type: {type(first_sample['messages'][2]['tool_calls'][0]['function']['arguments'])}")

# %%
# Debug cell - Check a single sample to see raw output
print("\nüîç DEBUG: Testing single sample to see model output format\n")
test_sample = dataset['test'][0]
messages = [test_sample["messages"][0], test_sample["messages"][1]]

inputs = tokenizer.apply_chat_template(messages, tools=AGENT_TOOLS, add_generation_prompt=True, return_dict=True, return_tensors="pt")
out = model.generate(**inputs.to(model.device), pad_token_id=tokenizer.eos_token_id, max_new_tokens=128)
output = tokenizer.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=False)

print(f"Query: {test_sample['messages'][1]['content']}")
print(f"\nExpected Agent: {test_sample['messages'][2]['tool_calls'][0]['function']['name']}")
print(f"Expected Args: {test_sample['messages'][2]['tool_calls'][0]['function']['arguments']}")
print(f"\n{'='*80}")
print("Raw Model Output:")
print(output)
print('='*80)

In [None]:
import re

def extract_function_call(output):
    """
    Extract function name and arguments from model output.
    Based on working pattern: <start_function_call>call:func_name{param1:value1,param2:value2}<end_function_call>
    Handles both escaped strings and unescaped literals (true, false, numbers).
    """
    try:
        # Pattern to match function calls: call:func_name{params}
        function_call_pattern = r'<start_function_call>call:(\w+)\{([^}]+)\}<end_function_call>'
        match = re.search(function_call_pattern, output)

        if match:
            func_name = match.group(1)
            params_str = match.group(2)
            params = {}

            # Regex to extract parameters
            # Handles: param:<escape>value<escape> OR param:value
            param_pattern = r'(\w+):(?:<escape>(.*?)<escape>|([^,{}]+))'
            param_matches = re.finditer(param_pattern, params_str)

            for p_match in param_matches:
                key = p_match.group(1)
                val_escaped = p_match.group(2)
                val_simple = p_match.group(3)

                if val_escaped is not None:
                    final_val = val_escaped
                else:
                    final_val = val_simple.strip()
                    # Cast booleans and numbers
                    if final_val.lower() == 'true':
                        final_val = True
                    elif final_val.lower() == 'false':
                        final_val = False
                    elif final_val.isdigit():
                        final_val = int(final_val)

                params[key] = final_val

            return {"function_name": func_name, "arguments": params}

        # Fallback: check if function name appears anywhere
        for agent in ["technical_support_agent", "billing_agent", "product_info_agent"]:
            if agent in output:
                return {"function_name": agent, "arguments": {}}

    except Exception as e:
        print(f"   [Debug] Extract error: {e}")

    # Always return a valid dict, never None
    return {"function_name": "NONE", "arguments": {}}

In [None]:
def evaluate_routing(dataset_split, model, tokenizer, phase="BEFORE"):
    """Evaluate routing accuracy and collect results"""
    success_count = 0
    total_tests = len(dataset_split)
    results = []

    print(f"\n{'='*100}")
    print(f"üìä {phase} FINE-TUNING EVALUATION - Testing {total_tests} samples")
    print(f"{'='*100}\n")

    for idx, item in enumerate(dataset_split):
        messages = [
            item["messages"][0],
            item["messages"][1],
        ]

        inputs = tokenizer.apply_chat_template(messages, tools=AGENT_TOOLS, add_generation_prompt=True, return_dict=True, return_tensors="pt")

        out = model.generate(**inputs.to(model.device), pad_token_id=tokenizer.eos_token_id, max_new_tokens=128)
        output = tokenizer.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=False)

        # Extract expected values (only the actual parameters, not all possible ones)
        expected_agent = item['messages'][2]['tool_calls'][0]['function']['name']
        expected_args = item['messages'][2]['tool_calls'][0]['function']['arguments']

        # Extract predicted values
        predicted = extract_function_call(output)
        predicted_agent = predicted['function_name']
        predicted_args = predicted['arguments']

        # Debug: Print raw output for first few samples
        if idx < 5 and phase == "BEFORE":
            print(f"\n[DEBUG] Raw model output for sample {idx+1}:")
            print(output[:300])
            print(f"[DEBUG] Extracted: {predicted_agent}({predicted_args})\n")

        # Check if correct
        other_agents = [agent for agent in ["technical_support_agent", "billing_agent", "product_info_agent"] if agent != expected_agent]
        is_correct = expected_agent == predicted_agent and not any(agent in output for agent in other_agents)

        # Store result
        result = {
            "query": item['messages'][1]['content'],
            "expected_agent": expected_agent,
            "expected_arguments": expected_args,
            "predicted_agent": predicted_agent,
            "predicted_arguments": predicted_args,
            "raw_output": output[:200],
            "correct": is_correct
        }
        results.append(result)

        # Print result
        status = "‚úÖ CORRECT" if is_correct else "‚ùå WRONG"
        print(f"{idx+1}. Query: {result['query'][:80]}...")
        print(f"   Expected:  {expected_agent}{json.dumps(expected_args, ensure_ascii=False)}")
        print(f"   Predicted: {predicted_agent}{json.dumps(predicted_args, ensure_ascii=False)}")
        print(f"   {status}\n")

        if is_correct:
            success_count += 1

    accuracy = (success_count / total_tests) * 100
    print(f"\n{'='*100}")
    print(f"üéØ Routing Accuracy: {success_count}/{total_tests} ({accuracy:.1f}%)")
    print(f"{'='*100}\n")

    return results, success_count, total_tests, accuracy

In [None]:
print("üîç Evaluating model BEFORE fine-tuning...")
results_before, success_before, total_before, accuracy_before = evaluate_routing(
    dataset['test'], model, tokenizer, phase="BEFORE"
)

**Training Configuration**


**üìà For Even Better Results:**
- Expand dataset to 500+ samples (currently 115)
- Add more edge cases and variations
- Consider data augmentation (paraphrasing queries)
- Fine-tune for more epochs if validation loss keeps decreasing
- Play with hyperparameters

In [None]:
from trl import SFTConfig

torch_dtype = model.dtype

# Check if we should push to hub (only if token and repo name are provided)
should_push_to_hub = bool(hf_token and hf_repo_name and hf_token != "xxxx")

# Improved training configuration for SLOWER CONVERGENCE (8-10 epochs)
# Key changes for slower convergence:
# 1. Even lower learning rate (1e-5) to slow down learning
# 2. Reduced warmup (5% instead of 10%) for slower start
# 3. Higher weight decay (0.02) for more regularization
# 4. More epochs (15) - should converge around epoch 8-10
# 5. Larger effective batch size with gradient accumulation
args = SFTConfig(
    output_dir=checkpoint_dir,              # directory to save and repository id
    max_length=512,                         # max sequence length for model and packing of the dataset
    packing=False,                          # Groups multiple samples in the dataset into a single sequence
    num_train_epochs=15,                    # 15 epochs - should converge around epoch 8-10
    per_device_train_batch_size=4,          # batch size per device during training
    gradient_accumulation_steps=3,          # effective batch size = 4 * 3 = 12 (larger for slower convergence)
    gradient_checkpointing=False,           # Caching is incompatible with gradient checkpointing
    optim="adamw_torch_fused",              # use fused adamw optimizer
    weight_decay=0.02,                      # Higher L2 regularization to slow convergence
    warmup_ratio=0.05,                      # Reduced warmup (5%) for slower start
    logging_steps=1,                        # log every step
    save_strategy="epoch",                  # save checkpoint every epoch
    save_total_limit=3,                     # keep 3 best checkpoints
    load_best_model_at_end=True,           # load best model after training
    metric_for_best_model="eval_loss",      # use validation loss to select best
    greater_is_better=False,                # lower eval_loss is better
    eval_strategy="epoch",                  # evaluate checkpoint every epoch
    learning_rate=1e-5,                     # Lower learning rate (1e-5) for slower convergence
    fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,  # use bfloat16 precision
    lr_scheduler_type="cosine",             # cosine learning rate schedule
    push_to_hub=should_push_to_hub,         # push model to hub only if credentials provided
    hub_model_id=hf_repo_name if should_push_to_hub else None,  # set hub model id if pushing
    report_to="tensorboard",                 # report metrics to tensorboard
    dataloader_drop_last=False,             # don't drop last incomplete batch
)

Start Training

In [None]:
from trl import SFTTrainer

# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    processing_class=tokenizer,
)

In [None]:
# Start training
trainer.train()

# Save the final model
trainer.save_model()

In [None]:
import matplotlib.pyplot as plt

# Access the log history
log_history = trainer.state.log_history

# Extract training / validation loss
train_losses = [log["loss"] for log in log_history if "loss" in log]
epoch_train = [log["epoch"] for log in log_history if "loss" in log]
eval_losses = [log["eval_loss"] for log in log_history if "eval_loss" in log]
epoch_eval = [log["epoch"] for log in log_history if "eval_loss" in log]

# Plot the training loss
plt.figure(figsize=(10, 6))
plt.plot(epoch_train, train_losses, label="Training Loss", marker='o')
plt.plot(epoch_eval, eval_losses, label="Validation Loss", marker='s')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Multi-Agent Router Training Progress")
plt.legend()
plt.grid(True)
plt.show()

## After Fine-tuning Evaluation

In [None]:
print("üîç Evaluating model AFTER fine-tuning...")
results_after, success_after, total_after, accuracy_after = evaluate_routing(
    dataset['test'], model, tokenizer, phase="AFTER"
)

In [None]:
import json
from datetime import datetime

# Prepare comparison data
comparison_data = {
    "metadata": {
        "base_model": base_model,
        "training_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "num_train_samples": len(dataset['train']),
        "num_test_samples": len(dataset['test']),
        "num_epochs": args.num_train_epochs,
        "learning_rate": learning_rate,
        "batch_size": args.per_device_train_batch_size,
    },
    "performance": {
        "before_training": {
            "accuracy": accuracy_before,
            "correct": success_before,
            "total": total_before
        },
        "after_training": {
            "accuracy": accuracy_after,
            "correct": success_after,
            "total": total_after
        },
        "improvement": {
            "accuracy_gain": accuracy_after - accuracy_before,
            "additional_correct": success_after - success_before
        }
    },
    "detailed_results": {
        "before_training": results_before,
        "after_training": results_after
    }
}


In [None]:
# Save to JSON file
output_json_path = f"{checkpoint_dir}/evaluation_results.json"
os.makedirs(checkpoint_dir, exist_ok=True)

with open(output_json_path, 'w') as f:
    json.dump(comparison_data, f, indent=2)

print(f"‚úÖ Results saved to: {output_json_path}")

# Display summary
print("\n" + "="*100)
print("üìä TRAINING RESULTS SUMMARY")
print("="*100)
print(f"Before Training: {success_before}/{total_before} ({accuracy_before:.1f}%)")
print(f"After Training:  {success_after}/{total_after} ({accuracy_after:.1f}%)")
print(f"Improvement:     +{success_after - success_before} correct (+{accuracy_after - accuracy_before:.1f}%)")
print("="*100 + "\n")

# Final Visualization & Analysis

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from IPython.display import display, HTML

# Create comprehensive visualization with 6 subplots
fig = make_subplots(
    rows=3, cols=2,
    subplot_titles=(
        'üéØ Routing Accuracy Comparison',
        'üìà Prediction Distribution',
        'üìâ Training Progress',
        'ü§ñ Before Training: Per-Agent Accuracy',
        'ü§ñ After Training: Per-Agent Accuracy',
        'üìä Per-Agent Improvement'
    ),
    specs=[
        [{"type": "bar"}, {"type": "bar"}],
        [{"type": "scatter"}, {"type": "bar"}],
        [{"type": "bar"}, {"type": "bar"}]
    ],
    vertical_spacing=0.12,
    horizontal_spacing=0.15
)

# 1. Accuracy Comparison
phases = ['Before<br>Training', 'After<br>Training']
accuracies = [accuracy_before, accuracy_after]
colors_acc = ['#ff6b6b', '#51cf66']

fig.add_trace(go.Bar(
    x=phases,
    y=accuracies,
    text=[f'{acc:.1f}%' for acc in accuracies],
    textposition='outside',
    marker=dict(color=colors_acc, line=dict(color='black', width=2)),
    showlegend=False
), row=1, col=1)

fig.update_yaxes(range=[0, 100], title_text='Accuracy (%)', row=1, col=1)

In [None]:
# 2. Correct vs Incorrect Predictions
categories = ['Before Training', 'After Training']
correct_counts = [success_before, success_after]
incorrect_counts = [total_before - success_before, total_after - success_after]

fig.add_trace(go.Bar(
    x=categories,
    y=correct_counts,
    name='‚úÖ Correct',
    marker_color='#51cf66',
    text=correct_counts,
    textposition='inside'
), row=1, col=2)

fig.add_trace(go.Bar(
    x=categories,
    y=incorrect_counts,
    name='‚ùå Incorrect',
    marker_color='#ff6b6b',
    text=incorrect_counts,
    textposition='inside'
), row=1, col=2)

fig.update_yaxes(title_text='Number of Predictions', row=1, col=2)
fig.update_layout(barmode='stack')

In [None]:
# 3. Training & Validation Loss
fig.add_trace(go.Scatter(
    x=epoch_train,
    y=train_losses,
    mode='lines+markers',
    name='Training Loss',
    line=dict(color='#4c6ef5', width=3),
    marker=dict(size=8)
), row=2, col=1)

fig.add_trace(go.Scatter(
    x=epoch_eval,
    y=eval_losses,
    mode='lines+markers',
    name='Validation Loss',
    line=dict(color='#f59f00', width=3),
    marker=dict(size=8)
), row=2, col=1)

fig.update_xaxes(title_text='Epoch', row=2, col=1)
fig.update_yaxes(title_text='Loss', row=2, col=1)

In [None]:
# 4. Per-Agent Accuracy (Before)
agent_names_short = ['Technical<br>Support', 'Billing', 'Product<br>Info']
agent_accuracy_before = []

for agent in ["technical_support_agent", "billing_agent", "product_info_agent"]:
    agent_results = [r for r in results_before if r['expected_agent'] == agent]
    if agent_results:
        correct = sum(1 for r in agent_results if r['correct'])
        acc = (correct / len(agent_results)) * 100
        agent_accuracy_before.append(acc)
    else:
        agent_accuracy_before.append(0)

fig.add_trace(go.Bar(
    x=agent_names_short,
    y=agent_accuracy_before,
    marker=dict(color=['#ff6b6b', '#ff922b', '#ffd43b'], line=dict(color='black', width=2)),
    text=[f'{acc:.0f}%' for acc in agent_accuracy_before],
    textposition='outside',
    showlegend=False
), row=2, col=2)

fig.update_yaxes(range=[0, 100], title_text='Accuracy (%)', row=2, col=2)

In [None]:
# 5. Per-Agent Accuracy (After)
agent_accuracy_after = []

for agent in ["technical_support_agent", "billing_agent", "product_info_agent"]:
    agent_results = [r for r in results_after if r['expected_agent'] == agent]
    if agent_results:
        correct = sum(1 for r in agent_results if r['correct'])
        acc = (correct / len(agent_results)) * 100
        agent_accuracy_after.append(acc)
    else:
        agent_accuracy_after.append(0)

fig.add_trace(go.Bar(
    x=agent_names_short,
    y=agent_accuracy_after,
    marker=dict(color=['#51cf66', '#69db7c', '#8ce99a'], line=dict(color='black', width=2)),
    text=[f'{acc:.0f}%' for acc in agent_accuracy_after],
    textposition='outside',
    showlegend=False
), row=3, col=1)

fig.update_yaxes(range=[0, 100], title_text='Accuracy (%)', row=3, col=1)

In [None]:
# 6. Improvement by Agent
improvements = [after - before for before, after in zip(agent_accuracy_before, agent_accuracy_after)]
colors_improvement = ['#51cf66' if imp >= 0 else '#ff6b6b' for imp in improvements]

fig.add_trace(go.Bar(
    x=agent_names_short,
    y=improvements,
    marker=dict(color=colors_improvement, line=dict(color='black', width=2)),
    text=[f'{imp:+.0f}%' for imp in improvements],
    textposition='outside',
    showlegend=False
), row=3, col=2)

fig.add_hline(y=0, line_dash="dash", line_color="black", row=3, col=2)
fig.update_yaxes(title_text='Accuracy Improvement (%)', row=3, col=2)

# Update overall layout
fig.update_layout(
    height=1200,
    showlegend=True,
    template='plotly_white',
    font=dict(size=12),
    title_text="Multi-Agent Router: Complete Training Analysis",
    title_font_size=20
)

fig.show()

# Save as HTML
fig.write_html(f'{checkpoint_dir}/training_analysis_interactive.html')
print(f"‚úÖ Interactive visualization saved to: {checkpoint_dir}/training_analysis_interactive.html")

# Final Summary Report

In [None]:
summary_html = f"""
<div style="background: white; padding: 30px; border-radius: 15px; border: 3px solid #4c6ef5; font-family: Arial, sans-serif; box-shadow: 0 4px 6px rgba(0,0,0,0.1);">
    <h1 style="text-align: center; margin-bottom: 30px; color: #2c3e50;">üéâ Multi-Agent Router Training Complete!</h1>

    <div style="background: #f8f9fa; padding: 20px; border-radius: 10px; margin-bottom: 20px; border-left: 5px solid #4c6ef5;">
        <h2 style="color: #2c3e50;">üìä Performance Metrics</h2>
        <table style="width: 100%; color: #2c3e50; font-size: 16px; border-collapse: collapse;">
            <tr style="border-bottom: 2px solid #dee2e6;">
                <td style="padding: 12px;"><strong>Before Training:</strong></td>
                <td style="text-align: right; font-size: 20px; padding: 12px;">{success_before}/{total_before} ({accuracy_before:.1f}%)</td>
            </tr>
            <tr style="border-bottom: 2px solid #dee2e6;">
                <td style="padding: 12px;"><strong>After Training:</strong></td>
                <td style="text-align: right; font-size: 20px; padding: 12px; color: #51cf66; font-weight: bold;">{success_after}/{total_after} ({accuracy_after:.1f}%)</td>
            </tr>
            <tr style="background: #e7f5ff;">
                <td style="padding: 12px;"><strong>Improvement:</strong></td>
                <td style="text-align: right; font-size: 24px; padding: 12px; color: #f59f00; font-weight: bold;">+{success_after - success_before} correct (+{accuracy_after - accuracy_before:.1f}%)</td>
            </tr>
        </table>
    </div>

    <div style="background: #f8f9fa; padding: 20px; border-radius: 10px; margin-bottom: 20px; border-left: 5px solid #51cf66;">
        <h2 style="color: #2c3e50;">ü§ñ Agent-Specific Performance</h2>
        <table style="width: 100%; color: #2c3e50; font-size: 14px; border-collapse: collapse;">
            <tr style="background: #e7f5ff; border-bottom: 2px solid #4c6ef5;">
                <th style="padding: 12px; text-align: left;">Agent</th>
                <th style="padding: 12px; text-align: center;">Before</th>
                <th style="padding: 12px; text-align: center;">After</th>
                <th style="padding: 12px; text-align: center;">Improvement</th>
            </tr>
            <tr style="border-bottom: 1px solid #dee2e6;">
                <td style="padding: 10px;"><strong>Technical Support</strong></td>
                <td style="padding: 10px; text-align: center;">{agent_accuracy_before[0]:.0f}%</td>
                <td style="padding: 10px; text-align: center; color: #51cf66; font-weight: bold;">{agent_accuracy_after[0]:.0f}%</td>
                <td style="padding: 10px; text-align: center; color: #f59f00; font-weight: bold;">{improvements[0]:+.0f}%</td>
            </tr>
            <tr style="background: #f8f9fa; border-bottom: 1px solid #dee2e6;">
                <td style="padding: 10px;"><strong>Billing</strong></td>
                <td style="padding: 10px; text-align: center;">{agent_accuracy_before[1]:.0f}%</td>
                <td style="padding: 10px; text-align: center; color: #51cf66; font-weight: bold;">{agent_accuracy_after[1]:.0f}%</td>
                <td style="padding: 10px; text-align: center; color: #f59f00; font-weight: bold;">{improvements[1]:+.0f}%</td>
            </tr>
            <tr style="border-bottom: 1px solid #dee2e6;">
                <td style="padding: 10px;"><strong>Product Info</strong></td>
                <td style="padding: 10px; text-align: center;">{agent_accuracy_before[2]:.0f}%</td>
                <td style="padding: 10px; text-align: center; color: #51cf66; font-weight: bold;">{agent_accuracy_after[2]:.0f}%</td>
                <td style="padding: 10px; text-align: center; color: #f59f00; font-weight: bold;">{improvements[2]:+.0f}%</td>
            </tr>
        </table>
    </div>

    <div style="background: #f8f9fa; padding: 20px; border-radius: 10px; margin-bottom: 20px; border-left: 5px solid #f59f00;">
        <h2 style="color: #2c3e50;">‚öôÔ∏è Training Configuration</h2>
        <ul style="font-size: 14px; line-height: 1.8; color: #495057;">
            <li><strong>Base Model:</strong> {base_model}</li>
            <li><strong>Training Samples:</strong> {len(dataset['train'])}</li>
            <li><strong>Test Samples:</strong> {len(dataset['test'])}</li>
            <li><strong>Epochs:</strong> {args.num_train_epochs}</li>
            <li><strong>Batch Size:</strong> {args.per_device_train_batch_size}</li>
            <li><strong>Learning Rate:</strong> {args.learning_rate}</li>
            <li><strong>Weight Decay:</strong> {args.weight_decay}</li>
            <li><strong>Scheduler:</strong> {args.lr_scheduler_type} (with {int(args.warmup_ratio * 100)}% warmup)</li>
            <li><strong>GPU:</strong> T4 (Free Colab)</li>
        </ul>
    </div>

    <div style="background: #e7f5ff; padding: 20px; border-radius: 10px; border-left: 5px solid #4c6ef5;">
        <h2 style="color: #2c3e50;">‚úÖ Next Steps</h2>
        <ol style="font-size: 14px; line-height: 1.8; color: #495057;">
            <li>Review the evaluation results JSON file</li>
            <li>Test the model with your own queries</li>
            <li>Integrate into your multi-agent system</li>
            <li>Monitor performance in production</li>
            <li>Expand dataset with edge cases if needed</li>
        </ol>
    </div>
</div>
"""

display(HTML(summary_html))

In [None]:
# Print file locations
print("\n" + "="*100)
print("üìÅ OUTPUT FILES")
print("="*100)
print(f"‚úÖ Model checkpoints: {checkpoint_dir}")
print(f"‚úÖ Evaluation results: {checkpoint_dir}/evaluation_results.json")
print(f"‚úÖ Training visualization: {checkpoint_dir}/training_analysis.png")
print("="*100 + "\n")


Push Dataset to Hugging Face Hub

In [None]:
# Save dataset to local files first
import os

# Create dataset directory
dataset_dir = "multiagent_router_dataset"
os.makedirs(dataset_dir, exist_ok=True)

# Save the raw dataset as JSON
import json

with open(f"{dataset_dir}/dataset.json", "w") as f:
    json.dump(multi_agent_dataset, f, indent=2)

print(f"‚úÖ Raw dataset saved to {dataset_dir}/dataset.json")

# Save train and test splits as JSON Lines format (common for HF datasets)
with open(f"{dataset_dir}/train.jsonl", "w") as f:
    for item in dataset['train']:
        # Flatten the structure for easier use
        flat_item = {
            "query": item["messages"][1]["content"],
            "agent_name": item["messages"][2]["tool_calls"][0]["function"]["name"],
            "agent_arguments": item["messages"][2]["tool_calls"][0]["function"]["arguments"],
            "system_message": item["messages"][0]["content"],
        }
        f.write(json.dumps(flat_item) + "\n")

with open(f"{dataset_dir}/test.jsonl", "w") as f:
    for item in dataset['test']:
        flat_item = {
            "query": item["messages"][1]["content"],
            "agent_name": item["messages"][2]["tool_calls"][0]["function"]["name"],
            "agent_arguments": item["messages"][2]["tool_calls"][0]["function"]["arguments"],
            "system_message": item["messages"][0]["content"],
        }
        f.write(json.dumps(flat_item) + "\n")

print(f"‚úÖ Train split saved: {len(dataset['train'])} samples")
print(f"‚úÖ Test split saved: {len(dataset['test'])} samples")


In [None]:
# Create README.md for the dataset
readme_content = f"""---
language:
- en
license: apache-2.0
task_categories:
- text-classification
- question-answering
tags:
- function-calling
- multi-agent
- routing
- customer-support
- synthetic
pretty_name: Multi-Agent Router Fine-tuning Dataset
size_categories:
- n<1K
---

# Multi-Agent Router Fine-tuning Dataset

## Dataset Description

This dataset is designed for fine-tuning language models to perform intelligent routing in multi-agent customer support systems. The model learns to classify user queries and route them to the appropriate specialized agent with relevant parameters.

### Supported Tasks

- **Function Calling**: Route queries to appropriate agent functions
- **Intent Classification**: Identify the type of support needed
- **Parameter Extraction**: Extract relevant parameters from queries

## Dataset Structure

### Data Instances

Each instance contains:
- `query`: The user's question or request
- `agent_name`: The target agent to handle the query (technical_support_agent, billing_agent, or product_info_agent)
- `agent_arguments`: JSON object with parameters for the agent
- `system_message`: System prompt for the model

Example:
```json
{{
  "query": "My app keeps crashing when I try to upload photos larger than 5MB",
  "agent_name": "technical_support_agent",
  "agent_arguments": {{
    "issue_type": "crash",
    "priority": "high"
  }},
  "system_message": "You are an intelligent routing agent..."
}}
```

### Data Fields

- **query** (string): User's question or request
- **agent_name** (string): Target agent name
  - `technical_support_agent`: Technical issues, bugs, integration
  - `billing_agent`: Payments, subscriptions, invoices
  - `product_info_agent`: Features, plans, integrations
- **agent_arguments** (dict): Agent-specific parameters
  - Technical Support: `issue_type`, `priority`
  - Billing: `request_type`, `urgency`
  - Product Info: `query_type`, `category`
- **system_message** (string): System prompt

### Data Splits

| Split | Examples |
|-------|----------|
| train | {len(dataset['train'])} |
| test  | {len(dataset['test'])} |

## Dataset Creation

### Curation Rationale

This dataset was created to train routing models for multi-agent customer support systems. Real-world customer support requires:
- Accurate classification of query intent
- Extraction of priority/urgency levels
- Routing to specialized agents

### Source Data

#### Initial Data Collection and Normalization

The dataset consists of synthetic but realistic customer support queries covering:
- **Technical Support** (20 samples): App crashes, API errors, authentication issues, performance problems
- **Billing** (20 samples): Refunds, payment failures, subscription management, pricing inquiries
- **Product Information** (20 samples): Feature comparisons, integrations, compliance questions, platform capabilities
- **Edge Cases** (5 samples): Ambiguous queries to test robustness

Queries were designed to be:
- Realistic and varied
- Include specific details (error codes, product names, numeric values)
- Cover different priority/urgency levels
- Include edge cases and ambiguous requests

## Usage

### Load Dataset

```python
from datasets import load_dataset

dataset = load_dataset("bhaiyahnsingh45/multiagent-router-finetuning")

# Access splits
train_data = dataset['train']
test_data = dataset['test']

# Example usage
for example in train_data:
    print(f"Query: {{example['query']}}")
    print(f"Agent: {{example['agent_name']}}")
    print(f"Arguments: {{example['agent_arguments']}}")
```

### Fine-tuning Example

```python
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import get_json_schema

# Define your agent functions
def technical_support_agent(issue_type: str, priority: str) -> str:
    \"\"\"Routes technical issues to specialized support team.\"\"\"
    pass

def billing_agent(request_type: str, urgency: str) -> str:
    \"\"\"Routes billing and payment queries.\"\"\"
    pass

def product_info_agent(query_type: str, category: str) -> str:
    \"\"\"Routes product information queries.\"\"\"
    pass

# Get tool schemas
tools = [
    get_json_schema(technical_support_agent),
    get_json_schema(billing_agent),
    get_json_schema(product_info_agent)
]

# Format for training (example with FunctionGemma)
def create_conversation(sample):
    return {{
        "messages": [
            {{"role": "developer", "content": sample["system_message"]}},
            {{"role": "user", "content": sample["query"]}},
            {{"role": "assistant", "tool_calls": [{{
                "type": "function",
                "function": {{
                    "name": sample["agent_name"],
                    "arguments": sample["agent_arguments"]
                }}
            }}]}}
        ],
        "tools": tools
    }}

# Apply to dataset
dataset = dataset.map(create_conversation)
```

## Dataset Statistics

### Query Length Distribution

- **Min tokens**: ~5
- **Max tokens**: ~25
- **Average tokens**: ~12

### Agent Distribution

| Agent | Count | Percentage |
|-------|-------|------------|
| Technical Support | ~20 | ~33% |
| Billing | ~20 | ~33% |
| Product Info | ~20 | ~33% |
| Edge Cases | ~5 | ~8% |

### Parameter Distribution

**Technical Support - Priority Levels:**
- High: ~50%
- Medium: ~40%
- Low: ~10%

**Billing - Urgency Levels:**
- High: ~30%
- Medium: ~40%
- Low: ~30%

## Evaluation

Expected model performance after fine-tuning:
- **Baseline accuracy**: 10-30% (pre-trained model)
- **Target accuracy**: 70-95% (fine-tuned model)
- **Training time**: ~5-10 minutes on T4 GPU

## Considerations for Using the Data

### Social Impact

This dataset helps improve automated customer support systems by:
- Reducing wait times through accurate routing
- Improving first-contact resolution rates
- Enabling 24/7 support capabilities

### Limitations

- Synthetic data may not cover all real-world variations
- English language only
- Limited to three agent types
- May require domain adaptation for specific industries

## Additional Information

### Dataset Curators

Created for fine-tuning FunctionGemma and similar function-calling models.

### Licensing Information

Apache 2.0 License

### Citation Information

```bibtex
@dataset{{multiagent_router_finetuning,
  author = {{Your Name}},
  title = {{Multi-Agent Router Fine-tuning Dataset}},
  year = {{2025}},
  publisher = {{Hugging Face}},
  url = {{https://huggingface.co/datasets/bhaiyahnsingh45/multiagent-router-finetuning}}
}}
```

### Contributions

Contributions to expand this dataset are welcome! Areas for improvement:
- Additional languages
- More agent types (sales, feedback, onboarding)
- Domain-specific variations (healthcare, finance, e-commerce)
- Real user query examples (with proper anonymization)

### Contact

For questions or feedback, please open an issue on the dataset repository.

---

**Note**: This is a synthetic dataset created for training purposes. For production use, consider augmenting with real anonymized customer queries from your specific domain.
"""

# Save README
with open(f"{dataset_dir}/README.md", "w") as f:
    f.write(readme_content)

print(f"‚úÖ README.md created")


In [None]:
# Upload to Hugging Face Hub
from huggingface_hub import login, upload_folder

# Configuration
hf_dataset_repo = "bhaiyahnsingh45/multiagent-router-finetuning"  # @param {type:"string"}

print("\nüöÄ Uploading dataset to Hugging Face Hub...")
print(f"Repository: {hf_dataset_repo}")

# Login (will use token from earlier login)
try:
    # Upload the dataset folder
    upload_folder(
        folder_path=dataset_dir,
        repo_id=hf_dataset_repo,
        repo_type="dataset",
        commit_message="Initial dataset upload: Multi-Agent Router fine-tuning dataset"
    )

    print(f"\n‚úÖ Dataset successfully uploaded!")
    print(f"üîó View your dataset at: https://huggingface.co/datasets/{hf_dataset_repo}")
    print(f"\nüìö Users can now load it with:")
    print(f'   from datasets import load_dataset')
    print(f'   dataset = load_dataset("{hf_dataset_repo}")')

except Exception as e:
    print(f"\n‚ùå Error uploading dataset: {e}")
    print("\nüí° Troubleshooting:")
    print("1. Make sure you're logged in: huggingface-cli login")
    print("2. Check repository name format: username/dataset-name")
    print("3. Ensure you have write permissions")
    print(f"\nüìÅ Dataset files are saved locally in: {dataset_dir}")
    print("   You can manually upload them to Hugging Face if needed.")


Push Model to Hugging Face Hub


In [None]:
# Check if we have valid credentials for pushing
if hf_token and hf_repo_name and hf_token != "xxxx":
    print("üöÄ Pushing model to Hugging Face Hub...")
    try:
        # Push the model
        trainer.push_to_hub(commit_message="Multi-agent router fine-tuned model")

        # Also push tokenizer
        tokenizer.push_to_hub(hf_repo_name)

        print(f"‚úÖ Model successfully pushed to: https://huggingface.co/{hf_repo_name}")

        # Create comprehensive model card with inference code
        model_card = f"""---
language:
- en
license: gemma
library_name: transformers
tags:
- function-calling
- multi-agent
- router
- gemma
- fine-tuned
- customer-support
base_model: google/functiongemma-270m-it
datasets:
- bhaiyahnsingh45/multiagent-router-finetuning
metrics:
- accuracy
pipeline_tag: text-generation
widget:
- text: "My app keeps crashing when I upload large files"
  example_title: "Technical Issue"
- text: "I need a refund for my subscription"
  example_title: "Billing Request"
- text: "What integrations do you support?"
  example_title: "Product Info"
---

# Multi-Agent Router (Fine-tuned FunctionGemma 270M)

<div align="center">
  <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png" alt="Hugging Face" width="100"/>

  **Intelligent routing model for multi-agent customer support systems**

  [![License: Gemma](https://img.shields.io/badge/License-Gemma-blue.svg)](https://ai.google.dev/gemma/terms)
  [![Model: FunctionGemma](https://img.shields.io/badge/Model-FunctionGemma-orange.svg)](https://huggingface.co/google/functiongemma-270m-it)
  [![Dataset](https://img.shields.io/badge/Dataset-Available-green.svg)](https://huggingface.co/datasets/bhaiyahnsingh45/multiagent-router-finetuning)
</div>

## üìã Model Description

This model is a **fine-tuned version of Google's FunctionGemma 270M** specifically trained for intelligent routing in multi-agent customer support systems. It learns to:

1. **Classify user intent** from natural language queries
2. **Route to the appropriate specialist agent**
3. **Extract relevant parameters** (priority, urgency, category)

### ü§ñ Supported Agents

The model routes queries to three specialized agents:

| Agent | Handles | Parameters |
|-------|---------|------------|
| üîß **Technical Support** | Crashes, bugs, API errors, authentication issues | `issue_type`, `priority` |
| üí∞ **Billing** | Payments, refunds, subscriptions, invoices | `request_type`, `urgency` |
| üìä **Product Info** | Features, integrations, plans, compliance | `query_type`, `category` |

## üéØ Training Details

### Base Model
- **Model**: `google/functiongemma-270m-it`
- **Parameters**: 270 Million
- **Architecture**: Gemma with function calling capabilities

### Fine-tuning Configuration
- **Training Samples**: {len(dataset['train'])}
- **Test Samples**: {len(dataset['test'])}
- **Epochs**: {args.num_train_epochs}
- **Batch Size**: {args.per_device_train_batch_size}
- **Learning Rate**: {learning_rate}
- **GPU**: NVIDIA T4 (Google Colab Free Tier)
- **Training Time**: ~5-8 minutes

### Dataset
Fine-tuned on [bhaiyahnsingh45/multiagent-router-finetuning](https://huggingface.co/datasets/bhaiyahnsingh45/multiagent-router-finetuning) containing 85 realistic customer support queries across three categories.

## üìä Performance

| Metric | Before Training | After Training | Improvement |
|--------|----------------|----------------|-------------|
| **Accuracy** | {accuracy_before:.1f}% | {accuracy_after:.1f}% | **+{accuracy_after - accuracy_before:.1f}%** |
| **Correct Predictions** | {success_before}/{total_before} | {success_after}/{total_after} | +{success_after - success_before} |

### Per-Agent Performance
- **Technical Support**: High accuracy on crash reports, API errors, authentication issues
- **Billing**: Excellent routing for refunds, payments, subscription management
- **Product Info**: Strong performance on feature queries, integrations, compliance questions

## üöÄ Quick Start

### Installation

```bash
pip install transformers torch
```

### Basic Usage

```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
import json

# Load model and tokenizer
model_name = "{hf_repo_name}"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype="auto"
)

# Define your agent tools
from transformers.utils import get_json_schema

def technical_support_agent(issue_type: str, priority: str) -> str:
    \"\"\"
    Routes technical issues to specialized support team.

    Args:
        issue_type: Type of technical issue (crash, authentication, performance, api_error, etc.)
        priority: Priority level (low, medium, high)
    \"\"\"
    return f"Routing to Technical Support: {{issue_type}} with {{priority}} priority"

def billing_agent(request_type: str, urgency: str) -> str:
    \"\"\"
    Routes billing and payment queries.

    Args:
        request_type: Type of request (refund, invoice, upgrade, cancellation, etc.)
        urgency: How urgent (low, medium, high)
    \"\"\"
    return f"Routing to Billing: {{request_type}} with {{urgency}} urgency"

def product_info_agent(query_type: str, category: str) -> str:
    \"\"\"
    Routes product information queries.

    Args:
        query_type: Type of query (features, comparison, integrations, limits, etc.)
        category: Category (plans, storage, mobile, security, etc.)
    \"\"\"
    return f"Routing to Product Info: {{query_type}} about {{category}}"

# Get tool schemas
AGENT_TOOLS = [
    get_json_schema(technical_support_agent),
    get_json_schema(billing_agent),
    get_json_schema(product_info_agent)
]

# System message
SYSTEM_MSG = "You are an intelligent routing agent that directs customer queries to the appropriate specialized agent."

# Function to route queries
def route_query(user_query: str):
    \"\"\"Route a user query to the appropriate agent\"\"\"

    messages = [
        {{"role": "developer", "content": SYSTEM_MSG}},
        {{"role": "user", "content": user_query}}
    ]

    # Format prompt
    inputs = tokenizer.apply_chat_template(
        messages,
        tools=AGENT_TOOLS,
        add_generation_prompt=True,
        return_dict=True,
        return_tensors="pt"
    )

    # Generate
    outputs = model.generate(
        **inputs.to(model.device),
        max_new_tokens=128,
        pad_token_id=tokenizer.eos_token_id
    )

    # Decode
    result = tokenizer.decode(
        outputs[0][len(inputs["input_ids"][0]):],
        skip_special_tokens=False
    )

    return result

# Example usage
query = "My app crashes when I try to upload large files"
result = route_query(query)
print(f"Query: {{query}}")
print(f"Routing: {{result}}")
```

### Expected Output Format

```
<start_function_call>call:technical_support_agent{{issue_type:crash,priority:high}}<end_function_call>
```

## üí° Usage Examples

### Example 1: Technical Issue
```python
query = "I'm getting a 500 error when calling the API"
result = route_query(query)
# Output: technical_support_agent(issue_type="api_error", priority="high")
```

### Example 2: Billing Request
```python
query = "I need a refund for my annual subscription"
result = route_query(query)
# Output: billing_agent(request_type="refund", urgency="medium")
```

### Example 3: Product Question
```python
query = "What integrations do you support for project management?"
result = route_query(query)
# Output: product_info_agent(query_type="integrations", category="project_management")
```

## üîß Advanced Usage: Parse Function Calls

```python
def parse_function_call(output: str) -> dict:
    \"\"\"Extract function name and arguments from model output\"\"\"

    pattern = r'<start_function_call>call:(\\w+)\\{{([^}}]+)\\}}<end_function_call>'
    match = re.search(pattern, output)

    if match:
        func_name = match.group(1)
        params_str = match.group(2)

        # Parse parameters
        params = {{}}
        param_pattern = r'(\\w+):(?:<escape>(.*?)<escape>|([^,{{}}]+))'
        for p_match in re.finditer(param_pattern, params_str):
            key = p_match.group(1)
            val = p_match.group(2) or p_match.group(3).strip()
            params[key] = val

        return {{
            "agent": func_name,
            "parameters": params
        }}

    return {{"agent": "unknown", "parameters": {{}}}}

# Use it
query = "I was charged twice this month"
result = route_query(query)
parsed = parse_function_call(result)
print(parsed)
# Output: {{'agent': 'billing_agent', 'parameters': {{'request_type': 'dispute', 'urgency': 'high'}}}}
```

## üèóÔ∏è Integration Example

```python
class MultiAgentRouter:
    def __init__(self, model_name: str):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype="auto"
        )
        self.system_msg = "You are an intelligent routing agent..."

    def route(self, query: str) -> dict:
        \"\"\"Route query and return agent + parameters\"\"\"
        messages = [
            {{"role": "developer", "content": self.system_msg}},
            {{"role": "user", "content": query}}
        ]

        inputs = self.tokenizer.apply_chat_template(
            messages,
            tools=AGENT_TOOLS,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt"
        )

        outputs = self.model.generate(
            **inputs.to(self.model.device),
            max_new_tokens=128,
            pad_token_id=self.tokenizer.eos_token_id
        )

        result = self.tokenizer.decode(
            outputs[0][len(inputs["input_ids"][0]):],
            skip_special_tokens=False
        )

        return parse_function_call(result)

# Usage
router = MultiAgentRouter("{hf_repo_name}")
routing = router.route("My payment failed but I don't know why")
print(f"Route to: {{routing['agent']}}")
print(f"Parameters: {{routing['parameters']}}")
```

## üìà Evaluation

The model was evaluated on a held-out test set of {total_after} queries:

- **Routing Accuracy**: {accuracy_after:.1f}%
- **False Positive Rate**: {(100 - accuracy_after):.1f}%
- **Average Inference Time**: ~50ms on T4 GPU

## ‚ö†Ô∏è Limitations

1. **Language**: Currently supports English only
2. **Domain**: Optimized for customer support; may need fine-tuning for other domains
3. **Agents**: Limited to 3 agent types (can be extended with additional training)
4. **Context**: Works best with single-turn queries; multi-turn conversations may need context handling
5. **Edge Cases**: Ambiguous queries may require fallback logic

## üîÆ Future Improvements

- [ ] Add support for more languages
- [ ] Expand to 5+ agent types (sales, feedback, onboarding)
- [ ] Handle multi-turn conversations
- [ ] Add confidence scores for routing decisions
- [ ] Support for compound queries requiring multiple agents

## üìù Citation

```bibtex
@misc{{functiongemma_multiagent_router,
  author = {{Bhaiya Singh}},
  title = {{Multi-Agent Router: Fine-tuned FunctionGemma for Customer Support}},
  year = {{2025}},
  publisher = {{Hugging Face}},
  howpublished = {{\\url{{https://huggingface.co/{hf_repo_name}}}}}
}}
```

## üìÑ License

This model inherits the [Gemma License](https://ai.google.dev/gemma/terms) from the base model.

## üôè Acknowledgments

- Base model: [google/functiongemma-270m-it](https://huggingface.co/google/functiongemma-270m-it)
- Training framework: [Hugging Face TRL](https://github.com/huggingface/trl)
- Dataset: [bhaiyahnsingh45/multiagent-router-finetuning](https://huggingface.co/datasets/bhaiyahnsingh45/multiagent-router-finetuning)

## üìß Contact

For questions, issues, or collaboration opportunities:
- Open an issue on the [model repository](https://huggingface.co/{hf_repo_name})
- Dataset issues: [dataset repository](https://huggingface.co/datasets/bhaiyahnsingh45/multiagent-router-finetuning)

---

**Built with ‚ù§Ô∏è using FunctionGemma and Hugging Face Transformers**
"""

        # Save model card
        os.makedirs(checkpoint_dir, exist_ok=True)
        with open(f"{checkpoint_dir}/README.md", "w", encoding="utf-8") as f:
            f.write(model_card)

        print("‚úÖ Model card created")

        # Push README separately to ensure it's uploaded
        from huggingface_hub import HfApi
        api = HfApi()
        api.upload_file(
            path_or_fileobj=f"{checkpoint_dir}/README.md",
            path_in_repo="README.md",
            repo_id=hf_repo_name,
            repo_type="model",
            commit_message="Add comprehensive model card with usage examples"
        )
        print("‚úÖ Model card pushed to hub")

    except Exception as e:
        print(f"‚ùå Error pushing to hub: {e}")
        print("üí° You can manually push later using: trainer.push_to_hub()")
else:
    print("‚ö†Ô∏è Skipping Hugging Face upload (no token or repo name provided)")
    print("üí° To push later, set hf_token and hf_repo_name variables and run:")
    print("   trainer.push_to_hub()")

print("\nüéâ Training pipeline complete!")