# 08.2a_rtx6000ada: RTX 6000 Ada Optimized Training

**Tuned for RTX 6000 Ada (48GB VRAM, 62GB RAM, 14 vCPUs)**

This notebook balances model size, batch size, and sequence length to achieve good GPU utilization on a single RTX 6000 Ada card.

## Target Performance

- RTX 6000 Ada theoretical: ~91 TFLOPS bfloat16
- Target MFU: ~30-40% (realistic for small-scale training)
- Expected throughput: 100-500 it/s

## Parameters

**Tuning guide:**
- If GPU utilization low: increase `BATCH_SIZE` or `MAX_SEQ_LEN`
- If OOM: decrease `BATCH_SIZE` or `HIDDEN_DIM`
- If too slow: decrease `NUM_TRAIN_STEPS` or `SAVE_EVERY_N_STEPS`

In [None]:
# ============================================================================
# TUNABLE PARAMETERS - Adjust these for experimentation
# ============================================================================

# Model architecture
VOCAB_SIZE = 128           # ASCII byte vocabulary
HIDDEN_DIM = 256          # Embedding dimension (256/512/768/1024)
N_LAYER = 6                # Transformer layers (4/6/8/12)
N_HEAD = 8                 # Attention heads (4/8/16)
MAX_SEQ_LEN = 1024         # Context window (512/1024/2048)

# Training (start conservative, scale up)
BATCH_SIZE = 64            # Per-device batch size (32/64/128/256)
GRADIENT_ACCUMULATION = 1  # Effective batch = BATCH_SIZE × this
NUM_TRAIN_STEPS = 10000    # Total training steps
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.01

# Initialization
INIT_MODE = "qwen"         # "normal" or "qwen"

# Checkpointing
SAVE_EVERY_N_STEPS = 100   # Snapshot frequency (50/100/500)

# Data loading
NUM_WORKERS = 0            # MUST be 0 for GPU dataset (no multiprocessing)

# Data
CORPUS_PATH = "../data/training_corpus.txt"
OUTPUT_DIR = f"../data/embeddings_{VOCAB_SIZE}vocab_{INIT_MODE}init_rtx6000ada"
OUTPUT_FILE = "embedding_evolution.safetensors"

RANDOM_SEED = 42

# ============================================================================
# END TUNABLE PARAMETERS
# ============================================================================

## Model Size Calculator

In [None]:
# Estimate model size before training
def estimate_model_size(vocab, hidden, layers, heads, seq_len):
    """Rough parameter count estimate for GPT-2 style model."""
    # Embeddings (input + output tied)
    emb = vocab * hidden
    
    # Per layer: attention + FFN
    attn = 4 * hidden * hidden  # QKV + output projection
    ffn = 8 * hidden * hidden   # 4x expansion
    layer_params = (attn + ffn) * layers
    
    total = emb + layer_params
    return total

estimated_params = estimate_model_size(VOCAB_SIZE, HIDDEN_DIM, N_LAYER, N_HEAD, MAX_SEQ_LEN)
estimated_size_mb = estimated_params * 2 / 1e6  # bfloat16 = 2 bytes

print(f"Estimated model size:")
print(f"  Parameters: {estimated_params:,}")
print(f"  Memory (bfloat16): {estimated_size_mb:.1f} MB")
print(f"\nBatch memory estimate:")
print(f"  Activations per example: ~{MAX_SEQ_LEN * HIDDEN_DIM * N_LAYER * 4 / 1e6:.1f} MB")
print(f"  Batch of {BATCH_SIZE}: ~{BATCH_SIZE * MAX_SEQ_LEN * HIDDEN_DIM * N_LAYER * 4 / 1e6:.1f} MB")
print(f"  Total estimate: ~{estimated_size_mb + BATCH_SIZE * MAX_SEQ_LEN * HIDDEN_DIM * N_LAYER * 4 / 1e6:.1f} MB")
print(f"\nShould fit easily in 48GB VRAM ✓")

## Imports

In [None]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments, TrainerCallback
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
from safetensors.torch import save_file
import time

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# Enable TF32 for Ada architecture
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

## Hardware Check

In [None]:
if not torch.cuda.is_available():
    raise RuntimeError("This notebook requires CUDA")

device = torch.device("cuda")
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
cuda_cap = torch.cuda.get_device_capability(0)

print(f"GPU: {gpu_name}")
print(f"VRAM: {gpu_memory:.1f} GB")
print(f"CUDA Capability: {cuda_cap[0]}.{cuda_cap[1]}")
print(f"TF32 enabled: {torch.backends.cuda.matmul.allow_tf32}")
print(f"\n✓ Hardware ready")

## Load Corpus (CPU → GPU)

In [None]:
print(f"Loading corpus: {CORPUS_PATH}")

with open(CORPUS_PATH, 'r', encoding='ascii') as f:
    corpus_text = f.read()

corpus_bytes = [b for b in corpus_text.encode('ascii') if b < VOCAB_SIZE]

unique_bytes = set(corpus_bytes)
dead_tokens = VOCAB_SIZE - len(unique_bytes)

print(f"  Total bytes: {len(corpus_bytes):,}")
print(f"  Unique: {len(unique_bytes)} / {VOCAB_SIZE}")
print(f"  Dead tokens: {dead_tokens} ({100 * dead_tokens / VOCAB_SIZE:.1f}%)")

# Pre-load to GPU
corpus_tensor = torch.tensor(corpus_bytes, dtype=torch.long, device=device)
print(f"\n✓ Corpus on GPU: {corpus_tensor.numel() * corpus_tensor.element_size() / 1e6:.2f} MB")

## GPU Dataset (Zero-Copy)

In [None]:
class GPUByteDataset(Dataset):
    def __init__(self, corpus_tensor, max_seq_len):
        self.corpus = corpus_tensor
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        return max(0, len(self.corpus) - self.max_seq_len)
    
    def __getitem__(self, idx):
        chunk = self.corpus[idx : idx + self.max_seq_len + 1]
        return {
            'input_ids': chunk[:-1],
            'labels': chunk[1:]
        }

dataset = GPUByteDataset(corpus_tensor, MAX_SEQ_LEN)
print(f"✓ Dataset: {len(dataset):,} examples")
print(f"  Tokens/example: {MAX_SEQ_LEN}")
print(f"  Tokens/epoch: {len(dataset) * MAX_SEQ_LEN:,}")

## Model Configuration

In [None]:
config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_SEQ_LEN,
    n_embd=HIDDEN_DIM,
    n_layer=N_LAYER,
    n_head=N_HEAD,
    resid_pdrop=0.0,
    embd_pdrop=0.0,
    attn_pdrop=0.0,
    tie_word_embeddings=True,
)

model = GPT2LMHeadModel(config)
model = model.to(torch.bfloat16).to(device)

total_params = sum(p.numel() for p in model.parameters())
embedding_params = model.transformer.wte.weight.numel()
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / 1e6

print(f"Model initialized:")
print(f"  Total parameters: {total_params:,}")
print(f"  Embedding parameters: {embedding_params:,}")
print(f"  Model size: {model_size_mb:.2f} MB (bfloat16)")
print(f"  Layers: {config.n_layer}, Heads: {config.n_head}, Hidden: {config.n_embd}")

## Initialize Embeddings

In [None]:
if INIT_MODE == "qwen":
    print(f"\nQwen-style initialization (singular unit vector)")
    with torch.no_grad():
        random_vector = torch.randn(HIDDEN_DIM, device=device)
        random_vector = random_vector / random_vector.norm()
        model.transformer.wte.weight[:] = random_vector
    print(f"  All {VOCAB_SIZE} tokens → unit vector (norm={random_vector.norm().item():.6f})")
else:
    print(f"\nNormal initialization (default PyTorch)")
    with torch.no_grad():
        norms = torch.norm(model.transformer.wte.weight, p=2, dim=1)
        print(f"  Token norms: min={norms.min().item():.6f}, max={norms.max().item():.6f}")

## Pre-allocate Embedding History

In [None]:
embedding_history = torch.zeros(
    (NUM_TRAIN_STEPS + 1, VOCAB_SIZE, HIDDEN_DIM),
    dtype=torch.bfloat16
)

embedding_history[0] = model.transformer.wte.weight.data.clone().cpu()

initial_centroid = embedding_history[0].mean(dim=0)
initial_centroid_norm = initial_centroid.norm().item()

history_size_mb = embedding_history.element_size() * embedding_history.numel() / 1e6

print(f"Embedding history allocated:")
print(f"  Shape: {embedding_history.shape}")
print(f"  Memory: {history_size_mb:.1f} MB")
print(f"  Initial centroid norm: {initial_centroid_norm:.6f}")

## Snapshot Callback

In [None]:
class EmbeddingSnapshotCallback(TrainerCallback):
    def __init__(self, embedding_history, output_dir, output_file, save_every_n):
        self.embedding_history = embedding_history
        self.output_path = Path(output_dir) / output_file
        self.save_every_n = save_every_n
        self.last_time = time.time()
        self.steps_since_print = 0
        
        Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    def on_step_end(self, args, state, control, model=None, **kwargs):
        step = state.global_step
        
        # Store in memory
        self.embedding_history[step] = model.transformer.wte.weight.data.clone().cpu()
        
        # Save to disk periodically
        should_save = (step % self.save_every_n == 0) or (step == args.max_steps)
        if should_save:
            save_file(
                {'embedding_history': self.embedding_history[:step+1]},
                self.output_path
            )
        
        # Print every 100 steps
        self.steps_since_print += 1
        if step % 100 == 0 and step > 0:
            elapsed = time.time() - self.last_time
            throughput = self.steps_since_print / elapsed
            
            embeddings = self.embedding_history[step]
            centroid_norm = embeddings.mean(dim=0).norm().item()
            
            marker = "[SAVED]" if should_save else ""
            print(f"[{step:5d}] {throughput:6.1f} it/s | centroid: {centroid_norm:.6f} {marker}")
            
            self.last_time = time.time()
            self.steps_since_print = 0
        
        return control

print(f"✓ Callback ready (save every {SAVE_EVERY_N_STEPS} steps)")

## Training Configuration

In [None]:
effective_batch = BATCH_SIZE * GRADIENT_ACCUMULATION

training_args = TrainingArguments(
    output_dir="./training_output",
    max_steps=NUM_TRAIN_STEPS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    
    # Logging
    logging_steps=100,
    
    # Checkpointing (disable model saves, we only want embeddings)
    save_steps=NUM_TRAIN_STEPS + 1,
    save_total_limit=0,
    
    # Data loading (MUST be 0 for GPU dataset)
    dataloader_num_workers=NUM_WORKERS,
    dataloader_pin_memory=False,
    
    # Precision
    bf16=True,
    tf32=True,
    
    # Misc
    seed=RANDOM_SEED,
    report_to="none",
    disable_tqdm=False,
)

print(f"Training configuration:")
print(f"  Steps: {NUM_TRAIN_STEPS:,}")
print(f"  Batch size: {BATCH_SIZE} × {GRADIENT_ACCUMULATION} = {effective_batch}")
print(f"  Tokens/step: {effective_batch * MAX_SEQ_LEN:,}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Precision: bfloat16 + TF32")

## Create Trainer

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    callbacks=[EmbeddingSnapshotCallback(
        embedding_history,
        OUTPUT_DIR,
        OUTPUT_FILE,
        SAVE_EVERY_N_STEPS
    )],
)

print("✓ Trainer ready")

## Train

In [None]:
print(f"\n{'='*80}")
print(f"Starting training...")
print(f"{'='*80}\n")

start_time = time.time()
trainer.train()
elapsed = time.time() - start_time

avg_throughput = NUM_TRAIN_STEPS / elapsed

print(f"\n{'='*80}")
print(f"✓ Training complete")
print(f"Total time: {elapsed:.1f}s ({elapsed/60:.1f} min)")
print(f"Average: {avg_throughput:.1f} it/s")
print(f"{'='*80}")

## Final Save & Analysis

In [None]:
# Final save
output_path = Path(OUTPUT_DIR) / OUTPUT_FILE
save_file({'embedding_history': embedding_history}, output_path)

# Analysis
final_embeddings = embedding_history[-1]
final_norms = torch.norm(final_embeddings, p=2, dim=1)
final_centroid = final_embeddings.mean(dim=0)
final_centroid_norm = final_centroid.norm().item()
centroid_displacement = (final_centroid - initial_centroid).norm().item()

print(f"\nSummary:")
print(f"  Model: {total_params:,} params, {N_LAYER} layers")
print(f"  Init: {INIT_MODE}")
print(f"  Steps: {NUM_TRAIN_STEPS:,}")
print(f"  Batch: {effective_batch} × {MAX_SEQ_LEN} tokens")
print(f"\nEmbedding evolution:")
print(f"  Initial centroid norm: {initial_centroid_norm:.6f}")
print(f"  Final centroid norm: {final_centroid_norm:.6f}")
print(f"  Displacement: {centroid_displacement:.6f}")
print(f"\nFinal token norms:")
print(f"  Min: {final_norms.min().item():.6f}")
print(f"  Max: {final_norms.max().item():.6f}")
print(f"  Mean: {final_norms.mean().item():.6f}")
print(f"\nSaved to: {output_path}")
print(f"File size: {output_path.stat().st_size / 1e6:.1f} MB")