### Open in [Google Colab](https://colab.research.google.com/drive/12jQXxqLUIfRZmdsAFX5ENLr2evztxIZO?usp=sharing)

# Setup

In [1]:
%load_ext tensorboard

In [None]:
try:
  from google.colab import drive
  drive.mount('/content/drive', force_remount=True)
  path_top = '/content/drive/MyDrive/Master Thesis/repo'
  !cp "$path_top/dataset.zip" .
  !unzip -q dataset.zip
  !pip install ltn -q
  !pip install tensorflow-io -q
  !pip install tensorflow-addons -q
  !cp "$path_top/code/params.py" .
  !cp "$path_top/code/melspec.py" .
except Exception:
  path_top = ''

Mounted at /content/drive


In [None]:
import os
import time
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_io as tfio
import tensorflow_addons as tfaddons
import datetime
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow_datasets as tfds
import IPython
from tensorflow.keras import layers
from scipy.io import wavfile as wf
from scipy.stats import norm

import params
from melspec import MelSpec

In [None]:
LABEL_RATE  = 2  # rate of labels in Hz

audio_path  = 'dataset/audio'
label_path  = 'dataset/DEAM_Annotations/annotations/annotations averaged per song/dynamic (per second annotations)'

---
# Dataset

## Load Audio Data

In [None]:
def load_audio_data():
    ds_audio_train = tf.keras.utils.audio_dataset_from_directory(
        audio_path,
        labels=None,
        batch_size=None,
        shuffle=False,
        seed=None,
        output_sequence_length=(45 * params.SAMPLE_RATE), # 45 seconds
        follow_links=True
    )

    return ds_audio_train

## Audio Preprocessing

### Skip first 15 seconds
The first 15 seconds are not annotated. Thus, if labels are used, the first 15 seconds of audio should be skipped.

In [None]:
_15_SEC_SKIP = 15 * params.SAMPLE_RATE

def skip15sec(data):
    return data[_15_SEC_SKIP:]

### Make mono

In [None]:
def to_mono(data):
    return data[...,0] # take first channel

### Reshape data to match labels

In [None]:
def reshape_data(data):
    if (len(data) == 0):
        return data
    #n_channels = 2
    return tf.reshape(data, (LABEL_RATE * 30, params.SAMPLE_RATE//LABEL_RATE, ))#n_channels))

In [None]:
# cuts a full-song spectrogram into pieces according to the input size of the model
def sequence_spect(data, seq_len=params.SEQ_LENGTH, lstm_window=params.LSTM_WINDOW, bins=params.MEL_BINS):
    if len(data) == 0:
        return data
    trim = len(data) - seq_len * lstm_window
    i = np.random.randint(trim+1)
    trim = trim - i
    data = data[i:-trim if trim else None]
    if (len(data) != (seq_len * lstm_window)):
        raise Exception("Uh oh")

#    data = data[:seq_len * lstm_window]

    return tf.reshape(data, (seq_len, lstm_window, bins))

In [None]:
# changes (batch_size, seq_length) into (batch_size * seq_length,)
def unsequence_spect(data, spec_shape = (params.LSTM_WINDOW, params.MEL_BINS)):
    return tf.reshape(data, (-1, *spec_shape))

### Mel-scaled Fourier transform

In [None]:
# create melspec layer for audio-to-spectrogram transformation
mel_spec = MelSpec(
    frame_length = params.STFT_WINDOW,
    frame_step = params.STFT_HOP,
    sampling_rate = params.SAMPLE_RATE,
    num_mel_channels = params.MEL_BINS,
    freq_min = params.MEL_FREQ_MIN,
    freq_max = params.MEL_FREQ_MAX,
    min_db = -params.DB_MIN,
)

# wrapper function for melspec layer
def mel_spectrogram(data):
    S = mel_spec(tf.expand_dims(data, -1))
    return S

### Audio Reconstruction (Griffin-Lim)

In [None]:
# Implements Griffin-Lim for audio reconstruction
def mel_to_audio(S):
    S = tf.cast(S, tf.float32)

    # mel to linear
    mel_matrix = tf.signal.linear_to_mel_weight_matrix(
        num_mel_bins=params.MEL_BINS,
        num_spectrogram_bins=params.STFT_WINDOW//2+1,
        sample_rate=params.SAMPLE_RATE,
        lower_edge_hertz=params.MEL_FREQ_MIN,
        upper_edge_hertz=params.MEL_FREQ_MAX,
        dtype=tf.dtypes.float32
    )
    with np.errstate(divide="ignore", invalid="ignore"):
        mel_inversion_matrix = tf.constant(
            np.nan_to_num(
                np.divide(mel_matrix.numpy().T, np.sum(mel_matrix.numpy(), axis=1))
            ).T
        )
    S = tf.tensordot(S, tf.transpose(mel_inversion_matrix), 1)
    
    
    # dB to amplitude
    S = tf.pow(tf.ones(tf.shape(S)) * 10.0, (S) / 20)  # 10^(dB / 20)
    
    # Griffin-Lim:
    return tfio.audio.inverse_spectrogram(S, params.STFT_WINDOW, params.STFT_WINDOW, params.STFT_HOP, iterations=params.GRIFFIN_LIM_ITER).numpy()

### Map preprocessing to dataset

In [None]:
ds_audio_train = load_audio_data()

ds_audio_train = ds_audio_train.map(skip15sec)
ds_audio_train = ds_audio_train.map(to_mono)
ds_spect_train = ds_audio_train.map(mel_spectrogram)
ds_spect_train = ds_spect_train.map(sequence_spect)
#ds_spect_train = ds_spect_train.map(unsequence_spect)
ds_spect_train = ds_spect_train.unbatch()

### Check audio data

In [None]:
## S := TxF matrix
def show_spectrogram(S, title='', n_spect=1, scale=5, rhythm_plots=True):
    #fig, ax = plt.subplots(ncols=1, figsize=(15,4))
    #cax = ax.matshow(S.T, aspect='auto', origin='lower')
    #fig.colorbar(cax)
    plt.figure(figsize=(scale*n_spect, scale))
    plt.imshow(S.T, origin='lower', vmin=-40, vmax=70)
    plt.axis('off')
    plt.title(title)
    plt.show()
    plt.close()
    
    if rhythm_plots:
        x = tf.reduce_mean(S, axis=1, keepdims=True)
        fig = plt.figure(figsize=(7, 2))
        ax = fig.add_subplot(1, 2, 1)
        ax.plot(x)
        
        y = tf.abs(tf.signal.stft(
            tf.squeeze(x),
            x.shape[0],
            1,
            window_fn=None
        ))
        y = y[:, params.RL_LOW_BOUND:params.RL_HI_BOUND]
        y = tf.reshape(y, (y.shape[1], y.shape[0]))
        ax = fig.add_subplot(1, 2, 2)
        ax.plot(y)
        plt.show()
        plt.close()

def show_audio(S):
    S = tf.squeeze(S)
    rec_aud = mel_to_audio(S)
    wf.write('rec.wav', params.SAMPLE_RATE, rec_aud)
    rec = IPython.display.Audio('rec.wav')
    display(rec)

In [None]:
# FOR SPECTROGRAMS
skip = 0
n = 1
for x in ds_spect_train.skip(skip).take(1):
    x = x.numpy()
    print(np.min(x), np.max(x))
    print(x.shape)
    show_spectrogram(x)
    print("Reconstructed Audio:")
    show_audio(x)

_="""
# AUDIO for reference
aud = None
for a in ds_audio_train.skip(skip).take(n):
    if aud is None:
       aud = a.numpy()
    else:
        aud = np.concatenate((aud, a), axis=0)

wf.write('test.wav', params.SAMPLE_RATE, aud)
test = Audio('test.wav')
display(test)
"""

## Load Labels

In [None]:
ds_valence = pd.read_csv(f"{label_path}/valence.csv")
ds_valence = ds_valence.dropna(axis=1)
ds_valence = ds_valence.drop(columns='song_id')
#ds_valence

In [None]:
ds_arousal = pd.read_csv(f"{label_path}/arousal.csv")
ds_arousal = ds_arousal.dropna(axis=1)
ds_arousal = ds_arousal.drop(columns='song_id')
#ds_arousal

### average labels for each spectrogram

In [None]:
n_lab_per_spect = 6

ds_arousal = ds_arousal.transpose()
ds_arousal = ds_arousal.groupby(np.arange(len(ds_arousal.index))//n_lab_per_spect).mean()
ds_arousal = ds_arousal.transpose()
ds_arousal = np.reshape(ds_arousal.to_numpy(), (len(ds_spect_train), 1))

ds_valence = ds_valence.transpose()
ds_valence = ds_valence.groupby(np.arange(len(ds_valence.index))//n_lab_per_spect).mean()
ds_valence = ds_valence.transpose()
ds_valence = np.reshape(ds_valence.to_numpy(), (len(ds_spect_train), 1))

In [None]:
ARO_MEAN, ARO_STD = ds_arousal.mean(), ds_arousal.std()

_ = plt.hist(ds_arousal, bins=100, density=True, range=[-1, 1])
plt.plot(np.arange(-1, 1, 0.01), norm.pdf(np.arange(-1, 1, 0.01), ARO_MEAN, ARO_STD))
plt.title("Arousal distribution")
print(f"Mean: {ARO_MEAN}\nStddev: {ARO_STD}")

In [None]:
VAL_MEAN, VAL_STD = ds_valence.mean(), ds_valence.std()

_ = plt.hist(ds_valence, bins=100, density=True, range=[-1, 1])
plt.plot(np.arange(-1, 1, 0.01), norm.pdf(np.arange(-1, 1, 0.01), VAL_MEAN, VAL_STD))
plt.title("Valence distribution")
print(f"Mean: {VAL_MEAN}\nStddev: {VAL_STD}")

## Combine audio and labels

In [None]:
ds_arousal = tf.data.Dataset.from_tensor_slices(ds_arousal.astype(np.float32))
ds_valence = tf.data.Dataset.from_tensor_slices(ds_valence.astype(np.float32))

ds_train = tf.data.Dataset.zip((
    ds_spect_train,
    (ds_arousal, ds_valence)
))

In [None]:
ds_train = ds_train.shuffle(buffer_size=80*params.BATCH_SIZE, reshuffle_each_iteration=True)
ds_train = ds_train.batch(params.BATCH_SIZE)
ds_train = ds_train.prefetch(buffer_size=tf.data.AUTOTUNE)

# Model

## Generator


In [None]:
def create_generator(noise_dim = params.GEN_NOISE_DIM):

    emb_size = params.LSTM_WINDOW//32 * params.MEL_BINS//32
    emb_ch = 16
    
    noise_inputs = tf.keras.Input(batch_shape=(None, noise_dim))
    arousal_inputs = tf.keras.Input(batch_shape=(None, 1))
    valence_inputs = tf.keras.Input(batch_shape=(None, 1))
    rhythm_inputs = tf.keras.Input(batch_shape=(None, params.RL_HI_BOUND-params.RL_LOW_BOUND))

    # embed labels
    x_a = layers.Dense(emb_size * emb_ch, activation='relu')(arousal_inputs)
    x_v = layers.Dense(emb_size * emb_ch, activation='relu')(valence_inputs)
    x_r = layers.Dense(emb_size * emb_ch, activation='relu')(rhythm_inputs)

    # noise
    x = layers.Dense(emb_size * emb_ch, activation='relu')(noise_inputs)

    x = tf.stack([x, x_a, x_v, x_r], axis=-1)

    # deconv upsample
    x = layers.Reshape((params.LSTM_WINDOW//32, params.MEL_BINS//32, emb_ch * 4))(x)
    for filters in (512, 512, 256, 256, 128):
        #x = layers.Conv2DTranspose(filters, (3,3), strides=2, padding='same')(x)
        x = layers.UpSampling2D()(x)
        x = layers.Conv2D(filters, (3, 3), strides=1, padding='same')(x)
        x = layers.ReLU()(x)
        x = layers.BatchNormalization()(x)

    x = layers.Conv2D(1, (3,3), strides=1, padding='same', activation='sigmoid')(x)

    spec = layers.Reshape((params.LSTM_WINDOW, params.MEL_BINS))(x)

    outputs = spec
    return tf.keras.Model([noise_inputs, [arousal_inputs, valence_inputs], rhythm_inputs], outputs)

#create_generator().summary()
#tf.keras.utils.plot_model(create_generator(), show_shapes=True)


## Discriminator

In [None]:
class WeightClip(tf.keras.constraints.Constraint):
    '''Clips the weights incident to each hidden unit to be inside a range
    '''
    def __init__(self, c=0.01, **kwargs):
        self.c = c

    def __call__(self, p):
        return tf.keras.backend.clip(p, -self.c, self.c)

def create_discriminator(
    input_shape = (params.LSTM_WINDOW, params.MEL_BINS)):
    
    inputs = tf.keras.Input(shape=input_shape)
    arousal_inputs = tf.keras.Input(batch_shape=(None, 1))
    valence_inputs = tf.keras.Input(batch_shape=(None, 1))
    
    n_emb_upsampling = 3
    emb_shape = (input_shape[0]//(2**n_emb_upsampling), input_shape[1]//(2**n_emb_upsampling))

    x_a = layers.Dense(tf.reduce_prod(emb_shape), activation='relu', kernel_constraint=WeightClip(params.WGAN_D_CONSTRAINT))(arousal_inputs)
    x_v = layers.Dense(tf.reduce_prod(emb_shape), activation='relu', kernel_constraint=WeightClip(params.WGAN_D_CONSTRAINT))(valence_inputs)

    x_a = layers.Reshape((*emb_shape, 1))(x_a)
    x_v = layers.Reshape((*emb_shape, 1))(x_v)
    
    for _ in range(n_emb_upsampling):
        x_a = layers.UpSampling2D()(x_a)
        x_v = layers.UpSampling2D()(x_v)

    # downsample
    x = tf.expand_dims(inputs, axis=-1)
    x = tf.concat((x, x_a, x_v), axis=-1)
    for filters in (64, 64, 128, 128, 256, 256):
        res = layers.Conv2D(filters, 1, strides=2, padding='same', kernel_constraint=WeightClip(params.WGAN_D_CONSTRAINT))(x)

        x = layers.Conv2D(filters, (3,3), strides=2, padding='same', kernel_constraint=WeightClip(params.WGAN_D_CONSTRAINT))(x)
        x = layers.ReLU()(x)
        x = layers.Dropout(0.3)(x)
        x = layers.BatchNormalization()(x)

        x = layers.add([x, res])

    x = layers.Flatten()(x)
    """x = layers.LSTM(
        lstm_units, 
        return_sequences=True, 
        kernel_constraint=d_constraint, 
    )(x)"""
    x = layers.Dense(1, activation='linear', kernel_constraint=WeightClip(params.WGAN_D_CONSTRAINT))(x)

    return tf.keras.Model([inputs, arousal_inputs, valence_inputs], x)

#create_discriminator().summary()
#tf.keras.utils.plot_model(create_discriminator(), show_shapes=True)

In [None]:
def create_patch_discriminator(
    input_shape = (params.PATCH_WIDTH, params.MEL_BINS),
    lstm_units = params.LSTM_UNITS_DISC):
    
    inputs = tf.keras.Input(shape=input_shape)
    arousal_inputs = tf.keras.Input(batch_shape=(None, 1))
    valence_inputs = tf.keras.Input(batch_shape=(None, 1))
    
    n_emb_upsampling = 1
    emb_shape = (input_shape[0]//(2**n_emb_upsampling), input_shape[1]//(2**n_emb_upsampling))

    x_a = layers.Dense(tf.reduce_prod(emb_shape), activation='relu', kernel_constraint=WeightClip(params.WGAN_D_CONSTRAINT))(arousal_inputs)
    x_v = layers.Dense(tf.reduce_prod(emb_shape), activation='relu', kernel_constraint=WeightClip(params.WGAN_D_CONSTRAINT))(valence_inputs)

    x_a = layers.Reshape((*emb_shape, 1))(x_a)
    x_v = layers.Reshape((*emb_shape, 1))(x_v)
    
    for _ in range(n_emb_upsampling):
        x_a = layers.UpSampling2D()(x_a)
        x_v = layers.UpSampling2D()(x_v)


    # downsample
    x = tf.expand_dims(inputs, axis=-1)
    x = tf.concat((x, x_a, x_v), axis=-1)
    last_filter = 1
    for filters in (64, 128, 128, 256):
        res = layers.Conv2D(filters, 1, strides=2, padding='same', kernel_constraint=WeightClip(params.WGAN_D_CONSTRAINT))(x)

        x = layers.ZeroPadding2D(padding=(0, 1))(x)
        x = layers.Conv2D(filters, (2,4), strides=(2, 2), padding='valid', kernel_constraint=WeightClip(params.WGAN_D_CONSTRAINT))(x)
        x = layers.ReLU()(x)
        x = layers.Dropout(0.3)(x)
        x = layers.BatchNormalization()(x)

        x = layers.add([x, res])
        last_filter = filters

    x = layers.Flatten()(x)
    x = layers.Dense(1, activation='linear', kernel_constraint=WeightClip(params.WGAN_D_CONSTRAINT))(x)

    return tf.keras.Model([inputs, arousal_inputs, valence_inputs], x)

#create_patch_discriminator().summary()
#tf.keras.utils.plot_model(create_patch_discriminator(), show_shapes=True)

# Training

## misc functions

### Spectrogram Normalisation layers
* Limits the data to dB between -40 and 70
* Rescales into the range [0, 1]

In [None]:
class NormSpect(layers.Layer):
    def __init__(self, min_dB=-40., max_dB=70., **kwargs):
        super().__init__(**kwargs)
        self.min = min_dB
        self.max = max_dB
        
        r = self.max - self.min
        self.rescale = layers.Rescaling(
            scale=1./r, 
            offset=(-self.min)/r
        )

    def call(self, spect):
        spect = tf.clip_by_value(spect, self.min, self.max)
        spect = self.rescale(spect)
        return spect


class DeNormSpect(layers.Layer):
    def __init__(self, min_dB=-40., max_dB=70., **kwargs):
        super().__init__(**kwargs)
        self.min = min_dB
        self.max = max_dB

        r = self.max - self.min
        self.rescale = layers.Rescaling(
            scale=r,
            offset=self.min
        )
    
    def call(self, spect):
        spect = tf.clip_by_value(spect, 0., 1.)
        spect = self.rescale(spect)
        return spect

normalize_spect = NormSpect()
denormalize_spect = DeNormSpect()

### Random Generator input

In [None]:
@tf.function
def make_latent_noise(batch_size=1):
    g_noises = tf.random.normal((batch_size, params.GEN_NOISE_DIM))
    return g_noises

@tf.function
def make_generator_labels(batch_size=1):
    g_labels = [[
            tf.random.normal((batch_size, 1), mean=ARO_MEAN, stddev=ARO_STD),
            tf.random.normal((batch_size, 1), mean=VAL_MEAN, stddev=VAL_STD),
        ],
        tf.squeeze(
            tf.one_hot(
                tf.random.uniform((batch_size, 1), minval=0, maxval=params.RL_HI_BOUND-params.RL_LOW_BOUND, dtype=tf.int32), 
                params.RL_HI_BOUND-params.RL_LOW_BOUND
            ),
            axis=1
        )
    ]
    return g_labels

@tf.function
def make_fake_data(generator, batch_size=1):
    g_noises = make_latent_noise(batch_size=batch_size)
    g_e_labels, g_r_labels = make_generator_labels(batch_size=batch_size)
    out = generator([g_noises, g_e_labels, g_r_labels])
    return out, g_e_labels


# Uses the generator to generate a single spectrogram
def make_example(generator, noise=None, labels=None):
    if noise is None:
        noise = make_latent_noise()
    if labels is None:
        labels = make_generator_labels()
    out = generator([noise, *labels])
    out = denormalize_spect(out)
    out = tf.squeeze(out, 0)
    return out.numpy()

### Layers for dividing the spectrograms (and labels) into patches for the PatchGAN

In [None]:
class PatchToBatchLayer(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, data):
        return tf.reshape(data, (-1, params.PATCH_WIDTH, params.MEL_BINS))

class PatchToBatchLayer_labels(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, data):
        n = params.LSTM_WINDOW//params.PATCH_WIDTH
        data = tf.repeat(data, n, axis=0, name='repeat_labels_for_all_patches')
        return data

reshape_to_patch_batch = PatchToBatchLayer()
reshape_to_patch_batch_l = PatchToBatchLayer_labels()

## GAN

### Generator losses
#### Similarity Loss
1. Divide the generated data into two equal length sets
2. Calculate MSE between the sets
3. Minimise negative MSE $\implies$ Maximise difference within each pair

#### Rhythm Loss
1. Take the mean amplitude in the generated spectrogram for each time step
2. Perform STFT on the "mean amplitude"-signal to extract rhythm frequencies
3. Mask all rhythm frequencies except for the label
4. Minimise negative rhythm frequency "amplitude" $\implies$ Maximise presence of rhythm

In [None]:
@tf.function
def similarity_loss(data): # pairs the generated data in random pairs and calculates the MSE between them
    # data (None, 512, 128) => (None, 65536)
    data = tf.reshape(data, (-1, 2, params.LSTM_WINDOW*params.MEL_BINS))
    data1, data2 = tf.split(data, 2, axis=1)
    
    # MSE (pair[0], pair[1])
    return -tf.keras.losses.mean_squared_error(data1, data2)

@tf.function
def rhythm_loss(data, rhythm_labels):
    x = tf.reduce_mean(data, axis=2)
    x = tf.abs(tf.signal.stft(
        x,
        params.LSTM_WINDOW,
        1,
        window_fn=None
    ))
    x = x[..., params.RL_LOW_BOUND:params.RL_HI_BOUND]
    n_rhy = params.RL_HI_BOUND - params.RL_LOW_BOUND
    x = tf.squeeze(x, axis=1)
    x = tf.math.multiply(x, rhythm_labels) # mask with one-hot labels
    return -tf.reduce_max(x, axis=1)

### Define GAN connections
* Send generator output + input labels to the discriminators
* If PatchGAN $\implies$ Divide spectrograms into patches & distribute labels

In [None]:
def define_gan(generator, discriminator, patch=False):
    discriminator.trainable = False
    noise_g, e_labels, rhy_label = generator.input
    g_out = generator.output
    
    labels_d = e_labels
    if patch:
        g_out = PatchToBatchLayer()(g_out)
        labels_a, labels_v = labels_d
        labels_a = PatchToBatchLayer_labels()(labels_a)
        labels_v = PatchToBatchLayer_labels()(labels_v)
        labels_d = [labels_a, labels_v]
    gan_out = discriminator([g_out, labels_d])

    model = tf.keras.Model([noise_g, e_labels, rhy_label], gan_out)
    return model

In [None]:
class GAN(tf.keras.Model):
    def __init__(self, generator, discriminator, patch_disc, patch_weight=params.PATCH_WEIGHT, **kwargs):
        super(GAN, self).__init__(**kwargs)
        self.generator = generator
        self.gan = define_gan(generator, discriminator)
        self.patchgan = define_gan(generator, patch_disc, patch=True)
        self.patch_weight = patch_weight
        #self.gan.summary()
        
        self.loss_tracker = tf.keras.metrics.Mean(name="generator loss")
        self.patch_loss_tracker = tf.keras.metrics.Mean(name='patch generator loss')
        self.combined_loss_tracker = tf.keras.metrics.Mean(name='combined generator loss')
        self.sim_loss_tracker = tf.keras.metrics.Mean(name="similarity generator loss")
        self.rhy_loss_tracker = tf.keras.metrics.Mean(name="rhythm generator loss")
    
    def compile(self, optimizer=None, loss=None, patch_loss=None):
        super(GAN, self).compile()
        self.opt = optimizer
        self.loss = loss
        self.patch_loss = patch_loss

    @property
    def metrics(self):
        return [self.loss_tracker, 
                self.patch_loss_tracker, 
                self.combined_loss_tracker, 
                self.sim_loss_tracker, 
                self.rhy_loss_tracker]
    
    # trains generator
    def train_step(self, data):
        g_in = data[0]
        _, _, rhythm_labels = g_in

        with tf.GradientTape() as tape:
            pred = self.gan(g_in)     # predict sequences
            labels = tf.ones(tf.shape(pred))
            d_loss = self.loss(pred, labels)   # G_loss: D(G(z))
            self.loss_tracker.update_state(d_loss)

            # patch loss
            patch_pred = self.patchgan(g_in)
            patch_labels = tf.ones(tf.shape(patch_pred))
            patch_loss = self.patch_loss(patch_pred, patch_labels)
            self.patch_loss_tracker.update_state(patch_loss)  # log loss before weighting

            # similarity loss
            g_spect = self.generator(g_in)
            sim_loss = similarity_loss(g_spect) * params.SIM_WEIGHT
            self.sim_loss_tracker.update_state(sim_loss)

            # rhythm loss
            rhy_loss = rhythm_loss(g_spect, rhythm_labels) * params.RHYTHM_WEIGHT
            self.rhy_loss_tracker.update_state(rhy_loss)
            
            combined_loss = tf.reduce_mean(d_loss) + \
                            tf.reduce_mean(patch_loss) + \
                            tf.reduce_mean(sim_loss) + \
                            tf.reduce_mean(rhy_loss)
            self.combined_loss_tracker.update_state(combined_loss)
        
        grads = tape.gradient(combined_loss, self.generator.trainable_weights)
        self.opt.apply_gradients(zip(grads, self.generator.trainable_weights))

        return {
            "g_loss":self.loss_tracker.result(),
            "g_patch_loss":self.patch_loss_tracker.result(),
            "g_combined_loss":self.combined_loss_tracker.result(),
            "g_sim_loss":self.sim_loss_tracker.result(),
            "g_rhy_loss":self.rhy_loss_tracker.result()
        }

## Training loop


In [None]:
def train(discriminator, patch_disc, gan, dataset, epochs = 10):
    batch_size = params.BATCH_SIZE # change this for LSTM

    d_loss_tracker = tf.keras.metrics.Mean('d_loss', dtype=tf.float32)
    g_loss_tracker = tf.keras.metrics.Mean('g_loss', dtype=tf.float32)

    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    d_log_dir = 'logs/train/' + current_time + '/disc'
    d_summary_writer = tf.summary.create_file_writer(d_log_dir)
    g_log_dir = 'logs/train/' + current_time + '/gen'
    g_summary_writer = tf.summary.create_file_writer(g_log_dir)
    
    config_file = 'models/train/' + current_time + '/params.py'
    os.makedirs(os.path.dirname(config_file), exist_ok=True)
    !cp params.py $config_file

    const_noise = make_latent_noise()
    const_label = make_generator_labels()

    for epoch in range(epochs):
        print(f"Starting epoch {epoch+1}")
        t_start = time.time()
        
        data = iter(dataset)

        n_batches = int(np.ceil(len(dataset) / params.N_CRITIC))
        b_out = display(IPython.display.Pretty(f'Starting...'), display_id=True)
        for batch in range(n_batches):
            
            b_out.update(IPython.display.Pretty(f"Batch {batch+1} of {n_batches}."))
            # train discriminator
            for _ in range(params.N_CRITIC):
                try:
                    # get real batch

                    (X_real_S, X_real_L) = data.get_next()
                    X_real_S = normalize_spect(X_real_S)
                    
                    # generate batch
                    X_fake_S, X_fake_L = make_fake_data(gan.generator, batch_size=X_real_S.shape[0]) # make batch of as many fake as real spectrograms
                    
                    # create labels (real: 1, fake: -1)
                    l_real = tf.ones((X_real_S.shape[0], 1))
                    l_fake = -tf.ones((X_fake_S.shape[0], 1))
                    
                    X = (tf.concat([X_real_S, X_fake_S], 0), (
                            tf.concat([X_real_L[0], X_fake_L[0]], 0),
                            tf.concat([X_real_L[1], X_fake_L[1]], 0))
                    )
                    l = tf.concat([l_real, l_fake], 0)

                    d_loss = discriminator.train_on_batch(x=X, y=l, reset_metrics=True)

                    d_loss_tracker(d_loss)

                    # patch discriminator training
                    X_real_p = reshape_to_patch_batch(X_real_S)
                    X_fake_p = reshape_to_patch_batch(X_fake_S)

                    # repeat all labels for all patches of spectrogram
                    X_real_L_a = reshape_to_patch_batch_l(X_real_L[0])
                    X_real_L_v = reshape_to_patch_batch_l(X_real_L[1])
                    X_fake_L_a = reshape_to_patch_batch_l(X_fake_L[0])
                    X_fake_L_v = reshape_to_patch_batch_l(X_fake_L[1])

                    l_real_p = tf.ones((X_real_p.shape[0], 1))
                    l_fake_p = -tf.ones((X_fake_p.shape[0], 1))

                    X_p = (tf.concat([X_real_p, X_fake_p], 0), (
                            tf.concat([X_real_L_a, X_fake_L_a], 0),
                            tf.concat([X_real_L_v, X_fake_L_v], 0))
                    )
                    l_p = tf.concat([l_real_p, l_fake_p], 0)
                    pd_loss = patch_disc.train_on_batch(x=X_p, y=l_p, reset_metrics=True)
                
                except tf.errors.OutOfRangeError: # if N_CRITIC batches are not available
                    break
            
            tb_batch = batch + epoch*n_batches
            # update tensorboard
            with d_summary_writer.as_default():
                tf.summary.scalar('d_loss', d_loss, step=tb_batch)
                tf.summary.scalar('pd_loss', pd_loss, step=tb_batch)

            # make noise
            g_noises = make_latent_noise(batch_size=batch_size)
            g_labels = make_generator_labels(batch_size=batch_size)

            # train generator
            g_loss = gan.train_on_batch([g_noises, *g_labels], reset_metrics=True, return_dict=True)
            
            g_loss_tracker(g_loss['g_combined_loss'])
            
            # update tensorboard logs
            with g_summary_writer.as_default():
                tf.summary.scalar('g_loss', g_loss['g_loss'], step=tb_batch)
                tf.summary.scalar('g_patch_loss', g_loss['g_patch_loss'], step=tb_batch)
                tf.summary.scalar('g_combined_loss', g_loss['g_combined_loss'], step=tb_batch)
                tf.summary.scalar('g_sim_loss', g_loss['g_sim_loss'], step=tb_batch)
                tf.summary.scalar('g_rhy_loss', g_loss['g_rhy_loss'], step=tb_batch)
            
            print('', end='\r')
        # print metrics
        print(f"\nEpoch {epoch+1} end ({time.time() - t_start:.1f}s):")
        print(f"  Discriminator loss: {np.mean(d_loss_tracker.result())}")
        print(f"  Generator loss: {np.mean(g_loss_tracker.result())}")
        

        # reset all metrics for new epoch
        d_loss_tracker.reset_states()
        g_loss_tracker.reset_states()

        # generate example
        for noise, labels in zip((const_noise, make_latent_noise()), 
                                 (const_label, make_generator_labels())):
            out = make_example(gan.generator, noise=noise, labels=labels)
            
            show_spectrogram(out, title=f"min dB:{np.min(out)}, max dB:{np.max(out)}")
            show_audio(out)
    
        # save model
        models = {'gen':generator, 'disc':discriminator, 'p_disc':patch_disc}
        for model in models.keys():
            filename = 'models/train/' + current_time + f'/{model}_{epoch+1}.h5'
            models[model].save(filename)

In [None]:
# labels are either 1 or -1, depending on if the loss should be minimised or maximised
def wasserstein_loss(pred, labels):
    return tf.math.multiply(pred, labels)

class WassersteinLoss(tf.keras.losses.Loss):
    def __init__(self, weight=1.):
        super().__init__(name="WassersteinLoss")
        self.weight = weight
    
    def call(self, pred, labels):
        return wasserstein_loss(pred, labels) * self.weight

## Set up Models

In [None]:
#continue_training = "drive/MyDrive/Master Thesis/notebooks/models/train/good_big/{model}_30.h5"
continue_training = None # uncomment to train from scratch

if continue_training is not None:
    discriminator = tf.keras.models.load_model(
        continue_training.format(model = 'disc'), 
        custom_objects={"wasserstein_loss": wasserstein_loss,
                        "similarity_loss": similarity_loss,
                        "WeightClip":WeightClip,
                        "WassersteinLoss":WassersteinLoss}
    )
    discriminator.trainable = True   
    patch_disc = tf.keras.models.load_model(
        continue_training.format(model = 'p_disc'), 
        custom_objects={"wasserstein_loss": wasserstein_loss,
                        "similarity_loss": similarity_loss,
                        "WeightClip":WeightClip,
                        "WassersteinLoss":WassersteinLoss}
    )
    patch_disc.trainable = True
    generator = tf.keras.models.load_model(
        continue_training.format(model = 'gen'), 
        custom_objects={"wasserstein_loss": wasserstein_loss,
                        "similarity_loss": similarity_loss,
                        "WeightClip":WeightClip,
                        "WassersteinLoss":WassersteinLoss}
    )
else:
    discriminator = create_discriminator()
    patch_disc = create_patch_discriminator()
    generator = create_generator()

discriminator.compile(optimizer=tf.keras.optimizers.RMSprop(params.D_LR), 
                      loss=WassersteinLoss(weight=params.GAN_WEIGHT))
patch_disc.compile(optimizer=tf.keras.optimizers.RMSprop(params.PD_LR), 
                   loss=WassersteinLoss(weight=params.PATCH_WEIGHT))

gan = GAN(generator, discriminator, patch_disc)
gan.compile(
    optimizer=tf.keras.optimizers.RMSprop(params.G_LR), 
    loss=WassersteinLoss(weight=params.GAN_WEIGHT), 
    patch_loss=WassersteinLoss(weight=params.PATCH_WEIGHT)
)

# Training Output

In [None]:
%tensorboard --logdir logs/train

In [None]:
train(discriminator, patch_disc, gan, ds_train, epochs=30)

# Misc.

In [None]:
drive.mount('/content/drive', force_remount=True)
!cp -r ./logs/* "$path_top/logs/"
!cp -r ./models/* "$path_top/models/"