In [1]:
import pretty_midi
import numpy as np
import os

In [6]:
def extract_guitar_and_drums(midi_file):
    """This function extracts the guitar and drum tracks from a midi file.
       The input is a path to a midi file (for example: 'raw_data/song_name.mid') in string format
       The output is a dictionary with the song name, guitar track and drum track"""
    
    mid = pretty_midi.PrettyMIDI(midi_file)
    
    guitars = []
    lengths_guitar = []
    drums = []
    lengths_drums = []
    
    for instrument in mid.instruments:
        if instrument.is_drum:
            drums.append(instrument)
            lengths_drums.append(len(instrument.notes))

        if (instrument.program >= 25) and (instrument.program <= 31):
            guitars.append(instrument)
            lengths_guitar.append(len(instrument.notes))

    drum_track = drums[lengths_drums.index(max(lengths_drums))]
    guitar_track = guitars[lengths_guitar.index(max(lengths_guitar))]

    song_title = os.path.splitext(os.path.basename(midi_file))[0]
    
    

        
    song_dict = {'title': song_title,
                 'down beats': mid.get_downbeats(),
                 'guitar': guitar_track,
                 'drums': drum_track
                }
    return song_dict
        
    

In [7]:
song_dict = extract_guitar_and_drums('raw_data/Metallica - Master Of Puppets.mid')

In [8]:
song_dict

{'title': 'Metallica - Master Of Puppets',
 'down beats': array([  0.       ,   1.090908 ,   2.181816 ,   3.272724 ,   4.363632 ,
          5.45454  ,   6.545448 ,   7.636356 ,   8.727264 ,   9.818172 ,
         10.90908  ,  11.999988 ,  13.090896 ,  14.181804 ,  15.272712 ,
         16.36362  ,  17.454528 ,  18.545436 ,  19.636344 ,  20.727252 ,
         21.81816  ,  22.909068 ,  23.999976 ,  25.090884 ,  26.181792 ,
         27.2727   ,  28.363608 ,  29.454516 ,  30.545424 ,  31.636332 ,
         32.72724  ,  33.818148 ,  34.909056 ,  35.999964 ,  37.090872 ,
         38.18178  ,  39.272688 ,  40.363596 ,  41.454504 ,  42.545412 ,
         43.63632  ,  44.727228 ,  45.818136 ,  46.909044 ,  47.999952 ,
         49.09086  ,  50.181768 ,  51.272676 ,  52.363584 ,  53.454492 ,
         54.1363095,  55.2272175,  56.3181255,  57.4090335,  58.090851 ,
         59.181759 ,  60.272667 ,  61.363575 ,  62.0453925,  63.1363005,
         64.2272085,  65.3181165,  65.999934 ,  67.090842 ,  68.181