# GSPO Training for Tool-Calling

**Portable notebook for Unsloth Docker + any NVIDIA GPU**

## Quick Start

### Option 1: Unsloth Docker (Recommended)
```bash
# On any GPU cloud (Nebius VM, RunPod, Lambda, etc.) or local:
docker run -d -e JUPYTER_PASSWORD="your_password" \
  -p 8888:8888 \
  -v $(pwd)/work:/workspace/work \
  --gpus all \
  unsloth/unsloth
```
Then open `http://localhost:8888` and upload this notebook.

### Option 2: Nebius Managed JupyterLab
Use their JupyterLab service and run the install cell below (takes ~3 min).

---

## What is GSPO?

**GSPO (Group Sequence Policy Optimization)** generates completions during training and optimizes using custom reward functions.

| Method | Best For | Tool-Calling Use Case |
|--------|----------|----------------------|
| **SFT** | Teaching new skills | Initial tool syntax |
| **KTO** | Refining preferences | Good vs bad choices |
| **GSPO** | Reward optimization | Tool selection quality |

## 1. Environment Check

In [None]:
import torch
import os
import shutil

print("=" * 50)
print("ENVIRONMENT CHECK")
print("=" * 50)

# GPU
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print(f"CUDA: {torch.version.cuda}")
else:
    print("WARNING: No GPU detected!")

# Detect environment
IN_UNSLOTH_DOCKER = os.path.exists("/workspace/unsloth-notebooks")
IN_NEBIUS = os.path.exists(os.path.expanduser("~/persistent"))

print(f"\nUnsloth Docker: {'Yes' if IN_UNSLOTH_DOCKER else 'No'}")
print(f"Nebius JupyterLab: {'Yes' if IN_NEBIUS else 'No'}")

# Set work directory
if IN_UNSLOTH_DOCKER:
    WORK_DIR = "/workspace/work"
elif IN_NEBIUS:
    WORK_DIR = os.path.expanduser("~/persistent")
else:
    WORK_DIR = os.getcwd()

OUTPUT_DIR = os.path.join(WORK_DIR, "GSPO_Training")
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"\nWork dir: {WORK_DIR}")
print(f"Output dir: {OUTPUT_DIR}")

# Disk space
total, used, free = shutil.disk_usage(WORK_DIR)
print(f"Disk: {free // (2**30)} GB free")

In [None]:
# Install dependencies (ONLY if NOT in Unsloth Docker)
# Skip this cell if using unsloth/unsloth Docker image

if not IN_UNSLOTH_DOCKER:
    print("Installing Unsloth (this takes ~3 minutes)...")
    !pip install "unsloth[cu121-torch240] @ git+https://github.com/unslothai/unsloth.git" -q
    !pip install -U trl peft datasets accelerate bitsandbytes xformers triton -q
    print("Done!")
else:
    print("Unsloth Docker detected - skipping install (pre-installed!)")

## 2. HuggingFace Auth

In [None]:
import os

# Check environment variables
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HF_API_KEY")

# Try .env files
if not HF_TOKEN:
    for path in [f"{WORK_DIR}/.env", os.path.expanduser("~/.env"), ".env"]:
        if os.path.exists(path):
            with open(path) as f:
                for line in f:
                    if line.startswith(("HF_TOKEN=", "HF_API_KEY=")):
                        HF_TOKEN = line.split("=", 1)[1].strip().strip('"\'')
                        print(f"Loaded from: {path}")
                        break
            if HF_TOKEN: break

# Manual fallback
if not HF_TOKEN:
    print("No HF_TOKEN found. Uncomment below to set:")
    # HF_TOKEN = "hf_your_token_here"

# Validate
hf_user = None
if HF_TOKEN:
    os.environ["HF_TOKEN"] = HF_TOKEN
    from huggingface_hub import HfApi
    try:
        hf_user = HfApi().whoami(token=HF_TOKEN)["name"]
        print(f"Authenticated: {hf_user}")
    except Exception as e:
        print(f"Token invalid: {e}")
else:
    print("HF_TOKEN not set - upload disabled")

## 3. Configuration

In [None]:
# =============================================================================
# CONFIGURATION - Edit these values
# =============================================================================

MODEL_NAME = "unsloth/mistral-7b-instruct-v0.3-bnb-4bit"
MAX_SEQ_LENGTH = 2048

DATASET_NAME = "professorsynapse/claudesidian-behaviors-merged"
DATASET_FILE = "behavior_gspo_v1.3.jsonl"

OUTPUT_MODEL_NAME = "nexus-tools-gspo"

# LoRA config
LORA_R = 32
LORA_ALPHA = 64
LORA_DROPOUT = 0.05

# Training config
BATCH_SIZE = 4
GRAD_ACCUM = 4
LEARNING_RATE = 5e-6
NUM_GENERATIONS = 4
MAX_NEW_TOKENS = 512

print(f"Model: {MODEL_NAME}")
print(f"Dataset: {DATASET_NAME}/{DATASET_FILE}")
print(f"Effective batch: {BATCH_SIZE * GRAD_ACCUM}")

## 4. Load Model

In [None]:
from unsloth import FastLanguageModel

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    dtype=None,
    load_in_4bit=True,
    token=HF_TOKEN,
)

print(f"Loaded: {sum(p.numel() for p in model.parameters()):,} params")

## 5. Apply LoRA

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable: {trainable:,} params")

## 6. Load Dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset(DATASET_NAME, data_files=DATASET_FILE, split="train")
print(f"Loaded {len(dataset)} examples")
print(f"Sample tool: {dataset[0]['ground_truth_tool']}")

## 7. Reward Functions

In [None]:
import re
import json

CONTEXT_FIELDS = ["sessionId", "workspaceId", "sessionDescription",
                  "sessionMemory", "toolContext", "primaryGoal", "subgoal"]

def extract_tool_call(text):
    """Extract tool name and args from output."""
    # Pattern 1: tool_call: name\narguments: {...}
    m = re.search(r'tool_call:\s*(\w+)\s*\narguments:\s*({.+?})', text, re.DOTALL)
    if m:
        try: return m.group(1), json.loads(m.group(2))
        except: return m.group(1), {}
    
    # Pattern 2: {"name": "...", "arguments": ...}
    m = re.search(r'"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*({.+?})', text, re.DOTALL)
    if m:
        try: return m.group(1), json.loads(m.group(2).replace('\\"', '"'))
        except: return m.group(1), {}
    
    # Pattern 3: Just tool name
    m = re.search(r'(\w+Manager_\w+|\w+Librarian_\w+)', text)
    return (m.group(1), {}) if m else (None, None)


def combined_reward(completions, prompts, ground_truth_tool, **kw):
    """Combined reward (max 2.0): tool(1.0) + json(0.3) + context(0.5) + format(0.2)"""
    rewards = []
    for c in completions:
        text = c[0]["content"] if isinstance(c, list) else c
        tool, args = extract_tool_call(text)
        r = 0.0
        
        # Tool selection (+1.0 exact, +0.3 same family)
        if tool == ground_truth_tool:
            r += 1.0
        elif tool and ground_truth_tool and tool.split('_')[0] == ground_truth_tool.split('_')[0]:
            r += 0.3
        
        # JSON structure (+0.3)
        if args and isinstance(args, dict):
            r += 0.3
        
        # Context completeness (up to +0.5)
        if args and "context" in args and isinstance(args.get("context"), dict):
            ctx = args["context"]
            present = sum(1 for f in CONTEXT_FIELDS if ctx.get(f))
            r += (present / len(CONTEXT_FIELDS)) * 0.5
        
        # Format (+0.2)
        if re.search(r'tool_call|function|arguments|Manager_|Librarian_', text, re.I):
            r += 0.2
        
        rewards.append(r)
    return rewards

# Test
test = [[{"content": 'tool_call: vaultManager_list\narguments: {"context": {"sessionId": "x", "workspaceId": "x", "sessionDescription": "x", "sessionMemory": "x", "toolContext": "x", "primaryGoal": "x", "subgoal": "x"}}'}]]
print(f"Test reward: {combined_reward(test, [''], 'vaultManager_list')[0]:.2f} / 2.0")

## 8. Prepare Dataset

In [None]:
# Set chat template
MISTRAL_TEMPLATE = """{{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' %}{% if loop.index == 1 %}{{ message['content'] + ' ' }}{% endif %}{% elif message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% endif %}{% endfor %}"""

if tokenizer.chat_template is None and 'mistral' in MODEL_NAME.lower():
    tokenizer.chat_template = MISTRAL_TEMPLATE

def format_example(ex):
    return {
        "prompt": tokenizer.apply_chat_template(ex["prompt"], tokenize=False, add_generation_prompt=True),
        "ground_truth_tool": ex["ground_truth_tool"],
        "ground_truth_args": json.dumps(ex["ground_truth_args"], default=str) if ex["ground_truth_args"] else "{}"
    }

formatted_dataset = dataset.map(format_example, remove_columns=["prompt"])
print(f"Formatted {len(formatted_dataset)} examples")

## 9. Setup Trainer

In [None]:
from trl import GRPOConfig, GRPOTrainer
from unsloth import is_bfloat16_supported
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = os.path.join(OUTPUT_DIR, timestamp)

training_args = GRPOConfig(
    output_dir=run_dir,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    num_generations=NUM_GENERATIONS,
    max_completion_length=MAX_NEW_TOKENS,
    importance_sampling_level="sequence",  # GSPO
    learning_rate=LEARNING_RATE,
    num_train_epochs=1,
    fp16=not is_bfloat16_supported(),
    bf16=is_bfloat16_supported(),
    optim="adamw_8bit",
    logging_steps=5,
    save_steps=50,
    save_total_limit=3,
    seed=42,
    report_to="none",
)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=combined_reward,
    args=training_args,
    train_dataset=formatted_dataset,
)

print(f"Trainer ready! Output: {run_dir}")

## 10. Train!

In [None]:
import glob

# Check for resume
checkpoints = sorted(glob.glob(f"{run_dir}/checkpoint-*"))
resume_from = max(checkpoints, key=lambda x: int(x.split("-")[-1])) if checkpoints else None
if resume_from:
    print(f"Resuming from: {resume_from}")

print("\n" + "=" * 50)
print("STARTING GSPO TRAINING")
print("=" * 50 + "\n")

trainer.train(resume_from_checkpoint=resume_from)

print("\n" + "=" * 50)
print("TRAINING COMPLETE")
print("=" * 50)

## 11. Save & Upload

In [None]:
# Save locally
final_dir = os.path.join(run_dir, "final_model")
model.save_pretrained(final_dir)
tokenizer.save_pretrained(final_dir)
print(f"Saved: {final_dir}")

In [None]:
# Upload LoRA to HuggingFace
if hf_user and HF_TOKEN:
    repo = f"{hf_user}/{OUTPUT_MODEL_NAME}"
    model.push_to_hub(repo, token=HF_TOKEN)
    tokenizer.push_to_hub(repo, token=HF_TOKEN)
    print(f"LoRA: https://huggingface.co/{repo}")
else:
    print("Set HF_TOKEN to upload")

In [None]:
# Upload merged 16-bit (optional)
if hf_user and HF_TOKEN:
    model.push_to_hub_merged(
        f"{hf_user}/{OUTPUT_MODEL_NAME}-merged",
        tokenizer,
        save_method="merged_16bit",
        token=HF_TOKEN,
    )
    print(f"Merged: https://huggingface.co/{hf_user}/{OUTPUT_MODEL_NAME}-merged")

In [None]:
# Upload GGUF quantizations (optional)
if hf_user and HF_TOKEN:
    model.push_to_hub_gguf(
        f"{hf_user}/{OUTPUT_MODEL_NAME}",
        tokenizer,
        quantization_method=["q4_k_m", "q5_k_m", "q8_0"],
        token=HF_TOKEN,
    )
    print(f"GGUF: https://huggingface.co/{hf_user}/{OUTPUT_MODEL_NAME} (Files tab)")

## Done!

**Local model:** `{OUTPUT_DIR}/{timestamp}/final_model/`

**Next steps:**
1. Test with LM Studio or Ollama
2. Compare GSPO vs SFT on tool selection
3. Iterate on reward functions