In [1]:
from google.colab import drive
drive.mount('/drive')

Mounted at /drive


In [2]:
%cd /drive/MyDrive/'Speaker-Diarization-ATML2021'/lifelong
%ls

/drive/MyDrive/Speaker-Diarization-ATML2021/lifelong
 all_transcripts.json      hybrid_speaker_clustering.py   [0m[01;34m__pycache__[0m/
'Copy of lifelong.ipynb'   [01;36micsimeetingcorpus[0m@             [01;36mshare[0m@
 ge2eloss.py               lifelong.ipynb                 show_generator.py
 goddamn.h5                model.py                       spectral_cluster.py
 helper_funcs.py           nearest_embeddings.py


In [3]:
!pip install simpleder

Collecting simpleder
  Downloading https://files.pythonhosted.org/packages/99/30/753524ad7b95aed7a7f2cbb9913a99e0987c3b98d07c2980941c958c9647/simpleder-0.0.3-py3-none-any.whl
Installing collected packages: simpleder
Successfully installed simpleder-0.0.3


In [4]:
import tensorflow as tf
import tensorflow.keras as K
import simpleder
import time
import copy
import json
import os

from show_generator import *
from helper_funcs import *
from model import *
from ge2eloss import *
from hybrid_speaker_clustering import *
from nearest_embeddings import *

In [5]:
n_mfcc = 13
show_len = 10 #minutes
icsi_audio_dir = "icsimeetingcorpus/Signals"
icsi_segments_dir = "icsimeetingcorpus/ICSIplus/Segments"
cache_dir = "/content/sample_data"

In [6]:

if os.path.isfile('all_transcripts.json'):
    with open('all_transcripts.json', 'r') as json_file:
        all_transcripts = json.load(json_file)
else:
    all_transcripts = gather_transcripts(icsi_segments_dir)
    with open('all_transcripts.json', 'w') as json_file:
        json.dump(all_transcripts, json_file)

cleaned_transcripts = {k: remove_overlaps(all_transcripts[k], min_len=0) for k in all_transcripts.keys()}
small_transcripts = split_transcripts(cleaned_transcripts, show_len)


In [7]:
# feature extractor
import librosa
def extract_mfcc(signal_segment, sr=16000, n_mfcc=13, n_fft=512):
    hop = int(0.010 * sr) # 10ms
    win = int(0.025 * sr) # 25ms
    mfcc = librosa.feature.mfcc(signal_segment, sr, n_mfcc= n_mfcc, hop_length= hop, win_length= win, n_fft=n_fft, window= "ham")
    del_mfcc = librosa.feature.delta(mfcc)
    feat = np.vstack([mfcc, del_mfcc])
    return feat

In [8]:
# load the model
lifelong_model = K.models.load_model("share/embd_model.h5")
original_w = copy.deepcopy(lifelong_model.trainable_variables)

# generator takes only one walk through the dataset
generator = show_generator(small_transcripts, extract_mfcc, icsi_audio_dir, cache_dir, sr=16000, prefetch_buffer_size=5, max_threads=3)
int_to_token = generator.get_int_to_token()

centroids = []
n_embds = []
covariance_matrices = []
old_mfccs = []
old_labels = []


# lifelong Training 
model_opt = K.optimizers.RMSprop(learning_rate=1e-4)
param_opt = K.optimizers.RMSprop(learning_rate=1e-4)

w = tf.Variable(initial_value= 13.0, trainable= True, 
                constraint= lambda x: tf.clip_by_value(x, 0.0, 1000.0), name= 'w')    
b = tf.Variable(initial_value= -5.0, trainable= True, name= 'b')

der_record = []
max_mem = 500
for bnum, show in enumerate(generator.dict_generator()):
    mfcc_features = show['feats']
    rounded_groundtruth = show['labels']

    # concatenate mfcc features with older ones in memory
    original_length = len(mfcc_features)
    if len(old_mfccs) > 1: mfcc_features = tf.concat([mfcc_features, old_mfccs], axis=0)

    # put mfcc features through model
    with tf.GradientTape() as tape:

        tape.watch(w)
        tape.watch(b)
        print(mfcc_features.shape, len(old_mfccs))
        embeds = lifelong_model(mfcc_features, training=True)
        embeds = tf.nn.l2_normalize(embeds, axis=1)

        # get labels from clustering stuffs
        predicted_labels = []
        for i in range(len(embeds) - len(old_mfccs)):
            label, centroids, covariance_matrices, n_embds = classify_new_embeds(
                tf.expand_dims(embeds[i],0).numpy(), centroids, covariance_matrices, n_embds,
                min_threshold = 1.2, max_threshold = 1.5)
            predicted_labels.append(label)

        if len(old_labels)> 1: _labels = tf.concat([predicted_labels, old_labels], axis=0)
        else: _labels = predicted_labels

        # compute loss with ge2e
        embeds= tf.expand_dims(embeds, 0)
        _labels = tf.expand_dims(_labels, 0)
        loss0 = ge2e(embeds, _labels, [tf.squeeze(centroids)], w, b)

        # assuming that the model weights staw around the older weights
        temp = []
        for cw, ow in zip(lifelong_model.trainable_variables, original_w):
            temp.append(tf.reduce_sum(tf.square(cw-ow)))
        loss1 = tf.reduce_sum(temp)
        loss = loss0 + 1000*loss1

    # apply gradients
    gradients = tape.gradient(loss, [lifelong_model.trainable_variables, w, b])
    gradients1 = [(tf.clip_by_value(grad, -10.0, 10.0)) for grad in gradients[0]]
    gradients2 = [(tf.clip_by_value(grad, -10.0, 10.0)) for grad in gradients[1:]]
    model_opt.apply_gradients(zip(gradients1, lifelong_model.trainable_variables))
    param_opt.apply_gradients(zip(gradients2, [w,b]))

    # update old stuffs
    unique_classes = np.unique(predicted_labels)
    embeds = tf.squeeze(embeds)[0:original_length]
    for k in unique_classes:
        class_mask = tf.equal(predicted_labels, k)
        class_embeds = tf.boolean_mask(embeds, class_mask)
        class_mfccs = tf.boolean_mask(mfcc_features[0:original_length], class_mask)
        indices = identify_nearest_embeddings(class_embeds, 3)
        for i in indices:
            old_mfccs.append(class_mfccs[int(i)])
            old_labels.append(k)
            if len(old_mfccs) > max_mem: 
                old_mfccs.pop(0)
                old_labels.pop(0)
    
    # save checkpoint
    if bnum%20 == 0:
        lifelong_model.save("goddamn.h5")

    # we may use the approximate ground truth to get DER
    hyp = []
    for i, label in enumerate(predicted_labels):
        if rounded_groundtruth[i] == 58: pass # dont evaluate on overlap
        hyp.append((str(label), float(i)*2, float(i+1)*2))

    ref = []
    for i, gt in enumerate(rounded_groundtruth):
        if rounded_groundtruth[i] == 58: pass # dont evaluate on overlap
        ref.append((str(gt), float(i)*2, float(i+1)*2))

    der = simpleder.DER(ref, hyp)
    der_record.append(der)

    print("loss: ", loss0.numpy(), "loss1: ", loss1.numpy(), "der: ", der) 

(300, 201, 26) 0
loss:  2.6698008 loss1:  0.0 der:  0.38717948717948714
(306, 201, 26) 6
loss:  0.7874467 loss1:  0.10133819 der:  0.4230769230769231
(312, 201, 26) 12
loss:  0.7209668 loss1:  1.7750406e-05 der:  0.41025641025641024
(318, 201, 26) 18
loss:  1.4011937 loss1:  4.0461746e-05 der:  0.4512820512820513
(324, 201, 26) 24
loss:  1.2784091 loss1:  0.000109971355 der:  0.46923076923076923
(330, 201, 26) 30
loss:  176.25977 loss1:  0.00015928186 der:  0.38974358974358975
(336, 201, 26) 36
loss:  1.691704 loss1:  0.012703052 der:  0.43589743589743585
(342, 201, 26) 42
loss:  1.1151886 loss1:  0.002057066 der:  0.34102564102564104
(348, 201, 26) 48
loss:  0.86249125 loss1:  0.00072745694 der:  0.4923076923076923
(354, 201, 26) 54
loss:  180.58188 loss1:  0.0003042342 der:  0.4230769230769231
(360, 201, 26) 60
loss:  142.72243 loss1:  0.010520209 der:  0.34102564102564104
(366, 201, 26) 66
loss:  148.28363 loss1:  0.015896512 der:  0.34615384615384615
(372, 201, 26) 72
loss:  128.19

In [11]:
lifelong_model.save("model.h5")