In [1]:
# -*- coding: utf-8 -*-
"""
create on Wed Sep 16 09:29:29 2020

@author: Mingyu Hsueh

Environment:
    Tensorflow 2.0
    Python 3.8 ++
"""

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras.layers import Layer
from tensorflow.python.ops import nn
import matplotlib.pyplot as plt

  from ._conv import register_converters as _register_converters


In [2]:
from tensorflow.keras.optimizers import Optimizer
from tensorflow.keras import backend as K
from tensorflow.python.keras.optimizer_v2.gradient_descent import SGD
class Ge2eOptimizer(SGD):
    """
    Note:
        Inherited from tf.keras.optimizer.SGD
    Attributes:
        __init__: constructs Ge2eOptimizer class
        get_updates: compute and modify gradients
    """

    def __init__(self,
             learning_rate,
             momentum=0.0,
             nesterov=False,
             **kwargs):
        """
        Note:
        Args:
            learning_rate: SGD parameter
            momentum: SGD parameter
            nesterov: SGD parameter
        Returns:
        """

        super(Ge2eOptimizer, self).__init__(**kwargs)

    def get_updates(self, loss, params):
        """
        Note:
            Compute gradients and modify them according to the GE2E paper
        Args:
            loss: loss to be minimized
            params: all trainable variables in the model
        Returns:
            A list of modified gradients to be applied
        """
        
        grads = self.get_gradients(loss, params)
#         print(grads)
        grads_clip, _ = tf.clip_by_global_norm(grads, 3.0)
        grads_rescale= grads_clip[:-2] + [0.01*grad for grad in grads_clip[-2:]]   # 0.01 for w,b in similarity
        grads_and_vars = list(zip(grads_rescale, params))
        
        return [self.apply_gradients(grads_and_vars)]

In [3]:
def normalize(x):
    """
        l2 normalize for input vector x
        return: normalized input
    """
    return nn.l2_normalize(x,axis=-1)

In [4]:
# # my regression layer does not work

class regression(Layer):
    def __init__(self, N, M, dense_unit, flag, **kwargs):
        super(regression, self).__init__(**kwargs)
        
        self.N = N
        self.M = M
        self.dense_unit = dense_unit
                
    def build(self, input_shape):
        self.w = self.add_weight(name="w", shape=[1], dtype=tf.float32, 
#                                  initializer=tf.keras.initializers.Constant(10.), trainable=True)
                                 initializer=tf.constant_initializer(value=10.), trainable=True, \
                                 constraint=tf.keras.constraints.NonNeg())
        self.b = self.add_weight(name="b", shape=[1], dtype=tf.float32, 
                                 initializer=tf.keras.initializers.Constant(-5.), trainable=True)

        super(regression, self).build(input_shape)
    
    def get_config(self):
        config = {'N':self.N, 'M':self.M, 'dense_unit':self.dense_unit}
        base_config = super(regression, self).get_config()
        return dict(list(base_config.items())+list(config.items()))
    
    def call(self, inputs):
        """
        Calculate similarity matrix
        input: embedding (NM * P)
        return: similarity matrix in tf type (NM * N)
        """
        # reshape to [N, M, P]
        embedded_split = tf.reshape(inputs, shape=[self.N, self.M, self.dense_unit])
        center = normalize(tf.reduce_mean(embedded_split, axis=1))          # [N,P] normalized center vectors eq.(1)
        center_except = normalize(tf.reshape((tf.reduce_sum(embedded_split, axis=1, keepdims=True) \
                        - embedded_split), shape=[self.N*self.M, self.dense_unit]))  # [NM,P] center vectors eq.(8)
        # make similarity matrix eq.(9)
        S = tf.concat(
            [tf.concat([tf.reduce_sum(center_except[i*self.M:(i+1)*self.M,:]*embedded_split[j,:,:], axis=1, \
                                      keepdims=True) if i==j
                        else tf.reduce_sum(center[i:(i+1),:]*embedded_split[j,:,:], axis=1, keepdims=True) \
                        for i in range(self.N)], axis=1) for j in range(self.N)], axis=0)
        
        w_S = tf.abs(self.w)*S + self.b

        return w_S

In [5]:
# class regression(Layer):
#     def __init__(self, N, M, dense_unit, flag, **kwargs):
#         super(regression, self).__init__(**kwargs)
        
#         self.N = N
#         self.M = M
#         self.dense_unit = dense_unit
#         self.flag = flag
                
#     def build(self, input_shape):
#         self.w = self.add_weight(name="w", initializer=keras.initializers.Constant(value=10),trainable=True)
#         self.b = self.add_weight(name="b", initializer=keras.initializers.Constant(value=-5),trainable=True)
#         super(regression, self).build(input_shape)
        
#     def call(self, inputs):
#         if self.flag:
#             # [tot_utt, embed_dim]
#             utterances = inputs[0]
#             # [tot_utt, embed_dim, num_spkr]
#             centroids = tf.keras.backend.permute_dimensions(inputs[1], [1, 0, 2])

#             l2_utterances = nn.l2_normalize(utterances, axis=1)
#             l2_centroids = nn.l2_normalize(centroids, axis=1)

#             similarity = K.batch_dot(l2_utterances, l2_centroids, axes=[1, 1])
#         else:
#             l2_utterances = tf.nn.l2_normalize(inputs[0], axis=-1)
#             l2_centroids = tf.nn.l2_normalize(K.transpose(inputs[1]), axis=0)
#             similarity = K.dot(l2_utterances, l2_centroids)
            
#         self.weight = tf.clip_by_value(self.w, 1e-6, np.infty)
#         weighted_similarity = self.w * similarity + self.b
#         return weighted_similarity
    
#     def get_config(self):
#         config = {'N':self.N, 'M':self.M, 'output_dim': self.dense_unit}
#         base_config = super(regression, self).get_config()
#         return dict(list(base_config.items())+list(config.items()))

In [6]:
# class Centroid_matrix(Layer):
#     def __init__(self, num_speakers, num_utterance, **kwargs):
#         super(Centroid_matrix, self).__init__(**kwargs)
#         self.N = num_speakers
#         self.M = num_utterance

#     def build(self, input_shape):
#         super(Centroid_matrix, self).build(input_shape)

#     def call(self, inputs):

#         # input shape [tot_utt, embed_dim]
#         inputs = tf.keras.backend.permute_dimensions(inputs, [1, 0])  # [embed_dim, tot_utt]
        
#         # centroid_column
#         self_block = tf.keras.backend.ones(shape=[self.M, self.M], dtype=tf.float32) - \
#                                             tf.keras.backend.eye(self.M, dtype=tf.float32)
#         self_block = self_block / (self.M - 1) # subtract itself and mean
        
#         # [num_spkr_utt, num_spkr_utt]
#         centroid_block = tf.pad(self_block, [[0, 0], [0, (self.N - 1) * self.M]], name="normal_centroid_select_pad", \
#                                 constant_values=1/self.M) # other speakers mean
        
#         # [num_spkr_utt * num_spkr, num_spkr_utt]
#         centroid_per_spkr = tf.pad(centroid_block, [[0, (self.N - 1) * self.M], [0, 0]], name="other_utterances_zero", \
#                                    constant_values=0)
#         # [tot_utt, tot_utt]

#         # [tot_utt, tot_utt]
#         centroid_per_spkr_list = [tf.roll(centroid_per_spkr, axis=0, shift=spk_idx * self.M) for spk_idx in range(self.N)]
#         # num_spkr * [tot_utt, tot_utt]
#         centroid_list = tf.keras.backend.stack(centroid_per_spkr_list, axis=-1)
#         # [tot_utt, tot_utt, num_spkr]

#         self_exclusive_centroids = tf.keras.backend.dot(inputs, centroid_list)
#         # [embed_dim, tot_utt] * [tot_utt, tot_utt, num_spkr]
#         # ---> [embed_dim, tot_utt, num_spkr]
#         return self_exclusive_centroids
    
#     def get_config(self):
#         config = {'Nspeakers':self.N, 'Mutterance':self.M}
#         base_config = super(Centroid_matrix, self).get_config()
#         return dict(list(base_config.items())+list(config.items()))

In [7]:
# class Centroid_matrix_basic(Layer):
#     """
#     Note:
#         Compute centroids of all speakers (including the utterance itself)
#     Attributes:
#         __init__:
#         build: Creates the variables of the layer
#         call: Compute centroid matrix of speaker embeddings
#     """

#     def __init__(self, num_speakers, num_utterance, **kwargs):
#         """
#         Note:
#         Args:
#             num_speakers: number of speakers (in the paper, it is 64)
#             num_utterance: number of utterances per speaker (in the paper, it is 10)
#         Returns:
#         """

#         super(Centroid_matrix_basic, self).__init__(**kwargs)
#         self.N = num_speakers
#         self.M = num_utterance

#     def build(self, input_shape):
#         """
#         Note:
#             Creates the variables of the layer according the input_shape (optional)
#         Args:
#             input_shape: [#total_utterance, #emb_dim]
#         Returns:
#         """

#         super(Centroid_matrix_basic, self).build(input_shape)

#     def call(self, inputs):
#         """
#         Note:
#             Compute centroid matrix of speaker embeddings
#         Args:
#             inputs: the output of multi lstm and dense layer
#         Returns:
#             centroid: speaker centroid matrix [#spk, #emb_dim]
#         """

#         # Compute centroids of speaker embeddings
#         centroid = tf.keras.backend.reshape(inputs, [self.N, self.M, -1])
#         centroid = tf.keras.backend.mean(centroid, axis=1)

#         return centroid

In [8]:
class Ge2e_loss(Layer):
    """
    Note:
        Compute the loss of ge2e in two ways; softmax, contrast
    Attributes:
        __init__: constructs Ge2e_loss class
        call: compute the ge2e loss
    """

    def __init__(self, num_speakers, num_utterance, loss_type="contrast", **kwargs):
        """
        Note:
            set up the loss configurations; No. speakers, No. utterances, loss_type
        Args:
            num_speakers: the number of speakers
            num_utterance: the number of utterances
            loss_type: "softmax" or "contrast"
        Returns:
        """

        self.N = num_speakers
        self.M = num_utterance
        self.loss_type = loss_type
        super(Ge2e_loss, self).__init__(**kwargs)
    
    def get_config(self):
        config = {'Nspeakers':self.N, 'Mutterance':self.M, 'loss_type': self.loss_type}
        base_config = super(Ge2e_loss, self).get_config()
        return dict(list(base_config.items())+list(config.items()))

    def call(self, inputs):
        """
        Note:
            compute the ge2e loss of a batch
        Args:
            inputs: the similarities between batch utterances & batch centroids
                    [32*4, 32]
        Returns:
            loss: ge2e loss
        """

        
        # contrast loss : - positive + max(negatives)
        # loss_positive = tf.math.sigmoid(loss_positive)
#         self_block = tf.keras.backend.zeros(shape=[self.M, 1], dtype=tf.float32) # [M,1]
#         neg_blocks = tf.keras.backend.ones(shape=[self.M*(self.N-1), 1], dtype=tf.float32) #[M*(N-1),1]

#         mask_per_spkr = tf.keras.backend.concatenate([self_block, neg_blocks], axis=0) 
#         # [M*N,1]
#         mask_per_spkr_list = [tf.roll(mask_per_spkr, axis=1, shift=spk_idx*self.M) for spk_idx in range(self.N)]
#         neg_mask = tf.keras.backend.concatenate(mask_per_spkr_list, axis=1)
#         #[M*N,N]
#         pos_mask = tf.Variable(tf.Variable(tf.zeros([self.M*self.N, self.N])))
#         comparison = tf.equal( neg_mask, tf.constant(0, dtype=tf.float32) )
#         pos_mask = pos_mask.assign( tf.where (comparison, tf.zeros_like(neg_mask), neg_mask))

#         loss_positive = tf.multiply(tf.ones_like(pos_mask),pos_mask) - tf.math.sigmoid(pos_mask * inputs)
#         loss_negative = tf.keras.backend.max(tf.math.sigmoid(neg_mask * inputs), axis=1, keepdims=True)
#         #[tot_utt, 1]
#         loss = tf.keras.backend.sum(loss_positive, keepdims=True) + tf.keras.backend.sum(loss_negative, keepdims=True)
#         print(loss)
        S_correct = tf.concat([inputs[i*self.M:(i+1)*self.M, i:(i+1)] for i in range(self.N)], axis=0)  # colored entries in Fig.1

        S_sig = tf.sigmoid(inputs)
        S_sig = tf.concat([tf.concat([0*S_sig[i*self.M:(i+1)*self.M, j:(j+1)] if i==j
                          else S_sig[i*self.M:(i+1)*self.M, j:(j+1)] for j in range(self.N)], axis=1)
                         for i in range(self.N)], axis=0)
        loss = tf.reduce_sum(1-tf.sigmoid(S_correct)+tf.reduce_max(S_sig, axis=1, keepdims=True))
        return tf.reshape(loss,(1,))

In [9]:
def test_similarity(enroll, evaluation):
    """
        Calculate similarity matrix
        input: enroll (NM * P)
               eval (K * P)
        return: similarity matrix in tf type (N * K)
    """
#     tmp1, tmp2 = enroll.shape, evaluation.shape
#     if tmp1[1]==tmp2[1]: similarity_matrix = np.matmul(evaluation, enroll.T)
#     else: similarity_matrix = np.matmul(evaluation, enroll)
    similarity_matrix = np.matmul(evaluation, enroll.T)
    return similarity_matrix

In [10]:
def EER_estimate(similarity_matrix, S_ground, draw=False):
    """
        estimate EER and draw (Optional)
        Input: similarity_matrix, ground, N, M
        Output: EER value, fig
    """
    S = similarity_matrix
    M = similarity_matrix.shape[0]
    N = similarity_matrix.shape[1]
    print(similarity_matrix.shape)
    print(S_ground.shape)
    # initial variable
    diff, EER, EER_thres, EER_FAR, EER_FRR = 1.0, 0.0, 0.0, 0.0, 0.0
    if draw: all_FAR, all_FRR = [], []

    # through threshold to calculate FR and FA rate
    for thres in [0.01*i for i in range(100)]:
        S_thres = S>thres
        TP, TN, FP, FN = 0, 0, 0, 0
        
        for k in range(M):
            for j in range(N):
                if S_ground[k,j]:
                    if S_thres[k,j]: TP+=1
                    else: FN+=1
                else:
                    if S_thres[k,j]: FP+=1
                    else: TN+=1
                    
        FAR = FP/(FP+TN)
        FRR = FN/(TP+FN)
        if draw:
            all_FAR.append(FAR)
            all_FRR.append(FRR)
        
        # update EER
        if diff>np.abs(FAR-FRR):
            diff = np.abs(FAR-FRR)
            EER = (FAR+FRR)/2
            EER_thres, EER_FAR, EER_FRR = thres, FAR, FRR
            print('confusion matrix')
            print('%5d, %5d, %5d, %5d'%(TP,FN,TN,FP))
    print("EER: %0.2f (thres: %0.2f, FAR: %0.2f, FRR: %0.2f) \n"%(EER*100.0, EER_thres*1.0, EER_FAR*100.0, EER_FRR*100.0))
    return EER
    
    
    if draw:
        thres = np.linspace(0,1,100)
        plt.plot(thres, all_FRR)
        plt.plot(thres, all_FAR)
        plt.legend(['FRR','FAR'])
        plt.title('keyword, EER=%1.2f%%'%(EER*100))
        plt.xlabel('threshold')
        plt.ylabel('Error rate')
        plt.show()
