In [1]:
from get_training_data import *

import tensorflow as tf
import keras
import pickle
import gzip
import os
import random
import matplotlib.pyplot as plt
%matplotlib nbagg

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
with gzip.open('train_cases.pkl.gz','rb') as fp:
    train_cases = pickle.load(fp)
type(train_cases), len(train_cases)

(tuple, 2)

In [3]:
train_X = train_cases[0]
train_y = train_cases[1]
ntrain = train_X.shape[0]
type(train_X), train_X.shape, train_X.dtype, type(train_y), train_y.shape, train_y.dtype

(numpy.ndarray,
 (16384, 16, 16, 128),
 dtype('float32'),
 numpy.ndarray,
 (16384, 1),
 dtype('float64'))

In [4]:
with gzip.open('valid_cases.pkl.gz','rb') as fp:
    valid_cases = pickle.load(fp)
valid_X = valid_cases[0]
valid_y = valid_cases[1]
nvalid = valid_X.shape[0]
type(valid_X), valid_X.shape, valid_X.dtype, type(valid_y), valid_y.shape, valid_y.dtype

(numpy.ndarray,
 (4096, 16, 16, 128),
 dtype('float32'),
 numpy.ndarray,
 (4096, 1),
 dtype('float64'))

In [5]:
def choose_training_case():
    i = random.randint(0, ntrain-1)
    return(train_X[i,:,:,:], train_y[i,:])

def choose_validation_case():
    i = random.randint(0, nvalid-1)
    return(valid_X[i,:,:,:], valid_y[i,:])

In [6]:
ADD_CONST = 1e-5
CENTER = -2.5
SCALE = 5
XFORM = True

In [7]:
def preproc(a):
    if XFORM:
        return((np.log(a+ADD_CONST) - CENTER) / SCALE)
    else:
        return(a)

In [8]:
def make_batch_from_megabatch(batchsize=32, validation=False):
    get_case = choose_validation_case if validation else choose_training_case    
    xbatch = np.empty([0, INPUTS_PER_BEAT, NBEATS, NCHANNELS],dtype=np.float32)
    ybatch = np.empty([0, 1])
    for i in range(batchsize):
        case = get_case()
        x = np.expand_dims(case[0], axis=0)
        xbatch = np.concatenate([xbatch, x], axis=0)
        y = np.array(int(case[1])).reshape([1,1])
        ybatch = np.concatenate([ybatch, y], axis=0)
    return(preproc(xbatch), ybatch)

def train_gen(batchsize=32):
    while True:
        yield(make_batch_from_megabatch(batchsize, validation=False))

def valid_gen(batchsize=32):
    while True:
        yield(make_batch_from_megabatch(batchsize, validation=True))

In [9]:
# The generators are way too slow: have to fix this
for t in train_gen():
    break
t[0].shape, t[1].shape

((32, 16, 16, 128), (32, 1))

In [10]:
for v in valid_gen():
    break
v[0].shape, v[1].shape

((32, 16, 16, 128), (32, 1))

In [11]:
input_shape = (INPUTS_PER_BEAT, NBEATS, NCHANNELS)

model = keras.Sequential()
model.add(keras.layers.Conv2D(64, 1, activation='relu', input_shape=input_shape))
model.add(keras.layers.SpatialDropout2D(.1))
model.add(keras.layers.Conv2D(32, 1, activation='relu'))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Conv2D(16, (2,4), activation='relu'))
model.add(keras.layers.SpatialDropout2D(.3))
model.add(keras.layers.Conv2D(25, (4,6), activation='relu'))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Conv2D(3, (2,4), activation='relu'))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dropout(.6))
model.add(keras.layers.Dense(1, activation='sigmoid'))

In [12]:
model.compile(loss='binary_crossentropy', 
              optimizer=keras.optimizers.Adam(lr=.003,decay=2e-4), 
              metrics=['accuracy'])

In [13]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 16, 16, 64)        8256      
_________________________________________________________________
spatial_dropout2d_1 (Spatial (None, 16, 16, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 16, 16, 32)        2080      
_________________________________________________________________
batch_normalization_1 (Batch (None, 16, 16, 32)        128       
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 15, 13, 16)        4112      
_________________________________________________________________
spatial_dropout2d_2 (Spatial (None, 15, 13, 16)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 12, 8, 25)         9625      
__________

In [14]:
valid=valid_gen()
train=train_gen()
model.fit_generator(train, 
                    validation_data=valid, validation_steps=64,
                    epochs=60, 
                    steps_per_epoch=64)

Epoch 1/60
Epoch 2/60
Epoch 3/60
Epoch 4/60
Epoch 5/60
Epoch 6/60
Epoch 7/60
Epoch 8/60
Epoch 9/60
Epoch 10/60
Epoch 11/60
Epoch 12/60
Epoch 13/60
Epoch 14/60
Epoch 15/60
Epoch 16/60
Epoch 17/60
Epoch 18/60
Epoch 19/60
Epoch 20/60
Epoch 21/60
Epoch 22/60
Epoch 23/60
Epoch 24/60
Epoch 25/60
Epoch 26/60
Epoch 27/60
Epoch 28/60
Epoch 29/60
Epoch 30/60
Epoch 31/60
Epoch 32/60
Epoch 33/60
Epoch 34/60
Epoch 35/60
Epoch 36/60
Epoch 37/60
Epoch 38/60
Epoch 39/60
Epoch 40/60
Epoch 41/60
Epoch 42/60
Epoch 43/60
Epoch 44/60
Epoch 45/60
Epoch 46/60
Epoch 47/60
Epoch 48/60
Epoch 49/60
Epoch 50/60
Epoch 51/60
Epoch 52/60
Epoch 53/60
Epoch 54/60
Epoch 55/60
Epoch 56/60
Epoch 57/60
Epoch 58/60
Epoch 59/60
Epoch 60/60


<keras.callbacks.History at 0x7ff74a6da6d8>

In [15]:
for i in range(10):
    song, tempo, compatible, clip = get_validation_case()
    print( song, tempo )
    c = clip_to_tf_input(resample_clip(clip))
    x = np.expand_dims(preproc(c), axis=0)
    p = model.predict(x)[0][0]
    print( 'Predicted: ', p, '   Actual: ', compatible, '\n' )

ImLookingThruYou 69
Predicted:  0.0006818745    Actual:  False 

HangingOnTheTelephone 99
Predicted:  0.00021099366    Actual:  False 

BothSidesNow 85
Predicted:  0.048795793    Actual:  False 

EternalFlame 80
Predicted:  0.7326286    Actual:  True 

WalkThisWay1 49
Predicted:  0.00021785329    Actual:  False 

IThinkWereAloneNow2MP 172
Predicted:  0.0102476785    Actual:  False 

09 Suddenly I See 144
Predicted:  0.003922112    Actual:  False 

Kalamazoo 185
Predicted:  0.0029575678    Actual:  False 

CaliforniaDreamin2MP 56
Predicted:  0.87617373    Actual:  True 

HangingOnTheTelephone 99
Predicted:  0.0011896356    Actual:  False 

