# Training Pipeline Proof-of-Concept

In [None]:
import os
import argparse
from typing import Dict, Any
import dotenv
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.onnx
from torch.quantization import quantize_dynamic
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
import ray
from ray import train
import wandb
import onnx
import horovod.torch as hvd

sns.set_theme(style="whitegrid")
dotenv.load_dotenv()

In [6]:
# NOTE: This cell is tagger `parameters`.
model_name = "gpt2"  # @param
data_path = "data/train.txt"  # @param
output_dir = "tmp"  # @param
wandb_project = "llm-finetuning"  # @param
wandb_run_name = "test-run"  # @param
distributed = False  # @param
batch_size = 8  # @param
num_workers = 4  # @param
max_length = 512  # @param
learning_rate = 5e-5  # @param
weight_decay = 0.01  # @param
num_epochs = 3  # @param
warmup_steps = 500  # @param

In [16]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available(): # NOTE: Apple NPU
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Device: {device}")

Device: mps


In [7]:
os.makedirs(output_dir, exist_ok=True)

In [11]:
if distributed:
    ray.init()
    hvd.init()
    torch.cuda.set_device(hvd.local_rank())

In [8]:
class TextDataset(Dataset):
    def __init__(self, data_path: str, tokenizer, max_length: int = 512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        with open(data_path, 'r') as f:
            self.texts = f.readlines()
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        encodings = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        return {
            'input_ids': encodings['input_ids'].squeeze(),
            'attention_mask': encodings['attention_mask'].squeeze()
        }

In [None]:
if not distributed or hvd.rank() == 0:
    api_key = os.getenv("WANDB_API_KEY")

    if api_key:
        wandb.login()
        wandb.init(
            project=wandb_project,
            name=wandb_run_name,
            config={
                'model_name': model_name,
                'batch_size': batch_size,
                'learning_rate': learning_rate,
                'num_epochs': num_epochs,
                'max_length': max_length
            }
        )
        print("Weights and Biases initialized successfully.")
    else:
        print("Error: WANDB_API_KEY environment variable is not set. Please define it before proceeding.")