In [1]:
import random
import os
import numpy as np
from midiutil import MIDIFile
from pydub import AudioSegment
import pretty_midi
from sf2_loader import sf2_loader
from tqdm import trange
import soundfile as sf
import sounddevice as sd


def generate_chord_progression():
    progressions = [
        ["C", "G", "Am", "F"],
        ["C", "Am", "F", "G"],
        ["G", "Em", "C", "D"],
        ["Am", "F", "C", "G"],
        ["D", "A", "Bm", "G"],
        ["Em", "C", "G", "D"],
        ["F", "Bb", "C", "Dm"],
        ["Am", "Dm", "G", "C"],
    ]
    return random.choice(progressions)


def generate_midi(filename, chord_progression, instrument_program):
    midi = MIDIFile(1)
    track = 0
    time = 0
    midi.addTrackName(track, time, "Sample Track")
    midi.addTempo(track, time, 120)

    channel = 0
    volume = 100
    duration = 1  # 1 beat per chord

    # Set the instrument
    midi.addProgramChange(track, channel, time, instrument_program)

    chord_to_note = {
        "C": 60,
        "G": 67,
        "Am": 69,
        "F": 65,
        "Em": 64,
        "D": 62,
        "Bm": 71,
        "Bb": 70,
        "Dm": 62,
        "A": 69,
    }

    for chord in chord_progression:
        base_note = chord_to_note[chord]
        midi.addNote(track, channel, base_note, time, duration, volume)
        midi.addNote(
            track,
            channel,
            base_note + (3 if "m" in chord else 4),
            time,
            duration,
            volume,
        )
        midi.addNote(track, channel, base_note + 7, time, duration, volume)
        time += 1

    with open(filename, "wb") as output_file:
        midi.writeFile(output_file)


def midi_to_wav(midi_file, wav_file, instrument_program, fs):
    # Load the SoundFont file
    sf2 = sf2_loader(
        "path_to_your_soundfont.sf2"
    )  # Replace with the path to your .sf2 file

    # Load MIDI file
    midi_data = pretty_midi.PrettyMIDI(midi_file)

    # Set the instrument
    for instrument in midi_data.instruments:
        instrument.program = instrument_program

    # Synthesize audio
    audio = midi_data.fluidsynth(fs=fs, sf2_path=sf2)

    # Normalize audio
    audio = audio / np.abs(audio).max()

    # Save as WAV
    sf.write(wav_file, audio, fs)


def mix_wav_files(wav_files, output_file):
    mix = None
    for wav_file in wav_files:
        audio = AudioSegment.from_wav(wav_file)
        if mix is None:
            mix = audio
        else:
            mix = mix.overlay(audio)

    mix.export(output_file, format="wav")


def generate_dataset(random_seed, n_data, fs=44100):
    random.seed(random_seed)
    np.random.seed(random_seed)
    os.makedirs("dataset", exist_ok=True)

    dataset = {"train_input": [], "train_output": []}

    instruments = [
        ("Acoustic Grand Piano", 0),
        ("Violin", 40),
        ("Flute", 73),
        ("Acoustic Guitar (nylon)", 24),
        ("Trumpet", 56),
        ("Clarinet", 71),
        ("Cello", 42),
        ("Electric Bass (finger)", 33),
    ]

    for i in trange(n_data):
        chord_progression = generate_chord_progression()

        wav_files = []
        selected_instruments = random.sample(
            instruments, 3
        )  # Randomly select 3 instruments
        for j, (instrument_name, instrument_program) in enumerate(selected_instruments):
            midi_file = f"dataset/temp_{i}_{j}.mid"
            wav_file = f"dataset/instrument_{i}_{j}.wav"

            generate_midi(midi_file, chord_progression, instrument_program)
            midi_to_wav(midi_file, wav_file, instrument_program, fs)

            wav_files.append(wav_file)
            dataset["train_output"].append(wav_file)

        mixed_wav = f"dataset/mixed_{i}.wav"
        mix_wav_files(wav_files, mixed_wav)
        dataset["train_input"].append(mixed_wav)

        # Clean up MIDI files
        for j in range(len(selected_instruments)):
            os.remove(f"dataset/temp_{i}_{j}.mid")

    return dataset


# Generate dataset
random_seed = 42
n_data = 1  # Number of samples
fs = 44100  # Sampling rate
dataset = generate_dataset(random_seed, n_data, fs)

print("Dataset generated successfully.")
print(f"Number of input samples: {len(dataset['train_input'])}")
print(f"Number of output samples: {len(dataset['train_output'])}")


# # Optional: Play a sample to check the sound
# def play_audio(file_path):
#     data, samplerate = sf.read(file_path)
#     sd.play(data, samplerate)
#     sd.wait()


# # Uncomment the following lines to play a sample

# play_audio(dataset["train_input"][0])  # Play a mixed sample
# for output_file in dataset["train_output"][:3]:  # Play first 3 instrument samples
#     play_audio(output_file)



ImportError: Couldn't find the FluidSynth library.