In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
cd /content/drive/'My Drive'/MusicTransformer

In [0]:
ls

In [0]:
!pip install -r requirements.txt

In [0]:
!pip list

In [0]:
from model import MusicTransformer, MusicTransformerDecoder
from custom.layers import *
from custom import callback
import params as par
from tensorflow.python.keras.optimizer_v2.adam import Adam
from data import Data
import utils
import argparse
import datetime
import sys
from midi_processor.processor import decode_midi, encode_midi
tf.executing_eagerly()

In [0]:
l_r = None
batch_size = 2
pickle_dir = 'music'
max_seq = 2048
epochs = 5
is_reuse = False
load_path = None
save_path = datetime.datetime.now().strftime('%Y-%m-%d')
multi_gpu = True
num_layer = 6

In [0]:
dataset = Data('dataset/processed')
print(dataset)
learning_rate = callback.CustomSchedule(par.embedding_dim) if l_r is None else l_r
opt = Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
mt = MusicTransformerDecoder(
            embedding_dim=256,
            vocab_size=par.vocab_size,
            num_layer=num_layer,
            max_seq=max_seq,
            dropout=0.2,
            debug=False, loader_path=load_path)
mt.compile(optimizer=opt, loss=callback.transformer_dist_train_loss)

# Train Model

In [0]:
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
train_log_dir = 'logs/mt_decoder/'+current_time+'/train'
eval_log_dir = 'logs/mt_decoder/'+current_time+'/eval'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
eval_summary_writer = tf.summary.create_file_writer(eval_log_dir)

idx = 0
for e in range(epochs):
    mt.reset_metrics()
    for b in range(len(dataset.files) // batch_size):
        try:
            batch_x, batch_y = dataset.slide_seq2seq_batch(batch_size, max_seq)
        except:
            continue
        result_metrics = mt.train_on_batch(batch_x, batch_y)
        if b % 100 == 0:
            eval_x, eval_y = dataset.slide_seq2seq_batch(batch_size, max_seq, 'eval')
            eval_result_metrics, weights = mt.evaluate(eval_x, eval_y)
            mt.save(save_path)
            with train_summary_writer.as_default():
                if b == 0:
                    tf.summary.histogram("target_analysis", batch_y, step=e)
                    tf.summary.histogram("source_analysis", batch_x, step=e)

                tf.summary.scalar('loss', result_metrics[0], step=idx)
                tf.summary.scalar('accuracy', result_metrics[1], step=idx)

            with eval_summary_writer.as_default():
                if b == 0:
                    mt.sanity_check(eval_x, eval_y, step=e)

                tf.summary.scalar('loss', eval_result_metrics[0], step=idx)
                tf.summary.scalar('accuracy', eval_result_metrics[1], step=idx)
                for i, weight in enumerate(weights):
                    with tf.name_scope("layer_%d" % i):
                        with tf.name_scope("w"):
                            utils.attention_image_summary(weight, step=idx)
                # for i, weight in enumerate(weights):
                #     with tf.name_scope("layer_%d" % i):
                #         with tf.name_scope("_w0"):
                #             utils.attention_image_summary(weight[0])
                #         with tf.name_scope("_w1"):
                #             utils.attention_image_summary(weight[1])
            idx += 1
            print('\n====================================================')
            print('Epoch/Batch: {}/{}'.format(e, b))
            print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(result_metrics[0], result_metrics[1]))
            print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_result_metrics[0], eval_result_metrics[1]))
              
            

# Generate Music

In [0]:
max_seq = 2048
load_path = save_path
mode = 'dec'
beam = None             # int type
length = 2048           # int type
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
save_path = 'music/generated_'+current_time+'.mid'
gen_log_dir = 'logs/mt_decoder/generate_'+current_time+'/generate'
gen_summary_writer = tf.summary.create_file_writer(gen_log_dir)

In [0]:

if mode == 'enc-dec':
    print(">> generate with original seq2seq wise... beam size is {}".format(beam))
    mt = MusicTransformer(
            embedding_dim=256,
            vocab_size=par.vocab_size,
            num_layer=6,
            max_seq=2048,
            dropout=0.2,
            debug=False, loader_path=load_path)
else:
    print(">> generate with decoder wise... beam size is {}".format(beam))
    mt = MusicTransformerDecoder(loader_path=load_path)

inputs = encode_midi('dataset/midi/BENABD10.mid')


with gen_summary_writer.as_default():
    result = mt.generate(inputs[:10], beam=beam, length=length, tf_board=True)

for i in result:
    print(i)

if mode == 'enc-dec':
    decode_midi(list(inputs[-1*par.max_seq:]) + list(result[1:]), file_path=save_path)
else:
    decode_midi(result, file_path=save_path)
