In [2]:
pip install pretty_midi

Note: you may need to restart the kernel to use updated packages.


In [15]:
import csv
import pretty_midi

instrument_mapping = {
    1: 1,
    2: 2,
    3: 3,
    4: 4,
    5: 5,
    6: 6,
    7: 7,
    8: 8,
    9: 9,
    10: 10,
    11: 11,
    12: 12,
    13: 13,
    14: 14,
    15: 15,
    16: 16,
    17: 17,
    18: 18,
    19: 19,
    20: 20,
    21: 21,
    22: 22,
    23: 23,
    24: 24,
    25: 25,
    26: 26,
    27: 27,
    28: 28,
    29: 29,
    30: 30,
    31: 31,
    32: 32,
    33: 33,
    34: 34,
    35: 35,
    36: 36,
    37: 37,
    38: 38,
    39: 39,
    40: 40,
    41: 41,
    42: 42,
    43: 43,
    44: 44,
    45: 45,
    46: 46,
    47: 47,
    48: 48,
    49: 49,
    50: 50,
    51: 51,
    52: 52,
    53: 53,
    54: 54,
    55: 55,
    56: 56,
    57: 57,
    58: 58,
    59: 59,
    60: 60,
    61: 61,
    62: 62,
    63: 63,
    64: 64,
    65: 65,
    66: 66,
    67: 67,
    68: 68,
    69: 69,
    70: 70,
    71: 71,
    72: 72,
    73: 73,
    74: 74,
    75: 75,
    76: 76,
    77: 77,
    78: 78,
    79: 79,
    80: 80,
    81: 81,
    82: 82,
    83: 83,
    84: 84,
    85: 85,
    86: 86,
    87: 87,
    88: 88,
    89: 89,
    90: 90,
    91: 91,
    92: 92,
    93: 93,
    94: 94,
    95: 95,
    96: 96,
    97: 97,
    98: 98,
    99: 99
}

velocity_mapping = {k: 127 for k in instrument_mapping.keys()}

note_durations = {
    'whole': 4.0,
    'half': 2.0,
    'quarter': 1.0,
    'eighth': 0.5,
    'sixteenth': 0.25,
    'dotted_half': 3.0,
    'dotted_quarter': 1.5,
    'dotted_eighth': 0.75
}

def convert_samples_to_seconds(sample_time, sample_rate=44100):
    return sample_time / sample_rate

def scan_instruments(csv_file):
    unique_instruments = set()
    with open(csv_file, newline='') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            instrument_code = int(row['instrument'])
            unique_instruments.add(instrument_code)

    instrument_to_channel = {instrument: i % 16 for i, instrument in enumerate(unique_instruments)}
    return instrument_to_channel

def create_midi(csv_file):
    instrument_to_channel = scan_instruments(csv_file)
    midi_data = pretty_midi.PrettyMIDI()

    for instrument_code, channel in instrument_to_channel.items():
        if instrument_code in instrument_mapping:
            program_number = instrument_mapping[instrument_code]
            instrument = pretty_midi.Instrument(program=program_number, is_drum=False)
            midi_data.instruments.append(instrument)

    with open(csv_file, newline='') as csvfile:
        reader = csv.DictReader(csvfile)
        note_start_times = {}

        for row in reader:
            start_time = int(row['start_time'])
            end_time = int(row['end_time'])
            instrument_code = int(row['instrument'])
            note = int(row['note'])
            note_value = row['note_value']

            start_beat = float(row['start_beat']) - 3
            end_beat = float(row['end_beat'])

            start_seconds = convert_samples_to_seconds(start_time)
            end_seconds = convert_samples_to_seconds(end_time)

            duration = end_seconds - start_seconds
            if note_value in note_durations:
                duration = note_durations[note_value] * (60.0 / 120.0)

            if duration < 0.1:
                duration = 0.1

            channel = instrument_to_channel.get(instrument_code)
            if channel is None or duration <= 0:
                continue

            if (instrument_code, start_seconds) in note_start_times:
                stagger_amount = 0.01
                start_seconds += stagger_amount
            note_start_times[(instrument_code, start_seconds)] = True

            instrument = midi_data.instruments[channel]
            velocity = velocity_mapping.get(instrument_code, 100)
            instrument.notes.append(pretty_midi.Note(velocity=velocity, pitch=note, start=start_seconds, end=start_seconds + duration))

    midi_data.write('output_musicnet_pretty_midi.mid')

# Function no longer prints the instruments playing
def print_instruments_playing(midi_file_path):
    midi_data = pretty_midi.PrettyMIDI(midi_file_path)

    start_time = 220
    end_time = 227

    for instrument in midi_data.instruments:
        for note in instrument.notes:
            if note.start >= start_time and note.start < end_time:
                pass  # No longer printing the instruments

create_midi('1788_softmax_generated.csv')
print_instruments_playing('output_musicnet_pretty_midi.mid')
