# **Trainer**

**Notas -** ver como funciona:
- L2-normalization.
- Hilbert space.

In [None]:
# -*- coding: utf-8 -*-
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
""" trainer.py """
import tensorflow as tf
from tensorflow.keras.utils import Progbar
from model.dataset import Dataset
from model.fp.melspec.melspectrogram import get_melspec_layer
from model.fp.specaug_chain.specaug_chain import get_specaug_chain_layer
from model.fp.nnfp import get_fingerprinter
from model.fp.NTxent_loss_single_gpu import NTxentLoss
from model.fp.online_triplet_loss import OnlineTripletLoss
from model.fp.lamb_optimizer import LAMB
from model.utils.experiment_helper import ExperimentHelper
from model.utils.mini_search_subroutines import mini_search_eval

## **Build Fingerprinter**
- ***m_pre*** is log-power-Mel-spectogram layer (S).
    - As a first step, input audio X is converted to time-frequency representation S.
    - To calculate the log Mel spectogram:
        <br>The function needs the follow variables: *fs*, *dur*, *n_fft*, *stft_hop*, *n_mels*, *f_min*, *f_max*. These things are in *cfg*.
- ***m_specaug*** is spec-augmentation layer.
    - Cutout and Spec-augment are applied after extracting log-power Mel-spectrogram features, such that {$s^{org}, s^{rep}$}. Unlike other augmentations, we uniformly apply a **batch-wise** random mask to all examples in a batch including $s^{org}$. The size and position of each rectangle/vertical/horizontal mask is random in the range [1/10, 1/2] the length of each time/frequency axis.
- ***m_fp*** is fingerprinter *g(f(.))*
    - It is then fed into convolutional encoder $f(.)$. Finally, L2-normalization is applied to its output through a linear projection layer $g(.)$. Thus, we employ $g ◦ f : S → \mathcal{Z}^d$ as a segment-wise encoder that transforms S into d-dimensional fingerprint embedding space $\mathcal{Z}^d$. The d-dimensional output space $\mathcal{Z}^d$ always belongs to *Hilbert* space $L^2(\mathbb{R}^d)$: the cosine similarity of a pair unit such as cos($z_a$, $z_b$) becomes inner-product $z_a^Tz_b$.
    <br>The $g◦f(.)$ described here can be interpreted as a reorganization of the previous audio fingerprinting networks [7] into the common form employed in self-supervised learning (SSL) [14–17]. However, our approach differs from the typical SSL that throws $g(.)$ away before fine-tuning for the target task: we maintain the self-supervised $g(.)$ up to the final target task.
    
Note: References [7] and [14-17] are in the article.

In [None]:
def build_fp(cfg):
    """ Build fingerprinter """
    # m_pre: log-power-Mel-spectrogram layer, S.
    m_pre = get_melspec_layer(cfg, trainable=False)

    # m_specaug: spec-augmentation layer.
    m_specaug = get_specaug_chain_layer(cfg, trainable=False)
    assert(m_specaug.bypass==False) # Detachable by setting m_specaug.bypass.

    # m_fp: fingerprinter g(f(.)).
    m_fp = get_fingerprinter(cfg, trainable=False)
    return m_pre, m_specaug, m_fp

## **Train step**
A mini-batch with the size of N consists of N/2 pairs of $\{s^{org}, s^{rep}\}$. $s^{org}$ is the time-frequency representation of sampled audio and $s^{rep}$ is the augmented replica of $s^{org}$, where $s^{rep} = M_\alpha(s^{org})$. $M_{\alpha}$ is an ordered augmentation chain that consists of multiple augmentors with the random parameter set $\alpha$ for each replica. In this configuration, the indices of original examples are always odd, and that of replicas are even. Therefore, the batchwise output of $f ◦ g(s)$ can be $\{z^{org}_{2k−1}, z^{rep}_{2k}\}^{2/N}_{k=1}$.
Anchors are a reference points chosen to be compared with others points inside of the same mini-batch - calculate by Batch_size/2.
<br><br>train_step tem como argumentos: X (dados para treino), m_pre (os espetrogramas mel pré-calculados), m_specaug (a cadeia de aumento espetral dos dados), m_fp (dados fp), loss_obj (objeto de loss que calcula a loss entre os embeddings dos exemplos originais e replicados) e helper (para ajudar na atualização da loss)

In [None]:
@tf.function
def train_step(X, m_pre, m_specaug, m_fp, loss_obj, helper):
    """ Train step """
    # X: (Xa, Xp)
    # Xa: anchors or originals, s.t. [xa_0, xa_1,...]
    # Xp: augmented replicas, s.t. [xp_0, xp_1] with xp_n = rand_aug(xa_n).
    n_anchors = len(X[0])
    X = tf.concat(X, axis=0)
    feat = m_specaug(m_pre(X))  # (nA+nP, F, T, 1)
    m_fp.trainable = True
    with tf.GradientTape() as t:
        emb = m_fp(feat)  # (BSZ, Dim)
        loss, sim_mtx, _ = loss_obj.compute_loss(
            emb[:n_anchors, :], emb[n_anchors:, :]) # {emb_org, emb_rep}
    g = t.gradient(loss, m_fp.trainable_variables)
    helper.optimizer.apply_gradients(zip(g, m_fp.trainable_variables))
    avg_loss = helper.update_tr_loss(loss) # To tensorboard.
    return avg_loss, sim_mtx # avg_loss: average within the current epoch

Validação

In [None]:
@tf.function
def val_step(X, m_pre, m_fp, loss_obj, helper):
    """ Validation step """
    n_anchors = len(X[0])
    X = tf.concat(X, axis=0)
    feat = m_pre(X)  # (nA+nP, F, T, 1)
    m_fp.trainable = False
    emb = m_fp(feat)  # (BSZ, Dim)
    loss, sim_mtx, _ = loss_obj.compute_loss(
        emb[:n_anchors, :], emb[n_anchors:, :]) # {emb_org, emb_rep}
    avg_loss = helper.update_val_loss(loss) # To tensorboard.
    return avg_loss, sim_mtx

### Teste

In [None]:
@tf.function
def test_step(X, m_pre, m_fp):
    """ Test step used for mini-search-validation """
    X = tf.concat(X, axis=0)
    feat = m_pre(X)  # (nA+nP, F, T, 1)
    m_fp.trainable = False
    emb_f = m_fp.front_conv(feat)  # (BSZ, Dim)
    emb_f_postL2 = tf.math.l2_normalize(emb_f, axis=1)
    emb_gf = m_fp.div_enc(emb_f)
    emb_gf = tf.math.l2_normalize(emb_gf, axis=1)
    return emb_f, emb_f_postL2, emb_gf # f(.), L2(f(.)), L2(g(f(.))

### Mini search validation
Exemplo:
<br> ======= mini-search-validation: argmin f =======
<br> Scope:	   1   	  3   	  5   	  9   	  11  	  19  	 
<br> T1acc:	 0.07	0.07	0.07	0.07	0.07	0.07	
<br> mRank:	 749.50	748.50	747.50	745.50	744.50	740.50	
<br> ======= mini-search-validation: argmin L2(f) =======
<br> Scope:	   1   	  3   	  5   	  9   	  11  	  19  	 
<br> T1acc:	 0.07	0.07	0.07	0.07	0.07	0.07	
<br> mRank:	 749.50	748.50	747.50	745.50	744.50	740.50	
<br> ======= mini-search-validation: argmin g(f) =======
<br> Scope:	   1   	  3   	  5   	  9   	  11  	  19  	 
<br> T1acc:	 0.07	0.07	0.07	0.07	0.07	0.07	
<br> mRank:	 749.50	748.50	747.50	745.50	744.50	740.50

In [None]:
def mini_search_validation(ds, m_pre, m_fp, mode='argmin',
                           scopes=[1, 3, 5, 9, 11, 19], max_n_samples=3000):
    """ Mini-search-validation """
    # Construct mini-DB
    key_strs = ['f', 'L2(f)', 'g(f)']
    m_fp.trainable = False
    (db, query, emb, dim) = (dict(), dict(), dict(), dict())
    dim['f'] = dim['L2(f)'] = m_fp.front_hidden_ch[-1]
    dim['g(f)'] = m_fp.emb_sz
    bsz = ds.bsz
    n_anchor = bsz // 2
    n_iter = min(len(ds), max_n_samples // bsz)
    for k in key_strs:
        (db[k], query[k]) = (tf.zeros((0, dim[k])), tf.zeros((0, dim[k])))
    for i in range(n_iter):
        X = ds.__getitem__(i)
        emb['f'], emb['L2(f)'], emb['g(f)'] = test_step(X, m_pre, m_fp)
        for k in key_strs:
            db[k] = tf.concat((db[k], emb[k][:n_anchor, :]), axis=0)
            query[k] = tf.concat((query[k], emb[k][n_anchor:, :]), axis=0)

    # Search test
    accs_by_scope = dict()
    for k in key_strs:
        tf.print(f'======= mini-search-validation: \033[31m{mode} \033[33m{k} \033[0m=======' + '\033[0m')
        query[k] = tf.expand_dims(query[k], axis=1) # (nQ, d) --> (nQ, 1, d)
        accs_by_scope[k], _ = mini_search_eval(
            query[k], db[k], scopes, mode, display=True)
    return accs_by_scope, scopes, key_strs

## **Trainer**
Firstly, it loads the dataset. Where *cfg* is a dictionary contains configurations.
After that, build the model calling *build_fp* function on the way to return *m_pre*, *m_specaug*, *m_fp*.

<br> Learning rate: explicação.
<br> Função de otimização: Adam
- ```tf.keras.optimizers.Adam```: Adam optimization is a stochastic gradient descent method that is based on adaptive estimation of first-order and second-order moments.
- They trained the model using LAMB optimizer, which performed 2 pp better than Adam with the 3 s query sequence for batch size N ≥ 320. In practice, Adam worked better only for N ≤ 240. The learning rate had an initial value of 1e-4·N/640 with cosine decay without warmup or restarts, then it reached a minimum value of 1-e7 in 100 epochs.

<br> Função de Loss: NTXENT
- $l(i,j) = -log \frac{exp(a_{i,j}/\tau)}{\mathbb{1} \sum_{k=1}^{N}(k\neq i)exp(a_{i,j}/\tau)}$
    - Onde, $\mathbb{1} \epsilon \{0,1\}$ is an indicator function that returns 1 iff(.) is true, and $\tau > 0$ denotes the temperature parameter for softmax. The Equation is employed to replace MIPS from the property: computing the top-k (k=1 in our setup) predictions in the softmax function is equivalent to the MIPS.
- De acordo com https://towardsdatascience.com/nt-xent-normalized-temperature-scaled-cross-entropy-loss-explained-and-implemented-in-pytorch-cc081f69848.
    - At a high level, the contrastive learning model is fed 2N images, originating from N underlying images. Each of the N underlying images is augmented using a random set of image augmentations to produce 2 augmented images.
    - For SoftMax: A low temperature increases the variance in the output distribution and makes the maximum value stand out over the other values. (Our value $\tau = 0.05$).
    - Pares positivos vs pares negativos:
        - Pares positivos derivam da mesma imagem no processo de Augmentation.
        - Pares negativos não derivam da mesma imagem.
    
<br> 
<br> 

In [None]:
def trainer(cfg, checkpoint_name):
    # Dataloader
    dataset = Dataset(cfg)

    # Build models.
    m_pre, m_specaug, m_fp = build_fp(cfg)

    # Learning schedule
    total_nsteps = cfg['TRAIN']['MAX_EPOCH'] * len(dataset.get_train_ds())
    if cfg['TRAIN']['LR_SCHEDULE'].upper() == 'COS':
        lr_schedule = tf.keras.experimental.CosineDecay(
            initial_learning_rate=float(cfg['TRAIN']['LR']),
            decay_steps=total_nsteps,
            alpha=1e-06)
    elif cfg['TRAIN']['LR_SCHEDULE'].upper() == 'COS-RESTART':
        lr_schedule = tf.keras.experimental.CosineDecayRestarts(
            initial_learning_rate=float(cfg['TRAIN']['LR']),
            first_decay_steps=int(total_nsteps * 0.1),
            num_periods=0.5,
            alpha=2e-06)
    else:
        lr_schedule = float(cfg['TRAIN']['LR'])

    # Optimizer
    if cfg['TRAIN']['OPTIMIZER'].upper() == 'LAMB':
        opt = LAMB(learning_rate=lr_schedule)
    elif cfg['TRAIN']['OPTIMIZER'].upper() == 'ADAM':
        opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
    else:
        raise NotImplementedError(cfg['TRAIN']['OPTIMIZER'])

    # Experiment helper: see utils.experiment_helper.py for details.
    helper = ExperimentHelper(
        checkpoint_name=checkpoint_name,
        optimizer=opt,
        model_to_checkpoint=m_fp,
        cfg=cfg)

    # Loss objects
    if cfg['LOSS']['LOSS_MODE'].upper() == 'NTXENT': # Default
        loss_obj_train = NTxentLoss(
            n_org=cfg['BSZ']['TR_N_ANCHOR'],
            n_rep=cfg['BSZ']['TR_BATCH_SZ'] - cfg['BSZ']['TR_N_ANCHOR'],
            tau=cfg['LOSS']['TAU'])
        loss_obj_val = NTxentLoss(
            n_org=cfg['BSZ']['VAL_N_ANCHOR'],
            n_rep=cfg['BSZ']['VAL_BATCH_SZ'] - cfg['BSZ']['VAL_N_ANCHOR'],
            tau=cfg['LOSS']['TAU'])
    elif cfg['LOSS']['LOSS_MODE'].upper() == 'ONLINE-TRIPLET': # Now-playing
        loss_obj_train = OnlineTripletLoss(
            bsz=cfg['BSZ']['TR_BATCH_SZ'],
            n_anchor=cfg['BSZ']['TR_N_ANCHOR'],
            mode = 'semi-hard',
            margin=cfg['LOSS']['MARGIN'])
        loss_obj_val = OnlineTripletLoss(
            bsz=cfg['BSZ']['VAL_BATCH_SZ'],
            n_anchor=cfg['BSZ']['VAL_N_ANCHOR'],
            mode = 'all', # use 'all' mode for validation
            margin=0.)
    else:
        raise NotImplementedError(cfg['LOSS']['LOSS_MODE'])

    # Training loop
    ep_start = helper.epoch
    ep_max = cfg['TRAIN']['MAX_EPOCH']
    for ep in range(ep_start, ep_max + 1):
        tf.print(f'EPOCH: {ep}/{ep_max}')

        # Train
        """ Parallelism to speed up preprocessing.............. """
        train_ds = dataset.get_train_ds(cfg['DATA_SEL']['REDUCE_ITEMS_P'])
        progbar = Progbar(len(train_ds))
        enq = tf.keras.utils.OrderedEnqueuer(
            train_ds, use_multiprocessing=True, shuffle=train_ds.shuffle)
        enq.start(workers=cfg['DEVICE']['CPU_N_WORKERS'],
                  max_queue_size=cfg['DEVICE']['CPU_MAX_QUEUE'])
        i = 0
        while i < len(enq.sequence):
            X = next(enq.get()) # X: Tuple(Xa, Xp)
            avg_loss, sim_mtx = train_step(X, m_pre, m_specaug, m_fp,
                                            loss_obj_train, helper)
            progbar.add(1, values=[("tr loss", avg_loss)])
            i += 1
        enq.stop()
        """ End of Parallelism................................. """

        if cfg['TRAIN']['SAVE_IMG'] and (sim_mtx is not None):
            helper.write_image_tensorboard('tr_sim_mtx', sim_mtx.numpy())

        # Validate
        """ Parallelism to speed up preprocessing.............. """
        val_ds = dataset.get_val_ds(max_song=250) # max 500
        enq = tf.keras.utils.OrderedEnqueuer(
            val_ds, use_multiprocessing=True, shuffle=False)
        enq.start(workers=cfg['DEVICE']['CPU_N_WORKERS'],
                  max_queue_size=cfg['DEVICE']['CPU_MAX_QUEUE'])
        i = 0
        while i < len(enq.sequence):
            X = next(enq.get()) # X: Tuple(Xa, Xp)
            _, sim_mtx = val_step(X, m_pre, m_fp, loss_obj_val,
                                  helper)
            i += 1
        enq.stop()
        """ End of Parallelism................................. """

        if cfg['TRAIN']['SAVE_IMG'] and (sim_mtx is not None):
            helper.write_image_tensorboard('val_sim_mtx', sim_mtx.numpy())

        # On epoch end
        tf.print('tr_loss:{:.4f}, val_loss:{:.4f}'.format(
            helper._tr_loss.result(), helper._val_loss.result()))
        helper.update_on_epoch_end(save_checkpoint_now=True)


        # Mini-search-validation (optional)
        if cfg['TRAIN']['MINI_TEST_IN_TRAIN']:
            accs_by_scope, scopes, key_strs = mini_search_validation(
                val_ds, m_pre, m_fp)
            for k in key_strs:
                helper.update_minitest_acc(accs_by_scope[k], scopes, k)
