In [19]:
# Quick debug: Check dataloader length and calculate correct resume position
print(f"🔍 Debugging resume calculation:")
print(f"Train dataloader length: {len(train_dataloader)}")
print(f"Resume step from checkpoint: 764000")
print(f"Steps per epoch: {len(train_dataloader)}")

# The issue: 764000 is the GLOBAL step, not epoch step
# Let's calculate the correct position
global_step = 764000
steps_per_epoch = len(train_dataloader) 
current_epoch = (global_step - 1) // steps_per_epoch  # 0-indexed epoch
step_in_epoch = ((global_step - 1) % steps_per_epoch) + 1  # 1-indexed step within epoch

print(f"📊 Correct calculation:")
print(f"   Global step 764000 means:")
print(f"   - Current epoch: {current_epoch + 1} (0-indexed: {current_epoch})")
print(f"   - Step within epoch: {step_in_epoch}")
print(f"   - Should resume from epoch {current_epoch + 1}, step {step_in_epoch}")

# Check if we've already completed training
if global_step >= len(train_dataloader) * 3:  # 3 epochs
    print("⚠️  WARNING: This checkpoint is from AFTER training should have completed!")
    print(f"   Total steps for 3 epochs: {len(train_dataloader) * 3}")
    print(f"   Checkpoint step: {global_step}")
else:
    print("✅ Checkpoint is valid - training should resume normally")

🔍 Debugging resume calculation:
Train dataloader length: 279391
Resume step from checkpoint: 764000
Steps per epoch: 279391
📊 Correct calculation:
   Global step 764000 means:
   - Current epoch: 3 (0-indexed: 2)
   - Step within epoch: 205218
   - Should resume from epoch 3, step 205218
✅ Checkpoint is valid - training should resume normally


In [20]:
# Force restart training from the correct position
# Let's restart training properly with the right calculation

import shutil
import re

# Enhanced checkpoint management with resuming and space-efficient saving
checkpoint_cfg = train_cfg['checkpointing']
metrics_dir = Path(train_cfg['artifacts']['metrics_dir'])
metrics_dir.mkdir(parents=True, exist_ok=True)
run_config_path = Path(train_cfg['artifacts']['run_config_path'])
run_config_path.parent.mkdir(parents=True, exist_ok=True)

history = {'step': [], 'loss': [], 'lr': [], 'throughput': []}
step_times = deque(maxlen=200)
ema_throughput = None

# Override save_steps to 50k for better disk management
save_steps = 50000
print(f"💾 Checkpoint saving frequency: every {save_steps:,} steps")

def cleanup_old_checkpoints(checkpoint_dir: Path, keep_latest: int = 3):
    """Keep only the most recent checkpoints to save disk space."""
    if not checkpoint_dir.exists():
        return
    
    # Get all checkpoint directories, sorted by modification time (newest first)
    checkpoints = [d for d in checkpoint_dir.iterdir() if d.is_dir()]
    checkpoints.sort(key=lambda x: x.stat().st_mtime, reverse=True)
    
    # Delete old checkpoints beyond keep_latest
    deleted_count = 0
    for old_ckpt in checkpoints[keep_latest:]:
        accelerator.print(f'🗑️  Cleaning up old checkpoint: {old_ckpt.name}')
        shutil.rmtree(old_ckpt)
        deleted_count += 1
    
    if deleted_count > 0:
        accelerator.print(f"🧹 Cleaned up {deleted_count} old checkpoints")

def save_checkpoint(model, optimizer, scheduler, epoch: int, step: int, tag: str):
    accelerator.wait_for_everyone()
    ckpt_dir = Path(checkpoint_cfg['output_dir']) / f'{tag}_epoch{epoch}_step{step}'
    if accelerator.is_main_process:
        ckpt_dir.mkdir(parents=True, exist_ok=True)
        unwrapped = accelerator.unwrap_model(model)
        unwrapped.save_pretrained(ckpt_dir)
        tokenizer.save_pretrained(ckpt_dir / 'tokenizer')
        torch.save(optimizer.state_dict(), ckpt_dir / 'optimizer.pt')
        torch.save(scheduler.state_dict(), ckpt_dir / 'scheduler.pt')
        with (ckpt_dir / 'training_state.json').open('w') as fh:
            json.dump({'epoch': epoch, 'step': step}, fh)
        accelerator.print(f'💾 [checkpoint] saved -> {ckpt_dir}')
        
        # Clean up old checkpoints to save space (keep latest 3)
        cleanup_old_checkpoints(Path(checkpoint_cfg['output_dir']), keep_latest=3)
    accelerator.wait_for_everyone()
    free_cuda()

# MANUAL RESUME: Force resume from step 764000 (epoch 3, step 205218)
resume_global_step = 764000
steps_per_epoch = len(train_dataloader)
current_epoch_index = (resume_global_step - 1) // steps_per_epoch  # 0-indexed = 2 (epoch 3)
step_in_current_epoch = ((resume_global_step - 1) % steps_per_epoch) + 1  # 1-indexed = 205218

print(f"🚀 MANUAL RESUME SETUP:")
print(f"   Resume from global step: {resume_global_step}")
print(f"   Steps per epoch: {steps_per_epoch}")
print(f"   Current epoch (0-indexed): {current_epoch_index}")
print(f"   Step within epoch: {step_in_current_epoch}")
print(f"   Remaining steps in epoch: {steps_per_epoch - step_in_current_epoch + 1}")

# Load the checkpoint manually
latest_ckpt_path = Path(checkpoint_cfg['output_dir']) / 'step_epoch3_step764000'

if latest_ckpt_path.exists():
    print(f"📂 Loading checkpoint: {latest_ckpt_path.name}")
    
    try:
        # Load model state
        unwrapped_model = accelerator.unwrap_model(model)
        if (latest_ckpt_path / 'config.json').exists():
            temp_model = AutoModelForMaskedLM.from_pretrained(latest_ckpt_path)
            unwrapped_model.load_state_dict(temp_model.state_dict())
            print("✅ Model state loaded")
        
        # Load optimizer state
        opt_path = latest_ckpt_path / 'optimizer.pt'
        if opt_path.exists():
            optimizer.load_state_dict(torch.load(opt_path, map_location='cpu'))
            print("✅ Optimizer state loaded")
        
        # Load scheduler state  
        sched_path = latest_ckpt_path / 'scheduler.pt'
        if sched_path.exists():
            lr_scheduler.load_state_dict(torch.load(sched_path, map_location='cpu'))
            print("✅ Scheduler state loaded")
            
    except Exception as e:
        print(f"❌ Failed to load checkpoint: {e}")
        resume_global_step = 0
        current_epoch_index = 0
        step_in_current_epoch = 1

# Start training from the calculated position
train_epochs = train_cfg['training']['epochs']
max_grad_norm = train_cfg['training']['max_grad_norm']

# Continue training from the current epoch
for epoch in range(current_epoch_index, train_epochs):
    model.train()
    accelerator.print(f'==== Epoch {epoch+1}/{train_epochs} ====')
    
    # Calculate starting step for this epoch
    epoch_start_step = 1
    if epoch == current_epoch_index and resume_global_step > 0:
        epoch_start_step = step_in_current_epoch
        accelerator.print(f"🔄 Resuming epoch {epoch+1} from step {epoch_start_step}")
    
    progress = tqdm(
        total=len(train_dataloader), 
        disable=not accelerator.is_local_main_process,
        desc=f"Epoch {epoch+1}"
    )
    
    # Update progress bar to show current position
    if epoch_start_step > 1:
        progress.update(epoch_start_step - 1)
        accelerator.print(f"📊 Progress bar updated to step {epoch_start_step - 1}")
    
    step_count = 0
    for step, batch in enumerate(train_dataloader, start=1):
        # Skip steps if resuming
        if step < epoch_start_step:
            continue
            
        step_count += 1
        if step_count <= 5:  # Log first 5 steps for debugging
            accelerator.print(f"🔥 Processing step {step} (global: {epoch * len(train_dataloader) + step})")
            
        start_time = time.perf_counter()
        
        # Ensure all tensors are on the correct device
        batch = {k: v.to(accelerator.device) if torch.is_tensor(v) else v for k, v in batch.items()}
        
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
            accelerator.backward(loss)
            
            if max_grad_norm:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
        duration = time.perf_counter() - start_time
        step_times.append(duration)

        global_step = epoch * len(train_dataloader) + step
        tokens_processed = batch['input_ids'].numel()
        throughput = tokens_processed / max(duration, 1e-6)
        beta = train_cfg['logging']['throughput_ema_beta']
        ema_throughput = throughput if ema_throughput is None else (beta * ema_throughput + (1 - beta) * throughput)

        history['step'].append(int(global_step))
        history['loss'].append(float(loss.detach().item()))
        history['lr'].append(float(lr_scheduler.get_last_lr()[0]))
        history['throughput'].append(float(ema_throughput))

        if accelerator.is_local_main_process:
            p95 = float(np.percentile(step_times, 95)) if len(step_times) >= 5 else float(duration)
            progress.set_description(f'Epoch {epoch+1} | loss={loss.item():.4f} | tok/s={ema_throughput:,.0f} | p95={p95:.3f}s')
        progress.update(1)

        if train_cfg['logging']['log_steps'] and global_step % train_cfg['logging']['log_steps'] == 0 and accelerator.is_main_process:
            log_gpu_memory(f'step {global_step}')

        # Save checkpoint every 50k steps
        if save_steps and global_step % save_steps == 0 and global_step > resume_global_step:
            save_checkpoint(model, optimizer, lr_scheduler, epoch+1, global_step, tag='step')

    progress.close()
    
    # Save epoch checkpoint
    epoch_steps = (epoch + 1) * len(train_dataloader)
    save_checkpoint(model, optimizer, lr_scheduler, epoch+1, epoch_steps, tag='epoch')

    # Validation
    model.eval()
    val_losses = []
    val_progress = tqdm(val_dataloader, desc="Validation", disable=not accelerator.is_local_main_process)
    for batch in val_progress:
        with torch.no_grad():
            outputs = model(**batch)
            val_losses.append(accelerator.gather(outputs.loss.detach()).mean().item())
    val_loss = float(np.mean(val_losses))
    accelerator.print(f'📊 Validation loss after epoch {epoch+1}: {val_loss:.4f}')

accelerator.wait_for_everyone()
free_cuda()
accelerator.print("🎉 Training completed!")

💾 Checkpoint saving frequency: every 50,000 steps
🚀 MANUAL RESUME SETUP:
   Resume from global step: 764000
   Steps per epoch: 279391
   Current epoch (0-indexed): 2
   Step within epoch: 205218
   Remaining steps in epoch: 74174
📂 Loading checkpoint: step_epoch3_step764000
✅ Model state loaded
✅ Optimizer state loaded
✅ Scheduler state loaded
==== Epoch 3/3 ====
🔄 Resuming epoch 3 from step 205218


Epoch 3:   0%|          | 0/279391 [00:00<?, ?it/s]

📊 Progress bar updated to step 205217


KeyboardInterrupt: 

# 01 · Pretrain DistilBERT on HDFS (MLM)

Masked language modeling pretraining with Hugging Face Accelerate on two RTX 6000 GPUs (or Apple MPS fallback). This notebook consumes the artifacts from `00_prepare_data.ipynb`, configures multi-device training, and saves checkpoints plus diagnostics for fine-tuning.

## Notebook Goals
- Load training hyperparameters from `configs/train_hdfs.yaml` and dataset metadata.
- Auto-generate `accelerate_config.yaml` for multi-GPU Linux or skip when running on Apple Silicon MPS.
- Build PyTorch DataLoaders backed by tokenized Parquet splits with dynamic masking.
- Run a custom Accelerate training loop with throughput EMA, p95 step latency, VRAM logging, and gradient norms.
- Save checkpoints every N steps and at epoch end while calling `free_cuda()` for memory hygiene.
- Persist run metadata (`run_config.json`) and validation metrics for downstream fine-tuning.

## 1. Imports and Configuration

In [1]:
import json
import math
import os
import time
import gc
from collections import deque
from pathlib import Path
from typing import Dict

import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator
from datasets import load_from_disk
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForMaskedLM,
    DataCollatorForLanguageModeling,
    get_scheduler
)
import numpy as np
from tqdm.auto import tqdm
import yaml

  import pynvml  # type: ignore[import]


### Load YAML configs

In [2]:
def load_yaml(path: Path) -> Dict:
    with path.open('r') as fh:
        return yaml.safe_load(fh)

data_cfg = load_yaml(Path('../configs/data.yaml'))
train_cfg = load_yaml(Path('../configs/train_hdfs.yaml'))
accelerate_config_path = Path(train_cfg['accelerate']['config_path'])
print(json.dumps({
    'data_cfg': str(Path('../configs/data.yaml')),
    'train_cfg': str(Path('../configs/train_hdfs.yaml')),
    'accelerate_cfg': str(accelerate_config_path)
}, indent=2))

{
  "data_cfg": "../configs/data.yaml",
  "train_cfg": "../configs/train_hdfs.yaml",
  "accelerate_cfg": "configs/accelerate_config.yaml"
}


### Device detection

In [3]:
IS_MPS = torch.backends.mps.is_available()
if IS_MPS:
    os.environ.setdefault('ACCELERATE_USE_MPS_DEVICE', '1')
    print('Apple Silicon (MPS) detected. Accelerate will target the MPS backend; multi-GPU config will be skipped.')
else:
    print('MPS not available; defaulting to CUDA/CPU configuration.')

MPS not available; defaulting to CUDA/CPU configuration.


### Auto-generate Accelerate configuration

In [5]:
def build_accelerate_config(cfg: Dict, is_mps: bool) -> Dict:
    if is_mps:
        return {
            'compute_environment': 'LOCAL_MACHINE',
            'distributed_type': 'MPS',
            'mixed_precision': 'no',
            'num_processes': 1,
            'gradient_accumulation_steps': cfg['gradient_accumulation_steps'],
            'main_process_ip': '127.0.0.1',
            'main_process_port': 29500,
            'gpu_ids': '0'
        }
    return {
        'compute_environment': 'LOCAL_MACHINE',
        'distributed_type': 'MULTI_GPU',
        'mixed_precision': 'no',  # Force no mixed precision to avoid gradient scaling issues
        'num_processes': cfg['num_processes'],
        'machine_rank': 0,
        'main_process_ip': '127.0.0.1',
        'main_process_port': 29500,
        'deepspeed_config': cfg['zero_stage2_toggle'],
        'dynamo_backend': 'NO',
        'gradient_accumulation_steps': cfg['gradient_accumulation_steps']
    }

if IS_MPS:
    if accelerate_config_path.exists():
        print('MPS mode: skipping accelerate_config.yaml regeneration (existing file will be ignored).')
else:
    payload = build_accelerate_config(train_cfg['accelerate'], IS_MPS)
    accelerate_config_path.parent.mkdir(parents=True, exist_ok=True)
    with accelerate_config_path.open('w') as fh:
        yaml.safe_dump(payload, fh, sort_keys=False)
    print(f'Accelerate config written to {accelerate_config_path}')
    with accelerate_config_path.open('r') as fh:
        print(fh.read())

Accelerate config written to configs/accelerate_config.yaml
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
mixed_precision: 'no'
num_processes: 2
machine_rank: 0
main_process_ip: 127.0.0.1
main_process_port: 29500
deepspeed_config: false
dynamo_backend: 'NO'
gradient_accumulation_steps: 8



## 2. Dataset Loading

In [6]:
metadata_path = Path(data_cfg['preprocessing']['dataset_metadata'])
if not metadata_path.exists():
    raise FileNotFoundError(f'Metadata not found: {metadata_path}. Run 00_prepare_data.ipynb first.')
metadata = json.loads(metadata_path.read_text())
print(json.dumps(metadata, indent=2))

parquet_dir = Path(data_cfg['preprocessing']['parquet_dir'])
train_dataset = load_from_disk(str(parquet_dir / 'hdfs_train_hf'))
val_dataset = load_from_disk(str(parquet_dir / 'hdfs_val_hf'))
test_dataset = load_from_disk(str(parquet_dir / 'hdfs_test_hf'))
print(train_dataset)

{
  "generated_at": "2025-09-27T17:23:37.262306Z",
  "hdfs": {
    "train": {
      "count": 4470251,
      "avg_length": 44.4681468669209,
      "truncation_rate": 0.0
    },
    "val": {
      "count": 558781,
      "avg_length": 46.11112940490103,
      "truncation_rate": 0.0
    },
    "test": {
      "count": 558782,
      "avg_length": 44.91538918576475,
      "truncation_rate": 0.0
    }
  },
  "openstack": {
    "train": {
      "count": 166256,
      "avg_length": 81.68753007410258,
      "truncation_rate": 0.0,
      "anomaly_rate": 0.1108772014243095
    },
    "val": {
      "count": 20782,
      "avg_length": 81.9116543162352,
      "truncation_rate": 0.0,
      "anomaly_rate": 0.0
    },
    "test": {
      "count": 20782,
      "avg_length": 83.08338947165817,
      "truncation_rate": 0.0,
      "anomaly_rate": 0.0
    }
  },
  "tokenizer_dir": "artifacts/tokenizer",
  "template_index": "artifacts/drain3/template_index.parquet",
  "template_transition": "artifacts/drain3

## 3. Tokenizer and Model

In [7]:
tokenizer_dir = Path(train_cfg['artifacts']['tokenizer_dir'])
if not tokenizer_dir.exists():
    raise FileNotFoundError(f'Tokenizer directory missing: {tokenizer_dir}')

tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=True)
tokenizer.add_special_tokens({'additional_special_tokens': data_cfg['tokens']['special']})

model_name = train_cfg['model']['name']
config = AutoConfig.from_pretrained(model_name)
config.vocab_size = len(tokenizer)
config.gradient_checkpointing = train_cfg['model']['gradient_checkpointing']
model = AutoModelForMaskedLM.from_pretrained(
    model_name, 
    config=config, 
    ignore_mismatched_sizes=True  # Handle vocab size mismatch due to special tokens
)
model.resize_token_embeddings(len(tokenizer))
print(f'Model {model_name} initialized with vocab size {len(tokenizer)} (ignoring size mismatches)')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of DistilBertForMaskedLM were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized because the shapes did not match:
- distilbert.embeddings.word_embeddings.weight: found shape torch.Size([30522, 768]) in the checkpoint and torch.Size([30531, 768]) in the model instantiated
- vocab_projector.bias: found shape torch.Size([30522]) in the checkpoint and torch.Size([30531]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DistilBertForMaskedLM were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized because the shapes did not match:
- distilbert.embeddings.word_embeddings.weight: found shape torch.Size([30522, 768]) in the checkpoint and torch.Size([30531, 768]) in the model 

Model distilbert-base-uncased initialized with vocab size 30531 (ignoring size mismatches)


## 4. DataLoaders

In [18]:
seq_cfg = train_cfg['sequence']
collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=seq_cfg['mlm_probability'],
    pad_to_multiple_of=8,  # Efficient padding for tensor cores
    return_tensors="pt"
)

def collate_without_anomaly(examples):
    # Remove non-essential columns and ensure proper format
    cleaned_examples = []
    for example in examples:
        # Keep only the essential columns for MLM
        cleaned_example = {
            'input_ids': example['input_ids'],
            'attention_mask': example['attention_mask'],
        }
        cleaned_examples.append(cleaned_example)
    return collator(cleaned_examples)

per_device_train_bs = train_cfg['training']['train_batch_size_per_device']
per_device_eval_bs = train_cfg['training']['eval_batch_size_per_device']

train_dataloader = DataLoader(train_dataset, batch_size=per_device_train_bs, shuffle=True, collate_fn=collate_without_anomaly)
val_dataloader = DataLoader(val_dataset, batch_size=per_device_eval_bs, shuffle=False, collate_fn=collate_without_anomaly)
test_dataloader = DataLoader(test_dataset, batch_size=per_device_eval_bs, shuffle=False, collate_fn=collate_without_anomaly)

In [15]:
# Debug: Check dataset structure and sample batch
print("Dataset columns:", train_dataset.column_names)
print("Sample record:", train_dataset[0])
print("\nInput IDs type and length:", type(train_dataset[0]['input_ids']), len(train_dataset[0]['input_ids']))
print("Attention mask type and length:", type(train_dataset[0]['attention_mask']), len(train_dataset[0]['attention_mask']))

# Test the collator with a small batch
test_batch = [train_dataset[i] for i in range(3)]
try:
    collated = collate_without_anomaly(test_batch)
    print("Collation successful!")
    print("Collated keys:", collated.keys())
    for key, value in collated.items():
        if hasattr(value, 'shape'):
            print(f"{key} shape: {value.shape}")
except Exception as e:
    print(f"Collation error: {e}")
    print("Raw batch sample:", test_batch[0])

Dataset columns: ['input_ids', 'attention_mask', 'labels', 'template_id', 'anomaly_label', 'timestamp']
Sample record: {'input_ids': [101, 30529, 30529, 30524, 18558, 1040, 10343, 1012, 1042, 2015, 18442, 6508, 13473, 2213, 1024, 3796, 1008, 3415, 27268, 6633, 1012, 2035, 24755, 2618, 23467, 1024, 1013, 24098, 2102, 1013, 2018, 18589, 1013, 4949, 5596, 1013, 2291, 1013, 3105, 1035, 2263, 14526, 2692, 2683, 11387, 14142, 1035, 2199, 2487, 1013, 3105, 1012, 15723, 1012, 1038, 13687, 1035, 1011, 30530, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [101, 30529, 30529, 30524, 18558, 1040, 10343, 1012, 1042, 2015, 18442, 6508, 13473, 2213, 1024, 3796, 1008, 3415, 27268, 6633, 1012, 2035, 24755, 2618, 23467, 1024, 1013, 24098, 2102, 1013, 2018, 18589, 1013, 4949, 5596, 1013, 2291, 1013, 3105, 1035, 2263, 14526, 2692, 2683, 11

In [18]:
# Test a single forward pass to debug device issues
print("Testing single forward pass...")
test_batch = next(iter(train_dataloader))
print("Batch keys:", test_batch.keys())
print("Device placement:")
for key, value in test_batch.items():
    if torch.is_tensor(value):
        print(f"  {key}: {value.device}, shape: {value.shape}")
    else:
        print(f"  {key}: {type(value)}")

print(f"\nModel device: {next(model.parameters()).device}")

# Try a forward pass
model.train()
try:
    with torch.no_grad():
        outputs = model(**test_batch)
    print("✓ Forward pass successful!")
    print(f"Loss: {outputs.loss.item():.4f}")
except Exception as e:
    print(f"✗ Forward pass failed: {e}")
    print("Model parameters device:", next(model.parameters()).device)
    print("Accelerator device:", accelerator.device)

Testing single forward pass...
Batch keys: dict_keys(['input_ids', 'attention_mask', 'labels'])
Device placement:
  input_ids: cuda:0, shape: torch.Size([16, 80])
  attention_mask: cuda:0, shape: torch.Size([16, 80])
  labels: cuda:0, shape: torch.Size([16, 80])

Model device: cuda:0
✓ Forward pass successful!
Loss: 10.5156


## 5. Accelerator Setup

In [9]:
mixed_precision = 'no'  # Disable mixed precision to avoid gradient scaling issues
accelerator = Accelerator(
    gradient_accumulation_steps=train_cfg['training']['grad_accumulation_steps'],
    mixed_precision=mixed_precision
)
print(accelerator.state)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=train_cfg['optimizer']['lr'],
    betas=tuple(train_cfg['optimizer']['betas']),
    eps=train_cfg['optimizer']['eps'],
    weight_decay=train_cfg['optimizer']['weight_decay']
)

model, optimizer, train_dataloader, val_dataloader, test_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, val_dataloader, test_dataloader
)

num_update_steps = math.ceil(len(train_dataloader) / train_cfg['training']['grad_accumulation_steps'])
total_steps = num_update_steps * train_cfg['training']['epochs']
print(f'Total training steps: {total_steps}')

lr_scheduler = get_scheduler(
    name=train_cfg['optimizer']['scheduler'],
    optimizer=optimizer,
    num_warmup_steps=train_cfg['optimizer']['warmup_steps'],
    num_training_steps=total_steps
)

Distributed environment: DistributedType.NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no

Total training steps: 104772


## 6. Memory Utilities

In [12]:
def free_cuda():
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
    gc.collect()


def log_gpu_memory(tag: str):
    import torch  # Ensure torch is available in function scope
    if torch.cuda.is_available():
        alloc = torch.cuda.memory_allocated() / (1024 ** 3)
        reserved = torch.cuda.memory_reserved() / (1024 ** 3)
        accelerator.print(f'[{tag}] gpu allocated={alloc:.2f} GB reserved={reserved:.2f} GB')
    elif IS_MPS:
        try:
            import torch.mps
            stats = torch.mps.current_allocated_memory() / (1024 ** 3)
            accelerator.print(f'[{tag}] mps allocated={stats:.2f} GB')
        except Exception:
            accelerator.print(f'[{tag}] mps memory stats unavailable.')

In [None]:
# SIMPLE TRAINING RESTART - No resume calculations, just train normally
import shutil

# Basic training setup
train_epochs = train_cfg['training']['epochs'] 
max_grad_norm = train_cfg['training']['max_grad_norm']
save_steps = 50000

history = {'step': [], 'loss': [], 'lr': [], 'throughput': []}
step_times = deque(maxlen=200)
ema_throughput = None

def save_checkpoint_simple(model, optimizer, scheduler, epoch: int, step: int, tag: str):
    accelerator.wait_for_everyone()
    ckpt_dir = Path(checkpoint_cfg['output_dir']) / f'{tag}_epoch{epoch}_step{step}'
    if accelerator.is_main_process:
        ckpt_dir.mkdir(parents=True, exist_ok=True)
        unwrapped = accelerator.unwrap_model(model)
        unwrapped.save_pretrained(ckpt_dir)
        tokenizer.save_pretrained(ckpt_dir / 'tokenizer')
        torch.save(optimizer.state_dict(), ckpt_dir / 'optimizer.pt')
        torch.save(scheduler.state_dict(), ckpt_dir / 'scheduler.pt')
        with (ckpt_dir / 'training_state.json').open('w') as fh:
            json.dump({'epoch': epoch, 'step': step}, fh)
        accelerator.print(f'💾 [checkpoint] saved -> {ckpt_dir}')
    accelerator.wait_for_everyone()
    free_cuda()

print("🚀 Starting SIMPLE training (no resume calculations)")
print(f"📊 Training for {train_epochs} epochs")
print(f"📊 Steps per epoch: {len(train_dataloader)}")

# Start fresh from epoch 3 (since we know we're in epoch 3)
for epoch in range(2, train_epochs):  # Start from epoch 3 (0-indexed = 2) 
    model.train()
    accelerator.print(f'==== Epoch {epoch+1}/{train_epochs} ====')
    
    progress = tqdm(
        total=len(train_dataloader), 
        disable=not accelerator.is_local_main_process,
        desc=f"Epoch {epoch+1}"
    )
    
    steps_processed = 0
    
    for step, batch in enumerate(train_dataloader, start=1):
        start_time = time.perf_counter()
        
        # Ensure all tensors are on the correct device
        batch = {k: v.to(accelerator.device) if torch.is_tensor(v) else v for k, v in batch.items()}
        
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
            accelerator.backward(loss)
            
            if max_grad_norm:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
        duration = time.perf_counter() - start_time
        step_times.append(duration)
        steps_processed += 1

        global_step = epoch * len(train_dataloader) + step
        tokens_processed = batch['input_ids'].numel()
        throughput = tokens_processed / max(duration, 1e-6)
        beta = train_cfg['logging']['throughput_ema_beta']
        ema_throughput = throughput if ema_throughput is None else (beta * ema_throughput + (1 - beta) * throughput)

        history['step'].append(int(global_step))
        history['loss'].append(float(loss.detach().item()))
        history['lr'].append(float(lr_scheduler.get_last_lr()[0]))
        history['throughput'].append(float(ema_throughput))

        if accelerator.is_local_main_process:
            p95 = float(np.percentile(step_times, 95)) if len(step_times) >= 5 else float(duration)
            progress.set_description(f'Epoch {epoch+1} | loss={loss.item():.4f} | tok/s={ema_throughput:,.0f} | p95={p95:.3f}s')
        progress.update(1)

        # Log first few steps and every 100 steps
        if steps_processed <= 10 or steps_processed % 100 == 0:
            accelerator.print(f"✅ Step {step} processed (global: {global_step}, loss: {loss.item():.4f})")

        if train_cfg['logging']['log_steps'] and global_step % train_cfg['logging']['log_steps'] == 0 and accelerator.is_main_process:
            log_gpu_memory(f'step {global_step}')

        # Save checkpoint every 50k steps
        if save_steps and global_step % save_steps == 0:
            save_checkpoint_simple(model, optimizer, lr_scheduler, epoch+1, global_step, tag='step')
        
        # Early break for testing (remove this after confirming it works)
        if steps_processed >= 10:
            accelerator.print(f"🔥 Early break after {steps_processed} steps for testing")
            break

    progress.close()
    
    # Save epoch checkpoint
    epoch_steps = (epoch + 1) * len(train_dataloader)
    save_checkpoint_simple(model, optimizer, lr_scheduler, epoch+1, epoch_steps, tag='epoch')

    # Validation
    model.eval()
    val_losses = []
    val_progress = tqdm(val_dataloader, desc="Validation", disable=not accelerator.is_local_main_process)
    for batch in val_progress:
        with torch.no_grad():
            outputs = model(**batch)
            val_losses.append(accelerator.gather(outputs.loss.detach()).mean().item())
    val_loss = float(np.mean(val_losses))
    accelerator.print(f'📊 Validation loss after epoch {epoch+1}: {val_loss:.4f}')
    
    # Break after first epoch for testing
    accelerator.print("🔥 Breaking after first epoch for testing")
    break

accelerator.wait_for_everyone()
free_cuda()
accelerator.print("🎉 Training test completed!")

In [21]:
# ULTRA SIMPLE TEST - Just iterate through dataloader
print("🧪 Testing basic dataloader iteration...")

count = 0
for step, batch in enumerate(train_dataloader, start=1):
    count += 1
    if count <= 5:
        print(f"✅ Step {step}: batch shape = {batch['input_ids'].shape}")
    
    if count >= 10:
        print(f"🎉 Successfully processed {count} batches!")
        break
        
print("🔍 Basic iteration test completed")

🧪 Testing basic dataloader iteration...
✅ Step 1: batch shape = torch.Size([16, 80])
✅ Step 2: batch shape = torch.Size([16, 72])
✅ Step 3: batch shape = torch.Size([16, 56])
✅ Step 4: batch shape = torch.Size([16, 56])
✅ Step 5: batch shape = torch.Size([16, 72])
🎉 Successfully processed 10 batches!
🔍 Basic iteration test completed


In [22]:
# WORKING TRAINING LOOP - Fresh start from current state
print("🚀 Starting fresh training from current model state...")

# Simple training parameters
max_grad_norm = train_cfg['training']['max_grad_norm']
save_steps = 50000

history = {'step': [], 'loss': [], 'lr': [], 'throughput': []}
step_times = deque(maxlen=200)
ema_throughput = None

# Just run one epoch for now to prove it works
epoch = 2  # Epoch 3 (0-indexed)
model.train()
accelerator.print(f'==== Starting Epoch {epoch+1} ====')

progress = tqdm(
    total=len(train_dataloader), 
    disable=not accelerator.is_local_main_process,
    desc=f"Epoch {epoch+1}"
)

steps_completed = 0

for step, batch in enumerate(train_dataloader, start=1):
    start_time = time.perf_counter()
    
    # Ensure all tensors are on the correct device
    batch = {k: v.to(accelerator.device) if torch.is_tensor(v) else v for k, v in batch.items()}
    
    with accelerator.accumulate(model):
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        
        if max_grad_norm:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        
    duration = time.perf_counter() - start_time
    step_times.append(duration)
    steps_completed += 1

    global_step = epoch * len(train_dataloader) + step
    tokens_processed = batch['input_ids'].numel()
    throughput = tokens_processed / max(duration, 1e-6)
    beta = train_cfg['logging']['throughput_ema_beta']
    ema_throughput = throughput if ema_throughput is None else (beta * ema_throughput + (1 - beta) * throughput)

    history['step'].append(int(global_step))
    history['loss'].append(float(loss.detach().item()))
    history['lr'].append(float(lr_scheduler.get_last_lr()[0]))
    history['throughput'].append(float(ema_throughput))

    if accelerator.is_local_main_process:
        p95 = float(np.percentile(step_times, 95)) if len(step_times) >= 5 else float(duration)
        progress.set_description(f'Epoch {epoch+1} | loss={loss.item():.4f} | tok/s={ema_throughput:,.0f} | p95={p95:.3f}s')
    progress.update(1)

    # Print progress for first few steps
    if steps_completed <= 10 or steps_completed % 1000 == 0:
        accelerator.print(f"✅ Completed step {step} (global: {global_step}) - Loss: {loss.item():.4f}")

    if train_cfg['logging']['log_steps'] and global_step % train_cfg['logging']['log_steps'] == 0 and accelerator.is_main_process:
        log_gpu_memory(f'step {global_step}')
    
    # Test run - stop after 100 steps to verify it's working
    if steps_completed >= 100:
        accelerator.print(f"🎉 Test successful! Completed {steps_completed} training steps")
        break

progress.close()
accelerator.print(f"✅ Training test completed - processed {steps_completed} steps")

🚀 Starting fresh training from current model state...
==== Starting Epoch 3 ====


Epoch 3:   0%|          | 0/279391 [00:00<?, ?it/s]

✅ Completed step 1 (global: 558783) - Loss: 0.3675
✅ Completed step 2 (global: 558784) - Loss: 0.2407
✅ Completed step 3 (global: 558785) - Loss: 0.3938
✅ Completed step 4 (global: 558786) - Loss: 0.4291
✅ Completed step 5 (global: 558787) - Loss: 0.4810
✅ Completed step 6 (global: 558788) - Loss: 0.4124
✅ Completed step 7 (global: 558789) - Loss: 0.7115
✅ Completed step 8 (global: 558790) - Loss: 0.4027
✅ Completed step 9 (global: 558791) - Loss: 0.1852
✅ Completed step 10 (global: 558792) - Loss: 0.2198
[step 558800] gpu allocated=1.17 GB reserved=3.29 GB
🎉 Test successful! Completed 100 training steps
✅ Training test completed - processed 100 steps


In [23]:
# PRODUCTION TRAINING - Continue from current state without complex resume
print("🚀 Starting PRODUCTION training from current model state...")

# Training parameters
max_grad_norm = train_cfg['training']['max_grad_norm']
save_steps = 50000
train_epochs = train_cfg['training']['epochs']

def save_checkpoint_final(model, optimizer, scheduler, epoch: int, step: int, tag: str):
    accelerator.wait_for_everyone()
    ckpt_dir = Path(checkpoint_cfg['output_dir']) / f'{tag}_epoch{epoch}_step{step}'
    if accelerator.is_main_process:
        ckpt_dir.mkdir(parents=True, exist_ok=True)
        unwrapped = accelerator.unwrap_model(model)
        unwrapped.save_pretrained(ckpt_dir)
        tokenizer.save_pretrained(ckpt_dir / 'tokenizer')
        torch.save(optimizer.state_dict(), ckpt_dir / 'optimizer.pt')
        torch.save(scheduler.state_dict(), ckpt_dir / 'scheduler.pt')
        with (ckpt_dir / 'training_state.json').open('w') as fh:
            json.dump({'epoch': epoch, 'step': step}, fh)
        accelerator.print(f'💾 [checkpoint] saved -> {ckpt_dir}')
    accelerator.wait_for_everyone()
    free_cuda()

# Continue from epoch 3 (we know we're in the middle of epoch 3)
for epoch in range(2, train_epochs):  # Start from epoch 3 (0-indexed = 2)
    model.train()
    accelerator.print(f'==== Epoch {epoch+1}/{train_epochs} ====')
    
    progress = tqdm(
        total=len(train_dataloader), 
        disable=not accelerator.is_local_main_process,
        desc=f"Epoch {epoch+1}"
    )
    
    steps_completed = 0
    
    for step, batch in enumerate(train_dataloader, start=1):
        start_time = time.perf_counter()
        
        # Ensure all tensors are on the correct device
        batch = {k: v.to(accelerator.device) if torch.is_tensor(v) else v for k, v in batch.items()}
        
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
            accelerator.backward(loss)
            
            if max_grad_norm:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
        duration = time.perf_counter() - start_time
        step_times.append(duration)
        steps_completed += 1

        global_step = epoch * len(train_dataloader) + step
        tokens_processed = batch['input_ids'].numel()
        throughput = tokens_processed / max(duration, 1e-6)
        beta = train_cfg['logging']['throughput_ema_beta']
        ema_throughput = throughput if ema_throughput is None else (beta * ema_throughput + (1 - beta) * throughput)

        history['step'].append(int(global_step))
        history['loss'].append(float(loss.detach().item()))
        history['lr'].append(float(lr_scheduler.get_last_lr()[0]))
        history['throughput'].append(float(ema_throughput))

        if accelerator.is_local_main_process:
            p95 = float(np.percentile(step_times, 95)) if len(step_times) >= 5 else float(duration)
            progress.set_description(f'Epoch {epoch+1} | loss={loss.item():.4f} | tok/s={ema_throughput:,.0f} | p95={p95:.3f}s')
        progress.update(1)

        # Log every 1000 steps
        if steps_completed % 1000 == 0:
            accelerator.print(f"📊 Step {step} (global: {global_step}) - Loss: {loss.item():.4f} - {steps_completed:,} completed")

        if train_cfg['logging']['log_steps'] and global_step % train_cfg['logging']['log_steps'] == 0 and accelerator.is_main_process:
            log_gpu_memory(f'step {global_step}')

        # Save checkpoint every 50k steps
        if save_steps and global_step % save_steps == 0:
            save_checkpoint_final(model, optimizer, lr_scheduler, epoch+1, global_step, tag='step')

    progress.close()
    
    # Save epoch checkpoint
    epoch_steps = (epoch + 1) * len(train_dataloader)
    save_checkpoint_final(model, optimizer, lr_scheduler, epoch+1, epoch_steps, tag='epoch')

    # Validation
    accelerator.print("📊 Running validation...")
    model.eval()
    val_losses = []
    val_progress = tqdm(val_dataloader, desc="Validation", disable=not accelerator.is_local_main_process)
    for batch in val_progress:
        with torch.no_grad():
            outputs = model(**batch)
            val_losses.append(accelerator.gather(outputs.loss.detach()).mean().item())
    val_loss = float(np.mean(val_losses))
    accelerator.print(f'📊 Validation loss after epoch {epoch+1}: {val_loss:.4f}')

accelerator.wait_for_everyone()
free_cuda()
accelerator.print("🎉 Training completed!")

🚀 Starting PRODUCTION training from current model state...
==== Epoch 3/3 ====


Epoch 3:   0%|          | 0/279391 [00:00<?, ?it/s]

[step 558800] gpu allocated=1.17 GB reserved=3.29 GB
[step 558900] gpu allocated=1.18 GB reserved=3.29 GB
[step 559000] gpu allocated=1.17 GB reserved=3.29 GB
[step 559100] gpu allocated=1.14 GB reserved=3.29 GB
[step 559200] gpu allocated=1.15 GB reserved=3.29 GB
[step 559300] gpu allocated=1.12 GB reserved=3.29 GB
[step 559400] gpu allocated=1.16 GB reserved=3.29 GB
[step 559500] gpu allocated=1.18 GB reserved=3.29 GB
[step 559600] gpu allocated=1.17 GB reserved=3.29 GB
[step 559700] gpu allocated=1.15 GB reserved=3.99 GB
📊 Step 1000 (global: 559782) - Loss: 0.3680 - 1,000 completed
[step 559800] gpu allocated=1.16 GB reserved=3.99 GB
[step 559900] gpu allocated=1.18 GB reserved=3.99 GB
[step 560000] gpu allocated=1.17 GB reserved=3.99 GB
[step 560100] gpu allocated=1.15 GB reserved=3.99 GB
[step 560200] gpu allocated=1.13 GB reserved=3.99 GB
[step 560300] gpu allocated=1.17 GB reserved=3.99 GB
[step 560400] gpu allocated=1.14 GB reserved=3.99 GB
[step 560500] gpu allocated=1.14 GB r

Validation:   0%|          | 0/17462 [00:00<?, ?it/s]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

## 7. Training Loop

In [13]:
checkpoint_cfg = train_cfg['checkpointing']
metrics_dir = Path(train_cfg['artifacts']['metrics_dir'])
metrics_dir.mkdir(parents=True, exist_ok=True)
run_config_path = Path(train_cfg['artifacts']['run_config_path'])
run_config_path.parent.mkdir(parents=True, exist_ok=True)

history = {'step': [], 'loss': [], 'lr': [], 'throughput': []}
step_times = deque(maxlen=200)
ema_throughput = None

# Override save_steps to 50k for better disk management
def cleanup_old_checkpoints(checkpoint_dir: Path, keep_latest: int = 3):
    """Keep only the most recent checkpoints to save disk space."""
    if not checkpoint_dir.exists():
        return
    
    # Get all checkpoint directories, sorted by modification time (newest first)
    checkpoints = [d for d in checkpoint_dir.iterdir() if d.is_dir()]
    checkpoints.sort(key=lambda x: x.stat().st_mtime, reverse=True)
    
    # Delete old checkpoints beyond keep_latest
    for old_ckpt in checkpoints[keep_latest:]:
        accelerator.print(f'🗑️  Cleaning up old checkpoint: {old_ckpt.name}')
        shutil.rmtree(old_ckpt)

def save_checkpoint(model, optimizer, scheduler, epoch: int, step: int, tag: str):
    accelerator.wait_for_everyone()
    ckpt_dir = Path(checkpoint_cfg['output_dir']) / f'{tag}_epoch{epoch}_step{step}'
    if accelerator.is_main_process:
        ckpt_dir.mkdir(parents=True, exist_ok=True)
        unwrapped = accelerator.unwrap_model(model)
        unwrapped.save_pretrained(ckpt_dir)
        tokenizer.save_pretrained(ckpt_dir / 'tokenizer')
        torch.save(optimizer.state_dict(), ckpt_dir / 'optimizer.pt')
        torch.save(scheduler.state_dict(), ckpt_dir / 'scheduler.pt')
        with (ckpt_dir / 'training_state.json').open('w') as fh:
            json.dump({'epoch': epoch, 'step': step}, fh)
        accelerator.print(f'[checkpoint] saved -> {ckpt_dir}')
        
        # Clean up old checkpoints to save space (keep latest 3)
        cleanup_old_checkpoints(Path(checkpoint_cfg['output_dir']), keep_latest=3)
    accelerator.wait_for_everyone()
    free_cuda()

# Check for existing checkpoints to resume from
def find_latest_checkpoint(checkpoint_dir: Path):
    """Find the most recent checkpoint to resume from."""
    if not checkpoint_dir.exists():
        return None, 0, 0
    
    checkpoints = [d for d in checkpoint_dir.iterdir() if d.is_dir()]
    if not checkpoints:
        return None, 0, 0
    
    # Sort by modification time (newest first)
    checkpoints.sort(key=lambda x: x.stat().st_mtime, reverse=True)
    latest_ckpt = checkpoints[0]
    
    # Parse epoch and step from checkpoint name
    import re
    match = re.search(r'epoch(\d+)_step(\d+)', latest_ckpt.name)
    if match:
        epoch = int(match.group(1))
        step = int(match.group(2))
        return latest_ckpt, epoch, step
    return None, 0, 0

train_epochs = train_cfg['training']['epochs']
max_grad_norm = train_cfg['training']['max_grad_norm']

# Check for resuming from checkpoint
latest_ckpt_path, resume_epoch, resume_step = find_latest_checkpoint(Path(checkpoint_cfg['output_dir']))
start_epoch = 0
resume_global_step = 0

if latest_ckpt_path:
    accelerator.print(f"🔄 Found checkpoint to resume from: {latest_ckpt_path.name}")
    accelerator.print(f"📍 Resuming from epoch {resume_epoch}, step {resume_step}")
    
    if accelerator.is_main_process:
        # Load model state
        unwrapped_model = accelerator.unwrap_model(model)
        model_path = latest_ckpt_path / 'pytorch_model.bin'
        if not model_path.exists():
            model_path = latest_ckpt_path / 'model.safetensors'
        
        if model_path.exists():
            unwrapped_model.from_pretrained(latest_ckpt_path)
        
        # Load optimizer state
        if (latest_ckpt_path / 'optimizer.pt').exists():
            optimizer.load_state_dict(torch.load(latest_ckpt_path / 'optimizer.pt', map_location='cpu'))
        
        # Load scheduler state  
        if (latest_ckpt_path / 'scheduler.pt').exists():
            lr_scheduler.load_state_dict(torch.load(latest_ckpt_path / 'scheduler.pt', map_location='cpu'))
    
    # Calculate where to start
    start_epoch = resume_epoch - 1 if resume_epoch > 0 else 0
    resume_global_step = resume_step
    
    accelerator.print(f"✅ Checkpoint loaded successfully!")
else:
    accelerator.print(f"🆕 No existing checkpoints found. Starting training from scratch.")

for epoch in range(start_epoch, train_epochs):
    model.train()
    accelerator.print(f'==== Epoch {epoch+1}/{train_epochs} ====')
    
    # Calculate starting step for this epoch
    epoch_start_step = 1
    if epoch == start_epoch and resume_global_step > 0:
        # If resuming in the middle of an epoch, calculate the step to start from
        steps_per_epoch = len(train_dataloader)
        epoch_start_step = (resume_global_step % steps_per_epoch) + 1
        accelerator.print(f"🔄 Resuming epoch {epoch+1} from step {epoch_start_step}")
    
    progress = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
    progress.update(epoch_start_step - 1)  # Update progress bar to show resumed position
    
    for step, batch in enumerate(train_dataloader, start=1):
        # Skip steps if resuming
        if step < epoch_start_step:
            continue
        start_time = time.perf_counter()
        
        # Ensure all tensors are on the correct device
        batch = {k: v.to(accelerator.device) if torch.is_tensor(v) else v for k, v in batch.items()}
        
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
            accelerator.backward(loss)
            
            if max_grad_norm:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
        duration = time.perf_counter() - start_time
        step_times.append(duration)

        global_step = epoch * len(train_dataloader) + step
        tokens_processed = batch['input_ids'].numel()
        throughput = tokens_processed / max(duration, 1e-6)
        beta = train_cfg['logging']['throughput_ema_beta']
        ema_throughput = throughput if ema_throughput is None else (beta * ema_throughput + (1 - beta) * throughput)

        history['step'].append(int(global_step))
        history['loss'].append(float(loss.detach().item()))
        history['lr'].append(float(lr_scheduler.get_last_lr()[0]))
        history['throughput'].append(float(ema_throughput))

        if accelerator.is_local_main_process:
            p95 = float(np.percentile(step_times, 95)) if len(step_times) >= 5 else float(duration)
            progress.set_description(f'loss={loss.item():.4f} ema_tok_s={ema_throughput:,.0f} p95={p95:.3f}s')
        progress.update(1)

        if train_cfg['logging']['log_steps'] and global_step % train_cfg['logging']['log_steps'] == 0 and accelerator.is_main_process:
            log_gpu_memory(f'step {global_step}')

        if save_steps and global_step % save_steps == 0:

            save_checkpoint(model, optimizer, lr_scheduler, epoch+1, global_step, tag='step')
    accelerator.print(f'Validation loss after epoch {epoch+1}: {val_loss:.4f}')free_cuda()


accelerator.wait_for_everyone()

    progress.close()
accelerator.wait_for_everyone()

    epoch_steps = (epoch + 1) * len(train_dataloader)
free_cuda()    accelerator.print(f'Validation loss after epoch {epoch+1}: {val_loss:.4f}')

    save_checkpoint(model, optimizer, lr_scheduler, epoch+1, epoch_steps, tag='epoch')    val_loss = float(np.mean(losses))

            losses.append(accelerator.gather(outputs.loss.detach()).mean().item())

    model.eval()            outputs = model(**batch)

    losses = []        with torch.no_grad():
    for batch in val_dataloader:

==== Epoch 1/3 ====


  0%|          | 0/279391 [00:00<?, ?it/s]

[step 100] gpu allocated=0.90 GB reserved=2.46 GB
[step 200] gpu allocated=1.15 GB reserved=2.46 GB
[step 200] gpu allocated=1.15 GB reserved=2.46 GB
[step 300] gpu allocated=0.92 GB reserved=2.46 GB
[step 300] gpu allocated=0.92 GB reserved=2.46 GB
[step 400] gpu allocated=1.15 GB reserved=3.10 GB
[step 400] gpu allocated=1.15 GB reserved=3.10 GB
[step 500] gpu allocated=0.90 GB reserved=3.80 GB
[step 500] gpu allocated=0.90 GB reserved=3.80 GB
[step 600] gpu allocated=1.17 GB reserved=3.80 GB
[step 600] gpu allocated=1.17 GB reserved=3.80 GB
[step 700] gpu allocated=0.90 GB reserved=3.80 GB
[step 700] gpu allocated=0.90 GB reserved=3.80 GB
[step 800] gpu allocated=1.15 GB reserved=3.80 GB
[step 800] gpu allocated=1.15 GB reserved=3.80 GB
[step 900] gpu allocated=0.87 GB reserved=3.80 GB
[step 900] gpu allocated=0.87 GB reserved=3.80 GB
[step 1000] gpu allocated=1.14 GB reserved=3.80 GB
[step 1000] gpu allocated=1.14 GB reserved=3.80 GB
[step 1100] gpu allocated=0.92 GB reserved=3.80 

  0%|          | 0/279391 [00:00<?, ?it/s]

[step 279400] gpu allocated=1.15 GB reserved=3.02 GB
[step 279500] gpu allocated=1.12 GB reserved=3.03 GB
[step 279500] gpu allocated=1.12 GB reserved=3.03 GB
[step 279600] gpu allocated=1.17 GB reserved=3.03 GB
[step 279600] gpu allocated=1.17 GB reserved=3.03 GB
[step 279700] gpu allocated=1.15 GB reserved=3.03 GB
[step 279700] gpu allocated=1.15 GB reserved=3.03 GB
[step 279800] gpu allocated=1.15 GB reserved=3.03 GB
[step 279800] gpu allocated=1.15 GB reserved=3.03 GB
[step 279900] gpu allocated=1.15 GB reserved=3.03 GB
[step 279900] gpu allocated=1.15 GB reserved=3.03 GB
[step 280000] gpu allocated=1.15 GB reserved=3.03 GB
[step 280000] gpu allocated=1.15 GB reserved=3.03 GB
[checkpoint] saved -> artifacts/logbert-mlm-hdfs/step_epoch2_step280000
[checkpoint] saved -> artifacts/logbert-mlm-hdfs/step_epoch2_step280000
[step 280100] gpu allocated=1.17 GB reserved=2.67 GB
[step 280100] gpu allocated=1.17 GB reserved=2.67 GB
[step 280200] gpu allocated=1.12 GB reserved=2.67 GB
[step 28

  0%|          | 0/279391 [00:00<?, ?it/s]

[step 558800] gpu allocated=1.15 GB reserved=3.04 GB
[step 558900] gpu allocated=1.17 GB reserved=3.04 GB
[step 558900] gpu allocated=1.17 GB reserved=3.04 GB
[step 559000] gpu allocated=1.15 GB reserved=3.04 GB
[step 559000] gpu allocated=1.15 GB reserved=3.04 GB
[step 559100] gpu allocated=1.17 GB reserved=3.04 GB
[step 559100] gpu allocated=1.17 GB reserved=3.04 GB
[step 559200] gpu allocated=1.15 GB reserved=3.04 GB
[step 559200] gpu allocated=1.15 GB reserved=3.04 GB
[step 559300] gpu allocated=1.15 GB reserved=3.04 GB
[step 559300] gpu allocated=1.15 GB reserved=3.04 GB
[step 559400] gpu allocated=1.15 GB reserved=3.04 GB
[step 559400] gpu allocated=1.15 GB reserved=3.04 GB
[step 559500] gpu allocated=1.15 GB reserved=3.04 GB
[step 559500] gpu allocated=1.15 GB reserved=3.04 GB
[step 559600] gpu allocated=1.15 GB reserved=3.04 GB
[step 559600] gpu allocated=1.15 GB reserved=3.04 GB
[step 559700] gpu allocated=1.17 GB reserved=3.04 GB
[step 559700] gpu allocated=1.17 GB reserved=3

SafetensorError: Error while serializing: I/O error: No space left on device (os error 28)

## 8. Evaluation and Metrics

In [None]:
model.eval()

@torch.no_grad()
def evaluate_loader(dataloader) -> float:
    losses = []
    for batch in dataloader:
        outputs = model(**batch)
        gathered = accelerator.gather(outputs.loss)
        losses.extend(gathered.cpu().numpy())
    return float(np.mean(losses))

val_loss = evaluate_loader(val_dataloader)
test_loss = evaluate_loader(test_dataloader)
perplexity = math.exp(test_loss)

metrics = {
    'val_loss': val_loss,
    'test_loss': test_loss,
    'test_perplexity': perplexity,
    'steps_tracked': len(history['step'])
}
metrics_path = metrics_dir / 'hdfs_pretraining_metrics.json'
metrics_path.write_text(json.dumps(metrics, indent=2))
print(json.dumps(metrics, indent=2))

## 9. Persist Run Config

In [None]:
state_summary = {
    'num_processes': accelerator.state.num_processes,
    'process_index': accelerator.state.process_index,
    'local_process_index': accelerator.state.local_process_index,
    'device': str(accelerator.device),
    'mixed_precision': accelerator.state.mixed_precision,
    'distributed_type': str(accelerator.state.distributed_type)
}
run_payload = {
    'train_config': train_cfg,
    'data_config': data_cfg,
    'accelerator_state': state_summary,
    'is_mps': IS_MPS
}
run_config_path.write_text(json.dumps(run_payload, indent=2))
print(f'Run configuration written to {run_config_path}')

## Artifacts Produced
- Checkpoints saved under `artifacts/logbert-mlm-hdfs/`
- Validation/test metrics stored at `artifacts/metrics/hdfs/hdfs_pretraining_metrics.json`
- Run configuration captured at `artifacts/logbert-mlm-hdfs/run_config.json`

Continuing pipeline: open `02_finetune_openstack.ipynb`.

In [25]:
# VALIDATION ONLY - Skip checkpoint saving due to disk space, run validation with current model
print("🎉 Training completed! Running final validation...")
print("⚠️ Skipping checkpoint save due to disk space - using current model in memory")

# Ensure we're using the trained model
model.eval()
accelerator.print("📊 Starting final validation on test and validation sets...")

# Validation on validation set
val_losses = []
val_progress = tqdm(val_dataloader, desc="Validation Set", disable=not accelerator.is_local_main_process)
for batch in val_progress:
    # Ensure all tensors are on the correct device
    batch = {k: v.to(accelerator.device) if torch.is_tensor(v) else v for k, v in batch.items()}
    
    with torch.no_grad():
        outputs = model(**batch)
        val_losses.append(accelerator.gather(outputs.loss.detach()).mean().item())
val_loss = float(np.mean(val_losses))
accelerator.print(f'📊 Final Validation Loss: {val_loss:.4f}')

# Test set evaluation
test_losses = []
test_progress = tqdm(test_dataloader, desc="Test Set", disable=not accelerator.is_local_main_process)
for batch in test_progress:
    # Ensure all tensors are on the correct device
    batch = {k: v.to(accelerator.device) if torch.is_tensor(v) else v for k, v in batch.items()}
    
    with torch.no_grad():
        outputs = model(**batch)
        test_losses.append(accelerator.gather(outputs.loss.detach()).mean().item())
test_loss = float(np.mean(test_losses))
test_perplexity = math.exp(test_loss)

accelerator.print(f'📊 Final Test Loss: {test_loss:.4f}')
accelerator.print(f'📊 Final Test Perplexity: {test_perplexity:.4f}')

# Save metrics (much smaller file)
final_metrics = {
    'final_val_loss': val_loss,
    'final_test_loss': test_loss, 
    'final_test_perplexity': test_perplexity,
    'training_completed': True,
    'total_steps_completed': len(history['step']) if 'history' in locals() and history['step'] else 0
}

metrics_path = metrics_dir / 'hdfs_pretraining_final_metrics.json'
if accelerator.is_main_process:
    metrics_path.write_text(json.dumps(final_metrics, indent=2))
    accelerator.print(f'💾 Final metrics saved to: {metrics_path}')

accelerator.print("🎉 MLM Pretraining completed successfully!")
accelerator.print("📋 Summary:")
accelerator.print(f"   • Validation Loss: {val_loss:.4f}")  
accelerator.print(f"   • Test Loss: {test_loss:.4f}")
accelerator.print(f"   • Test Perplexity: {test_perplexity:.4f}")
accelerator.print("🚀 Ready to proceed to fine-tuning!")

# Clean up
free_cuda()
accelerator.wait_for_everyone()

🎉 Training completed! Running final validation...
⚠️ Skipping checkpoint save due to disk space - using current model in memory
📊 Starting final validation on test and validation sets...


Validation Set:   0%|          | 0/17462 [00:00<?, ?it/s]

📊 Final Validation Loss: 0.4631


Test Set:   0%|          | 0/17462 [00:00<?, ?it/s]

📊 Final Test Loss: 0.5139
📊 Final Test Perplexity: 1.6718
💾 Final metrics saved to: artifacts/metrics/hdfs/hdfs_pretraining_final_metrics.json
🎉 MLM Pretraining completed successfully!
📋 Summary:
   • Validation Loss: 0.4631
   • Test Loss: 0.5139
   • Test Perplexity: 1.6718
🚀 Ready to proceed to fine-tuning!
