In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm.auto import trange
tf.random.set_seed(314) ##314
np.random.seed(999)  ##999
import json
from model import AWM
%matplotlib inline

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

print("REPLICAS: ", strategy.num_replicas_in_sync)

#### Notes on data preparation:
* Need to import your own Audio and Image Dataset. For our training run we used the Speech Commands Dataset from Kaggle and resampled with a sampling rate of 8192. For images we used the MS COCO Dataset rescaled to 128x128 pixels.
* Since we are training the model in TPU, it is necessary to offload the dataset into the memory first. That's why we are using preprocessed image and audio numpy blocks for dataset.
* Audio dataset expected numpy shape: (Dataset_length, 8192, 1).
* Image dataset expected numpy shape: (Dataset_length, 128, 128, 3).

In [None]:
AUTO = tf.data.experimental.AUTOTUNE
def get_dataset(batch_size = 128, dataset_type = 'train', audio_scale = 0.8, repeat=True):
  ## Import your audio data here
  audio_data = np.load(f'path_audio_numpy_block')

  ## Import your image data here
  img_data = np.load(f'path_image_numpy_block')[:len(audio_data)]
  
  print(f'{dataset_type} size: {len(audio_data)}')

  def augment(img, audio):
    img = tf.cast(img, tf.float32)
    img /= 255.0
    audio *= audio_scale
    return img, audio

  dataset = tf.data.Dataset.from_tensor_slices((img_data, audio_data))
  dataset = dataset.map(augment)

  if repeat:
    dataset = dataset.repeat()
    dataset = dataset.shuffle(len(img_data))
  dataset = dataset.batch(batch_size, drop_remainder=True)
  dataset = dataset.prefetch(AUTO)
  return dataset

In [None]:
BATCH_SIZE = 128
PER_REPLICA_BATCH_SIZE = BATCH_SIZE // strategy.num_replicas_in_sync

train_dataset = strategy.distribute_datasets_from_function(
  lambda _ : get_dataset(PER_REPLICA_BATCH_SIZE)
)

val_dataset = strategy.distribute_datasets_from_function(
  lambda _ : get_dataset(PER_REPLICA_BATCH_SIZE, 'val')
)

In [None]:
train_iterator = iter(train_dataset)
val_iterator = iter(val_dataset)

In [None]:
def loss_function(pred, orig_audio, marked=None, orig_img=None):
  img_loss = 10*tf.reduce_mean(tf.keras.losses.mean_squared_error(orig_img, marked))
  audio_loss = 20*tf.reduce_mean(tf.keras.losses.mean_squared_error(orig_audio, pred))
  return img_loss, audio_loss

In [None]:
@tf.function
def train_step(iterator, steps_per_epoch):
  def train_step_fn(inputs):
    images, audios = inputs
    with tf.GradientTape() as tape:
      preds, enc, marked = o(audios, images)
      img_loss, aud_loss = loss_fn(preds, audios, marked, images)
      loss = img_loss + aud_loss
      ssim = tf.image.ssim(images, marked, 1.0)
    gradients = tape.gradient(loss, o.trainable_variables)
    optim.apply_gradients(zip(gradients, o.trainable_variables))
    training_loss.update_state(loss * strategy.num_replicas_in_sync)
    training_accuracy.update_state(audios, preds)
    training_ssim.update_state(ssim)

  for _ in tf.range(steps_per_epoch):
    strategy.run(train_step_fn, args=(next(iterator),))

@tf.function
def val_step(iterator, steps_per_epoch):
  def val_step_fn(inputs):
    images, audios = inputs
    preds, enc, marked = o(audios, images)
    img_loss, aud_loss = loss_fn(preds, audios, marked, images)
    loss = img_loss + aud_loss
    ssim = tf.image.ssim(images, marked, 1.0)
    val_loss.update_state(loss * strategy.num_replicas_in_sync)
    val_accuracy.update_state(audios, preds)
    val_ssim.update_state(ssim)

  for _ in tf.range(steps_per_epoch):
    strategy.run(val_step_fn, args=(next(iterator),))

with strategy.scope():
  o = AWM()
#   a = A2A()
#   load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
#   o.load_weights('../input/trained-weights/AWM_1091_E500/AWM_1091_E500/AWM_109', options=load_options)
#   a.load_weights('../input/trained-weights/AWM_ENC_DEC/AWM_ENC_DEC/AWM_ENC_DEC_PRETRAINED_WEIGHTS_TPU', options=load_options)
#   o.encoder = a.encoder
#   o.decoder = a.decoder

  optim = tf.keras.optimizers.Adam(learning_rate=0.0001)
  training_loss = tf.keras.metrics.Mean()
  val_loss = tf.keras.metrics.Mean()
  training_accuracy = tf.keras.metrics.RootMeanSquaredError()
  val_accuracy = tf.keras.metrics.RootMeanSquaredError()
  training_ssim = tf.keras.metrics.Mean()
  val_ssim = tf.keras.metrics.Mean()
  loss_fn = loss_function

In [None]:
STEPS_PER_EPOCH_TRAIN = 42643 // BATCH_SIZE
STEPS_PER_EPOCH_VAL = 9768 // BATCH_SIZE

EPOCHS = 200
losses = []
losses_val = []
accuracies = []
accuracies_val = []
ssims = []
ssims_val = []

for epoch in trange(EPOCHS):
  train_step(train_iterator, STEPS_PER_EPOCH_TRAIN)
  val_step(val_iterator, STEPS_PER_EPOCH_VAL)
  
  losses.append(float(training_loss.result().numpy()))
  losses_val.append(float(val_loss.result().numpy()))
  accuracies.append(float(training_accuracy.result().numpy()))
  accuracies_val.append(float(val_accuracy.result().numpy()))
  ssims.append(float(training_ssim.result().numpy()))
  ssims_val.append(float(val_ssim.result().numpy()))
  training_loss.reset_states()
  training_accuracy.reset_states()
  training_ssim.reset_states()
  val_loss.reset_states()
  val_accuracy.reset_states()
  val_ssim.reset_states()

In [None]:
plt.figure(figsize=(16, 9), dpi=150)
plt.plot(losses)
plt.plot(losses_val)
plt.title(f'Loss vs Epoch')
plt.legend(['Training Loss', 'Validation Loss'])
plt.grid(ls='--')
plt.show()

In [None]:
plt.figure(figsize=(16, 9), dpi=150)
plt.plot(accuracies)
plt.plot(accuracies_val)

plt.title(f'Accuracy vs Epoch')
plt.legend(['Training Accuracy', 'Validation Accuracy'])
plt.grid(ls='--')
plt.show()

In [None]:
plt.figure(figsize=(16, 9), dpi=150)
plt.plot(ssims)
plt.plot(ssims_val)

plt.title(f'Marked Image SSIM vs Epoch')
plt.legend(['Training SSIM', 'Validation SSIM'])
plt.grid(ls='--')
plt.show()

In [None]:
print(f'Max SSIM Score {max(ssims_val)} at epoch {ssims_val.index(max(ssims_val))+1}')

In [None]:
def visualize_output(model, dataset, samples=1, output_loss=False):
  imgs, audios = next(dataset)
  preds, enc, marked = model(audios, imgs)

  for i in range(samples):
    fig = plt.figure(figsize=(12, 6))

    fig.add_subplot(241)
    plt.title('Cover Image')
    plt.imshow(imgs[i])

    fig.add_subplot(242)
    plt.title('Watermark')
    plt.plot(audios[i])

    fig.add_subplot(243)
    plt.title('Marked')
    plt.imshow(marked[i])

    fig.add_subplot(244)
    plt.title('Decoded Watermark')
    plt.plot(preds[i])
    plt.show()

  if output_loss:
    ssim = tf.reduce_mean(tf.image.ssim(imgs, marked, 1.0))
    print('img ssim loss:', ssim.numpy())

    loss = loss_function(preds, audios, marked, imgs)
    print('img loss:', loss[0].numpy())
    print('audio loss:', loss[1].numpy())

In [None]:
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
o.save_weights('trained/AWM_109', options=save_options)

In [None]:
test_ds = get_dataset(32, dataset_type='test')
test_iter = iter(test_ds)

In [None]:
visualize_output(o, test_iter, 32, False)

In [None]:
def save_history(array=[], desc=''):
  with open(f'{desc}.json', 'w', encoding='utf-8') as f:
    json.dump(array, f, ensure_ascii=False, indent=2)

In [None]:
save_history(losses, 'train_loss')
save_history(losses_val, 'val_loss')
save_history(accuracies, 'train_acc')
save_history(accuracies_val, 'val_acc')
save_history(ssims, 'train_ssim')
save_history(ssims_val, 'val_ssim')