In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
ctr=0
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    if ctr!=0:
        break
    for filename in filenames:
        print(os.path.join(dirname, filename))
        ctr+=1

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Install dependencies

In [None]:
!pip install speechbrain matplotlib pandas seaborn pytorch_lightning torchmetrics resampy python_speech_features scikit-learn tensorboard

# Main Code

In [None]:
import os

# Datasets imports
import glob
import random
import resampy
from scipy.io import wavfile
from scipy.signal import fftconvolve
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Dataset
from python_speech_features import mfcc

# Plda_classifier imports
import pickle
from speechbrain.processing.PLDA_LDA import *

# Plda_score_stat imports
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn
from sklearn.manifold import TSNE
from speechbrain.utils.metric_stats import EER, minDCF

# Main imports
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.tensorboard
import torchmetrics
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.utils.data import DataLoader

# global variables
global num_worker, log_dir, bse_dir, data_final_path, n_epochs_total, n_batch_size
bse_dir = "/kaggle/working/"
log_dir = bse_dir+"testlogs"
num_worker = 4
data_final_path = "/kaggle/input/audiodataset10percent/"
n_epochs_total = 2
n_batch_size = 512

# Tdnn_layer
class TdnnLayer(nn.Module):
    def __init__(self, input_size=24, output_size=512, context=[0], batch_norm=True, dropout_p=0.0):
        super(TdnnLayer, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.context = context
        self.batch_norm = batch_norm
        self.dropout_p = dropout_p
        self.linear = nn.Linear(input_size*len(context), output_size)
        self.relu = nn.ReLU()
        if(self.batch_norm):
            self.norm = nn.BatchNorm1d(output_size)
        if(self.dropout_p):
            self.drop = nn.Dropout(p=self.dropout_p)
    def forward(self, x):
        x_context = get_time_context(x, self.context)
        x = torch.cat(x_context, 2)
        x = self.linear(x)
        x = self.relu(x)
        if(self.dropout_p):
            x = self.drop(x)
        if(self.batch_norm):
            x = x.transpose(1,2)
            x = self.norm(x)
            x = x.transpose(1,2)
        return x
def get_time_context(x, c=[0]):
    l = len(c) - 1
    xc =   [x[:, c[l]+cc:c[0]+cc, :]
            if cc!=c[l] else
            x[:, c[l]+cc:, :]
            for cc in c]
    return xc
# Plda_classifier
def get_train_x_vec(train_xv, train_label, x_id_train):
    N = train_xv.shape[0]
    print('N train utt:', N)
    md = ['id'+str(train_label[i]) for i in range(N)]
    modelset = np.array(md, dtype="|O")
    sg = [str(x_id_train[i]) for i in range(N)]
    segset = np.array(sg, dtype="|O")
    s = np.array([None] * N)
    stat0 = np.array([[1.0]]* N)
    xvectors_stat = StatObject_SB(modelset=modelset, segset=segset, start=s, stop=s, stat0=stat0, stat1=train_xv)
    return xvectors_stat
def setup_plda(mean=None, F=None, Sigma=None, rank_f=150, nb_iter=10, scaling_factor=1):
    plda = PLDA(mean=mean, F=F, Sigma=Sigma, rank_f=rank_f, nb_iter=nb_iter, scaling_factor=scaling_factor)
    return plda
def train_plda(plda, xvectors_stat):
    plda.plda(xvectors_stat)
    return plda
def get_x_vec_stat(xv, id):
    N = xv.shape[0]
    sgs = [str(id[i]) for i in range(N)]
    sets = np.array(sgs, dtype="|O")
    s = np.array([None] * N)
    stat0 = np.array([[1.0]]* N)
    xv_stat = StatObject_SB(modelset=sets, segset=sets, start=s, stop=s, stat0=stat0, stat1=xv)
    return xv_stat
def plda_scores(plda, en_stat, te_stat):
    ndx = Ndx(models=en_stat.modelset, testsegs=te_stat.modelset)
    fast_plda_scores = fast_PLDA_scoring(en_stat, te_stat, ndx, plda.mean, plda.F, plda.Sigma, p_known=0.0)
    return fast_plda_scores
def save_plda(plda, file_name):
    try:
        with open(bse_dir+'plda/'+file_name+'.pickle', 'wb') as f:
            pickle.dump(plda, f, protocol=pickle.HIGHEST_PROTOCOL)
    except Exception as ex:
        print('Error during pickling plda: ', ex)
def load_plda(file_path_name):
    try:
        with open(file_path_name, 'rb') as f:
            return pickle.load(f)
    except Exception as ex:
        print('Error during pickling plda: ', ex)
def lda(x_vec_stat, reduced_dim=2):
    lda = LDA()
    new_train_obj = lda.do_lda(x_vec_stat, reduced_dim=reduced_dim)
    return new_train_obj
# Plda_score_stat
class plda_score_stat_object():
    def __init__(self, x_vectors_test):
        self.x_vectors_test = x_vectors_test
        self.x_id_test = np.array(self.x_vectors_test.iloc[:, 1])
        self.x_vec_test = np.array([np.array(x_vec[1:-1].split(), dtype=np.float64) for x_vec in self.x_vectors_test.iloc[:, 3]])
        self.en_stat = get_x_vec_stat(self.x_vec_test, self.x_id_test)
        self.te_stat = get_x_vec_stat(self.x_vec_test, self.x_id_test)
        self.plda_scores = 0
        self.positive_scores = []
        self.negative_scores = []
        self.positive_scores_mask = []
        self.negative_scores_mask = []
        self.eer = 0
        self.eer_th = 0
        self.min_dcf = 0
        self.min_dcf_th = 0
        self.checked_xvec = []
        self.checked_label = []
    def test_plda(self, plda, veri_test_file_path):
        self.plda_scores = plda_scores(plda, self.en_stat, self.te_stat)
        self.positive_scores_mask = np.zeros_like(self.plda_scores.scoremat)
        self.negative_scores_mask = np.zeros_like(self.plda_scores.scoremat)
        checked_list = []
        for pair in open(veri_test_file_path):
            is_match = bool(int(pair.split(" ")[0].rstrip().split(".")[0].strip()))
            enrol_id = pair.split(" ")[1].strip()
            test_id = pair.split(" ")[2].strip()
            i = int(np.where(self.plda_scores.modelset == enrol_id)[0][0])
            if(not enrol_id in checked_list):
                checked_list.append(enrol_id)
                self.checked_xvec.append(np.array(self.x_vectors_test.loc[self.x_vectors_test['id'] == enrol_id, 'xvector'].item()[1:-1].split(), dtype=np.float64))
                self.checked_label.append(int(enrol_id.split(".")[0].split("/")[0][2:]))
            j = int(np.where(self.plda_scores.segset == test_id)[0][0])
            if(not test_id in checked_list):
                checked_list.append(test_id)
                self.checked_xvec.append(np.array(self.x_vectors_test.loc[self.x_vectors_test['id'] == test_id, 'xvector'].item()[1:-1].split(), dtype=np.float64))
                self.checked_label.append(int(test_id.split(".")[0].split("/")[0][2:]))
            current_score = float(self.plda_scores.scoremat[i,j])
            if(is_match):
                self.positive_scores.append(current_score)
                self.positive_scores_mask[i,j] = 1
            else:
                self.negative_scores.append(current_score)
                self.negative_scores_mask[i,j] = 1
        self.checked_xvec = np.array(self.checked_xvec)
        self.checked_label = np.array(self.checked_label)
    def calc_eer_mindcf(self):
        self.eer, self.eer_th = EER(torch.tensor(self.positive_scores), torch.tensor(self.negative_scores))
        self.min_dcf, self.min_dcf_th = minDCF(torch.tensor(self.positive_scores), torch.tensor(self.negative_scores), p_target=0.5)
    def plot_images(self, writer):
        split_xvec = []
        split_label = []
        group_kfold = sklearn.model_selection.GroupKFold(n_splits=2)
        groups1234 = np.where(self.checked_label<10290, 0, 1)
        for g12, g34 in group_kfold.split(self.checked_xvec, self.checked_label, groups1234):
            x12, x34 = self.checked_xvec[g12], self.checked_xvec[g34]
            y12, y34 = self.checked_label[g12], self.checked_label[g34]
            groups12 = np.where(y12<10280, 0, 1)
            groups34 = np.where(y34<10300, 0, 1)
            for g1, g2 in group_kfold.split(x12, y12, groups12):
                split_xvec.append(x12[g1])
                split_xvec.append(x12[g2])
                split_label.append(y12[g1])
                split_label.append(y12[g2])
                break
            for g3, g4 in group_kfold.split(x34, y34, groups34):
                split_xvec.append(x34[g3])
                split_xvec.append(x34[g4])
                split_label.append(y34[g3])
                split_label.append(y34[g4])
                break
            break
        split_xvec = np.array(split_xvec)
        split_label = np.array(split_label)
        print('generating images for tensorboard')
        scoremat_norm = np.array(self.plda_scores.scoremat)
        scoremat_norm -= np.min(scoremat_norm)
        scoremat_norm /= np.max(scoremat_norm)
        print('score_matrix')
        img = np.zeros((3, scoremat_norm.shape[0], scoremat_norm.shape[1]))
        img[0] = np.array([scoremat_norm])
        img[1] = np.array([scoremat_norm])
        img[2] = np.array([scoremat_norm])
        writer.add_image('score_matrix', img, 0)
        print('ground_truth')
        img = np.zeros((3, scoremat_norm.shape[0], scoremat_norm.shape[1]))
        img[1] = np.array([self.positive_scores_mask])
        img[0] = np.array([self.negative_scores_mask])
        writer.add_image('ground_truth', img, 0)
        print('ground_truth_scores')
        img = np.zeros((3, scoremat_norm.shape[0], scoremat_norm.shape[1]))
        img[1] = np.array([scoremat_norm*self.positive_scores_mask])
        img[0] = np.array([scoremat_norm*self.negative_scores_mask])
        writer.add_image('ground_truth_scores', img, 0)
        checked_values_map = self.positive_scores_mask + self.negative_scores_mask
        checked_values = checked_values_map * self.plda_scores.scoremat
        eer_prediction_positive = np.where(checked_values >= self.eer_th, 1, 0) * checked_values_map
        eer_prediction_negative = np.where(checked_values < self.eer_th, 1, 0) * checked_values_map
        min_dcf_prediction_positive = np.where(checked_values >= self.min_dcf_th, 1, 0) * checked_values_map
        min_dcf_prediction_negative = np.where(checked_values < self.min_dcf_th, 1, 0) * checked_values_map
        print('prediction_eer_min_dcf')
        img = np.ones((3, scoremat_norm.shape[0], scoremat_norm.shape[1]*2+5))
        img[1,:,:checked_values.shape[1]] = eer_prediction_positive
        img[0,:,:checked_values.shape[1]] = eer_prediction_negative
        img[1,:,-checked_values.shape[1]:] = min_dcf_prediction_positive
        img[0,:,-checked_values.shape[1]:] = min_dcf_prediction_negative
        img[2,:,:checked_values.shape[1]] = 0
        img[2,:,-checked_values.shape[1]:] = 0
        writer.add_image('prediction_eer_min_dcf', img, 0)
        print('correct_prediction_eer_min_dcf')
        img = np.ones((3, scoremat_norm.shape[0], scoremat_norm.shape[1]*2+5))
        img[1,:,:checked_values.shape[1]] = eer_prediction_positive * self.positive_scores_mask
        img[0,:,:checked_values.shape[1]] = eer_prediction_negative * self.negative_scores_mask
        img[1,:,-checked_values.shape[1]:] = min_dcf_prediction_positive * self.positive_scores_mask
        img[0,:,-checked_values.shape[1]:] = min_dcf_prediction_negative * self.negative_scores_mask
        img[2,:,:checked_values.shape[1]] = 0
        img[2,:,-checked_values.shape[1]:] = 0
        writer.add_image('correct_prediction_eer_min_dcf', img, 0)
        print('false_prediction_eer_min_dcf')
        img = np.ones((3, scoremat_norm.shape[0], scoremat_norm.shape[1]*2+5))
        img[1,:,:checked_values.shape[1]] = eer_prediction_positive * self.negative_scores_mask
        img[0,:,:checked_values.shape[1]] = eer_prediction_negative * self.positive_scores_mask
        img[1,:,-checked_values.shape[1]:] = min_dcf_prediction_positive * self.negative_scores_mask
        img[0,:,-checked_values.shape[1]:] = min_dcf_prediction_negative * self.positive_scores_mask
        img[2,:,:checked_values.shape[1]] = 0
        img[2,:,-checked_values.shape[1]:] = 0
        writer.add_image('false_prediction_eer_min_dcf', img, 0)
        def generate_scatter_plot(x, y, label, plot_name):
            df = pd.DataFrame({'x': x, 'y': y, 'label': label})
            fig, ax = plt.subplots(1)
            fig.set_size_inches(16, 12)
            sns.scatterplot(x='x', y='y', hue='label', palette='bright', data=df, ax=ax, s=80) #use sns.color_palette("hls", 40) for 40 speakers
            limx = (x.min()-5, x.max()+5)
            limy = (y.min()-5, y.max()+5)
            ax.set_xlim(limx)
            ax.set_ylim(limy)
            ax.set_aspect(1.0/ax.get_data_ratio(), adjustable='box')
            ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
            ax.title.set_text(plot_name)
        for i, (checked_xvec, checked_label) in enumerate(zip(split_xvec, split_label)):
            print('scatter_plot_LDA'+str(i+1))
            new_stat = get_x_vec_stat(checked_xvec, checked_label)
            new_stat = lda(new_stat)
            generate_scatter_plot(new_stat.stat1[:, 0], new_stat.stat1[:, 1], checked_label, 'scatter_plot_LDA'+str(i+1))
            writer.add_figure('scatter_plot_LDA'+str(i+1), plt.gcf())
            print('scatter_plot_PCA'+str(i+1))
            pca = sklearn.decomposition.PCA(n_components=2)
            pca_result = pca.fit_transform(sklearn.preprocessing.StandardScaler().fit_transform(checked_xvec))
            generate_scatter_plot(pca_result[:,0], pca_result[:,1], checked_label, 'scatter_plot_PCA'+str(i+1))
            writer.add_figure('scatter_plot_PCA'+str(i+1), plt.gcf())
            print('scatter_plot_TSNE'+str(i+1))
            tsne = TSNE(2)
            tsne_result = tsne.fit_transform(checked_xvec)
            generate_scatter_plot(tsne_result[:,0], tsne_result[:,1], checked_label, 'scatter_plot_TSNE'+str(i+1))
            writer.add_figure('scatter_plot_TSNE'+str(i+1), plt.gcf())
# Dataset
EPS = 1e-20
class Dataset(Dataset):
    def __init__(self,
                sampling_rate=16000,
                mfcc_numcep=24,
                mfcc_nfilt=26,
                mfcc_nfft=512,
                data_folder_path='data',
                augmentations_per_sample=2):
        self.samples = []
        self.labels = []
        self.n_samples = 0
        self.unique_labels = []
        self.train_samples = []
        self.train_labels = []
        self.val_samples = []
        self.val_labels = []
        self.test_samples = []
        self.test_labels = []
        self.data_folder_path = data_folder_path
        self.sampling_rate = sampling_rate
        self.augmentations_per_sample = augmentations_per_sample
        self.mfcc_numcep = mfcc_numcep
        self.mfcc_nfilt = mfcc_nfilt
        self.mfcc_nfft = mfcc_nfft

    def init_samples_and_labels(self):
        vox_train_path = self.data_folder_path + '/VoxCeleb/vox1_dev_wav/*/*/*.wav'
        vox_test_path = self.data_folder_path + '/VoxCeleb/vox1_test_wav/*/*/*.wav'
        globs = glob.glob(vox_train_path)
        print('collectiong training and validation samples')
        samples = [(sample, 'none') for sample in globs]
        labels = [os.path.basename(os.path.dirname(os.path.dirname(f))) for f in globs]
        for i in range(self.augmentations_per_sample):
            samples = samples + [(sample, random.choice(['music', 'speech', 'noise', 'rir'])) for sample in globs]
            labels = labels + [os.path.basename(os.path.dirname(os.path.dirname(f))) for f in globs]
        unique_labels = np.unique(labels)
        print('found:')
        print(len(unique_labels), ' unique speakers')
        print(int(len(samples)/(self.augmentations_per_sample+1)), ' voice samples')
        print(len(samples), ' total voice samples including augmentations')
        print('splitting into 90% training and 10% validation')
        skf = StratifiedKFold(n_splits=10, shuffle=True)
        train_index, val_index = [], []
        for traini, vali in skf.split(samples, labels):
            if(len(vali) == int(round(len(samples)/10))):
                train_index = traini
                val_index = vali
        if(len(train_index) <= 1):
            print('StratifiedKFold Failed')
        self.train_samples = list(np.array(samples)[train_index])
        self.train_labels = list(np.array(labels)[train_index])
        self.val_samples = list(np.array(samples)[val_index])
        self.val_labels = list(np.array(labels)[val_index])
        globs = glob.glob(vox_test_path)
        print('collectiong test samples')
        test_samples = [(sample, 'none') for sample in globs]
        test_labels = [os.path.basename(os.path.dirname(os.path.dirname(f))) for f in globs]
        unique_labels = np.unique(test_labels)
        print('found:')
        print(len(unique_labels), ' unique speakers')
        print(len(test_samples), ' voice samples')
        print('DONE collectiong samples')
        self.test_samples = list(np.array(test_samples))
        self.test_labels = list(np.array(test_labels))
    def __getitem__(self, index):
        sample_path, augmentation = self.samples[index]
        rate, sample = wavfile.read(sample_path, np.dtype)
        sample = resampy.resample(sample, rate, self.sampling_rate)
        augmented_sample = self.augment_data(sample, augmentation)
        augmented_sample = mfcc(augmented_sample, self.sampling_rate, numcep=self.mfcc_numcep, nfilt=self.mfcc_nfilt, nfft=self.mfcc_nfft)
        label = self.unique_labels.index(self.labels[index])
        id = '/'.join(sample_path.rsplit('/')[-3:])
        return torch.from_numpy(augmented_sample), label, id
    def __len__(self):
        return self.n_samples
    def load_data(self, train=False, val=False, test=False):
        self.samples = []
        self.labels = []
        self.n_samples = 0
        self.unique_labels = []
        if(train):
            self.samples = self.samples + self.train_samples
            self.labels = self.labels + self.train_labels
        if(val):
            self.samples = self.samples + self.val_samples
            self.labels = self.labels + self.val_labels
        if(test):
            self.samples = self.samples + self.test_samples
            self.labels = self.labels + self.test_labels
        self.n_samples = len(self.samples)
        self.unique_labels = list(np.unique(self.labels))
    def augment_data(self, sample, augmentation):
        sample = self.cut_to_sec(sample, 3)
        if(augmentation == 'music'):
            aug_sample = self.augment_musan_music(sample)
        elif(augmentation == 'speech'):
            aug_sample = self.augment_musan_speech(sample)
        elif(augmentation == 'noise'):
            aug_sample = self.augment_musan_noise(sample)
        elif(augmentation == 'rir'):
            aug_sample = self.augment_rir(sample)
        else:
            aug_sample = sample
        aug_sample = aug_sample.astype(np.float64)
        aug_sample -= np.min(aug_sample)
        aug_sample /= np.max(aug_sample)
        return aug_sample
    def cut_to_sec(self, sample, length):
        if(len(sample) < self.sampling_rate*length):
            new_sample = np.pad(sample, (0, self.sampling_rate*length-len(sample)), 'constant', constant_values=(0, 0))
        else:
            start_point = random.randint(0, len(sample) - self.sampling_rate*length)
            new_sample = sample[start_point:start_point + self.sampling_rate*length]
        return new_sample
    def add_with_certain_snr(self, sample, noise, min_snr_db=5, max_snr_db=20):
        sample = sample.astype('int64')
        noise = noise.astype('int64')
        sample_rms = np.sqrt(np.mean(sample**2))
        noise_rms = np.sqrt(np.mean(noise**2))
        wanted_snr = random.randint(min_snr_db, max_snr_db)
        wanted_noise_rms = np.sqrt(sample_rms**2 / 10**(wanted_snr/10))
        new_noise = noise * wanted_noise_rms/(noise_rms+EPS)
        noisy_sample = sample + new_noise
        return noisy_sample
    def augment_musan_music(self, sample):
        musan_music_path = self.data_folder_path + '/musan/music/*/*.wav'
        song_path = random.choice(glob.glob(musan_music_path))
        rate, song = wavfile.read(song_path, np.dtype)
        song = resampy.resample(song, rate, self.sampling_rate)
        song = self.cut_to_sec(song, 3)
        aug_sample = self.add_with_certain_snr(sample, song, min_snr_db=5, max_snr_db=15)
        return aug_sample
    def augment_musan_speech(self, sample):
        musan_speech_path = self.data_folder_path + '/musan/speech/*/*.wav'
        speaker_path = random.choice(glob.glob(musan_speech_path))
        rate, speakers = wavfile.read(speaker_path, np.dtype)
        speakers = resampy.resample(speakers, rate, self.sampling_rate)
        speakers = self.cut_to_sec(speakers, 3)
        for i in range(random.randint(2, 6)):
            speaker_path = random.choice(glob.glob(musan_speech_path))
            rate, speaker = wavfile.read(speaker_path, np.dtype)
            speaker = resampy.resample(speaker, rate, self.sampling_rate)
            speaker = self.cut_to_sec(speaker, 3)
            speakers = speakers + speaker
        aug_sample = self.add_with_certain_snr(sample, speakers, min_snr_db=13, max_snr_db=20)
        return aug_sample
    def augment_musan_noise(self, sample):
        musan_noise_path = self.data_folder_path + '/musan/noise/*/*.wav'
        for i in range(3):
            noise_path = random.choice(glob.glob(musan_noise_path))
            rate, noise = wavfile.read(noise_path, np.dtype)
            noise = resampy.resample(noise, rate, self.sampling_rate)
            noise = self.cut_to_sec(noise, 1)
            sample[i:i+self.sampling_rate] = self.add_with_certain_snr(sample[i:i+self.sampling_rate], noise, min_snr_db=0, max_snr_db=15)
        return sample
    def augment_rir(self, sample):
        rir_noise_path = self.data_folder_path + '/RIRS_NOISES/simulated_rirs/*/*/*.wav'
        rir_path = random.choice(glob.glob(rir_noise_path))
        _, rir = wavfile.read(rir_path, np.dtype)
        aug_sample = fftconvolve(sample, rir)
        aug_sample = aug_sample / abs(aug_sample).max()
        sample_max = abs(sample).max()
        aug_max = abs(aug_sample).max()
        aug_sample = aug_sample * (sample_max/aug_max)
        aug_sample = sample + aug_sample[:len(sample)]
        return aug_sample
class Config:
    def __init__(self,
                batch_size=n_batch_size,
                input_size=24,
                hidden_size=512,
                num_classes=1211,
                x_vector_size=512,
                x_vec_extract_layer=6,
                learning_rate=0.001,
                num_epochs=n_epochs_total,
                batch_norm=True,
                dropout_p=0.0,
                augmentations_per_sample=2,
                plda_rank_f=50,
                checkpoint_path='none',
                data_folder_path='data',
                train_x_vector_model=True,
                extract_x_vectors=True,
                train_plda=True,
                test_plda=True):
        self.batch_size = batch_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        self.x_vector_size = x_vector_size
        self.x_vec_extract_layer = x_vec_extract_layer
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.batch_norm = batch_norm
        self.dropout_p = dropout_p
        self.augmentations_per_sample = augmentations_per_sample
        self.plda_rank_f = plda_rank_f
        self.checkpoint_path = checkpoint_path
        self.data_folder_path = data_folder_path
        self.train_x_vector_model = train_x_vector_model
        self.extract_x_vectors = extract_x_vectors
        self.train_plda = train_plda
        self.test_plda = test_plda
class XVectorModel(pl.LightningModule):
    def __init__(self, input_size=24,
                hidden_size=512,
                num_classes=1211,
                x_vector_size=512,
                x_vec_extract_layer=6,
                batch_size=512,
                learning_rate=0.001,
                batch_norm=True,
                dropout_p=0.0,
                augmentations_per_sample=2,
                data_folder_path='data'):
        super().__init__()
        self.time_context_layers = nn.Sequential(
            TdnnLayer(input_size=input_size, output_size=hidden_size, context=[-2, -1, 0, 1, 2], batch_norm=batch_norm, dropout_p=dropout_p),
            TdnnLayer(input_size=hidden_size, output_size=hidden_size, context=[-2, 0, 2], batch_norm=batch_norm, dropout_p=dropout_p),
            TdnnLayer(input_size=hidden_size, output_size=hidden_size, context=[-3, 0, 3], batch_norm=batch_norm, dropout_p=dropout_p),
            TdnnLayer(input_size=hidden_size, output_size=hidden_size, batch_norm=batch_norm, dropout_p=dropout_p),
            TdnnLayer(input_size=hidden_size, output_size=1500, batch_norm=batch_norm, dropout_p=dropout_p)
        )
        self.segment_layer6 = nn.Linear(3000, x_vector_size)
        self.segment_layer7 = nn.Linear(x_vector_size, x_vector_size)
        self.output = nn.Linear(x_vector_size, num_classes)
        self.x_vec_extract_layer = x_vec_extract_layer
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.dataset = Dataset(data_folder_path=data_folder_path, augmentations_per_sample=augmentations_per_sample)
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.save_hyperparameters()
    def stat_pool(self, x):
        mean = torch.mean(x, 1)
        stand_dev = torch.std(x, 1)
        out = torch.cat((mean, stand_dev), 1)
        return out
    def forward(self, x):
        out = self.time_context_layers(x)
        out = self.stat_pool(out)
        out = F.relu(self.segment_layer6(out))
        out = F.relu(self.segment_layer7(out))
        out = self.output(out)
        return out
    def extract_x_vec(self, x):
        out = self.time_context_layers.forward(x)
        out = self.stat_pool(out)
        if(self.x_vec_extract_layer == 6):
            x_vec = self.segment_layer6.forward(out)
        elif(self.x_vec_extract_layer == 7):
            out = F.relu(self.segment_layer6.forward(out))
            x_vec = self.segment_layer7.forward(out)
        else:
            x_vec = self.segment_layer6.forward(out)
        return x_vec
    def training_step(self, batch, batch_index):
        samples, labels, id = batch
        outputs = self(samples.float())
        loss = F.cross_entropy(outputs, labels)
        return {'loss': loss, 'train_preds': outputs, 'train_labels': labels, 'train_id': id}
    def training_step_end(self, outputs):
        self.log('train_step_loss', outputs['loss'])
        accuracy = self.accuracy(outputs['train_preds'], outputs['train_labels'])
        self.log('train_step_acc', self.accuracy)
        return {'loss': outputs['loss'], 'acc': accuracy}
    def on_fit_start(self):
        sample_input = torch.rand((1, 299, 24)).to(self.device)
        self.logger.experiment.add_graph(self, sample_input)
    def on_train_epoch_end(self):
        print(f"Epoch {self.current_epoch} ended. Skipping add_graph to avoid Trainer attachment issues.")
    def validation_step(self, batch, batch_index):
      samples, labels, _ = batch
      outputs = self(samples.float())
      loss = F.cross_entropy(outputs, labels)
      self.log("val_step_loss", loss, prog_bar=True)  # Log the metric for early stopping
      return {"val_loss": loss}
    def validation_step_end(self, outputs):
      accuracy = self.accuracy(outputs["val_preds"], outputs["val_labels"])
      self.log("val_step_acc", accuracy, prog_bar=True)
    def test_step(self, batch, batch_index):
        samples, labels, id = batch
        x_vecs = self.extract_x_vec(samples.float())
        return [(x_vecs, labels, id)]
    def test_epoch_end(self, test_step_outputs):
        for batch_output in test_step_outputs:
            for x_vec, label, id in batch_output:
                for x, l, i in zip(x_vec, label, id):
                    x_vector.append((i, int(l.cpu().numpy()), np.array(x.cpu().numpy(), dtype=np.float64)))
        return test_step_outputs
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    def train_dataloader(self):
        self.dataset.load_data(train=True)
        train_data_loader = DataLoader(dataset=self.dataset, batch_size=self.batch_size, num_workers=num_worker, shuffle=True)
        return train_data_loader
    def val_dataloader(self):
        self.dataset.load_data(val=True)
        val_data_loader = DataLoader(dataset=self.dataset, batch_size=self.batch_size, num_workers=num_worker, shuffle=False)
        return val_data_loader
    def test_dataloader(self):
        if(extract_mode == 'train'):
            self.dataset.load_data(train=True, val=True)
            test_data_loader = DataLoader(dataset=self.dataset, batch_size=self.batch_size, num_workers=num_worker, shuffle=False)
        if(extract_mode == 'test'):
            self.dataset.load_data(test=True)
            test_data_loader = DataLoader(dataset=self.dataset, batch_size=self.batch_size, num_workers=num_worker, shuffle=False)
        return test_data_loader


# Main execution
if __name__ == "__main__":
    print('setting up model and trainer parameters')
    ppth = f'{log_dir}/lightning_logs/version_7/checkpoints/last.ckpt'
    ckpt_path = ppth if os.path.exists(f'{log_dir}/lightning_logs/version_7/checkpoints/last.ckpt') else 'none'
    sep1="::::::::\\::::::::"
    print(sep1,ckpt_path,sep1)
    config = Config(data_folder_path=data_final_path,
                    checkpoint_path='none',#ppth,
                    train_x_vector_model = True,
                    extract_x_vectors = False,
                    train_plda = False,
                    test_plda = False,
                    x_vec_extract_layer=6,
                    plda_rank_f=25)#TODO delete most of this
    # Define model and trainer
    tb_logger = pl_loggers.TensorBoardLogger(save_dir=log_dir)
    early_stopping_callback = EarlyStopping(monitor="val_step_loss", mode="min")
    checkpoint_callback = ModelCheckpoint(monitor='val_step_loss', save_top_k=10, save_last=True, verbose=True)
    if(config.checkpoint_path == 'none'):
        model = XVectorModel(input_size=config.input_size,
                            hidden_size=config.hidden_size,
                            num_classes=config.num_classes,
                            x_vector_size=config.x_vector_size,
                            x_vec_extract_layer=config.x_vec_extract_layer,
                            batch_size=config.batch_size,
                            learning_rate=config.learning_rate,
                            batch_norm=config.batch_norm,
                            dropout_p=config.dropout_p,
                            augmentations_per_sample=config.augmentations_per_sample,
                            data_folder_path=config.data_folder_path)
    else:
        model = XVectorModel.load_from_checkpoint(config.checkpoint_path)
    model.dataset.init_samples_and_labels()
    trainer = pl.Trainer(callbacks=[early_stopping_callback, checkpoint_callback],
                        logger=tb_logger,
                        log_every_n_steps=1,
                        accelerator="gpu" if torch.cuda.is_available() else "cpu",
                        max_epochs=config.num_epochs)
    # Train the x-vector model
    if(config.train_x_vector_model):
        print('training x-vector model')
        if(config.checkpoint_path == 'none'):
            trainer.fit(model)
        else:
            trainer.fit(model, ckpt_path=config.checkpoint_path)
    # Extract the x-vectors
    if(config.extract_x_vectors):
        print('extracting x-vectors')
        if not os.path.exists('x_vectors'):
            os.makedirs('x_vectors')
        # Extract the x-vectors for trainng the PLDA classifier and save to csv
        x_vector = []
        extract_mode = 'train'
        if(config.train_x_vector_model):
            trainer.test(model)
            x_vector = pd.DataFrame(x_vector)
            x_vector.to_csv('x_vectors/x_vector_train_v1_5_l7relu.csv')#TODO set to default name
        elif(config.checkpoint_path != 'none'):
            trainer.test(model, ckpt_path=config.checkpoint_path)
            x_vector = pd.DataFrame(x_vector)
            x_vector.to_csv('x_vectors/x_vector_train_v1_5_l7relu.csv')#TODO set to default name
        else:
            print('could not extract train x-vectors')
        # Extract the x-vectors for testing the PLDA classifier and save to csv
        x_vector = []
        extract_mode = 'test'
        if(config.train_x_vector_model):
            trainer.test(model)
            x_vector = pd.DataFrame(x_vector)
            x_vector.to_csv(bse_dir+'x_vectors/x_vector_test_v1_5_l7relu.csv')#TODO set to default name
        elif(config.checkpoint_path != 'none'):
            trainer.test(model, ckpt_path=config.checkpoint_path)
            x_vector = pd.DataFrame(x_vector)
            x_vector.to_csv(bse_dir+'x_vectors/x_vector_test_v1_5_l7relu.csv')#TODO set to default name
        else:
            print('could not extract test x-vectors')
    if(config.train_plda):
        print('loading x_vector data')
        if not os.path.exists('plda'):
            os.makedirs('plda')
        # Extract the x-vectors, labels and id from the csv
        x_vectors_train = pd.read_csv(bse_dir+'x_vectors/i_vector_train_v2.csv')#TODO set to default name
        x_id_train = np.array(x_vectors_train.iloc[:, 1])
        x_label_train = np.array(x_vectors_train.iloc[:, 2], dtype=int)
        x_vec_train = np.array([np.array(x_vec[1:-1].split(), dtype=np.float64) for x_vec in x_vectors_train.iloc[:, 3]])
        # Generate x_vec stat objects
        print('generating x_vec stat objects')
        tr_stat = get_train_x_vec(x_vec_train, x_label_train, x_id_train)
        # Train plda
        print('training plda')
        plda = setup_plda(rank_f=50, nb_iter=10)
        plda = train_plda(plda, tr_stat)
        save_plda(plda, 'plda_ivec_v2_d50')
        # Train plda
        print('training plda')
        plda = setup_plda(rank_f=100, nb_iter=10)
        plda = train_plda(plda, tr_stat)
        save_plda(plda, 'plda_ivec_v2_d100')
        # Train plda
        print('training plda')
        plda = setup_plda(rank_f=150, nb_iter=10)
        plda = train_plda(plda, tr_stat)
        save_plda(plda, 'plda_ivec_v2_d150')
        # Train plda
        print('training plda')
        plda = setup_plda(rank_f=200, nb_iter=10)
        plda = train_plda(plda, tr_stat)
        save_plda(plda, 'plda_ivec_v2_d200')
    if(config.test_plda):
        # Extract the x-vectors, labels and id from the csv
        print('loading x_vector data')
        x_vectors_test = pd.read_csv(bse_dir+'x_vectors/i_vector_test_v2.csv')#TODO set to default name
        x_vectors_test.columns = ['index', 'id', 'label', 'xvector']
        score = plda_score_stat_object(x_vectors_test)
        # Test plda
        print('testing plda')
        if(not config.train_plda):
            plda = load_plda(bse_dir+'plda/plda_ivec_v2_d200.pickle')#TODO set to default name
        score.test_plda(plda, config.data_folder_path + '/VoxCeleb/veri_test2.txt')
        # Calculate EER and minDCF
        print('calculating EER and minDCF')
        score.calc_eer_mindcf()
        print('EER: ', score.eer, '   threshold: ', score.eer_th)
        print('minDCF: ', score.min_dcf, '   threshold: ', score.min_dcf_th)
        # Generate images for tensorboard
        score.plot_images(tb_logger.experiment)
        save_plda(score, 'plda_score_ivec_v2_d200')#TODO set to default name
    print('DONE')

# End