In [None]:
import os
import json
import torch
import wandb
import random
import pathlib
import logging
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase

from cs336_alignment.data_loading import iterate_batches

# Run out of memory running the normal dataset
# Create a custom PackedSFTDataset class to load piece of the dataset
class PackedSFTDataset(Dataset):
    def __init__(self, tokenizer: PreTrainedTokenizerBase, dataset_path: str, seq_length: int, shuffle: bool, max_samples: int = None):
        self.tokenizer = tokenizer
        self.seq_length = seq_length

        with open(dataset_path, "r", encoding="utf-8") as f:
            raw_data = [json.loads(line) for line in f]

        if shuffle:
            import random
            random.shuffle(raw_data)

        if max_samples is not None:
            raw_data = raw_data[:max_samples]

        self.inputs = []
        self.outputs = []
        
        for ex in raw_data:
            sample = (
                "Below is an instruction that describes a task. Write a response that appropriately completes the request."
                f"\n\n### Instruction:\n{ex['prompt']}\n\n### Response:\n{ex['response']}"
            )
            tokenized = self.tokenizer.encode(sample, truncation=True, max_length=self.seq_length+1)
            if len(tokenized) >= 2:
                self.inputs.append(tokenized[:-1])
                self.outputs.append(tokenized[1:])

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_ids = self.inputs[idx]
        labels = self.outputs[idx]

        input_ids += [self.tokenizer.pad_token_id] * (self.seq_length - len(input_ids))
        labels += [-100] * (self.seq_length - len(labels))  # ignore padding in loss

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long)
        }

# Configs
OUTPUT_DIR = "/home/alvin/Homework/s2025-assignment3-alignment/outputs/qwen2.5-3b-sft-2"
PROJECT_NAME = "EE491B_qwen2.5-3B"
os.makedirs(OUTPUT_DIR, exist_ok=True)



In [None]:
args = {
    "train_path": "/home/alvin/Homework/s2025-assignment3-alignment/data/tuning/safety_augmented_ultrachat_200k_single_turn/train.jsonl",
    "dev_path": "/home/alvin/Homework/s2025-assignment3-alignment/data/tuning/safety_augmented_ultrachat_200k_single_turn/test.jsonl",
    "output_dir": OUTPUT_DIR,
    "model_path": "/home/alvin/Homework/s2025-assignment3-alignment/models/Qwen/Qwen2.5-3B-Instruct",
    "vocab_size": 151936,  # this value depends on the tokenizer used by Qwen2.5
    "context_length": 2048,
    "d_model": 2560,
    "num_layers": 32,
    "num_heads": 32,
    "d_ff": 10240,
    "attn_pdrop": 0.1,
    "residual_pdrop": 0.1,
    "batch_size": 1,
    "train_steps": 20,
    "gradient_accumulation_steps": 8,
    "eval_iters": 5,
    "eval_interval": 10,
    "learning_rate": 5e-5,
    "lr_scheduler": "cosine",
    "warmup_ratio": 0.03,
    "weight_decay": 0.1,
    "adam_beta1": 0.9,
    "adam_beta2": 0.95,
    "adam_eps": 1e-8,
    "grad_clip": 1.0,
    "device": "cuda",
    "compile": False,
    "dtype": "bfloat16",
    "wandb_project": PROJECT_NAME  # Or set to your project name if using Weights & Biases
}

In [None]:
# WandB Logging
wandb.init(project=PROJECT_NAME, config={
    "batch_size": args["batch_size"],
    "grad_accumulation_steps": args["gradient_accumulation_steps"],
    "train_steps": args["train_steps"],
    "learning_rate": args["learning_rate"],
    "model": args["model_path"],
})

"""
Train a language model on one or multiple GPUs.

To run single-GPU training:

```
python scripts/train.py
```

To run multi-GPU training, use `torchrun`. e.g., for single-node, 2 GPU:

```
torchrun --standalone --nproc_per_node=2 scripts/train.py
```
"""
from __future__ import annotations

import argparse
import json
import logging
import os
import pathlib
import sys
from contextlib import nullcontext

import numpy as np
import numpy.typing as npt
import torch
import torch.nn.functional as F
import wandb
from cs336_basics.data import get_batch
from cs336_basics.model import TransformerLM
from cs336_basics.optimizer import get_cosine_lr
from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm

logger = logging.getLogger(__name__)


def train(
    train_path,
    dev_path,
    output_dir,
    model_path,
    vocab_size,
    context_length,
    d_model,
    num_layers,
    num_heads,
    d_ff,
    attn_pdrop,
    residual_pdrop,
    batch_size,
    train_steps,
    gradient_accumulation_steps,
    eval_iters,
    eval_interval,
    learning_rate,
    lr_scheduler,
    warmup_ratio,
    weight_decay,
    adam_beta1,
    adam_beta2,
    adam_eps,
    grad_clip,
    device,
    compile,
    dtype,
    wandb_project,
):
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    train_data = PackedSFTDataset(tokenizer, train_path, context_length, shuffle=True, max_samples=8192)
    dev_data = PackedSFTDataset(tokenizer, dev_path, context_length, shuffle=False, max_samples=4096)
    train_loader = iterate_batches(train_data, batch_size=batch_size, shuffle=True)
    dev_loader = iterate_batches(dev_data, batch_size=batch_size, shuffle=False)
    model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float32 if torch.cuda.is_available() else torch.float32)
    model = model.to(device)

    # Wrap model in DDP, if we're using it.
    is_ddp = int(os.environ.get("RANK", -1)) != -1
    if is_ddp:
        init_process_group(backend="nccl")
        ddp_rank = int(os.environ["RANK"])
        ddp_local_rank = int(os.environ["LOCAL_RANK"])
        ddp_world_size = int(os.environ["WORLD_SIZE"])
        device = f"cuda:{ddp_local_rank}"
        torch.cuda.set_device(device)
        seed = ddp_rank  # each process gets a different seed
        # Rank 0 does logging, file creation, etc.
        is_master_process = ddp_rank == 0
    else:
        seed = 0
        ddp_world_size = 1
        is_master_process = True

    if is_master_process:
        logger.info(
            "Total number of tokens per training step: "
            + str(
                gradient_accumulation_steps
                * ddp_world_size
                * batch_size
                * context_length
            )
        )

    # Seed each process differently so we can be sure that they
    # see different data batches.
    # NOTE: This assumes that you're using torch RNG, you may have
    # to seed numpy too as well if your code uses numpy random functions.
    torch.manual_seed(seed)

    # Save the model config
    if is_master_process:
        model_config_output_path = os.path.join(output_dir, "model_config.json")
        logger.info(f"Saving model config to {model_config_output_path}")
        with open(model_config_output_path, "w") as f:
            json.dump(model.config.to_dict(), f, indent=4)

    device_type = "cuda" if "cuda" in device else "cpu"
    torch_dtype = {
        "float32": torch.float32,
        "bfloat16": torch.bfloat16,
        "float16": torch.float16,
    }[dtype]
    if is_master_process:
        logger.info(f"Using dtype: {torch_dtype}")
    amp_ctx = (
        nullcontext()
        if device_type == "cpu"
        else torch.amp.autocast(device_type=device_type, dtype=torch_dtype)
    )
    # GradScaler is only used for FP16
    scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))

    # Move model to the device
    model = model.to(device)

    # compile the model, requires torch 2.0
    if compile:
        torch.set_float32_matmul_precision("high")
        model = torch.compile(model)

    if is_ddp:
        model = DDP(model, device_ids=[ddp_local_rank])

    # Set up the AdamW optimizer.
    # We do not apply decay on 1D parameters (e.g., biases and RMSNorms)
    param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
    params_to_decay = [p for _, p in param_dict.items() if p.dim() >= 2]
    params_to_not_decay = [p for _, p in param_dict.items() if p.dim() < 2]
    optim_groups = [
        {"params": params_to_decay, "weight_decay": weight_decay},
        {"params": params_to_not_decay, "weight_decay": 0.0},
    ]
    optimizer = torch.optim.AdamW(
        optim_groups,
        lr=learning_rate,
        betas=(adam_beta1, adam_beta2),
        eps=adam_eps,
    )

    # Get the first batch
    train_iter = iter(train_loader)
    for i in tqdm(range(train_steps)):
        for micro_step_idx in range(gradient_accumulation_steps):
            try:
                batch = next(train_iter)
            except StopIteration:
                train_iter = iter(train_loader)
                batch = next(train_iter)

            batch_x = batch["input_ids"].to(device)
            batch_y = batch["labels"].to(device)

            if is_ddp:
                model.require_backward_grad_sync = (micro_step_idx == gradient_accumulation_steps - 1)

            with amp_ctx:
                outputs = model(batch_x)
                logits = outputs.logits
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), batch_y.view(-1))
            scaler.scale(loss).backward()

        if grad_clip:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

        loss_float = loss.item() * gradient_accumulation_steps
        if is_master_process:
            logger.info(f"Train step {i}, Loss: {loss_float}")
            if wandb_project:
                wandb.log({"train_loss": loss_float, "lr": lr}, step=i)

        if i != 0 and i % eval_interval == 0 and is_master_process:
            dev_loss = estimate_dev_loss(
                model=model,
                dev_loader=dev_loader,
                eval_iters=eval_iters,
                device=device,
            )
            logger.info(f"Estimated validation loss: {dev_loss}")
            if wandb_project:
                wandb.log({"eval_loss": dev_loss}, step=i)

    # Calculate final estimated dev loss
    if is_master_process:
        dev_loss = estimate_dev_loss(
            model=model,
            dev_loader=dev_loader,
            eval_iters=eval_iters,
            device=device,
        )
        logger.info(f"Final estimated validation loss: {dev_loss}")
        if wandb_project:
            wandb.log({"eval_loss": dev_loss}, step=train_steps)
        # Save the model weights
        model_weights_output_path = os.path.join(output_dir, "model.pt")
        logger.info(f"Saving model weights to {model_weights_output_path}")
        torch.save(model.state_dict(), model_weights_output_path)

    if is_ddp:
        destroy_process_group()


@torch.no_grad()
def estimate_dev_loss(
    model: TransformerLM,
    dev_loader: torch.utils.data.DataLoader,
    eval_iters: int,
    device: str,
):
    model.eval()
    losses = torch.zeros(eval_iters)
    dev_iter = iter(dev_loader)
    for k in tqdm(range(eval_iters)):
        try:
            batch = next(dev_iter)
        except StopIteration:
            dev_iter = iter(dev_loader)
            batch = next(dev_iter)

        batch_x = batch["input_ids"].to(device)
        batch_y = batch["labels"].to(device)

        logits = model(batch_x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), batch_y.view(-1))
        losses[k] = loss.item()
    model.train()
    return losses.mean()



[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


In [16]:
train(**args)

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 25.34it/s]
  scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))
  0%|          | 0/20 [09:11<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 86.00 MiB. GPU 0 has a total capacity of 47.99 GiB of which 0 bytes is free. Process 9861 has 17179869184.00 GiB memory in use. Process 46735 has 17179869184.00 GiB memory in use. Process 17381 has 17179869184.00 GiB memory in use. Of the allocated memory 108.84 GiB is allocated by PyTorch, and 1.08 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)