In [None]:
# Local imports
from main.preprocess import *
from main.siamese import *
from main.constants import *

import os
import random
import numpy as np
import wave
import pyaudio
import tensorflow as tf
import matplotlib.pyplot as plt

%matplotlib inline

os.environ['TF_CUDNN_DETERMINISM'] = '1'
os.environ['TF_DETERMINISTIC_OPS'] = '1'
np.random.seed(1)
tf.random.set_seed(1)
random.seed(1)

# For preventing failing during training on gpu
gpus = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [None]:
%%time
df_all, df, labels = get_dataset_frame(
    audio_direct, ext, num_speakers
)

In [None]:
unverif_samples = 20
impost_df = df_all[~df_all['Target'].isin(labels)]
imposts = impost_df.sample(unverif_samples, random_state=0)

In [None]:
verif_samples = 3
bound = 2 * triplet_len
verif_df = []

for label in labels:
    samples_df = df[df['Target'] == label].iloc[bound:bound+verif_samples]
    verif_df.append(samples_df)
    
verif_df = pd.concat(verif_df)

verif_paths = verif_df['Full_path'].to_numpy()
verif_targets = verif_df['Target'].to_numpy()

In [None]:
%%time
changed_signals, threshold, _ = truncate_or_pad(
    audio_direct, ext, num_speakers, triplet_len, alpha, rate
)

In [None]:
def preprocess_samples(path_to_audio, database,
                       model, dist_threshold):
    wave = load_emphas(path_to_audio, alpha, rate)
    wave = clear_from_silence(wave)
    changed_wave = np.zeros((threshold,))
    if len(wave) > threshold:
        changed_wave = wave[:threshold]
    else:
        changed_wave[:len(wave)] = wave
    spec = get_spectrogram(changed_wave, nfft, hop_len,
                           win_len, hamming)
    spec_features = segment_spectrogram(
        spec, num_segments, num_features
    )

    _input_shape = spec_features.shape[1:]
    spec_conv = SpectrogramConvolution(_input_shape)
    micro_conv = spec_conv.convolute(spec_features)

    min_dist, is_verify, identity = identity_verification(
        micro_conv, database, model, dist_threshold
    )
    return [min_dist, is_verify, identity]

In [None]:
%%time
# Preprocess microphone input
record_format = pyaudio.paInt16
num_channels = 1
chunk_size = 1024
audio_dur = threshold / rate
out_filename = 'test.wav'

audio = pyaudio.PyAudio()
stream = audio.open(format=record_format, input=True,
                rate=rate, channels=num_channels,
                frames_per_buffer=chunk_size)
frames = []

# Record voice
for _ in range(int(rate / chunk_size * audio_dur)):
    data = stream.read(chunk_size)
    frames.append(data)

stream.stop_stream()
stream.close()
audio.terminate()

In [None]:
waveFile = wave.open(out_filename, 'wb')
waveFile.setnchannels(num_channels)
waveFile.setsampwidth(audio.get_sample_size(record_format))
waveFile.setframerate(samples_rate)
waveFile.writeframes(b''.join(frames))
waveFile.close()

In [None]:
unverif_targets = imposts['Target'].to_list()
unverif_targets.append('Me')
unverif_paths = imposts['Full_path'].to_list()
unverif_paths.append('test.wav')
unverif_paths = np.asarray(unverif_paths)
unverif_targets = np.asarray(unverif_targets)

In [None]:
%%time
history, siam, convs, labels = train_siamese()

In [None]:
database = create_database(convs, labels, siam)

In [None]:
dist_threshold = 0.005

unverif_results = np.asarray([preprocess_samples(
    path, database, siam, dist_threshold
) for path in unverif_paths])

In [None]:
unverif_results = pd.DataFrame(unverif_results,
    columns=['Distance', 'Is_verified', 'Identity'])
unverif_results['Target'] = unverif_targets
#unverif_results

In [None]:
incorrect = unverif_results[unverif_results['Is_verified']]
accuracy = 1 - len(incorrect) / len(unverif_results)
accuracy

In [None]:
verif_results = np.asarray([preprocess_samples(
    path, database, siam, dist_threshold
) for path in verif_paths])

In [None]:
verif_results = pd.DataFrame(verif_results,
    columns=['Distance', 'Is_verified', 'Identity'])
verif_results['Target'] = verif_targets
#verif_results

In [None]:
correct = verif_results[verif_results['Target'] == verif_results['Identity']]
accuracy = len(correct) / len(verif_results)
accuracy

In [None]:
plt.figure(fisize=(10, 8))