# DJ Transition Generation Training - Kaggle Edition

This notebook trains a deep learning model to generate smooth transitions between electronic music tracks using a U-Net architecture. 

## Instructions
1. Enable GPU accelerator in Kaggle (Settings → Accelerator → GPU)
2. Run all cells sequentially
3. Training will take approximately 2-3 hours on Kaggle GPU
4. Download the trained model from the output section

## Install Dependencies

Install required packages that may not be available in Kaggle's default environment.

In [None]:
# Install required packages
!pip install tensorboard -q
!pip install soundfile -q
!pip install librosa -q
!pip install scikit-image -q

print("All dependencies installed successfully!")

## 🔧 Setup Environment and Import Libraries

Import all necessary libraries and configure GPU settings for optimal performance.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
import os
import sys
import time
import gc
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import soundfile as sf
from IPython.display import display, Audio, HTML
import warnings
warnings.filterwarnings('ignore')

# Optimize PyTorch performance
torch.backends.cudnn.benchmark = True if torch.cuda.is_available() else False
torch.backends.cudnn.deterministic = False

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("No GPU available, using CPU (training will be much slower)")

# Create necessary directories
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('logs', exist_ok=True)
os.makedirs('outputs', exist_ok=True)

print("Environment setup complete!")

## ⚙️ Load Configuration

Define model hyperparameters and training configuration optimized for Kaggle environment.

In [None]:
# Configuration Parameters - Optimized for Kaggle
SAMPLE_RATE = 22050
N_FFT = 2048
HOP_LENGTH = 512
N_MELS = 128
SPECTROGRAM_HEIGHT = 128
SPECTROGRAM_WIDTH = 512
SEGMENT_DURATION = 12.0  # Use 12s effective duration (cut last 3s from 15s)

# Model Configuration
IN_CHANNELS = 3  # [source_a, source_b, noise]
OUT_CHANNELS = 1  # transition
MODEL_DIM = 512

# Training Configuration - Optimized for Kaggle GPU
BATCH_SIZE = 28  # Increased batch size for better GPU utilization
GRADIENT_ACCUMULATION_STEPS = 2
LEARNING_RATE = 1e-4
NUM_EPOCHS = 30  # Reduced epochs for Kaggle time limits
WARMUP_EPOCHS = 3

# Dataset Configuration
TRAIN_SAMPLES = 15000  # Will be determined by your uploaded files
VAL_SAMPLES = 2000     # Will be determined by your uploaded files

print(" Configuration loaded:")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Learning Rate: {LEARNING_RATE}")
print(f"   Model Dimension: {MODEL_DIM}")
print(f"   Training Samples: {TRAIN_SAMPLES}")
print(f"   Validation Samples: {VAL_SAMPLES}")
print(f"   Spectrogram Size: {SPECTROGRAM_HEIGHT}x{SPECTROGRAM_WIDTH}")
print(f"   Segment Duration (effective): {SEGMENT_DURATION}s (cut last 3s from 15s)")

## Define Model Architecture

Implement the ProductionUNet model with encoder-decoder architecture for DJ transition generation.

In [None]:
class ProductionUNet(nn.Module):
    """
    Production U-Net model for generating DJ transitions
    """
    
    def __init__(self, in_channels=3, out_channels=1, model_dim=512):
        super().__init__()
        
        # Encoder path
        self.enc1 = self._make_encoder_block(in_channels, 64)
        self.enc2 = self._make_encoder_block(64, 128) 
        self.enc3 = self._make_encoder_block(128, 256)
        self.enc4 = self._make_encoder_block(256, 512)
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, model_dim, 3, padding=1),
            nn.BatchNorm2d(model_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(model_dim, model_dim, 3, padding=1),
            nn.BatchNorm2d(model_dim),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.2)
        )
        
        # Decoder path
        self.upconv4 = nn.ConvTranspose2d(model_dim, 512, 2, stride=2)
        self.dec4 = self._make_decoder_block(512 + 256, 512)
        
        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = self._make_decoder_block(256 + 128, 256)
        
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = self._make_decoder_block(128 + 64, 128)
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = self._make_decoder_block(64, 64)
        
        # Final output layer
        self.final = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, out_channels, 1),
            nn.Tanh()
        )
        
    def _make_encoder_block(self, in_channels, out_channels):
        """Create encoder block with convolution, normalization, and pooling"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
    
    def _make_decoder_block(self, in_channels, out_channels):
        """Create decoder block with convolution, normalization, and dropout"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1)
        )
    
    def forward(self, x):
        """Forward pass through the U-Net"""
        # Encoder path with skip connections
        e1 = self.enc1(x)      # [B, 64, H/2, W/2]
        e2 = self.enc2(e1)     # [B, 128, H/4, W/4]
        e3 = self.enc3(e2)     # [B, 256, H/8, W/8]
        e4 = self.enc4(e3)     # [B, 512, H/16, W/16]
        
        # Bottleneck
        bottleneck = self.bottleneck(e4)  # [B, model_dim, H/16, W/16]
        
        # Decoder path with skip connections
        d4 = self.upconv4(bottleneck)     # [B, 512, H/8, W/8]
        d4 = torch.cat([d4, e3], dim=1)   # [B, 512+256, H/8, W/8]
        d4 = self.dec4(d4)                # [B, 512, H/8, W/8]
        
        d3 = self.upconv3(d4)             # [B, 256, H/4, W/4]
        d3 = torch.cat([d3, e2], dim=1)   # [B, 256+128, H/4, W/4]
        d3 = self.dec3(d3)                # [B, 256, H/4, W/4]
        
        d2 = self.upconv2(d3)             # [B, 128, H/2, W/2]
        d2 = torch.cat([d2, e1], dim=1)   # [B, 128+64, H/2, W/2]
        d2 = self.dec2(d2)                # [B, 128, H/2, W/2]
        
        d1 = self.upconv1(d2)             # [B, 64, H, W]
        d1 = self.dec1(d1)                # [B, 64, H, W]
        
        # Final output
        output = self.final(d1)           # [B, out_channels, H, W]
        
        return output
    
    def count_parameters(self):
        """Count the number of trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

# Test model creation
test_model = ProductionUNet(
    in_channels=IN_CHANNELS,
    out_channels=OUT_CHANNELS,
    model_dim=MODEL_DIM
)

print(f"Model created successfully!")
print(f"Total parameters: {test_model.count_parameters():,}")
print(f"Model size: ~{test_model.count_parameters() * 4 / 1024**2:.1f} MB")

# Test forward pass
test_input = torch.randn(1, IN_CHANNELS, SPECTROGRAM_HEIGHT, SPECTROGRAM_WIDTH)
with torch.no_grad():
    test_output = test_model(test_input)
print(f"Model forward pass test successful!")
print(f"   Input shape: {test_input.shape}")
print(f"   Output shape: {test_output.shape}")

del test_model, test_input, test_output
torch.cuda.empty_cache() if torch.cuda.is_available() else None

## 📊 Load Your Own Dataset

Upload and process your own audio files for training the DJ transition model.

In [None]:
import librosa
import glob
from pathlib import Path
import random

class AudioTransitionDataset(Dataset):
    """
    Dataset for training DJ transition models using real audio files
    """
    
    def __init__(self, audio_dir, spectrogram_size=(128, 512), 
                 sample_rate=22050, hop_length=512, split='train', val_split=0.2):
        self.audio_dir = Path(audio_dir)
        self.height, self.width = spectrogram_size
        self.sample_rate = sample_rate
        self.hop_length = hop_length
        self.split = split

        # Segment policy: dataset built with 15s segments; use only first 12s
        self.raw_segment_duration = 15.0      # seconds (as in dataset)
        self.effective_duration = 12.0        # seconds (cut last 3s)
        
        # Find all audio files
        audio_extensions = ['*.wav', '*.mp3', '*.flac', '*.m4a']
        audio_files = []
        for ext in audio_extensions:
            audio_files.extend(glob.glob(str(self.audio_dir / f"**/{ext}"), recursive=True))
        
        if not audio_files:
            raise ValueError(f"No audio files found in {audio_dir}")
        
        # Split into train/validation
        random.seed(42)
        random.shuffle(audio_files)
        split_idx = int(len(audio_files) * (1 - val_split))
        
        if split == 'train':
            self.audio_files = audio_files[:split_idx]
        else:
            self.audio_files = audio_files[split_idx:]
        
        print(f"Found {len(audio_files)} total audio files")
        print(f"{split} split: {len(self.audio_files)} files")
        print(f"Segment policy: take 15s from file, keep first 12s (drop last 3s)")
        
    def __len__(self):
        # Generate multiple combinations from available files
        return len(self.audio_files) * 3  # 3x augmentation
    
    def __getitem__(self, idx):
        """
        Load and process audio files to create training examples
        
        Returns:
            inputs: [3, H, W] tensor containing [source_a, source_b, noise]
            target: [1, H, W] tensor containing target transition
        """
        # Select two different audio files
        file_idx = idx % len(self.audio_files)
        source_a_path = self.audio_files[file_idx]
        
        # Select different file for source B
        source_b_idx = (file_idx + 1 + (idx // len(self.audio_files))) % len(self.audio_files)
        source_b_path = self.audio_files[source_b_idx]
        
        # Load and process audio files
        try:
            source_a_spec = self._load_audio_segment(source_a_path)
            source_b_spec = self._load_audio_segment(source_b_path)
            
            # Generate noise for the third channel
            noise = torch.randn(self.height, self.width) * 0.05
            
            # Create target transition
            target = self._create_transition_target(source_a_spec, source_b_spec)
            
            # Stack inputs
            inputs = torch.stack([source_a_spec, source_b_spec, noise], dim=0)
            target = target.unsqueeze(0)  # Add channel dimension
            
            return inputs, target
            
        except Exception as e:
            print(f"Error loading {source_a_path} or {source_b_path}: {e}")
            # Fallback to synthetic data if file loading fails
            return self._generate_fallback_data()
    
    def _load_audio_segment(self, audio_path):
        """Load and convert audio file to mel-spectrogram"""
        # Load audio
        audio, sr = librosa.load(audio_path, sr=self.sample_rate)
        
        # Extract random 15s segment (raw), then keep first 12s (drop last 3s)
        raw_segment_samples = int(self.raw_segment_duration * self.sample_rate)   # 15s
        effective_segment_samples = int(self.effective_duration * self.sample_rate)  # 12s
        
        if len(audio) > raw_segment_samples:
            start_idx = random.randint(0, len(audio) - raw_segment_samples)
            audio = audio[start_idx:start_idx + raw_segment_samples]
        elif len(audio) < raw_segment_samples:
            # Pad if too short
            audio = np.pad(audio, (0, raw_segment_samples - len(audio)), mode='constant')
        
        # Keep the first 12 seconds (cut the last 3 seconds)
        audio = audio[:effective_segment_samples]
        
        # Convert to mel-spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=audio,
            sr=self.sample_rate,
            n_mels=self.height,
            hop_length=self.hop_length,
            n_fft=2048,
            fmin=20,
            fmax=8000
        )
        
        # Convert to log scale and normalize
        log_mel = np.log(mel_spec + 1e-6)
        
        # Ensure we only keep frames corresponding to 12s before resizing
        frames_12 = int(self.effective_duration * self.sample_rate / self.hop_length)
        if log_mel.shape[1] > frames_12:
            log_mel = log_mel[:, :frames_12]
        elif log_mel.shape[1] < frames_12:
            pad_w = frames_12 - log_mel.shape[1]
            log_mel = np.pad(log_mel, ((0, 0), (0, pad_w)), mode='constant')
        
        # Resize to exact target size if needed (map 12s → self.width=512)
        if log_mel.shape[1] != self.width:
            from scipy.ndimage import zoom
            zoom_factor = self.width / log_mel.shape[1]
            log_mel = zoom(log_mel, (1, zoom_factor), order=1)
        
        # Normalize to [-1, 1]
        log_mel = (log_mel - log_mel.mean()) / (log_mel.std() + 1e-6)
        log_mel = np.clip(log_mel, -3, 3) / 3  # Soft clipping
        
        return torch.FloatTensor(log_mel)
    
    def _create_transition_target(self, source_a, source_b):
        """Create target transition between two audio spectrograms"""
        transition = torch.zeros_like(source_a)
        
        # Create smooth crossfade with creative elements
        for i in range(self.width):
            # Main crossfade curve
            alpha = i / self.width
            
            # Add beat-sync variations
            beat_sync = 0.1 * np.sin(2 * np.pi * alpha * 8)  # 8 beats across transition
            alpha = np.clip(alpha + beat_sync, 0, 1)
            
            # Crossfade with frequency-dependent mixing
            transition[:, i] = (1 - alpha) * source_a[:, i] + alpha * source_b[:, i]
        
        # Add transition-specific effects
        # Enhance bass frequencies during transition
        bass_region = slice(0, self.height // 6)
        transition[bass_region, self.width//3:2*self.width//3] *= 1.2
        
        # Add high-frequency sweep
        for i in range(self.width//4, 3*self.width//4):
            sweep_freq = int(self.height//2 + (self.height//4) * (i - self.width//4) / (self.width//2))
            if sweep_freq < self.height:
                transition[sweep_freq, i] += 0.3
        
        # Normalize
        transition = (transition - transition.mean()) / (transition.std() + 1e-6)
        transition = np.clip(transition, -3, 3) / 3
        
        return transition
    
    def _generate_fallback_data(self):
        """Generate synthetic data as fallback"""
        print("Using synthetic fallback data")
        # Simple synthetic spectrograms
        source_a = torch.randn(self.height, self.width) * 0.5
        source_b = torch.randn(self.height, self.width) * 0.5
        noise = torch.randn(self.height, self.width) * 0.05
        
        # Simple crossfade target
        target = torch.zeros_like(source_a)
        for i in range(self.width):
            alpha = i / self.width
            target[:, i] = (1 - alpha) * source_a[:, i] + alpha * source_b[:, i]
        
        inputs = torch.stack([source_a, source_b, noise], dim=0)
        target = target.unsqueeze(0)
        
        return inputs, target

# Dataset configuration
print("Setting up audio dataset...")
print("Instructions:")
print("   1. Upload your audio files to Kaggle (Add Data → Upload)")
print("   2. Update AUDIO_DATA_PATH below to point to your uploaded folder")
print("   3. Supported formats: WAV, MP3, FLAC, M4A")
print("   4. Recommended: at least 100+ audio files for good training")

# Configure your audio data path here
AUDIO_DATA_PATH = "/kaggle/input/your-audio-dataset"  # Update this path!

# Check if audio data exists
if not os.path.exists(AUDIO_DATA_PATH):
    print(f"Audio data path not found: {AUDIO_DATA_PATH}")
    print("Please update AUDIO_DATA_PATH to point to your uploaded audio files")
    print("Creating minimal synthetic dataset for demonstration...")
    
    # Create a minimal synthetic dataset as fallback
    class FallbackDataset(Dataset):
        def __init__(self, num_samples=1000):
            self.num_samples = num_samples
            
        def __len__(self):
            return self.num_samples
            
        def __getitem__(self, idx):
            # Generate simple synthetic data
            source_a = torch.randn(SPECTROGRAM_HEIGHT, SPECTROGRAM_WIDTH) * 0.5
            source_b = torch.randn(SPECTROGRAM_HEIGHT, SPECTROGRAM_WIDTH) * 0.5
            noise = torch.randn(SPECTROGRAM_HEIGHT, SPECTROGRAM_WIDTH) * 0.05
            
            # Create crossfade target
            target = torch.zeros_like(source_a)
            for i in range(SPECTROGRAM_WIDTH):
                alpha = i / SPECTROGRAM_WIDTH
                target[:, i] = (1 - alpha) * source_a[:, i] + alpha * source_b[:, i]
            
            inputs = torch.stack([source_a, source_b, noise], dim=0)
            target = target.unsqueeze(0)
            
            return inputs, target
    
    # Use fallback dataset
    train_dataset = FallbackDataset(TRAIN_SAMPLES)
    val_dataset = FallbackDataset(VAL_SAMPLES)
    
    print(f"Using fallback synthetic dataset")
    print(f"   Training samples: {len(train_dataset):,}")
    print(f"   Validation samples: {len(val_dataset):,}")
    
else:
    # Use real audio dataset
    print(f"Found audio data at: {AUDIO_DATA_PATH}")
    
    try:
        # Create datasets
        train_dataset = AudioTransitionDataset(
            audio_dir=AUDIO_DATA_PATH,
            spectrogram_size=(SPECTROGRAM_HEIGHT, SPECTROGRAM_WIDTH),
            sample_rate=SAMPLE_RATE,
            hop_length=HOP_LENGTH,
            split='train',
            val_split=0.2
        )
        
        val_dataset = AudioTransitionDataset(
            audio_dir=AUDIO_DATA_PATH,
            spectrogram_size=(SPECTROGRAM_HEIGHT, SPECTROGRAM_WIDTH),
            sample_rate=SAMPLE_RATE,
            hop_length=HOP_LENGTH,
            split='val',
            val_split=0.2
        )
        
        print(f"Real audio dataset created successfully!")
        
    except Exception as e:
        print(f"Error creating audio dataset: {e}")
        print("Falling back to synthetic dataset...")
        
        # Fallback to synthetic
        train_dataset = FallbackDataset(TRAIN_SAMPLES)
        val_dataset = FallbackDataset(VAL_SAMPLES)

# Test dataset
test_inputs, test_targets = train_dataset[0]
print(f"Dataset ready!")
print(f"   Training samples: {len(train_dataset):,}")
print(f"   Validation samples: {len(val_dataset):,}")
print(f"   Input shape: {test_inputs.shape}")
print(f"   Target shape: {test_targets.shape}")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"Data loaders created!")
print(f"   Training batches: {len(train_loader)}")
print(f"   Validation batches: {len(val_loader)}")

del test_inputs, test_targets

### How to Upload Your Audio Data to Kaggle:

1. **Prepare Your Audio Files:**
   - Organize your audio files in folders (e.g., by genre, artist, etc.)
   - Supported formats: WAV, MP3, FLAC, M4A
   - Recommended: 100+ audio files for good training results
   - Each file should be at least 15 seconds long

2. **Upload to Kaggle:**
   - Click "Add Data" → "Upload" in the right panel
   - Upload your audio folder as a ZIP file
   - Wait for upload to complete
   - Note the dataset path (usually `/kaggle/input/your-dataset-name/`)

3. **Update the Path:**
   - Modify the `AUDIO_DATA_PATH` variable below
   - Run the dataset creation cell

## Initialize Training Components

Set up the trainer class with model, optimizer, loss function, and data loaders for training.



In [None]:
# Initialize model
model = ProductionUNet(
    in_channels=IN_CHANNELS,
    out_channels=OUT_CHANNELS,
    model_dim=MODEL_DIM
).to(device)

print(f"Model initialized!")
print(f"Parameters: {model.count_parameters():,}")
print(f"Device: {device}")

# Initialize optimizer with learning rate scheduling
optimizer = optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=1e-5,
    betas=(0.9, 0.999)
)

# Learning rate scheduler with warmup
def get_lr_schedule_with_warmup(optimizer, warmup_epochs, total_epochs):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            # Warmup phase
            return (epoch + 1) / warmup_epochs
        else:
            # Cosine annealing phase
            progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
            return 0.5 * (1 + np.cos(np.pi * progress))
    
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scheduler = get_lr_schedule_with_warmup(optimizer, WARMUP_EPOCHS, NUM_EPOCHS)

# Loss function
criterion = nn.MSELoss()

# Initialize tensorboard
writer = SummaryWriter('logs/kaggle_training')

print("Training components initialized!")
print(f"   Optimizer: AdamW")
print(f"   Learning Rate: {LEARNING_RATE}")
print(f"   Scheduler: Warmup + Cosine Annealing")
print(f"   Loss Function: MSE")

## Load Checkpoint (Optional)

Load a pre-trained checkpoint to continue training from a previous state. Update the checkpoint path below if you have one.

In [None]:
# Checkpoint loading configuration
CHECKPOINT_PATH = '/kaggle/input/checkpoint5k/other/default/1/best_model_kaggle.pt'  # Set to checkpoint path to resume training
# Example: CHECKPOINT_PATH = "/kaggle/input/your-checkpoint/best_model_kaggle.pt"
# Or: CHECKPOINT_PATH = "checkpoints/model_epoch_10.pt"

# Initialize training state variables
start_epoch = 0
best_val_loss = float('inf')
train_losses = []
val_losses = []
learning_rates = []

# Load checkpoint if specified
if CHECKPOINT_PATH and os.path.exists(CHECKPOINT_PATH):
    print(f"Loading checkpoint from: {CHECKPOINT_PATH}")
    
    try:
        checkpoint = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=False)
        
        # Load model state
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Model state loaded successfully")
        
        # Load optimizer state
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print(f"Optimizer state loaded successfully")
        
        # Load scheduler state
        if 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            print(f"Scheduler state loaded successfully")
        
        # Load training state
        if 'epoch' in checkpoint:
            start_epoch = checkpoint['epoch']
            print(f"Resuming from epoch {start_epoch}")
        
        if 'best_val_loss' in checkpoint:
            best_val_loss = checkpoint['best_val_loss']
            print(f"Best validation loss: {best_val_loss:.6f}")
        
        # Load training history if available
        if 'train_losses' in checkpoint:
            train_losses = checkpoint.get('train_losses', [])
            val_losses = checkpoint.get('val_losses', [])
            learning_rates = checkpoint.get('learning_rates', [])
            print(f"Training history loaded ({len(train_losses)} epochs)")
        
        # Verify model configuration matches
        if 'model_config' in checkpoint:
            config = checkpoint['model_config']
            if (config.get('in_channels') == IN_CHANNELS and 
                config.get('out_channels') == OUT_CHANNELS and 
                config.get('model_dim') == MODEL_DIM):
                print(f"Model configuration verified")
            else:
                print(f" Model configuration mismatch:")
                print(f"   Checkpoint: {config}")
                print(f"   Current: in_channels={IN_CHANNELS}, out_channels={OUT_CHANNELS}, model_dim={MODEL_DIM}")
        
        print(f" Checkpoint loaded successfully!")
        
    except Exception as e:
        print(f" Error loading checkpoint: {e}")
        print(f"   Continuing with fresh training...")
        start_epoch = 0
        best_val_loss = float('inf')
        train_losses = []
        val_losses = []
        learning_rates = []

elif CHECKPOINT_PATH:
    print(f"  Checkpoint path specified but file not found: {CHECKPOINT_PATH}")
    print(f"   Starting fresh training...")
    
else:
    print(f" No checkpoint specified - starting fresh training")

print(f"  Training configuration:")
print(f"  Starting epoch: {start_epoch + 1}")
print(f"  Target epochs: {NUM_EPOCHS}")
print(f"  Best validation loss: {best_val_loss:.6f}" if best_val_loss != float('inf') else "   Best validation loss: Not set")
print(f"  Training history: {len(train_losses)} previous epochs")

## Training Loop Implementation

Execute the main training loop with epoch-by-epoch training, validation, and progress monitoring.

In [None]:
print(" Starting training...")
print(f" Epochs: {NUM_EPOCHS} | Batch size: {BATCH_SIZE}")
print("=" * 60)

start_time = time.time()

for epoch in range(start_epoch, NUM_EPOCHS):
    current_epoch = epoch + 1
    
    print(f"\n Epoch {current_epoch}/{NUM_EPOCHS}")
    print("-" * 50)
    
    # Training phase
    model.train()
    epoch_train_loss = 0.0
    num_train_batches = len(train_loader)
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update weights
        if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()
        
        epoch_train_loss += loss.item()
        
        # Progress logging
        if batch_idx % 200 == 0:
            progress = batch_idx / num_train_batches * 100
            print(f"   Batch {batch_idx}/{num_train_batches} ({progress:.1f}%) - Loss: {loss.item():.6f}")
    
    avg_train_loss = epoch_train_loss / num_train_batches
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    epoch_val_loss = 0.0
    num_val_batches = len(val_loader)
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            epoch_val_loss += loss.item()
    
    avg_val_loss = epoch_val_loss / num_val_batches
    val_losses.append(avg_val_loss)
    
    # Update learning rate
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)
    
    # Log to tensorboard
    writer.add_scalar('Loss/Train', avg_train_loss, current_epoch)
    writer.add_scalar('Loss/Validation', avg_val_loss, current_epoch)
    writer.add_scalar('Learning_Rate', current_lr, current_epoch)
    
    # Print epoch summary
    print(f" Train Loss: {avg_train_loss:.6f}")
    print(f" Val Loss: {avg_val_loss:.6f}")
    print(f" Learning Rate: {current_lr:.2e}")
    
    # Save best model
    is_best = avg_val_loss < best_val_loss
    if is_best:
        best_val_loss = avg_val_loss
        # Save best model
        best_checkpoint = {
            'epoch': current_epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'best_val_loss': best_val_loss,
            'model_config': {
            'in_channels': IN_CHANNELS,
            'out_channels': OUT_CHANNELS,
            'model_dim': MODEL_DIM
            }
        }
        torch.save(best_checkpoint, 'checkpoints/best_model_kaggle.pt')
        print(f" New best model saved! (Val Loss: {best_val_loss:.6f})")
        
        # Save regular checkpoint every 5 epochs
        if current_epoch % 5 == 0:
            checkpoint = {
                'epoch': current_epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'best_val_loss': best_val_loss,
                'model_config': {
                'in_channels': IN_CHANNELS,
                'out_channels': OUT_CHANNELS,
                'model_dim': MODEL_DIM
                }
            }
            torch.save(checkpoint, f'checkpoints/model_epoch_{current_epoch}.pt')
            print(f"💾 Checkpoint saved: model_epoch_{current_epoch}.pt")
            
            # Memory cleanup
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                gc.collect()

    # Training completed
    total_time = time.time() - start_time
    print(f"\n Training completed!")
    print(f" Total time: {total_time:.1f}s ({total_time/60:.1f}m)")
    print(f" Best validation loss: {best_val_loss:.6f}")

    # Save final checkpoint
    final_checkpoint = {
        'epoch': NUM_EPOCHS,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'learning_rates': learning_rates,
        'best_val_loss': best_val_loss,
        'total_training_time': total_time,
        'model_config': {
        'in_channels': IN_CHANNELS,
        'out_channels': OUT_CHANNELS,
        'model_dim': MODEL_DIM
        }
    }
    torch.save(final_checkpoint, 'checkpoints/final_model_kaggle.pt')
    print(" Final checkpoint saved: final_model_kaggle.pt")

    # Close tensorboard writer
    writer.close()