In [1]:
import os
import sys
import pathlib
import click
import yaml
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'

In [2]:
""" trainer.py """
import tensorflow as tf
from tensorflow.keras.utils import Progbar
import tensorflow.keras as K
from model.dataset import Dataset
from model.fp.melspec.melspectrogram import get_melspec_layer
from model.fp.specaug_chain.specaug_chain import get_specaug_chain_layer
from model.fp.nnfp import get_fingerprinter
from model.fp.NTxent_loss_single_gpu import NTxentLoss
from model.fp.online_triplet_loss import OnlineTripletLoss
from model.fp.lamb_optimizer import LAMB
from model.utils.experiment_helper import ExperimentHelper
from model.utils.mini_search_subroutines import mini_search_eval

2024-04-09 18:08:25.447413: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-09 18:08:25.447453: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-09 18:08:25.448416: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
def load_config(config_fname):
    config_filepath = './config/' + config_fname + '.yaml'
    if os.path.exists(config_filepath):
        print(f'cli: Configuration from {config_filepath}')
    else:
        sys.exit(f'cli: ERROR! Configuration file {config_filepath} is missing!!')

    with open(config_filepath, 'r') as f:
        cfg = yaml.safe_load(f)
    return cfg


def update_config(cfg, key1: str, key2: str, val):
    cfg[key1][key2] = val
    return cfg


def print_config(cfg):
    os.system("")
    print('\033[36m' + yaml.dump(cfg, indent=4, width=120, sort_keys=False) +
          '\033[0m')
    return

In [4]:
checkpoint_name:str = "Checks_test_generate"   # string
checkpoint_index:int = 100  # int
config:str = "default"       # string 'default'
source_root_dir:str = '/mnt/dataset/public/Fingerprinting/neural-audio-fp-dataset/music/test-dummy-db-100k-full/'
output_root_dir:str = './logs/emb/'
skip_dummy:bool = False

In [5]:
from model.utils.config_gpu_memory_lim import allow_gpu_memory_growth
from model.generate import generate_fingerprint

cfg = load_config(config)
allow_gpu_memory_growth()

cli: Configuration from ./config/default.yaml


GENERATE

In [None]:
generate_fingerprint(cfg, checkpoint_name, checkpoint_index, ..., ..., skip_dummy)

In [7]:
def build_fp(cfg):
    """ Build fingerprinter """
    # m_pre: log-power-Mel-spectrogram layer, S.
    m_pre = get_melspec_layer(cfg, trainable=False)

    # m_fp: fingerprinter g(f(.)).
    m_fp = get_fingerprinter(cfg, trainable=False)
    return m_pre, m_fp

In [8]:
def load_checkpoint(checkpoint_root_dir, checkpoint_name, checkpoint_index,
                    m_fp):
    """ Load a trained fingerprinter """
    # Create checkpoint
    checkpoint = tf.train.Checkpoint(model=m_fp)
    checkpoint_dir = checkpoint_root_dir + f'/{checkpoint_name}/'
    c_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir,
                                           max_to_keep=None)

    # Load
    if checkpoint_index == None:
        tf.print("\x1b[1;32mArgument 'checkpoint_index' was not specified.\x1b[0m")
        tf.print('\x1b[1;32mSearching for the latest checkpoint...\x1b[0m')
        latest_checkpoint = c_manager.latest_checkpoint
        if latest_checkpoint:
            checkpoint_index = int(latest_checkpoint.split(sep='ckpt-')[-1])
            status = checkpoint.restore(latest_checkpoint)
            status.expect_partial()
            tf.print(f'---Restored from {c_manager.latest_checkpoint}---')
        else:
            raise FileNotFoundError(f'Cannot find checkpoint in {checkpoint_dir}')
    else:
        checkpoint_fpath = checkpoint_dir + 'ckpt-' + str(checkpoint_index)
        status = checkpoint.restore(checkpoint_fpath) # Let TF to handle error cases.
        status.expect_partial()
        tf.print(f'---Restored from {checkpoint_fpath}---')
    return checkpoint_index

In [9]:
# Build and load checkpoint
m_pre, m_fp = build_fp(cfg)
checkpoint_root_dir = cfg['DIR']['LOG_ROOT_DIR'] + 'checkpoint/'
checkpoint_index = load_checkpoint(checkpoint_root_dir, checkpoint_name,
                                    checkpoint_index, m_fp)

---Restored from ./logs/checkpoint//Checks_test_generate/ckpt-100---


In [10]:
def get_data_source(cfg, source_root_dir, skip_dummy):
    dataset = Dataset(cfg)
    ds = dict()
    if skip_dummy:
        tf.print("Excluding \033[33m'dummy_db'\033[0m from source.")
        pass
    else:
        ds['dummy_db'] = dataset.get_test_dummy_db_ds()

    if dataset.datasel_test_query_db in ['unseen_icassp', 'unseen_syn']:
        ds['query'], ds['db'] = dataset.get_test_query_db_ds()
    else:
        raise ValueError(dataset.datasel_test_query_db)

    tf.print(f'\x1b[1;32mData source: {ds.keys()}\x1b[0m',
             f'{dataset.datasel_test_query_db}')
    return ds

In [11]:
def prevent_overwrite(key, target_path):
    if (key == 'dummy_db') & os.path.exists(target_path):
        answer = input(f'{target_path} exists. Will you overwrite (y/N)?')
        if answer.lower() not in ['y', 'yes']: sys.exit()

In [12]:
dataset = Dataset(cfg)

In [17]:
dataset.datasel_test_dummy_db

'100k_full_icassp'

In [19]:
# Get data source
""" ds = {'key1': <Dataset>, 'key2': <Dataset>, ...} """
ds = get_data_source(cfg, ..., skip_dummy)

[1;32mData source: dict_keys(['dummy_db', 'query', 'db'])[0m unseen_icassp


In [None]:
ds

In [None]:
# Make output directory
if output_root_dir:
    output_root_dir = output_root_dir + f'/{checkpoint_name}/{checkpoint_index}/'
else:
    output_root_dir = cfg['DIR']['OUTPUT_ROOT_DIR'] + \
        f'/{checkpoint_name}/{checkpoint_index}/'
os.makedirs(output_root_dir, exist_ok=True)
if not skip_dummy:
    prevent_overwrite('dummy_db', f'{output_root_dir}/dummy_db.mm')

In [None]:
def generate_fingerprint(cfg,
                         checkpoint_name,
                         checkpoint_index,
                         source_root_dir,
                         output_root_dir,
                         skip_dummy):
    """
    After run, the output (generated fingerprints) directory will be:
      .
      └──logs
         └── emb
             └── CHECKPOINT_NAME
                 └── CHECKPOINT_INDEX
                     ├── db.mm
                     ├── db_shape.npy
                     ├── dummy_db.mm
                     ├── dummy_db_shape.npy
                     ├── query.mm
                     └── query_shape.npy
    """
    # Build and load checkpoint
    m_pre, m_fp = build_fp(cfg)
    checkpoint_root_dir = cfg['DIR']['LOG_ROOT_DIR'] + 'checkpoint/'
    checkpoint_index = load_checkpoint(checkpoint_root_dir, checkpoint_name,
                                       checkpoint_index, m_fp)

    # Get data source
    """ ds = {'key1': <Dataset>, 'key2': <Dataset>, ...} """
    ds = get_data_source(cfg, source_root_dir, skip_dummy)

    # Make output directory
    if output_root_dir:
        output_root_dir = output_root_dir + f'/{checkpoint_name}/{checkpoint_index}/'
    else:
        output_root_dir = cfg['DIR']['OUTPUT_ROOT_DIR'] + \
            f'/{checkpoint_name}/{checkpoint_index}/'
    os.makedirs(output_root_dir, exist_ok=True)
    if not skip_dummy:
        prevent_overwrite('dummy_db', f'{output_root_dir}/dummy_db.mm')

    # Generate
    sz_check = dict() # for warning message
    for key in ds.keys():
        bsz = int(cfg['BSZ']['TS_BATCH_SZ'])  # Do not use ds.bsz here.
        # n_items = len(ds[key]) * bsz
        n_items = ds[key].n_samples
        dim = cfg['MODEL']['EMB_SZ']
        """
        Why use "memmap"?

        • First, we need to store a huge uncompressed embedding vectors until
          constructing a compressed DB with IVF-PQ (using FAISS). Handling a
          huge ndarray is not a memory-safe way: "memmap" consume 0 memory.

        • Second, Faiss-GPU does not support reconstruction of DB from
          compressed DB (index). In eval/eval_faiss.py, we need uncompressed
          vectors to calaulate sequence-level matching score. The created
          "memmap" will be reused at that point.

        Reference:
            https://numpy.org/doc/stable/reference/generated/numpy.memmap.html

        """
        # Create memmap, and save shapes
        assert n_items > 0
        arr_shape = (n_items, dim)
        arr = np.memmap(f'{output_root_dir}/{key}.mm',
                        dtype='float32',
                        mode='w+',
                        shape=arr_shape)
        np.save(f'{output_root_dir}/{key}_shape.npy', arr_shape)

        # Fingerprinting loop
        tf.print(
            f"=== Generating fingerprint from \x1b[1;32m'{key}'\x1b[0m " +
            f"bsz={bsz}, {n_items} items, d={dim}"+ " ===")
        progbar = Progbar(len(ds[key]))

        """ Parallelism to speed up preprocessing------------------------- """
        enq = tf.keras.utils.OrderedEnqueuer(ds[key],
                                              use_multiprocessing=True,
                                              shuffle=False)
        enq.start(workers=cfg['DEVICE']['CPU_N_WORKERS'],
                  max_queue_size=cfg['DEVICE']['CPU_MAX_QUEUE'])
        i = 0
        while i < len(enq.sequence):
            progbar.update(i)
            X, _ = next(enq.get())
            emb = test_step(X, m_pre, m_fp)
            arr[i * bsz:(i + 1) * bsz, :] = emb.numpy() # Writing on disk.
            i += 1
        progbar.update(i, finalize=True)
        enq.stop()
        """ End of Parallelism-------------------------------------------- """

        tf.print(f'=== Succesfully stored {arr_shape[0]} fingerprint to {output_root_dir} ===')
        sz_check[key] = len(arr)
        arr.flush(); del(arr) # Close memmap

    if 'custom_source' in ds.keys():
        pass;
    elif sz_check['db'] != sz_check['query']:
        print("\033[93mWarning: 'db' and 'query' size does not match. This can cause a problem in evaluataion stage.\033[0m")
    return
