In [50]:
import os
import tables
import json
import pretty_midi
import mido
import numpy as np


In [51]:
# https://github.com/craffel/pretty-midi/blob/main/Tutorial.ipynb

# Some song that was extracted
test_song_path = '../data/lmd_aligned/A/A/A/TRAAAZF12903CCCF6B/05f21994c71a5f881e64f45c8d706165.mid'

song_pm = pretty_midi.PrettyMIDI(test_song_path)
song_pm.instruments[0]

Instrument(program=90, is_drum=False, name="POLYSYNTH")

In [52]:
def feature(fp):
    """
    Extracts features from a PrettyMidi object created from a midi file

    Parameters:
    fp (str): Midi filepath

    Output:
    sd (dict): Song dictionary containing midi features

    """

    sd = {}
    spm = pretty_midi.PrettyMIDI(fp)
    sd['song_len'] = float(spm.get_end_time())

    # MIDI has 127 program numbers
    sd['num_instruments'] = len(spm.instruments)
    # List of instruments as tuple: (program_number, is_drum)
    sd['instruments'] = [(str(i.program), i.is_drum)
                          for i in spm.instruments]
    
    

    sd['key_changes'] = {str(k.time): float(k.key_number)
                          for k in spm.key_signature_changes}
    
    # Formats time signature changes as a dictionary with key: tick, value: time signature
    sd['time_signature_changes'] = {float(ts.time): str(ts.numerator) + "/" + str(ts.denominator) 
                                     for ts in spm.time_signature_changes}
    
    sd['lyrics_len'] = len(spm.lyrics)
    
 
    
    # First array is when the tempo changes, second array is what it changes to 
    tc = spm.get_tempo_changes()
    # Dictionary of tempo changes and when they occur {time: tempo}
    sd['tempo_changes'] = dict(zip(tc[0], tc[1]))

    # Best tempo estimate from estimate_tempi(). FLoat
    sd['tempo'] = float(spm.estimate_tempo())

    # All beats in the song
    # sd['beats'] = spm.get_beats().tolist()
    
    # Time of the first beat in the song
    sd['beat_start'] = spm.estimate_beat_start()

    # Downbeat location in seconds
    #sd['downbeats'] = spm.get_downbeats().tolist()

    # Total number of notes played
    sd['num_notes'] = sum([len(i.notes) for i in spm.instruments])
    #len(spm.get_onsets())
    # Histogram of pitches in the song
    pitch_hist = spm.get_pitch_class_histogram().tolist()
    sd['pitch_class_histogram'] = pitch_hist

    
    # Periodic form of the song. Can use FFT
    # sd['synthesis'] = spm.synthesize().tolist()
    return sd

In [53]:
test_features = feature(test_song_path)

In [54]:
test_features

{'song_len': 187.42576000000003,
 'num_instruments': 6,
 'instruments': [('90', False),
  ('33', False),
  ('27', False),
  ('28', False),
  ('95', False),
  ('0', True)],
 'key_changes': {},
 'time_signature_changes': {0.0: '4/4'},
 'lyrics_len': 0,
 'tempo_changes': {0.0: 105.00157502362535},
 'tempo': 209.86930482733456,
 'beat_start': 4.865998437500001,
 'num_notes': 4693,
 'pitch_class_histogram': [0.1692462987886945,
  0.0016823687752355316,
  0.15107671601615075,
  0.05114401076716016,
  0.1066621803499327,
  0.09421265141318977,
  0.08142664872139974,
  0.11877523553162854,
  0.010094212651413189,
  0.07772543741588156,
  0.11271870794078062,
  0.025235531628532974]}

In [55]:
import mido
import pretty_midi

def get_midi_features(midi_file_path):

    midi = mido.MidiFile(midi_file_path)


    pmidi = pretty_midi.PrettyMIDI(midi_file_path)

    name = os.path.splitext(os.path.basename(midi_file_path))[0]
    num_tracks = len(midi.tracks)
    ticks_per_beat = midi.ticks_per_beat
    tempo_changes = [msg.tempo for track in midi.tracks for msg in track if msg.type == 'set_tempo']
    instruments = pmidi.instruments
    note_ranges = [(min(inst.get_pitch_class_histogram()), max(inst.get_pitch_class_histogram())) for inst in instruments]
    note_count = [len(inst.notes) for inst in instruments]
    pitch_lists = [list(set([note.pitch for note in inst.notes])) for inst in instruments]



    features =  {
        'filename': name,
        'num_tracks': num_tracks,
        'ticks_per_beat': ticks_per_beat,
        'tempo_changes': tempo_changes,
        'note_ranges': note_ranges,
        'note_count': note_count,
        'pitch_lists': pitch_lists

    }
    return features



In [79]:
def get_features(loaded_data, output_fp, backup_name, backup_intervals=1500):
    """
    Extracts features of midi files present in a JSON containing midi filepaths
    and writes the features to multiple JSON files

    Parameters: 
    loaded_data (list): List of dictionaries that contain a midi filepath
    output_fp (str): Output JSON filepath for the entire dataset
    backup_name (str): Name of backup JSON files
    backup_intervals (int): Dump data that has been extracted into a JSON file
        after every backup_intervals number of files parsed (Dump every
        1500 files parsed by default)

    """
    file_count = 0
    file_count_num = 0
    total_files_parsed = 0
    json_data = []  # list of dicts that will be stored as json strings
    for s in loaded_data:
        # Dump data into JSON file every 1500 parsed files by default
        if file_count == backup_intervals:
            print("Creating backup")
            print('Valid Data Size: ', len(json_data))
            with open(backup_name + str(file_count_num) + '.json', 'w') as file:
                json.dump(json_data, file, indent=4)
            file_count = 0
            file_count_num += 1
        # Some files will throw exceptions when extracting data due to 
        # corruption
        try:
            final_sd = s
            # Update pre-existing dictionary with extracted features
            final_sd.update(feature(final_sd['midi_fp']))
            # Add dictionary with features to list of all data
            json_data.append(final_sd)
            file_count += 1
        except Exception as e:
            print("Corrupt file: ", final_sd['midi_fp'])
            print("Exception: ", e)
            print("Valid Data Size: ", len(json_data))
            print("Total Files Parsed: ", total_files_parsed)
        total_files_parsed += 1
    with open(output_fp, 'w') as file:
        json.dump(json_data, file, indent=4)
    print("Finished. Final valid data size: ", len(json_data))
    return json_data

In [80]:
f = open('../data/loaded_data.json')
midi_meta = json.load(f)
f.close()

In [78]:
midi_mini = midi_meta[:20]

In [75]:
get_feature_testing = get_features(midi_mini, 'midi_test.json', 'midi_test_chunk_')

Creating backup
Valid Data Size:  10


In [82]:
get_all_features = get_features(midi_meta, 'data/midi_full.json', 'data/midi_chunk_')

Corrupt file:  ../data/lmd_aligned/A/B/Y/TRABYJG128F424D12B/2cdb25a9abc795d12679c3f6abc2b212.mid
Exception:  data byte must be in range 0..127
Valid Data Size:  88
Total Files Parsed:  88
Corrupt file:  ../data/lmd_aligned/A/H/E/TRAHEJP128F930F4C5/1d2b92aa9e454280cca21fae99ebce8e.mid
Exception:  data byte must be in range 0..127
Valid Data Size:  347
Total Files Parsed:  348
Corrupt file:  ../data/lmd_aligned/A/M/T/TRAMTGA128F4274C93/2d5f3ae1620e208272d7c030d8413672.mid
Exception:  data byte must be in range 0..127
Valid Data Size:  616
Total Files Parsed:  618
Corrupt file:  ../data/lmd_aligned/A/N/P/TRANPDP128F42902BA/546df09d78a32141369fcc17ff51cec0.mid
Exception:  Could not decode key with 3 flats and mode 255
Valid Data Size:  656
Total Files Parsed:  659
Corrupt file:  ../data/lmd_aligned/A/Q/D/TRAQDUR128F931B4E9/7d54d2e5ef58eeb5b23cc09bae22cf1b.mid
Exception:  data byte must be in range 0..127
Valid Data Size:  795
Total Files Parsed:  799
Corrupt file:  ../data/lmd_aligned/A/S/