# Music Transition Transformer - Google Colab Notebook

This notebook demonstrates the Music Transition Transformer, a transformer-based model for creating seamless transitions between music segments using mel-spectrogram representation.

## Features
- **Dual Encoder Architecture**: Separate encoders for preceding and following music segments
- **Mel-Spectrogram Processing**: Works with frequency-domain audio representation
- **Continuous Representation**: Generates smooth spectrograms instead of discrete tokens
- **Autoregressive Generation**: Creates transitions step-by-step with proper temporal coherence
- **Synthetic Data Support**: Includes synthetic data generation for testing

## Setup and Installation

First, let's install the required dependencies and clone the repository.

In [None]:
# Install required packages
!pip install torch torchaudio numpy librosa soundfile matplotlib tqdm scikit-learn

# Clone the repository
!git clone https://github.com/SoykatAmin/DJNet-Transformer.git

# Change to the project directory
import os
os.chdir('DJNet-Transformer')

print("Setup complete!")

## Import Libraries and Modules

Import all necessary libraries and the custom modules from our project.

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import librosa
from tqdm import tqdm

# Import our custom modules
from music_transformer.model import MusicTransitionTransformer
from music_transformer.config import Config
from music_transformer.audio_processor import AudioProcessor
from music_transformer.dataset import create_synthetic_spectrograms
from music_transformer.train import Trainer

print("All modules imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device available: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

## Model Configuration

Set up the configuration for our transformer model.

In [None]:
# Initialize configuration
config = Config()

# Print configuration details
print("Model Configuration:")
print(f"  - Mel bins: {config.mel_bins}")
print(f"  - Sequence length: {config.seq_len}")
print(f"  - Model dimension: {config.d_model}")
print(f"  - Number of heads: {config.num_heads}")
print(f"  - Number of layers: {config.num_layers}")
print(f"  - Feedforward dimension: {config.d_ff}")
print(f"  - Dropout: {config.dropout}")
print(f"  - Learning rate: {config.learning_rate}")
print(f"  - Batch size: {config.batch_size}")

## Initialize the Model

Create an instance of the Music Transition Transformer model.

In [None]:
# Initialize the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MusicTransitionTransformer(config).to(device)

# Print model information
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model initialized on: {device}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: {total_params * 4 / 1024 / 1024:.2f} MB (assuming 32-bit floats)")

## Generate Synthetic Data

Create synthetic spectrogram data to test our model.

In [None]:
# Generate synthetic spectrograms only if not using custom dataset
if not use_custom_dataset:
    print("Generating synthetic spectrogram data...")
    spectrograms = create_synthetic_spectrograms(
        num_samples=10,
        mel_bins=config.mel_bins,
        seq_len=config.seq_len * 3  # Total length for preceding + transition + following
    )

    print(f"Generated {len(spectrograms)} synthetic spectrograms")
    print(f"Each spectrogram shape: {spectrograms[0].shape}")

    # Visualize one of the synthetic spectrograms
    plt.figure(figsize=(15, 6))
    plt.imshow(spectrograms[0], aspect='auto', origin='lower', cmap='viridis')
    plt.colorbar(label='Magnitude')
    plt.title('Sample Synthetic Mel-Spectrogram')
    plt.xlabel('Time Frames')
    plt.ylabel('Mel Frequency Bins')
    plt.tight_layout()
    plt.show()
else:
    print("Skipping synthetic data generation - using custom DJNet dataset")

## Dataset Information and Statistics

Let's examine the characteristics of our loaded dataset.

In [None]:
# Display dataset statistics
print("=== DATASET INFORMATION ===")
print(f"Dataset type: {'DJNet Custom Dataset' if use_custom_dataset else 'Synthetic Dataset'}")
print(f"Number of samples: {len(spectrograms)}")
print(f"Spectrogram shape: {spectrograms[0].shape}")
print(f"Sequence length per segment: {config.seq_len}")
print(f"Total sequence length: {spectrograms[0].shape[0]} (should be {config.seq_len * 3})")
print(f"Mel frequency bins: {spectrograms[0].shape[1]}")

# Calculate statistics
all_spectrograms = np.array(spectrograms)
print(f"\n=== SPECTROGRAM STATISTICS ===")
print(f"Mean magnitude: {all_spectrograms.mean():.4f}")
print(f"Std magnitude: {all_spectrograms.std():.4f}")
print(f"Min magnitude: {all_spectrograms.min():.4f}")
print(f"Max magnitude: {all_spectrograms.max():.4f}")

if use_custom_dataset and 'metadata_df' in locals():
    print(f"\n=== DJNET DATASET DETAILS ===")
    print(f"Original dataset size: {len(metadata_df)} transitions")
    print(f"Successfully processed: {len(spectrograms)} transitions")
    print(f"Success rate: {len(spectrograms)/len(metadata_df)*100:.1f}%")
    
    # Show transition type distribution if available
    if len(metadata_df) > 0:
        transition_counts = metadata_df['transition_type'].value_counts()
        print(f"\nTransition types in original dataset:")
        for t_type, count in transition_counts.items():
            print(f"  {t_type}: {count}")

# Visualize spectrogram statistics
plt.figure(figsize=(15, 10))

# Plot magnitude distribution
plt.subplot(2, 3, 1)
plt.hist(all_spectrograms.flatten(), bins=50, alpha=0.7, edgecolor='black')
plt.title('Magnitude Distribution')
plt.xlabel('Magnitude')
plt.ylabel('Frequency')
plt.yscale('log')

# Plot mean spectrogram
plt.subplot(2, 3, 2)
mean_spec = all_spectrograms.mean(axis=0)
plt.imshow(mean_spec.T, aspect='auto', origin='lower', cmap='viridis')
plt.colorbar(label='Mean Magnitude')
plt.title('Mean Spectrogram Across All Samples')
plt.xlabel('Time Frames')
plt.ylabel('Mel Frequency Bins')

# Plot std spectrogram
plt.subplot(2, 3, 3)
std_spec = all_spectrograms.std(axis=0)
plt.imshow(std_spec.T, aspect='auto', origin='lower', cmap='viridis')
plt.colorbar(label='Std Magnitude')
plt.title('Standard Deviation Spectrogram')
plt.xlabel('Time Frames')
plt.ylabel('Mel Frequency Bins')

# Plot frequency band energy over time
plt.subplot(2, 3, 4)
seq_len = config.seq_len
time_energy = mean_spec.mean(axis=1)  # Average across frequency bins
plt.plot(time_energy[:seq_len], label='Preceding', alpha=0.8)
plt.plot(range(seq_len, seq_len*2), time_energy[seq_len:seq_len*2], label='Transition', alpha=0.8)
plt.plot(range(seq_len*2, seq_len*3), time_energy[seq_len*2:], label='Following', alpha=0.8)
plt.axvline(x=seq_len, color='red', linestyle='--', alpha=0.5, label='Segment Boundaries')
plt.axvline(x=seq_len*2, color='red', linestyle='--', alpha=0.5)
plt.title('Mean Energy Over Time')
plt.xlabel('Time Frames')
plt.ylabel('Mean Magnitude')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot frequency band distribution
plt.subplot(2, 3, 5)
freq_energy = mean_spec.mean(axis=0)  # Average across time
plt.plot(freq_energy)
plt.title('Mean Energy Across Frequency Bands')
plt.xlabel('Mel Frequency Bins')
plt.ylabel('Mean Magnitude')
plt.grid(True, alpha=0.3)

# Plot sample comparison
plt.subplot(2, 3, 6)
if len(spectrograms) >= 3:
    sample_indices = [0, len(spectrograms)//2, len(spectrograms)-1]
    for i, idx in enumerate(sample_indices):
        sample_energy = spectrograms[idx].mean(axis=1)
        plt.plot(sample_energy, alpha=0.7, label=f'Sample {idx}')
    plt.title('Energy Comparison Across Samples')
    plt.xlabel('Time Frames')
    plt.ylabel('Mean Magnitude')
    plt.legend()
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nDataset loaded and ready for model training/testing!")

## Load Custom DJNet Dataset (Optional)

If you have created a dataset using the DJNet_Colab notebook, you can load and use it here instead of synthetic data.

In [None]:
# Option 1: Load from Google Drive if you saved the DJNet dataset there
use_custom_dataset = True  # Set to False to use synthetic data instead

if use_custom_dataset:
    try:
        # Mount Google Drive if not already mounted
        from google.colab import drive
        import pandas as pd
        import os
        
        # Try to mount Google Drive
        try:
            drive.mount('/content/drive')
            print("Google Drive mounted successfully!")
        except:
            print("Drive already mounted or not available")
        
        # Path to your DJNet dataset (adjust this path based on where you saved it)
        djnet_dataset_path = '/content/drive/MyDrive/DJNet_Data/output/djnet_dataset'
        
        # Alternative paths you might have used
        alternative_paths = [
            '/content/djnet_dataset',
            '/content/drive/MyDrive/djnet_dataset',
            './djnet_dataset'
        ]
        
        dataset_found = False
        
        # Check if dataset exists
        if os.path.exists(djnet_dataset_path):
            dataset_found = True
            print(f"Found DJNet dataset at: {djnet_dataset_path}")
        else:
            # Try alternative paths
            for alt_path in alternative_paths:
                if os.path.exists(alt_path):
                    djnet_dataset_path = alt_path
                    dataset_found = True
                    print(f"Found DJNet dataset at: {djnet_dataset_path}")
                    break
        
        if not dataset_found:
            print("DJNet dataset not found. Available options:")
            print("1. Run the DJNet_Colab notebook first to generate the dataset")
            print("2. Upload your dataset to Google Drive")
            print("3. Set use_custom_dataset = False to use synthetic data")
            use_custom_dataset = False
        else:
            # Load metadata
            metadata_path = os.path.join(djnet_dataset_path, 'metadata.csv')
            if os.path.exists(metadata_path):
                metadata_df = pd.read_csv(metadata_path)
                print(f"Loaded metadata for {len(metadata_df)} transitions")
                print(f"Transition types: {metadata_df['transition_type'].value_counts().to_dict()}")
                
                # Display first few entries
                print("\nFirst few dataset entries:")
                print(metadata_df[['transition_type', 'source_a_track', 'source_b_track']].head())
                
            else:
                print("Metadata file not found in dataset directory")
                use_custom_dataset = False
    
    except Exception as e:
        print(f"Error loading custom dataset: {e}")
        print("Falling back to synthetic data...")
        use_custom_dataset = False

if not use_custom_dataset:
    print("Using synthetic data for demonstration...")

In [None]:
def load_djnet_spectrograms(djnet_dataset_path, metadata_df, max_samples=50):
    """
    Load and process DJNet dataset audio files into mel-spectrograms
    """
    print("Processing DJNet audio files into mel-spectrograms...")
    
    # Initialize audio processor
    audio_processor = AudioProcessor(config)
    
    spectrograms = []
    valid_samples = 0
    
    from tqdm import tqdm
    for idx, row in tqdm(metadata_df.iterrows(), total=min(len(metadata_df), max_samples), desc="Processing transitions"):
        if valid_samples >= max_samples:
            break
            
        try:
            transition_dir = os.path.join(djnet_dataset_path, row['path'])
            
            # Load the three audio segments
            source_a_path = os.path.join(transition_dir, 'source_a.wav')
            target_path = os.path.join(transition_dir, 'target.wav')
            source_b_path = os.path.join(transition_dir, 'source_b.wav')
            
            # Check if all files exist
            if all(os.path.exists(p) for p in [source_a_path, target_path, source_b_path]):
                # Load audio files
                source_a, _ = librosa.load(source_a_path, sr=config.sample_rate)
                target, _ = librosa.load(target_path, sr=config.sample_rate)
                source_b, _ = librosa.load(source_b_path, sr=config.sample_rate)
                
                # Convert to mel-spectrograms
                mel_a = audio_processor.audio_to_mel_spectrogram(source_a)
                mel_target = audio_processor.audio_to_mel_spectrogram(target)
                mel_b = audio_processor.audio_to_mel_spectrogram(source_b)
                
                # Ensure consistent sequence length
                target_len = config.seq_len
                
                # Trim or pad to target length
                def pad_or_trim(mel_spec, target_length):
                    if mel_spec.shape[1] > target_length:
                        return mel_spec[:, :target_length]
                    elif mel_spec.shape[1] < target_length:
                        pad_width = target_length - mel_spec.shape[1]
                        return np.pad(mel_spec, ((0, 0), (0, pad_width)), mode='edge')
                    return mel_spec
                
                mel_a = pad_or_trim(mel_a, target_len)
                mel_target = pad_or_trim(mel_target, target_len)
                mel_b = pad_or_trim(mel_b, target_len)
                
                # Combine into one spectrogram (preceding + transition + following)
                combined_mel = np.concatenate([mel_a, mel_target, mel_b], axis=1)
                spectrograms.append(combined_mel.T)  # Transpose to (time, freq)
                
                valid_samples += 1
                
        except Exception as e:
            print(f"Error processing transition {idx}: {str(e)[:100]}...")
            continue
    
    print(f"Successfully processed {len(spectrograms)} transitions from DJNet dataset")
    return spectrograms

# Load DJNet spectrograms if custom dataset is available
if use_custom_dataset and 'metadata_df' in locals():
    djnet_spectrograms = load_djnet_spectrograms(djnet_dataset_path, metadata_df, max_samples=20)
    
    if len(djnet_spectrograms) > 0:
        spectrograms = djnet_spectrograms
        print(f"Using {len(spectrograms)} spectrograms from DJNet dataset")
        print(f"Each spectrogram shape: {spectrograms[0].shape}")
        
        # Visualize one of the DJNet spectrograms
        plt.figure(figsize=(15, 8))
        
        # Plot the full combined spectrogram
        plt.subplot(2, 1, 1)
        plt.imshow(spectrograms[0].T, aspect='auto', origin='lower', cmap='viridis')
        plt.colorbar(label='Magnitude (dB)')
        plt.title('DJNet Transition: Full Sequence (Source A + Transition + Source B)')
        plt.xlabel('Time Frames')
        plt.ylabel('Mel Frequency Bins')
        
        # Plot the three segments separately
        seq_len = config.seq_len
        source_a_mel = spectrograms[0][:seq_len].T
        transition_mel = spectrograms[0][seq_len:seq_len*2].T
        source_b_mel = spectrograms[0][seq_len*2:].T
        
        plt.subplot(2, 3, 4)
        plt.imshow(source_a_mel, aspect='auto', origin='lower', cmap='viridis')
        plt.title('Source A (Preceding)')
        plt.xlabel('Time Frames')
        plt.ylabel('Mel Bins')
        
        plt.subplot(2, 3, 5)
        plt.imshow(transition_mel, aspect='auto', origin='lower', cmap='viridis')
        plt.title('Target Transition')
        plt.xlabel('Time Frames')
        
        plt.subplot(2, 3, 6)
        plt.imshow(source_b_mel, aspect='auto', origin='lower', cmap='viridis')
        plt.title('Source B (Following)')
        plt.xlabel('Time Frames')
        
        plt.tight_layout()
        plt.show()
        
        print("Successfully loaded DJNet dataset!")
    else:
        print("No valid spectrograms found in DJNet dataset, falling back to synthetic data")
        use_custom_dataset = False

## Test Forward Pass

Test the model's forward pass with teacher forcing.

In [None]:
# Prepare data for forward pass
batch_size = 2
sample_spectrograms = torch.tensor(spectrograms[:batch_size], dtype=torch.float32).to(device)

# Split into preceding, transition, and following segments
seq_len = config.seq_len
preceding = sample_spectrograms[:, :, :seq_len]  # First third
transition = sample_spectrograms[:, :, seq_len:2*seq_len]  # Middle third (target)
following = sample_spectrograms[:, :, 2*seq_len:]  # Last third

print(f"Input shapes:")
print(f"  Preceding: {preceding.shape}")
print(f"  Transition (target): {transition.shape}")
print(f"  Following: {following.shape}")

# Forward pass with teacher forcing
model.eval()
with torch.no_grad():
    output = model(preceding, following, transition)
    print(f"\nModel output shape: {output.shape}")
    print(f"Output statistics:")
    print(f"  Mean: {output.mean().item():.4f}")
    print(f"  Std: {output.std().item():.4f}")
    print(f"  Min: {output.min().item():.4f}")
    print(f"  Max: {output.max().item():.4f}")

## Test Autoregressive Generation

Test the model's ability to generate transitions autoregressively.

In [None]:
# Test autoregressive generation
print("Testing autoregressive generation...")

model.eval()
with torch.no_grad():
    # Use the first sample for generation
    test_preceding = preceding[:1]  # Take only first sample
    test_following = following[:1]
    
    generated_transition = model.generate(
        test_preceding, 
        test_following, 
        max_length=seq_len
    )
    
    print(f"Generated transition shape: {generated_transition.shape}")
    print(f"Generation statistics:")
    print(f"  Mean: {generated_transition.mean().item():.4f}")
    print(f"  Std: {generated_transition.std().item():.4f}")
    print(f"  Min: {generated_transition.min().item():.4f}")
    print(f"  Max: {generated_transition.max().item():.4f}")

# Visualize the generated transition
plt.figure(figsize=(15, 10))

# Plot preceding segment
plt.subplot(2, 2, 1)
plt.imshow(test_preceding[0].cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
plt.title('Preceding Segment')
plt.ylabel('Mel Frequency Bins')
plt.colorbar()

# Plot generated transition
plt.subplot(2, 2, 2)
plt.imshow(generated_transition[0].cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
plt.title('Generated Transition')
plt.colorbar()

# Plot following segment
plt.subplot(2, 2, 3)
plt.imshow(test_following[0].cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
plt.title('Following Segment')
plt.xlabel('Time Frames')
plt.ylabel('Mel Frequency Bins')
plt.colorbar()

# Plot combined sequence
plt.subplot(2, 2, 4)
combined = torch.cat([
    test_preceding[0], 
    generated_transition[0], 
    test_following[0]
], dim=1)
plt.imshow(combined.cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
plt.title('Complete Sequence (Preceding + Generated + Following)')
plt.xlabel('Time Frames')
plt.colorbar()

plt.tight_layout()
plt.show()

## Training Example

Demonstrate how to train the model with synthetic data.

In [None]:
# Use the loaded dataset for training (whether custom DJNet or synthetic)
print("Preparing training data...")

if use_custom_dataset:
    print(f"Using {len(spectrograms)} samples from DJNet dataset for training")
    # Use the loaded DJNet spectrograms
    train_data = torch.tensor(spectrograms, dtype=torch.float32)
else:
    print("Generating additional synthetic training data...")
    # Generate more synthetic data for training
    train_spectrograms = create_synthetic_spectrograms(
        num_samples=50,
        mel_bins=config.mel_bins,
        seq_len=config.seq_len * 3
    )
    train_data = torch.tensor(train_spectrograms, dtype=torch.float32)

# Initialize trainer
trainer = Trainer(model, config, device)

print(f"Training data shape: {train_data.shape}")
if use_custom_dataset:
    print("Training on real DJ transitions from your custom dataset!")
else:
    print("Training on synthetic data for demonstration...")

print("Starting training demo...")

# Train for a few epochs as demonstration
num_epochs = 5 if use_custom_dataset else 3  # More epochs for real data
losses = []

for epoch in range(num_epochs):
    epoch_losses = []
    
    # Create batches
    for i in range(0, len(train_data), config.batch_size):
        batch = train_data[i:i+config.batch_size]
        if len(batch) < config.batch_size:
            continue
            
        # Split batch into segments
        batch_preceding = batch[:, :, :seq_len]
        batch_transition = batch[:, :, seq_len:2*seq_len]
        batch_following = batch[:, :, 2*seq_len:]
        
        # Train step
        loss = trainer.train_step(
            batch_preceding.to(device),
            batch_following.to(device),
            batch_transition.to(device)
        )
        epoch_losses.append(loss)
    
    avg_loss = np.mean(epoch_losses)
    losses.append(avg_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

# Plot training loss
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(range(1, num_epochs+1), losses, 'b-o', linewidth=2, markersize=8)
plt.title(f'Training Loss - {"DJNet Dataset" if use_custom_dataset else "Synthetic Dataset"}')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)

# Add trend analysis
if len(losses) > 1:
    improvement = losses[0] - losses[-1]
    improvement_pct = (improvement / losses[0]) * 100
    plt.text(0.05, 0.95, f'Loss Improvement: {improvement_pct:.1f}%', 
             transform=plt.gca().transAxes, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

# Plot loss distribution
plt.subplot(1, 2, 2)
plt.hist(losses, bins=max(3, len(losses)//2), alpha=0.7, edgecolor='black')
plt.title('Loss Distribution Across Epochs')
plt.xlabel('Loss Value')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Training demo completed!")
if use_custom_dataset:
    print("The model has been trained on your real DJ transition data!")
    print("This should produce more realistic transitions compared to synthetic data.")
else:
    print("The model has been trained on synthetic data for demonstration purposes.")

## Model Evaluation

Evaluate the model's performance after training.

In [None]:
# Generate test data
test_spectrograms = create_synthetic_spectrograms(
    num_samples=10,
    mel_bins=config.mel_bins,
    seq_len=config.seq_len * 3
)
test_data = torch.tensor(test_spectrograms, dtype=torch.float32).to(device)

# Evaluate model
model.eval()
test_losses = []

with torch.no_grad():
    for i in range(len(test_data)):
        sample = test_data[i:i+1]
        
        # Split into segments
        test_preceding = sample[:, :, :seq_len]
        test_transition = sample[:, :, seq_len:2*seq_len]
        test_following = sample[:, :, 2*seq_len:]
        
        # Forward pass
        output = model(test_preceding, test_following, test_transition)
        
        # Calculate loss
        loss = nn.MSELoss()(output, test_transition)
        test_losses.append(loss.item())

avg_test_loss = np.mean(test_losses)
print(f"Average test loss: {avg_test_loss:.4f}")
print(f"Test loss std: {np.std(test_losses):.4f}")

# Compare original vs generated transition
with torch.no_grad():
    sample_idx = 0
    test_sample = test_data[sample_idx:sample_idx+1]
    
    original_preceding = test_sample[:, :, :seq_len]
    original_transition = test_sample[:, :, seq_len:2*seq_len]
    original_following = test_sample[:, :, 2*seq_len:]
    
    # Generate new transition
    generated_transition = model.generate(
        original_preceding,
        original_following,
        max_length=seq_len
    )
    
    # Calculate similarity metrics
    mse = nn.MSELoss()(generated_transition, original_transition).item()
    mae = nn.L1Loss()(generated_transition, original_transition).item()
    
    print(f"\nComparison with original transition:")
    print(f"  MSE: {mse:.4f}")
    print(f"  MAE: {mae:.4f}")

# Visualize comparison
plt.figure(figsize=(15, 8))

plt.subplot(2, 3, 1)
plt.imshow(original_preceding[0].cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
plt.title('Original Preceding')
plt.ylabel('Mel Bins')

plt.subplot(2, 3, 2)
plt.imshow(original_transition[0].cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
plt.title('Original Transition')

plt.subplot(2, 3, 3)
plt.imshow(original_following[0].cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
plt.title('Original Following')

plt.subplot(2, 3, 4)
plt.imshow(original_preceding[0].cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
plt.title('Same Preceding')
plt.xlabel('Time Frames')
plt.ylabel('Mel Bins')

plt.subplot(2, 3, 5)
plt.imshow(generated_transition[0].cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
plt.title('Generated Transition')
plt.xlabel('Time Frames')

plt.subplot(2, 3, 6)
plt.imshow(original_following[0].cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
plt.title('Same Following')
plt.xlabel('Time Frames')

plt.tight_layout()
plt.show()

## Audio Processing Example

Demonstrate audio processing capabilities (optional - requires audio files).

In [None]:
# Initialize audio processor
audio_processor = AudioProcessor(config)

print("Audio Processor Configuration:")
print(f"  Sample rate: {config.sample_rate} Hz")
print(f"  Hop length: {config.hop_length}")
print(f"  N FFT: {config.n_fft}")
print(f"  Mel bins: {config.mel_bins}")

# Create a synthetic audio signal for demonstration
duration = 3.0  # 3 seconds
sample_rate = config.sample_rate
t = np.linspace(0, duration, int(sample_rate * duration))

# Create a simple synthetic audio signal (combination of sine waves)
frequencies = [440, 554, 659, 880]  # A4, C#5, E5, A5 (A major chord)
synthetic_audio = np.sum([
    np.sin(2 * np.pi * freq * t) * np.exp(-t * 0.5)  # Decaying sine waves
    for freq in frequencies
], axis=0)

# Add some envelope and normalize
envelope = np.exp(-t * 0.3)
synthetic_audio = synthetic_audio * envelope
synthetic_audio = synthetic_audio / np.max(np.abs(synthetic_audio))

print(f"\nSynthetic audio signal:")
print(f"  Duration: {duration} seconds")
print(f"  Sample rate: {sample_rate} Hz")
print(f"  Shape: {synthetic_audio.shape}")

# Convert to mel-spectrogram
mel_spec = audio_processor.audio_to_mel_spectrogram(synthetic_audio)
print(f"\nMel-spectrogram shape: {mel_spec.shape}")

# Visualize the synthetic audio and its mel-spectrogram
plt.figure(figsize=(15, 8))

# Plot waveform
plt.subplot(2, 1, 1)
plt.plot(t, synthetic_audio)
plt.title('Synthetic Audio Waveform')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.grid(True)

# Plot mel-spectrogram
plt.subplot(2, 1, 2)
plt.imshow(mel_spec, aspect='auto', origin='lower', cmap='viridis')
plt.colorbar(label='Magnitude (dB)')
plt.title('Mel-Spectrogram')
plt.xlabel('Time Frames')
plt.ylabel('Mel Frequency Bins')

plt.tight_layout()
plt.show()

print("\nAudio processing example completed!")
print("Note: To use real audio files, upload them to Colab and use:")
print("  audio, sr = librosa.load('path/to/audio.wav', sr=config.sample_rate)")
print("  mel_spec = audio_processor.audio_to_mel_spectrogram(audio)")

## Conclusion

This notebook demonstrated the key features of the Music Transition Transformer:

1. **Model Architecture**: Dual encoder transformer for processing preceding and following music segments
2. **Synthetic Data Generation**: Created test data for model validation
3. **Forward Pass**: Teacher forcing mode for training
4. **Autoregressive Generation**: Step-by-step transition generation
5. **Training Loop**: Demonstrated training with synthetic data
6. **Model Evaluation**: Performance assessment and visualization
7. **Audio Processing**: Conversion between audio and mel-spectrogram representations

### Next Steps

To use this model with real music:
1. Upload your audio files to Colab
2. Use the `AudioProcessor` to convert them to mel-spectrograms
3. Create proper datasets with real music segments
4. Train the model on your data
5. Generate smooth transitions between your music pieces

### Key Parameters to Experiment With

- `seq_len`: Length of input/output sequences
- `d_model`: Model dimension (affects capacity)
- `num_heads`: Number of attention heads
- `num_layers`: Depth of the transformer
- `learning_rate`: Training speed
- `mel_bins`: Frequency resolution

Happy experimenting with music transitions!

## Dataset Comparison and Results

Understanding the differences between training on custom DJNet data vs synthetic data.

In [None]:
print("=== DATASET COMPARISON ANALYSIS ===")

if use_custom_dataset:
    print("✅ You successfully used your custom DJNet dataset!")
    print("\n🎵 ADVANTAGES OF CUSTOM DJNET DATA:")
    print("• Real DJ transitions with authentic musical patterns")
    print("• Multiple transition types (linear fade, bass swap, etc.)")
    print("• Tempo and key-matched musical segments")
    print("• Professional-quality audio processing")
    print("• Realistic spectral characteristics")
    
    if 'metadata_df' in locals():
        print(f"\n📊 YOUR DATASET STATISTICS:")
        print(f"• Total transitions processed: {len(spectrograms)}")
        print(f"• Original dataset size: {len(metadata_df)}")
        
        if len(metadata_df) > 0:
            transition_types = metadata_df['transition_type'].value_counts()
            print(f"• Transition variety: {len(transition_types)} different types")
            most_common = transition_types.index[0]
            print(f"• Most common type: {most_common} ({transition_types[most_common]} samples)")
    
    print(f"\n🎯 TRAINING RESULTS:")
    print(f"• Training epochs: {len(losses)} (extended for real data)")
    print(f"• Final loss: {losses[-1]:.4f}")
    if len(losses) > 1:
        improvement = ((losses[0] - losses[-1]) / losses[0]) * 100
        print(f"• Loss improvement: {improvement:.1f}%")
    
    print(f"\n💡 EXPECTED OUTCOMES:")
    print("• Generated transitions should sound more natural")
    print("• Better preservation of musical coherence")
    print("• More realistic frequency transitions")
    print("• Improved temporal flow between segments")
    
else:
    print("ℹ️  You used synthetic data for this demonstration")
    print("\n🔬 SYNTHETIC DATA CHARACTERISTICS:")
    print("• Mathematical patterns without musical structure")
    print("• Consistent for testing and development")
    print("• Fast generation and processing")
    print("• Good for model architecture validation")
    
    print(f"\n📊 SYNTHETIC DATASET STATISTICS:")
    print(f"• Generated samples: {len(spectrograms)}")
    print(f"• Training epochs: {len(losses)}")
    print(f"• Final loss: {losses[-1]:.4f}")
    
    print(f"\n🚀 TO USE YOUR CUSTOM DATA:")
    print("1. Run the DJNet_Colab notebook to create your dataset")
    print("2. Set use_custom_dataset = True in the second cell")
    print("3. Adjust djnet_dataset_path to your dataset location")
    print("4. Re-run this notebook to train on real DJ data")

print(f"\n🔄 WORKFLOW COMPARISON:")
print("┌─ Synthetic Data Workflow ─┐    ┌─ Custom DJNet Workflow ─┐")
print("│ 1. Generate synthetic data │    │ 1. Run DJNet_Colab      │")
print("│ 2. Train transformer       │    │ 2. Load real transitions │")
print("│ 3. Test model              │    │ 3. Process audio files   │")
print("│ 4. Basic validation        │    │ 4. Train on real data    │")
print("└────────────────────────────┘    │ 5. Generate realistic DJ │")
print("                                   │    transitions           │")
print("                                   └──────────────────────────┘")

# Performance visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Dataset type indicator
axes[0, 0].text(0.5, 0.7, 'DATASET TYPE', ha='center', va='center', 
                fontsize=16, weight='bold', transform=axes[0, 0].transAxes)
dataset_type = 'DJNet Custom' if use_custom_dataset else 'Synthetic'
color = 'green' if use_custom_dataset else 'orange'
axes[0, 0].text(0.5, 0.3, dataset_type, ha='center', va='center', 
                fontsize=20, weight='bold', color=color, transform=axes[0, 0].transAxes)
axes[0, 0].set_xlim(0, 1)
axes[0, 0].set_ylim(0, 1)
axes[0, 0].axis('off')

# Training progress
axes[0, 1].plot(range(1, len(losses)+1), losses, 'o-', color=color, linewidth=2, markersize=8)
axes[0, 1].set_title('Training Progress', fontsize=14, weight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].grid(True, alpha=0.3)

# Dataset characteristics radar chart (simplified)
categories = ['Realism', 'Variety', 'Quality', 'Speed', 'Consistency']
if use_custom_dataset:
    values = [9, 8, 9, 6, 7]  # Custom data scores
else:
    values = [4, 5, 6, 10, 9]  # Synthetic data scores

angles = np.linspace(0, 2*np.pi, len(categories), endpoint=False)
values_plot = values + [values[0]]  # Complete the circle
angles_plot = np.concatenate([angles, [angles[0]]])

axes[1, 0].plot(angles_plot, values_plot, 'o-', color=color, linewidth=2, markersize=8)
axes[1, 0].fill(angles_plot, values_plot, alpha=0.25, color=color)
axes[1, 0].set_xticks(angles)
axes[1, 0].set_xticklabels(categories)
axes[1, 0].set_ylim(0, 10)
axes[1, 0].set_title('Dataset Characteristics (1-10 scale)', fontsize=14, weight='bold')
axes[1, 0].grid(True)

# Sample spectrogram comparison
sample_spec = spectrograms[0]
im = axes[1, 1].imshow(sample_spec.T, aspect='auto', origin='lower', cmap='viridis')
axes[1, 1].set_title(f'Sample Spectrogram - {dataset_type}', fontsize=14, weight='bold')
axes[1, 1].set_xlabel('Time Frames')
axes[1, 1].set_ylabel('Mel Frequency Bins')
plt.colorbar(im, ax=axes[1, 1], label='Magnitude')

plt.tight_layout()
plt.show()

print(f"\n🎉 Analysis complete! Your model is trained and ready to generate DJ transitions.")