Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions configs/areal_waa_grpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# AReaL GRPO config for WAA desktop automation training.
#
# Trains a VLM (Qwen2.5-VL-3B) to automate Windows desktop tasks using
# AReaL's async RL framework with GRPO (Group Relative Policy Optimization).
#
# The workflow (WAADesktopWorkflow) wraps WAADesktopEnv to:
# 1. Reset the Windows VM to a task's initial state
# 2. Loop: screenshot -> LLM (via AReaL proxy) -> parse action -> execute
# 3. Evaluate with dense milestone rewards
# 4. Return reward for gradient computation
#
# AReaL transparently handles token tracking, logprobs, and gradient
# computation through its OpenAI-compatible proxy.
#
# Prerequisites:
# 1. WAA server reachable (SSH tunnel or direct):
# ssh -N -L 5000:localhost:5000 azureuser@<VM_IP>
# 2. AReaL installed: pip install areal
# 3. openadapt-evals installed: pip install openadapt-evals
#
# Architecture:
# GPU VM CPU VM
# +----------------------------+ +---------------------+
# | AReaL | | Docker |
# | PPOTrainer | | QEMU (Windows 11) |
# | OpenAI proxy (sglang) | HTTP | WAA Flask API |
# | WAADesktopWorkflow ------+----->| /screenshot |
# | Qwen2.5-VL-3B (actor) | | /execute_windows |
# | Qwen2.5-VL-3B (ref) | | /evaluate |
# +----------------------------+ +---------------------+
#
# Usage:
# python examples/agent_workflow/train.py \
# --config configs/areal_waa_grpo.yaml \
# scheduler.type=local
#
# For single-GPU dev (override cluster settings):
# python examples/agent_workflow/train.py \
# --config configs/areal_waa_grpo.yaml \
# scheduler.type=local \
# cluster.n_nodes=1 \
# cluster.n_gpus_per_node=1 \
# allocation_mode="sglang:d1p1t1"

experiment_name: waa-desktop-grpo
trial_name: trial0

seed: 42
enable_offload: false
total_train_epochs: 50
tokenizer_path: ${actor.path}

# --- Workflow ---
# AReaL resolves this dotted path and instantiates the class.
# WAADesktopWorkflow.run() receives:
# data = dataset row (task_id, instruction, max_steps, server_url)
# extra_kwargs = {base_url, api_key, http_client} from AReaL proxy
workflow: openadapt_evals.training.areal_workflow.WAADesktopWorkflow
eval_workflow: ${workflow}

# --- Cluster ---
cluster:
n_nodes: 1
n_gpus_per_node: 1 # 1 for g5.xlarge; use 4+ for g5.12xlarge
fileroot: /tmp/areal/experiments
name_resolve:
type: nfs
nfs_record_root: /tmp/areal/name_resolve

# Single-GPU allocation: 1 device for sglang inference
allocation_mode: "sglang:d1p1t1"

scheduler:
type: null

# --- Rollout ---
rollout:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
tokenizer_path: ${tokenizer_path}
max_concurrent_rollouts: 4 # Low: each rollout talks to a single VM
queue_size: null
consumer_batch_size: ${train_dataset.batch_size}
max_head_offpolicyness: 2
enable_rollout_tracing: false
scheduling_spec: ${actor.scheduling_spec}
dump_to_file: true
openai:
mode: inline
export_style: individual
turn_discount: 1.0 # No discount across turns (episodic)

# --- Generation config ---
gconfig:
n_samples: 4 # GRPO group size (4 rollouts per task)
min_new_tokens: 0
max_new_tokens: 256 # Action JSON is short
max_tokens: 2048
greedy: false
temperature: 0.7

# --- Actor model ---
actor:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: Qwen/Qwen2.5-VL-3B-Instruct
init_from_scratch: false
disable_dropout: true
gradient_checkpointing: true
dtype: bfloat16
mb_spec:
max_tokens_per_mb: 8192
optimizer:
type: adam
lr: 1.0e-5 # Conservative LR for VLM fine-tuning
weight_decay: 0.01
beta1: 0.9
beta2: 0.999
eps: 1e-8
lr_scheduler_type: constant
gradient_clipping: 1.0
warmup_steps_proportion: 0.01
eps_clip: 0.2 # PPO clip range
temperature: ${gconfig.temperature}
reward_scaling: 1.0 # Rewards are already 0-1 from evaluate_dense
reward_bias: 0.0
kl_ctl: 0.0 # No KL penalty (GRPO style)
ppo_n_minibatches: 1
recompute_logprob: true
use_decoupled_loss: true
behave_imp_weight_cap: 5.0
reward_norm: null
adv_norm:
mean_level: batch
std_level: batch
weight_update_mode: xccl
max_new_tokens: ${gconfig.max_new_tokens}
scheduling_spec:
- task_type: worker
port_count: 2
gpu: 1
cmd: python3 -m areal.infra.rpc.rpc_server

# --- Reference model ---
ref:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: ${actor.path}
init_from_scratch: false
disable_dropout: true
dtype: ${actor.dtype}
mb_spec:
max_tokens_per_mb: 8192
optimizer: null
scheduling_strategy:
type: colocation
target: actor
scheduling_spec: ${actor.scheduling_spec}

# --- Inference engine (sglang) ---
sglang:
model_path: ${actor.path}
random_seed: ${seed}
skip_tokenizer_init: true
dtype: ${actor.dtype}
max_running_requests: null
context_length: 4096
mem_fraction_static: 0.7 # Leave room for VLM image processing

# --- Training dataset ---
# Each row should have: task_id, instruction, max_steps, server_url
# For WAA, this is typically a small set of tasks (4-20) repeated.
train_dataset:
batch_size: 4 # Small batch: each sample is a full episode
shuffle: true
pin_memory: true
num_workers: 1
path: REPLACE_WITH_DATASET_PATH # HF dataset or local JSONL
type: rl
max_length: 2048

valid_dataset:
batch_size: 4
pin_memory: true
num_workers: 1
path: ${train_dataset.path}
type: rl

# --- Model saving ---
saver:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 5
freq_steps: null
freq_secs: null

# --- Recovery ---
recover:
mode: disabled
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 5

# --- Evaluation ---
evaluator:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 5

# --- Logging ---
stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled # Set to "online" for W&B tracking
Loading
Loading