In [None]:
import os, librosa, time, pickle, random, warnings
from glob import glob
import numpy as np
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
warnings.filterwarnings('ignore')
import tensorflow_io as tfio
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import seaborn as sns

from tqdm.notebook import tqdm
import IPython.display as ipd
from IPython.core.display import display, clear_output
#%load_ext tensorboard

from sklearn.manifold import TSNE
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.model_selection import train_test_split

from tensorflow.keras.layers import *
from tensorflow.keras.applications import *
from tensorflow.keras import *

from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from sklearn.cluster import KMeans
from coclust.evaluation.external import accuracy

SR = 16_000
FRAME = 0.2

In [None]:
def loadData(name, frame=0.2, seconds=None):
    '''
    Dataset name prefixes: 
    TIMIT : DARPA TIMIT
    LIBRI : LibriSpeech
    ASR : Bengali ASR
    '''

    basepath = "./drive/My Drive/SpeakerRecognition/preprocessed_dataset"

    data_types = [f"{name}_X_16000_{seconds}_{frame}.pkl", 
                  f"{name}_y_16000_{seconds}_{frame}.pkl",
                  f"{name}_fy_16000_{seconds}_{frame}.pkl", 
                  f"noise_16000_{frame}.pkl"]

    rets = []
    for i in range(len(data_types)):
        with open(os.path.join(basepath, data_types[i]), 'rb') as f:
            rets.append(pickle.load(f))

    X = np.asarray(rets[0], dtype=np.float32)
    X = np.expand_dims(X, axis=-1)

    y = LabelEncoder().fit_transform(rets[1])
    fi = LabelEncoder().fit_transform(rets[2])

    noise = np.asarray(rets[3], dtype=np.float32)
    noise = np.expand_dims(noise, axis=-1)

    # Returns X, y, file_indexes, noise 
    print("Total unique labels:", len(np.unique(y)))
    return X, y, fi, noise

In [None]:
def speakerFilter(data_x, data_y, speakers, data_fi=None, seed=42):
    '''
    Limits number of speakers
    data_x  : Speech frame
    data_y  : actual label of speech frames
    speakers: number of speakers
    '''

    tot_spkr = len(np.unique(data_y))
    rand = random.Random(seed)
    persons = rand.sample(range(tot_spkr), speakers)
    #print(persons)

    idx = np.asarray([i for i in range(data_x.shape[0]) if data_y[i] in persons], 
                     dtype=np.int32)

    if data_fi is not None:
        return (data_x[idx], LabelEncoder().fit_transform(data_y[idx]), 
                data_fi[idx])
    return (data_x[idx], LabelEncoder().fit_transform(data_y[idx]))


# An utility function
def showWAV(wav, sr, frame):
    '''
       showWAV(X[:2], sr=SR, frame=FRAME)
       wav shape : (x, y, z)
    '''
    plt.figure(figsize=(12, 3))
    for i in range(wav.shape[0]):
        plt.subplot(wav.shape[0], 1, i+1)
        display(ipd.Audio(wav[i, ..., 0], rate=SR, autoplay=False))
        plt.plot(np.arange(int(sr*frame)), wav[i, ..., 0])
    plt.show()


def makeImpure(data_y, ratio=0, seed=42):
    if ratio == 0:
        return data_y

    tot_spkr = len(np.unique(data_y))
    tot_impurs = int(data_y.shape[0]*ratio)
    rand = random.Random(seed)

    idxs = rand.sample(range(data_y.shape[0]), tot_impurs)
    shuffle_data = data_y[idxs]
    rand.shuffle(shuffle_data)

    ret_data = np.copy(data_y)
    ret_data[idxs] = shuffle_data

    return ret_data


def pairwiseRelations(data_y, max_lim, seed=42):
    '''
    Limits & constructs the number of pairwise relations
    data_y  : The actual label of speech frames
    max_lim : Maximum number of pairwise relations
    '''

    tot_spkr = len(np.unique(data_y))
    rand = random.Random(seed)

    av_idx = dict()
    for i in range(data_y.shape[0]):
        label = data_y[i]
        if label not in av_idx:
            av_idx[label] = set()
        av_idx[label].add(i)
    
    tot_buckets = data_y.shape[0]//max_lim
    bucket = [0 for i in range(tot_buckets)]
    label = 0

    new_y = np.zeros(data_y.shape, dtype=np.int32)

    for i in range(tot_buckets):
        for l in range(tot_spkr):
            while len(av_idx[l]) > 0:
                min_lim = len(av_idx[l])
                ids = rand.sample(av_idx[l], min(min_lim, max_lim))
                new_y[ids] = label
                label += 1
                for id in ids: av_idx[l].remove(id)

    return new_y

In [None]:
# FFT was not used!!
class FFT(tf.keras.layers.Layer):
    def __init(self):
        super(FFT, self).__init__()

    def build(self, input_shape):
        super(FFT, self).build(input_shape)

    def adapt(self, input):
        return self.call(input)

    def call(self, input):
        # Since tf.signal.fft applies FFT on the innermost dimension,
        # we need to squeeze the dimensions and then expand them again
        # after FFT
        input = tf.squeeze(input, axis=-1)
        fft = tf.signal.fft(
            tf.cast(tf.complex(real=input, imag=tf.zeros_like(input)), tf.complex64)
        )
        fft = tf.expand_dims(fft, axis=-1)
        # Return the absolute value of the first half of the FFT
        # which represents the positive frequencies
        return tf.math.abs(fft[:, : (input.shape[1] // 2), :])


class MEL(tf.keras.layers.Layer):
    def __init__(self, scale=None, setup=2, custom_setup=None, 
                 output_channel=1):
        '''
            scale :        Can be [None, 'spec', 'log', 'db']
                           By default it returns mel
            custom_setup : Select parameter setups. must be in dict format
            setup :        Select parameter setups that are predefined. 
                           Default is setup=2
            output_channel: Number of output channels. If greater than one, the
                            output will be repeated.
        '''
        self.scale = scale
        super(MEL, self).__init__()
        self.setup = [{"nfft": 191, "window": 128, "stride": 34, "mels": 100},
                      {"nfft": 1024, "window": 128, "stride": 61, "mels": 263},
                      {"nfft": 511, "window": 32, "stride": 16, "mels": 256}]
        self.id = setup
        self.output_channel = output_channel 
        if custom_setup is not None:
            self.id = 0
            self.setup = [custom_setup]
        
    def build(self, input_shape):
        super(MEL, self).build(input_shape)

    def adapt(self, input):
        return self.call(input)

    def call(self, input):
        spect = tfio.experimental.audio.spectrogram(input[..., 0], 
                                                    nfft=self.setup[self.id]["nfft"], 
                                                    window=self.setup[self.id]["window"], 
                                                    stride=self.setup[self.id]["stride"])
        
        if self.scale == "spec":
            spect = tf.expand_dims(spect, axis=-1)
            if self.output_channel > 1:
                return tf.keras.backend.repeat_elements(spect, 
                                                        self.output_channel, 
                                                        axis=-1)
            return spect

        mel = tfio.experimental.audio.melscale(spect, rate=SR, 
                                               mels=self.setup[self.id]["mels"], 
                                               fmin=0, fmax=8000)

        if self.scale == "log":
            mel = tf.math.log(mel)
        elif self.scale == "db":
            mel = tfio.experimental.audio.dbscale(mel, top_db=128)

        mel = tf.expand_dims(mel, axis=-1)
        if self.output_channel > 1:
            mel = tf.keras.backend.repeat_elements(mel, self.output_channel, 
                                                    axis=-1)
        return mel

In [None]:
class MyLogger(tf.keras.callbacks.Callback):
    '''
    Parameters:

    n    : Number of steps after which logs will appear/calculated
    plot : If True, the graphs (ACC graph, scatters) will be plotted 
    scatter : If True, the scatter will be plotted, plot must be True
    val_train : The metrices will be also be calculated for the trainind data and
              the corresponding pseudo labels
    AE : The AutoEmbedder portion of the model, used for evaluation
    
    validation_data : [(X, pseudo_labels), (X, actual_labels)]
    nodes : Number of data nodes that will be used for evaluation
            default is 500

    save_model : If True, the model will be saved when max ACC on actual label
                 is found
    
    savepath : the basepath of the save directory

    * self.start_epoch contains the last epoch when the training is terminated
    '''
    def __init__(self, validation_data, n=1, plot=False, save_model=True,
                 savepath=None, dg=None, AE=None, val_train=False, 
                 scatter=False, nodes=500, show_fig=True,
                 # MEL layer setup
                 mel_scale='spec', setup=0, output_channel=1,):
        
        self.n = n
        if validation_data != None:
            self.x_tr, self.y_tr = validation_data[0]
            self.x_val, self.y_val = validation_data[1]
            self.classes_val = len(np.unique(self.y_val))
            self.classes_tr = len(np.unique(self.y_tr))
        self.start_time = time.time()
        self.savepath = savepath
        self.plot = plot
        self.save_model = save_model
        self.maxACC = 0
        self.start_epoch = 0
        self.nodes = nodes
        self.AE = AE
        self.val_train = val_train
        self.scatter = scatter
        self.show_fig = show_fig
        self.mel = MEL(scale=mel_scale, setup=setup, 
                       output_channel=output_channel)

        self.savelog = {'Epoch': [], 'ACC':[], 'NMI':[], 'ARI':[],
                        'val_Epoch': [], 'val_ACC':[], 'val_NMI':[], 
                        'val_ARI':[], 'loss':[]}
        
        # Defining save paths
        if self.savepath != None:
            self.logpath = os.path.join(self.savepath, 'log.pickle')
            self.modelpath = os.path.join(self.savepath, 'model', '')
        # Creating save paths
        if self.savepath != None and os.path.exists(self.savepath) == False:
            os.makedirs(self.savepath)
        # Loading previous data if found
        if self.savepath != None and os.path.exists(self.logpath):
            with open(self.logpath, 'rb') as f:
                self.savelog = pickle.load(f)
            print('Previous data loaded, starting epoch:', self.savelog['Epoch'][-1])
            self.start_epoch = self.savelog['Epoch'][-1]
            self.maxACC = max(self.savelog['val_ACC'])
        if self.plot:
            sns.set_style("whitegrid")

 
    def on_epoch_end(self, epoch, logs={}):
        if epoch % self.n != 0: return
        epoch += 1

        # Validation data
        vacc, vnmi, vari, vouts, vys = self._KmeansAcc(self.x_val, self.y_val, 
                                           self.classes_val)
        logs['val_acc'], logs['val_nmi'], logs['val_ari'] = vacc, vnmi, vari

        # Train data
        if self.val_train:
            acc, nmi, ari, outs, ys = self._KmeansAcc(self.x_tr, self.y_tr, 
                                            self.classes_tr)
            logs['acc'], logs['nmi'], logs['ari'] = acc, nmi, ari
        else:
            acc, nmi, ari = None, None, None

        ep_time = time.time() - self.start_time
        self._saveLog(epoch, acc, nmi, ari, logs['loss'], vacc, vnmi, vari)
        if self.plot: self.plotter(vouts, vys)
        else : clear_output(wait=True)

        self.start_time = time.time()
        print(f"Epoch {epoch}: bACC: {self.maxACC:.2f}")
        print(f"val_ACC {vacc:.3f} val_NMI {vnmi:.3f} val_ARI {vari:.3f} Loss {logs['loss']:.3f} ET:{ep_time:.1f}")
        if self.val_train:
            print(f"tACC {acc:.3f} tNMI {nmi:.3f} tARI {ari:.3f}")

        if self.savepath != None and vacc > self.maxACC:
            print(f'Saving model, prev :{self.maxACC:.2f}, current: {vacc:.2f}')
            self.maxACC = vacc
            with open(self.logpath, 'wb') as f:
                pickle.dump(self.savelog, f)
            if self.save_model:
                tf.keras.models.save_model(self.model, self.modelpath)
        elif self.savepath != None:
            with open(self.logpath, 'wb') as f:
                pickle.dump(self.savelog, f)
                
        self.maxACC = max(vacc, self.maxACC)


    def _saveLog(self, epoch, acc, nmi, ari, loss, vacc, vnmi, vari):        
        self.savelog['Epoch'].append(epoch)
        # Train data
        if self.val_train:
            self.savelog['ACC'].append(acc)
            self.savelog['NMI'].append(nmi)
            self.savelog['ARI'].append(ari)
        # Validation data
        self.savelog['val_ACC'].append(vacc)
        self.savelog['val_NMI'].append(vnmi)
        self.savelog['val_ARI'].append(vari)
        self.savelog['loss'].append(loss)
 
 
    def plotter(self, embds, labels):
        clear_output(wait=True)
        plt.figure(figsize=(12.8, 4.8))
        if self.scatter:
            plt.subplot(1, 2, 1)
        
        plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
        # dashes are the train data
        if self.val_train:
            plt.plot(self.savelog['Epoch'], (self.savelog['ACC']),
                    '--', label='tACC', c='purple')
            plt.plot(self.savelog['Epoch'], (self.savelog['NMI']), 
                    '--', label='tNMI', c='orange')
            plt.plot(self.savelog['Epoch'], (self.savelog['ARI']), 
                    '--', label='tARI', c='c')
        plt.plot(self.savelog['Epoch'], 
                 (np.array(self.savelog['loss'])/max(self.savelog['loss'])), 
                 label='Training Loss', c='r')
        # solids are validation data
        plt.plot(self.savelog['Epoch'], (self.savelog['val_ACC']), 
                 label='ACC', c='purple')
        plt.plot(self.savelog['Epoch'], (self.savelog['val_NMI']), 
                 label='NMI', c='orange')
        plt.plot(self.savelog['Epoch'], (self.savelog['val_ARI']), 
                 label='ARI', c='c')
        plt.ylabel('Score')
        plt.xlabel('Epoch')

        plt.grid()
        plt.minorticks_on()
        plt.grid(b=True, which='minor', linestyle='--', alpha=0.25)
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.13),
                   fancybox=True, shadow=False, ncol=5)

        if self.scatter:
            plt.subplot(1, 2, 2)
            p = TSNE(n_components=2).fit_transform(embds)
            plt.scatter(p[:, 0], p[:, 1], c=labels)
 
        if self.savepath != None:
            plt.savefig(os.path.join(self.savepath, "logPlot.png"), 
                        transparent=True, bbox_inches='tight', pad_inches=0.05,
                        dpi=200)
        if self.show_fig:
            plt.show()
        else:
            plt.close('all')
            
    
    def _KmeansAcc(self, x, y, c):
        kmeans = KMeans(n_clusters=c, n_jobs=-1, n_init=10, random_state=12)
        
        # Increase the value of rindex from c*2, its is only a demo!
        rindex = random.sample(range(0, y.shape[0]), max(min(self.nodes, y.shape[0]), c*2))
        outs = self.AE(self.mel(x[rindex]), training=False)
        y_pred = kmeans.fit_predict(outs)
       
        acc = accuracy(y[rindex], y_pred)
        nmi = normalized_mutual_info_score(y[rindex], y_pred)
        ari = adjusted_rand_score(y[rindex], y_pred)
        return (acc, nmi, ari, outs, y[rindex])

In [None]:
class AEGenerator(tf.keras.utils.Sequence):
    """ Recieves X and y 
        Performes pairwise matching with a batch size 
        both pairs are generated from the input X and y
 
        X         : Input, shape=(bs, farme_size, 1)
        y         : Output shape=(bs, ) or (bs, 1)
        dist      : Distance parameter for AE
        noise     : Noise for augmentation, shape=(bs, farme_size, 1)
        scale     : The maximum limit of noise that would be mixed with X
        show_logs : Show random selection errors
        gt        : Ground truth of the actual class.
                    Use only when show_logs is True
    """
 
    def __init__(self, iX, iy, dist, noise, batch_size=64, 
                 scale=0.4, show_logs=False, gt=False,
                 # MEL layer setup
                 mel_scale='spec', setup=0, output_channel=1,
                 ):
        self.batch_size  = batch_size
        self.noise       = noise

        self.mel = MEL(scale=mel_scale, setup=setup, 
                       output_channel=output_channel)
        
        # inputX : mel, inputA : raw-audio, inputY : labels 
        #self.inputX = self.mel(iX).numpy()
        self.inputA = iX
        self.inputY = iy
        self.output_shape = self.mel(self.inputA[:1, ...]).numpy().shape

        if self.inputY.shape[-1] == 1:
            self.inputY = np.squeeze(self.inputY)
        if self.inputA.shape[-1] == 1:
            self.inputA = np.squeeze(self.inputA)
        if self.noise.shape[-1] == 1:
            self.noise = np.squeeze(self.noise)
        
        self.total       = len(iX)
        self.dist        = dist
        self.scale       = scale
        self.gt          = gt
 
        # Generates label : indexes_where_label_found
        ulabels = np.unique(self.inputY)
        self.class_index = dict([(label, list(np.where(self.inputY == label)[0])) \
                                 for label in ulabels])  
 
        self.indexes     = np.arange(self.total)
        self.total_batch = self.total // self.batch_size 
        self.classes = len(np.unique(self.inputY))
 
        self.rand = random.Random(12)
        random.seed(12)
 
        # Show logs
        self.show_logs = show_logs
        self.log = {}
        self.on_epoch_end()
    
    def _augment(self, wav_indices):
        """Needs to be modified
           Input_dim : [SR, 1]
        """
        noise_id = np.random.randint(low=0, high=self.noise.shape[0], 
                                     size=self.batch_size)
        scale = np.random.uniform(low=0, high=self.scale,
                                  size=(self.batch_size, 1))
        
        real = np.multiply(self.inputA[wav_indices], (1-scale)) 
        jitter = np.multiply(self.noise[noise_id], scale)
        wav = real + jitter

        return wav
    
    def __len__(self):
        """ Denotes the number of batches per epoch """
        return self.total_batch
 
    def _print_logs(self):
        for a, b in self.log.items():
            print(a, ':', b)
    
    def on_epoch_end(self):
        """ Updates indexes after each epoch """
        np.random.shuffle(self.indexes)
        if self.show_logs:
            self._print_logs()    
            self.log['canNotLink_error'] = 0
            self.log['time'] = 0
 
    def _make_choice(self, p=None):
        w = [0.5, 0.5]
        if p is not None:
            w = [p, 1-p]
        return self.rand.choices([True, False], weights=w)[0]
 
    def __getitem__(self, batch_index):
        """ Generate one batch of data """
        idx, y = self._genIndexes(batch_index)
        wavs = self.inputA[idx]
        
        if self.scale > 0:
            aug_idx = self.rand.sample(range(2*self.batch_size), 
                                       k=self.batch_size)
            wavs[aug_idx] = self._augment(idx[aug_idx])

        rets = self.mel(np.expand_dims(wavs, axis=-1)).numpy()
        return [rets[:self.batch_size], rets[self.batch_size:]], y
 
    def _genIndexes(self, index):
        idx = np.zeros((2, self.batch_size), dtype=np.int32)
        idx[0] = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        constraint = np.zeros(self.batch_size, dtype=np.int32)
 
        for i in range(self.batch_size):
            plabel = self.inputY[idx[0][i]]
            take = idx[0][i]
 
            # Generating must-links
            if self._make_choice():
                take = self.rand.choice(self.class_index[plabel])
            # Generating can-not links
            else:
                # Taking random pair
                while self.inputY[take] == plabel:
                    take = self.rand.choice(self.indexes)
                constraint[i] = self.dist
 
                if self.show_logs and self.gt[idx[0][i]] == self.gt[take]:
                    self.log['canNotLink_error'] += 1

            idx[1][i] = take

        return idx.reshape((2*self.batch_size)), constraint

In [None]:
# Implementation of the distance layer
# We are assuming that the dimenstion of both of the embedders are 
# already subtracted
class Distance(Layer):
    def __init__(self, ):
        super(Distance, self).__init__()
        
    def call(self, inputs):
        return tf.expand_dims(
               tf.math.sqrt(
               tf.math.reduce_sum(
               tf.math.square(inputs), axis=1)), axis=-1)


# Implementation of the AutoEmbedder
def buildAE(input_shape, dims, alpha=1, dis=100, topmodel='MobileNet',
            scale=None, setup=0, dropout=0.001, decay=False):
    
    inp1 = Input(input_shape, name='input1')
    inp2 = Input(input_shape, name='input2')


    topmodel = eval(f"{topmodel}(input_shape={input_shape}," + 
                        "include_top=False, weights='imagenet')")
    
    m = Sequential([topmodel,
                    Flatten(), 
                    Dense(dims)], name='AutoEmbedder')
    #m = SincNet(input_shape, dims, softmax=True)

    out1 = m(inp1)
    out2 = m(inp2)

    # Subtracting output pairs
    out = Subtract()([out1, out2])
    # Calculating distance
    out = Distance()(out)
    # The thresholded ReLU layer
    out = ReLU(max_value=dis)(out)

    # Initializing model
    model = Model([inp1, inp2], out, name="AE_train")
    # Using default adam optimizer with mean square error
    if decay:
        print('Decay loaded')
        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                       initial_learning_rate=0.1, decay_steps=50,
                       decay_rate=0.0005)

    # Optimizer default:
    # tf.keras.optimizers.Adam(learning_rate=0.0007 if decay is False else lr_schedule)
    model.compile(optimizer= tf.keras.optimizers.Adam(learning_rate=0.0005), 
                   loss='mse')
    AE = Model(inp1, out1, name="AutoEmbedder")    

    return model, AE

In [None]:
base_dir = './logs/DenseNet_ASR/'
model, AE = None, None

def load_AE(path, input_shape):
    model = tf.keras.models.load_model(path)
    print('Model loaded')
    inp = Input(input_shape)
    out = model.get_layer('AutoEmbedder')(inp)
    AE = Model(inp, out)
    print('AE loaded')
    return model, AE
 
 
def trainModel(model, X, cy, y, noise, db, spkrs, clim, 
               alpha=0.25, dist=100, dims=3, setup=2,
               batch_size=64, scale='spec', epochs=3000, decay=False,
               load_pretrain=False, frame=0.2, seconds=16,
               add_mel=True, aug_scale=0, base_dir=base_dir, 
               pretrain_path=None):
    '''
    This would check if previously ran model exists, and load it in model and 
    AE. This function would return MyLogger callback object and the starting 
    epoch for the model. Necessary savepath parameters would be declared in this
    function as parameter.
 
    Possible Naming Scheme: Model_dist_dims_setup_scale_frame_[T,L,B]_MIN_SPKRs_CLIM
    T : TIMIT
    L : LibriSpeech
    B : Bengali ASR
    dist  : Distance between clusters (known as AutoEmbedder alpha parameter)
    dims  : The output dimention of AutoEmbedder
    frame : Frame size (input frame size)
    scale : The scaling feature of speech (spectrogram, mfcc)
    SPKRS : Number of speakers
    CLIM  : Maximum Inter-cluster-linkage
    MIN   : Minutes of speech per-speaker
    '''
    global AE
    tf.keras.backend.clear_session()
    nmodel = f"{model}{alpha}" if model == 'MobileNet' else f"{model}"
    pname = f"{nmodel}_A{dist}_D{dims}_{setup}_{scale}_F{frame}_{db}_T{seconds}_S{spkrs}_L{clim}"
    
    loaded = False
    save_dir = os.path.join(base_dir, pname)
    
    dg = AEGenerator(iX=X, iy=cy, dist=dist, noise=noise, 
                     batch_size=batch_size, scale=aug_scale, 
                     show_logs=False, gt=y, 
                     output_channel=1 if model == 'MobileNet' else 3)

    #if not os.path.exists(save_dir):
    #    os.makedirs(save_dir)
    if os.path.exists(os.path.join(save_dir, 'model')):
        try:
            print('Previous model found, loading model')
            model, AE = load_AE(os.path.join(save_dir, 'model'), 
                                dg.output_shape[1:])
            loaded = True
        except Exception as e:
            print(str(e))
            pass
    
    if not loaded and load_pretrain and pretrain_path is not None:
        print('Loading pre-trained model')
        model, AE = load_AE(pretrain_path, dg.output_shape[1:])
        loaded = True
 

    if not loaded:
        print('Initializing new model')
        model, AE = buildAE(dg.output_shape[1:], 
                            dims=dims, setup=setup, 
                            alpha=alpha, scale=scale, dis=100, 
                            decay=decay, topmodel=model)
        loaded = True
    

    print(X.shape, cy.shape, y.shape)
    log = MyLogger([(X, cy), (X, y)], n=5, plot=True, AE=AE,
                   savepath=save_dir, 
                   save_model=True,
                   scatter=False,
                   val_train=True,
                   show_fig=False,
                   nodes=len(np.unique(y))*2,
                   output_channel= (1 if nmodel.startswith('MobileNet') else 3))

    model.summary()
    model.get_layer('AutoEmbedder').summary()
    #time.sleep(1)
    #AE.summary()
    
    print('Starting model training')
    model.fit(dg, epochs=epochs+1, verbose=0, callbacks=[log], 
              initial_epoch=log.start_epoch, use_multiprocessing=False,
              workers=4, max_queue_size=10)

In [None]:
frame = 0.2
seconds = 10
spkrs = 25
clim = 5
dbi = 1
db = ['TIMIT', 'LIBRI', 'ASR']

In [None]:
X, y, fi, noise = loadData(db[dbi], frame=frame, seconds=seconds)
 
print(X.shape, y.shape)
 
if spkrs is not None and spkrs != 0:
    X, y = speakerFilter(X, y, spkrs)
 
spkrs = len(np.unique(y))
cy = pairwiseRelations(y, max_lim=clim)
 
print(spkrs, clim, seconds, len(np.unique(cy)))
print(X.shape, y.shape)

In [None]:
base_dir = f'./logs/'

trainModel(model='DenseNet121', alpha=1, dist=100, dims=12, setup=0, 
           scale='spec', db=db[dbi][0], spkrs=spkrs, clim=clim, 
           X=X, cy=cy, y=y, noise=noise, batch_size=128,
           load_pretrain=False, frame=frame, epochs=1200,
           seconds=seconds, decay=False, add_mel=False,
           aug_scale=0.07, base_dir=base_dir, 
           pretrain_path=None)