# Password Game RL Training with VERL

Train Qwen3-0.6B to solve the Password Game using PPO.

**Rules**: 9 progressive password rules  
**Reward**: +1 per rule passed, -0.1 per character

In [None]:
!nvidia-smi -L

In [None]:
!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu121
!pip install -q flash-attn --no-build-isolation
!pip install -q transformers accelerate datasets tokenizers wandb tqdm numpy

In [None]:
!uv pip install google-genai openai

In [None]:
import os, json, random, time, re
from dataclasses import dataclass, asdict
from typing import List, Dict, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
import wandb
from tqdm.auto import tqdm

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
assert DEVICE == "cuda", "GPU required"

## Config

In [None]:
@dataclass
class Config:
    # Model
    model_name: str = "Qwen/Qwen3-0.6B"
    precision: str = "bfloat16"
    use_flash_attn: bool = True
    
    # Training
    num_epochs: int = 3
    num_steps_per_epoch: int = 100
    batch_size: int = 4
    samples_per_prompt: int = 4
    learning_rate: float = 1e-6
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    warmup_steps: int = 50
    
    # PPO
    ppo_epochs: int = 4
    clip_range: float = 0.2
    value_loss_coef: float = 0.1
    kl_coef: float = 0.05
    gamma: float = 0.99
    gae_lambda: float = 0.95
    normalize_advantages: bool = True
    
    # Generation
    max_prompt_length: int = 1024
    max_new_tokens: int = 256
    temperature: float = 0.8
    top_p: float = 0.9
    top_k: int = 50
    
    # Password Game
    num_rules: int = 9
    reward_per_rule: float = 1.0
    length_penalty: float = 0.1
    
    # Data
    num_train_samples: int = 1000
    num_val_samples: int = 200
    
    # Logging
    wandb_project: str = "password-game-rl"
    wandb_run_name: Optional[str] = None
    log_interval: int = 10
    eval_interval: int = 50
    save_interval: int = 100
    output_dir: str = f"./password_game_{int(time.time())}"
    seed: int = 42
    
    def __post_init__(self):
        if self.wandb_run_name is None:
            self.wandb_run_name = f"password_ppo_{int(time.time())}"
        os.makedirs(self.output_dir, exist_ok=True)

config = Config()
print(f"Model: {config.model_name}")
print(f"Batch: {config.batch_size} x {config.samples_per_prompt} = {config.batch_size * config.samples_per_prompt}")
print(f"Output: {config.output_dir}")

with open(os.path.join(config.output_dir, "config.json"), "w") as f:
    json.dump(asdict(config), f, indent=2)

In [None]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(config.seed)

## Models

In [None]:
dtype = torch.bfloat16 if config.precision == "bfloat16" else torch.float32

tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
print(f"Tokenizer: {len(tokenizer)} tokens")

In [None]:
policy_model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    torch_dtype=dtype,
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="flash_attention_2" if config.use_flash_attn else "eager"
)
policy_model.config.use_cache = False
print(f"Policy: {sum(p.numel() for p in policy_model.parameters())/1e9:.2f}B params")

In [None]:
reference_model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    torch_dtype=dtype,
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="flash_attention_2" if config.use_flash_attn else "eager"
)
reference_model.eval()
for param in reference_model.parameters():
    param.requires_grad = False
print("Reference: frozen")

In [None]:
class ValueHead(nn.Module):
    def __init__(self, hidden_size: int):
        super().__init__()
        self.linear = nn.Linear(hidden_size, 1)
        nn.init.orthogonal_(self.linear.weight, gain=0.01)
        nn.init.constant_(self.linear.bias, 0.0)
    
    def forward(self, hidden_states):
        return self.linear(hidden_states)

value_head = ValueHead(policy_model.config.hidden_size).to(DEVICE).to(dtype)
print(f"Value head: {policy_model.config.hidden_size}")

## Password Game Environment

In [None]:
PASSWORD_RULES = [


In [None]:
class PasswordDataset(Dataset):
    def __init__(self, num_samples: int, max_rules: int):
        self.prompts = []
        self.num_rules = []
        for _ in range(num_samples):
            n = random.randint(3, min(max_rules, len(PASSWORD_RULES)))
            self.prompts.append(format_prompt(n))
            self.num_rules.append(n)
    
    def __len__(self):
        return len(self.prompts)
    
    def __getitem__(self, idx):
        return {'prompt': self.prompts[idx], 'num_rules': self.num_rules[idx]}

train_dataset = PasswordDataset(config.num_train_samples, 9)
val_dataset = PasswordDataset(config.num_val_samples, 9)
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

In [None]:
    "Your password must include one of our sponsors: (Pepsi, Starbucks, Shell) - case insensitive",
    "The roman numerals in your password should multiply to 35.",
    "Your password must include this CAPTCHA: {captcha}",
    "Your password must include today's Wordle answer.",
    "Your password must include a two letter symbol from the periodic table.",
    "Your password must include the current phase of the moon as an emoji",
    "Your password must include the name of this country: {country}",
    "Your password must include a leap year.",
    "🥚 ← This is my chicken Paul. He hasn't hatched yet, please put him in your password and keep him safe.",
    "The elements in your password must have atomic numbers that add up to 200.",
    "Your password is not strong enough 🏋️‍♂️",
    "Your password must contain one of the following affirmations: (I am loved, I am worthy, I am enough)",
    "Paul has hatched! Please don't forget to feed him, he eats three 🐛.",
    "A sacrifice must be made. Pick two letters that you will no longer be able to use.",
    "Your password must include this color in hex. green",
    "Your password must include the length of your password.",
    "The length of your password must be a prime number.",
    "Uh lets skip this one",
    "Your password must contain three consecutive letters that appear in alphabetical order (like 'abc', 'hij', 'xyz')"
]

instructions = """
You are playing a password game. After each step, you will be given a rule and the password has to comply to that rule and all the previous rules.
Keep thinking and only submit as final answer if you are confident it's correct.
If you can't get it, ask for a hint.
You are allowed to use search.
Goal is minimum length of final password.

GOAL: When given a rule, in your answer, only give me the password string you have that satisfies the rule.

Some clarifications as you may need them:
1. Only capital letters count towards Roman Numbers - if it's a valid roman number.
2. For a substring to count as an element, first (or only) letter of the symbol has to be capitalized.
3. For other answers, you can play around with capitalization or not to satisfy the given rules. Checks are case insensitive othersie - not for roman or elements.
4. If you have a conflict, state that and say you want to end the game.

RESPONSE GUIDELINES:
respond in markdown in following format: 
### Password
{password}
### Giveup
{True if you want to give up, false otherwise}
"""

class PasswordGame:
    def __init__(self):
        self.current_rule = 0
        self.game_active = True
        self.captcha = create_captcha()
        self.country = random_country()
        self.wordle_answer = get_wordle_answer()
        self.moon_phase = get_current_moon_phase()
        self.password_history = []


    def get_current_rule(self) -> Optional[str]:
        if self.current_rule >= len(rules) or not self.game_active:
            return None
        rule = rules[self.current_rule]
        if "{captcha}" in rule:
            return rule.format(captcha=self.captcha)
        elif "{country}" in rule:
            return rule.format(country=self.country)
        return rule

    def get_all_rules_up_to_current(self) -> List[str]:
        formatted_rules = []
        for i, rule in enumerate(rules[:self.current_rule + 1]):
            if "{captcha}" in rule:
                formatted_rules.append(rule.format(captcha=self.captcha))
            elif "{country}" in rule:
                formatted_rules.append(rule.format(country=self.country))
            else:
                formatted_rules.append(rule)
        return formatted_rules

    def advance_rule(self):
        self.current_rule += 1
        if self.current_rule >= len(rules):
            self.game_active = False

    def end_game(self):
        self.game_active = False

    def calculate_reward(self, password: str) -> float:
        """Calculate reward: +1 per passing rule, -0.1 per character."""
        satisfied_rules = 0

        # Check all rules up to current rule (inclusive when game ends)
        rule_count = self.current_rule if self.game_active else len(rules)

        for i in range(rule_count):
            if self._check_rule(password, i):
                satisfied_rules += 1

        # +1 per passing rule, -0.1 per character
        rule_score = satisfied_rules
        length_penalty = len(password) * 0.1
        total_reward = rule_score - length_penalty

        return round(total_reward, 1)

    def get_rule_feedback(self, password: str) -> Dict:
        """Get detailed feedback on which rules pass/fail."""
        feedback = {
            "password": password,
            "length": len(password),
            "rules_checked": [],
            "total_passing": 0,
            "reward": 0.0
        }

        # For feedback, include current rule if game is active
        rule_count = (self.current_rule + 1) if self.game_active else len(rules)

        for i in range(rule_count):
            passes = self._check_rule(password, i)
            rule_text = rules[i]
            if "{captcha}" in rule_text:
                rule_text = rule_text.format(captcha=self.captcha)
            elif "{country}" in rule_text:
                rule_text = rule_text.format(country=self.country)

            feedback["rules_checked"].append({
                "rule_index": i,
                "rule_text": rule_text,
                "passes": passes
            })
            if passes:
                feedback["total_passing"] += 1

        feedback["reward"] = self.calculate_reward(password)
        return feedback

    def _check_rule(self, password: str, rule_index: int) -> bool:
        """Comprehensive rule checking for all password rules."""
        if rule_index == 0:  # At least 5 characters
            return len(password) >= 5

        elif rule_index == 1:  # Include a number
            return any(c.isdigit() for c in password)

        elif rule_index == 2:  # Include uppercase letter
            return any(c.isupper() for c in password)

        elif rule_index == 3:  # Include special character
            return any(not c.isalnum() for c in password)

        elif rule_index == 4:  # Digits sum to 25
            digit_sum = sum(int(c) for c in password if c.isdigit())
            return digit_sum == 25

        elif rule_index == 5:  # Include month
            months = ['january', 'february', 'march', 'april', 'may', 'june',
                     'july', 'august', 'september', 'october', 'november', 'december']
            return any(month in password.lower() for month in months)

        elif rule_index == 6:  # Include roman numeral
            roman_pattern = r'[IVXLCDM]+'
            return bool(re.search(roman_pattern, password))

        elif rule_index == 7:  # Include sponsor
            sponsors = ['pepsi', 'starbucks', 'shell']
            return any(sponsor in password.lower() for sponsor in sponsors)

        elif rule_index == 8:  # Roman numerals multiply to 35
            return self._check_roman_multiply(password, 35)

        elif rule_index == 9:  # Include CAPTCHA
            return self.captcha in password.lower()

        elif rule_index == 10:  # Include Wordle answer
            return self.wordle_answer.lower() in password.lower()

        elif rule_index == 11:  # Include periodic element
            return self._check_periodic_element(password)

        elif rule_index == 12:  # Include moon phase emoji
            return self.moon_phase in password

        elif rule_index == 13:  # Include country (dynamic)
            return self.country.lower() in password.lower()

        elif rule_index == 14:  # Include leap year
            return self._check_leap_year(password)

        elif rule_index == 15:  # Include Paul (egg emoji)
            return '🥚' in password

        elif rule_index == 16:  # Atomic numbers sum to 200
            return self._check_atomic_sum(password, 200)

        elif rule_index == 17:  # Not strong enough (always fails)
            return False

        elif rule_index == 18:  # Include affirmation
            affirmations = ['i am loved', 'i am worthy', 'i am enough']
            return any(affirmation in password.lower() for affirmation in affirmations)

        elif rule_index == 19:  # Paul eats 3 bugs
            return password.count('🐛') == 3

        elif rule_index == 20:  # Sacrifice two letters
            # This would need state tracking - simplified for now
            return True

        elif rule_index == 21:  # Include green hex
            green_hex_pattern = r'#00[89ab]000|#008000'
            return bool(re.search(green_hex_pattern, password.lower()))

        elif rule_index == 22:  # Include password length
            return str(len(password)) in password

        elif rule_index == 23:  # Length is prime
            return self._is_prime(len(password))

        elif rule_index == 24:  # Skip this one
            return True

        elif rule_index == 25:  # Include 3 consecutive chars    
            # Normalize to lowercase for case-insensitive checking
            pwd = password.lower()
    
            for i in range(len(pwd) - 2):
                triplet = pwd[i:i+3]
        
                # Must be three letters
                if triplet.isalpha():
                    # Check if they're consecutive: a->b->c, etc.
                    if ord(triplet[2]) - ord(triplet[0]) == 2:
                        # Verify the middle character is exactly +1
                        if ord(triplet[1]) - ord(triplet[0]) == 1:
                            return True
            
            return False

    def get_game_state(self) -> Dict:
        return {
            "current_rule_index": self.current_rule,
            "current_rule": self.get_current_rule(),
            "all_rules": self.get_all_rules_up_to_current(),
            "game_active": self.game_active,
            "instructions": instructions,
            "captcha": self.captcha,
            "country": self.country,
            "wordle_answer": self.wordle_answer,
            "moon_phase": self.moon_phase
        }

    def get_instructions(self) -> str:
        return instructions

    def get_minimal_game_state(self) -> Dict:
        """Return minimal game state, only exposing non-searchable values when needed."""
        state = {
            "current_rule_index": self.current_rule,
            "current_rule": self.get_current_rule(),
            "all_rules": self.get_all_rules_up_to_current(),
            "game_active": self.game_active
        }

        # Only expose captcha when rule 9 (index 9) is active or passed
        if self.current_rule >= 9:
            state["captcha"] = self.captcha

        # Only expose country when rule 13 (index 13) is active or passed
        if self.current_rule >= 13:
            state["country"] = self.country

        return state

    def _check_roman_multiply(self, password: str, target: int) -> bool:
        """Check if roman numerals in password multiply to target."""
        roman_values = {'I': 1, 'V': 5, 'X': 10, 'L': 50, 'C': 100, 'D': 500, 'M': 1000}
        roman_pattern = r'[IVXLCDM]+'
        romans = re.findall(roman_pattern, password)

        if not romans:
            return False

        product = 1
        for roman in romans:
            value = self._roman_to_int(roman)
            if value > 0:
                product *= value

        return product == target

    def _roman_to_int(self, roman: str) -> int:
        """Convert roman numeral to integer."""
        roman_values = {'I': 1, 'V': 5, 'X': 10, 'L': 50, 'C': 100, 'D': 500, 'M': 1000}
        total = 0
        prev_value = 0

        for char in reversed(roman):
            value = roman_values.get(char, 0)
            if value < prev_value:
                total -= value
            else:
                total += value
            prev_value = value

        return total

    def _check_periodic_element(self, password: str) -> bool:
        """Check for periodic table elements (first letter capitalized)."""
        elements = [
            'He', 'Li', 'Be', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'Cl', 'Ar', 'Ca', 'Sc', 'Ti', 'Cr', 'Mn', 'Fe', 'Co', 'Ni',
            'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd',
            'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd',
            'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb',
            'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm',
            'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og'
        ]
        return any(element in password for element in elements)

    def _check_leap_year(self, password: str) -> bool:
        """Check for leap years in password."""
        numbers = re.findall(r'\d{4}', password)
        for num_str in numbers:
            year = int(num_str)
            if self._is_leap_year(year):
                return True
        return False

    def _is_leap_year(self, year: int) -> bool:
        """Check if year is leap year."""
        return (year % 4 == 0 and year % 100 != 0) or (year % 400 == 0)

    def _check_atomic_sum(self, password: str, target: int) -> bool:
        """Check if atomic numbers of elements sum to target."""
        element_atomic = {
            'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'Ne': 10,
            'Na': 11, 'Mg': 12, 'Al': 13, 'Si': 14, 'P': 15, 'S': 16, 'Cl': 17, 'Ar': 18, 'K': 19, 'Ca': 20,
            'Sc': 21, 'Ti': 22, 'V': 23, 'Cr': 24, 'Mn': 25, 'Fe': 26, 'Co': 27, 'Ni': 28, 'Cu': 29, 'Zn': 30,
            'Ga': 31, 'Ge': 32, 'As': 33, 'Se': 34, 'Br': 35, 'Kr': 36, 'Rb': 37, 'Sr': 38, 'Y': 39, 'Zr': 40,
            'Nb': 41, 'Mo': 42, 'Tc': 43, 'Ru': 44, 'Rh': 45, 'Pd': 46, 'Ag': 47, 'Cd': 48, 'In': 49, 'Sn': 50,
            'Sb': 51, 'Te': 52, 'I': 53, 'Xe': 54, 'Cs': 55, 'Ba': 56, 'La': 57, 'Ce': 58, 'Pr': 59, 'Nd': 60,
            'Pm': 61, 'Sm': 62, 'Eu': 63, 'Gd': 64, 'Tb': 65, 'Dy': 66, 'Ho': 67, 'Er': 68, 'Tm': 69, 'Yb': 70,
            'Lu': 71, 'Hf': 72, 'Ta': 73, 'W': 74, 'Re': 75, 'Os': 76, 'Ir': 77, 'Pt': 78, 'Au': 79, 'Hg': 80,
            'Tl': 81, 'Pb': 82, 'Bi': 83, 'Po': 84, 'At': 85, 'Rn': 86, 'Fr': 87, 'Ra': 88, 'Ac': 89, 'Th': 90,
            'Pa': 91, 'U': 92, 'Np': 93, 'Pu': 94, 'Am': 95, 'Cm': 96, 'Bk': 97, 'Cf': 98, 'Es': 99, 'Fm': 100,
            'Md': 101, 'No': 102, 'Lr': 103, 'Rf': 104, 'Db': 105, 'Sg': 106, 'Bh': 107, 'Hs': 108, 'Mt': 109, 'Ds': 110,
            'Rg': 111, 'Cn': 112, 'Nh': 113, 'Fl': 114, 'Mc': 115, 'Lv': 116, 'Ts': 117, 'Og': 118
        }

        total_atomic = 0
        for element, atomic_num in element_atomic.items():
            if element in password:
                total_atomic += atomic_num

        return total_atomic == target

    def step(self, password: str=None, give_up:bool=False):
        """
        The main interaction function for the RL environment.
        You submit a password, and it returns the new state.
        """

        if give_up:
          self.end_game()
          reward = self.calculate_reward(password)
          feedback = self.get_rule_feedback(password)
          return {
            "game_ended": True,
            "gave_up": True,
            "reward": reward,
            "final_password": password,
            "rule_feedback": feedback
          }
        if password is not None:
          self.password_history.append(password)
          
        if len(self.password_history) == 0:
          return {"current_rule_index": self.current_rule,
          "current_rule": self.get_current_rule(),
          "game_active": self.game_active,
          "instructions": self.get_instructions(),
          }

        # Advance to next rule
        self.advance_rule()

        # Check if game ended naturally
        if not self.game_active:
          reward = self.calculate_reward(password)
          feedback = self.get_rule_feedback(password)
          return {
            "game_ended": True,
            "gave_up": False,
            "reward": reward,
            "final_password": password,
            "rule_feedback": feedback
          }

        return self.get_minimal_game_state()
        

    def _is_prime(self, n: int) -> bool:
        """Check if number is prime."""
        if n < 2:
            return False
        if n == 2:
            return True
        if n % 2 == 0:
            return False

        for i in range(3, int(n**0.5) + 1, 2):
            if n % i == 0:
                return False
        return True

## Baseline Eval

In [None]:
def evaluate_model(model, dataset, num_samples=100, desc="Eval"):
    model.eval()
    model.config.use_cache = True
    
    total_reward = 0.0
    
    with torch.no_grad():
        for i in tqdm(range(min(num_samples, len(dataset))), desc=desc):
            item = dataset[i]
            prompt = item['prompt']
            num_rules = item['num_rules']
            
            inputs = tokenizer([prompt], return_tensors="pt", padding=True, truncation=True, max_length=config.max_prompt_length).to(DEVICE)
            outputs = model.generate(
                **inputs,
                max_new_tokens=config.max_new_tokens,
                do_sample=True,
                temperature=config.temperature,
                top_p=config.top_p,
                top_k=config.top_k,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
            
            generated = tokenizer.decode(outputs[0, inputs.input_ids.size(1):], skip_special_tokens=True)
            password = generated.strip().split()[0] if generated.strip() else ""
            reward = compute_reward(password, num_rules)
            total_reward += reward
    
    model.config.use_cache = False
    model.train()
    return total_reward / min(num_samples, len(dataset))

baseline_reward = evaluate_model(policy_model, val_dataset, num_samples=100, desc="Baseline")
print(f"\nBaseline reward: {baseline_reward:.4f}")

## PPO Utils

In [None]:
def compute_log_probs(model, input_ids, attention_mask, return_values=False):
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=return_values)
    logits = outputs.logits
    log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
    token_log_probs = torch.gather(log_probs, dim=2, index=input_ids[:, 1:].unsqueeze(-1)).squeeze(-1)
    mask = attention_mask[:, 1:].bool()
    token_log_probs = token_log_probs * mask
    if return_values:
        hidden_states = outputs.hidden_states[-1]
        values = value_head(hidden_states).squeeze(-1)
        return token_log_probs, values
    return token_log_probs

def compute_advantages(rewards, values, masks, gamma=0.99, gae_lambda=0.95):
    batch_size, seq_len = rewards.shape
    advantages = torch.zeros_like(rewards)
    gae = 0
    for t in reversed(range(seq_len)):
        next_value = 0 if t == seq_len - 1 else values[:, t + 1]
        delta = rewards[:, t] + gamma * next_value * masks[:, t] - values[:, t]
        gae = delta + gamma * gae_lambda * masks[:, t] * gae
        advantages[:, t] = gae
    returns = advantages + values
    return advantages, returns

def whiten(values, mask):
    mean = (values * mask).sum() / mask.sum()
    var = ((values - mean) ** 2 * mask).sum() / mask.sum()
    std = torch.sqrt(var + 1e-8)
    return (values - mean) / std

print("PPO utils defined")

## Training

In [None]:
optimizer = torch.optim.AdamW(
    list(policy_model.parameters()) + list(value_head.parameters()),
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)
total_steps = config.num_epochs * config.num_steps_per_epoch
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=total_steps)
print(f"Optimizer ready: {total_steps} steps")

In [None]:
wandb_run = wandb.init(project=config.wandb_project, name=config.wandb_run_name, config=asdict(config))
print(f"WandB: {wandb_run.get_url()}")

In [None]:
policy_model.train()
value_head.train()
global_step = 0
best_val_reward = -float('inf')
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

print("Starting training...")

for epoch in range(config.num_epochs):
    print(f"\nEpoch {epoch+1}/{config.num_epochs}")
    epoch_iter = iter(train_dataloader)
    
    for step in tqdm(range(config.num_steps_per_epoch), desc=f"Epoch {epoch+1}"):
        try:
            batch = next(epoch_iter)
        except StopIteration:
            epoch_iter = iter(train_dataloader)
            batch = next(epoch_iter)
        
        prompts = batch['prompt']
        num_rules_list = batch['num_rules']
        
        # Rollout
        policy_model.eval()
        policy_model.config.use_cache = True
        with torch.no_grad():
            prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=config.max_prompt_length).to(DEVICE)
            all_responses = []
            all_full_ids = []
            all_masks = []
            for _ in range(config.samples_per_prompt):
                outputs = policy_model.generate(
                    **prompt_inputs,
                    max_new_tokens=config.max_new_tokens,
                    do_sample=True,
                    temperature=config.temperature,
                    top_p=config.top_p,
                    top_k=config.top_k,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )
                generated_ids = outputs[:, prompt_inputs.input_ids.size(1):]
                responses = []
                for gen_ids in generated_ids:
                    resp = tokenizer.decode(gen_ids, skip_special_tokens=True)
                    password = resp.strip().split()[0] if resp.strip() else ""
                    responses.append(password)
                all_responses.extend(responses)
                all_full_ids.append(outputs)
                mask = torch.ones_like(outputs)
                mask[outputs == tokenizer.pad_token_id] = 0
                all_masks.append(mask)
            all_full_ids = torch.cat(all_full_ids, dim=0)
            all_masks = torch.cat(all_masks, dim=0)
            expanded_num_rules = num_rules_list * config.samples_per_prompt
        
        # Rewards
        rewards = torch.tensor([compute_reward(pwd, nr) for pwd, nr in zip(all_responses, expanded_num_rules)], device=DEVICE, dtype=dtype)
        mean_reward = rewards.mean().item()
        
        # Old probs & values
        with torch.no_grad():
            old_log_probs, old_values = compute_log_probs(policy_model, all_full_ids, all_masks, return_values=True)
            ref_log_probs = compute_log_probs(reference_model, all_full_ids, all_masks)
            prompt_len = prompt_inputs.input_ids.size(1)
            old_values_gen = old_values[:, prompt_len:]
            generated_ids_all = all_full_ids[:, prompt_len:]
        
        # Advantages
        response_mask = (generated_ids_all != tokenizer.pad_token_id).float()
        reward_per_token = torch.zeros_like(generated_ids_all, dtype=dtype)
        for i, reward in enumerate(rewards):
            valid = generated_ids_all[i] != tokenizer.pad_token_id
            reward_per_token[i][valid] = reward / valid.sum().clamp(min=1)
        advantages, returns = compute_advantages(reward_per_token, old_values_gen, response_mask, config.gamma, config.gae_lambda)
        if config.normalize_advantages:
            advantages = whiten(advantages, response_mask)
        
        # PPO updates
        policy_model.train()
        policy_model.config.use_cache = False
        for ppo_epoch in range(config.ppo_epochs):
            curr_log_probs, curr_values = compute_log_probs(policy_model, all_full_ids, all_masks, return_values=True)
            curr_values_gen = curr_values[:, prompt_len:]
            curr_lp_gen = curr_log_probs[:, prompt_len-1:]
            old_lp_gen = old_log_probs[:, prompt_len-1:]
            ref_lp_gen = ref_log_probs[:, prompt_len-1:]
            
            ratio = torch.exp(curr_lp_gen - old_lp_gen.detach())
            policy_loss = torch.max(
                -advantages.detach() * ratio,
                -advantages.detach() * torch.clamp(ratio, 1-config.clip_range, 1+config.clip_range)
            )
            policy_loss = (policy_loss * response_mask).sum() / response_mask.sum()
            value_loss = ((curr_values_gen - returns.detach())**2 * response_mask).sum() / response_mask.sum()
            kl_penalty = ((curr_lp_gen - ref_lp_gen.detach()) * response_mask).sum() / response_mask.sum()
            
            loss = policy_loss + config.value_loss_coef * value_loss + config.kl_coef * kl_penalty
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(list(policy_model.parameters()) + list(value_head.parameters()), config.max_grad_norm)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        # Logging
        if global_step % config.log_interval == 0:
            wandb.log({"step": global_step, "loss": loss.item(), "reward": mean_reward, "kl": kl_penalty.item()}, step=global_step)
        
        # Eval
        if global_step % config.eval_interval == 0 and global_step > 0:
            val_reward = evaluate_model(policy_model, val_dataset, num_samples=50, desc=f"Eval@{global_step}")
            wandb.log({"val_reward": val_reward}, step=global_step)
            if val_reward > best_val_reward:
                best_val_reward = val_reward
                best_dir = os.path.join(config.output_dir, "best_model")
                os.makedirs(best_dir, exist_ok=True)
                policy_model.save_pretrained(best_dir)
                tokenizer.save_pretrained(best_dir)
                print(f"\nBest: {best_val_reward:.4f}")
        
        # Checkpoint
        if global_step % config.save_interval == 0 and global_step > 0:
            ckpt_dir = os.path.join(config.output_dir, f"checkpoint-{global_step}")
            os.makedirs(ckpt_dir, exist_ok=True)
            policy_model.save_pretrained(ckpt_dir)
            tokenizer.save_pretrained(ckpt_dir)
        
        global_step += 1
        torch.cuda.empty_cache()

print(f"\nTraining complete! Best val: {best_val_reward:.4f}")

## Final Eval

In [None]:
final_reward = evaluate_model(policy_model, val_dataset, num_samples=len(val_dataset), desc="Final")
print(f"Final reward: {final_reward:.4f}")
print(f"Improvement: {final_reward - baseline_reward:.4f}")

final_dir = os.path.join(config.output_dir, "final_model")
os.makedirs(final_dir, exist_ok=True)
policy_model.save_pretrained(final_dir)
tokenizer.save_pretrained(final_dir)
print(f"Saved to {final_dir}")

summary = {
    "baseline": baseline_reward,
    "final": final_reward,
    "best_val": best_val_reward,
    "improvement": final_reward - baseline_reward
}
with open(os.path.join(config.output_dir, "summary.json"), "w") as f:
    json.dump(summary, f, indent=2)

wandb.finish()
print(f"\nSummary: {summary}")