In [None]:
import tensorflow as tf
import os

from kerastuner.tuners import RandomSearch
from tensorflow.keras.layers import Input, Dense, Conv1D, Flatten, Dropout, BatchNormalization
from tensorflow.keras.layers import UpSampling1D, Reshape, GRU, Lambda

from tfrecord_provider import CompleteTFRecordProvider

print(tf.__version__)
print(tf.config.list_physical_devices())

In [None]:
def create_model(hp):
    latent_dim = hp.Choice('latent_dim', values=[8, 16, 32, 48])
    input_ = Input(shape=(1000, 16), name='z_input')
    
    x = input_
    
    for i in range(0, 4):
        num_filters = 2**(i + 4)
        x = Conv1D(num_filters, 7, strides=2, activation='relu', name=f'conv_{i + 1}')(x)
        x = BatchNormalization(name=f'bn_{i + 1}')(x)
    
    x = Flatten(name='flatten')(x)
    x = Dropout(0.4)(x)
    
    x = Dense(latent_dim, activation='relu')(x)
    x = Reshape((latent_dim, 1), name='reshape')(x)
    
    for i in range(3, 0, -1):
        num_filters = 2**(8 - i)
        x = UpSampling1D(2, name=f'up_{i + 1}')(x)
        x = Conv1D(num_filters, 7, activation='relu', name=f'up_conv_{i + 1}')(x)
        
    x = GRU(512, return_sequences=True, name='gru_1')(x)
    x = GRU(1000, return_sequences=True, name='gru_2')(x)
    x = Lambda(lambda x: tf.transpose(x, [0, 2, 1]), name='lambda')(x)
    
    output_ = Dense(16, activation='linear', name='z_output')(x)
    
    model = tf.keras.models.Model(
        input_, output_,
        name='ae'
    )
    
    model.compile(
        loss='mse',
        optimizer='adam'
    )
    return model

In [None]:
tuner = RandomSearch(
    create_model,
    objective='val_loss',
    max_trials=4,
    executions_per_trial=1,
    directory='trials',
    project_name='z_latent_dimension_search'
)

In [None]:
tuner.search_space_summary()

In [None]:
def complete_record_generator(dataset_dir, set_name, batch_size):
    tfrecord_path = os.path.join(dataset_dir, set_name, 'complete.tfrecord*')
    return iter(CompleteTFRecordProvider(tfrecord_path).get_batch(batch_size=batch_size))


def data_generator(complete_data_generator):
    while True:
        features = next(complete_data_generator)
        batch_size = features['z'].shape[0]
        z = tf.reshape(features['z'], (batch_size, 1000, 16))
        inputs = {'z_input': z}
        outputs = {'z_output': z}
        yield inputs, outputs

In [None]:
batch_size = 16

train_generator = complete_record_generator('d:/soundofai/complete_data', 'train', batch_size)
valid_generator = complete_record_generator('d:/soundofai/complete_data', 'valid', batch_size)

total_examples = 32690
validation_examples = 2081
steps = int(total_examples / batch_size)
validation_steps = int(validation_examples / batch_size)

In [None]:
tuner.search(
    data_generator(train_generator),
    epochs=10,
    steps_per_epoch=steps,
    validation_data=data_generator(valid_generator),
    validation_steps=validation_steps
)