In [1]:
import os
import sys
import pathlib
import click
import yaml

import csv
import tensorflow as tf
from tensorflow.keras.utils import Progbar
import tensorflow.keras as K
import librosa

In [2]:

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'

from model.dataset import Dataset
from model.fp.melspec.melspectrogram import get_melspec_layer
from model.fp.specaug_chain.specaug_chain import get_specaug_chain_layer
from model.fp.nnfp import get_fingerprinter
from model.fp.NTxent_loss_single_gpu import NTxentLoss
from model.fp.online_triplet_loss import OnlineTripletLoss
from model.fp.lamb_optimizer import LAMB
from model.utils.experiment_helper import ExperimentHelper
from model.utils.mini_search_subroutines import mini_search_eval

# Functions definition

In [4]:
def build_fp(cfg):
    """ Build fingerprinter """
    # m_pre: log-power-Mel-spectrogram layer, S.
    m_pre = get_melspec_layer(cfg, trainable=False)

    # m_specaug: spec-augmentation layer.
    m_specaug = get_specaug_chain_layer(cfg, trainable=False)
    assert(m_specaug.bypass==False) # Detachable by setting m_specaug.bypass.

    # m_fp: fingerprinter g(f(.)).
    m_fp = get_fingerprinter(cfg, trainable=False)
    return m_pre, m_specaug, m_fp

def load_config(config_fname):
    config_filepath = './config/' + config_fname + '.yaml'
    if os.path.exists(config_filepath):
        print(f'cli: Configuration from {config_filepath}')
    else:
        sys.exit(f'cli: ERROR! Configuration file {config_filepath} is missing!!')

    with open(config_filepath, 'r') as f:
        cfg = yaml.safe_load(f)
    return cfg

@tf.function
def test_step(X, m_pre, m_fp):
    """ Test step used for mini-search-validation """
    X = tf.concat(X, axis=0)
    feat = m_pre(X)  # (nA+nP, F, T, 1)
    m_fp.trainable = False
    emb_f = m_fp.front_conv(feat)  # (BSZ, Dim)
    emb_f_postL2 = tf.math.l2_normalize(emb_f, axis=1)
    emb_gf = m_fp.div_enc(emb_f)
    emb_gf = tf.math.l2_normalize(emb_gf, axis=1)
    return emb_f, emb_f_postL2, emb_gf # f(.), L2(f(.)), L2(g(f(.))

# Load latest checkpoint

In [10]:
checkpoint_name_dir:str = "./logs/CHECKPOINT_BSZ_120"#"CHECKPOINT"   # string
checkpoint_index:int = None  # int
config:str = "default"   

In [7]:
cfg = load_config(config)


cli: Configuration from ./config/default.yaml


In [8]:
m_pre, m_specaug, m_fp = build_fp(cfg)

In [13]:

checkpoint = tf.train.Checkpoint(m_fp)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_name_dir))

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x1106e4b50>

# Load data

In [12]:
audio_path = ""

<model.fp.nnfp.FingerPrinter at 0x1a2d68f40>

In [None]:
audio,fs = librosa.load(audio_path, mono=True, sr=22050)

# Model Predict

In [None]:
def predict(audio, model):
    
    
    
    return emb

In [None]:
dataset = Dataset(cfg)

In [None]:
val_ds = dataset.get_val_ds(max_song=250)