In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install scipy torch torchaudio museval
!git clone https://github.com/sigsep/open-unmix-pytorch.git
%cd /content/open-unmix-pytorch/scripts
!pip install -r requirements.txt
!pip install pydub
!pip install openunmix
!pip install pyloudnorm
!pip install noisereduce
!apt-get install -y sox libsox-dev libsox-fmt-all
!pip install scipy==1.9.3
!pip install nussl
!pip install datasets
!pip install demucs

In [None]:
import os
import shutil
import random
import librosa
import soundfile as sf
import numpy as np
import torch
import librosa.display
import scipy.signal
import noisereduce as nr
import IPython.display as ipd
import nussl
from openunmix import predict, utils

In [None]:
# ======================== Organize Drum Loops ======================== #

# Define paths
drum_loops_path = "/content/drive/MyDrive/Drum_Loops"
output_path = "/content/train_data"
train_path, valid_path = os.path.join(output_path, "train"), os.path.join(output_path, "valid")

# Drum loop categories
categories = {
    "kick": ["kick", "Kick"],
    "hihat": ["hihat", "Hi-Hat", "hats"],
    "clap": ["clap", "Claps", "Clap_Loop"],
    "percussion": ["percu", "Percu"]
}

# Create categorized subfolders
for category in categories.keys():
    os.makedirs(os.path.join(drum_loops_path, category), exist_ok=True)

# Move files into their respective folders
for file in os.listdir(drum_loops_path):
    file_path = os.path.join(drum_loops_path, file)
    if os.path.isfile(file_path):  # Only move files, not directories
        for category, keywords in categories.items():
            if any(keyword.lower() in file.lower() for keyword in keywords):
                shutil.move(file_path, os.path.join(drum_loops_path, category, file))
                print(f" Moved {file} to {category}/")

print("\n All drum loops are now correctly categorized!")

# ==================== Generate Synthetic Mixtures ==================== #

# Create necessary directories
os.makedirs(train_path, exist_ok=True)
os.makedirs(valid_path, exist_ok=True)

# Load available drum stems
kick_files = librosa.util.find_files(os.path.join(drum_loops_path, "kick"))
hihat_files = librosa.util.find_files(os.path.join(drum_loops_path, "hihat"))
clap_files = librosa.util.find_files(os.path.join(drum_loops_path, "clap"))
perc_files = librosa.util.find_files(os.path.join(drum_loops_path, "percussion"))

# Ensure we have files for each category
assert kick_files and hihat_files and clap_files and perc_files, " Some drum categories are missing!"

# STFT parameters (matching Open-Unmix)
n_fft, hop_length = 4096, 1024

# Generate 100 random drum mixtures (90% train, 10% validation)
num_samples, train_size = 100, int(100 * 0.9)

for i in range(1, num_samples + 1):
    # Randomly select drum stems
    kick, _ = librosa.load(random.choice(kick_files), sr=44100)
    hihat, _ = librosa.load(random.choice(hihat_files), sr=44100)
    clap, _ = librosa.load(random.choice(clap_files), sr=44100)
    perc, _ = librosa.load(random.choice(perc_files), sr=44100)

    # Pad and normalize to the longest sample
    max_len = max(map(len, [kick, hihat, clap, perc]))
    kick, hihat, clap, perc = (np.pad(s, (0, max_len - len(s))) for s in [kick, hihat, clap, perc])

    # Create synthetic mixture
    mixture = 0.4 * kick + 0.3 * hihat + 0.5 * clap + 0.3 * perc
    mixture /= np.max(np.abs(mixture))  # Normalize

    # Normalize "clap.wav" separately
    clap /= np.max(np.abs(clap))

    # Assign to train/valid split
    dataset_folder = train_path if i <= train_size else valid_path
    track_folder = os.path.join(dataset_folder, str(i))
    os.makedirs(track_folder, exist_ok=True)

    # Save mixture and target
    sf.write(os.path.join(track_folder, "mixture.wav"), mixture, 44100)
    sf.write(os.path.join(track_folder, "clap.wav"), clap, 44100)

    print(f" Created {track_folder}/mixture.wav & {track_folder}/clap.wav")

print("\n **100 synthetic drum mixtures generated with `nfft=4096` and `hop_length=1024`!**")

# ====================== Convert Audio to Stereo ====================== #

def convert_to_stereo(directory):
    """Convert mono WAV files to stereo for model compatibility."""
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".wav"):
                file_path = os.path.join(root, file)
                y, sr = librosa.load(file_path, sr=44100, mono=False)

                if len(y.shape) == 1:  # Convert mono to stereo
                    y = np.vstack([y, y])

                sf.write(file_path, y.T, sr)
                print(f" Converted {file} to stereo (2 channels)")

convert_to_stereo(output_path)
print(" All dataset audio files are now stereo!")

# ===================== Compute STFT for Debugging ===================== #

sample_file = os.path.join(train_path, "1/mixture.wav")
y, sr = librosa.load(sample_file, sr=44100)
stft = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
magnitude = np.abs(stft)

print(f" Spectrogram Shape: {magnitude.shape}")  # (freq_bins, time_frames)
print(f" Frequency bins: {magnitude.shape[0]}")
print(f" Time frames: {magnitude.shape[1]}")


In [None]:
!python train.py \
    --dataset aligned \
    --root /content/train_data \
    --input-file mixture.wav \
    --output-file clap.wav \
    --seq-dur 7.74 \
    --nfft 2048 \
    --nhop 1024 \
    --epochs 15 \
    --batch-size 12

In [None]:
# Load mixture audio
mixture_path = "/content/train_data/valid/99/mixture.wav"
y, sr = librosa.load(mixture_path, sr=44100, mono=False)

# Convert to PyTorch tensor and move to the correct device
device = "cuda" if torch.cuda.is_available() else "cpu"
audio_tensor = torch.tensor(y, dtype=torch.float32).to(device)

# Ensure correct paths
model_dir = "/content/open-unmix-pytorch/scripts/open-unmix/"  # Directory containing clap.pth and clap.json
model_name = "clap"  # The target you trained on

# 🔍 DEBUG: Check Model Availability
print(f" Loading separator for model: {model_name} on {device}")

separator = utils.load_separator(
    model_str_or_path=model_dir,
    targets=[model_name],
    niter=1,
    residual=True,  # Fix: Enable residual to avoid single target issues
    wiener_win_len=300,
    device=device,
    pretrained=True,
    filterbank="torch",
)

separator.freeze()
separator.to(device)  # Move the separator model to the correct device
print(f" Separator loaded on {device}: {separator}")

# Run separation manually, ensuring everything is on the same device
estimates = predict.separate(
    audio=audio_tensor,  # Now on the correct device
    rate=sr,
    separator=separator,  # Directly use loaded separator
    targets=[model_name],
)

# Check if the model produced an estimate
if model_name not in estimates:
    print(f"No '{model_name}' estimate was generated!")
else:
    # Extract the separated clap sound and remove extra dimensions
    clap_estimate = np.squeeze(estimates[model_name].cpu().numpy())  # Remove unnecessary dimensions

    # Save the separated claps
    output_path = f"/content/train_data/valid/99/{model_name}_estimate.wav"
    sf.write(output_path, clap_estimate.T, sr)  # Now correctly shaped

    print(f" Separated {model_name} saved at: {output_path}")

    # Play the estimated clap audio
    import IPython.display as ipd
    print(f"\n Separated {model_name} (Model Output):")
    ipd.display(ipd.Audio(output_path, rate=sr))


In [None]:
# Define file paths
input_wav = "/content/train_data/valid/99/clap_estimate.wav"  # Your separated file
output_wav = "/content/train_data/valid/99/clap_cleaned.wav"

# Load audio
y, sr = librosa.load(input_wav, sr=44100)

###  Apply Wiener Filtering with Larger Kernel (Less Aggressive)
y_denoised = scipy.signal.wiener(y, mysize=512)  # Reduced smoothing to keep details

### Soft Bandpass Filtering (Preserve More Highs & Lows)
nyquist = sr / 2
low_cutoff = 100  # Lower cutoff to preserve more bass
high_cutoff = 12000  # Higher cutoff to keep more high frequencies

b, a = scipy.signal.butter(3, [low_cutoff / nyquist, high_cutoff / nyquist], btype='band')
y_filtered = scipy.signal.filtfilt(b, a, y_denoised)

### Reduce Noise Without Overdoing It
y_reduced = nr.reduce_noise(y=y_filtered, sr=sr, prop_decrease=0.2, stationary=True)  # Lower noise reduction

### Normalize Volume (Fix Low Output Levels)
y_normalized = y_reduced / np.max(np.abs(y_reduced)) * 0.9  # Avoid clipping but keep loudness

### Save Processed File
sf.write(output_wav, y_normalized, sr)

# Play the cleaned file
print(f" Cleaned file saved as {output_wav}")
ipd.display(ipd.Audio(output_wav, rate=sr))


In [None]:
# Load MUSDB18 dataset (test subset)
musdb = nussl.datasets.MUSDB18(subsets=['test'], download=True)

# Manually set this index to browse different samples
sample_idx = 27  # Change this to listen to different samples

# Load the sample
item = musdb[sample_idx]
mixture_audio = item['mix']
sr = mixture_audio.sample_rate

# Save the sample for playback
test_sample_path = "/content/musdb18_sample.wav"
sf.write(test_sample_path, mixture_audio.audio_data.T, sr)

# 🎧 Play the sample
print(f" Now playing sample {sample_idx}:")
ipd.display(ipd.Audio(test_sample_path, rate=sr))


In [None]:
# Convert audio to PyTorch tensor and move to the correct device
device = "cuda" if torch.cuda.is_available() else "cpu"
audio_tensor = torch.tensor(mixture_audio.audio_data, dtype=torch.float32).to(device)

# Ensure correct model paths
model_dir = "/content/open-unmix-pytorch/scripts/open-unmix/"  # Path to trained model
model_name = "clap"  # The target you trained for

# Load separator model
print(f" Loading separator for model: {model_name} on {device}")
separator = utils.load_separator(
    model_str_or_path=model_dir,
    targets=[model_name],
    niter=1,
    residual=True,  # Enable residual to avoid single target issues
    wiener_win_len=300,
    device=device,
    pretrained=True,
    filterbank="torch",
)

separator.to(device).freeze()
print(f" Separator loaded successfully!")

# Run model on the mixture
estimates = predict.separate(
    audio=audio_tensor,
    rate=sr,
    separator=separator,
    targets=[model_name],
)

# Check if output exists
if model_name not in estimates:
    print(f" No '{model_name}' estimate was generated!")
else:
    # Extract separated claps
    clap_estimate = np.squeeze(estimates[model_name].cpu().numpy())

    # Save the separated claps
    output_path = "/content/musdb18_clap_estimate.wav"
    sf.write(output_path, clap_estimate.T, sr)

    print(f"Separated claps saved at: {output_path}")

    # Play the separated claps
    print("\n Separated Claps (Model Output):")
    ipd.display(ipd.Audio(output_path, rate=sr))


In [None]:
# Load the separated clap audio
clap_path = "/content/musdb18_clap_estimate.wav"
y, sr = librosa.load(clap_path, sr=44100)

# Apply noise reduction
reduced_noise = nr.reduce_noise(y=y, sr=sr, prop_decrease=0.7)  # Tune `prop_decrease` if needed

# Apply dynamic range compression
compressed = librosa.effects.percussive(reduced_noise)

# Normalize the audio
normalized = librosa.util.normalize(compressed)

# Save the post-processed audio
processed_path = "/content/musdb18_clap_estimate_processed.wav"
sf.write(processed_path, normalized, sr)

# Play the post-processed audio
print("\n Post-processed Claps (Denoised & Enhanced):")
ipd.display(ipd.Audio(processed_path, rate=sr))


## After clap


In [None]:
import os
import random
import numpy as np
import librosa
import soundfile as sf
import shutil
import gc
import subprocess
from pathlib import Path

# --- FIXED FUNCTIONS ---
def loop_audio_to_duration(audio, target_length):
    """Ensure audio is exactly target_length samples"""
    if audio.ndim == 1:
        current_length = len(audio)
        if current_length < target_length:
            repeats = (target_length // current_length) + 1
            return np.tile(audio, repeats)[:target_length]
        return audio[:target_length]
    else:
        current_length = audio.shape[1]
        if current_length < target_length:
            repeats = (target_length // current_length) + 1
            return np.tile(audio, (1, repeats))[:, :target_length]
        return audio[:, :target_length]

def to_stereo(audio):
    """Convert mono to stereo (samples, 2)"""
    return np.column_stack([audio, audio]) if audio.ndim == 1 else audio.T if audio.shape[0] == 2 else audio

def augment_audio(audio, sr):
    """Fixed time_stretch call with rate= parameter"""
    if np.random.rand() < 0.5:
        steps = random.choice([-2, -1, 1, 2])
        audio = librosa.effects.pitch_shift(audio, sr=sr, n_steps=steps)
    if np.random.rand() < 0.5:
        rate = random.uniform(0.85, 1.15)
        audio = librosa.effects.time_stretch(audio, rate=rate)  # Corrected line
    return audio

# --- ENHANCED GENERATION ---
def generate_dataset():
    # Configuration
    input_path = "/content/drive/MyDrive/Drum_Loops"
    output_path = "/content/train_data_v2"
    sr = 44100
    seq_dur = 7.75  # Must match training
    max_samples = int(seq_dur * sr)

    # Clean setup
    shutil.rmtree(output_path, ignore_errors=True)
    os.makedirs(os.path.join(output_path, "train"), exist_ok=True)
    os.makedirs(os.path.join(output_path, "valid"), exist_ok=True)

    # Load instrument files
    instrument_files = {}
    categories = ["percussion", "kick", "hihat", "harp", "flute",
                "clarinet", "clap", "choir", "cello", "acousticGuitar"]

    for cat in categories:
        cat_path = os.path.join(input_path, cat)
        instrument_files[cat] = librosa.util.find_files(cat_path) if os.path.exists(cat_path) else []
        print(f"Found {len(instrument_files[cat])} {cat} files")

    # Generate 300 samples with Demucs integration
    for i in range(1, 301):
        try:
            # Create original mixture
            track_dir = os.path.join(output_path, "train" if i <= 250 else "valid", f"{i:03d}")
            os.makedirs(track_dir, exist_ok=True)

            # 1. Generate flute track
            flute_file = random.choice(instrument_files["flute"])
            flute_audio, _ = librosa.load(flute_file, sr=sr, mono=True)
            flute_audio = augment_audio(flute_audio, sr)
            flute_audio = loop_audio_to_duration(flute_audio, max_samples)
            sf.write(os.path.join(track_dir, "flute.wav"), to_stereo(flute_audio), sr)

            # 2. Create mixture
            other_cats = [c for c in categories if c != "flute" and instrument_files[c]]
            selected = random.sample(other_cats, random.randint(2, 5))

            mixture = 0.7 * flute_audio
            for cat in selected:
                audio_file = random.choice(instrument_files[cat])
                audio, _ = librosa.load(audio_file, sr=sr, mono=True)
                audio = augment_audio(audio, sr)
                audio = loop_audio_to_duration(audio, max_samples)
                mixture += 0.3 * audio

            # Save original mixture
            sf.write(os.path.join(track_dir, "mixture_original.wav"), to_stereo(mixture), sr)

            # 3. Process with Demucs
            subprocess.run([
                "demucs",
                "--two-stems", "other",
                os.path.join(track_dir, "mixture_original.wav"),
                "-o", track_dir
            ], check=True)

            # 4. Replace mixture with Demucs' output
            demucs_mix = os.path.join(track_dir, "htdemucs", "mixture_original", "other.wav")
            os.rename(demucs_mix, os.path.join(track_dir, "mixture.wav"))

            # Cleanup
            shutil.rmtree(os.path.join(track_dir, "htdemucs"))
            os.remove(os.path.join(track_dir, "mixture_original.wav"))

            print(f"[{i}] Generated track {track_dir}")

        except Exception as e:
            print(f"Skipped track {i}: {str(e)}")
            continue

    print("✅ Dataset generated with Demucs pre-processing!")

if __name__ == "__main__":
    generate_dataset()


In [None]:
!python open-unmix-pytorch/scripts/train.py \
    --dataset aligned \
    --root /content/train_data_v2 \
    --input-file mixture.wav \
    --output-file flute.wav \
    --target flute \
    --seq-dur 7.74 \
    --nfft 4096 \
    --nhop 1024 \
    --epochs 30 \
    --batch-size 16 \
    --hidden-size 512 \
    --lr 0.001 \
    --weight-decay 0.0001


In [None]:
import os
import torch
import librosa
import soundfile as sf
import numpy as np
import subprocess
import shutil
from pathlib import Path
from openunmix import predict
from openunmix import utils

def process_full_song(input_path, output_path, model_dir, target_name="flute", max_duration=None):
    # Step 1: Run Demucs to get 'other' stem
    demucs_output_dir = os.path.join("separated", "htdemucs")
    base_name = Path(input_path).stem

    # Run Demucs command
    subprocess.run([
        "demucs",
        "--two-stems", "other",
        "-n", "htdemucs",
        input_path,
        "-o", demucs_output_dir
    ], check=True)

    # Step 2: Load Demucs' 'other' stem
    # Update the Demucs output path handling
    other_stem_path = os.path.join(
        demucs_output_dir,
        "htdemucs",
        Path(input_path).with_suffix("").name,  # Handles spaces/special chars
        "other.wav"
    )

    ipd.display(ipd.Audio(other_stem_path, rate=44100))
    y, sr = librosa.load(other_stem_path, sr=44100, mono=False)

    if y.ndim == 1:
        y = np.vstack([y, y])  # Ensure stereo

    # Apply duration limit if specified
    if max_duration:
        max_samples = int(max_duration * sr)
        y = y[:, :max_samples]

    # Step 3: Initialize separator
    device = "cuda" if torch.cuda.is_available() else "cpu"
    separator = utils.load_separator(
        model_str_or_path=model_dir,
        targets=[target_name],
        niter=0,
        residual=True,
        device=device,
        pretrained=False
    )

    # Load custom weights
    state_dict = torch.load(f"{model_dir}/{target_name}.pth", map_location=device)
    adjusted_state_dict = {f"target_models.{target_name}.{k}": v for k, v in state_dict.items()}
    separator.load_state_dict(adjusted_state_dict, strict=False)
    separator.to(device).eval()

    # Step 4: Process in chunks
    chunk_size = int(7.74 * sr)
    estimates = []

    for i in range(0, y.shape[1], chunk_size):
        chunk = y[:, i:i+chunk_size]
        audio_tensor = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).to(device)

        with torch.no_grad():
            estimates_chunk = predict.separate(
                audio=audio_tensor,
                rate=sr,
                separator=separator,
                targets=[target_name]
            )
            flute_estimate = estimates_chunk[target_name].squeeze().cpu().numpy().T

        estimates.append(flute_estimate)

    # Step 5: Combine and save
    full_estimate = np.concatenate(estimates, axis=0)
    sf.write(output_path, full_estimate, sr)

    # Cleanup Demucs output
    shutil.rmtree(os.path.join(demucs_output_dir, "htdemucs", base_name))

    print(f"Flute track saved to {output_path}")

# Usage
process_full_song(
    input_path="/content/separated/htdemucs/htdemucs/test.wav",
    output_path="flute_enhanced.wav",
    model_dir="/content/open-unmix/"
    # max_duration=7.74  # Remove for full song
)


In [None]:
import IPython.display as ipd
import os

flute_path = "flute_enhanced.wav"
other_stem_path = os.path.join("separated", "htdemucs", "test", "other.wav")

if os.path.exists(flute_path):
    print("Flute Enhanced Audio:")
    ipd.display(ipd.Audio(flute_path, rate=44100))
else:
    print("Flute track generation failed - check model training")

if os.path.exists(other_stem_path):
    print("\nDemucs 'Other' Stem:")
    ipd.display(ipd.Audio(other_stem_path, rate=44100))
else:
    print("Demucs failed - check input file permissions")

## BabySlakh

In [None]:
import os
import yaml
import tarfile
from collections import Counter
import matplotlib.pyplot as plt
import pandas as pd

# Step 1: Extract the tar.gz file if needed
def extract_archive(archive_path, extract_path):
    if not os.path.exists(extract_path):
        print(f"Extracting {archive_path} to {extract_path}...")
        with tarfile.open(archive_path, 'r:gz') as tar:
            tar.extractall(path=extract_path)
        print("Extraction complete!")
    else:
        print(f"Using existing directory: {extract_path}")

# Step 2: Analyze the dataset
def analyze_babyslakh(dataset_path):
    # Initialize counters
    inst_class_counter = Counter()
    program_name_counter = Counter()
    plugin_name_counter = Counter()
    is_drum_counter = Counter()

    # Track additional information
    track_count = 0
    stems_per_track = []

    # Walk through all directories
    print(f"Scanning dataset at: {dataset_path}")
    for root, dirs, files in os.walk(dataset_path):
        if 'metadata.yaml' in files:
            track_count += 1
            metadata_path = os.path.join(root, 'metadata.yaml')

            try:
                with open(metadata_path, 'r') as f:
                    metadata = yaml.safe_load(f)

                # Skip if no stems data
                if 'stems' not in metadata:
                    continue

                # Count stems in this track
                stems_per_track.append(len(metadata['stems']))

                # Analyze each stem
                for stem_id, stem_data in metadata['stems'].items():
                    if 'inst_class' in stem_data:
                        inst_class_counter[stem_data['inst_class']] += 1
                    if 'midi_program_name' in stem_data:
                        program_name_counter[stem_data['midi_program_name']] += 1
                    if 'plugin_name' in stem_data:
                        plugin_name_counter[stem_data['plugin_name']] += 1
                    if 'is_drum' in stem_data:
                        is_drum_counter[stem_data['is_drum']] += 1

            except Exception as e:
                print(f"Error processing {metadata_path}: {e}")

    return {
        'track_count': track_count,
        'stems_per_track': stems_per_track,
        'inst_class': inst_class_counter,
        'program_name': program_name_counter,
        'plugin_name': plugin_name_counter,
        'is_drum': is_drum_counter
    }

# Step 3: Visualize results
def visualize_results(results):
    # Print overall statistics
    print(f"\nTotal tracks analyzed: {results['track_count']}")
    print(f"Average stems per track: {sum(results['stems_per_track'])/len(results['stems_per_track']):.2f}")
    print(f"Min stems: {min(results['stems_per_track'])}, Max stems: {max(results['stems_per_track'])}")

    # Create pandas DataFrames for better display
    print("\n--- Top 10 Instrument Classes ---")
    inst_df = pd.DataFrame(results['inst_class'].most_common(), columns=['Instrument Class', 'Count'])
    print(inst_df.head(10))

    print("\n--- Top 10 Specific Instruments ---")
    prog_df = pd.DataFrame(results['program_name'].most_common(), columns=['Instrument', 'Count'])
    print(prog_df.head(10))

    print("\n--- Drum vs Non-Drum Instruments ---")
    drum_df = pd.DataFrame(results['is_drum'].most_common(), columns=['Is Drum', 'Count'])
    print(drum_df)

    # Visualization of top instrument classes
    plt.figure(figsize=(12, 6))
    top_classes = dict(results['inst_class'].most_common(10))
    plt.bar(top_classes.keys(), top_classes.values())
    plt.xticks(rotation=45, ha='right')
    plt.title('Top 10 Instrument Classes')
    plt.tight_layout()
    plt.savefig('instrument_classes.png')

    # Visualization of top specific instruments
    plt.figure(figsize=(12, 6))
    top_instruments = dict(results['program_name'].most_common(10))
    plt.bar(top_instruments.keys(), top_instruments.values())
    plt.xticks(rotation=45, ha='right')
    plt.title('Top 10 Specific Instruments')
    plt.tight_layout()
    plt.savefig('specific_instruments.png')

# Main execution function
def main():
    # Set your paths
    archive_path = '/content/drive/MyDrive/babyslakh_16k.tar.gz'
    extract_path = '/content/drive/MyDrive/babyslakh_16k'

    # Extract if needed
    try:
        extract_archive(archive_path, extract_path)
    except Exception as e:
        print(f"Extraction error: {e}")
        print("Continuing with analysis assuming the dataset is already extracted...")

    # Analyze the dataset
    results = analyze_babyslakh(extract_path)

    # Display and visualize results
    visualize_results(results)

    # Save results to CSV for further analysis
    pd.DataFrame(results['inst_class'].most_common(), columns=['Instrument Class', 'Count']).to_csv('instrument_classes.csv')
    pd.DataFrame(results['program_name'].most_common(), columns=['Instrument', 'Count']).to_csv('specific_instruments.csv')

    print("\nAnalysis complete! CSV files and plots have been saved.")

if __name__ == "__main__":
    main()


In [None]:
import os
import shutil
import yaml
import numpy as np
import torchaudio
import torch
from torchaudio.transforms import Resample

# --- OPEN-UNMIX COMPLIANT PARAMETERS ---
TARGET_SR = 44100        # Mandatory sample rate[1][7]
NFFT = 4096              # STFT size (cannot be changed)[1][4]
NHOP = 1024              # STFT hop size[1][4]
SEQ_DUR = 6.1            # Default training duration[1][7]
BATCH_SIZE = 16          # Default batch size[7]
NB_CHANNELS = 2          # Force stereo input[4]

def create_dataset():
    # Clean existing data
    train_data_path = "/content/train_data"
    if os.path.exists(train_data_path):
        shutil.rmtree(train_data_path)
        print("Removed existing training data")

    # Create directory structure
    os.makedirs(os.path.join(train_data_path, "train"), exist_ok=True)
    os.makedirs(os.path.join(train_data_path, "valid"), exist_ok=True)

    # Initialize resampler for 16kHz→44.1kHz conversion
    resampler = Resample(
        orig_freq=16000,
        new_freq=TARGET_SR,
        resampling_method="sinc_interp_kaiser"
    )

    # Find all Choir Aahs tracks
    choir_tracks = []
    for root, _, files in os.walk("/content/drive/MyDrive/babyslakh_16k"):
        if "metadata.yaml" in files:
            with open(os.path.join(root, "metadata.yaml")) as f:
                metadata = yaml.safe_load(f)

            for stem_id, stem_data in metadata.get("stems", {}).items():
                if stem_data.get("midi_program_name", "").lower() == "choir aahs":
                    stem_path = os.path.join(root, "stems", f"{stem_id}.wav")
                    if os.path.exists(stem_path):
                        choir_tracks.append((root, stem_path))
                        break

    # Split tracks (2 validation, rest training)
    np.random.seed(42)
    np.random.shuffle(choir_tracks)
    valid_tracks = choir_tracks[:2]
    train_tracks = choir_tracks[2:]

    # Audio processing parameters
    max_samples = int(SEQ_DUR * TARGET_SR)  # 6s * 44100 = 264600 samples

    def process_and_save(track_path, stem_path, dest_dir):
        """Process audio to meet Open-Unmix requirements"""
        os.makedirs(dest_dir, exist_ok=True)

        # Process mixture (full track mix)
        mix_wave, _ = torchaudio.load(os.path.join(track_path, "mix.wav"))
        mix_wave = resampler(mix_wave)
        mix_wave = mix_wave[:, :max_samples]  # Exact 6s duration
        if mix_wave.shape[0] == 1:  # Force stereo
            mix_wave = torch.cat([mix_wave, mix_wave], dim=0)
        torchaudio.save(
            os.path.join(dest_dir, "mixture.wav"),
            mix_wave,
            TARGET_SR
        )

        # Process target stem
        stem_wave, _ = torchaudio.load(stem_path)
        stem_wave = resampler(stem_wave)
        stem_wave = stem_wave[:, :max_samples]
        if stem_wave.shape[0] == 1:
            stem_wave = torch.cat([stem_wave, stem_wave], dim=0)
        torchaudio.save(
            os.path.join(dest_dir, "choir_aahs.wav"),
            stem_wave,
            TARGET_SR
        )

    # Generate training data
    for i, (track_path, stem_path) in enumerate(train_tracks):
        dest_dir = os.path.join(train_data_path, "train", f"track_{i:03d}")
        process_and_save(track_path, stem_path, dest_dir)

    # Generate validation data
    for i, (track_path, stem_path) in enumerate(valid_tracks):
        dest_dir = os.path.join(train_data_path, "valid", f"track_{i:03d}")
        process_and_save(track_path, stem_path, dest_dir)

    print(f"""
    Dataset successfully created!
    - Training tracks: {len(train_tracks)}
    - Validation tracks: {len(valid_tracks)}
    - Sample rate: {TARGET_SR}Hz
    - Sequence duration: {SEQ_DUR}s
    - STFT parameters: nfft={NFFT}, nhop={NHOP}
    """)

if __name__ == "__main__":
    create_dataset()


In [None]:
!python open-unmix-pytorch/scripts/train.py \
    --dataset aligned \
    --root /content/train_data \
    --input-file mixture.wav \
    --output-file choir_aahs.wav \
    --seq-dur 6.0 \
    --nfft 4096 \
    --nhop 1024 \
    --nb-channels 2 \
    --batch-size 16 \
    --epochs 50 \
    --patience 140

In [None]:
import librosa
import torch
import torchaudio
import numpy as np
import soundfile as sf
from openunmix import predict, utils

# Configuration
TARGET_NAME = "choir_aahs"
MODEL_DIR = "/content/open-unmix/"  # Contains choir_aahs.pth and choir_aahs.json
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load mixture audio with proper channel handling
mixture_path = "/content/train_data/valid/track_000/mixture.wav"
y, sr = librosa.load(mixture_path, sr=44100, mono=False)

# Convert to PyTorch tensor with correct shape (channels, samples)
if y.ndim == 1:  # Convert mono to stereo
    y = np.stack([y, y], axis=0)
audio_tensor = torch.tensor(y, dtype=torch.float32).to(DEVICE)

# Load separator with proper configuration (from Open-Unmix docs[1][6])
print(f"Loading separator for {TARGET_NAME} on {DEVICE}")
separator = utils.load_separator(
    model_str_or_path=MODEL_DIR,
    targets=[TARGET_NAME],
    niter=1,
    residual=False,  # Disable residual for single-target separation[1][9]
    wiener_win_len=300,
    device=DEVICE,
    pretrained=True,
    filterbank="torch",
)
separator.freeze()
separator.to(DEVICE)

# Perform separation with proper device alignment
print(f"Separating {TARGET_NAME}...")
# Updated separation call
estimates = predict.separate(
    audio=audio_tensor,
    rate=sr,
    targets=[TARGET_NAME],
    model_dir=MODEL_DIR,  # Directory containing choir_aahs.pth/choir_aahs.json
    niter=0,  # Disable EM for single target
    residual=False,
    device=DEVICE
)

# Handle output with proper shape validation
if TARGET_NAME not in estimates:
    raise ValueError(f"No '{TARGET_NAME}' estimate generated!")

choir_estimate = estimates[TARGET_NAME].cpu().numpy()

# Remove batch dimension if present (batch_size=1)
if choir_estimate.ndim == 3:
    choir_estimate = choir_estimate.squeeze(0)

# Save output with proper channel order
output_path = f"/content/train_data/valid/track_000/{TARGET_NAME}_estimate.wav"
sf.write(output_path, choir_estimate.T, sr)  # Transpose to (samples, channels)

print(f"Successfully saved separation to {output_path}")

# Audio playback
from IPython.display import Audio
print(f"\nOriginal Mixture:")
display(Audio(mixture_path, rate=sr))
print(f"\nSeparated {TARGET_NAME}:")
display(Audio(output_path, rate=sr))
