# DJ Transition Generation Training - Kaggle Edition

This notebook trains a deep learning model to generate smooth transitions between music tracks using a U-Net architecture. 

## Install Dependencies

Install required packages that may not be available in Kaggle's default environment.

In [None]:
# Install required packages
!pip install tensorboard -q
!pip install soundfile -q
!pip install librosa -q
!pip install scikit-image -q

print("All dependencies installed successfully!")

## Setup Environment and Import Libraries

Import all necessary libraries and configure GPU settings for optimal performance.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
import os
import sys
import time
import gc
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import soundfile as sf
from IPython.display import display, Audio, HTML
import warnings
warnings.filterwarnings('ignore')

# Optimize PyTorch performance
torch.backends.cudnn.benchmark = True if torch.cuda.is_available() else False
torch.backends.cudnn.deterministic = False

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🔧 Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("No GPU available, using CPU (training will be much slower)")

# Create necessary directories
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('logs', exist_ok=True)
os.makedirs('outputs', exist_ok=True)

print("Environment setup complete!")

## Load Configuration

Define model hyperparameters and training configuration optimized for Kaggle environment.

In [None]:
# Configuration Parameters - Optimized for Kaggle
SAMPLE_RATE = 22050
N_FFT = 2048
HOP_LENGTH = 512
N_MELS = 128
SPECTROGRAM_HEIGHT = 128
SPECTROGRAM_WIDTH = 512
SEGMENT_DURATION = 15.0

# Model Configuration
IN_CHANNELS = 3  # [source_a, source_b, noise]
OUT_CHANNELS = 1  # transition
MODEL_DIM = 512

# Training Configuration - Optimized for Kaggle GPU
BATCH_SIZE = 8  # Increased batch size for better GPU utilization
GRADIENT_ACCUMULATION_STEPS = 2
LEARNING_RATE = 1e-4
NUM_EPOCHS = 30  # Reduced epochs for Kaggle time limits
WARMUP_EPOCHS = 3

# Dataset Configuration
TRAIN_SAMPLES = 15000  # Will be determined by your uploaded files
VAL_SAMPLES = 2000     # Will be determined by your uploaded files

print(" Configuration loaded:")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Learning Rate: {LEARNING_RATE}")
print(f"   Model Dimension: {MODEL_DIM}")
print(f"   Training Samples: {TRAIN_SAMPLES}")
print(f"   Validation Samples: {VAL_SAMPLES}")
print(f"   Spectrogram Size: {SPECTROGRAM_HEIGHT}x{SPECTROGRAM_WIDTH}")
print(f"   Segment Duration: {SEGMENT_DURATION}s")

## Define Model Architecture

Implement the ProductionUNet model with encoder-decoder architecture for DJ transition generation.

In [None]:
class ProductionUNet(nn.Module):
    """
    Production U-Net model for generating DJ transitions
    """
    
    def __init__(self, in_channels=3, out_channels=1, model_dim=512):
        super().__init__()
        
        # Encoder path
        self.enc1 = self._make_encoder_block(in_channels, 64)
        self.enc2 = self._make_encoder_block(64, 128) 
        self.enc3 = self._make_encoder_block(128, 256)
        self.enc4 = self._make_encoder_block(256, 512)
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, model_dim, 3, padding=1),
            nn.BatchNorm2d(model_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(model_dim, model_dim, 3, padding=1),
            nn.BatchNorm2d(model_dim),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.2)
        )
        
        # Decoder path
        self.upconv4 = nn.ConvTranspose2d(model_dim, 512, 2, stride=2)
        self.dec4 = self._make_decoder_block(512 + 256, 512)
        
        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = self._make_decoder_block(256 + 128, 256)
        
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = self._make_decoder_block(128 + 64, 128)
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = self._make_decoder_block(64, 64)
        
        # Final output layer
        self.final = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, out_channels, 1),
            nn.Tanh()
        )
        
    def _make_encoder_block(self, in_channels, out_channels):
        """Create encoder block with convolution, normalization, and pooling"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
    
    def _make_decoder_block(self, in_channels, out_channels):
        """Create decoder block with convolution, normalization, and dropout"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1)
        )
    
    def forward(self, x):
        """Forward pass through the U-Net"""
        # Encoder path with skip connections
        e1 = self.enc1(x)      # [B, 64, H/2, W/2]
        e2 = self.enc2(e1)     # [B, 128, H/4, W/4]
        e3 = self.enc3(e2)     # [B, 256, H/8, W/8]
        e4 = self.enc4(e3)     # [B, 512, H/16, W/16]
        
        # Bottleneck
        bottleneck = self.bottleneck(e4)  # [B, model_dim, H/16, W/16]
        
        # Decoder path with skip connections
        d4 = self.upconv4(bottleneck)     # [B, 512, H/8, W/8]
        d4 = torch.cat([d4, e3], dim=1)   # [B, 512+256, H/8, W/8]
        d4 = self.dec4(d4)                # [B, 512, H/8, W/8]
        
        d3 = self.upconv3(d4)             # [B, 256, H/4, W/4]
        d3 = torch.cat([d3, e2], dim=1)   # [B, 256+128, H/4, W/4]
        d3 = self.dec3(d3)                # [B, 256, H/4, W/4]
        
        d2 = self.upconv2(d3)             # [B, 128, H/2, W/2]
        d2 = torch.cat([d2, e1], dim=1)   # [B, 128+64, H/2, W/2]
        d2 = self.dec2(d2)                # [B, 128, H/2, W/2]
        
        d1 = self.upconv1(d2)             # [B, 64, H, W]
        d1 = self.dec1(d1)                # [B, 64, H, W]
        
        # Final output
        output = self.final(d1)           # [B, out_channels, H, W]
        
        return output
    
    def count_parameters(self):
        """Count the number of trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

# Test model creation
test_model = ProductionUNet(
    in_channels=IN_CHANNELS,
    out_channels=OUT_CHANNELS,
    model_dim=MODEL_DIM
)

print(f"Model created successfully!")
print(f"Total parameters: {test_model.count_parameters():,}")
print(f"Model size: ~{test_model.count_parameters() * 4 / 1024**2:.1f} MB")

# Test forward pass
test_input = torch.randn(1, IN_CHANNELS, SPECTROGRAM_HEIGHT, SPECTROGRAM_WIDTH)
with torch.no_grad():
    test_output = test_model(test_input)
print(f"Model forward pass test successful!")
print(f"   Input shape: {test_input.shape}")
print(f"   Output shape: {test_output.shape}")

del test_model, test_input, test_output
torch.cuda.empty_cache() if torch.cuda.is_available() else None

## Load Your Own Dataset

Upload and process your own audio files for training the DJ transition model.

In [None]:
import librosa
import glob
from pathlib import Path
import random

class AudioTransitionDataset(Dataset):
    """
    Dataset for training DJ transition models using real audio files
    """
    
    def __init__(self, audio_dir, spectrogram_size=(128, 512), 
                 sample_rate=22050, hop_length=512, split='train', val_split=0.2):
        self.audio_dir = Path(audio_dir)
        self.height, self.width = spectrogram_size
        self.sample_rate = sample_rate
        self.hop_length = hop_length
        self.split = split
        
        # Find all audio files
        audio_extensions = ['*.wav', '*.mp3', '*.flac', '*.m4a']
        audio_files = []
        for ext in audio_extensions:
            audio_files.extend(glob.glob(str(self.audio_dir / f"**/{ext}"), recursive=True))
        
        if not audio_files:
            raise ValueError(f"No audio files found in {audio_dir}")
        
        # Split into train/validation
        random.seed(42)
        random.shuffle(audio_files)
        split_idx = int(len(audio_files) * (1 - val_split))
        
        if split == 'train':
            self.audio_files = audio_files[:split_idx]
        else:
            self.audio_files = audio_files[split_idx:]
        
        print(f"Found {len(audio_files)} total audio files")
        print(f"{split} split: {len(self.audio_files)} files")
        
        # Pre-compute segment duration for consistent sizing
        self.segment_duration = self.width * self.hop_length / self.sample_rate  # ~15 seconds
        
    def __len__(self):
        # Generate multiple combinations from available files
        return len(self.audio_files) * 3  # 3x augmentation
    
    def __getitem__(self, idx):
        """
        Load and process audio files to create training examples
        
        Returns:
            inputs: [3, H, W] tensor containing [source_a, source_b, noise]
            target: [1, H, W] tensor containing target transition
        """
        # Select two different audio files
        file_idx = idx % len(self.audio_files)
        source_a_path = self.audio_files[file_idx]
        
        # Select different file for source B
        source_b_idx = (file_idx + 1 + (idx // len(self.audio_files))) % len(self.audio_files)
        source_b_path = self.audio_files[source_b_idx]
        
        # Load and process audio files
        try:
            source_a_spec = self._load_audio_segment(source_a_path)
            source_b_spec = self._load_audio_segment(source_b_path)
            
            # Generate noise for the third channel
            noise = torch.randn(self.height, self.width) * 0.05
            
            # Create target transition
            target = self._create_transition_target(source_a_spec, source_b_spec)
            
            # Stack inputs
            inputs = torch.stack([source_a_spec, source_b_spec, noise], dim=0)
            target = target.unsqueeze(0)  # Add channel dimension
            
            return inputs, target
            
        except Exception as e:
            print(f"Error loading {source_a_path} or {source_b_path}: {e}")
            # Fallback to synthetic data if file loading fails
            return self._generate_fallback_data()
    
    def _load_audio_segment(self, audio_path):
        """Load and convert audio file to mel-spectrogram"""
        # Load audio
        audio, sr = librosa.load(audio_path, sr=self.sample_rate)
        
        # Extract random segment if audio is longer than needed
        segment_samples = int(self.segment_duration * self.sample_rate)
        
        if len(audio) > segment_samples:
            start_idx = random.randint(0, len(audio) - segment_samples)
            audio = audio[start_idx:start_idx + segment_samples]
        elif len(audio) < segment_samples:
            # Pad if too short
            audio = np.pad(audio, (0, segment_samples - len(audio)), mode='constant')
        
        # Convert to mel-spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=audio,
            sr=self.sample_rate,
            n_mels=self.height,
            hop_length=self.hop_length,
            n_fft=2048,
            fmin=20,
            fmax=8000
        )
        
        # Convert to log scale and normalize
        log_mel = np.log(mel_spec + 1e-6)
        
        # Resize to exact target size if needed
        if log_mel.shape[1] != self.width:
            from scipy.ndimage import zoom
            zoom_factor = self.width / log_mel.shape[1]
            log_mel = zoom(log_mel, (1, zoom_factor), order=1)
        
        # Normalize to [-1, 1]
        log_mel = (log_mel - log_mel.mean()) / (log_mel.std() + 1e-6)
        log_mel = np.clip(log_mel, -3, 3) / 3  # Soft clipping
        
        return torch.FloatTensor(log_mel)
    
    def _create_transition_target(self, source_a, source_b):
        """Create target transition between two audio spectrograms"""
        transition = torch.zeros_like(source_a)
        
        # Create smooth crossfade with creative elements
        for i in range(self.width):
            # Main crossfade curve
            alpha = i / self.width
            
            # Add beat-sync variations
            beat_sync = 0.1 * np.sin(2 * np.pi * alpha * 8)  # 8 beats across transition
            alpha = np.clip(alpha + beat_sync, 0, 1)
            
            # Crossfade with frequency-dependent mixing
            transition[:, i] = (1 - alpha) * source_a[:, i] + alpha * source_b[:, i]
        
        # Add transition-specific effects
        # Enhance bass frequencies during transition
        bass_region = slice(0, self.height // 6)
        transition[bass_region, self.width//3:2*self.width//3] *= 1.2
        
        # Add high-frequency sweep
        for i in range(self.width//4, 3*self.width//4):
            sweep_freq = int(self.height//2 + (self.height//4) * (i - self.width//4) / (self.width//2))
            if sweep_freq < self.height:
                transition[sweep_freq, i] += 0.3
        
        # Normalize
        transition = (transition - transition.mean()) / (transition.std() + 1e-6)
        transition = np.clip(transition, -3, 3) / 3
        
        return transition
    
    def _generate_fallback_data(self):
        """Generate synthetic data as fallback"""
        print("Using synthetic fallback data")
        # Simple synthetic spectrograms
        source_a = torch.randn(self.height, self.width) * 0.5
        source_b = torch.randn(self.height, self.width) * 0.5
        noise = torch.randn(self.height, self.width) * 0.05
        
        # Simple crossfade target
        target = torch.zeros_like(source_a)
        for i in range(self.width):
            alpha = i / self.width
            target[:, i] = (1 - alpha) * source_a[:, i] + alpha * source_b[:, i]
        
        inputs = torch.stack([source_a, source_b, noise], dim=0)
        target = target.unsqueeze(0)
        
        return inputs, target

# Dataset configuration
print("Setting up audio dataset...")
print("Instructions:")
print("   1. Upload your audio files to Kaggle (Add Data → Upload)")
print("   2. Update AUDIO_DATA_PATH below to point to your uploaded folder")
print("   3. Supported formats: WAV, MP3, FLAC, M4A")
print("   4. Recommended: at least 100+ audio files for good training")

# Configure your audio data path here
AUDIO_DATA_PATH = "/kaggle/input/your-audio-dataset"  # Update this path!

# Check if audio data exists
if not os.path.exists(AUDIO_DATA_PATH):
    print(f"Audio data path not found: {AUDIO_DATA_PATH}")
    print("Please update AUDIO_DATA_PATH to point to your uploaded audio files")
    print("Creating minimal synthetic dataset for demonstration...")
    
    # Create a minimal synthetic dataset as fallback
    class FallbackDataset(Dataset):
        def __init__(self, num_samples=1000):
            self.num_samples = num_samples
            
        def __len__(self):
            return self.num_samples
            
        def __getitem__(self, idx):
            # Generate simple synthetic data
            source_a = torch.randn(SPECTROGRAM_HEIGHT, SPECTROGRAM_WIDTH) * 0.5
            source_b = torch.randn(SPECTROGRAM_HEIGHT, SPECTROGRAM_WIDTH) * 0.5
            noise = torch.randn(SPECTROGRAM_HEIGHT, SPECTROGRAM_WIDTH) * 0.05
            
            # Create crossfade target
            target = torch.zeros_like(source_a)
            for i in range(SPECTROGRAM_WIDTH):
                alpha = i / SPECTROGRAM_WIDTH
                target[:, i] = (1 - alpha) * source_a[:, i] + alpha * source_b[:, i]
            
            inputs = torch.stack([source_a, source_b, noise], dim=0)
            target = target.unsqueeze(0)
            
            return inputs, target
    
    # Use fallback dataset
    train_dataset = FallbackDataset(TRAIN_SAMPLES)
    val_dataset = FallbackDataset(VAL_SAMPLES)
    
    print(f"Using fallback synthetic dataset")
    print(f"   Training samples: {len(train_dataset):,}")
    print(f"   Validation samples: {len(val_dataset):,}")
    
else:
    # Use real audio dataset
    print(f"Found audio data at: {AUDIO_DATA_PATH}")
    
    try:
        # Create datasets
        train_dataset = AudioTransitionDataset(
            audio_dir=AUDIO_DATA_PATH,
            spectrogram_size=(SPECTROGRAM_HEIGHT, SPECTROGRAM_WIDTH),
            sample_rate=SAMPLE_RATE,
            hop_length=HOP_LENGTH,
            split='train',
            val_split=0.2
        )
        
        val_dataset = AudioTransitionDataset(
            audio_dir=AUDIO_DATA_PATH,
            spectrogram_size=(SPECTROGRAM_HEIGHT, SPECTROGRAM_WIDTH),
            sample_rate=SAMPLE_RATE,
            hop_length=HOP_LENGTH,
            split='val',
            val_split=0.2
        )
        
        print(f"Real audio dataset created successfully!")
        
    except Exception as e:
        print(f"Error creating audio dataset: {e}")
        print("Falling back to synthetic dataset...")
        
        # Fallback to synthetic
        train_dataset = FallbackDataset(TRAIN_SAMPLES)
        val_dataset = FallbackDataset(VAL_SAMPLES)

# Test dataset
test_inputs, test_targets = train_dataset[0]
print(f"Dataset ready!")
print(f"   Training samples: {len(train_dataset):,}")
print(f"   Validation samples: {len(val_dataset):,}")
print(f"   Input shape: {test_inputs.shape}")
print(f"   Target shape: {test_targets.shape}")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"Data loaders created!")
print(f"   Training batches: {len(train_loader)}")
print(f"   Validation batches: {len(val_loader)}")

del test_inputs, test_targets

### How to Upload Your Audio Data to Kaggle:

1. **Prepare Your Audio Files:**
   - Organize your audio files in folders (e.g., by genre, artist, etc.)
   - Supported formats: WAV, MP3, FLAC, M4A
   - Recommended: 100+ audio files for good training results
   - Each file should be at least 15 seconds long

2. **Upload to Kaggle:**
   - Click "Add Data" → "Upload" in the right panel
   - Upload your audio folder as a ZIP file
   - Wait for upload to complete
   - Note the dataset path (usually `/kaggle/input/your-dataset-name/`)

3. **Update the Path:**
   - Modify the `AUDIO_DATA_PATH` variable below
   - Run the dataset creation cell

## Initialize Training Components

Set up the trainer class with model, optimizer, loss function, and data loaders for training.

In [None]:
# Initialize model
model = ProductionUNet(
    in_channels=IN_CHANNELS,
    out_channels=OUT_CHANNELS,
    model_dim=MODEL_DIM
).to(device)

print(f" Model initialized!")
print(f"   Parameters: {model.count_parameters():,}")
print(f"   Device: {device}")

# Initialize optimizer with learning rate scheduling
optimizer = optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=1e-5,
    betas=(0.9, 0.999)
)

# Learning rate scheduler with warmup
def get_lr_schedule_with_warmup(optimizer, warmup_epochs, total_epochs):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            # Warmup phase
            return (epoch + 1) / warmup_epochs
        else:
            # Cosine annealing phase
            progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
            return 0.5 * (1 + np.cos(np.pi * progress))
    
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scheduler = get_lr_schedule_with_warmup(optimizer, WARMUP_EPOCHS, NUM_EPOCHS)

# Loss function
criterion = nn.MSELoss()

# Initialize tensorboard
writer = SummaryWriter('logs/kaggle_training')

print(" Training components initialized!")
print(f"   Optimizer: AdamW")
print(f"   Learning Rate: {LEARNING_RATE}")
print(f"   Scheduler: Warmup + Cosine Annealing")
print(f"   Loss Function: MSE")

##  Training Loop Implementation

Execute the main training loop with epoch-by-epoch training, validation, and progress monitoring.

In [None]:
# Training state
best_val_loss = float('inf')
train_losses = []
val_losses = []
learning_rates = []

print(" Starting training...")
print(f" Epochs: {NUM_EPOCHS} | Batch size: {BATCH_SIZE}")
print("=" * 60)

start_time = time.time()

for epoch in range(NUM_EPOCHS):
    current_epoch = epoch + 1
    
    print(f"\n📈 Epoch {current_epoch}/{NUM_EPOCHS}")
    print("-" * 50)
    
    # Training phase
    model.train()
    epoch_train_loss = 0.0
    num_train_batches = len(train_loader)
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update weights
        if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()
        
        epoch_train_loss += loss.item()
        
        # Progress logging
        if batch_idx % 200 == 0:
            progress = batch_idx / num_train_batches * 100
            print(f"   Batch {batch_idx}/{num_train_batches} ({progress:.1f}%) - Loss: {loss.item():.6f}")
    
    avg_train_loss = epoch_train_loss / num_train_batches
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    epoch_val_loss = 0.0
    num_val_batches = len(val_loader)
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            epoch_val_loss += loss.item()
    
    avg_val_loss = epoch_val_loss / num_val_batches
    val_losses.append(avg_val_loss)
    
    # Update learning rate
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)
    
    # Log to tensorboard
    writer.add_scalar('Loss/Train', avg_train_loss, current_epoch)
    writer.add_scalar('Loss/Validation', avg_val_loss, current_epoch)
    writer.add_scalar('Learning_Rate', current_lr, current_epoch)
    
    # Print epoch summary
    print(f" Train Loss: {avg_train_loss:.6f}")
    print(f" Val Loss: {avg_val_loss:.6f}")
    print(f" Learning Rate: {current_lr:.2e}")
    
    # Save best model
    is_best = avg_val_loss < best_val_loss
    if is_best:
        best_val_loss = avg_val_loss
        # Save best model
        best_checkpoint = {
            'epoch': current_epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'best_val_loss': best_val_loss,
            'model_config': {
            'in_channels': IN_CHANNELS,
            'out_channels': OUT_CHANNELS,
            'model_dim': MODEL_DIM
            }
        }
        torch.save(best_checkpoint, 'checkpoints/best_model_kaggle.pt')
        print(f" New best model saved! (Val Loss: {best_val_loss:.6f})")
        
        # Save regular checkpoint every 5 epochs
        if current_epoch % 5 == 0:
        checkpoint = {
            'epoch': current_epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'best_val_loss': best_val_loss,
            'model_config': {
            'in_channels': IN_CHANNELS,
            'out_channels': OUT_CHANNELS,
            'model_dim': MODEL_DIM
            }
        }
        torch.save(checkpoint, f'checkpoints/model_epoch_{current_epoch}.pt')
        print(f"💾 Checkpoint saved: model_epoch_{current_epoch}.pt")
        
        # Memory cleanup
        if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

    # Training completed
    total_time = time.time() - start_time
    print(f"\n Training completed!")
    print(f" Total time: {total_time:.1f}s ({total_time/60:.1f}m)")
    print(f" Best validation loss: {best_val_loss:.6f}")

    # Save final checkpoint
    final_checkpoint = {
        'epoch': NUM_EPOCHS,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'learning_rates': learning_rates,
        'best_val_loss': best_val_loss,
        'total_training_time': total_time,
        'model_config': {
        'in_channels': IN_CHANNELS,
        'out_channels': OUT_CHANNELS,
        'model_dim': MODEL_DIM
        }
    }
    torch.save(final_checkpoint, 'checkpoints/final_model_kaggle.pt')
    print(" Final checkpoint saved: final_model_kaggle.pt")

    # Close tensorboard writer
    writer.close()

##  Model Evaluation and Validation

Evaluate model performance on validation data and visualize training metrics and loss curves.

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Loss curves
epochs = range(1, len(train_losses) + 1)
axes[0,0].plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
axes[0,0].plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
axes[0,0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0,0].set_xlabel('Epoch')
axes[0,0].set_ylabel('Loss')
axes[0,0].legend()
axes[0,0].grid(True, alpha=0.3)

# Learning rate schedule
axes[0,1].plot(epochs, learning_rates, 'g-', linewidth=2)
axes[0,1].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
axes[0,1].set_xlabel('Epoch')
axes[0,1].set_ylabel('Learning Rate')
axes[0,1].grid(True, alpha=0.3)

# Generate sample predictions
print(" Generating sample predictions...")
model.eval()
with torch.no_grad():
    # Get a batch from validation set
    sample_inputs, sample_targets = next(iter(val_loader))
    sample_inputs = sample_inputs.to(device)
    sample_targets = sample_targets.to(device)
    
    # Generate predictions
    sample_predictions = model(sample_inputs)
    
    # Move to CPU for visualization
    sample_inputs = sample_inputs.cpu()
    sample_targets = sample_targets.cpu()
    sample_predictions = sample_predictions.cpu()

# Visualize sample predictions
sample_idx = 0
source_a = sample_inputs[sample_idx, 0].numpy()
source_b = sample_inputs[sample_idx, 1].numpy()
target = sample_targets[sample_idx, 0].numpy()
prediction = sample_predictions[sample_idx, 0].numpy()

# Plot source A
im1 = axes[1,0].imshow(source_a, aspect='auto', origin='lower', cmap='viridis')
axes[1,0].set_title('Source A (House)', fontsize=14, fontweight='bold')
axes[1,0].set_xlabel('Time Frames')
axes[1,0].set_ylabel('Frequency Bins')
plt.colorbar(im1, ax=axes[1,0], fraction=0.046, pad=0.04)

# Plot generated transition vs target
im2 = axes[1,1].imshow(prediction, aspect='auto', origin='lower', cmap='viridis')
axes[1,1].set_title('Generated Transition', fontsize=14, fontweight='bold')
axes[1,1].set_xlabel('Time Frames')
axes[1,1].set_ylabel('Frequency Bins')
plt.colorbar(im2, ax=axes[1,1], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.savefig('outputs/training_results.png', dpi=150, bbox_inches='tight')
plt.show()

# Calculate final metrics
final_train_loss = train_losses[-1]
final_val_loss = val_losses[-1]
improvement = (train_losses[0] - final_train_loss) / train_losses[0] * 100

print(f"\n Training Summary:")
print(f"   Initial Training Loss: {train_losses[0]:.6f}")
print(f"   Final Training Loss: {final_train_loss:.6f}")
print(f"   Final Validation Loss: {final_val_loss:.6f}")
print(f"   Best Validation Loss: {best_val_loss:.6f}")
print(f"   Training Improvement: {improvement:.1f}%")
print(f"   Total Parameters: {model.count_parameters():,}")
print(f"   Training Time: {total_time/60:.1f} minutes")

# Model quality assessment
if best_val_loss < 0.1:
    quality = "EXCELLENT"
elif best_val_loss < 0.2:
    quality = "GOOD"
elif best_val_loss < 0.3:
    quality = "FAIR"
else:
    quality = "NEEDS IMPROVEMENT"

print(f"\nModel Quality: {quality}")
print(f"Best Validation Loss: {best_val_loss:.6f}")

# Save evaluation results
evaluation_results = {
    'final_train_loss': final_train_loss,
    'final_val_loss': final_val_loss,
    'best_val_loss': best_val_loss,
    'training_improvement_percent': improvement,
    'total_parameters': model.count_parameters(),
    'training_time_minutes': total_time/60,
    'quality_assessment': quality,
    'train_losses': train_losses,
    'val_losses': val_losses,
    'learning_rates': learning_rates
}

torch.save(evaluation_results, 'outputs/evaluation_results.pt')
print(f"\n Evaluation results saved to outputs/evaluation_results.pt")

## Generate Sample Audio Transitions

Let's generate some actual audio transitions to test our trained model!

In [None]:
# Audio generation class for creating actual audio files
class AudioGenerator:
    def __init__(self, model, device, sample_rate=22050):
        self.model = model
        self.device = device
        self.sample_rate = sample_rate
        
    def spectrogram_to_audio(self, spectrogram):
        """Convert spectrogram back to audio using Griffin-Lim algorithm"""
        # Ensure spectrogram is 2D
        if len(spectrogram.shape) == 3:
            spectrogram = spectrogram.squeeze(0)
        
        # Convert to magnitude spectrogram
        magnitude = np.exp(spectrogram) - 1e-6
        
        # Use Griffin-Lim to reconstruct audio
        audio = librosa.griffinlim(magnitude, 
                                 hop_length=512, 
                                 win_length=1024,
                                 n_iter=32)
        
        return audio
    
    def create_synthetic_track(self, style='house', duration=4.0):
        """Create a synthetic track segment"""
        n_frames = int(duration * self.sample_rate / 512)  # 512 is hop_length
        
        if style == 'house':
            # House: 120-128 BPM, 4/4 kick pattern
            fundamental = np.random.uniform(80, 120)  # Hz
            harmonics = [1, 2, 3, 4, 5]
            kick_pattern = np.array([1, 0, 0, 0] * (n_frames // 4))[:n_frames]
        else:
            # Techno: 128-140 BPM, more aggressive
            fundamental = np.random.uniform(100, 150)  # Hz
            harmonics = [1, 2, 3, 5, 7]
            kick_pattern = np.array([1, 0, 1, 0] * (n_frames // 4))[:n_frames]
        
        # Create frequency content
        n_mels = 128
        spectrogram = np.random.normal(0, 0.1, (n_mels, n_frames))
        
        # Add rhythmic elements
        for i, intensity in enumerate(kick_pattern):
            if intensity > 0:
                spectrogram[20:40, i] += intensity * np.random.uniform(0.5, 1.0)
        
        # Add harmonic content
        for harmonic in harmonics:
            freq_bin = min(int(harmonic * fundamental / 22050 * 128), 127)
            spectrogram[freq_bin-2:freq_bin+2, :] += np.random.uniform(0.2, 0.4)
        
        # Smooth and normalize
        from scipy.ndimage import gaussian_filter
        spectrogram = gaussian_filter(spectrogram, sigma=0.5)
        spectrogram = np.clip(spectrogram, 0, None)
        
        return np.log(spectrogram + 1e-6)
    
    def generate_transition(self, style_a='house', style_b='techno', duration=4.0):
        """Generate a transition between two styles"""
        # Create source tracks
        source_a = self.create_synthetic_track(style_a, duration)
        source_b = self.create_synthetic_track(style_b, duration)
        
        # Prepare input tensor
        input_tensor = np.stack([source_a, source_b])
        input_tensor = torch.FloatTensor(input_tensor).unsqueeze(0).to(self.device)
        
        # Generate transition
        self.model.eval()
        with torch.no_grad():
            transition = self.model(input_tensor)
            transition = transition.cpu().numpy().squeeze()
        
        return source_a, source_b, transition

# Create audio generator
audio_gen = AudioGenerator(model, device)

print("🎵 Generating sample transitions...")

# Generate different style combinations
transitions = []
styles = [('house', 'techno'), ('techno', 'house'), ('house', 'house'), ('techno', 'techno')]

for i, (style_a, style_b) in enumerate(styles):
    print(f"   Generating {style_a} → {style_b} transition...")
    source_a, source_b, transition = audio_gen.generate_transition(style_a, style_b)
    
    # Convert spectrograms to audio
    audio_a = audio_gen.spectrogram_to_audio(source_a)
    audio_b = audio_gen.spectrogram_to_audio(source_b)
    audio_transition = audio_gen.spectrogram_to_audio(transition)
    
    # Save audio files
    sf.write(f'outputs/source_a_{style_a}_{i}.wav', audio_a, 22050)
    sf.write(f'outputs/source_b_{style_b}_{i}.wav', audio_b, 22050)
    sf.write(f'outputs/transition_{style_a}_to_{style_b}_{i}.wav', audio_transition, 22050)
    
    transitions.append({
        'styles': (style_a, style_b),
        'spectrograms': (source_a, source_b, transition),
        'audio': (audio_a, audio_b, audio_transition)
    })

print(f"Generated {len(transitions)} sample transitions!")
print("Audio files saved in outputs/ directory")

# Visualize the generated transitions
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

for i, transition_data in enumerate(transitions[:2]):  # Show first 2 transitions
    source_a, source_b, transition = transition_data['spectrograms']
    style_a, style_b = transition_data['styles']
    
    # Plot spectrograms
    im1 = axes[i,0].imshow(source_a, aspect='auto', origin='lower', cmap='viridis')
    axes[i,0].set_title(f'Source A ({style_a.title()})', fontweight='bold')
    axes[i,0].set_ylabel('Frequency Bins')
    
    im2 = axes[i,1].imshow(transition, aspect='auto', origin='lower', cmap='viridis')
    axes[i,1].set_title(f'Generated Transition', fontweight='bold')
    
    im3 = axes[i,2].imshow(source_b, aspect='auto', origin='lower', cmap='viridis')
    axes[i,2].set_title(f'Source B ({style_b.title()})', fontweight='bold')
    
    if i == 1:  # Bottom row
        for ax in axes[i,:]:
            ax.set_xlabel('Time Frames')

plt.tight_layout()
plt.savefig('outputs/generated_transitions.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nSample transitions generated! You can download the audio files to listen to them.")

## Save and Export Model

Finally, let's save our trained model and create a deployment package!

In [None]:
# Create comprehensive model export
print("Preparing model for export...")

# 1. Save the complete model
model_export = {
    'model_state_dict': model.state_dict(),
    'model_config': {
        'in_channels': 2,
        'out_channels': 1,
        'features': [64, 128, 256, 512],
        'model_type': 'ProductionUNet',
        'total_parameters': model.count_parameters(),
        'input_shape': [2, 128, 512],
        'output_shape': [1, 128, 512]
    },
    'training_info': {
        'epochs_trained': config['num_epochs'],
        'final_train_loss': final_train_loss,
        'final_val_loss': final_val_loss,
        'best_val_loss': best_val_loss,
        'training_time_minutes': total_time/60,
        'optimizer': 'AdamW',
        'learning_rate': config['learning_rate'],
        'batch_size': config['batch_size']
    },
    'training_history': {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'learning_rates': learning_rates
    },
    'metadata': {
        'created_date': str(datetime.now()),
        'framework': 'PyTorch',
        'device_trained': str(device),
        'kaggle_notebook': True,
        'model_version': '1.0'
    }
}

# Save complete model
torch.save(model_export, 'outputs/dj_transition_model_complete.pt')
print("Complete model saved to outputs/dj_transition_model_complete.pt")

# 2. Save just the model weights (smaller file)
torch.save(model.state_dict(), 'outputs/dj_transition_model_weights.pt')
print("Model weights saved to outputs/dj_transition_model_weights.pt")

# 3. Export model in ONNX format for deployment
model.eval()
dummy_input = torch.randn(1, 2, 128, 512).to(device)

try:
    torch.onnx.export(
        model,
        dummy_input,
        'outputs/dj_transition_model.onnx',
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['source_tracks'],
        output_names=['transition'],
        dynamic_axes={
            'source_tracks': {0: 'batch_size'},
            'transition': {0: 'batch_size'}
        }
    )
    print("ONNX model saved to outputs/dj_transition_model.onnx")
except Exception as e:
    print(f"ONNX export failed: {e}")

# 4. Create deployment configuration
deployment_config = {
    'model_info': {
        'name': 'DJNet Transition Generator',
        'version': '1.0',
        'description': 'Deep learning model for generating smooth transitions between DJ tracks',
        'input_format': 'Mel-spectrogram pairs (2, 128, 512)',
        'output_format': 'Transition mel-spectrogram (1, 128, 512)',
        'sampling_rate': 22050,
        'hop_length': 512,
        'n_mels': 128
    },
    'inference': {
        'preprocessing': {
            'normalize': True,
            'log_scale': True,
            'clamp_min': 1e-6
        },
        'postprocessing': {
            'griffin_lim_iterations': 32,
            'hop_length': 512,
            'win_length': 1024
        }
    },
    'performance': {
        'parameters': model.count_parameters(),
        'inference_time_gpu_ms': '~50',  # Estimated
        'memory_requirement_mb': '~500',  # Estimated
        'recommended_device': 'GPU (CUDA)',
        'minimum_device': 'CPU'
    }
}

import json
with open('outputs/deployment_config.json', 'w') as f:
    json.dump(deployment_config, f, indent=2)

print("Deployment config saved to outputs/deployment_config.json")

# 5. Create a simple inference script
inference_script = '''
import torch
import numpy as np
import librosa
import soundfile as sf

def load_model(model_path):
    """Load the trained model"""
    checkpoint = torch.load(model_path, map_location='cpu')
    
    # Reconstruct model architecture
    from your_model_file import ProductionUNet  # Replace with actual import
    model = ProductionUNet(in_channels=2, out_channels=1)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    return model

def audio_to_spectrogram(audio_path, sr=22050, n_mels=128, hop_length=512):
    """Convert audio file to mel-spectrogram"""
    audio, _ = librosa.load(audio_path, sr=sr)
    
    # Create mel-spectrogram
    mel_spec = librosa.feature.melspectrogram(
        y=audio, sr=sr, n_mels=n_mels, hop_length=hop_length
    )
    
    # Convert to log scale
    log_mel = np.log(mel_spec + 1e-6)
    
    return log_mel

def spectrogram_to_audio(spectrogram, sr=22050, hop_length=512):
    """Convert mel-spectrogram back to audio"""
    # Convert from log scale
    mel_spec = np.exp(spectrogram) - 1e-6
    
    # Reconstruct audio using Griffin-Lim
    audio = librosa.griffinlim(mel_spec, hop_length=hop_length, n_iter=32)
    
    return audio

def generate_transition(model, audio_path_a, audio_path_b, output_path):
    """Generate transition between two audio files"""
    # Load and convert audio to spectrograms
    spec_a = audio_to_spectrogram(audio_path_a)
    spec_b = audio_to_spectrogram(audio_path_b)
    
    # Prepare input tensor
    input_tensor = torch.FloatTensor(np.stack([spec_a, spec_b])).unsqueeze(0)
    
    # Generate transition
    with torch.no_grad():
        transition = model(input_tensor)
        transition = transition.cpu().numpy().squeeze()
    
    # Convert back to audio
    audio_transition = spectrogram_to_audio(transition)
    
    # Save audio
    sf.write(output_path, audio_transition, 22050)
    
    return audio_transition

# Example usage:
# model = load_model('dj_transition_model_complete.pt')
# generate_transition(model, 'track_a.wav', 'track_b.wav', 'transition.wav')
'''

with open('outputs/inference_example.py', 'w') as f:
    f.write(inference_script)

print("Inference script saved to outputs/inference_example.py")

# 6. Create README
readme_content = f'''# DJNet Transition Generator

## Model Overview
This is a trained U-Net model for generating smooth transitions between DJ tracks.

### Model Details
- **Architecture**: U-Net with skip connections
- **Parameters**: {model.count_parameters():,}
- **Input**: Pair of mel-spectrograms (2, 128, 512)
- **Output**: Transition mel-spectrogram (1, 128, 512)
- **Training**: {config['num_epochs']} epochs on synthetic data

### Performance
- **Final Training Loss**: {final_train_loss:.6f}
- **Final Validation Loss**: {final_val_loss:.6f}
- **Best Validation Loss**: {best_val_loss:.6f}
- **Training Time**: {total_time/60:.1f} minutes

### Files Included
- `dj_transition_model_complete.pt` - Full model with training info
- `dj_transition_model_weights.pt` - Model weights only
- `dj_transition_model.onnx` - ONNX format for deployment
- `deployment_config.json` - Configuration for deployment
- `inference_example.py` - Example inference script
- `evaluation_results.pt` - Training metrics and results

### Quick Start
```python
import torch

# Load model
checkpoint = torch.load('dj_transition_model_complete.pt')
model = ProductionUNet(in_channels=2, out_channels=1)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Generate transition (see inference_example.py for full code)
```

### Requirements
- PyTorch >= 1.9.0
- librosa >= 0.8.0
- numpy >= 1.19.0
- soundfile >= 0.10.0

Generated on: {str(datetime.now())}
'''

with open('outputs/README.md', 'w') as f:
    f.write(readme_content)

print("✅ README saved to outputs/README.md")

# 7. Display file summary
print(f"\n📦 Export Summary:")
print(f"   📁 outputs/dj_transition_model_complete.pt ({os.path.getsize('outputs/dj_transition_model_complete.pt')/1024/1024:.1f} MB)")
print(f"   📁 outputs/dj_transition_model_weights.pt ({os.path.getsize('outputs/dj_transition_model_weights.pt')/1024/1024:.1f} MB)")
if os.path.exists('outputs/dj_transition_model.onnx'):
    print(f"   📁 outputs/dj_transition_model.onnx ({os.path.getsize('outputs/dj_transition_model.onnx')/1024/1024:.1f} MB)")
print(f"   📁 outputs/deployment_config.json")
print(f"   📁 outputs/inference_example.py")
print(f"   📁 outputs/README.md")
print(f"   📁 outputs/evaluation_results.pt")

print(f"\n🎉 Model training and export complete!")
print(f"💡 Download the outputs/ folder to get all model files")
print(f"🚀 Ready for deployment and production use!")