# Speech Understanding - Programming Assignment 3

    Aryan Tiwari B20AI056

---

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import sys
import torch
import librosa
import torchaudio

import torchmetrics

# import gradio as gr
import wandb as wb


## Question 1

**Goal:** 

    The task is to classify the audio samples into Real and Fake

**Tasks:** 

    —- Use the SSL W2V model trained for LA and DF tracks of the ASVSpoof dataset.

    —- Download the custom dataset from here. Report the AUC and EER on this dataset. 
    
    —- Analyze the performance of the model.
    
    —- Finetune the model on FOR dataset. 
    
    —- Report the performance using AUC and EER on For dataset. 
    
    —- Use the model trained on the FOR dataset to evaluate the custom dataset. Report the EER and AUC
    
    —- Comment on the change in performance, if any. 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
!git clone https://github.com/TakHemlata/SSL_Anti-spoofing.git

In [None]:
import random
from typing import Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import fairseq


############################
## FOR fine-tuned SSL MODEL
############################


class SSLModel(nn.Module):
    def __init__(self,device):
        super(SSLModel, self).__init__()
        
        cp_path = 'xlsr2_300m.pt'   # Change the pre-trained XLSR model path. 
        model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
        self.model = model[0]
        self.device=device
        self.out_dim = 1024
        return

    def extract_feat(self, input_data):
        
        # put the model to GPU if it not there
        if next(self.model.parameters()).device != input_data.device \
           or next(self.model.parameters()).dtype != input_data.dtype:
            self.model.to(input_data.device, dtype=input_data.dtype)
            self.model.train()

        
        if True:
            # input should be in shape (batch, length)
            if input_data.ndim == 3:
                input_tmp = input_data[:, :, 0]
            else:
                input_tmp = input_data
                
            # [batch, length, dim]
            emb = self.model(input_tmp, mask=False, features_only=True)['x']
        return emb


#---------AASIST back-end------------------------#
''' Jee-weon Jung, Hee-Soo Heo, Hemlata Tak, Hye-jin Shim, Joon Son Chung, Bong-Jin Lee, Ha-Jin Yu and Nicholas Evans. 
    AASIST: Audio Anti-Spoofing Using Integrated Spectro-Temporal Graph Attention Networks. 
    In Proc. ICASSP 2022, pp: 6367--6371.'''


class GraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_weight = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x):
        '''
        x   :(#bs, #node, #dim)
        '''
        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x)

        # projection
        x = self._project(x, att_map)

        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)
        return x

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map(self, x):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)
        att_map = torch.matmul(att_map, self.att_weight)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class HtrgGraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        self.proj_type1 = nn.Linear(in_dim, in_dim)
        self.proj_type2 = nn.Linear(in_dim, in_dim)

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_projM = nn.Linear(in_dim, out_dim)

        self.att_weight11 = self._init_new_params(out_dim, 1)
        self.att_weight22 = self._init_new_params(out_dim, 1)
        self.att_weight12 = self._init_new_params(out_dim, 1)
        self.att_weightM = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        self.proj_with_attM = nn.Linear(in_dim, out_dim)
        self.proj_without_attM = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x1, x2, master=None):
        '''
        x1  :(#bs, #node, #dim)
        x2  :(#bs, #node, #dim)
        '''
        #print('x1',x1.shape)
        #print('x2',x2.shape)
        num_type1 = x1.size(1)
        num_type2 = x2.size(1)
        #print('num_type1',num_type1)
        #print('num_type2',num_type2)
        x1 = self.proj_type1(x1)
        #print('proj_type1',x1.shape)
        x2 = self.proj_type2(x2)
        #print('proj_type2',x2.shape)
        x = torch.cat([x1, x2], dim=1)
        #print('Concat x1 and x2',x.shape)
        
        if master is None:
            master = torch.mean(x, dim=1, keepdim=True)
            #print('master',master.shape)
        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x, num_type1, num_type2)
        #print('master',master.shape)
        # directional edge for master node
        master = self._update_master(x, master)
        #print('master',master.shape)
        # projection
        x = self._project(x, att_map)
        #print('proj x',x.shape)
        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)

        x1 = x.narrow(1, 0, num_type1)
        #print('x1',x1.shape)
        x2 = x.narrow(1, num_type1, num_type2)
        #print('x2',x2.shape)
        return x1, x2, master

    def _update_master(self, x, master):

        att_map = self._derive_att_map_master(x, master)
        master = self._project_master(x, master, att_map)

        return master

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map_master(self, x, master):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = x * master
        att_map = torch.tanh(self.att_projM(att_map))

        att_map = torch.matmul(att_map, self.att_weightM)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _derive_att_map(self, x, num_type1, num_type2):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)

        att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)

        att_board[:, :num_type1, :num_type1, :] = torch.matmul(
            att_map[:, :num_type1, :num_type1, :], self.att_weight11)
        att_board[:, num_type1:, num_type1:, :] = torch.matmul(
            att_map[:, num_type1:, num_type1:, :], self.att_weight22)
        att_board[:, :num_type1, num_type1:, :] = torch.matmul(
            att_map[:, :num_type1, num_type1:, :], self.att_weight12)
        att_board[:, num_type1:, :num_type1, :] = torch.matmul(
            att_map[:, num_type1:, :num_type1, :], self.att_weight12)

        att_map = att_board

        

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _project_master(self, x, master, att_map):

        x1 = self.proj_with_attM(torch.matmul(
            att_map.squeeze(-1).unsqueeze(1), x))
        x2 = self.proj_without_attM(master)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class GraphPool(nn.Module):
    def __init__(self, k: float, in_dim: int, p: Union[float, int]):
        super().__init__()
        self.k = k
        self.sigmoid = nn.Sigmoid()
        self.proj = nn.Linear(in_dim, 1)
        self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
        self.in_dim = in_dim

    def forward(self, h):
        Z = self.drop(h)
        weights = self.proj(Z)
        scores = self.sigmoid(weights)
        new_h = self.top_k_graph(scores, h, self.k)

        return new_h

    def top_k_graph(self, scores, h, k):
        """
        args
        =====
        scores: attention-based weights (#bs, #node, 1)
        h: graph data (#bs, #node, #dim)
        k: ratio of remaining nodes, (float)
        returns
        =====
        h: graph pool applied data (#bs, #node', #dim)
        """
        _, n_nodes, n_feat = h.size()
        n_nodes = max(int(n_nodes * k), 1)
        _, idx = torch.topk(scores, n_nodes, dim=1)
        idx = idx.expand(-1, -1, n_feat)

        h = h * scores
        h = torch.gather(h, 1, idx)

        return h




class Residual_block(nn.Module):
    def __init__(self, nb_filts, first=False):
        super().__init__()
        self.first = first

        if not self.first:
            self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
        self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(1, 1),
                               stride=1)
        self.selu = nn.SELU(inplace=True)

        self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
        self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(0, 1),
                               stride=1)

        if nb_filts[0] != nb_filts[1]:
            self.downsample = True
            self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
                                             out_channels=nb_filts[1],
                                             padding=(0, 1),
                                             kernel_size=(1, 3),
                                             stride=1)

        else:
            self.downsample = False
        

    def forward(self, x):
        identity = x
        if not self.first:
            out = self.bn1(x)
            out = self.selu(out)
        else:
            out = x

        #print('out',out.shape)
        out = self.conv1(x)

        #print('aft conv1 out',out.shape)
        out = self.bn2(out)
        out = self.selu(out)
        # print('out',out.shape)
        out = self.conv2(out)
        #print('conv2 out',out.shape)
        
        if self.downsample:
            identity = self.conv_downsample(identity)

        out += identity
        #out = self.mp(out)
        return out


class Model(nn.Module):
    def __init__(self, args,device):
        super().__init__()
        self.device = device
        
        # AASIST parameters
        filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]]
        gat_dims = [64, 32]
        pool_ratios = [0.5, 0.5, 0.5, 0.5]
        temperatures =  [2.0, 2.0, 100.0, 100.0]


        ####
        # create network wav2vec 2.0
        ####
        self.ssl_model = SSLModel(self.device)
        self.LL = nn.Linear(self.ssl_model.out_dim, 128)

        self.first_bn = nn.BatchNorm2d(num_features=1)
        self.first_bn1 = nn.BatchNorm2d(num_features=64)
        self.drop = nn.Dropout(0.5, inplace=True)
        self.drop_way = nn.Dropout(0.2, inplace=True)
        self.selu = nn.SELU(inplace=True)

        # RawNet2 encoder
        self.encoder = nn.Sequential(
            nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
            nn.Sequential(Residual_block(nb_filts=filts[2])),
            nn.Sequential(Residual_block(nb_filts=filts[3])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])))

        self.attention = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(1,1)),
            nn.SELU(inplace=True),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 64, kernel_size=(1,1)),
            
        )
        # position encoding
        self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1]))
        
        self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
        self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
        
        # Graph module
        self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[0])
        self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[1])
        # HS-GAL layer 
        self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])

        # Graph pooling layers
        self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
        self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
        self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

        self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        
        self.out_layer = nn.Linear(5 * gat_dims[1], 2)

    def forward(self, x):
        #-------pre-trained Wav2vec model fine tunning ------------------------##
        x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1))
        x = self.LL(x_ssl_feat) #(bs,frame_number,feat_out_dim)
        
        # post-processing on front-end features
        x = x.transpose(1, 2)   #(bs,feat_out_dim,frame_number)
        x = x.unsqueeze(dim=1) # add channel 
        x = F.max_pool2d(x, (3, 3))
        x = self.first_bn(x)
        x = self.selu(x)

        # RawNet2-based encoder
        x = self.encoder(x)
        x = self.first_bn1(x)
        x = self.selu(x)
        
        w = self.attention(x)
        
        #------------SA for spectral feature-------------#
        w1 = F.softmax(w,dim=-1)
        m = torch.sum(x * w1, dim=-1)
        e_S = m.transpose(1, 2) + self.pos_S 
        
        # graph module layer
        gat_S = self.GAT_layer_S(e_S)
        out_S = self.pool_S(gat_S)  # (#bs, #node, #dim)
        
        #------------SA for temporal feature-------------#
        w2 = F.softmax(w,dim=-2)
        m1 = torch.sum(x * w2, dim=-2)
     
        e_T = m1.transpose(1, 2)
       
        # graph module layer
        gat_T = self.GAT_layer_T(e_T)
        out_T = self.pool_T(gat_T)
        
        # learnable master node
        master1 = self.master1.expand(x.size(0), -1, -1)
        master2 = self.master2.expand(x.size(0), -1, -1)

        # inference 1
        out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
            out_T, out_S, master=self.master1)

        out_S1 = self.pool_hS1(out_S1)
        out_T1 = self.pool_hT1(out_T1)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
            out_T1, out_S1, master=master1)
        out_T1 = out_T1 + out_T_aug
        out_S1 = out_S1 + out_S_aug
        master1 = master1 + master_aug

        # inference 2
        out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
            out_T, out_S, master=self.master2)
        out_S2 = self.pool_hS2(out_S2)
        out_T2 = self.pool_hT2(out_T2)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
            out_T2, out_S2, master=master2)
        out_T2 = out_T2 + out_T_aug
        out_S2 = out_S2 + out_S_aug
        master2 = master2 + master_aug

        out_T1 = self.drop_way(out_T1)
        out_T2 = self.drop_way(out_T2)
        out_S1 = self.drop_way(out_S1)
        out_S2 = self.drop_way(out_S2)
        master1 = self.drop_way(master1)
        master2 = self.drop_way(master2)

        out_T = torch.max(out_T1, out_T2)
        out_S = torch.max(out_S1, out_S2)
        master = torch.max(master1, master2)

        # Readout operation
        T_max, _ = torch.max(torch.abs(out_T), dim=1)
        T_avg = torch.mean(out_T, dim=1)

        S_max, _ = torch.max(torch.abs(out_S), dim=1)
        S_avg = torch.mean(out_S, dim=1)
        
        last_hidden = torch.cat(
            [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
        
        last_hidden = self.drop(last_hidden)
        output = self.out_layer(last_hidden)
        
        return output


In [None]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import DatasetFolder
from tqdm import tqdm
import os
import json
from sklearn import metrics
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from matplotlib import pyplot as plt

# Configuration Management
class Config:
    def __init__(self):
        self.experimentSaveName = 'customSubsetEval1'
        self.rootDirectoryPathRishabhSubset = '../data/Dataset_Speech_Assignment' 
        self.batchSize = 8
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.randomSeed = 42
        torch.manual_seed(self.randomSeed)

# Model Loading
def loadModel(modelClass, stateDictPath, args):
    model = modelClass(args, 'cpu')
    stateDict = torch.load(stateDictPath, map_location='cpu')
    for key in list(stateDict.keys()):
        stateDict[key.replace('module.', '')] = stateDict.pop(key)
    model.load_state_dict(stateDict)
    return model

# Dataset and DataLoader Setup
def pad(x, maxLength=64600):
    xLength = x.shape[0]
    if xLength >= maxLength:
        return x[:maxLength]
    numRepeats = int(maxLength / xLength) + 1
    paddedX = np.tile(x, (1, numRepeats))[:, :maxLength][0]
    return paddedX

def loaderRishabhSubset(samplePath):
    cut = 64600
    X, fs = librosa.load(samplePath, sr=16000)
    XPad = pad(X, cut)
    xInput = torch.Tensor(XPad)
    return xInput

def setupDataLoader(rootDirectoryPath, loaderFunc, extensions, batchSize):
    dataset = DatasetFolder(rootDirectoryPath, loader=loaderFunc, extensions=extensions)
    loader = DataLoader(dataset, batch_size=batchSize, shuffle=True)
    return loader

# Evaluation Utilities
def compute_det_curve(target_scores, nontarget_scores):
    target_scores = np.array(target_scores)
    nontarget_scores = np.array(nontarget_scores)
    n_scores = target_scores.size + nontarget_scores.size
    all_scores = np.concatenate((target_scores, nontarget_scores))
    labels = np.concatenate((np.ones(target_scores.size), np.zeros(nontarget_scores.size)))
    indices = np.argsort(all_scores, kind='mergesort')
    labels = labels[indices]
    tar_trial_sums = np.cumsum(labels)
    nontarget_trial_sums = nontarget_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums)
    frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / target_scores.size))
    far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / nontarget_scores.size))
    thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices]))
    return frr, far, thresholds

def compute_eer(target_scores, nontarget_scores):
    frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores)
    abs_diffs = np.abs(frr - far)
    min_index = np.argmin(abs_diffs)
    eer = np.mean((frr[min_index], far[min_index]))
    return eer, thresholds[min_index]

def compute_eer_auc(target_scores, nontarget_scores):
    frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores)
    abs_diffs = np.abs(frr - far)
    min_index = np.argmin(abs_diffs)
    eer = np.mean((frr[min_index], far[min_index]))
    auc = np.trapz(1 - frr, far)
    return eer, auc, thresholds[min_index]

# Evaluation Script
def evalScript(model, loader, device, experiment_save_name, savePath='B20AI056_eval', printLogs=True, saveLogs=True, saveFigs=True):
    with torch.no_grad():
        model = model.to(device)
        model.eval()
        scores = []
        truths = []
        for xs, labels in tqdm(loader):
            xs, labels = xs.to(device), labels.to(device)
            outputs = model(xs)
            scores.extend((outputs[:, 1]).data.cpu().numpy().ravel().tolist())
            truths.extend(labels.tolist())

        fpr, tpr, thresholds = metrics.roc_curve(np.array(truths), np.array(scores), pos_label=1)
        auc = metrics.auc(fpr, tpr)
        eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
        thresh = interp1d(fpr, thresholds)(eer)

        if printLogs:
            print('FINAL SCORES BELOW!!!!!!!!!!')
            print('eer:', eer, '        auc:', auc,  "    thresh:", thresh)
            print('FINAL SCORES ABOVE!!!!!!!!!!')
            print()
        if saveLogs:
            with open(os.path.join(savePath, experiment_save_name + '.json'), 'w') as f:
                json.dump(
                    {
                        'eer': float(f'{eer}'),
                        'auc': float(f'{auc}'),
                        'thresh': float(f'{thresh}'),
                    }, f, indent=4
                )
            
        if saveFigs:    
            # calculate the precision recall and F1-Score
            precision, recall, _ = metrics.precision_recall_curve(truths, scores)
            f1 = metrics.f1_score(truths, scores)
            plt.figure()
            plt.plot(recall, precision, marker='.')
            plt.xlabel('Recall')
            plt.ylabel('Precision')
            plt.title('Precision Recall Curve')
            plt.grid()
            plt.savefig(os.path.join(savePath, experiment_save_name + '_precision_recall_curve.png'))
            plt.close()


    return eer, auc

# Main
if __name__ == "__main__":

    
    config = Config()
    model = loadModel(Model, '../models/BestLAModelForDF.pth', None)  # Adjust args if needed
    loader = setupDataLoader(config.rootDirectoryPathRishabhSubset, loaderRishabhSubset, ('wav', 'mp3'), config.batchSize)
    evalScript(model, loader, config.device, config.experimentSaveName, printLogs=False, saveLogs=False, saveFigs=True)


In [None]:
import torch
import numpy as np
import librosa
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import DatasetFolder
import os
import json
from sklearn import metrics
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from tqdm import tqdm
import wandb

class Config:
    def __init__(self):
        self.checkpointSaveDir = 'checkpoint_save_dir'
        self.experimentSaveName = 'trainEval2Final'
        self.rootDirectoryPathRishabhSubset = '../data/Dataset_Speech_Assignment' 
        self.rootDirectoryPathTesting = '../data/Dataset_Speech_Assignment'  
        self.rootDirectoryPathTrain = '../data/for-2seconds/training'  
        self.rootDirectoryPathValidation = '../data/for-2seconds/validation'  
        self.batchSize = 32
        self.numEpochs = 2
        self.learningRate = 3e-4
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.randomSeed = 42
        torch.manual_seed(self.randomSeed)


def loadModel(modelClass, stateDictPath, args):
    model = modelClass(args, 'cpu')
    stateDict = torch.load(stateDictPath, map_location='cpu')
    for key in list(stateDict.keys()):
        stateDict[key.replace('module.', '')] = stateDict.pop(key)
    model.load_state_dict(stateDict)
    return model

def pad(x, maxLen=64600):
    xLen = x.shape[0]
    if xLen >= maxLen:
        return x[:maxLen]
    numRepeats = int(maxLen / xLen) + 1
    paddedX = np.tile(x, (1, numRepeats))[:, :maxLen][0]
    return paddedX

def loaderRishabhSubset(samplePath):
    cut = 64600
    X, fs = librosa.load(samplePath, sr=16000)
    XPad = pad(X, cut)
    xInput = torch.Tensor(XPad)
    return xInput

def setupDataLoader(rootDirectoryPath, loaderFunc, extensions, batchSize):
    dataset = DatasetFolder(rootDirectoryPath, loader=loaderFunc, extensions=extensions)
    loader = DataLoader(dataset, batch_size=batchSize, shuffle=True)
    return loader

def compute_det_curve(target_scores, nontarget_scores):
    target_scores = np.array(target_scores)
    nontarget_scores = np.array(nontarget_scores)
    n_scores = target_scores.size + nontarget_scores.size
    all_scores = np.concatenate((target_scores, nontarget_scores))
    labels = np.concatenate((np.ones(target_scores.size), np.zeros(nontarget_scores.size)))
    # Sort labels based on scores
    indices = np.argsort(all_scores, kind='mergesort')
    labels = labels[indices]
    # Compute false rejection and false acceptance rates
    tar_trial_sums = np.cumsum(labels)
    nontarget_trial_sums = nontarget_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums)
    frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / target_scores.size))  # false rejection rates
    far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / nontarget_scores.size))  # false acceptance rates
    thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices]))  # Thresholds are the sorted scores
    return frr, far, thresholds



def compute_eer(target_scores, nontarget_scores):
    """ Returns equal error rate (EER) and the corresponding threshold. """
    frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores)
    abs_diffs = np.abs(frr - far)
    min_index = np.argmin(abs_diffs)
    eer = np.mean((frr[min_index], far[min_index]))
    return eer, thresholds[min_index]


def compute_eer_auc(target_scores, nontarget_scores):
    """ Returns equal error rate (EER), AUC, and the corresponding threshold. """
    frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores)
    abs_diffs = np.abs(frr - far)
    min_index = np.argmin(abs_diffs)
    eer = np.mean((frr[min_index], far[min_index]))
    
    # Compute AUC using trapezoidal rule
    auc = np.trapz(1 - frr, far)

    return eer, auc, thresholds[min_index]

def evalScript(model, loader, device, experiment_save_name, savePath='B20AI056_eval', savelogs=True, printlogs=True):
    with torch.no_grad():
        model = model.to(device)
        model.eval()
        correct = 0
        total = 0
        scores = []
        truths = []
        for xs, labels in tqdm(loader):
                xs, labels = xs.to(device), labels.to(device)
                outputs = model(xs)     
                scores.extend((outputs[:, 1]).data.cpu().numpy().ravel().tolist())
                truths.extend(labels.tolist())

    # calculate EER, AUC correctly
    fpr, tpr, thresholds = metrics.roc_curve(np.array(truths), np.array(scores), pos_label=1)
    auc = metrics.auc(fpr, tpr)

    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    thresh = interp1d(fpr, thresholds)(eer)

    if printlogs:
        tqdm.write('FINAL SCORES :')
        tqdm.write('eer:', eer, '        auc:', auc,  "    thresh:", thresh)
        tqdm.write()
    
    if savelogs:
        with open(os.path.join(savePath, experiment_save_name), 'w') as f:
            json.dump(
                {
                    'eer':    float(f'{eer}'    ) ,
                    'auc':    float(f'{auc}'    ) ,
                    'thresh': float(f'{thresh}' ) ,
                }, f, indent=4
            )
    eer    = float(f'{eer}'    )
    auc    = float(f'{auc}'    )
    thresh = float(f'{thresh}' )
    return eer, auc, thresh

def trainScript(model, trainLoader, evalLoader, testLoader, config, optimizer, criterion, trainLoss=[], eerTrain=[], eerEval=[], eerTest=[], aucTrain=[], aucEval=[], aucTest=[], printlogs=True, savelogs=True):
    model.train()
    model = model.to(config.device)
    optimizer = optimizer(model.parameters(), lr=config.learningRate)
    criterion = criterion()

    for epoch in range(config.numEpochs+1):
        if epoch > 0:
            tqdm.write(f'begun training epoch#{epoch} out of {config.numEpochs}')
            tqdm.write(f'-' * 20)
            bar = tqdm(total=len(trainLoader))
            for xs, labels in trainLoader:
                optimizer.zero_grad()
                xs, labels = xs.to(device), labels.to(device)
                outputs = model(xs)     
                scores = torch.stack([outputs[:, 1], - outputs[:, 1]], dim=1)

                # print(scores.shape, labels.shape)
                loss = criterion(scores, labels)
                loss.backward()
                optimizer.step()
                trainLoss.append(loss.item())
                bar.update(1)
                bar.set_postfix({
                    'trainloss(tillnow)': np.mean(np.array(trainLoss)),
                })
                wandb.log({"trainloss": loss.item()})
            bar.close()
        else:
            tqdm.write('epoch 0, printing the raw results only...')
            
        train_auc, train_eer, train_thresh = evalScript(model, trainLoader, device, savelogs=False, printlogs=False)
        eval_auc, eval_eer, eval_thresh = evalScript(model, evalLoader, device, savelogs=False, printlogs=False)
        test_auc, test_eer, test_thresh = evalScript(model, testLoader, device, savelogs=False, printlogs=False)
        aucTrain.append(train_auc )
        eerTrain.append(train_eer )
        aucEval.append (eval_auc  )
        eerEval.append (eval_eer  )
        aucTest.append (test_auc  )
        eerTest.append (test_eer  )
        logs = {
            'checkpoint_save_dir': config.checkpointSaveDir,
            'experiment_save_name': config.experimentSaveName,
            'save_location': os.path.join(config.checkpointSaveDi, config.experimentSaveName),
            'trainloss': trainLoss,
            'auc_train': aucTrain,
            'eer_train': eerTrain,
            'auc_eval': aucEval,
            'eer_eval': eerEval,
            'auc_test': aucTest,
            'eer_test': eerTest,
            'training_epochs': epoch,
            'learning_rate': config.learningRate,
            'batch_size': config.batchSize,
        }
        wandb_log = {
            'auc_train' : aucTrain [-1] ,
            'eer_train' : eerTrain [-1] ,
            'auc_eval'  : aucEval  [-1] ,
            'eer_eval'  : eerEval  [-1] ,
            'auc_test'  : aucTest  [-1] ,
            'eer_test'  : eerTest  [-1] ,
        }
        
        wandb.log(wandb_log)

        if printlogs:
            print(json.dumps(obj=logs, indent=4))

        if savelogs:
            torch.save({
                'model':model.state_dict(), 
                'logs': logs
            }, os.path.join(config.checkpointSaveDir, config.experimentSaveName + '_ckpt.pt'))
            with open(os.path.join(config.checkpointSaveDir, config.experimentSaveName + '_logs.json'), 'w') as f:
                json.dump(logs, f, indent=4)
    return trainLoss, aucTrain, eerTrain, aucEval, eerEval, aucTest, eerTest


In [None]:
# Main
if __name__ == "__main__":
    import wandb

    config = Config()

    wandb.init(
        project='supa3',
        config={
            'learning-rate': config.learningRate,
            'num-epochs': config.numEpochs,
            'batch-size': config.batchSize,
            'random-seed': config.randomSeed,
        }
    )

    model = loadModel(Model, '../models/Best_LA_model_for_DF.pth', None)
    trainLoader = setupDataLoader(config.rootDirectoryPathTrain, loaderRishabhSubset, ('wav', 'mp3'), config.batchSize)
    evalLoader = setupDataLoader(config.rootDirectoryPathValidation, loaderRishabhSubset, ('wav', 'mp3'), config.batchSize)
    testLoader = setupDataLoader(config.rootDirectoryPathTesting, loaderRishabhSubset, ('wav', 'mp3'), config.batchSize)

    trainLoss, aucTrain, eerTrain, aucEval, eerEval, aucTest, eerTest = [], [], [], [], [], [], []
    trainLoss, aucTrain, eerTrain, aucEval, eerEval, aucTest, eerTest = trainScript(model, trainLoader, evalLoader, testLoader, config.numEpochs, config.learningRate, config.device, torch.optim.Adam, torch.nn.CrossEntropyLoss, trainLoss, eerTrain, eerEval, eerTest, aucTrain, aucEval, aucTest)

    print('trainloss:', trainLoss[-1])
    print('auc_train:', aucTrain[-1])
    print('eer_train:', eerTrain[-1])
    print('auc_eval:', aucEval[-1])
    print('eer_eval:', eerEval[-1])
    print('auc_test:', aucTest[-1])
    print('eer_test:', eerTest[-1])

    wandb.finish()