# Training Pipeline Proof-of-Concept

In [None]:
import os
import argparse
from typing import Dict, Any
import dotenv
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.onnx
from torch.quantization import quantize_dynamic
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
import ray
from ray import train
import wandb
import onnx
import horovod.torch as hvd

sns.set_theme(style="whitegrid")
dotenv.load_dotenv()

In [6]:
# NOTE: This cell is tagger `parameters`.
model_name = "gpt2"  # @param
data_path = "data/train.txt"  # @param
output_dir = "tmp"  # @param
wandb_project = "llm-finetuning"  # @param
wandb_run_name = "test-run"  # @param
distributed = False  # @param
batch_size = 8  # @param
num_workers = 4  # @param
max_length = 512  # @param
learning_rate = 5e-5  # @param
weight_decay = 0.01  # @param
num_epochs = 3  # @param
warmup_steps = 500  # @param

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Device: {device}")

Device: mps


In [7]:
os.makedirs(output_dir, exist_ok=True)

In [11]:
if distributed:
    ray.init()
    hvd.init()
    torch.cuda.set_device(hvd.local_rank())

In [8]:
class TextDataset(Dataset):
    def __init__(self, data_path: str, tokenizer, max_length: int = 512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        with open(data_path, 'r') as f:
            self.texts = f.readlines()
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        encodings = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        return {
            'input_ids': encodings['input_ids'].squeeze(),
            'attention_mask': encodings['attention_mask'].squeeze()
        }

In [None]:
if not distributed or hvd.rank() == 0:
    api_key = os.getenv("WANDB_API_KEY")

    if api_key:
        wandb.login()
        wandb.init(
            project=wandb_project,
            name=wandb_run_name,
            config={
                'model_name': model_name,
                'batch_size': batch_size,
                'learning_rate': learning_rate,
                'num_epochs': num_epochs,
                'max_length': max_length
            }
        )
        print("Weights and Biases initialized successfully.")
    else:
        print("Error: WANDB_API_KEY environment variable is not set. Please define it before proceeding.")

In [None]:
print(f"Using device: {device}")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)

dataset = TextDataset(data_path, tokenizer, max_length)

if distributed:
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=hvd.size(), rank=hvd.rank()
    )
else:
    train_sampler = None

train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=(train_sampler is None),
    sampler=train_sampler,
    num_workers=num_workers
)

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay
)

if distributed:
    optimizer = hvd.DistributedOptimizer(
        optimizer,
        named_parameters=model.named_parameters()
    )
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

num_training_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=num_training_steps
)

In [None]:
best_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}") as pbar:
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            
            loss = outputs.loss
            total_loss += loss.item()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            pbar.set_postfix({'loss': loss.item()})
            
            if not distributed or hvd.rank() == 0:
                wandb.log({
                    'loss': loss.item(),
                    'learning_rate': scheduler.get_last_lr()[0]
                })
    
    avg_loss = total_loss / len(train_loader)
    
    # NOTE: Save checkpoint if best loss
    if (not distributed or hvd.rank() == 0) and avg_loss < best_loss:
        best_loss = avg_loss
        checkpoint_path = os.path.join(output_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, checkpoint_path)
        print(f"Saved checkpoint with loss: {best_loss:.4f}")

In [None]:
if not distributed or hvd.rank() == 0:
    model.eval()
    dummy_input = torch.zeros(
        (1, max_length),
        dtype=torch.long,
        device=device
    )
    
    onnx_path = os.path.join(output_dir, 'model.onnx')
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        },
        opset_version=11
    )
    print(f"Exported ONNX model to: {onnx_path}")
    
    # NOTE: Verify onnx
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)

In [None]:
# NOTE: Quantizatiion (Work-in-progress)
if not distributed or hvd.rank() == 0:
    quantized_model = quantize_dynamic(
        model,
        {nn.Linear},
        dtype=torch.qint8
    )
    
    quantized_path = os.path.join(output_dir, 'quantized_model.pt')
    torch.save(
        quantized_model.state_dict(),
        quantized_path
    )
    print(f"Saved quantized model to: {quantized_path}")

In [None]:
# NOTE: Cleanup
if distributed:
    ray.shutdown()

if not distributed or hvd.rank() == 0:
    wandb.finish()