# Password Game RL Training with VERL

Train Qwen3-0.6B to solve the Password Game using PPO.

**Rules**: 9 progressive password rules  
**Reward**: +1 per rule passed, -0.1 per character

In [None]:
!nvidia-smi -L

In [None]:
!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu121
!pip install -q flash-attn --no-build-isolation
!pip install -q transformers accelerate datasets tokenizers wandb tqdm numpy

In [None]:
import os, json, random, time, re
from dataclasses import dataclass, asdict
from typing import List, Dict, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
import wandb
from tqdm.auto import tqdm

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
assert DEVICE == "cuda", "GPU required"

## Config

In [None]:
@dataclass
class Config:
    # Model
    model_name: str = "Qwen/Qwen3-0.6B"
    precision: str = "bfloat16"
    use_flash_attn: bool = True
    
    # Training
    num_epochs: int = 3
    num_steps_per_epoch: int = 100
    batch_size: int = 4
    samples_per_prompt: int = 4
    learning_rate: float = 1e-6
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    warmup_steps: int = 50
    
    # PPO
    ppo_epochs: int = 4
    clip_range: float = 0.2
    value_loss_coef: float = 0.1
    kl_coef: float = 0.05
    gamma: float = 0.99
    gae_lambda: float = 0.95
    normalize_advantages: bool = True
    
    # Generation
    max_prompt_length: int = 1024
    max_new_tokens: int = 256
    temperature: float = 0.8
    top_p: float = 0.9
    top_k: int = 50
    
    # Password Game
    num_rules: int = 9
    reward_per_rule: float = 1.0
    length_penalty: float = 0.1
    
    # Data
    num_train_samples: int = 1000
    num_val_samples: int = 200
    
    # Logging
    wandb_project: str = "password-game-rl"
    wandb_run_name: Optional[str] = None
    log_interval: int = 10
    eval_interval: int = 50
    save_interval: int = 100
    output_dir: str = f"./password_game_{int(time.time())}"
    seed: int = 42
    
    def __post_init__(self):
        if self.wandb_run_name is None:
            self.wandb_run_name = f"password_ppo_{int(time.time())}"
        os.makedirs(self.output_dir, exist_ok=True)

config = Config()
print(f"Model: {config.model_name}")
print(f"Batch: {config.batch_size} x {config.samples_per_prompt} = {config.batch_size * config.samples_per_prompt}")
print(f"Output: {config.output_dir}")

with open(os.path.join(config.output_dir, "config.json"), "w") as f:
    json.dump(asdict(config), f, indent=2)

In [None]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(config.seed)

## Models

In [None]:
dtype = torch.bfloat16 if config.precision == "bfloat16" else torch.float32

tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
print(f"Tokenizer: {len(tokenizer)} tokens")

In [None]:
policy_model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    torch_dtype=dtype,
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="flash_attention_2" if config.use_flash_attn else "eager"
)
policy_model.config.use_cache = False
print(f"Policy: {sum(p.numel() for p in policy_model.parameters())/1e9:.2f}B params")

In [None]:
reference_model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    torch_dtype=dtype,
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="flash_attention_2" if config.use_flash_attn else "eager"
)
reference_model.eval()
for param in reference_model.parameters():
    param.requires_grad = False
print("Reference: frozen")

In [None]:
class ValueHead(nn.Module):
    def __init__(self, hidden_size: int):
        super().__init__()
        self.linear = nn.Linear(hidden_size, 1)
        nn.init.orthogonal_(self.linear.weight, gain=0.01)
        nn.init.constant_(self.linear.bias, 0.0)
    
    def forward(self, hidden_states):
        return self.linear(hidden_states)

value_head = ValueHead(policy_model.config.hidden_size).to(DEVICE).to(dtype)
print(f"Value head: {policy_model.config.hidden_size}")

## Password Game Environment

In [None]:
PASSWORD_RULES = [
    "Your password must be at least 5 characters.",
    "Your password must include a number.",
    "Your password must include an uppercase letter.",
    "Your password must include a special character.",
    "The digits in your password must add up to 25.",
    "Your password must include a month of the year.",
    "Your password must include a roman numeral.",
    "Your password must include one of our sponsors: (Pepsi, Starbucks, Shell)",
    "The roman numerals in your password should multiply to 35.",
]

INSTRUCTIONS = """You are playing a password game. Create a password that satisfies ALL the given rules.
Return ONLY the password string, nothing else."""

def check_rule(password: str, rule_idx: int) -> bool:
    pwd = password
    
    if rule_idx == 0:
        return len(pwd) >= 5
    elif rule_idx == 1:
        return any(c.isdigit() for c in pwd)
    elif rule_idx == 2:
        return any(c.isupper() for c in pwd)
    elif rule_idx == 3:
        return any(not c.isalnum() for c in pwd)
    elif rule_idx == 4:
        return sum(int(c) for c in pwd if c.isdigit()) == 25
    elif rule_idx == 5:
        months = ['january', 'february', 'march', 'april', 'may', 'june', 'july', 'august', 'september', 'october', 'november', 'december']
        return any(m in pwd.lower() for m in months)
    elif rule_idx == 6:
        return bool(re.search(r'[IVXLCDM]+', pwd))
    elif rule_idx == 7:
        sponsors = ['pepsi', 'starbucks', 'shell']
        return any(s in pwd.lower() for s in sponsors)
    elif rule_idx == 8:
        romans = re.findall(r'[IVXLCDM]+', pwd)
        if not romans:
            return False
        roman_vals = {'I': 1, 'V': 5, 'X': 10, 'L': 50, 'C': 100, 'D': 500, 'M': 1000}
        product = 1
        for r in romans:
            val = 0
            prev = 0
            for c in reversed(r):
                v = roman_vals.get(c, 0)
                if v < prev:
                    val -= v
                else:
                    val += v
                prev = v
            if val > 0:
                product *= val
        return product == 35
    return False

def compute_reward(password: str, num_active_rules: int) -> float:
    passing = sum(check_rule(password, i) for i in range(num_active_rules))
    rule_score = passing * config.reward_per_rule
    length_penalty = len(password) * config.length_penalty
    return rule_score - length_penalty

def format_prompt(num_active_rules: int) -> str:
    rules_text = "\n".join([f"{i+1}. {PASSWORD_RULES[i]}" for i in range(num_active_rules)])
    return f"{INSTRUCTIONS}\n\nRules:\n{rules_text}\n\nPassword:"

print(f"Loaded {len(PASSWORD_RULES)} rules")
print(f"Example prompt:\n{format_prompt(3)[:200]}...")

In [None]:
class PasswordDataset(Dataset):
    def __init__(self, num_samples: int, max_rules: int):
        self.prompts = []
        self.num_rules = []
        for _ in range(num_samples):
            n = random.randint(3, min(max_rules, len(PASSWORD_RULES)))
            self.prompts.append(format_prompt(n))
            self.num_rules.append(n)
    
    def __len__(self):
        return len(self.prompts)
    
    def __getitem__(self, idx):
        return {'prompt': self.prompts[idx], 'num_rules': self.num_rules[idx]}

train_dataset = PasswordDataset(config.num_train_samples, 9)
val_dataset = PasswordDataset(config.num_val_samples, 9)
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

## Baseline Eval

In [None]:
def evaluate_model(model, dataset, num_samples=100, desc="Eval"):
    model.eval()
    model.config.use_cache = True
    
    total_reward = 0.0
    
    with torch.no_grad():
        for i in tqdm(range(min(num_samples, len(dataset))), desc=desc):
            item = dataset[i]
            prompt = item['prompt']
            num_rules = item['num_rules']
            
            inputs = tokenizer([prompt], return_tensors="pt", padding=True, truncation=True, max_length=config.max_prompt_length).to(DEVICE)
            outputs = model.generate(
                **inputs,
                max_new_tokens=config.max_new_tokens,
                do_sample=True,
                temperature=config.temperature,
                top_p=config.top_p,
                top_k=config.top_k,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
            
            generated = tokenizer.decode(outputs[0, inputs.input_ids.size(1):], skip_special_tokens=True)
            password = generated.strip().split()[0] if generated.strip() else ""
            reward = compute_reward(password, num_rules)
            total_reward += reward
    
    model.config.use_cache = False
    model.train()
    return total_reward / min(num_samples, len(dataset))

baseline_reward = evaluate_model(policy_model, val_dataset, num_samples=100, desc="Baseline")
print(f"\nBaseline reward: {baseline_reward:.4f}")

## PPO Utils

In [None]:
def compute_log_probs(model, input_ids, attention_mask, return_values=False):
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=return_values)
    logits = outputs.logits
    log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
    token_log_probs = torch.gather(log_probs, dim=2, index=input_ids[:, 1:].unsqueeze(-1)).squeeze(-1)
    mask = attention_mask[:, 1:].bool()
    token_log_probs = token_log_probs * mask
    if return_values:
        hidden_states = outputs.hidden_states[-1]
        values = value_head(hidden_states).squeeze(-1)
        return token_log_probs, values
    return token_log_probs

def compute_advantages(rewards, values, masks, gamma=0.99, gae_lambda=0.95):
    batch_size, seq_len = rewards.shape
    advantages = torch.zeros_like(rewards)
    gae = 0
    for t in reversed(range(seq_len)):
        next_value = 0 if t == seq_len - 1 else values[:, t + 1]
        delta = rewards[:, t] + gamma * next_value * masks[:, t] - values[:, t]
        gae = delta + gamma * gae_lambda * masks[:, t] * gae
        advantages[:, t] = gae
    returns = advantages + values
    return advantages, returns

def whiten(values, mask):
    mean = (values * mask).sum() / mask.sum()
    var = ((values - mean) ** 2 * mask).sum() / mask.sum()
    std = torch.sqrt(var + 1e-8)
    return (values - mean) / std

print("PPO utils defined")

## Training

In [None]:
optimizer = torch.optim.AdamW(
    list(policy_model.parameters()) + list(value_head.parameters()),
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)
total_steps = config.num_epochs * config.num_steps_per_epoch
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=total_steps)
print(f"Optimizer ready: {total_steps} steps")

In [None]:
wandb_run = wandb.init(project=config.wandb_project, name=config.wandb_run_name, config=asdict(config))
print(f"WandB: {wandb_run.get_url()}")

In [None]:
policy_model.train()
value_head.train()
global_step = 0
best_val_reward = -float('inf')
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

print("Starting training...")

for epoch in range(config.num_epochs):
    print(f"\nEpoch {epoch+1}/{config.num_epochs}")
    epoch_iter = iter(train_dataloader)
    
    for step in tqdm(range(config.num_steps_per_epoch), desc=f"Epoch {epoch+1}"):
        try:
            batch = next(epoch_iter)
        except StopIteration:
            epoch_iter = iter(train_dataloader)
            batch = next(epoch_iter)
        
        prompts = batch['prompt']
        num_rules_list = batch['num_rules']
        
        # Rollout
        policy_model.eval()
        policy_model.config.use_cache = True
        with torch.no_grad():
            prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=config.max_prompt_length).to(DEVICE)
            all_responses = []
            all_full_ids = []
            all_masks = []
            for _ in range(config.samples_per_prompt):
                outputs = policy_model.generate(
                    **prompt_inputs,
                    max_new_tokens=config.max_new_tokens,
                    do_sample=True,
                    temperature=config.temperature,
                    top_p=config.top_p,
                    top_k=config.top_k,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )
                generated_ids = outputs[:, prompt_inputs.input_ids.size(1):]
                responses = []
                for gen_ids in generated_ids:
                    resp = tokenizer.decode(gen_ids, skip_special_tokens=True)
                    password = resp.strip().split()[0] if resp.strip() else ""
                    responses.append(password)
                all_responses.extend(responses)
                all_full_ids.append(outputs)
                mask = torch.ones_like(outputs)
                mask[outputs == tokenizer.pad_token_id] = 0
                all_masks.append(mask)
            all_full_ids = torch.cat(all_full_ids, dim=0)
            all_masks = torch.cat(all_masks, dim=0)
            expanded_num_rules = num_rules_list * config.samples_per_prompt
        
        # Rewards
        rewards = torch.tensor([compute_reward(pwd, nr) for pwd, nr in zip(all_responses, expanded_num_rules)], device=DEVICE, dtype=dtype)
        mean_reward = rewards.mean().item()
        
        # Old probs & values
        with torch.no_grad():
            old_log_probs, old_values = compute_log_probs(policy_model, all_full_ids, all_masks, return_values=True)
            ref_log_probs = compute_log_probs(reference_model, all_full_ids, all_masks)
            prompt_len = prompt_inputs.input_ids.size(1)
            old_values_gen = old_values[:, prompt_len:]
            generated_ids_all = all_full_ids[:, prompt_len:]
        
        # Advantages
        response_mask = (generated_ids_all != tokenizer.pad_token_id).float()
        reward_per_token = torch.zeros_like(generated_ids_all, dtype=dtype)
        for i, reward in enumerate(rewards):
            valid = generated_ids_all[i] != tokenizer.pad_token_id
            reward_per_token[i][valid] = reward / valid.sum().clamp(min=1)
        advantages, returns = compute_advantages(reward_per_token, old_values_gen, response_mask, config.gamma, config.gae_lambda)
        if config.normalize_advantages:
            advantages = whiten(advantages, response_mask)
        
        # PPO updates
        policy_model.train()
        policy_model.config.use_cache = False
        for ppo_epoch in range(config.ppo_epochs):
            curr_log_probs, curr_values = compute_log_probs(policy_model, all_full_ids, all_masks, return_values=True)
            curr_values_gen = curr_values[:, prompt_len:]
            curr_lp_gen = curr_log_probs[:, prompt_len-1:]
            old_lp_gen = old_log_probs[:, prompt_len-1:]
            ref_lp_gen = ref_log_probs[:, prompt_len-1:]
            
            ratio = torch.exp(curr_lp_gen - old_lp_gen.detach())
            policy_loss = torch.max(
                -advantages.detach() * ratio,
                -advantages.detach() * torch.clamp(ratio, 1-config.clip_range, 1+config.clip_range)
            )
            policy_loss = (policy_loss * response_mask).sum() / response_mask.sum()
            value_loss = ((curr_values_gen - returns.detach())**2 * response_mask).sum() / response_mask.sum()
            kl_penalty = ((curr_lp_gen - ref_lp_gen.detach()) * response_mask).sum() / response_mask.sum()
            
            loss = policy_loss + config.value_loss_coef * value_loss + config.kl_coef * kl_penalty
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(list(policy_model.parameters()) + list(value_head.parameters()), config.max_grad_norm)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        # Logging
        if global_step % config.log_interval == 0:
            wandb.log({"step": global_step, "loss": loss.item(), "reward": mean_reward, "kl": kl_penalty.item()}, step=global_step)
        
        # Eval
        if global_step % config.eval_interval == 0 and global_step > 0:
            val_reward = evaluate_model(policy_model, val_dataset, num_samples=50, desc=f"Eval@{global_step}")
            wandb.log({"val_reward": val_reward}, step=global_step)
            if val_reward > best_val_reward:
                best_val_reward = val_reward
                best_dir = os.path.join(config.output_dir, "best_model")
                os.makedirs(best_dir, exist_ok=True)
                policy_model.save_pretrained(best_dir)
                tokenizer.save_pretrained(best_dir)
                print(f"\nBest: {best_val_reward:.4f}")
        
        # Checkpoint
        if global_step % config.save_interval == 0 and global_step > 0:
            ckpt_dir = os.path.join(config.output_dir, f"checkpoint-{global_step}")
            os.makedirs(ckpt_dir, exist_ok=True)
            policy_model.save_pretrained(ckpt_dir)
            tokenizer.save_pretrained(ckpt_dir)
        
        global_step += 1
        torch.cuda.empty_cache()

print(f"\nTraining complete! Best val: {best_val_reward:.4f}")

## Final Eval

In [None]:
final_reward = evaluate_model(policy_model, val_dataset, num_samples=len(val_dataset), desc="Final")
print(f"Final reward: {final_reward:.4f}")
print(f"Improvement: {final_reward - baseline_reward:.4f}")

final_dir = os.path.join(config.output_dir, "final_model")
os.makedirs(final_dir, exist_ok=True)
policy_model.save_pretrained(final_dir)
tokenizer.save_pretrained(final_dir)
print(f"Saved to {final_dir}")

summary = {
    "baseline": baseline_reward,
    "final": final_reward,
    "best_val": best_val_reward,
    "improvement": final_reward - baseline_reward
}
with open(os.path.join(config.output_dir, "summary.json"), "w") as f:
    json.dump(summary, f, indent=2)

wandb.finish()
print(f"\nSummary: {summary}")