<a href="https://colab.research.google.com/github/Sidhtang/finetuning-gemma3-with-grpo-on-Anthropic-hh-rlhf-dataset/blob/main/finetuning_gemma_3_with_grpo_on_Anthropic_hh_rlhf_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install --upgrade transformers



In [1]:
!pip install datasets



In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from tqdm import tqdm
from datasets import load_dataset
import gc

# Set environment variable to manage memory fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# First, let's optimize Colab's memory usage
def clear_memory():
    torch.cuda.empty_cache()
    gc.collect()

# Login to Hugging Face (make sure to replace with your token)
def hf_login(token):
    try:
        login(token=token, add_to_git_credential=True)
        print("Successfully logged in to Hugging Face!")
        return True
    except Exception as e:
        print(f"Login failed. Error: {e}")
        return False

# Example usage (replace with your actual token)
if __name__ == "__main__":
    hf_token = ''
    login_result = hf_login(hf_token)

class GRPOConfig:
    """Configuration optimized for Colab T4 GPU"""
    def __init__(
        self,
        planning_steps: int = 2,  # Reduced from 3 to save memory
        critic_lr: float = 3e-5,  # Slightly lower learning rate
        reward_scale: float = 0.1,
        value_loss_coef: float = 0.5,
        entropy_coef: float = 0.01,
        max_grad_norm: float = 1.0,
        ref_kl_threshold: float = 0.05,
        gradient_checkpointing: bool = True  # Added to save memory
    ):
        self.planning_steps = planning_steps
        self.critic_lr = critic_lr
        self.reward_scale = reward_scale
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.max_grad_norm = max_grad_norm
        self.ref_kl_threshold = ref_kl_threshold
        self.gradient_checkpointing = gradient_checkpointing

class CustomDataset(Dataset):
    """Dataset with optimized memory usage"""
    def __init__(self, dataset_name: str, tokenizer, max_length: int = 512):  # Reduced max_length
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = []

        try:
            # Load only a subset of the data for Colab
            raw_data = load_dataset(dataset_name, split='train[:20%]')  # Only 20% of data

            for item in raw_data:
                if "chosen" in item:
                    text = item["chosen"]
                    if "\n\nHuman: " in text and "\n\nAssistant: " in text:
                        parts = text.split("\n\nAssistant: ")
                        prompt = parts[0].replace("\n\nHuman: ", "").strip()
                        completion = parts[1].strip()
                        self.data.append({
                            "prompt": prompt,
                            "completion": completion,
                            "reward": item.get("reward", 1.0)
                        })

            if len(self.data) == 0:
                raise ValueError(f"No valid data found in dataset {dataset_name}")

            print(f"Loaded {len(self.data)} examples from {dataset_name} (subset for Colab)")

        except Exception as e:
            print(f"Error loading dataset {dataset_name}: {str(e)}")
            raise

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        prompt = item["prompt"]
        completion = item["completion"]
        reward = item["reward"]

        text = f"User: {prompt}\n\nAssistant: {completion}"

        encodings = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        labels = encodings["input_ids"].clone()
        assistant_start = text.find("Assistant: ")
        if assistant_start != -1:
            assistant_tokens = self.tokenizer(text[:assistant_start], return_tensors="pt")
            assistant_token_len = assistant_tokens["input_ids"].shape[1]
            labels[0, :assistant_token_len] = -100

        return {
            "input_ids": encodings["input_ids"].squeeze(),
            "attention_mask": encodings["attention_mask"].squeeze(),
            "labels": labels.squeeze(),
            "reward": torch.tensor(reward, dtype=torch.float)
        }

class ValueHead(nn.Module):
    """Simplified value head for memory efficiency"""
    def __init__(self, hidden_size):
        super().__init__()
        self.value_head = nn.Linear(hidden_size, 1)  # Simplified architecture

    def forward(self, hidden_states):
        return self.value_head(hidden_states)

class GRPOTrainer:
    """Trainer optimized for Colab T4 GPU"""
    def __init__(
        self,
        model_name: str = "google/gemma-3-4b-it",
        dataset: Dataset = None,
        output_dir: str = "./gemini3-grpo",
        config: GRPOConfig = None,
        learning_rate: float = 1e-5,  # Lower learning rate
        batch_size: int = 2,  # Reduced batch size
        epochs: int = 2,  # Fewer epochs
        gradient_accumulation_steps: int = 8,  # Increased accumulation steps
        max_grad_norm: float = 1.0,
        warmup_steps: int = 50,  # Reduced warmup
        save_steps: int = 200,  # More frequent saves
        device: str = None,
        mixed_precision: bool = True  # Added mixed precision
    ):
        self.model_name = model_name
        self.dataset = dataset
        self.output_dir = output_dir
        self.config = config if config else GRPOConfig()
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.epochs = epochs
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.max_grad_norm = max_grad_norm
        self.warmup_steps = warmup_steps
        self.save_steps = save_steps
        self.mixed_precision = mixed_precision

        self.device = torch.device(device if device else "cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        # Enable TF32 for faster computations (on supported GPUs)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

        os.makedirs(output_dir, exist_ok=True)

        self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # Load model with optimizations
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16 if self.mixed_precision else torch.float32,
            device_map="auto",
            low_cpu_mem_usage=True
        )

        if self.config.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

        # Smaller LoRA config for Colab
        lora_config = LoraConfig(
            task_type="CAUSAL_LM",
            r=8,  # Reduced from 16
            lora_alpha=16,  # Reduced from 32
            target_modules=["q_proj", "v_proj"],  # Fewer target modules
            lora_dropout=0.05
        )
        self.model = get_peft_model(self.model, lora_config)
        self.model.print_trainable_parameters()

        self.value_head = ValueHead(self.model.config.hidden_size).to(self.device)

        # Reference model with same optimizations
        self.ref_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16 if self.mixed_precision else torch.float32,
            device_map="auto",
            low_cpu_mem_usage=True
        )
        for param in self.ref_model.parameters():
            param.requires_grad = False

        # Optimizers with weight decay
        self.optimizer = optim.AdamW(
            list(self.model.parameters()) + list(self.value_head.parameters()),
            lr=learning_rate,
            weight_decay=0.01
        )

        self.critic_optimizer = optim.AdamW(
            self.value_head.parameters(),
            lr=self.config.critic_lr,
            weight_decay=0.01
        )

        # Scaler for mixed precision
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision)

    def kl_divergence(self, logits_a, logits_b, attention_mask):
        logits_a = logits_a.view(-1, logits_a.size(-1))
        logits_b = logits_b.view(-1, logits_b.size(-1))
        attention_mask = attention_mask.view(-1)

        mask = attention_mask.bool()
        logits_a = logits_a[mask]
        logits_b = logits_b[mask]

        probs_a = torch.softmax(logits_a, dim=-1)
        probs_b = torch.softmax(logits_b, dim=-1)

        kl = torch.sum(probs_a * (torch.log(probs_a) - torch.log(probs_b + 1e-10)), dim=-1)
        return kl.mean()

    def train(self):
        print(f"Starting training on {self.device}")
        self.model.train()
        self.value_head.train()

        total_steps = 0
        best_loss = float('inf')

        for epoch in range(self.epochs):
            epoch_loss = 0.0
            epoch_policy_loss = 0.0
            epoch_value_loss = 0.0
            epoch_entropy = 0.0
            epoch_kl = 0.0

            progress_bar = tqdm(self.dataloader, desc=f"Epoch {epoch+1}/{self.epochs}")

            for batch_idx, batch in enumerate(progress_bar):
                # Move batch to the correct device
                batch = {k: v.to(self.device) for k, v in batch.items()}
                total_loss = 0.0

                # Forward pass with mixed precision
                with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.mixed_precision):
                    with torch.no_grad():
                        ref_outputs = self.ref_model(
                            input_ids=batch["input_ids"],
                            attention_mask=batch["attention_mask"],
                            output_hidden_states=True
                        )
                        ref_logits = ref_outputs.logits
                        ref_hidden = ref_outputs.hidden_states[-1]

                    current_policy_logits = ref_logits.detach().clone()

                    for step in range(self.config.planning_steps):
                        outputs = self.model(
                            input_ids=batch["input_ids"],
                            attention_mask=batch["attention_mask"],
                            output_hidden_states=True
                        )
                        logits = outputs.logits
                        hidden = outputs.hidden_states[-1]

                        values = self.value_head(hidden[:, -1, :]).squeeze(-1)
                        kl = self.kl_divergence(logits, ref_logits, batch["attention_mask"])

                        probs = torch.softmax(logits.view(-1, logits.size(-1)), dim=-1)
                        entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1).mean()

                        advantages = batch["reward"] - values.detach()
                        policy_loss = -advantages * entropy
                        value_loss = 0.5 * ((values - batch["reward"]) ** 2).mean()

                        loss = (
                            policy_loss
                            + self.config.value_loss_coef * value_loss
                            - self.config.entropy_coef * entropy
                        )

                        if kl > self.config.ref_kl_threshold:
                            loss += kl

                        loss = loss / self.gradient_accumulation_steps
                        total_loss += loss.item()

                        # Scale loss for mixed precision
                        self.scaler.scale(loss).backward()

                        epoch_policy_loss += policy_loss.item()
                        epoch_value_loss += value_loss.item()
                        epoch_entropy += entropy.item()
                        epoch_kl += kl.item()

                        current_policy_logits = logits.detach()

                # Gradient accumulation and update
                if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
                    # Unscale gradients before clipping
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        list(self.model.parameters()) + list(self.value_head.parameters()),
                        self.max_grad_norm
                    )

                    # Step with scaler
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad()

                    # Clear memory
                    clear_memory()

                    progress_bar.set_postfix({
                        "loss": total_loss,
                        "policy_loss": epoch_policy_loss / (batch_idx + 1),
                        "value_loss": epoch_value_loss / (batch_idx + 1),
                        "kl": epoch_kl / (batch_idx + 1)
                    })

                    total_steps += 1

                    # Save checkpoint more frequently
                    if total_steps % self.save_steps == 0:
                        current_loss = total_loss / (batch_idx + 1)
                        if current_loss < best_loss:
                            best_loss = current_loss
                            self.save_checkpoint(f"checkpoint-best")
                        self.save_checkpoint(f"checkpoint-step-{total_steps}")

            # Save after each epoch
            self.save_checkpoint(f"checkpoint-epoch-{epoch+1}")

            print(f"Epoch {epoch+1} stats:")
            print(f"  Policy loss: {epoch_policy_loss / len(self.dataloader):.4f}")
            print(f"  Value loss: {epoch_value_loss / len(self.dataloader):.4f}")
            print(f"  Entropy: {epoch_entropy / len(self.dataloader):.4f}")
            print(f"  KL: {epoch_kl / len(self.dataloader):.4f}")

        self.save_checkpoint("final-model")
        print("Training complete!")

    def save_checkpoint(self, checkpoint_name: str):
        checkpoint_dir = os.path.join(self.output_dir, checkpoint_name)
        os.makedirs(checkpoint_dir, exist_ok=True)

        # Save with safetensors for efficiency
        self.model.save_pretrained(checkpoint_dir, safe_serialization=True)
        self.tokenizer.save_pretrained(checkpoint_dir)
        torch.save(self.value_head.state_dict(), os.path.join(checkpoint_dir, "value_head.pt"))
        torch.save(self.optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt"))

        print(f"Saved checkpoint to {checkpoint_dir}")

def main():
    # Colab-optimized parameters
    model_name = "google/gemma-3-4b-it"
    dataset_name = "Anthropic/hh-rlhf"
    output_dir = "./gemini3-grpo"
    learning_rate = 1e-5
    batch_size = 2
    epochs = 2
    gradient_accumulation_steps = 8
    planning_steps = 2
    save_steps = 200

    # Clear memory before starting
    clear_memory()

    # Configuration with Colab optimizations
    grpo_config = GRPOConfig(
        planning_steps=planning_steps,
        gradient_checkpointing=True
    )

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    try:
        # Load smaller dataset
        dataset = CustomDataset(dataset_name, tokenizer, max_length=512)
        print(f"Dataset successfully created with {len(dataset)} samples")
    except Exception as e:
        print(f"Failed to create dataset: {str(e)}")
        return

    # Initialize trainer with Colab-optimized settings
    trainer = GRPOTrainer(
        model_name=model_name,
        dataset=dataset,
        output_dir=output_dir,
        config=grpo_config,
        learning_rate=learning_rate,
        batch_size=batch_size,
        epochs=epochs,
        gradient_accumulation_steps=gradient_accumulation_steps,
        save_steps=save_steps,
        mixed_precision=True
    )

    # Train with memory monitoring
    print("Starting training...")
    print(f"Free memory: {torch.cuda.mem_get_info()[0]/1024**2:.2f} MB free")
    trainer.train()

if __name__ == "__main__":
    main()


Login failed. Error: name 'login' is not defined


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loaded 32160 examples from Anthropic/hh-rlhf (subset for Colab)
Dataset successfully created with 32160 samples
Using device: cuda


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 2,228,224 || all params: 3,882,491,392 || trainable%: 0.0574


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision)


Starting training...
Free memory: 2126.12 MB free
Starting training on cuda


Epoch 1/2:   0%|          | 0/16080 [00:00<?, ?it/s]It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Epoch 1/2:   0%|          | 0/16080 [00:18<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 68.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 70.12 MiB is free. Process 132112 has 14.67 GiB memory in use. Of the allocated memory 14.53 GiB is allocated by PyTorch, and 18.89 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)