In [59]:
import numpy as np
import tensorflow as tf
import random
from keras.layers import *
import pdb
from keras.models import Model
from functools import partial
from tqdm import tqdm_notebook as tqdm

### Data Loading
from adlframework.processors.general_processors import crop, reshape, pdb_trace
from adlframework.processors.lstm_processors import crop_and_label
from adlframework.processors.midi_processors import midi_to_np, notes_to_classification, make_time_relative
from adlframework.filters.general_filters import min_array_shape, threshold_label
from adlframework.retrievals.BlobLocalCache import BlobLocalCache
from adlframework.datasource import DataSource
from adlframework.dataentity.midi_de import MidiDataEntity

In [2]:
i = Input((100, 3))
bn = BatchNormalization()(i)
lstm_layer = LSTM(100)(bn)
notes = Dense(88)(lstm_layer)
space = Dense(12)(lstm_layer)
duration = Dense(11)(lstm_layer)
notes_prob = Activation('softmax')(notes)
space_prob = Activation('softmax')(space)
duration_prob = Activation('softmax')(duration)
model = Model(i, [space_prob, duration_prob, notes_prob])
model.compile('adam', 'categorical_crossentropy')

In [23]:
def convert_to_one_hot(sample):
    data, label = sample
    ret = []
    for i, v  in enumerate([12, 11, 88]):
        z = np.zeros(v)
        z[label[0][i]] = 1
        ret.append(z)
    return data, ret

## ADLFramework Data

In [27]:
### Controllers
controllers = [partial(threshold_label, labelnames="num_instruments", threshold=1, greater_than=False),
                midi_to_np,
                partial(min_array_shape, min_shape=(101, 4)),
                partial(crop, shape=(101, 3)),
                make_time_relative,
                notes_to_classification,
                partial(crop_and_label, num_rows=1),
                convert_to_one_hot ## Not applied to input data, only label
              ]

### Load Data
base = '/Users/localhost/Desktop/Projects/Working/StudyMuse/local_cache/alex_midiset/v2/'
midi_retrieval = BlobLocalCache(base+'midis/', base+'labels/')
midi_ds = DataSource(midi_retrieval, MidiDataEntity,
                         verbosity=0,
						controllers=controllers,
						backend='madmom',
						batch_size=50)

train_ds, temp = DataSource.split(midi_ds, split_percent=.6) # Train at .6
val_ds, test_ds = DataSource.split(temp, split_percent=.6) # Val at .24, test at .16

INFO:adlframework.datasource:Prefiltering entities


Retrieval not named, so won't be cached.




## Experiment

In [56]:
### Hyperparameters
epochs = 100

In [None]:
for i in tqdm(range(epochs)):
    batch = train_ds.next()
    data, labels = batch
    labels = [np.array(list(labels[:,i])) for i in range(3)]
    losses = model.train_on_batch(data, labels)
    print(*zip(losses, ['MAE', 'Notes', 'Spaces', 'Durations']))

Widget Javascript not detected.  It may not be installed or enabled properly.


[9.3263245, 2.444927, 2.4018421, 4.4795556]


INFO:adlframework.datasource:Looped the datasource
INFO:adlframework.datasource:Shuffling the datasource


[9.3485603, 2.4758248, 2.3823757, 4.4903593]
[9.2823601, 2.4287977, 2.3712609, 4.4823012]
[9.2568178, 2.4187486, 2.3625295, 4.4755397]
[9.2724934, 2.4488504, 2.3671637, 4.4564786]
[9.2167721, 2.4173594, 2.3296585, 4.4697547]
[9.183672, 2.4111593, 2.3035264, 4.468986]


INFO:adlframework.datasource:Looped the datasource
INFO:adlframework.datasource:Shuffling the datasource


[9.1398335, 2.3903506, 2.2921, 4.4573836]
[9.1582718, 2.4030619, 2.2866931, 4.4685163]
[9.0771217, 2.3860672, 2.2553389, 4.4357152]
[8.9863462, 2.3236413, 2.2385538, 4.4241514]
[8.9964104, 2.3259966, 2.2350836, 4.4353304]
[8.9541569, 2.3297303, 2.1803341, 4.4440932]


INFO:adlframework.datasource:Looped the datasource
INFO:adlframework.datasource:Shuffling the datasource


[8.7394085, 2.2135475, 2.0759661, 4.4498949]
[8.8110685, 2.3330941, 2.0789814, 4.398993]
[8.7199478, 2.2944877, 2.0059571, 4.4195027]
[8.3665047, 2.0610561, 1.9290695, 4.3763795]
[8.1487141, 2.0680075, 1.796735, 4.2839718]
[7.8315549, 2.0698936, 1.4621894, 4.2994719]


INFO:adlframework.datasource:Looped the datasource
INFO:adlframework.datasource:Shuffling the datasource


[7.9522905, 2.1248388, 1.5653327, 4.2621193]
[7.7238731, 2.1424217, 1.34253, 4.2389216]
[8.08885, 2.3370333, 1.4461679, 4.3056483]
[8.3832283, 2.6048193, 1.6509399, 4.1274695]
[7.4092894, 2.0113027, 1.3196298, 4.0783567]
[8.0546894, 2.0261924, 1.8717151, 4.1567822]


INFO:adlframework.datasource:Looped the datasource
INFO:adlframework.datasource:Shuffling the datasource


[7.3198733, 2.0890393, 1.1938668, 4.0369673]
[7.8182063, 2.1832323, 1.6117973, 4.0231767]
[7.7943373, 2.2509081, 1.3808038, 4.1626253]
[7.3730111, 1.9375296, 1.4429625, 3.9925189]
[7.4531088, 2.1779618, 1.3183038, 3.9568429]
[7.437922, 1.9939035, 1.4013875, 4.0426311]


INFO:adlframework.datasource:Looped the datasource
INFO:adlframework.datasource:Shuffling the datasource


[6.9334021, 1.7685571, 1.2562028, 3.9086425]
[7.0026894, 1.9370705, 1.0506415, 4.0149775]
[7.0837612, 1.9068707, 1.2213244, 3.9555662]
[6.975575, 1.9624199, 1.2316521, 3.781503]
[6.9623775, 1.9053193, 1.2986345, 3.758424]
[6.9514446, 2.0915833, 1.110536, 3.7493255]


INFO:adlframework.datasource:Looped the datasource
INFO:adlframework.datasource:Shuffling the datasource


[6.912066, 1.9339855, 1.3654119, 3.6126685]
[6.6819115, 1.9214959, 1.1241682, 3.6362472]
[7.0695171, 2.0787976, 1.2245027, 3.7662165]
[6.8282681, 2.0016181, 1.1589671, 3.6676824]
[6.9981222, 1.995556, 1.2627144, 3.7398522]
[7.0221939, 1.8430793, 1.3761212, 3.8029935]


INFO:adlframework.datasource:Looped the datasource
INFO:adlframework.datasource:Shuffling the datasource


[6.4313192, 1.732426, 1.2114197, 3.4874735]
[7.0902929, 2.113374, 1.3138274, 3.6630917]
[6.292625, 1.7951764, 0.88898832, 3.6084602]
[6.9590197, 1.8837181, 1.3799565, 3.6953449]
[6.5992308, 1.7468463, 1.2319371, 3.6204474]
[6.4501743, 1.8447005, 1.0494415, 3.5560324]


INFO:adlframework.datasource:Looped the datasource
INFO:adlframework.datasource:Shuffling the datasource


[6.7065592, 1.957803, 1.0897534, 3.659003]
[6.8092432, 1.9070044, 1.341743, 3.5604956]
[6.1710811, 1.7703073, 0.87474704, 3.5260267]
[6.6526756, 1.8301919, 1.1583592, 3.6641243]
[6.3436875, 1.9438074, 0.96905899, 3.4308212]
[6.5602274, 1.7572143, 1.1489751, 3.6540382]


INFO:adlframework.datasource:Looped the datasource
INFO:adlframework.datasource:Shuffling the datasource


[6.8386559, 1.9288723, 1.3485669, 3.5612168]
[6.6427393, 1.7825077, 1.3000029, 3.5602286]
[6.3450041, 1.7753102, 0.97048724, 3.5992069]
[6.4010181, 1.7312255, 1.195806, 3.4739869]
[6.4544573, 1.7763072, 1.0984871, 3.5796626]
[6.5438347, 1.9258013, 0.99765694, 3.6203766]


INFO:adlframework.datasource:Looped the datasource
INFO:adlframework.datasource:Shuffling the datasource


[6.3686543, 1.6512555, 1.098106, 3.619293]
[6.5535913, 1.7421105, 1.2370963, 3.5743845]
[6.3827314, 1.7189761, 1.1359582, 3.5277975]
[6.6405787, 1.9736285, 1.0949388, 3.5720115]
[6.5044241, 1.6684512, 1.2270335, 3.6089394]
[6.334095, 1.8697486, 0.93755525, 3.5267913]


INFO:adlframework.datasource:Looped the datasource
INFO:adlframework.datasource:Shuffling the datasource


[6.192215, 1.5499932, 1.1783724, 3.4638495]
[6.3452415, 1.6906235, 1.1685652, 3.4860528]
[5.9374018, 1.7005795, 0.82843095, 3.408391]
[6.3538599, 1.7999426, 1.0735173, 3.4804001]
[6.3433623, 1.715245, 1.266412, 3.3617053]
[6.4285197, 1.7075016, 0.98297191, 3.7380462]


INFO:adlframework.datasource:Looped the datasource
INFO:adlframework.datasource:Shuffling the datasource


[6.32652, 1.7561562, 1.0818114, 3.4885523]
[6.4669876, 1.7876058, 1.2130295, 3.4663522]
[6.2056026, 1.8000281, 1.052091, 3.3534837]
[6.2745528, 1.5831927, 1.0885662, 3.6027939]
[6.2997904, 1.7907014, 0.92975831, 3.5793307]
[5.8368464, 1.5345997, 0.8509813, 3.4512656]


INFO:adlframework.datasource:Looped the datasource
INFO:adlframework.datasource:Shuffling the datasource


[6.3196797, 1.6365561, 1.2502819, 3.4328418]
