In [None]:
import pretty_midi
import numpy as np
import dill
import mir_eval.display
import librosa.display
import IPython.display
import os, sys, time, datetime
from datetime import datetime
import matplotlib.pyplot as plt
#%matplotlib inline
    
print ("[info] Current Time   : " + datetime.now().strftime('%Y-%m-%d  %H:%M:%S'))
print ("[info] Python Version : " + sys.version.split('\n')[0] )
print ("[info] Working Dir    : " + os.getcwd())
#print ("[info] Library load done.")


# define function to load midi file
def load_midi_file(name):
    raw_midi_data = pretty_midi.PrettyMIDI(name)
    raw_midi_data.instruments[0].is_drum = False
    
    tc_times, tempo_changes = raw_midi_data.get_tempo_changes()    
    print ("Midi file tempo : {} BPM".format(tempo_changes[0]))
    
    midi_bpm = tempo_changes[0]
    single_beat_period = 60.0 / tempo_changes[0]
    single_bar_period = 4.0 * single_beat_period

    print ("Single beat length: {} Sec.".format(single_beat_period))
    #print ("Single Bar period: {} Sec.".format(single_bar_period))
    
    # midi file: 120 BPM 
    # 1 bar = 2 Sec. = 16 x 16th note
    # 1/2 Bar = 1 Sec. = 8 x 16th note
    # total midi length : 340 Bar
    #                     680 Sec
    #                     5440 x 16th note

    # Lower the midi file sampling resolution for simplicity
    midi_pno_roll = raw_midi_data.get_piano_roll(fs=8)
    #midi_pno_roll = raw_midi_data.get_piano_roll(fs=tempo_changes[0]/60.0)

    # get midi file array row and col size
    midi_row_num = midi_pno_roll.shape[0]
    midi_col_num = midi_pno_roll.shape[1]

    print ("Total bars: {}".format(midi_col_num/16))
    print ("Midi file data array shape: {}".format([midi_row_num, midi_col_num]))    
    
    return(raw_midi_data)


# Defince plot function, use librosa's specshow function for displaying the piano roll
def plot_piano_roll(raw_midi_data, start_pitch, end_pitch, fs=8):
    librosa.display.specshow(raw_midi_data.get_piano_roll(fs)[start_pitch:end_pitch],
                             hop_length=1, sr=fs, x_axis='time', y_axis='cqt_note',
                             fmin=pretty_midi.note_number_to_hz(start_pitch),
                             cmap='binary'
                            )

def plot_first_four_bar(raw_midi_data, bar_n):
    # Get and downbeat times
    beats = raw_midi_data.get_beats()
    start_beat = raw_midi_data.get_downbeats()

    # Draw drum midi piano roll
    plt.figure(figsize=(16, 5))
    plot_piano_roll(raw_midi_data, start_pitch=20, end_pitch=60)

    # Draw beat position and bar position
    mir_eval.display.events(beats, color='blue', lw=1)
    mir_eval.display.events(start_beat, color='red', lw=2)

    # Only display 0 - 8 Sec. midi data for simplicity
    plt.xlim(0, bar_n*2);
    
    # set plot tittle
    plt.title('The first four bar drum channel data')
    
    
def analyse_drum(raw_midi_data):    
    # check how many different sound are used in midi array
    midi_pno_roll = raw_midi_data.get_piano_roll(fs=8)

    # get midi file array row and col size
    midi_row_num = midi_pno_roll.shape[0]
    midi_col_num = midi_pno_roll.shape[1]

    # create an empty list to store activated notes/notes count
    exist_note_type_list = []
    note_counts = []

    # sweep whole midi file array, check activated note
    for note in range(0, midi_row_num):
        for beat in range(0, midi_col_num):
            element = midi_pno_roll[note, beat]
            # if a note is activated
            if element!=0:
                # if note is not seen before
                if note not in exist_note_type_list:
                    # put note number into a list
                    exist_note_type_list.append(note)
                    # add a new note count elements
                    note_counts.append(1)
                else: # if note is already in list
                    # check the existing note index
                    note_idx, = np.where(np.array(exist_note_type_list)==note)[0]
                    # note count + 1
                    note_counts[note_idx] = note_counts[note_idx] + 1

    note_types = len(exist_note_type_list)    
    print ("Note types: {} types".format(note_types))
    print ("Activated MIDI notes: {}".format(exist_note_type_list))
    print ("Activated counts: {}".format(note_counts))    
    
    
    # instruments mapping from note number to GM instruments
    #
    # 36 : KD              (KD)         # 44 : Pedal HH     (PdHH)
    # 37 : SD ring shot    (SDrs)       # 47 : Low Mid-Tom  (LMT)
    # 38 : SD              (SD)         # 50 : High Tom     (HT)
    # 42 : Closed HH       (CsdHH)      # 51 : Ride Cymbal  (RC)
    # 43 : High Floor Tom  (HFT)        # 56 : Cowbell      (CB)

    # set plot label names
    note_type_name = ['KD', 'SDrs', 'SD', 'CsdHH', 'HFT', 'PdHH', 'LMT', 'HT', 'RC', 'CB']

    # Generate bar plot with size (8,4)
    plt.figure(figsize=(8,4))

    # Draw Bar plot
    plt.bar(range(10), height=note_counts)

    # set Label
    plt.xticks(range(10), note_type_name);
    
    # set plot tittle
    plt.title('Drum note counts')

    
# reduce data complexity into only n type sound
def get_simplified_data(raw_midi_data, keep_sound):

    # Lower the midi file sampling resolution for simplicity
    midi_pno_roll = raw_midi_data.get_piano_roll(fs=8)

    # get midi file array row and col size
    midi_row_num = midi_pno_roll.shape[0]
    midi_col_num = midi_pno_roll.shape[1]        
        
    #exist_note_type_list = [36, 38, 42]
    note_types = len(keep_sound)

    # create an empty array to store simplified midi data
    reduced_midi_array = np.zeros([note_types, midi_col_num])

    # copy data from original full midi data array to reduced array
    for idx, note in enumerate(keep_sound):
        reduced_midi_array[idx, :] = midi_pno_roll[note, :]

    # add closed SD into SD
    #reduced_midi_array[1, :] = reduced_midi_array[1, :] + midi_pno_roll[37, :]

    # make data [0 or 1]
    reduced_midi_array[reduced_midi_array>=0.5] = 1
    reduced_midi_array[reduced_midi_array<0.5] = 0
    reduced_midi_array = reduced_midi_array.astype(np.int)
    
    # show result data format
    #print ("#####  result data format  #####")
    print ("###  Input Data  ###")
    
    #print ("Total bars: {}  (16 beat/bar)".format(midi_col_num/16))
    #print ("The first bar([:, 0:16]): \n{}".format(reduced_midi_array[:reduced_midi_array.shape[0], :16]))
            
    
    # check how many types of non-repeated drum pattern
    
    beats_per_bar = 16
    total_bars = np.int(midi_col_num/beats_per_bar)
    print ("  >> Original MIDI file patterns: {}".format(total_bars))
    
    print ("  >> Array shape: {}".format(reduced_midi_array.shape))

    # create an empty list to store non-repeated drum pattern
    pattern_list = []

    # sweep patterns across whole "reduced_midi_array"
    for bar in range(0, total_bars):
        # create empty array to store single pattern
        single_drum_ptn = np.zeros([note_types, beats_per_bar])

        # calculate beat start and beat end
        beat_start = bar * beats_per_bar
        beat_end = (bar+1) * beats_per_bar

        # move data into single pattern empty array
        single_drum_ptn[:, :] = reduced_midi_array[:, beat_start:beat_end]

        # if this is the first pattern
        if bar==0:
            # store it into pattern_list
            pattern_list.append(single_drum_ptn)
        else:
            # initialize a flag to store pattern existence status
            ptn_is_exist_in_ptns_list = 0
            for bar_idx in range(0, len(pattern_list)):
                if np.array_equal(single_drum_ptn, pattern_list[bar_idx]):   # if this pattern is already in pattern_list
                    ptn_is_exist_in_ptns_list = 1                            # set this flag from 0 to 1
            if ptn_is_exist_in_ptns_list==0:                                 # if this pattern is not seen in pattern_list before
                pattern_list.append(single_drum_ptn)                         # store it into pattern_list

                
    num_types = len(pattern_list)
    single_pattern_shape = pattern_list[0].shape
    
    for x in range(0, num_types):
        pattern_list[x][pattern_list[x]>0.50] = 0.70
        pattern_list[x][pattern_list[x]<=0.50] = -0.70
    
    print ("###  Output Data  ###")
    
    #print ("All original 16 beat drum pattern: {}".format(midi_col_num/16))
    print ("  >> Simplified non-repeated patterns: {}".format(num_types))
    
    #print ("Total types of drum pattern: {}".format(num_types))
    print ("  >> pattern data format: {}".format(single_pattern_shape))                    
    
            
    return(pattern_list)
    
    
def make_large_64x64(input_array):
    large_array = np.zeros([64, 64])    
    for row in range(0, 64):                
        for col in range(0, 64):
            large_array[row, col] = input_array[row//16, col//4]    
    return large_array   


def sample_small_4x16(input_array):
    small_array = np.zeros([4, 16])
    for row in range(0, 64):                
        for col in range(0, 64):
            small_array[row//16, col//4] = small_array[row//16, col//4] + input_array[row, col]
    small_array = small_array/64.0
    return small_array
    
    
    
def save_data(data, file_name):
    with open(file_name, 'wb') as saving_file:             
        dill.dump(data, saving_file)

    print ("File \"{}\" is saved !".format(file_name))
    


import argparse
import os, sys
import numpy as np
import math
from datetime import datetime
import dill
from pylab import plt
#%matplotlib inline

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torch.autograd import Variable
from torch.autograd import grad

import torch.nn as nn
import torch.nn.functional as F
import torch
    
#print ("[info] Current Time   : " + datetime.now().strftime('%Y-%m-%d  %H:%M:%S'))
#print ("[info] Python Version : " + sys.version.split('\n')[0] )
#print ("[info] Working Dir    : " + os.getcwd())
#print ("[info] Library load done.")


class Config:
    n_epochs = 40000
    batch_size = 32
    g_lr = 2e-5
    d_lr = 8e-6
    z_dim = 256
    data_h = 4
    data_w = 16    
    channels = 1
    sample_interval = 1
    use_cuda = False
    #use_cuda = True
    
config = Config()
data_shape = (config.channels, config.data_h, config.data_w)


class md_dataset(Dataset):
    """ midi dataset."""

    # Initialize your data here.
    def __init__(self, file_name):
        
        with open(file_name, 'rb') as saved_file:
            load_data = dill.load(saved_file)
        
        self.data = load_data
        self.data_h = self.data[0].shape[0]
        self.data_w = self.data[0].shape[1]
        self.len = len(load_data)
        
        self.total_ary = np.zeros([self.data_h, self.len*self.data_w])
        
        for x in range(0, self.len):
            self.total_ary[:, x*16:(x+1)*16] = self.data[x][:, :]
            
        self.total_ary = self.total_ary.astype(np.float32)
        
        self.total_ary = self.total_ary.reshape([1, self.total_ary.shape[0], self.total_ary.shape[1]])
        
        self.total_ary = torch.from_numpy(self.total_ary)

    def __getitem__(self, index):     
        #out = torch.from_numpy(self.data[index].reshape([1,self.data_h,self.data_w]).astype(np.float32))
        out = self.total_ary[:, :, index*16:(index+1)*16]
        return out

    def __len__(self):
        return self.len

    
# Define Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # set fully connected layer
        self.fc_model = nn.Sequential(
                                nn.Linear(config.z_dim, 128),
                                nn.LeakyReLU(0.01, inplace=True),
                                nn.Dropout(p=0.4),
            
                                nn.Linear(128, 256),
                                nn.LeakyReLU(0.01, inplace=True),
                                nn.Dropout(p=0.3),
            
                                nn.Linear(256, 512),
                                nn.LeakyReLU(0.01, inplace=True),
                                nn.Dropout(p=0.2),  

                                nn.Linear(512, 1024),
                                nn.LeakyReLU(0.01, inplace=True),
                                nn.Dropout(p=0.1),              

                                nn.Linear(1024, 4096),
                                nn.LeakyReLU(0.01, inplace=True),
            
                                nn.Linear(4096, 4*16),
                                nn.Tanh()
                )        
        
    # define G forward network calculation
    def forward(self, g_input):
        fc_out = self.fc_model(g_input)
        fc_out_4d = fc_out.view(g_input.size(0), 1, 4, 16)
        
        return fc_out_4d

    
    
# Define Discriminator 
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # set convolulation layer
        self.conv = nn.Sequential(
                              nn.Conv2d(1 ,2048 ,(2, 2) ,(2 ,1) ,0 ,bias=False),
                              nn.BatchNorm2d(2048),
                              nn.LeakyReLU(0.20 ,inplace=True),

                              nn.Conv2d(2048 ,512 ,(2, 2) ,(2, 1) ,0 ,bias=False),
                              nn.BatchNorm2d(512),
                              nn.LeakyReLU(0.20 ,inplace=True),

                              nn.Conv2d(512 ,256 ,(2, 2) ,(1, 1) ,0 ,bias=False),
                              nn.BatchNorm2d(256),
                              nn.LeakyReLU(0.20 ,inplace=True),

                              nn.Conv2d(256 ,32 ,(2, 2) ,(1, 1) ,0 ,bias=False),
                              nn.BatchNorm2d(32),
                              nn.LeakyReLU(0.20 ,inplace=False),  
                )
        
        # set fully connected layer
        self.fc = nn.Sequential(
                            nn.Linear(5376, 1),    
                )
        
    # define D forward network calculation
    def forward(self, d_input):
        conv_out = self.conv(d_input)        
        fc_input = conv_out.view(d_input.size(0), -1)
        score = self.fc(fc_input)
        
        return score
    
    

def nv_cuda(xs):
    if torch.cuda.is_available() and config.use_cuda:
        if not isinstance(xs, (list, tuple)):
            return xs.cuda()
        else:
            return [x.cuda() for x in xs]
        
    else:
        if not isinstance(xs, (list, tuple)):
            return xs
        else:
            return [x for x in xs]

        
def gradient_penalty(x, f):
    # interpolation
    shape = [x.size(0)] + [1] * (x.dim() - 1)
    alpha = nv_cuda(torch.rand(shape))
    beta = nv_cuda(torch.rand(x.size()))
    
    y = x + 0.5 * x.std() * beta
    z = x + alpha * (y - x)

    # gradient penalty
    z = nv_cuda(Variable(z, requires_grad=True))
    o = f(z)
    g = grad(o, z, grad_outputs=nv_cuda(torch.ones(o.size())), create_graph=True)[0].view(z.size(0), -1)
    gp = ((g.norm(p=2, dim=1) - 1)**2).mean()

    return gp * 10.0


def get_mini_batch(in_tensor):
    basic_unit = config.batch_size//8
    out_tensor = torch.cat([in_tensor[           0:basic_unit*1], 
                            in_tensor[basic_unit*1:basic_unit*2], 
                            in_tensor[basic_unit*2:basic_unit*3], 
                            in_tensor[basic_unit*3:basic_unit*4],
                            in_tensor[basic_unit*4:basic_unit*5], 
                            in_tensor[basic_unit*5:basic_unit*6], 
                            in_tensor[basic_unit*6:basic_unit*7],
                            in_tensor[basic_unit*7:basic_unit*8]], 
                           dim=2)
    return out_tensor


def get_diff(input_tensor):
    # create tensor to store diff value
    out_tensor = torch.zeros(input_tensor.shape[0], 
                         input_tensor.shape[1],
                         input_tensor.shape[2]*2,
                         input_tensor.shape[3])
    
    out_tensor = Variable(out_tensor)
    
    if config.use_cuda:
        out_tensor = out_tensor.cuda()
    
    # calculate original tensor
    out_tensor[:,:,0:4,:] = input_tensor
        
    # save time axis diff tensor
    out_tensor[:,:,4:8,1:] = input_tensor[:,:,:,:15]
    out_tensor[:,:,4:8,:] = (input_tensor[:,:,:,:] - out_tensor[:,:,4:8,:] + 1.0) * 0.5
    
    return out_tensor



def count_parameters(generator, discriminator):
    gs  = sum(np.prod(list(p.size())) for p in generator.parameters())
    ds  = sum(np.prod(list(q.size())) for q in discriminator.parameters())
    
    print ('Number of [G/D/Total] params: [%d/%d/%d]' %(gs, ds, (gs + ds)))



# Loss function
#adversarial_loss = torch.nn.BCELoss()
#adversarial_loss = torch.nn.BCEWithLogitsLoss()
loss = torch.nn.BCEWithLogitsLoss()


#Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
Tensor = torch.cuda.FloatTensor if config.use_cuda else torch.FloatTensor


# Optimizers
#optimizer_G = torch.optim.Adam(generator.parameters(),     lr=config.lr,     betas=(0.5, 0.99))
#optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=config.lr*0.5, betas=(0.5, 0.99))
#optimizer_G = torch.optim.RMSprop(generator.parameters(),     lr=config.lr, alpha=0.9)
#optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=config.lr, alpha=0.9)

def set_model_gpu(generator, discriminator):
    if config.use_cuda:
        generator.cuda()
        discriminator.cuda()
        loss.cuda()


def reload_models(generator, discriminator, reload):
    if reload:
        generator.load_state_dict(torch.load('g_model.pt'))
        discriminator.load_state_dict(torch.load('d_model.pt')) 
        print('model parameters are reloaded.')
    else:
        print('No reload.')
      
    
def get_labels():
    real_label = Variable(Tensor(config.batch_size//8, 1).fill_(1.00), requires_grad=False)
    fake_label = Variable(Tensor(config.batch_size//8, 1).fill_(0.00), requires_grad=False)
    
    return real_label, fake_label
        

def plot_rhythm(rhythm):
    plt.imshow(np.flipud(rhythm), cmap='binary')
    plt.show()    

    
def show_training_status(epoch, config, g_loss, d_loss, fake_data): 
    if (epoch%config.sample_interval == 0) and (epoch>0):     
                    
        print ("[Epoch %d/%d] [G loss: %f] [D loss: %f]" % (epoch, config.n_epochs,
                                                            g_loss.data.cpu().numpy(), 
                                                            d_loss.data.cpu().numpy())
              )

        print(datetime.now().strftime('%Y-%m-%d  %H:%M:%S'))
        
        # plot rhythm        
        plt.imshow(np.flipud(fake_data.data[:1].squeeze()), cmap='binary')
        plt.show()

    
    
    
def generate_rhythm(generator, num):
    
    # set maximum number to 20
    num = min(20, num)
    
    # generate 100 non repeated pattern:
    z = Variable(Tensor(np.random.normal(0, 1, (num*4, config.z_dim)))) * 100.0

    pattern_list = []

    for x in range(0, num*4):
        pattern = generator(z).data[x].squeeze().cpu().numpy()

        pattern[pattern>0.25] = 1.0
        pattern[pattern<=0.25] = 0.0

        repeat_ptn = False

        if len(pattern_list)==0:
            pattern_list.append(pattern)
        else:
            for k in range(0, len(pattern_list)):
                if np.array_equal(pattern, pattern_list[k]):
                    repeat_ptn = True
                    break
            if (repeat_ptn==False):
                pattern_list.append(pattern)

        if len(pattern_list)>=num:
            #print("sweeped pattern: {}".format(x+1))
            break

    print ("generated non-repeated pattern: {}".format(len(pattern_list)))
    
    for x in range(0, num):
        print ("rhythm: {}".format(x+1))
        plot_rhythm(pattern_list[x])
    
    return pattern_list    



        
  
from midiutil.MidiFile import MIDIFile

def write_midi(rhythm, file_name):
    
    ptn_types = len(rhythm)
    output_bars = []

    # repeat each pattern for 4 times
    for ptrn_n in range (0, ptn_types):
        for _ in range(4):
            output_bars.append(rhythm[ptrn_n])
        #print ("output pattern : {}".format(ptrn_n+1))
        
        
    # create your MIDI object
    midi_file = MIDIFile(numTracks=1, adjust_origin=True, file_format=1)     # only 1 track

    # only need 1 track, track "0"
    track = 0   
    # start at the beginning
    time = 0    
    # set track channel, 0 = piano, 9 = Drum kit
    channel = 9
    # set track volume
    volume = 124

    # set track name
    midi_file.addTrackName(track, time, "GAN Drum")
    # set track tempo
    midi_file.addTempo(track, time, 120)

    # set unit time for a 16th note
    u16b_d = 0.25 # 1 drum beat = 0.25 Sec.
    # create time array
    u16b = list(range(0, 16+1))
    for x in range(0, 16):
        u16b[x] = (x) * u16b_d

    #drum_note_num = np.array([36, 37, 38, 42, 43, 44, 47, 50, 51, 56])
    drum_note_num = np.array([36, 37, 40, 42])

    # fill midi note into drum track
    for bar in range(len(output_bars)):
        bar_acc = bar * 16 * u16b_d    
        for y in range(0, 16):
            for x in range(0, drum_note_num.shape[0]):
                if (output_bars[bar][x,y] > 0.0):
                    pitch = drum_note_num[x]
                    time = bar_acc + u16b[y]             
                    duration = u16b_d        
                    midi_file.addNote(track, channel, pitch, time, duration, volume)

            
    # write it to disk
    #output_file_name = "gg_drum.mid"
    output_file_name = file_name
    with open(output_file_name, 'wb') as out_file:
        midi_file.writeFile(out_file)

    print ("MIDI file {} is saved, Total {} type of MIDI".format(file_name, int(len(output_bars)/4)))
    #print ("You can use DAW to transfer MIDI to audio now.")

       
    
# output drum midi
import os
import IPython.display as ipd
import librosa
import soundfile as sf

def syn_midi(file_name, samp_rate):
    input_midi = file_name
    #samp_rate = 44100
    output_file = '{}.wav'.format(file_name.split(".")[0])

    # run MIDI to audio synthesis
    syn_cmd = "fluidsynth -ni sf_2_GeneralUser.sf2 {} -F tmp.wav -r {}".format(input_midi, samp_rate)
    os.system(syn_cmd)

    #use only 16-bit mono sound, and do volume normalization
    y, sr = librosa.load('tmp.wav', sr=samp_rate, mono=True)
    y = y*0.80/np.max([np.abs(np.max(y)), np.abs(np.min(y))])
    sf.write(output_file, (y * np.iinfo(np.int16).max).astype(np.int16), sr, 'PCM_16')
    os.system("rm tmp.wav")

    print("MIDI to Audio transfer is done, synthesized file name: {}".format(output_file))
    #print("downloading file for playback...")

    #play audio wave file
    #ipd.Audio(output_file, rate=samp_rate)


# remove seaborn style
import seaborn as sns
sns.set_style("whitegrid", {'axes.grid' : False})    
    
    
    