In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import musdb
import museval
from tqdm import tqdm
import librosa
import warnings
warnings.filterwarnings('ignore')

# STFT parameters optimized for speed
SAMPLE_RATE = 44100
N_FFT = 2048  # Reduced from 4096 for faster computation
HOP_LENGTH = 512  # Reduced from 1024
N_BINS = N_FFT // 2 + 1
CHUNK_DURATION = 3  # Reduced from 6 seconds for more iterations
BATCH_SIZE = 16
EPOCHS = 20  # Adjust based on training time
LEARNING_RATE = 1e-3

class SimplifiedOpenUnmix(nn.Module):
    """Simplified Open-Unmix with 2-layer LSTM for faster training"""

    def __init__(self, n_fft=2048, hidden_size=256, n_layers=2):
        super().__init__()
        self.n_bins = n_fft // 2 + 1
        self.hidden_size = hidden_size

        # Input normalization
        self.input_norm = nn.BatchNorm1d(self.n_bins)

        # Frequency compression (reduce dimensionality)
        self.freq_encoder = nn.Linear(self.n_bins, hidden_size)

        # Bidirectional LSTM (reduced to 2 layers)
        self.lstm = nn.LSTM(
            hidden_size,
            hidden_size // 2,  # Smaller hidden size
            n_layers,
            bidirectional=True,
            batch_first=True,
            dropout=0.3
        )

        # Decoder
        self.freq_decoder = nn.Linear(hidden_size, self.n_bins)

        # Output activation (sigmoid for mask)
        self.output_activation = nn.Sigmoid()

    def forward(self, x):
        # x shape: (batch, time, freq)
        # batch, time, freq = x.shape

        # Normalize per frequency bin
        x = x.transpose(1, 2)  # (batch, freq, time)
        x = self.input_norm(x)
        x = x.transpose(1, 2)  # (batch, time, freq)

        # Frequency encoding
        x = self.freq_encoder(x)
        x = F.relu(x)

        # LSTM processing
        x, _ = self.lstm(x)

        # Decode back to frequency dimension
        x = self.freq_decoder(x)

        # Apply sigmoid to get mask values [0, 1]
        mask = self.output_activation(x)

        return mask

class MUSDBDataset(Dataset):
    """Fast data loader for MUSDB18-HQ"""

    def __init__(self, musdb_root, subset='train', target='vocals',
                 chunk_duration=CHUNK_DURATION, sample_rate=SAMPLE_RATE):
        self.db = musdb.DB(root=musdb_root, subsets=subset, is_wav=True)
        self.target = target
        self.chunk_duration = chunk_duration
        self.sample_rate = sample_rate
        self.chunk_samples = int(chunk_duration * sample_rate)

        # Pre-compute valid track indices and lengths for faster sampling
        self.track_lengths = []
        for track in self.db:
            self.track_lengths.append(len(track.audio))

    def __len__(self):
        # Return a fixed number for epoch size
        return len(self.db) * 20  # 20 chunks per track per epoch

    def __getitem__(self, idx):
        # Random track selection
        track_idx = np.random.randint(len(self.db))
        track = self.db[track_idx]

        # Random chunk extraction
        track_length = self.track_lengths[track_idx]
        if track_length > self.chunk_samples:
            start = np.random.randint(0, track_length - self.chunk_samples)
            end = start + self.chunk_samples
        else:
            start = 0
            end = track_length

        # Get mixture and target
        mixture = track.audio[start:end].T  # (2, samples)
        target_audio = track.targets[self.target].audio[start:end].T

        # Convert to mono for simplicity
        mixture = mixture.mean(axis=0)
        target_audio = target_audio.mean(axis=0)

        # Pad if necessary
        if len(mixture) < self.chunk_samples:
            pad_len = self.chunk_samples - len(mixture)
            mixture = np.pad(mixture, (0, pad_len))
            target_audio = np.pad(target_audio, (0, pad_len))

        return mixture.astype(np.float32), target_audio.astype(np.float32)

def compute_stft(audio, n_fft=N_FFT, hop_length=HOP_LENGTH):
    """Compute magnitude spectrogram"""
    stft = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length)
    magnitude = np.abs(stft)
    phase = np.angle(stft)
    return magnitude.T, phase.T  # (time, freq)

def apply_mask_and_istft(magnitude, phase, mask, n_fft=N_FFT, hop_length=HOP_LENGTH):
    """Apply mask and convert back to audio"""
    masked_magnitude = magnitude * mask
    stft = masked_magnitude.T * np.exp(1j * phase.T)
    audio = librosa.istft(stft, hop_length=hop_length)
    return audio

def evaluate_separation(estimated, reference, sample_rate=SAMPLE_RATE):
    """Compute SDR, SIR, SAR metrics using museval"""
    # Ensure 2D arrays for museval
    if estimated.ndim == 1:
        estimated = estimated[np.newaxis, :]
    if reference.ndim == 1:
        reference = reference[np.newaxis, :]

    try:
        sdr, isr, sir, sar = museval.evaluate(reference, estimated,
                                              win=sample_rate, hop=sample_rate)
        return {
            'SDR': np.median(sdr),
            'SIR': np.median(sir),
            'SAR': np.median(sar)
        }
    except:
        return {'SDR': -np.inf, 'SIR': -np.inf, 'SAR': -np.inf}

def train_epoch(model, dataloader, optimizer, device):
    """Train one epoch"""
    model.train()
    total_loss = 0

    for mixture, target in tqdm(dataloader, desc="Training"):
        mixture = mixture.to(device)
        target = target.to(device)

        # Compute spectrograms
        batch_size = mixture.shape[0]
        mix_mags, phases = [], []
        tgt_mags = []

        for i in range(batch_size):
            mix_mag, phase = compute_stft(mixture[i].cpu().numpy())
            tgt_mag, _ = compute_stft(target[i].cpu().numpy())

            mix_mags.append(mix_mag)
            phases.append(phase)
            tgt_mags.append(tgt_mag)

        # Stack and convert to tensors
        mix_mags = torch.FloatTensor(np.stack(mix_mags)).to(device)
        tgt_mags = torch.FloatTensor(np.stack(tgt_mags)).to(device)

        # Normalize magnitude
        mix_mags_norm = mix_mags + 1e-8

        # Forward pass
        masks = model(mix_mags_norm)

        # Apply masks
        estimated_mags = mix_mags * masks

        # MSE loss in magnitude domain
        loss = F.mse_loss(estimated_mags, tgt_mags)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

def validate(model, dataloader, device):
    """Validate and compute metrics"""
    model.eval()
    metrics = {'SDR': [], 'SIR': [], 'SAR': []}

    with torch.no_grad():
        for mixture, target in tqdm(dataloader, desc="Validating"):
            mixture = mixture.to(device)
            target = target.to(device)

            batch_size = mixture.shape[0]

            for i in range(batch_size):
                # Compute spectrograms
                mix_mag, phase = compute_stft(mixture[i].cpu().numpy())

                # Get mask prediction
                mix_mag_tensor = torch.FloatTensor(mix_mag).unsqueeze(0).to(device)
                mask = model(mix_mag_tensor + 1e-8).squeeze(0).cpu().numpy()

                # Apply mask and reconstruct
                estimated = apply_mask_and_istft(mix_mag, phase, mask)
                reference = target[i].cpu().numpy()

                # Compute metrics
                scores = evaluate_separation(estimated, reference)
                for key in metrics:
                    if not np.isinf(scores[key]):
                        metrics[key].append(scores[key])

    # Average metrics
    avg_metrics = {k: np.mean(v) if v else -np.inf for k, v in metrics.items()}
    return avg_metrics

def main():
    """Main training loop"""
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Initialize model
    model = SimplifiedOpenUnmix().to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)

    # Setup data (update path to your MUSDB18-HQ location)
    musdb_root = "/path/to/musdb18hq"  # UPDATE THIS PATH

    train_dataset = MUSDBDataset(musdb_root, subset='train', target='vocals')
    val_dataset = MUSDBDataset(musdb_root, subset='test', target='vocals')

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                            shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=4,
                          shuffle=False, num_workers=2)

    # Training loop
    best_sdr = -np.inf

    for epoch in range(EPOCHS):
        print(f"\nEpoch {epoch+1}/{EPOCHS}")

        # Train
        train_loss = train_epoch(model, train_loader, optimizer, device)
        print(f"Training loss: {train_loss:.4f}")

        # Validate every 2 epochs to save time
        if (epoch + 1) % 2 == 0:
            metrics = validate(model, val_loader, device)
            print(f"Validation metrics - SDR: {metrics['SDR']:.2f} dB, "
                  f"SIR: {metrics['SIR']:.2f} dB, SAR: {metrics['SAR']:.2f} dB")

            # Learning rate scheduling
            scheduler.step(metrics['SDR'])

            # Save best model
            if metrics['SDR'] > best_sdr:
                best_sdr = metrics['SDR']
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'metrics': metrics,
                }, f'best_model_vocals.pth')
                print(f"Saved best model with SDR: {best_sdr:.2f} dB")

    print(f"\nTraining complete! Best SDR: {best_sdr:.2f} dB")

if __name__ == "__main__":
    main()

In [None]:
import torch
import numpy as np
import librosa
import soundfile as sf
from scipy.signal import wiener
import warnings
warnings.filterwarnings('ignore')

# Import the model class from training script
from train_simple_openunmix import SimplifiedOpenUnmix, compute_stft, N_FFT, HOP_LENGTH

def load_audio(path, sr=44100, mono=True):
    """Load audio file"""
    audio, _ = librosa.load(path, sr=sr, mono=mono)
    return audio

def separate_audio(audio, model, device, n_fft=N_FFT, hop_length=HOP_LENGTH,
                  chunk_size=44100*10):  # 10 second chunks
    """Separate audio using trained model"""
    model.eval()

    # Process in chunks for long audio
    separated_chunks = []

    for start in range(0, len(audio), chunk_size):
        end = min(start + chunk_size, len(audio))
        chunk = audio[start:end]

        # Compute STFT
        magnitude, phase = compute_stft(chunk, n_fft, hop_length)

        # Prepare input
        mag_tensor = torch.FloatTensor(magnitude).unsqueeze(0).to(device)

        with torch.no_grad():
            # Get mask prediction
            mask = model(mag_tensor + 1e-8).squeeze(0).cpu().numpy()

            # Apply mask
            separated_magnitude = magnitude * mask

            # Convert back to audio
            stft_separated = separated_magnitude.T * np.exp(1j * phase.T)
            separated_chunk = librosa.istft(stft_separated, hop_length=hop_length)

        separated_chunks.append(separated_chunk)

    # Concatenate all chunks
    separated = np.concatenate(separated_chunks)

    return separated

def post_process_wiener(separated, mixture, n_fft=N_FFT, hop_length=HOP_LENGTH):
    """Apply Wiener filtering for better quality"""
    # Simple Wiener filter implementation
    stft_mix = librosa.stft(mixture, n_fft=n_fft, hop_length=hop_length)
    stft_sep = librosa.stft(separated, n_fft=n_fft, hop_length=hop_length)

    # Estimate noise as mixture - separated
    noise_estimate = np.abs(stft_mix) - np.abs(stft_sep)
    noise_estimate = np.maximum(noise_estimate, 0)

    # Apply Wiener filter
    wiener_mask = np.abs(stft_sep)**2 / (np.abs(stft_sep)**2 + noise_estimate**2 + 1e-8)
    filtered_stft = stft_mix * wiener_mask

    # Convert back to audio
    filtered_audio = librosa.istft(filtered_stft, hop_length=hop_length)

    return filtered_audio

def separate_file(input_path, output_path, model_path, target='vocals',
                  use_wiener=True, device=None):
    """Main separation function"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Loading model from {model_path}")
    # Load model
    model = SimplifiedOpenUnmix().to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    print(f"Loading audio from {input_path}")
    # Load audio
    audio = load_audio(input_path, mono=True)

    print(f"Separating {target}...")
    # Separate
    separated = separate_audio(audio, model, device)

    if use_wiener:
        print("Applying Wiener filter post-processing...")
        separated = post_process_wiener(separated, audio)

    print(f"Saving to {output_path}")
    # Save output
    sf.write(output_path, separated, 44100)

    print("Separation complete!")

    return separated

def batch_separate(input_dir, output_dir, model_path, target='vocals'):
    """Separate multiple files"""
    import os
    from pathlib import Path

    input_path = Path(input_dir)
    output_path = Path(output_dir)
    output_path.mkdir(exist_ok=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load model once
    model = SimplifiedOpenUnmix().to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Process all audio files
    audio_extensions = ['.wav', '.mp3', '.flac', '.m4a']
    audio_files = [f for f in input_path.iterdir()
                   if f.suffix.lower() in audio_extensions]

    for audio_file in audio_files:
        print(f"\nProcessing {audio_file.name}")

        try:
            # Load audio
            audio = load_audio(str(audio_file), mono=True)

            # Separate
            separated = separate_audio(audio, model, device)

            # Post-process
            separated = post_process_wiener(separated, audio)

            # Save
            output_file = output_path / f"{audio_file.stem}_{target}.wav"
            sf.write(output_file, separated, 44100)

            print(f"Saved to {output_file}")

        except Exception as e:
            print(f"Error processing {audio_file.name}: {e}")

# Example usage
if __name__ == "__main__":
    # Single file separation
    separate_file(
        input_path="path/to/input.wav",
        output_path="path/to/output_vocals.wav",
        model_path="best_model_vocals.pth",
        target='vocals',
        use_wiener=True
    )

    # Batch processing
    # batch_separate(
    #     input_dir="path/to/input_folder",
    #     output_dir="path/to/output_folder",
    #     model_path="best_model_vocals.pth",
    #     target='vocals'
    # )

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import time
from train_simple_openunmix import SimplifiedOpenUnmix, MUSDBDataset, train_epoch, validate

def train_all_stems(musdb_root, max_hours=4, device='cuda'):
    """Train models for all 4 stems within time budget"""

    stems = ['vocals', 'drums', 'bass', 'other']
    time_per_stem = (max_hours * 3600) / len(stems)  # Seconds per stem

    results = {}

    for stem in stems:
        print(f"\n{'='*50}")
        print(f"Training model for: {stem}")
        print(f"Time budget: {time_per_stem/3600:.1f} hours")
        print(f"{'='*50}\n")

        # Initialize model
        model = SimplifiedOpenUnmix().to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.5)

        # Setup data
        train_dataset = MUSDBDataset(musdb_root, subset='train', target=stem)
        val_dataset = MUSDBDataset(musdb_root, subset='test', target=stem)

        # Adjust batch size based on available memory
        batch_size = 16 if device == 'cuda' else 8

        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=(device == 'cuda')
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=4,
            shuffle=False,
            num_workers=2
        )

        # Training loop with time constraint
        start_time = time.time()
        best_sdr = -np.inf
        epoch = 0

        while (time.time() - start_time) < time_per_stem:
            epoch += 1
            print(f"\nEpoch {epoch} for {stem}")

            # Train
            train_loss = train_epoch(model, train_loader, optimizer, device)
            print(f"Training loss: {train_loss:.4f}")

            # Quick validation every 3 epochs
            if epoch % 3 == 0:
                metrics = validate(model, val_loader, device)
                print(f"Validation - SDR: {metrics['SDR']:.2f} dB")

                scheduler.step(metrics['SDR'])

                if metrics['SDR'] > best_sdr:
                    best_sdr = metrics['SDR']
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'metrics': metrics,
                    }, f'model_{stem}.pth')

            # Check time
            elapsed = time.time() - start_time
            remaining = time_per_stem - elapsed
            print(f"Time elapsed: {elapsed/60:.1f} min, remaining: {remaining/60:.1f} min")

            if remaining < 120:  # Less than 2 minutes left
                print(f"Time limit approaching, stopping training for {stem}")
                break

        # Final validation
        print(f"\nFinal validation for {stem}")
        final_metrics = validate(model, val_loader, device)
        results[stem] = final_metrics

        print(f"\nCompleted {stem} - Final SDR: {final_metrics['SDR']:.2f} dB")

    # Summary
    print(f"\n{'='*50}")
    print("TRAINING COMPLETE - SUMMARY")
    print(f"{'='*50}")
    for stem, metrics in results.items():
        print(f"{stem:8s}: SDR={metrics['SDR']:6.2f} dB, "
              f"SIR={metrics['SIR']:6.2f} dB, "
              f"SAR={metrics['SAR']:6.2f} dB")

    return results

def create_separation_script():
    """Create a complete separation script using all trained models"""

    script_content = '''
import torch
import numpy as np
import librosa
import soundfile as sf
from train_simple_openunmix import SimplifiedOpenUnmix, compute_stft, N_FFT, HOP_LENGTH

def separate_all_stems(input_path, output_dir, models_dir='.'):
    """Separate audio into all 4 stems"""
    import os
    from pathlib import Path

    # Create output directory
    output_path = Path(output_dir)
    output_path.mkdir(exist_ok=True)

    # Load audio
    audio, sr = librosa.load(input_path, sr=44100, mono=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    stems = ['vocals', 'drums', 'bass', 'other']

    for stem in stems:
        print(f"Separating {stem}...")

        # Load model
        model = SimplifiedOpenUnmix().to(device)
        model_path = Path(models_dir) / f'model_{stem}.pth'

        if not model_path.exists():
            print(f"Model not found: {model_path}")
            continue

        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()

        # Separate
        magnitude, phase = compute_stft(audio)
        mag_tensor = torch.FloatTensor(magnitude).unsqueeze(0).to(device)

        with torch.no_grad():
            mask = model(mag_tensor + 1e-8).squeeze(0).cpu().numpy()

        # Apply mask and reconstruct
        separated_magnitude = magnitude * mask
        stft_separated = separated_magnitude.T * np.exp(1j * phase.T)
        separated_audio = librosa.istft(stft_separated, hop_length=HOP_LENGTH)

        # Save
        output_file = output_path / f'{Path(input_path).stem}_{stem}.wav'
        sf.write(output_file, separated_audio, sr)
        print(f"Saved: {output_file}")

    print("Separation complete!")

if __name__ == "__main__":
    separate_all_stems(
        input_path="path/to/input.wav",
        output_dir="separated_output",
        models_dir="."
    )
'''

    with open('separate_all_stems.py', 'w') as f:
        f.write(script_content)

    print("Created separation script: separate_all_stems.py")

if __name__ == "__main__":
    # Configuration
    MUSDB_ROOT = "/path/to/musdb18hq"  # UPDATE THIS
    MAX_TRAINING_HOURS = 4

    # Check device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    if device.type == 'cpu':
        print("WARNING: Training on CPU will be slow. Consider using GPU.")
        print("Reducing model size and batch size for CPU training...")

    # Train all stems
    results = train_all_stems(MUSDB_ROOT, MAX_TRAINING_HOURS, device)

    # Create separation script
    create_separation_script()

    print("\nTraining complete! Use 'separate_all_stems.py' to separate new songs.")