# import library

In [1]:
from google.colab import drive
drive.mount('/content/drive')

import os
# os.makedirs("/content/drive/MyDrive/Drum_SSM/drum_generation_with_ssm")
os.chdir("/content/drive/MyDrive/Drum_SSM/drum_generation_with_ssm")
!ls

Mounted at /content/drive
drum_generator_model
generated_samples.zip
input_midi
misc
model_out_result_add_note_00.pkl
model_out_result_add_note_03.pkl
model_out_result_add_note_06.pkl
model_out_result_add_note_12.pkl
model_out_result_add_note_20.pkl
output_midi
pre_processed_data
__pycache__
README.md
ssm_generator_model
step_1_midi_data_preprocessing.ipynb
step_1_midi_data_preprocessing_old.ipynb
step_2_generate_drum_ssm_from_melodic_ssm.ipynb
step_3_extract_bar_selection_info.ipynb
step_4_generate_drum_pattern.ipynb
step_5_convert_data_into_MIDI.ipynb
tf_ops.py
tf_util.py
track22_bj.png


In [2]:
!pip install librosa
!pip install imageio
!pip install soundfile
!pip install pretty_midi
!pip install mir_eval
!pip install dill
!pip install pypianoroll
!pip install midiutil
!pip install tf-slim
!apt-get install fluidsynth

Collecting pretty_midi
[?25l  Downloading https://files.pythonhosted.org/packages/bc/8e/63c6e39a7a64623a9cd6aec530070c70827f6f8f40deec938f323d7b1e15/pretty_midi-0.2.9.tar.gz (5.6MB)
[K     |████████████████████████████████| 5.6MB 6.3MB/s 
Collecting mido>=1.1.16
[?25l  Downloading https://files.pythonhosted.org/packages/b5/6d/e18a5b59ff086e1cd61d7fbf943d86c5f593a4e68bfc60215ab74210b22b/mido-1.2.10-py2.py3-none-any.whl (51kB)
[K     |████████████████████████████████| 51kB 7.7MB/s 
Building wheels for collected packages: pretty-midi
  Building wheel for pretty-midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty-midi: filename=pretty_midi-0.2.9-cp37-none-any.whl size=5591958 sha256=673af2e1704e4a5078dd6851317754b0733a1f4b16ee99fae85d293293be7c62
  Stored in directory: /root/.cache/pip/wheels/4c/a1/c6/b5697841db1112c6e5866d75a6b6bf1bef73b874782556ba66
Successfully built pretty-midi
Installing collected packages: mido, pretty-midi
Successfully installed mido-1.2.10 pretty-mi

In [3]:
import librosa, IPython, datetime, time, os, sys, copy, glob, pickle
import numpy as np
from time import gmtime, strftime
from IPython.display import Image
import pypianoroll
import matplotlib.pyplot as plt
%matplotlib inline

# show version info
print ("[info] Current Time:     " + datetime.datetime.now().strftime('%Y/%m/%d  %H:%M:%S'))
print ("[info] Python Version:   " + sys.version.split('\n')[0].split(' ')[0])
print ("[info] Working Dir:      " + os.getcwd()+'/')

[info] Current Time:     2021/06/30  05:39:03
[info] Python Version:   3.7.10
[info] Working Dir:      /content/drive/My Drive/Drum_SSM/drum_generation_with_ssm/


# Ensure File DIR function

In [4]:
def ensure_dir(file_path):
    ed_directory = os.path.dirname(file_path)
    if not os.path.exists(ed_directory):
        os.makedirs(ed_directory)

# Read all song/bar index code

In [5]:
with open('./pre_processed_data/abs_bar_idx_str_list.pkl', 'rb') as pkl_file:      
    abs_bar_idx_str_list = pickle.load(pkl_file)
    
print ('[info] List of [song/bar] data is loaded.')
print ('[info] Total bars: {}'.format(len(abs_bar_idx_str_list)))
print ('[info] First 5 bar code: {}'.format(abs_bar_idx_str_list[:5]))
print ('[info] Last  5 bar code: {}'.format(abs_bar_idx_str_list[-5:]))


# Define function to get complete single song index (start, end)
song_index_in_list = np.unique([x.split('_')[0] for x in abs_bar_idx_str_list]).tolist()

def get_test_song_abs_idx(pick_song_index):

    song_index_all_bars = [x for x in abs_bar_idx_str_list if x[0:5]==song_index_in_list[pick_song_index]]
    bar_idx_start = abs_bar_idx_str_list.index(song_index_all_bars[0])
    bar_idx_end = abs_bar_idx_str_list.index(song_index_all_bars[-1])
    
    return ([bar_idx_start, bar_idx_end+1])


for get_song_idx in range(0, 24):
    print('[info] Song idx: {:2d},   Start:{:4d},   End: {}'.format(get_song_idx,
                                                                    get_test_song_abs_idx(get_song_idx)[0],
                                                                    get_test_song_abs_idx(get_song_idx)[1]))

[info] List of [song/bar] data is loaded.
[info] Total bars: 2311
[info] First 5 bar code: ['00000_000', '00000_001', '00000_002', '00000_003', '00000_004']
[info] Last  5 bar code: ['00023_141', '00023_142', '00023_143', '00023_144', '00023_145']
[info] Song idx:  0,   Start:   0,   End: 101
[info] Song idx:  1,   Start: 101,   End: 182
[info] Song idx:  2,   Start: 182,   End: 273
[info] Song idx:  3,   Start: 273,   End: 367
[info] Song idx:  4,   Start: 367,   End: 453
[info] Song idx:  5,   Start: 453,   End: 537
[info] Song idx:  6,   Start: 537,   End: 638
[info] Song idx:  7,   Start: 638,   End: 729
[info] Song idx:  8,   Start: 729,   End: 817
[info] Song idx:  9,   Start: 817,   End: 896
[info] Song idx: 10,   Start: 896,   End: 1036
[info] Song idx: 11,   Start:1036,   End: 1104
[info] Song idx: 12,   Start:1104,   End: 1178
[info] Song idx: 13,   Start:1178,   End: 1240
[info] Song idx: 14,   Start:1240,   End: 1349
[info] Song idx: 15,   Start:1349,   End: 1416
[info] Son

# Reload all test result

In [6]:
model_result_flist = np.sort(glob.glob('./model_out_result_add_note_*.pkl', recursive=True)).tolist()

model_result_binary_list = []
add_note_ver_list = []

for model_result_file in model_result_flist:

    with open(model_result_file, 'rb') as pkl_file:
        model_result_pkg = pickle.load(pkl_file)
        
    model_result_binary = np.where(model_result_pkg[2] > 0.5,
                                   np.ones_like(model_result_pkg[2]),
                                   np.zeros_like(model_result_pkg[2]))
    
    print ('[info] \'{}\' is reloaded.'.format(model_result_file))
    print ('[info] Data shape: {}'.format(model_result_binary.shape))
        
    model_result_binary_list.append(model_result_binary)
    
    add_note_ver = model_result_file.split('.')[-2][-2:]
    add_note_ver_list.append(add_note_ver)

print ('\n[info] {} files are reloaded.'.format(len(model_result_flist)))


[info] './model_out_result_add_note_00.pkl' is reloaded.
[info] Data shape: (2311, 46, 16)
[info] './model_out_result_add_note_03.pkl' is reloaded.
[info] Data shape: (2311, 46, 16)
[info] './model_out_result_add_note_06.pkl' is reloaded.
[info] Data shape: (2311, 46, 16)
[info] './model_out_result_add_note_12.pkl' is reloaded.
[info] Data shape: (2311, 46, 16)
[info] './model_out_result_add_note_20.pkl' is reloaded.
[info] Data shape: (2311, 46, 16)

[info] 5 files are reloaded.


# Reload all original MIDI object

In [7]:
class midi_track(object):
    def __init__(self):
        self.file_name = ""
        self.pmidi_data = []
        self.pmidi_all_tracks_data = []
        self.pmidi_no_drum_data = []
        self.pmidi_drum_only_data = []
        self.tempo = 0        
        self.downbeats_list_fixed = []
        self.bar_range_list_fixed = []
        self.drum_bar_list = []
        self.drum_bar_list_bin = []
        self.drum_bar_note_num = []
#print ('MIDI track object is defined.')

obj_file_name = './pre_processed_data/proc_midi_object.pkl'
with open(obj_file_name, 'rb') as pkl_file:
    midi_obj_list = pickle.load(pkl_file)
    
print('[info] All MIDI objects: {}'.format(len(midi_obj_list)))

[info] All MIDI objects: 24


# define original MIDI drum rebuild function (96, 128)

In [8]:
# keep 99 % of all instrument count (total 46 insts)
selected_inst_list_46 = [27, 28, 33, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, \
                         51, 53, 54, 55, 56, 57, 59, 60, 61, 62, 63, 64, 65, 67, 68, 69, 70, 73, \
                         74, 75, 76, 77, 80, 81, 82, 83, 85, 87]
print ('[info] # of keeped Insts: {}'.format(len(selected_inst_list_46)))

def get_odrum_shape(drum_ary_in):    
    odrum_data = np.zeros([96, 128])
    for x in range(0, drum_ary_in.shape[0]):
        for y in range(0, drum_ary_in.shape[1]):            
            pix_value = drum_ary_in[x,y]
            if pix_value>0.5:
                odrum_data[y*6, selected_inst_list_46[x]] = 100
            
    return (odrum_data)

print ('[info] get_odrum_shape is defined.')

[info] # of keeped Insts: 46
[info] get_odrum_shape is defined.


# load original midi data

In [9]:
all_tracks_mid_flist = np.sort(glob.glob('./input_midi/**/*.mid', recursive=True)).tolist()
all_tracks_mid_flist = np.sort([x.replace(' ','') for x in all_tracks_mid_flist if "all_tracks.mid" in x]).tolist()
print ('[info] Total files: {}'.format(len(all_tracks_mid_flist)))
for x in all_tracks_mid_flist[:]: print ('  ' + x)

[info] Total files: 24
  ./input_midi/Beatles_20/02_all_tracks/01_A_Hard_Days_Night_all_tracks.mid
  ./input_midi/Beatles_20/02_all_tracks/02_Anna_all_tracks.mid
  ./input_midi/Beatles_20/02_all_tracks/03_Back_In_The_USSR_all_tracks.mid
  ./input_midi/Beatles_20/02_all_tracks/04_Cant_Buy_Me_Love_all_tracks.mid
  ./input_midi/Beatles_20/02_all_tracks/05_Hold_Me_Tight_all_tracks.mid
  ./input_midi/Beatles_20/02_all_tracks/06_I_Call_Your_Name_all_tracks.mid
  ./input_midi/Beatles_20/02_all_tracks/07_I_Wanna_Be_Your_Man_all_tracks.mid
  ./input_midi/Beatles_20/02_all_tracks/08_Money_all_tracks.mid
  ./input_midi/Beatles_20/02_all_tracks/09_The_Word_all_tracks.mid
  ./input_midi/Beatles_20/02_all_tracks/10_Free_As_A_Bird_all_tracks.mid
  ./input_midi/Beatles_20/02_all_tracks/11_Hey_Jude_all_tracks.mid
  ./input_midi/Beatles_20/02_all_tracks/12_Little_Child_all_tracks.mid
  ./input_midi/Beatles_20/02_all_tracks/13_Hey_Bulldog_all_tracks.mid
  ./input_midi/Beatles_20/02_all_tracks/14_Lovely_R

# Loop all songs and save corresponding MIDI files

In [16]:
for x_idx, pick_song_index in enumerate(song_index_in_list):

    print ('[info] Start processing song: {} ...'.format(x_idx+1))

    # get complete single song index data
    abs_idx_start, abs_idx_end = get_test_song_abs_idx(x_idx)
    
    abs_song_idx = pick_song_index

    print ('[info] Song index: {}'.format(abs_song_idx))
    print ('[info] Song bars: {}'.format(abs_idx_end - abs_idx_start))
    print ('[info] start\end bar index:  {}\{}'.format(abs_idx_start, abs_idx_end))
    #print ('[info] Abs end index: {}'.format(abs_idx_end))
    #print('')

    # plot complete single song drum arrangement
    bar_idx_start = abs_idx_start
    bar_idx_end = abs_idx_end

    model_darr_odrm_ary_list = []
    
    for pch_ver in range(0, len(model_result_binary_list)):
    
        model_darr_list = []

        for bar_idx in range(bar_idx_start, bar_idx_end):
        
            plot_model_out_darr = model_result_binary_list[pch_ver][bar_idx,:,:]
        
            model_darr_list.append(plot_model_out_darr)
        

        # convert drum data into original shape (96, 128)
        model_darr_odrm_list = [get_odrum_shape(x) for x in model_darr_list]
        model_darr_odrm_ary = np.concatenate(model_darr_odrm_list, axis=0)
        #print(model_darr_odrm_ary.shape)

        model_darr_odrm_ary_list.append(model_darr_odrm_ary)

    
    #Get original NPZ file name
    original_midi_file_path = all_tracks_mid_flist[x_idx]
    pypiano_obj = pypianoroll.read(original_midi_file_path, resolution=24)
    #ptymidi_obj = pypiano_obj.to_pretty_midi()
    mtrack_data = pypiano_obj
    
    for pch_idx in range(0, len(model_result_binary_list)):
        
        # write drum notes in multitrack object
        mtrack_data.append(pypianoroll.StandardTrack(pianoroll=model_darr_odrm_ary_list[pch_idx], 
                                              program=pch_idx+1, 
                                              is_drum=True,
                                              name='Drums_{}'.format(add_note_ver_list[pch_idx])))

    # transfer data into pretty midi format
    pmidi_data = mtrack_data.to_pretty_midi()

    # print instruments
    print ('[info] Show {} Insts...'.format(len(pmidi_data.instruments)))
    for x in pmidi_data.instruments:
        print ('[info] MIDI ' + str(x))
    print('')

    # make all notes in Drums2 velocity=99
    for instrument in pmidi_data.instruments:
        #if instrument.program==5:
        if instrument.is_drum:
            for note in instrument.notes:
                note.velocity = 120
        else:
            for note in instrument.notes:
                note.velocity = 50            


    song_name_tmp = all_tracks_mid_flist[x_idx].split('/')[-1][:-15] + '_merged'
                
    # set midi file name to write
    midi_file_name = './output_midi/{}.mid'.format(song_name_tmp)

    # create folder if not exist
    ensure_dir(midi_file_name)

    # write midi file
    pmidi_data.write(midi_file_name)
    print ('[info] \"{}\" is saved.\n\n'.format(midi_file_name))
    
print ('[info] All {} files are saved.'.format(len(song_index_in_list)))

[info] Start processing song: 1 ...
[info] Song index: 00000
[info] Song bars: 101
[info] start\end bar index:  0\101




[info] Show 11 Insts...
[info] MIDI Instrument(program=85, is_drum=False, name="Lead 6 (voice)")
[info] MIDI Instrument(program=1, is_drum=False, name="Bright Acoustic Piano")
[info] MIDI Instrument(program=27, is_drum=False, name="Electric Guitar (clean)")
[info] MIDI Instrument(program=25, is_drum=False, name="Acoustic Guitar (steel)")
[info] MIDI Instrument(program=33, is_drum=False, name="Electric Bass (finger)")
[info] MIDI Instrument(program=0, is_drum=True, name="MIDI")
[info] MIDI Instrument(program=1, is_drum=True, name="Drums_00")
[info] MIDI Instrument(program=2, is_drum=True, name="Drums_03")
[info] MIDI Instrument(program=3, is_drum=True, name="Drums_06")
[info] MIDI Instrument(program=4, is_drum=True, name="Drums_12")
[info] MIDI Instrument(program=5, is_drum=True, name="Drums_20")

[info] "./output_midi/01_A_Hard_Days_Night_merged.mid" is saved.


[info] Start processing song: 2 ...
[info] Song index: 00001
[info] Song bars: 81
[info] start\end bar index:  101\182
[info]

# Congratulation ! Now you can find fusion tracks(Original midi + generated drums) under "./output_midi/"

In [17]:
!ls ./output_midi/

01_A_Hard_Days_Night_merged.mid
02_Anna_merged.mid
03_Back_In_The_USSR_merged.mid
04_Cant_Buy_Me_Love_merged.mid
05_Hold_Me_Tight_merged.mid
06_I_Call_Your_Name_merged.mid
07_I_Wanna_Be_Your_Man_merged.mid
08_Money_merged.mid
09_The_Word_merged.mid
10_Free_As_A_Bird_merged.mid
11_Hey_Jude_merged.mid
12_Little_Child_merged.mid
13_Hey_Bulldog_merged.mid
14_Lovely_Rita_merged.mid
15_The_Night_Before_merged.mid
16_From_Me_To_You_merged.mid
17_Roll_Over_Beethoven_merged.mid
18_Come_Together_merged.mid
19_Babys_In_Black_merged.mid
21_Michael_Jackson_-_Bad_merged.mid
22_Michael_Jackson_-_Billie_Jean_merged.mid
23_Michael_Jackson_-_Man_In_The_Mirror_merged.mid
24_Michael_Jackson_-_Smooth_Criminal_merged.mid
25_Michael_Jackson_-_Beat_It_merged.mid


# Use any DAW you like to open the MIDI file, you can see five generated tracks as following.

In [None]:
Image(url="./track22_bj.png",width=1200,height=800)