In [None]:
import os
import json
import torch
import wandb
import random
import pathlib
import logging
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase

from cs336_alignment.data_loading import iterate_batches

# Run out of memory running the normal dataset
# Create a custom PackedSFTDataset class to load piece of the dataset
class PackedSFTDataset(Dataset):
    def __init__(self, tokenizer: PreTrainedTokenizerBase, dataset_path: str, seq_length: int, shuffle: bool, max_samples: int = None):
        with open(dataset_path, "r", encoding="utf-8") as f:
            raw_data = [json.loads(line) for line in f]

        if shuffle:
            random.shuffle(raw_data)

        if max_samples is not None:
            raw_data = raw_data[:max_samples]

        self.inputs = []
        self.outputs = []
        
        for ex in raw_data:
            sample = (
                "Below is an instruction that describes a task. Write a response that appropriately completes the request."
                f"\n\n### Instruction:\n{ex['prompt']}\n\n### Response:\n{ex['response']}"
            )
            tokenized = tokenizer.encode(sample, truncation=True, max_length=seq_length+1)
            if len(tokenized) >= 2:
                self.inputs.append(tokenized[:-1])
                self.outputs.append(tokenized[1:])

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_ids = self.inputs[idx]
        labels = self.outputs[idx]
        
        input_ids += [tokenizer.pad_token_id] * (SEQ_LENGTH - len(input_ids))
        labels += [-100] * (SEQ_LENGTH - len(labels))  # ignore padding in loss
        
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long)
        }

# Configs
TRAIN_PATH = "/home/alvin/Homework/s2025-assignment3-alignment/data/tuning/safety_augmented_ultrachat_200k_single_turn/train.jsonl"
DEV_PATH = "/home/alvin/Homework/s2025-assignment3-alignment/data/tuning/safety_augmented_ultrachat_200k_single_turn/test.jsonl"
MODEL_PATH = "/home/alvin/Homework/s2025-assignment3-alignment/models/Qwen/Qwen2.5-3B-Instruct"
OUTPUT_DIR = "./qwen2.5-3B-instruct-finetuned"
PROJECT_NAME = "EE491B_qwen2.5-3B"

SEQ_LENGTH = 2048
BATCH_SIZE = 1
GRAD_ACCUMULATION_STEPS = 8
TRAIN_STEPS = 100
EVAL_INTERVAL = 50
EVAL_ITERS = 50
LEARNING_RATE = 5e-5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

os.makedirs(OUTPUT_DIR, exist_ok=True)

# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True, torch_dtype=torch.float32 if torch.cuda.is_available() else torch.float32)
model = model.to(DEVICE)

# Dataset and Dataloaders
train_dataset = PackedSFTDataset(tokenizer, TRAIN_PATH, SEQ_LENGTH, shuffle=True, max_samples=100)
dev_dataset = PackedSFTDataset(tokenizer, DEV_PATH, SEQ_LENGTH, shuffle=False, max_samples=10)
train_loader = iterate_batches(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_loader = iterate_batches(dev_dataset, batch_size=BATCH_SIZE, shuffle=False)


  from .autonotebook import tqdm as notebook_tqdm
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 38.18it/s]


In [None]:

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

# WandB Logging
wandb.init(project=PROJECT_NAME, config={
    "batch_size": BATCH_SIZE,
    "grad_accumulation_steps": GRAD_ACCUMULATION_STEPS,
    "train_steps": TRAIN_STEPS,
    "learning_rate": LEARNING_RATE,
    "model": MODEL_PATH,
})

# Training Loop
model.train()
step = 0
optimizer.zero_grad()
for epoch in range(100):
    pbar = tqdm(train_loader, desc=f"Training epoch {epoch}")
    for batch in pbar:
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss / GRAD_ACCUMULATION_STEPS
        loss.backward()

        if (step + 1) % GRAD_ACCUMULATION_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()

        if step % 10 == 0:
            wandb.log({"train/loss": loss.item() * GRAD_ACCUMULATION_STEPS, "step": step})

        if step > 0 and step % EVAL_INTERVAL == 0:
            model.eval()
            eval_losses = []
            with torch.no_grad():
                for eval_batch in dev_loader:
                    eval_batch = {k: v.to(DEVICE) for k, v in eval_batch.items()}
                    outputs = model(**eval_batch)
                    eval_losses.append(outputs.loss.item())
            val_loss = sum(eval_losses) / len(eval_losses)
            wandb.log({"val/loss": val_loss, "step": step})
            model.train()

        step += 1

        if step >= TRAIN_STEPS:
            break
    if step >= TRAIN_STEPS:
        break

# Save model
save_path = os.path.join(OUTPUT_DIR, "final_model")
os.makedirs(save_path, exist_ok=True)
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
wandb.finish()

print(f"Training complete. Model saved at {save_path}")


[34m[1mwandb[0m: Currently logged in as: [33malvinyang101[0m ([33malvinyang101-university-of-hawaii-at-manoa[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training epoch 0:  99%|█████████▉| 99/100 [51:33<00:31, 31.25s/it]  
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
step,▁▂▃▃▄▅▅▆▆▇█
train/loss,█▄▂▃▂▁▂▁▁▂
val/loss,▁

0,1
step,90.0
train/loss,4.01274
val/loss,4.85122


Training complete. Model saved at ./qwen2.5-3B-instruct-finetuned/final_model
