In [1]:
# import the necessary packages
from musicsearchmodel.dataset import TripletGenerator
from musicsearchmodel.model import get_embedding_module
from musicsearchmodel.model import get_siamese_network
from musicsearchmodel.model import SiameseModel
from musicsearchmodel.dataset import MapFunction
from musicsearchmodel import config
from tensorflow import keras
from tensorflow.keras.callbacks import ModelCheckpoint
import tensorflow as tf
import os

# create the data input pipeline for train and val dataset
print("[INFO] building the train and validation generators...")
trainTripletGenerator = TripletGenerator(
    datasetPath=config.TRAIN_DATASET)
valTripletGenerator = TripletGenerator(
    datasetPath=config.TRAIN_DATASET)
print("[INFO] building the train and validation `tf.data` dataset...")
trainTfDataset = tf.data.Dataset.from_generator(
    generator=trainTripletGenerator.get_next_element,
    output_types=(tf.string, tf.string, tf.string),
    output_shapes=(tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([]))
)
valTfDataset = tf.data.Dataset.from_generator(
    generator=trainTripletGenerator.get_next_element,
    output_types=(tf.string, tf.string, tf.string),
    output_shapes=(tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([]))
)

[INFO] building the train and validation generators...
[INFO] building the train and validation `tf.data` dataset...


Training_1

For reference:

- BATCH_SIZE = 64
- STEPS_PER_EPOCH = 71
- EPOCHS = 10
- MARGIN = 0.5
- OUTPUT_PATH = "output"
- MODEL_PATH = os.path.join(OUTPUT_PATH, "siamese_network")

In [None]:
# preprocess the images
mapFunction = MapFunction(imageSize=config.IMAGE_SIZE)
print("[INFO] building the train and validation `tf.data` pipeline...")
trainDs = (trainTfDataset
    .map(mapFunction)
    .shuffle(config.BUFFER_SIZE)
    .batch(config.BATCH_SIZE)
    .prefetch(config.AUTO)
)
valDs = (valTfDataset
    .map(mapFunction)
    .batch(config.BATCH_SIZE)
    .prefetch(config.AUTO)
)
# build the embedding module and the siamese network
print("[INFO] build the siamese model...")
embeddingModule = get_embedding_module(imageSize=config.IMAGE_SIZE)
siameseNetwork =  get_siamese_network(
    imageSize=config.IMAGE_SIZE,
    embeddingModel=embeddingModule,
)
siameseModel = SiameseModel(
    siameseNetwork=siameseNetwork,
    margin=0.5,
    lossTracker=keras.metrics.Mean(name="loss"),
)
# compile the siamese model
siameseModel.compile(
    optimizer=keras.optimizers.Adam(config.LEARNING_RATE)
)
# train and validate the siamese model
print("[INFO] training the siamese model...")
siameseModel.fit(
    trainDs,
    steps_per_epoch=config.STEPS_PER_EPOCH,
	validation_data=valDs,
	validation_steps=config.VALIDATION_STEPS,
	epochs=config.EPOCHS,
)
# check if the output directory exists, if it doesn't, then
# create it
if not os.path.exists(config.OUTPUT_PATH):
	os.makedirs(config.OUTPUT_PATH)
# save the siamese network to disk
modelPath = config.MODEL_PATH
print(f"[INFO] saving the siamese network to {modelPath}...")
keras.models.save_model(
	model=siameseModel.siameseNetwork,
	filepath=modelPath,
	include_optimizer=False,
)

[INFO] building the train and validation `tf.data` pipeline...
[INFO] build the siamese model...
[INFO] training the siamese model...
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30

In [3]:
# check if the output directory exists, if it doesn't, then
# create it
if not os.path.exists(config.OUTPUT_PATH):
	os.makedirs(config.OUTPUT_PATH)
# save the siamese network to disk
modelPath = config.MODEL_PATH
print(f"[INFO] saving the siamese network to {modelPath}...")
keras.models.save_model(
	model=siameseModel.siameseNetwork,
	filepath=modelPath,
	include_optimizer=False,
)

[INFO] saving the siamese network to output\siamese_network...
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: output\siamese_network\assets
