## Import packages

In [None]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import sys

sys.path.insert(1, "vanilla-transformer")

In [None]:
from embd import PositionalEmbedding
from encoder import EncoderLayer_postLN, EncoderLayer_preLN

## Build model

In [None]:
BATCH_SIZE = 32
def model():
    inputs = keras.layers.Input((127, 15), batch_size=BATCH_SIZE)
    x = PositionalEmbedding(32)(inputs)
    x = EncoderLayer_postLN(d_model=32, num_heads=128, dff=64)(x)
    x = EncoderLayer_postLN(d_model=32, num_heads=64, dff=32)(x)
    x = keras.layers.GlobalAveragePooling1D(data_format='channels_first')(x)
    x = keras.layers.Dense(5, activation='softmax')(x)
    return keras.Model(inputs, x)

In [None]:
model = model()
model.summary()

In [None]:
x = tf.random.normal((BATCH_SIZE, 127, 15))
model(x)

## Load data

In [None]:
data = np.load("mfcc.npz")
X = data["X"]
X_mask = data["X_mask"]
Y = data["Y"]

x_train = X[0:832]
x_mask_train = X_mask[0:832]
y_train = Y[0:832]
x_test = X[872:]
y_test = Y[872:]

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, x_mask_train))
train_dataset = train_dataset.shuffle(buffer_size=10000).batch(BATCH_SIZE)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.shuffle(buffer_size=1000).batch(BATCH_SIZE)

x = x_train[0:BATCH_SIZE]
x_rank = tf.rank(x).numpy()
x_norm_resize_shape = [BATCH_SIZE] + list(tf.ones(tf.rank(x), dtype=tf.int32).numpy())[1:]

## Training

In [None]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super().__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        step = tf.cast(step, dtype=tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

In [None]:
def perturbation_loss(x, y, from_logits=False):
    return keras.losses.CategoricalCrossentropy(from_logits=from_logits)(x, y)

loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
learning_rate = CustomSchedule(32)
optimizer = keras.optimizers.experimental.AdamW(learning_rate=learning_rate, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)

train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = keras.losses.SparseCategoricalCrossentropy(from_logits=False)

In [None]:
eps = 8.     # the perturbation parameter
sig = 1e-5   # initial perturbation StdDev
zeta = 1e-6  # differentiation constant
lamd = 1     # regularization parameter

@tf.function
def training_step(x, label, x_mask):
    x_p = tf.random.normal(x.shape, stddev=sig)
    x_norm = x_p
    for i in range(x_rank-1, 0, -1):
        x_norm = tf.norm(x_norm, ord=2, axis=int(i))
    x_p /= tf.reshape(x_norm, (BATCH_SIZE, 1, 1))
    x_p *= zeta


    with tf.GradientTape() as adversarial_tape:
        adversarial_tape.watch(x_p)
        y_p = model(x + x_p, training=True)
        y = model(x, training=True)
        l = perturbation_loss(y, y_p)
    g = adversarial_tape.gradient(l, x_p)
    print(g[0])

    g_norm = g
    for i in range(x_rank-1, 0, -1):
        g_norm = tf.norm(g_norm, ord=2, axis=int(i))

    x_p = eps * g / tf.reshape(g_norm, x_norm_resize_shape)
    x_p *= x_mask

    with tf.GradientTape() as model_tape:
        y_p = model(x + x_p, training=True)
        y = model(x, training=True)
        l = perturbation_loss(y, y_p)    # Recalculate regularization

        logits = model(x, training=True)
        loss = loss_fn(label, logits) + lamd * l / BATCH_SIZE
    grads = model_tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    train_acc_metric.update_state(label, logits)
    return loss, train_acc_metric.result(), l, x_p, g

In [None]:
train_loss = []
train_metric = []
val_metric = []
val_loss = []
p_loss = []
for epoch in tqdm(range(1000)):
    print("\nStart of epoch %d" % (epoch,))
    for step, (x, label, x_mask) in enumerate(train_dataset):
        loss, train_acc, l, x_p, g = training_step(x, label, x_mask)

    print(
        "Training loss: %.4f\nTraining metric: %.4f"
        % (float(loss), float(train_acc))
    )
    print("perturbation loss: %.4f" % float(l))

    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)
        v_loss = val_loss_metric(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc)))

    train_loss.append(loss)
    train_metric.append(train_acc)
    val_metric.append(val_acc)
    val_loss.append(v_loss)
    p_loss.append(l)

    tl = np.array(train_loss)
    tm = np.array(train_metric)
    vm = np.array(val_metric)
    vl = np.array(val_loss)
    pl = np.array(p_loss)

    np.savez("logs.npz", train_loss=tl, train_acc=tm, val_acc=vm, p_loss=pl, val_loss=vl, x_p=x_p)