In [None]:
# @title Complete ReMax Training (3 Epochs / KL divergence / BF16)
'''
=====================================================================================================
This training script was originally developed and optimized for execution within Google Colab,
relying heavily on Google Drive for persistent storage, Colab-specific authentication mechanisms,
and other environment-dependent utilities. As a result, the initial implementation included
Drive-mounted checkpoint directories, CSV logging to Drive, and secret-based Hugging Face login via
Colab‚Äôs userdata API. While these components streamlined experimentation within a Colab workflow,
they also made the script less portable and harder to reproduce in general compute environments
such as local machines, cloud VMs, or managed training clusters.

You can refactor the current version and remove the above mentioned Colab-specific assumptions,
replacing them with environment-agnostic paths, standard Hugging Face authentication, and fully
general dataset/model loading logic so the script can run consistently anywhere while retaining
the same behavior and training methodology.
=====================================================================================================
'''
# ==========================================
# 1. Cleanup & Install
# ==========================================

print("‚è≥ Installing libraries...")
!pip install -q -U transformers datasets accelerate huggingface_hub bitsandbytes
'''
The training was conducted using the following library versions at the time:
Accelerate: 0.28.0
Hugging Face Hub: 0.17.1
Transformers: 4.57.3
Pytorch: 2.9.0+cu126
Datasets: 4.4.1
Tokenizers: 0.22.1
Bitsandbytes: 0.48.2 (it was used as everything here was done in BF16, SFT and RM were loaded in BF16)

NOTE : TRL didn't had a drop in for ReMax method at the time so this is a custom raw implementation of ReMax
'''
import torch
import torch.nn.functional as F
import os
import csv
import math
from torch.optim import AdamW
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_scheduler
)
from huggingface_hub import login, HfApi
from google.colab import userdata, drive
from tqdm.auto import tqdm

# ==========================================
# 2. Setup & Login
# ==========================================
print("\nüìÇ Mounting Drive...")
drive.mount('/content/drive')

# Define Paths
DRIVE_ROOT = "/content/drive/MyDrive/Qwen3-ReMax-Training"
LOG_FILE = f"{DRIVE_ROOT}/remax_logs.csv"
CHECKPOINT_DIR = f"{DRIVE_ROOT}/checkpoints"
EPOCH_DIR = f"{DRIVE_ROOT}/epoch_saves"

for d in [CHECKPOINT_DIR, EPOCH_DIR]:
# Create directories and initialize the CSV log header if missing
    if not os.path.exists(d):
        os.makedirs(d)

# Init Log File (Added 'Phase' column for epochs)
if not os.path.exists(LOG_FILE):
    with open(LOG_FILE, 'w') as f:
        f.write("Global_Step,Epoch,Phase,Loss,Advantage,Reward_Sample,Reward_Greedy,KL_Proxy\n")

# Hugging Face Login
print("\nüîë Logging in...")
try:
    hf_token = userdata.get('HF_TOKEN')
    login(token=hf_token, add_to_git_credential=True)
    print("‚úÖ Logged in via Colab Secret.")
except:
  print("‚ö†Ô∏è Secret 'HF_TOKEN' not found. Falling back to manual input.")
  login(add_to_git_credential=True)

# ==========================================
# 3. Configuration
# ==========================================
SFT_MODEL_ID = "AIPlans/qwen3-0.6b-SFT-hs2" # this works as policy(to be trained) and reference model both
RM_MODEL_ID  = "AIPlans/qwen3-0.6b-RM-hs2"
OUTPUT_REPO  = "your-username/qwen3-0.6b-ReMax" # this naming is arbitrary
DATASET_NAME = "Jennny/helpsteer2-helpfulness-preference" # this is a variant of the HelpSteer2 dataset having only the helpfulness attribute
DEVICE = "cuda"

# All hyperparameters can be modified as suitable (A100 80GB was used at the time)
LR = 5e-7 # a lower learning rate as there are 3 epochs
BETA = 0.1 # for KL divergence
BATCH_SIZE = 16
GRAD_ACCUMULATION = 2     # Effective Batch = 32
EPOCHS = 3
MAX_NEW_TOKENS = 128
MAX_PROMPT_LENGTH = 1024
LOGGING_STEPS = 20        # Log every 20 steps

# ==========================================
# 4. Load Models (BF16)
# ==========================================
print("\nüß† Loading Models...")
tokenizer = AutoTokenizer.from_pretrained(SFT_MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

# 1. Actor (Trainable)
actor = AutoModelForCausalLM.from_pretrained(
    SFT_MODEL_ID, dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
)
actor.config.use_cache = False

# 2. Reference (Frozen weights)
ref = AutoModelForCausalLM.from_pretrained(
    SFT_MODEL_ID, dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
)
ref.eval()

# 3. Reward (Frozen weights)
rm = AutoModelForSequenceClassification.from_pretrained(
    RM_MODEL_ID, num_labels=1, dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
)
rm.eval()

print("‚úÖ Models Loaded.")

# ==========================================
# 5. Dataset (Train/Val Split)
# ==========================================
print(f"Loading and Preparing dataset: {DATASET_NAME}...")
# Load and Filter
full_dataset = load_dataset(DATASET_NAME, split="train")
full_dataset = full_dataset.filter(lambda x: x["chosen_score"] >= 3)

# Split: 95% Train, 5% Validation
split_dataset = full_dataset.train_test_split(test_size=0.05, seed=42)
train_data = split_dataset['train']
val_data = split_dataset['test'].select(range(64)) # Keep validation fast (64 examples)

def preprocess(examples):
    prompts = []
    for chosen in examples["chosen"]:
        if isinstance(chosen, list): prompt = chosen[0]['content']
        else: prompt = str(chosen).split("Assistant:")[0]
        prompts.append(f"User: {prompt}\n\nAssistant:")
    return tokenizer(prompts, padding="max_length", truncation=True, max_length=MAX_PROMPT_LENGTH, return_tensors="pt")

# Process Train
train_dataset = train_data.map(preprocess, batched=True, remove_columns=train_data.column_names)
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Process Val
val_dataset = val_data.map(preprocess, batched=True, remove_columns=val_data.column_names)
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"‚úÖ Dataset Split: {len(train_dataset)} Train | {len(val_dataset)} Validation")

# ==========================================
# 6. Optimizer
# ==========================================
optimizer = AdamW(actor.parameters(), lr=LR)
num_training_steps = math.ceil(len(dataloader) * EPOCHS / GRAD_ACCUMULATION)
lr_scheduler = get_scheduler("cosine", optimizer=optimizer, num_warmup_steps=20, num_training_steps=num_training_steps)

# ==========================================
# 7. Helper Functions
# ==========================================
def get_log_probs(model, input_ids, attention_mask):
    outputs = model(input_ids, attention_mask=attention_mask)
    logits = outputs.logits[:, :-1, :]
    labels = input_ids[:, 1:]
    log_probs = F.log_softmax(logits, dim=-1)
    return torch.gather(log_probs, -1, labels.unsqueeze(-1)).squeeze(-1)

def get_reward(texts):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=2048).to(DEVICE)
    with torch.no_grad():
        return rm(**inputs).logits.squeeze(-1)

def run_evaluation(epoch, global_step):
    print(f"\nüîç Running Validation (Epoch {epoch+1})...")
    actor.eval()

    val_loss, val_adv, val_r_sample, val_r_greedy, val_kl = 0, 0, 0, 0, 0
    steps = 0

    with torch.no_grad():
        for batch in val_dataloader:
            # --- Generation ---
            prompt_ids = batch["input_ids"].to(DEVICE)
            prompt_mask = batch["attention_mask"].to(DEVICE)

            sample_out = actor.generate(prompt_ids, attention_mask=prompt_mask, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, top_p=0.9, temperature=1.0, pad_token_id=tokenizer.pad_token_id)
            greedy_out = actor.generate(prompt_ids, attention_mask=prompt_mask, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, pad_token_id=tokenizer.pad_token_id)

            # --- Rewards ---
            actor_device = next(actor.parameters()).device
            rm_device = next(rm.parameters()).device

            sample_mask = (sample_out != tokenizer.pad_token_id).long()
            greedy_mask = (greedy_out != tokenizer.pad_token_id).long()

            r_sample = rm(input_ids=sample_out.to(rm_device), attention_mask=sample_mask.to(rm_device)).logits.squeeze(-1).to(actor_device).to(torch.bfloat16)
            r_greedy = rm(input_ids=greedy_out.to(rm_device), attention_mask=greedy_mask.to(rm_device)).logits.squeeze(-1).to(actor_device).to(torch.bfloat16)

            # --- KL & Loss (For metrics only) ---
            sample_out = sample_out.to(actor_device)
            sample_mask = sample_mask.to(actor_device)

            log_probs_actor = get_log_probs(actor, sample_out, sample_mask)
            log_probs_ref = get_log_probs(ref, sample_out, sample_mask)

            start = prompt_ids.shape[1] - 1
            min_len = min(log_probs_actor.shape[1], log_probs_ref.shape[1])

            if start < min_len:
                logp_gen = log_probs_actor[:, start:min_len]
                valid_mask = sample_mask[:, start+1 : start+1 + logp_gen.shape[1]].to(actor_device).to(torch.bfloat16)

                logp_sum = (logp_gen * valid_mask).sum(dim=-1)

                # KL Proxy
                ref_gen = log_probs_ref[:, start:min_len]
                kl_sum = ((logp_gen - ref_gen) * valid_mask).sum(dim=-1)
                num_tokens = valid_mask.sum(dim=-1).clamp(min=1.0)
                kl_div_normalized = kl_sum / num_tokens

                advantage = (r_sample - BETA * kl_div_normalized) - r_greedy
                loss = - (logp_sum * advantage).mean()

                val_loss += loss.item()
                val_adv += advantage.mean().item()
                val_r_sample += r_sample.mean().item()
                val_r_greedy += r_greedy.mean().item()
                val_kl += kl_div_normalized.mean().item()
                steps += 1

    # Average metrics
    if steps > 0:
        val_loss /= steps
        val_adv /= steps
        val_r_sample /= steps
        val_r_greedy /= steps
        val_kl /= steps

    print(f"üìä Validation Results: Loss: {val_loss:.4f} | Adv: {val_adv:.4f} | R_Sample: {val_r_sample:.4f}")

    # Log Validation Row
    with open(LOG_FILE, 'a') as f:
        f.write(f"{global_step},{epoch+1},Val,{val_loss:.4f},{val_adv:.4f},{val_r_sample:.4f},{val_r_greedy:.4f},{val_kl:.4f}\n")

    actor.train()

# ==========================================
# 8. Training Loop
# ==========================================
print("\nüöÄ Starting ReMax Loop...")
actor.train()
progress_bar = tqdm(range(num_training_steps))
optimizer.zero_grad()

acc_loss, acc_adv, acc_r_sample, acc_r_greedy, acc_kl = 0, 0, 0, 0, 0
global_step = 0

# epochs were run in phases
for epoch in range(EPOCHS):
    print(f"\n--- Starting Epoch {epoch+1}/{EPOCHS} ---")

    for step, batch in enumerate(dataloader):

        # --- A. Generation ---
        prompt_ids = batch["input_ids"].to(DEVICE)
        prompt_mask = batch["attention_mask"].to(DEVICE)

        with torch.no_grad():
            sample_out = actor.generate(
                prompt_ids, attention_mask=prompt_mask, max_new_tokens=MAX_NEW_TOKENS,
                do_sample=True, top_p=0.9, temperature=1.0, pad_token_id=tokenizer.pad_token_id
            )
            greedy_out = actor.generate(
                prompt_ids, attention_mask=prompt_mask, max_new_tokens=MAX_NEW_TOKENS,
                do_sample=False, pad_token_id=tokenizer.pad_token_id
            )

        # --- B. Rewards ---
        actor_device = next(actor.parameters()).device
        rm_device = next(rm.parameters()).device

        sample_mask = (sample_out != tokenizer.pad_token_id).long()
        greedy_mask = (greedy_out != tokenizer.pad_token_id).long()

        with torch.no_grad():
            r_sample = rm(input_ids=sample_out.to(rm_device), attention_mask=sample_mask.to(rm_device)).logits.squeeze(-1)
            r_greedy = rm(input_ids=greedy_out.to(rm_device), attention_mask=greedy_mask.to(rm_device)).logits.squeeze(-1)

        r_sample = r_sample.to(actor_device).to(torch.bfloat16)
        r_greedy = r_greedy.to(actor_device).to(torch.bfloat16)

        # --- C. KL & Log Probs ---
        sample_out = sample_out.to(actor_device)
        sample_mask = sample_mask.to(actor_device)

        log_probs_actor = get_log_probs(actor, sample_out, sample_mask)
        with torch.no_grad():
            log_probs_ref = get_log_probs(ref, sample_out, sample_mask)

        prompt_len = prompt_ids.shape[1]
        start = prompt_len - 1
        min_len = min(log_probs_actor.shape[1], log_probs_ref.shape[1])

        if start >= min_len:
            logp_sum = torch.zeros(prompt_ids.size(0), device=actor_device)
            kl_div_normalized = torch.zeros_like(logp_sum)
        else:
            logp_gen = log_probs_actor[:, start:min_len]
            valid_mask = sample_mask[:, start+1 : start+1 + logp_gen.shape[1]]
            valid_mask = valid_mask[:, :logp_gen.shape[1]].to(actor_device)

            valid_mask_f = valid_mask.to(torch.bfloat16)
            logp_gen_f = logp_gen.to(torch.bfloat16)

            logp_masked = logp_gen_f * valid_mask_f
            logp_sum = logp_masked.sum(dim=-1)

            # KL Proxy Calculation
            ref_gen_f = log_probs_ref[:, start:min_len].to(torch.bfloat16)
            kl_gen = (logp_gen_f - ref_gen_f) * valid_mask_f
            kl_sum = kl_gen.sum(dim=-1)
            num_tokens = valid_mask_f.sum(dim=-1).clamp(min=1.0)
            kl_div_normalized = kl_sum / num_tokens

        # --- D. Loss Calculation ---
        advantage = (r_sample - BETA * kl_div_normalized) - r_greedy
        advantage = advantage.clamp(-10.0, 10.0)

        loss = - (logp_sum * advantage.detach()).mean()
        loss = loss / GRAD_ACCUMULATION

        loss.backward()

        acc_loss += loss.item() * GRAD_ACCUMULATION
        acc_adv += advantage.mean().item()
        acc_r_sample += r_sample.mean().item()
        acc_r_greedy += r_greedy.mean().item()
        acc_kl += kl_div_normalized.mean().item()

        # --- E. Optimizer Step ---
        if (step + 1) % GRAD_ACCUMULATION == 0:
            torch.nn.utils.clip_grad_norm_(actor.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            global_step += 1

            # Log Training
            if global_step % LOGGING_STEPS == 0:
                with open(LOG_FILE, 'a') as f:
                    f.write(f"{global_step},{epoch+1},Train,{acc_loss/GRAD_ACCUMULATION},{acc_adv/GRAD_ACCUMULATION},{acc_r_sample/GRAD_ACCUMULATION},{acc_r_greedy/GRAD_ACCUMULATION},{acc_kl/GRAD_ACCUMULATION}\n")

                print(f"Step {global_step} | Loss: {acc_loss/GRAD_ACCUMULATION:.4f} | Adv: {acc_adv/GRAD_ACCUMULATION:.4f} | R_Sample: {acc_r_sample/GRAD_ACCUMULATION:.4f}")

            acc_loss, acc_adv, acc_r_sample, acc_r_greedy, acc_kl = 0, 0, 0, 0, 0

            # Checkpoint Rotation for storage efficiency in drive
            if global_step > 0 and global_step % 50 == 0:
                ckpt_path = f"{CHECKPOINT_DIR}/step_{global_step}"
                actor.save_pretrained(ckpt_path)
                tokenizer.save_pretrained(ckpt_path)
                print(f"üíæ Checkpoint saved: {ckpt_path}")

                all_ckpts = sorted([os.path.join(CHECKPOINT_DIR, d) for d in os.listdir(CHECKPOINT_DIR) if d.startswith("step_")], key=os.path.getmtime)
                if len(all_ckpts) > 2:
                    for old_ckpt in all_ckpts[:-2]:
                        shutil.rmtree(old_ckpt)

    # --- End of Epoch Validation & Save ---
    run_evaluation(epoch, global_step) # Run Validation Suite

    epoch_save_path = f"{EPOCH_DIR}/epoch_{epoch+1}"
    print(f"üéâ Epoch {epoch+1} Complete! Saving to {epoch_save_path}...")
    actor.save_pretrained(epoch_save_path)
    tokenizer.save_pretrained(epoch_save_path)

# ==========================================
# 9. Final Save & Push
# ==========================================
print("\n‚òÅÔ∏è Pushing BF16 Model to Hub...")
api = HfApi()
api.create_repo(repo_id=OUTPUT_REPO, exist_ok=True)
actor.push_to_hub(OUTPUT_REPO)
tokenizer.push_to_hub(OUTPUT_REPO)

print(f"‚úÖ ReMax Training Complete! Model: https://huggingface.co/{OUTPUT_REPO}")