# K-Dubstop

### Assignment goals:
- Continuous, unconditioned generation: Similar to Option 1, but with continuous
output (i.e., generate a waveform). Note that if using a model that generates
spectrograms you must also render the spectrogram as a waveform.
- Continuous, conditioned generation: Example tasks include:
  - prompt-based generation (e.g. text-to-spectrogram)
  -  “inpainting” (replace part of an existing waveform/spectrogram) or “outpainting”
(extend an existing waveform/spectrogram)
  - continuous control (e.g. generate music that follows a given volume or pitch
curve)
  - synthesis of a symbolic input (i.e., midi-to-audio using a learned model)

## Retrieving the dataset

### Install Dependencies

In [None]:
# %pip install spotdl

# !spotdl --download-ffmpeg

Collecting spotdl
  Using cached spotdl-4.2.11-py3-none-any.whl.metadata (8.1 kB)
Collecting beautifulsoup4<5.0.0,>=4.12.3 (from spotdl)
  Using cached beautifulsoup4-4.13.4-py3-none-any.whl.metadata (3.8 kB)
Collecting fastapi<0.104.0,>=0.103.0 (from spotdl)
  Using cached fastapi-0.103.2-py3-none-any.whl.metadata (24 kB)
Collecting mutagen<2.0.0,>=1.47.0 (from spotdl)
  Using cached mutagen-1.47.0-py3-none-any.whl.metadata (1.7 kB)
Collecting pydantic<3.0.0,>=2.9.2 (from spotdl)
  Using cached pydantic-2.11.4-py3-none-any.whl.metadata (66 kB)
Collecting pykakasi<3.0.0,>=2.3.0 (from spotdl)
  Using cached pykakasi-2.3.0-py3-none-any.whl.metadata (5.9 kB)
Collecting python-slugify<9.0.0,>=8.0.4 (from python-slugify[unidecode]<9.0.0,>=8.0.4->spotdl)
  Using cached python_slugify-8.0.4-py2.py3-none-any.whl.metadata (8.5 kB)
Collecting pytube<16.0.0,>=15.0.0 (from spotdl)
  Using cached pytube-15.0.0-py3-none-any.whl.metadata (5.0 kB)
Collecting rapidfuzz<4.0.0,>=3.10.1 (from spotdl)
  Do

### KPOP Dataset

Retrieve songs from Spotify playlist

In [8]:
!spotdl --output "./Data/Kpop_midi/{title}" --format mp3 --bitrate 320k "https://open.spotify.com/playlist/37fcr6LxQZ8Zq7dds9f9DG"

[2K[32mProcessing query: https://open.spotify.com/playlist/37fcr6LxQZ8Zq7dds9f9DG[0m      
[2K[32mFound 34 songs in worKout (Playlist)[0m                                            
[2K[37mTotal[0m                    [90m0/34 complete     [0m [90m━━━━━━━━━━━━━━━━━━━━━━━[0m [32m  0%[0m [90m-:--:--[0m
[2K[1A[2K[37mTotal[0m                    [90m0/34 complete     [0m [90m━━━━━━━━━━━━━━━━━━━━━━━[0m [32m  0%[0m [90m-:--:--[0m━[0m[90m━[0m[90m━[0m[90m━[0m[35m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m [32m  0%[0m [90m-:--:--[0m
[37mK/DA - THE BADDEST[0m       [90mProcessing        [0m [90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[35m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m[90m━[0m [32m  0%[0m [90m-:--:--[0m
[2K[1A[2K[1A[2K[37mTotal[0m

### Dubstep Dataset

In [None]:
!spotdl --output "./Data/Dubstep_midi/{title}" --format mp3 --bitrate 320k "https://open.spotify.com/playlist/1k92RRlyDyXbI86dN5rGU9"

## Continuous, unconditioned generation (TEMP)

K-pop Music Generation - Option 3: Continuous, Unconditioned Generation

WHAT THIS ACCOMPLISHES:
- Learns the distribution of K-pop music from your downloaded audio files
- Generates completely new K-pop-style melspectrograms from scratch
- Converts generated spectrograms back to playable audio
- Creates novel K-pop tracks that don't exist but sound authentic to the genre

APPROACH: VAE (Variational Autoencoder) for melspectrogram generation

### Installing Dependencies

In [41]:
# %pip install numpy
# %pip install librosa
# %pip install matplotlib
# %pip install torch
# %pip install soundfile
# %pip install pathlib


import os
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import soundfile as sf
from pathlib import Path
import pickle

In [42]:
# ===== STEP 1: DATA PREPROCESSING =====

class AudioDataProcessor:
    def __init__(self, data_dir="./Data/Kpop_midi", 
                 sr=22050, n_mels=128, n_fft=2048, hop_length=512):
        self.data_dir = data_dir
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        
    def audio_to_melspec(self, audio_path, target_length=1024):
        """Convert audio file to melspectrogram with consistent dimensions"""
        try:
            # Load audio
            y, sr = librosa.load(audio_path, sr=self.sr)
            
            # Take a fixed-length segment (e.g., 30 seconds)
            segment_length = self.sr * 30  # 30 seconds
            if len(y) > segment_length:
                # Take middle section
                start = (len(y) - segment_length) // 2
                y = y[start:start + segment_length]
            else:
                # Pad if too short
                y = np.pad(y, (0, max(0, segment_length - len(y))))
            
            # Generate melspectrogram
            mel_spec = librosa.feature.melspectrogram(
                y=y, sr=sr, n_mels=self.n_mels, 
                n_fft=self.n_fft, hop_length=self.hop_length
            )
            
            # Convert to log scale
            mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
            
            # Normalize to [-1, 1]
            mel_spec_norm = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min())
            mel_spec_norm = 2 * mel_spec_norm - 1
            
            return mel_spec_norm
            
        except Exception as e:
            print(f"Error processing {audio_path}: {e}")
            return None
    
    def process_dataset(self, save_path="processed_spectrograms.pkl"):
        """Process all audio files in the dataset"""
        spectrograms = []
        audio_files = []
        
        # Find all audio files
        for ext in ['*.mp3', '*.wav', '*.flac']:
            audio_files.extend(Path(self.data_dir).rglob(ext))
        
        print(f"Found {len(audio_files)} audio files")
        
        for i, audio_file in enumerate(audio_files):
            if i % 10 == 0:
                print(f"Processing {i}/{len(audio_files)}")
                
            spec = self.audio_to_melspec(str(audio_file))
            if spec is not None:
                spectrograms.append(spec)
        
        spectrograms = np.array(spectrograms)
        print(f"Processed {len(spectrograms)} spectrograms")
        print(f"Shape: {spectrograms.shape}")
        
        # Save processed data
        with open(save_path, 'wb') as f:
            pickle.dump(spectrograms, f)
            
        return spectrograms

In [43]:
# ===== STEP 2: DATASET CLASS =====

class SpectrogramDataset(Dataset):
    def __init__(self, spectrograms):
        self.spectrograms = torch.FloatTensor(spectrograms)
        
    def __len__(self):
        return len(self.spectrograms)
    
    def __getitem__(self, idx):
        return self.spectrograms[idx]

In [44]:
# ===== STEP 3: VAE MODEL ARCHITECTURE =====

class SpectrogramVAE(nn.Module):
    def __init__(self, input_shape=(128, 1292), latent_dim=256):
        super(SpectrogramVAE, self).__init__()
        self.input_shape = input_shape
        self.latent_dim = latent_dim
        
        # Calculate flattened size
        self.flat_size = input_shape[0] * input_shape[1]
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(self.flat_size, 2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
        )
        
        # Latent space
        self.fc_mu = nn.Linear(512, latent_dim)
        self.fc_logvar = nn.Linear(512, latent_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, self.flat_size),
            nn.Tanh()  # Output between -1 and 1
        )
    
    def encode(self, x):
        x = x.view(-1, self.flat_size)
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        x = self.decoder(z)
        return x.view(-1, *self.input_shape)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

In [45]:
# ===== STEP 4: TRAINING FUNCTIONS =====

def vae_loss(x_recon, x, mu, logvar):
    """VAE loss function combining reconstruction and KL divergence"""
    # Reconstruction loss
    recon_loss = nn.MSELoss()(x_recon, x)
    
    # KL divergence loss
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + 0.001 * kl_loss  # Beta-VAE with beta=0.001

def train_vae(model, dataloader, epochs=100, lr=0.001):
    """Train the VAE model"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, data in enumerate(dataloader):
            data = data.to(device)
            optimizer.zero_grad()
            
            x_recon, mu, logvar = model(data)
            loss = vae_loss(x_recon, data, mu, logvar)
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')
        
        # Save model checkpoint
        if (epoch + 1) % 20 == 0:
            torch.save(model.state_dict(), f'vae_checkpoint_epoch_{epoch+1}.pth')

In [46]:
# ===== STEP 5: GENERATION AND AUDIO SYNTHESIS =====

class AudioSynthesizer:
    def __init__(self, sr=22050, n_fft=2048, hop_length=512, n_iter=32):
        self.sr = sr
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.n_iter = n_iter
    
    def melspec_to_audio(self, mel_spec_norm):
        """Convert normalized melspectrogram back to audio"""
        # Denormalize
        mel_spec_db = (mel_spec_norm + 1) / 2  # Back to [0, 1]
        mel_spec_db = mel_spec_db * 80 - 80    # Scale to typical dB range
        
        # Convert from dB to power
        mel_spec = librosa.db_to_power(mel_spec_db)
        
        # Use Griffin-Lim to reconstruct audio
        audio = librosa.feature.inverse.mel_to_audio(
            mel_spec, sr=self.sr, n_fft=self.n_fft, 
            hop_length=self.hop_length, n_iter=self.n_iter
        )
        
        return audio
    
    def generate_audio(self, model, num_samples=5, output_dir="generated_kpop"):
        """Generate new K-pop audio samples"""
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.eval()
        
        os.makedirs(output_dir, exist_ok=True)
        
        with torch.no_grad():
            for i in range(num_samples):
                # Sample from latent space
                z = torch.randn(1, model.latent_dim).to(device)
                
                # Generate spectrogram
                generated_spec = model.decode(z)
                generated_spec = generated_spec.cpu().numpy()[0]
                
                # Convert to audio
                audio = self.melspec_to_audio(generated_spec)
                
                # Save audio
                output_path = os.path.join(output_dir, f"generated_kpop_{i+1}.wav")
                sf.write(output_path, audio, self.sr)
                
                # Save spectrogram visualization
                plt.figure(figsize=(12, 6))
                librosa.display.specshow(generated_spec, sr=self.sr, 
                                       hop_length=self.hop_length, x_axis='time', y_axis='mel')
                plt.colorbar(format='%+2.0f dB')
                plt.title(f'Generated K-pop Melspectrogram {i+1}')
                plt.tight_layout()
                plt.savefig(os.path.join(output_dir, f"spectrogram_{i+1}.png"))
                plt.close()
                
                print(f"Generated: {output_path}")

In [47]:
# ===== STEP 6: MAIN EXECUTION =====
print("=== K-pop Unconditioned Generation Pipeline ===")

# Step 1: Process audio data
print("\n1. Processing audio data...")
processor = AudioDataProcessor()

# Check if processed data exists``
if os.path.exists("processed_spectrograms.pkl"):
    print("Loading existing processed data...")
    with open("processed_spectrograms.pkl", 'rb') as f:
        spectrograms = pickle.load(f)
else:
    print("Processing audio files...")
    spectrograms = processor.process_dataset()

print(f"Dataset shape: {spectrograms.shape}")

# Step 2: Create dataset and dataloader
print("\n2. Creating dataset...")
dataset = SpectrogramDataset(spectrograms)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Step 3: Initialize and train model
print("\n3. Training VAE model...")
input_shape = spectrograms.shape[1:]  # (n_mels, time_steps)
model = SpectrogramVAE(input_shape=input_shape, latent_dim=256)

# Train model (or load existing)
model_path = "trained_vae_model.pth"
if os.path.exists(model_path):
    print("Loading existing model...")
    model.load_state_dict(torch.load(model_path))
else:
    print("Training new model...")
    train_vae(model, dataloader, epochs=100)
    torch.save(model.state_dict(), model_path)

# Step 4: Generate new music
print("\n4. Generating new K-pop tracks...")
synthesizer = AudioSynthesizer()
synthesizer.generate_audio(model, num_samples=5)

print("\n=== Generation Complete! ===")
print("Check the 'generated_kpop' folder for new tracks!")

=== K-pop Unconditioned Generation Pipeline ===

1. Processing audio data...
Processing audio files...
Found 34 audio files
Processing 0/34
Processing 10/34
Processing 20/34
Processing 30/34
Processed 34 spectrograms
Shape: (34, 128, 1292)
Dataset shape: (34, 128, 1292)

2. Creating dataset...

3. Training VAE model...
Training new model...
Epoch 1/100, Loss: 16.5510
Epoch 2/100, Loss: 1.7665
Epoch 3/100, Loss: 0.4224
Epoch 4/100, Loss: 0.0548
Epoch 5/100, Loss: 0.0507
Epoch 6/100, Loss: 0.0514
Epoch 7/100, Loss: 0.0493
Epoch 8/100, Loss: 0.0479
Epoch 9/100, Loss: 0.0477
Epoch 10/100, Loss: 0.0493
Epoch 11/100, Loss: 0.0508
Epoch 12/100, Loss: 0.0481
Epoch 13/100, Loss: 0.0475
Epoch 14/100, Loss: 0.0478
Epoch 15/100, Loss: 0.0501
Epoch 16/100, Loss: 0.0474
Epoch 17/100, Loss: 0.0503
Epoch 18/100, Loss: 0.0477
Epoch 19/100, Loss: 0.0479
Epoch 20/100, Loss: 0.0513
Epoch 21/100, Loss: 0.0470
Epoch 22/100, Loss: 0.0512
Epoch 23/100, Loss: 0.0508
Epoch 24/100, Loss: 0.0483
Epoch 25/100, Los

## Continuous, conditioned generation (TEMP)

K-pop Music Generation - Option 4: Continuous, Conditioned Generation (Continuous Control)

WHAT THIS ACCOMPLISHES:
- Generates K-pop music that follows specific control curves (volume, pitch, tempo, energy)
- Allows you to specify how you want the music to evolve over time
- Creates controllable music generation where you can influence musical characteristics
- Enables creation of music with specific emotional arcs or dynamic patterns

APPROACH: Conditional VAE with control signal conditioning for spectrogram generation

### Installing Dependencies

In [None]:
import os
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import soundfile as sf
from pathlib import Path
import pickle
import scipy.signal
from sklearn.preprocessing import StandardScaler

In [15]:
# ===== STEP 1: CONTROL SIGNAL EXTRACTION =====

class ControlSignalExtractor:
    def __init__(self, sr=22050, hop_length=512):
        self.sr = sr
        self.hop_length = hop_length
    
    def extract_volume_curve(self, audio):
        """Extract RMS energy (volume) curve from audio"""
        rms = librosa.feature.rms(y=audio, hop_length=self.hop_length)[0]
        return rms
    
    def extract_pitch_curve(self, audio):
        """Extract fundamental frequency curve"""
        pitches, magnitudes = librosa.piptrack(y=audio, sr=self.sr, hop_length=self.hop_length)
        
        # Get the pitch with highest magnitude at each time step
        pitch_curve = []
        for t in range(pitches.shape[1]):
            index = magnitudes[:, t].argmax()
            pitch = pitches[index, t] if magnitudes[index, t] > 0.1 else 0
            pitch_curve.append(pitch)
        
        return np.array(pitch_curve)
    
    def extract_spectral_centroid(self, audio):
        """Extract spectral centroid (brightness/timbre indicator)"""
        centroid = librosa.feature.spectral_centroid(y=audio, sr=self.sr, hop_length=self.hop_length)[0]
        return centroid
    
    def extract_tempo_curve(self, audio, frame_length=2048):
        """Extract local tempo variations"""
        # Use onset strength for tempo estimation
        onset_envelope = librosa.onset.onset_strength(y=audio, sr=self.sr, hop_length=self.hop_length)
        
        # Smooth the onset envelope to get tempo-like curve
        tempo_curve = scipy.signal.savgol_filter(onset_envelope, window_length=51, polyorder=3)
        return tempo_curve
    
    def extract_all_controls(self, audio):
        """Extract all control signals from audio"""
        controls = {
            'volume': self.extract_volume_curve(audio),
            'pitch': self.extract_pitch_curve(audio),
            'brightness': self.extract_spectral_centroid(audio),
            'energy': self.extract_tempo_curve(audio)
        }
        
        # Ensure all curves have the same length
        min_length = min(len(curve) for curve in controls.values())
        for key in controls:
            controls[key] = controls[key][:min_length]
        
        return controls

In [16]:
# ===== STEP 2: DATA PREPROCESSING FOR CONTROL CONDITIONING =====

class ControlConditionedProcessor:
    def __init__(self, data_dir="./Data/Kpop_midi", 
                 sr=22050, n_mels=128, n_fft=2048, hop_length=512):
        self.data_dir = data_dir
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.extractor = ControlSignalExtractor(sr, hop_length)
        self.scalers = {}
    
    def process_audio_with_controls(self, audio_path, target_length=1024):
        """Process audio file and extract both spectrogram and control signals"""
        try:
            # Load audio
            y, sr = librosa.load(audio_path, sr=self.sr)
            
            # Take a fixed-length segment
            segment_length = self.sr * 30  # 30 seconds
            if len(y) > segment_length:
                start = (len(y) - segment_length) // 2
                y = y[start:start + segment_length]
            else:
                y = np.pad(y, (0, max(0, segment_length - len(y))))
            
            # Generate melspectrogram
            mel_spec = librosa.feature.melspectrogram(
                y=y, sr=sr, n_mels=self.n_mels, 
                n_fft=self.n_fft, hop_length=self.hop_length
            )
            mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
            mel_spec_norm = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min())
            mel_spec_norm = 2 * mel_spec_norm - 1
            
            # Extract control signals
            controls = self.extractor.extract_all_controls(y)
            
            # Resize control signals to match spectrogram time dimension
            target_time_steps = mel_spec_norm.shape[1]
            for key in controls:
                controls[key] = np.interp(
                    np.linspace(0, 1, target_time_steps),
                    np.linspace(0, 1, len(controls[key])),
                    controls[key]
                )
            
            return mel_spec_norm, controls
            
        except Exception as e:
            print(f"Error processing {audio_path}: {e}")
            return None, None
    
    def process_dataset_with_controls(self, save_path="control_conditioned_data.pkl"):
        """Process all audio files and extract spectrograms + control signals"""
        spectrograms = []
        all_controls = []
        audio_files = []
        
        # Find all audio files
        for ext in ['*.mp3', '*.wav', '*.flac']:
            audio_files.extend(Path(self.data_dir).rglob(ext))
        
        print(f"Found {len(audio_files)} audio files")
        
        for i, audio_file in enumerate(audio_files):
            if i % 10 == 0:
                print(f"Processing {i}/{len(audio_files)}")
                
            spec, controls = self.process_audio_with_controls(str(audio_file))
            if spec is not None and controls is not None:
                spectrograms.append(spec)
                all_controls.append(controls)
        
        spectrograms = np.array(spectrograms)
        print(f"Processed {len(spectrograms)} spectrograms with controls")
        print(f"Spectrogram shape: {spectrograms.shape}")
        
        # Normalize control signals across the dataset
        self._fit_control_scalers(all_controls)
        normalized_controls = self._normalize_controls(all_controls)
        
        # Save processed data
        data = {
            'spectrograms': spectrograms,
            'controls': normalized_controls,
            'scalers': self.scalers
        }
        
        with open(save_path, 'wb') as f:
            pickle.dump(data, f)
        
        return spectrograms, normalized_controls
    
    def _fit_control_scalers(self, all_controls):
        """Fit scalers for each control signal type"""
        control_types = list(all_controls[0].keys())
        
        for control_type in control_types:
            # Concatenate all values for this control type
            all_values = np.concatenate([controls[control_type] for controls in all_controls])
            
            # Fit scaler
            scaler = StandardScaler()
            scaler.fit(all_values.reshape(-1, 1))
            self.scalers[control_type] = scaler
    
    def _normalize_controls(self, all_controls):
        """Normalize control signals using fitted scalers"""
        normalized_controls = []
        
        for controls in all_controls:
            normalized = {}
            for control_type, values in controls.items():
                normalized_values = self.scalers[control_type].transform(values.reshape(-1, 1)).flatten()
                normalized[control_type] = normalized_values
            normalized_controls.append(normalized)
        
        return normalized_controls

In [17]:
# ===== STEP 3: CONTROL CONDITIONED DATASET =====

class ControlConditionedDataset(Dataset):
    def __init__(self, spectrograms, controls, control_types=['volume', 'pitch', 'brightness', 'energy']):
        self.spectrograms = torch.FloatTensor(spectrograms)
        self.control_types = control_types
        
        # Convert controls to tensors
        self.controls = []
        for control_dict in controls:
            # Stack control signals into a single tensor [n_controls, time_steps]
            control_tensor = torch.stack([
                torch.FloatTensor(control_dict[control_type]) 
                for control_type in self.control_types
            ])
            self.controls.append(control_tensor)
        
        self.controls = torch.stack(self.controls)
        
    def __len__(self):
        return len(self.spectrograms)
    
    def __getitem__(self, idx):
        return self.spectrograms[idx], self.controls[idx]


In [18]:
# ===== STEP 4: CONTROL CONDITIONED VAE MODEL =====

class ControlConditionedVAE(nn.Module):
    def __init__(self, spec_shape=(128, 1292), control_dim=4, latent_dim=256, control_weight=1.0):
        super(ControlConditionedVAE, self).__init__()
        self.spec_shape = spec_shape
        self.control_dim = control_dim
        self.latent_dim = latent_dim
        self.control_weight = control_weight
        
        # Spectrogram dimensions
        self.spec_flat_size = spec_shape[0] * spec_shape[1]
        
        # Control signal processing
        self.control_processor = nn.Sequential(
            nn.Conv1d(control_dim, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(256),  # Fixed size output
            nn.Flatten(),
            nn.Linear(64 * 256, 512),
            nn.ReLU(),
            nn.Linear(512, 256)
        )
        
        # Encoder (spectrogram + control conditioning)
        self.encoder = nn.Sequential(
            nn.Linear(self.spec_flat_size + 256, 2048),  # +256 for control embedding
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
        )
        
        # Latent space
        self.fc_mu = nn.Linear(512, latent_dim)
        self.fc_logvar = nn.Linear(512, latent_dim)
        
        # Decoder (latent + control conditioning)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + 256, 512),  # +256 for control embedding
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, self.spec_flat_size),
            nn.Tanh()
        )
    
    def encode(self, x, controls):
        # Process controls
        control_embedding = self.control_processor(controls)
        
        # Flatten spectrogram and concatenate with control embedding
        x_flat = x.view(-1, self.spec_flat_size)
        x_with_controls = torch.cat([x_flat, control_embedding], dim=1)
        
        # Encode
        h = self.encoder(x_with_controls)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z, controls):
        # Process controls
        control_embedding = self.control_processor(controls)
        
        # Concatenate latent with control embedding
        z_with_controls = torch.cat([z, control_embedding], dim=1)
        
        # Decode
        x = self.decoder(z_with_controls)
        return x.view(-1, *self.spec_shape)
    
    def forward(self, x, controls):
        mu, logvar = self.encode(x, controls)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z, controls)
        return x_recon, mu, logvar


In [19]:
# ===== STEP 5: TRAINING FUNCTIONS =====

def control_vae_loss(x_recon, x, mu, logvar, controls_pred=None, controls_true=None, control_weight=0.1):
    """VAE loss with optional control signal reconstruction loss"""
    # Reconstruction loss
    recon_loss = nn.MSELoss()(x_recon, x)
    
    # KL divergence loss
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Control loss (if predicting controls)
    control_loss = 0
    if controls_pred is not None and controls_true is not None:
        control_loss = nn.MSELoss()(controls_pred, controls_true)
    
    return recon_loss + 0.001 * kl_loss + control_weight * control_loss

def train_control_vae(model, dataloader, epochs=100, lr=0.001):
    """Train the control-conditioned VAE"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.7)
    
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (spectrograms, controls) in enumerate(dataloader):
            spectrograms = spectrograms.to(device)
            controls = controls.to(device)
            
            optimizer.zero_grad()
            
            x_recon, mu, logvar = model(spectrograms, controls)
            loss = control_vae_loss(x_recon, spectrograms, mu, logvar)
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        scheduler.step()
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}')
        
        if (epoch + 1) % 20 == 0:
            torch.save(model.state_dict(), f'control_vae_checkpoint_epoch_{epoch+1}.pth')

In [20]:
# ===== STEP 6: CONTROL CURVE GENERATION =====

class ControlCurveGenerator:
    def __init__(self, time_steps=1292):
        self.time_steps = time_steps
    
    def generate_smooth_curve(self, start_val=0, end_val=1, curve_type='linear'):
        """Generate smooth control curves"""
        t = np.linspace(0, 1, self.time_steps)
        
        if curve_type == 'linear':
            curve = start_val + (end_val - start_val) * t
        elif curve_type == 'exponential':
            curve = start_val + (end_val - start_val) * (np.exp(3*t) - 1) / (np.exp(3) - 1)
        elif curve_type == 'logarithmic':
            curve = start_val + (end_val - start_val) * np.log(1 + 9*t) / np.log(10)
        elif curve_type == 'sine':
            curve = start_val + (end_val - start_val) * (np.sin(np.pi * t - np.pi/2) + 1) / 2
        elif curve_type == 'bell':
            curve = start_val + (end_val - start_val) * np.exp(-((t - 0.5) * 6)**2)
        else:
            curve = np.full(self.time_steps, start_val)
        
        return curve
    
    def create_control_set(self, volume_curve='linear', pitch_curve='sine', 
                          brightness_curve='bell', energy_curve='exponential'):
        """Create a set of control curves"""
        controls = {
            'volume': self.generate_smooth_curve(0.2, 0.8, volume_curve),
            'pitch': self.generate_smooth_curve(-1, 1, pitch_curve),
            'brightness': self.generate_smooth_curve(-0.5, 1.5, brightness_curve),
            'energy': self.generate_smooth_curve(0.1, 1.2, energy_curve)
        }
        return controls

In [21]:
# ===== STEP 7: CONTROLLED AUDIO SYNTHESIS =====

class ControlledAudioSynthesizer:
    def __init__(self, sr=22050, n_fft=2048, hop_length=512, n_iter=32):
        self.sr = sr
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.n_iter = n_iter
    
    def melspec_to_audio(self, mel_spec_norm):
        """Convert normalized melspectrogram back to audio"""
        mel_spec_db = (mel_spec_norm + 1) / 2 * 80 - 80
        mel_spec = librosa.db_to_power(mel_spec_db)
        audio = librosa.feature.inverse.mel_to_audio(
            mel_spec, sr=self.sr, n_fft=self.n_fft, 
            hop_length=self.hop_length, n_iter=self.n_iter
        )
        return audio
    
    def generate_controlled_music(self, model, control_generator, scalers, 
                                num_samples=5, output_dir="controlled_kpop"):
        """Generate music with specific control curves"""
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.eval()
        
        os.makedirs(output_dir, exist_ok=True)
        
        # Define different control scenarios
        scenarios = [
            {'name': 'crescendo', 'volume': 'exponential', 'energy': 'exponential'},
            {'name': 'dramatic_arc', 'volume': 'bell', 'pitch': 'sine', 'energy': 'bell'},
            {'name': 'fade_in_out', 'volume': 'bell', 'brightness': 'bell'},
            {'name': 'rising_energy', 'energy': 'exponential', 'brightness': 'exponential'},
            {'name': 'pitch_sweep', 'pitch': 'linear', 'volume': 'linear'}
        ]
        
        with torch.no_grad():
            for i, scenario in enumerate(scenarios[:num_samples]):
                print(f"Generating scenario: {scenario['name']}")
                
                # Generate control curves
                controls = control_generator.create_control_set(
                    volume_curve=scenario.get('volume', 'linear'),
                    pitch_curve=scenario.get('pitch', 'linear'),
                    brightness_curve=scenario.get('brightness', 'linear'),
                    energy_curve=scenario.get('energy', 'linear')
                )
                
                # Normalize controls using dataset scalers
                for control_type in controls:
                    if control_type in scalers:
                        controls[control_type] = scalers[control_type].transform(
                            controls[control_type].reshape(-1, 1)
                        ).flatten()
                
                # Convert to tensor
                control_tensor = torch.stack([
                    torch.FloatTensor(controls['volume']),
                    torch.FloatTensor(controls['pitch']),
                    torch.FloatTensor(controls['brightness']),
                    torch.FloatTensor(controls['energy'])
                ]).unsqueeze(0).to(device)
                
                # Sample from latent space
                z = torch.randn(1, model.latent_dim).to(device)
                
                # Generate spectrogram
                generated_spec = model.decode(z, control_tensor)
                generated_spec = generated_spec.cpu().numpy()[0]
                
                # Convert to audio
                audio = self.melspec_to_audio(generated_spec)
                
                # Save audio
                output_path = os.path.join(output_dir, f"{scenario['name']}_controlled.wav")
                sf.write(output_path, audio, self.sr)
                
                # Save control curves and spectrogram visualization
                self._save_analysis(generated_spec, controls, scenario['name'], output_dir)
                
                print(f"Generated: {output_path}")
    
    def _save_analysis(self, spectrogram, controls, name, output_dir):
        """Save spectrogram and control curve visualizations"""
        fig, axes = plt.subplots(3, 1, figsize=(15, 12))
        
        # Spectrogram
        librosa.display.specshow(spectrogram, sr=self.sr, hop_length=self.hop_length,
                               x_axis='time', y_axis='mel', ax=axes[0])
        axes[0].set_title(f'Generated Spectrogram: {name}')
        
        # Control curves
        time_axis = np.linspace(0, 30, len(controls['volume']))  # 30 seconds
        
        for i, (control_type, values) in enumerate(controls.items()):
            color = ['red', 'blue', 'green', 'orange'][i]
            axes[1].plot(time_axis, values, label=control_type, color=color, linewidth=2)
        
        axes[1].set_title('Control Curves')
        axes[1].set_xlabel('Time (seconds)')
        axes[1].set_ylabel('Control Value')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        # Control curve correlation with spectrogram energy
        spec_energy = np.mean(spectrogram, axis=0)  # Average energy per time step
        axes[2].plot(time_axis, spec_energy / np.max(spec_energy), label='Spectrogram Energy', linewidth=2)
        axes[2].plot(time_axis, controls['volume'], label='Volume Control', alpha=0.7, linewidth=2)
        axes[2].set_title('Energy vs Volume Control')
        axes[2].set_xlabel('Time (seconds)')
        axes[2].legend()
        axes[2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"{name}_analysis.png"), dpi=150)
        plt.close()

In [None]:
# ===== STEP 8: MAIN EXECUTION =====
print("=== K-pop Continuous Control Generation Pipeline ===")

# Step 1: Process audio data with control signals
print("\n1. Processing audio data and extracting control signals...")
processor = ControlConditionedProcessor()

if os.path.exists("control_conditioned_data.pkl"):
    print("Loading existing processed data...")
    with open("control_conditioned_data.pkl", 'rb') as f:
        data = pickle.load(f)
    spectrograms = data['spectrograms']
    controls = data['controls']
    scalers = data['scalers']
else:
    print("Processing audio files with control extraction...")
    spectrograms, controls = processor.process_dataset_with_controls()
    scalers = processor.scalers

print(f"Dataset: {len(spectrograms)} spectrograms with control signals")

# Step 2: Create dataset and dataloader
print("\n2. Creating control-conditioned dataset...")
dataset = ControlConditionedDataset(spectrograms, controls)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Step 3: Initialize and train control-conditioned VAE
print("\n3. Training control-conditioned VAE...")
spec_shape = spectrograms.shape[1:]
model = ControlConditionedVAE(spec_shape=spec_shape, control_dim=4, latent_dim=256)

model_path = "control_conditioned_vae.pth"
if os.path.exists(model_path):
    print("Loading existing model...")
    model.load_state_dict(torch.load(model_path))
else:
    print("Training new model...")
    train_control_vae(model, dataloader, epochs=80)
    torch.save(model.state_dict(), model_path)

# Step 4: Generate controlled music
print("\n4. Generating music with continuous control...")
control_generator = ControlCurveGenerator(time_steps=spec_shape[1])
synthesizer = ControlledAudioSynthesizer()

synthesizer.generate_controlled_music(model, control_generator, scalers, num_samples=5)

print("\n=== Controlled Generation Complete! ===")
print("Check 'controlled_kpop' folder for:")
print("- Audio files with different control scenarios")
print("- Analysis plots showing control curves and their effects")