# DDIM Next Token V1 Training Debugging Notebook

This notebook is designed to reproduce and debug the NaN/Inf gradient and AMP unscale_() errors encountered during training.

## Error Details
- **Warning**: NaN or Inf gradient norm detected at step 106, epoch 0
- **RuntimeError**: unscale_() has already been called on this optimizer since the last update()

This error occurs when the automatic mixed precision (AMP) gradient scaler tries to unscale gradients multiple times without an optimizer step in between.

## Training Context
- Model: DDIMNextTokenV1
- Learning Rate: 2.12e-5  
- Step: 105-106
- Loss: 0.63

Let's systematically debug and fix these issues.

## 1. Import Required Libraries

Import torch, accelerate, and all necessary modules for model, data loading, and training.

In [None]:
# Core PyTorch imports
import torch
import torch.nn.functional as F
import torchvision
from torchvision.utils import make_grid

# Accelerate and diffusers
from accelerate import Accelerator, notebook_launcher
from diffusers.optimization import get_cosine_schedule_with_warmup

# Model and data loading imports
from models import DDIMNextTokenV1
from data_loaders import ModularCharatersDataLoader

# Utilities
import os
import time
import wandb
from tqdm.auto import tqdm
from pathlib import Path

# For debugging
import traceback
import warnings

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 2. Set Training Hyperparameters

Define hyperparameters matching the error context where the issue occurred.
Using conservative settings to minimize NaN occurrence.

In [None]:
# Training hyperparameters based on the error context
model_version = "DDIMNextTokenV1"
dataset_name = "QLeca/modular_characters_hairs_RGB"

# Core training parameters
train_size = 1000  # Reduced for debugging
val_size = 100     # Reduced for debugging  
batch_size = 8     # Reduced from 16 to help stability
num_epochs = 3     # Reduced for debugging

# Learning rate and scheduler parameters
learning_rate = 1e-6  # Very conservative (original error had 2.12e-5)
warming_steps = 100   # Reduced warmup
num_cycles = 0.5

# Debugging parameters
max_nan_tolerance = 5  # Lower tolerance for debugging
debug_mode = True      # Enable detailed logging

# Model parameters
mixed_precision = "no"  # Disable AMP to avoid unscale_() issues initially
gradient_clip_value = 0.5  # Conservative gradient clipping

print("Hyperparameters set:")
print(f"  Model: {model_version}")
print(f"  Dataset: {dataset_name}")
print(f"  Learning Rate: {learning_rate}")
print(f"  Batch Size: {batch_size}")
print(f"  Mixed Precision: {mixed_precision}")
print(f"  Gradient Clipping: {gradient_clip_value}")
print(f"  Train/Val Size: {train_size}/{val_size}")

# Additional training tags for experiment tracking
train_tags = ["debugging", "nan_fix", "conservative_settings"]

## 3. Initialize Model Pipeline

Instantiate the DDIMNextTokenV1Pipeline and configure its training parameters.
We'll override the default settings to be more conservative for stability.

In [None]:
# Initialize the pipeline
print("Initializing DDIMNextTokenV1 Pipeline...")
pipeline = DDIMNextTokenV1.DDIMNextTokenV1Pipeline()

# Configure training parameters for stability
pipeline.train_config.train_batch_size = batch_size
pipeline.train_config.eval_batch_size = batch_size
pipeline.train_config.num_epochs = num_epochs
pipeline.train_config.learning_rate = learning_rate
pipeline.train_config.lr_warmup_steps = warming_steps
pipeline.train_config.mixed_precision = mixed_precision

# Additional safety configurations
pipeline.train_config.save_image_epochs = 1  # Save images every epoch for debugging
pipeline.train_config.save_model_epochs = 1  # Save model every epoch for debugging

print("Pipeline initialized successfully!")
print(f"  Device: {pipeline.device}")
print(f"  Image Size: {pipeline.train_config.image_size}")
print(f"  Model Config: {pipeline.model_config.config['in_channels']} → {pipeline.model_config.config['out_channels']} channels")

# Check initial model state
print("\nInitial model parameter check:")
nan_params = pipeline.check_model_for_nan()
if nan_params:
    print("⚠️  WARNING: Model already contains NaN parameters!")
else:
    print("✅ Model parameters are clean (no NaN values)")

# Print model summary
total_params = sum(p.numel() for p in pipeline.unet.parameters())
trainable_params = sum(p.numel() for p in pipeline.unet.parameters() if p.requires_grad)
print(f"\nModel Summary:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

## 4. Prepare Data Loaders

Set up the training and validation data loaders using the ModularCharatersDataLoader.
We'll also validate the data to ensure it doesn't contain NaN values.

In [None]:
# Create data loaders
print("Setting up data loaders...")
try:
    train_dataloader = ModularCharatersDataLoader.get_modular_char_dataloader(
        dataset_name=dataset_name,
        split="train",
        image_size=pipeline.train_config.image_size,
        batch_size=pipeline.train_config.train_batch_size,
        shuffle=True,
    )
    
    val_dataloader = ModularCharatersDataLoader.get_modular_char_dataloader(
        dataset_name=dataset_name,
        split="train",  # Using same split for validation during debugging
        image_size=pipeline.train_config.image_size,
        batch_size=pipeline.train_config.eval_batch_size,
        shuffle=True,
    )
    
    print("✅ Data loaders created successfully!")
    print(f"  Train dataset size: {len(train_dataloader.dataset)}")
    print(f"  Val dataset size: {len(val_dataloader.dataset)}")
    print(f"  Train batches: {len(train_dataloader)}")
    print(f"  Val batches: {len(val_dataloader)}")
    
except Exception as e:
    print(f"❌ Error creating data loaders: {e}")
    traceback.print_exc()
    raise

# Validate a sample batch for NaN values
print("\nValidating sample batch for NaN values...")
try:
    sample_batch = next(iter(train_dataloader))
    input_images = sample_batch["input"]
    target_images = sample_batch["target"]
    class_labels = sample_batch["label"]
    
    print(f"  Input shape: {input_images.shape}")
    print(f"  Target shape: {target_images.shape}")
    print(f"  Labels shape: {class_labels.shape}")
    
    # Check for NaN values in the data
    input_nan = torch.isnan(input_images).any()
    target_nan = torch.isnan(target_images).any()
    
    if input_nan or target_nan:
        print(f"❌ WARNING: NaN values found in data!")
        print(f"  Input NaN: {input_nan}")
        print(f"  Target NaN: {target_nan}")
    else:
        print("✅ Sample batch data is clean (no NaN values)")
        
    # Print data statistics
    print(f"  Input range: [{input_images.min():.3f}, {input_images.max():.3f}]")
    print(f"  Target range: [{target_images.min():.3f}, {target_images.max():.3f}]")
    print(f"  Unique labels: {torch.unique(class_labels).tolist()}")
    
except Exception as e:
    print(f"❌ Error validating batch: {e}")
    traceback.print_exc()

## 5. Custom Training Loop with Gradient Norm Handling

Implement a training loop that includes proper gradient norm clipping and checks for NaN/Inf gradients.
This version will handle the AMP scaler state correctly to avoid the unscale_() error.

In [None]:
def safe_training_step(pipeline, batch, optimizer, lr_scheduler, accelerator, step, epoch, 
                       nan_count, max_nan_tolerance, gradient_clip_value):
    """
    Perform a single training step with comprehensive NaN/Inf handling.
    Returns (loss_value, nan_count, should_continue)
    """
    
    # Extract batch data
    input_images = batch["input"].to(pipeline.device)
    target_images = batch["target"].to(pipeline.device)
    class_labels = batch['label'].to(pipeline.device)
    
    # Check for NaN values in input data
    if torch.isnan(input_images).any() or torch.isnan(target_images).any():
        print(f"❌ Step {step}: NaN in input data, skipping batch")
        return None, nan_count + 1, nan_count + 1 <= max_nan_tolerance
    
    # Sample noise and timesteps
    noise = torch.randn(target_images.shape, device=pipeline.device)
    bs = target_images.shape[0]
    timesteps = torch.randint(
        0, pipeline.scheduler.config['num_train_timesteps'], (bs,), 
        device=pipeline.device, dtype=torch.int
    )
    
    # Add noise to clean images
    noisy_targets = pipeline.scheduler.add_noise(target_images, noise, timesteps)
    
    # Forward pass
    with accelerator.accumulate(pipeline.unet):
        noisy_samples = torch.concat([input_images, noisy_targets], dim=1)
        noise_pred = pipeline.unet.forward(
            sample=noisy_samples,
            timestep=timesteps,
            class_labels=class_labels
        ).sample
        
        loss = F.mse_loss(noise_pred, noise)
        
        # Check for NaN/Inf loss
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"❌ Step {step}: NaN/Inf loss ({loss.item()}), skipping batch")
            return None, nan_count + 1, nan_count + 1 <= max_nan_tolerance
        
        # Check for NaN in noise prediction
        if torch.isnan(noise_pred).any():
            print(f"❌ Step {step}: NaN in noise prediction, skipping batch")
            return None, nan_count + 1, nan_count + 1 <= max_nan_tolerance
        
        # Backward pass
        accelerator.backward(loss)
        
        # Gradient handling with proper AMP scaler management
        if accelerator.sync_gradients:
            # Check for NaN gradients BEFORE any unscaling
            has_nan_gradients = False
            for name, param in pipeline.unet.named_parameters():
                if param.grad is not None and (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()):
                    print(f"❌ Step {step}: NaN/Inf gradient in {name}")
                    has_nan_gradients = True
                    break
            
            if has_nan_gradients:
                # Zero gradients but still step optimizer to maintain scaler state
                optimizer.zero_grad()
                optimizer.step()
                lr_scheduler.step()
                print(f"⚠️  Step {step}: Zeroed gradients due to NaN, but stepped optimizer")
                return None, nan_count + 1, nan_count + 1 <= max_nan_tolerance
            else:
                # Normal gradient clipping and optimization
                accelerator.clip_grad_norm_(pipeline.unet.parameters(), gradient_clip_value)
        
        # Normal optimizer step
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        
        return loss.detach().item(), nan_count, True

# Prepare extra parameters for training
extra_kwargs = {
    "num_cycles": num_cycles,
    "train_tags": train_tags,
}

print("Starting custom training with enhanced NaN handling...")
print(f"Max NaN tolerance: {max_nan_tolerance}")
print(f"Gradient clipping value: {gradient_clip_value}")
print(f"Mixed precision: {mixed_precision}")

# Start training with the pipeline's built-in method first
try:
    # Use the fixed pipeline training method
    pipeline.train_accelerate(
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        train_size=train_size,
        val_size=val_size,
        **extra_kwargs
    )
    print("✅ Training completed successfully!")
    
except Exception as e:
    print(f"❌ Training failed with error: {e}")
    traceback.print_exc()
    
    # If the main training fails, we'll implement a custom training loop below
    print("\n" + "="*50)
    print("MAIN TRAINING FAILED - Implementing custom debug training loop...")

## 6. Debugging NaN/Inf Gradients and AMP Issues

Add comprehensive debugging tools to catch and analyze AMP unscale_() errors and NaN gradient issues.
This section includes utilities to monitor optimizer and scaler states.

In [None]:
def debug_amp_scaler_state(accelerator):
    """Debug the AMP scaler state to understand unscale_() issues."""
    print("\n" + "="*50)
    print("AMP SCALER DEBUG INFO")
    print("="*50)
    
    if hasattr(accelerator, 'scaler') and accelerator.scaler is not None:
        scaler = accelerator.scaler
        print(f"Scaler enabled: {scaler.is_enabled()}")
        print(f"Scaler scale: {scaler.get_scale()}")
        print(f"Growth tracker: {scaler.get_growth_tracker()}")
        
        # Check if scaler state is consistent
        for i, optimizer in enumerate(accelerator.optimizer if isinstance(accelerator.optimizer, list) else [accelerator.optimizer]):
            print(f"Optimizer {i}: {type(optimizer).__name__}")
            print(f"  State dict keys: {list(optimizer.state_dict().keys())}")
            
    else:
        print("No AMP scaler found (expected for mixed_precision='no')")
    
    print("="*50)

def debug_gradient_state(model, step_info=""):
    """Debug gradient state of model parameters."""
    print(f"\nGRADIENT DEBUG {step_info}")
    print("-" * 30)
    
    total_params = 0
    params_with_grad = 0
    nan_grads = 0
    inf_grads = 0
    zero_grads = 0
    
    for name, param in model.named_parameters():
        total_params += 1
        if param.grad is not None:
            params_with_grad += 1
            if torch.isnan(param.grad).any():
                nan_grads += 1
                print(f"❌ NaN gradient in: {name}")
            elif torch.isinf(param.grad).any():
                inf_grads += 1
                print(f"❌ Inf gradient in: {name}")
            elif torch.allclose(param.grad, torch.zeros_like(param.grad)):
                zero_grads += 1
    
    print(f"Total parameters: {total_params}")
    print(f"Parameters with gradients: {params_with_grad}")
    print(f"Parameters with NaN gradients: {nan_grads}")
    print(f"Parameters with Inf gradients: {inf_grads}")
    print(f"Parameters with zero gradients: {zero_grads}")
    
    return nan_grads > 0 or inf_grads > 0

def custom_debug_training_loop():
    """
    Custom training loop with extensive debugging for AMP and gradient issues.
    This will help us understand exactly when and why the unscale_() error occurs.
    """
    print("\n" + "="*60)
    print("STARTING CUSTOM DEBUG TRAINING LOOP")
    print("="*60)
    
    # Reset pipeline to clean state
    pipeline.train_id = f"debug_run_{time.strftime('%Y-%m-%d_%H-%M-%S')}"
    pipeline.set_num_class_embeds(len(train_dataloader.vocab))
    
    # Create optimizer and scheduler manually for better control
    optimizer = torch.optim.AdamW(pipeline.unet.parameters(), lr=learning_rate)
    total_steps = (train_size * num_epochs) // batch_size
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=warming_steps,
        num_training_steps=total_steps,
        num_cycles=num_cycles
    )
    
    # Initialize accelerator with debug settings
    accelerator = Accelerator(
        mixed_precision=mixed_precision,
        gradient_accumulation_steps=1,
        log_with="tensorboard",
        project_dir=os.path.join(pipeline.train_config.output_dir, "debug_logs")
    )
    
    # Prepare everything
    pipeline.unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        pipeline.unet, optimizer, train_dataloader, lr_scheduler
    )
    
    print(f"Accelerator state:")
    print(f"  Mixed precision: {accelerator.mixed_precision}")
    print(f"  Use distributed: {accelerator.use_distributed}")
    print(f"  Device: {accelerator.device}")
    
    # Debug initial state
    debug_amp_scaler_state(accelerator)
    debug_gradient_state(pipeline.unet, "INITIAL")
    
    global_step = 0
    nan_count = 0
    
    # Training loop with extensive debugging
    for epoch in range(num_epochs):
        print(f"\n{'='*40}")
        print(f"EPOCH {epoch} - Debug Training")
        print(f"{'='*40}")
        
        pipeline.unet.train()
        
        # Take a small subset for debugging
        debug_batches = min(10, len(train_dataloader))  # Only first 10 batches for debugging
        
        for step, batch in enumerate(train_dataloader):
            if step >= debug_batches:
                break
                
            print(f"\n--- Step {step} (Global: {global_step}) ---")
            
            try:
                # Manual training step with debugging
                input_images = batch["input"].to(pipeline.device)
                target_images = batch["target"].to(pipeline.device)
                class_labels = batch['label'].to(pipeline.device)
                
                # Pre-step debugging
                debug_amp_scaler_state(accelerator)
                
                # Check input data
                if torch.isnan(input_images).any() or torch.isnan(target_images).any():
                    print(f"❌ NaN in input data at step {step}")
                    nan_count += 1
                    continue
                
                # Sample noise and forward pass
                noise = torch.randn(target_images.shape, device=pipeline.device)
                bs = target_images.shape[0]
                timesteps = torch.randint(0, pipeline.scheduler.config['num_train_timesteps'], 
                                        (bs,), device=pipeline.device, dtype=torch.int)
                
                noisy_targets = pipeline.scheduler.add_noise(target_images, noise, timesteps)
                
                with accelerator.accumulate(pipeline.unet):
                    noisy_samples = torch.concat([input_images, noisy_targets], dim=1)
                    noise_pred = pipeline.unet.forward(sample=noisy_samples, timestep=timesteps, 
                                                     class_labels=class_labels).sample
                    loss = F.mse_loss(noise_pred, noise)
                    
                    print(f"Loss: {loss.item():.6f}")
                    
                    if torch.isnan(loss) or torch.isinf(loss):
                        print(f"❌ NaN/Inf loss: {loss.item()}")
                        nan_count += 1
                        if nan_count > max_nan_tolerance:
                            print("Too many NaN occurrences, stopping")
                            return
                        continue
                    
                    # Backward pass
                    accelerator.backward(loss)
                    
                    # Post-backward debugging
                    print("After backward pass:")
                    has_bad_grads = debug_gradient_state(pipeline.unet, f"STEP_{step}_POST_BACKWARD")
                    
                    if accelerator.sync_gradients:
                        print("Syncing gradients...")
                        debug_amp_scaler_state(accelerator)
                        
                        if has_bad_grads:
                            print("❌ Bad gradients detected, zeroing and stepping anyway")
                            optimizer.zero_grad()
                        else:
                            print("✅ Gradients look good, clipping...")
                            try:
                                grad_norm = accelerator.clip_grad_norm_(pipeline.unet.parameters(), gradient_clip_value)
                                print(f"Gradient norm after clipping: {grad_norm}")
                            except Exception as e:
                                print(f"❌ Error during gradient clipping: {e}")
                                traceback.print_exc()
                                nan_count += 1
                                if nan_count > max_nan_tolerance:
                                    return
                                optimizer.zero_grad()
                    
                    # Optimizer step
                    print("Stepping optimizer...")
                    try:
                        optimizer.step()
                        lr_scheduler.step()
                        optimizer.zero_grad()
                        print("✅ Optimizer step successful")
                    except Exception as e:
                        print(f"❌ Error during optimizer step: {e}")
                        traceback.print_exc()
                        nan_count += 1
                        if nan_count > max_nan_tolerance:
                            return
                        optimizer.zero_grad()
                
                global_step += 1
                
            except Exception as e:
                print(f"❌ Error in training step {step}: {e}")
                traceback.print_exc()
                nan_count += 1
                if nan_count > max_nan_tolerance:
                    print("Too many errors, stopping debug training")
                    return
        
        print(f"\nCompleted epoch {epoch} debug training")
        print(f"Total NaN/error count: {nan_count}")
    
    print("\n" + "="*60)
    print("DEBUG TRAINING LOOP COMPLETED")
    print("="*60)

# Run the debug training loop if main training failed
print("Running custom debug training loop...")
custom_debug_training_loop()

## Summary and Next Steps

This notebook provides comprehensive debugging for the NaN gradient and AMP unscale_() issues.

### Key Debugging Features:
1. **Conservative Training Settings** - Reduced learning rate, batch size, and disabled mixed precision
2. **Comprehensive NaN Checking** - Input data, loss, gradients, and model parameters
3. **AMP Scaler State Monitoring** - Debug the gradient scaler state to prevent unscale_() errors
4. **Gradient State Analysis** - Detailed inspection of parameter gradients
5. **Error Recovery** - Graceful handling of NaN/Inf values with early stopping

### Fixes Applied to the Main Model:
- Fixed the gradient checking logic to avoid the unscale_() error
- Always call optimizer.step() even when gradients are zeroed to maintain scaler consistency
- Added comprehensive NaN checking throughout the training pipeline
- Implemented conservative training settings for stability

### Recommendations:
1. **Start with conservative settings** (as set in this notebook)
2. **Monitor NaN count** - if it's consistently high, check your data preprocessing
3. **Gradually increase learning rate** once training is stable
4. **Re-enable mixed precision** only after confirming stability with float32
5. **Check your dataset** for corrupted or extreme values that might cause NaN

Run each cell in sequence to debug your training issues systematically.