# Install Colab Env. packages

In [None]:
!pip install torch
!pip install torchvision
!pip install matplotlib
!pip install librosa
!pip install mir_eval
!pip install dill
!pip install pretty_midi
!pip install midiutil
!pip install pysoundfile
!apt-get update
!apt-get install fluidsynth
!pip install googledrivedownloader
from google_drive_downloader import GoogleDriveDownloader as gdd; import os;
gdd.download_file_from_google_drive(file_id='13ht_rPUUlle764VJRZOXEWGK5cPlY8Qd', dest_path='./files.zip', unzip=True); os.remove('files.zip');
print ("### Show files ###")
!ls ./

# Data pre-processing

## Import library

In [None]:
%run read_library.py

## The content of the library file:
## https://github.com/Sma1033/music_ai_course/blob/master/read_library.ipynb

## Load MIDI file

In [None]:
midi_data = load_midi_file('./drum_midi.mid')

## Check MIDI file content (first four bar)

In [None]:
plot_first_four_bar(midi_data, 4)

## Check MIDI data statistics

In [None]:
# instruments mapping from note number to GM instruments
# 36 : Kick Drum       (KD)         # 44 : Pedal HH     (PdHH)
# 37 : SD ring shot    (SDrs)       # 47 : Low Mid-Tom  (LMT)
# 38 : Snare Drum      (SD)         # 50 : High Tom     (HT)
# 42 : Closed HH       (CsdHH)      # 51 : Ride Cymbal  (RC)
# 43 : High Floor Tom  (HFT)        # 56 : Cowbell      (CB)

analyse_drum(midi_data)

## Clean data

In [None]:
keep_sound = [36, 37, 38, 42]

md_patterns = get_simplified_data(midi_data, keep_sound)

## Save data into file

In [None]:
save_data(md_patterns, 'drum_patterns.dill')

# Build GAN Model

## Make dataset from drum pattern file

In [None]:
# dataset = md_dataset('drum_patterns')
midi_dataset = DataLoader(dataset=md_dataset('drum_patterns.dill'),
                          batch_size=config.batch_size,
                          shuffle=True,
                          num_workers=4)

## Set Generator & Discriminator network

In [None]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# check network parameters
count_parameters(generator, discriminator)

## Set loss function & G/D optimizer

In [None]:
# Loss function
loss = torch.nn.BCEWithLogitsLoss()

# Optimizers
optimizer_G = torch.optim.RMSprop(generator.parameters(),     lr=config.g_lr, alpha=0.9)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=config.d_lr, alpha=0.9)

## Move model to GPU

In [None]:
set_model_gpu(generator, discriminator)

## Reload network parameters

In [None]:
reload = False

reload_models(generator, discriminator, reload)

## Training model section

In [None]:
# Adversarial ground truths
real_label, fake_label = get_labels()

# show start time
print ("Start training process...")
print (datetime.now().strftime('%Y-%m-%d  %H:%M:%S'))

# the training for loop, run "config.n_epochs+1" epochs
for epoch in range(config.n_epochs+1):
    
    # iteratively read data from 'midi_dataset'
    for i, data in enumerate(midi_dataset):
        
        # stop current iteration if batch_size is not full
        if (data.shape[0] != config.batch_size):
            break            
            
        # Configure input into proper data type
        real_data = Variable(data.type(Tensor))

        
        # ----------------- #
        #  Train Generator  #
        # ----------------- #

        # reset network gradiend to zero
        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (config.batch_size, config.z_dim)))) * 100.0

        # Generate a batch of images
        fake_data = generator(z)
        fake_data_mbd = get_mini_batch(get_diff(fake_data))

        # Loss measures generator's ability to fool the discriminator
        g_loss = loss(discriminator(fake_data_mbd), real_label)

        # calculate gradient and update parameters
        g_loss.backward()
        optimizer_G.step()

        
        
        # --------------------- #
        #  Train Discriminator  #
        # --------------------- #

        # reset network gradiend to zero
        optimizer_D.zero_grad()

        # Measure discriminator's ability to find real data
        real_data_mbd = get_mini_batch(get_diff(real_data))
        real_loss = loss(discriminator(real_data_mbd), real_label)
        
        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (config.batch_size, config.z_dim)))) * 100.0

        # Generate a batch of fake data, measure discriminator's ability to find fake data
        fake_data = generator(z)        
        fake_data_mbd = get_mini_batch(get_diff(fake_data))        
        fake_loss = loss(discriminator(fake_data_mbd.detach()), fake_label)
        
        # Loss measures discriminator's ability to distinguish real and fake data
        d_loss = real_loss + fake_loss
        
        # calculate gradient and update parameters           
        d_loss.backward()
        optimizer_D.step()
        
        
    # show training status periodically
    show_training_status(epoch, config, g_loss, d_loss, fake_data)
       
    
print ('Training is done.')

## Reload trained model for rhythm generation

In [None]:
reload = True

reload_models(generator, discriminator, reload)

## Generate 8 rhythms and store it in variable 'my_rhythm'

In [None]:
my_rhythm = generate_rhythm(generator, 8)

## Save rhythm into MIDI file

In [None]:
file_name = 'my_rhythm.mid'

write_midi(my_rhythm, file_name)

## Synthesize MIDI into wave file

In [None]:
samp_rate = 32000
file_name = 'my_rhythm.mid'

syn_midi(file_name, samp_rate)

## Play wave file

In [None]:
samp_rate = 32000
file_name = 'my_rhythm.wav'

ipd.Audio(file_name, rate=samp_rate)

## Download the MIDI file

In [None]:
files.download('my_rhythm.mid')