In [71]:
import os
import json
import random
import torch
import torchaudio
import pandas as pd
from pathlib import Path
import sys
# Add the parent directory of 'notebooks' to sys.path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # Move one level up
if parent_dir not in sys.path:
    sys.path.append(parent_dir)
from datagen.chordgen import __generate_midi_chord as generate_midi_chord, __synthesize_to_wav as synthesize_to_wav, __note_lookup as note_lookup, CHORDS, INVERSIONS, GM_INSTRUMENTS, JSON_FILE
from datagen.fxgen_torch import TorchFXGenerator
from datagen.pedals import Distortion, Chorus, Delay, Reverb, Noise
from utils.gdrive import download_from_gdrive

In [72]:
# Define dataset characteristics
TRAIN_SIZE = 10 #2000
VAL_SIZE = 500
TEST_SIZE = 1000
BASE_PATH = Path("./timbral_bias_datasets")
SF2_SUBDIR = "sf2"
WAV_SUBDIR = "wav"
SF2_ARCHIVE = "FluidR3_GM.sf2"

In [73]:
# Revised genre-based chord groupings ensuring all chords are represented
# Primary chords for each genre (high probability)
METAL_PRIMARY = ["1", "5", "sus4", "sus4(b7)", "min", "dim"]
POP_PRIMARY = ["maj", "min", "sus2", "sus4", "maj6", "7", "maj/2", "maj/4", "min/2", "min/4"]
JAZZ_PRIMARY = ["maj7", "min7", "maj9", "min9", "11", "13", "maj6(9)", "minmaj7", "hdim7", "dim7"]

# Secondary chords (lower probability but ensures representation)
METAL_SECONDARY = list(set(CHORDS.keys()) - set(METAL_PRIMARY))
POP_SECONDARY = list(set(CHORDS.keys()) - set(POP_PRIMARY))
JAZZ_SECONDARY = list(set(CHORDS.keys()) - set(JAZZ_PRIMARY))

In [74]:
def setup_soundfonts(base_path: Path):
    """
    Download and set up soundfonts if not already present.
    Returns path to the base directory containing sf2 folder.
    """
    # Move sf2 directory to base level, not within temp
    sf2_dir = base_path.parent / SF2_SUBDIR  # Move up one level from dataset dir
    sf2_dir.mkdir(parents=True, exist_ok=True)
    sf_filepath = sf2_dir / SF2_ARCHIVE
    
    if not sf_filepath.exists():
        print(f"Downloading soundfont to {sf_filepath}")
        download_from_gdrive(SF2_ARCHIVE, str(sf_filepath.absolute()))
    else:
        print("Soundfont already exists, skipping download")
    
    return sf_filepath.parent  # Return the directory containing sf2 folder

def select_chords_for_genre(primary_chords, secondary_chords, primary_weight=0.8, count=100):
    """
    Select chords for a genre with weighted probability.
    
    Args:
        primary_chords: List of primary chords for the genre
        secondary_chords: List of secondary chords for the genre
        primary_weight: Probability weight for primary chords
        count: Number of chords to select
    """
    primary_count = int(count * primary_weight)
    secondary_count = count - primary_count
    
    selections = (
        random.choices(primary_chords, k=primary_count) +
        random.choices(secondary_chords, k=secondary_count)
    )
    return selections

In [75]:
# Define genre-appropriate instruments
METAL_INSTRUMENTS = {
    29: "overdriven_guitar",
    30: "distortion_guitar",
    27: "electric_guitar_(clean)"  # For some variety
}

POP_INSTRUMENTS = {
    80: "lead_1_(square)",
    81: "lead_2_(sawtooth)", 
    4: "electric_piano_1",
    5: "electric_piano_2",
    27: "electric_guitar_(clean)"
}

JAZZ_INSTRUMENTS = {
    0: "acoustic_grand_piano",
    24: "acoustic_guitar_(nylon)",
    26: "electric_guitar_(jazz)",
    4: "electric_piano_1"
}

# FX combinations remain the same
METAL_FX = {
    'distortion': ['classic_distortion', 'fuzz'],
    'noise': ['room_noise'],
    'reverb': ['small_room']
}

POP_FX = {
    'reverb': ['plate', 'large_hall'],
    'chorus': ['subtle', 'classic'],
    'distortion': ['subtle_drive']
}

JAZZ_FX = {
    'reverb': ['large_hall', 'plate'],
    'chorus': ['subtle'],
    'delay': ['subtle']
}

In [76]:
def convert_to_mp3(wav_path: Path, target_path: Path, sample_rate: int = 44100, bitrate: float = 192.0):
    """
    Convert WAV file to MP3 using torchaudio.
    
    Args:
        wav_path: Path to source WAV file
        target_path: Path to output MP3 file
        sample_rate: Target sample rate
        bitrate: Target bitrate in kbps
    """
    waveform, sr = torchaudio.load(str(wav_path))
    
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # Ensure correct sample rate
    if sr != sample_rate:
        waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
    
    # Save as MP3
    torchaudio.save(
        str(target_path),
        waveform,
        sample_rate,
        format="mp3",
        compression=bitrate/1000  # torchaudio expects compression rate in kbps/1000
    )

def batch_convert_wavs_to_mp3(wav_dir: Path, mp3_dir: Path):
    """
    Convert all WAV files in a directory to MP3.
    Updates metadata to reflect MP3 filenames.
    
    Args:
        wav_dir: Directory containing WAV files
        mp3_dir: Directory to save MP3 files
    """
    mp3_dir.mkdir(exist_ok=True)
    
    # Load metadata
    json_path = wav_dir.parent / JSON_FILE
    with open(json_path, 'r') as f:
        metadata = json.load(f)
    
    # Convert each file and update metadata
    updated_metadata = {}
    for wav_file in wav_dir.glob("*.wav"):
        mp3_file = mp3_dir / wav_file.with_suffix('.mp3').name
        convert_to_mp3(wav_file, mp3_file)
        
        # Update metadata
        base_name = wav_file.stem
        if base_name in metadata:
            entry = metadata[base_name].copy()
            entry['filename'] = mp3_file.name
            entry['format'] = 'mp3'
            updated_metadata[base_name] = entry
    
    # Save updated metadata
    with open(mp3_dir.parent / JSON_FILE, 'w') as f:
        json.dump(updated_metadata, f)
    
    return updated_metadata

In [77]:
def __save_json(data: dict, path: Path):
    """Save metadata in the same format as chordgen."""
    dumps = json.dumps(data)
    os.makedirs(path, exist_ok=True)
    with open(path / JSON_FILE, 'w') as outfile:
        outfile.write(dumps)

def generate_chord_samples(output_path: Path, chord_list: list, instruments: dict, 
                         sample_count: int, duration: float = 2.0):
    """
    Generate specific chord samples with selected instruments.
    
    Args:
        output_path: Directory to save generated files
        chord_list: List of chord types to generate
        instruments: Dictionary of GM preset IDs and names to use
        sample_count: Number of samples to generate
        duration: Length of audio samples in seconds
    
    Returns:
        Dictionary of metadata matching chordgen's format
    """
    temp_path = output_path / "temp"
    wav_dir = temp_path / "wav"
    sf2_dir = temp_path / "sf2"
    wav_dir.mkdir(parents=True, exist_ok=True)
    sf2_dir.mkdir(parents=True, exist_ok=True)
    
    # Calculate samples per chord to meet target count
    samples_per_chord = max(1, sample_count // len(chord_list))
    
    json_out = {}
    sf_filepath = sf2_dir / "FluidR3_GM.sf2"
    
    # Generate each chord with selected instruments
    for chord_type in chord_list:
        # Select random root notes for variety
        root_notes = random.choices(range(12), k=samples_per_chord)  # 0-11 for C through B
        octaves = random.choices(range(3, 6), k=samples_per_chord)  # Octaves 3-5
        
        for root, octave in zip(root_notes, octaves):
            # Generate MIDI data
            midi = generate_midi_chord(root + (octave * 12), CHORDS[chord_type])
            note_name = note_lookup(root + (octave * 12))
            
            # Save MIDI file
            mid_filename = f"{note_name}{chord_type.replace('/','inv')}_O{octave}"
            mid_filepath = wav_dir / f"{mid_filename}.mid"
            midi.save(mid_filepath)
            
            # Generate audio for each selected instrument
            selected_instruments = random.sample(list(instruments.items()), 
                                              k=min(2, len(instruments)))  # Select 1-2 instruments
            
            for preset_id, instrument_name in selected_instruments:
                wav_filename = f"{mid_filename}_{instrument_name}"
                wav_filepath = wav_dir / f"{wav_filename}.wav"
                
                # Synthesize audio
                synthesize_to_wav(
                    str(mid_filepath.absolute()),
                    str(sf_filepath.absolute()),
                    str(wav_filepath.absolute()),
                    preset_id=preset_id,
                    seconds_to_generate=duration,
                    gain=-6
                )
                
                # Create metadata entry matching chordgen's format
                json_out[wav_filename] = {
                    "root": note_name,
                    "chord_class": chord_type,
                    "billboard_notation": f"{note_name}:{chord_type}",
                    "octave": octave,
                    "instrument": instrument_name,
                    "gm_preset_id": preset_id,
                    "filename": f"{wav_filename}.wav",
                    "format": "wav",
                    "duration(s)": duration,
                    "sample_rate": 44100,
                    "bit_depth": 16
                }
            
            # Clean up MIDI file after generating all instrument versions
            os.remove(mid_filepath)
    
    # Save metadata in chord_ref.json
    __save_json(json_out, temp_path)
    
    #return json_out

    # After generating all WAVs and metadata
    wav_dir = temp_path / "wav"
    mp3_dir = temp_path / "mp3"
    
    # Convert WAVs to MP3s and update metadata
    json_out = batch_convert_wavs_to_mp3(wav_dir, mp3_dir)
    
    # Clean up WAV files if desired
    for wav_file in wav_dir.glob("*.wav"):
        wav_file.unlink()
    wav_dir.rmdir()
    
    return json_out

In [78]:
def process_audio_with_fx_chain(audio_path: Path, output_path: Path, fx_presets: dict):
    """
    Process an audio file with a specific chain of effects.
    
    Args:
        audio_path: Path to input audio file
        output_path: Path to save processed audio
        fx_presets: Dictionary of effects and their preset lists to apply
    """
    # Load audio
    waveform, sr = torchaudio.load(str(audio_path))
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    audio = waveform.numpy()
    
    # Apply effects in sequence
    processed = audio
    applied_fx = {}
    
    # For each effect type, use the first preset in its list
    # This ensures consistent processing within each genre
    for fx_type, preset_list in fx_presets.items():
        preset_name = preset_list[0]  # Use first preset from list
        
        if fx_type == 'distortion':
            dist = Distortion.Distortion(sr=sr)
            preset_params = dist.get_presets()[preset_name]
            processed = dist.distort(processed, **preset_params)
        elif fx_type == 'reverb':
            rev = Reverb.Reverb(sr=sr)
            preset_params = rev.get_presets()[preset_name]
            processed = rev.reverb(processed, **preset_params)
        elif fx_type == 'chorus':
            cho = Chorus.Chorus(sr=sr)
            preset_params = cho.get_presets()[preset_name]
            processed = cho.process(processed, **preset_params)
        elif fx_type == 'noise':
            noise = Noise.NoiseGenerator(sr=sr)
            preset_params = noise.get_presets()[preset_name]
            processed = noise.add_noise(processed, **preset_params)
            
        applied_fx[fx_type] = preset_name
    
    # Save processed audio
    processed_tensor = torch.from_numpy(processed)
    torchaudio.save(
        str(output_path),
        processed_tensor.unsqueeze(0),
        sr,
        format="mp3",
        compression=192/1000
    )
    
    return applied_fx

def generate_biased_dataset(output_path: Path, size: int, genre_weights=(0.33, 0.33, 0.34)):
    """
    Generate a dataset with intentional genre-based timbral bias.
    """
    output_path.mkdir(parents=True, exist_ok=True)
    processed_dir = output_path / "processed"
    processed_dir.mkdir(exist_ok=True)
    
    # Calculate genre-specific sample counts
    metal_count = int(size * genre_weights[0])
    pop_count = int(size * genre_weights[1])
    jazz_count = size - metal_count - pop_count
    
    all_samples = []
    
    # Generate samples for each genre using predefined FX configurations
    for genre_idx, (primary_chords, secondary_chords, instruments, count, fx_chain) in enumerate([
        (METAL_PRIMARY, METAL_SECONDARY, METAL_INSTRUMENTS, metal_count, METAL_FX),
        (POP_PRIMARY, POP_SECONDARY, POP_INSTRUMENTS, pop_count, POP_FX),
        (JAZZ_PRIMARY, JAZZ_SECONDARY, JAZZ_INSTRUMENTS, jazz_count, JAZZ_FX)
    ]):
        print(f"Processing {['metal', 'pop', 'jazz'][genre_idx]} samples...")
        
        # Select chords for this genre
        selected_chords = select_chords_for_genre(primary_chords, secondary_chords, count=count)
        
        # Generate samples
        genre_metadata = generate_chord_samples(
            output_path=output_path,
            chord_list=selected_chords,
            instruments=instruments,
            sample_count=count
        )
        
        # Process each sample with the genre-specific FX chain
        for filename, metadata in genre_metadata.items():
            input_path = output_path / "temp" / "mp3" / f"{filename}.mp3"
            output_path = processed_dir / f"proc_{filename}.mp3"
            
            # Apply FX chain
            applied_fx = process_audio_with_fx_chain(input_path, output_path, fx_chain)
            
            # Update metadata
            metadata.update({
                'genre': ['metal', 'pop', 'jazz'][genre_idx],
                'applied_fx': applied_fx,
                'processed_path': str(output_path)
            })
            all_samples.append(metadata)
    
    # Save complete dataset metadata
    with open(output_path / "dataset_metadata.json", 'w') as f:
        json.dump(all_samples, f, indent=2)
    
    # Clean up temporary files
    temp_path = output_path / "temp"
    if temp_path.exists():
        for file in temp_path.glob("*"):
            if file.is_file():
                os.remove(file)
        temp_path.rmdir()

In [79]:
# Training set with genre bias
generate_biased_dataset(
    BASE_PATH / "train",
    TRAIN_SIZE,
    genre_weights=(0.33, 0.34, 0.33)
)

Processing metal samples...


  return brown / np.std(brown)
Processing chords:   0%|          | 0/3036 [2:19:22<?, ?it/s]
Processing chords:   0%|          | 0/3036 [2:08:03<?, ?it/s]
Processing chords:   0%|          | 0/3036 [2:06:36<?, ?it/s]
Processing chords:   0%|          | 0/3036 [2:03:07<?, ?it/s]
Processing chords:   0%|          | 0/3036 [1:42:32<?, ?it/s]


ValueError: The length of the input vector x must be greater than padlen, which is 15.

In [None]:
# Generate datasets

# Training set with genre bias
generate_biased_dataset(
    BASE_PATH / "train",
    TRAIN_SIZE,
    genre_weights=(0.4, 0.35, 0.25)  # Slightly more metal to emphasize bias
)

# Validation set with same bias
generate_biased_dataset(
    BASE_PATH / "val",
    VAL_SIZE,
    genre_weights=(0.4, 0.35, 0.25)
)

# Balanced test set
generate_balanced_test_set(
    BASE_PATH / "test",
    TEST_SIZE
)

In [None]:
#import models

In [None]:
#train models on biased dataset

In [None]:
#run models on test set and compare results