# DJNet-Diffusion Training on Colab/Kaggle

This notebook trains the DJNet diffusion model for music transition generation.

## Setup Instructions:
1. Upload your dataset to Colab/Kaggle
2. Update the paths in the configuration
3. Run all cells to start training

In [None]:
# Install dependencies
!pip install torch torchaudio diffusers transformers accelerate librosa soundfile
!pip install pandas numpy matplotlib seaborn tensorboard tqdm scipy omegaconf wandb scikit-learn PyYAML

In [None]:
# Upload and extract your DJNet code
# You can zip your local implementation and upload it here
# !unzip djnet-diffusion.zip

In [None]:
# Clone or pull the latest DJNet implementation
# !git clone https://github.com/SoykatAmin/DJNet-Diffusion.git djnet-diffusion
# If already cloned, pull latest changes:
# !cd djnet-diffusion && git pull origin main

# Or if you have a zip file, upload and extract:
# !unzip djnet-diffusion.zip

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB" if torch.cuda.is_available() else "No CUDA")

In [None]:
# Update these paths to match your Colab/Kaggle setup
import yaml

# Training configuration for GPU
training_config = {
    'training': {
        'batch_size': 16,  # Larger batch size for GPU
        'num_epochs': 50,
        'learning_rate': 1e-4,
        'weight_decay': 1e-6,
        'warmup_steps': 1000,
        'gradient_clip_norm': 1.0,
        'save_every_n_epochs': 5,
        'validate_every_n_epochs': 2
    },
    'data': {
        'metadata_path': '/content/metadata.csv',  # Update this path
        'data_root': '/content/djnet_dataset_20k',  # Update this path
        'train_split': 0.8,
        'val_split': 0.1,
        'test_split': 0.1,
        'num_workers': 2,
        'pin_memory': True,
        'shuffle': True
    },
    'logging': {
        'log_dir': 'logs',
        'use_wandb': False,  # Set to True if you want to use wandb
        'wandb_project': 'djnet-diffusion',
        'log_every_n_steps': 50,
        'save_audio_samples': True,
        'num_audio_samples': 4
    },
    'optimization': {
        'optimizer': 'AdamW',
        'scheduler': 'cosine',
        'min_lr': 1e-6
    },
    'checkpointing': {
        'checkpoint_dir': 'checkpoints',
        'resume_from_checkpoint': None,
        'save_best_only': False
    }
}

# Model configuration for GPU
model_config = {
    'model': {
        'name': 'DJNetDiffusion',
        'in_channels': 3,
        'out_channels': 1,
        'down_block_types': ['DownBlock2D', 'DownBlock2D', 'DownBlock2D', 'AttnDownBlock2D'],
        'up_block_types': ['AttnUpBlock2D', 'UpBlock2D', 'UpBlock2D', 'UpBlock2D'],
        'block_out_channels': [128, 256, 512, 512],
        'layers_per_block': 2,
        'attention_head_dim': 8,
        'norm_num_groups': 32,
        'cross_attention_dim': 256
    },
    'scheduler': {
        'num_train_timesteps': 1000,
        'beta_start': 0.0001,
        'beta_end': 0.02,
        'beta_schedule': 'linear',
        'variance_type': 'fixed_small',
        'clip_sample': False
    },
    'audio': {
        'sample_rate': 22050,
        'n_mels': 128,
        'n_fft': 2048,
        'hop_length': 512,
        'win_length': 2048,
        'context_duration': 4.0,
        'max_transition_duration': 8.0,
        'normalize_spectrograms': True,
        'spec_min': -80.0,
        'spec_max': 0.0
    },
    'conditioning': {
        'use_tempo': True,
        'use_transition_type': True,
        'use_transition_length': True,
        'tempo_embed_dim': 64,
        'type_embed_dim': 64,
        'length_embed_dim': 64
    }
}

# Save configurations
with open('training_config.yaml', 'w') as f:
    yaml.dump(training_config, f)
    
with open('model_config.yaml', 'w') as f:
    yaml.dump(model_config, f)

print("✅ Configuration files created")

In [None]:
# Import your DJNet implementation
import sys
sys.path.append('/content/djnet-diffusion')  # Update path as needed

from src.training.trainer import DJNetTrainer, load_config
print("✅ DJNet modules imported successfully")

## ⚠️ Important: Latest Fixes Applied

**The latest version includes critical fixes for:**
- Tensor dimension consistency in training and inference
- Custom collate function for proper batching
- Spectrogram size normalization

**Make sure to pull the latest changes before training!**

In [None]:
# Load configurations and start training
training_config = load_config('training_config.yaml')
model_config = load_config('model_config.yaml')
config = {**model_config, **training_config}

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize trainer
trainer = DJNetTrainer(config, device)

print(f"Model has {sum(p.numel() for p in trainer.model.parameters()):,} parameters")
print(f"Training dataset size: {len(trainer.train_loader.dataset)}")
print(f"Validation dataset size: {len(trainer.val_loader.dataset)}")

In [None]:
# Start training
try:
    trainer.train()
    print("🎉 Training completed successfully!")
except KeyboardInterrupt:
    print("Training interrupted by user")
    trainer.save_checkpoint(trainer.current_epoch, is_best=False)
    print("Checkpoint saved")

In [None]:
# Test generation with trained model
from inference.generator import DJNetGenerator

# Load best model
checkpoint_path = 'checkpoints/best_model.pth'
generator = DJNetGenerator(checkpoint_path, device)

print("✅ Generator loaded successfully")
print("Ready for inference!")

In [None]:
# Example generation (update paths to your test audio files)
output_path = generator.generate_transition_audio(
    song_a_path='/content/test_song_a.wav',  # Update this
    song_b_path='/content/test_song_b.wav',  # Update this
    output_path='/content/generated_transition.wav',
    transition_length=8.0,
    tempo=120.0,
    transition_type='linear_fade',
    num_inference_steps=50
)

print(f"Generated transition: {output_path}")

# Play the result
from IPython.display import Audio
Audio(output_path)