In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

import tensorflow as tf
from dataset import *
from model import *
import time
import generated.fragment_resolver_pb2 as fragment_resolver


In [None]:
cwd = os.getcwd()
sample_rate = 32758
min_duration_sec = 5
num_grid_cells = 20
confidence = 0.5
min_overlap_sec = min_duration_sec / num_grid_cells * 4
yolo_model_weights_path = os.path.join(cwd, '..', 'results', 'trained_models', '7_2_cpt.h5')
model_name = 'model7_2'
input_length = int(sample_rate * min_duration_sec)

model_params = {
    fragment_resolver.ResolvedTransformer.Type.SILENCE: {'silenceDurationUs': LinearTransformerNormalizerLayer(5e5, in_dtype=tf.int64)},
#     fragment_resolver.ResolvedTransformer.Type.TYPE2: {
#               'typeType2Param2': LinearTransformerNormalizerLayer(1, in_dtype=tf.int64), 
#               'typeType2Param3': LinearTransformerNormalizerLayer(2, in_dtype=tf.int64), 
#               'typeType2Param4': LinearTransformerNormalizerLayer(10, in_dtype=tf.int64)
#              },
#     fragment_resolver.ResolvedTransformer.Type.TYPE3: {
#               'typeType3Param5': LinearTransformerNormalizerLayer(5, in_dtype=tf.int64), 
#               'typeType3Param6': LinearTransformerNormalizerLayer(5, in_dtype=tf.int64)
#              }
}

config = fragment_resolver.FragmentResolverModelConfig()
config.sampleRate = sample_rate
encoding_type = 'CENTER_DURATION'
fragments_dtype = tf.float32.name

with open(os.path.join(cwd, 'generated', 'descriptor_set.desc'), 'rb') as desc_file:
    protobuf_descriptor = desc_file.read()

fragment_encoder = ProtoFragmentBatchEncoderLayer(sample_rate, model_params, protobuf_descriptor, fragments_dtype)    
audio_requests = tf.keras.layers.Input(1, dtype=tf.string, name='fragment_resolver_model_requests')
ragged_samples = AudioProcessRequestDecoderLayer(protobuf_descriptor, name='ragged_samples')(audio_requests)
padded_samples = AudioDataPadderLayer(sample_rate, min_duration_sec, min_duration_sec * 0.01, name='padded_samples')(ragged_samples)
frames_of_samples, frame_offsets = AudioDataUniformSplitterLayer(sample_rate, min_duration_sec, min_overlap_sec, min_duration_sec * 0.01, 'SAMPLE', name='splitted_samples')(padded_samples)
predicted_yolo_output_frames_batch = YoloLayer(input_length, num_grid_cells, fragment_encoder.transformer_output_length, yolo_model_weights_path, name='yolo_model')(frames_of_samples)
decoded_fragments = YoloOutputBatchDecoderLayer(input_length, confidence, encoding_type, fragments_dtype, name='yolo_output_decoder')(predicted_yolo_output_frames_batch, frame_offsets)
resolved_fragments = FragmentBatchResolverLayer(sample_rate, min_duration_sec, num_grid_cells, name='fragment_resolver')(decoded_fragments, frame_offsets)
encoded_fragment_protos = ProtoFragmentBatchDecoderLayer(sample_rate, model_params, protobuf_descriptor, name='audio_proto_encoder')(resolved_fragments)
resolved_fragments = tf.keras.layers.Lambda(lambda x: x, name='resolved_fragments_responses')(encoded_fragment_protos)

model = FragmentResolverModel(audio_requests, resolved_fragments, config, name='my_model')

audio_decoder = AudioDecoder(sample_rate)

# test_filepath1 = os.path.join(cwd, '..', 'data', 'clips', 'normalized', 'test1.mp3')
# test_filepath2 = os.path.join(cwd, '..', 'data', 'clips', 'normalized', 'test2.mp3')

# a1 = audio_decoder.decode(test_filepath1)
# a2 = audio_decoder.decode(test_filepath2)

# in1 = fragment_resolver.FragmentResolverModelRequest()
# in1.audioSamplesChannel1 = a1.numpy().tobytes()
# a1 = tf.constant(in1.SerializeToString())

# in2 = fragment_resolver.FragmentResolverModelRequest()
# in2.audioSamplesChannel1 = a2.numpy().tobytes()
# a2 = tf.constant(in2.SerializeToString())

# print(model.resolve(tf.reshape(a1, [-1, 1])))
# print(model.resolve(tf.reshape(tf.stack([a1, a2, a1]), [-1, 1])))
# print(model.config())

display(model.summary())

save_path = os.path.join(cwd, '..', 'results', 'saved_models', model_name) 
signatures = {
    'resolve': model.resolve,
    'config': model.config
}
tf.keras.models.save_model(model, save_path, include_optimizer=False, save_traces=True, signatures=signatures)

[l.output_shape for l in model.layers]

In [None]:
new_model = tf.keras.models.load_model(
    save_path, compile=False, 
    custom_objects={
        'AudioProcessRequestDecoderLayer': AudioProcessRequestDecoderLayer,
        'AudioDataPadderLayer': AudioDataPadderLayer,
        'AudioDataUniformSplitterLayer': AudioDataUniformSplitterLayer,
        'YoloLayer': YoloLayer,
        'YoloOutputBatchDecoderLayer': YoloOutputBatchDecoderLayer,
        'FragmentBatchResolverLayer': FragmentBatchResolverLayer,
        'ProtoFragmentBatchDecoderLayer': ProtoFragmentBatchDecoderLayer,
        'LinearTransformerNormalizerLayer': LinearTransformerNormalizerLayer
    }
)
# print(new_model(tf.stack([a1, a2, a1, a2])))
config = fragment_resolver.FragmentResolverModelConfig()
config.ParseFromString(new_model.config()['config'].numpy())
print(config)
new_model.summary()