# imports

In [2]:
import random
from typing import Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import fairseq

import soundfile as sf
from torch.utils.data import Dataset, DataLoader
import os

import sys
import time

from torch.optim import Adam
from torch_optimizer import AdaBound

  from .autonotebook import tqdm as notebook_tqdm


# config

In [3]:
CONFIG = {
  "model": "wav2vec2+AASIST",
  "batch_size": 8,
  "wandb_project": "project_name",
  "d_args": {
      "nb_samp": 64600,
      "first_conv": 128,
      "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]]
  },
  "device": "cuda:0",
  "num_class": 2,
  "gpu_id": 0,
  "num_capsules": 30,
  "epoches": 40,
  "opt": "AdaBound",
  "lr": 0.0001,
  "weight_decay": 0.00001,
  "random": True,
  "dropout": 0.05,
  "random_size": 0.01,
  "num_iterations": 2,
  "gamma": 0.5,
  "step_size": 10,
  "produced_file": "ssl_preds.txt",
  "asv_score_filename": "/asvspoof/LA/ASVspoof2019_LA_asv_scores/ASVspoof2019.LA.asv.dev.gi.trl.scores.txt",
  "dev_label_path": "/asvspoof/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt",
  "dev_path_flac": "/asvspoof/LA/ASVspoof2019_LA_dev",
  "train_label_path": "/asvspoof/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt",
  "train_path_flac":"/asvspoof/LA/ASVspoof2019_LA_train",
  "eval_label_path": "/asvspoof/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.eval.trl.txt",
  "eval_path_flac": "/asvspoof/LA/ASVspoof2019_LA_eval",
  "checkpoint": "/app/SafeSpeak-2024/weights/LA_model.pth",
  "num_workers": 6
}

# model

In [4]:
___author__ = "Hemlata Tak"
__email__ = "tak@eurecom.fr"

############################
## FOR fine-tuned SSL MODEL
############################


class SSLModel(nn.Module):
    def __init__(self,device):
        super(SSLModel, self).__init__()
        
        cp_path = '/app/SafeSpeak-2024/weights/xlsr2_300m.pt'   # Change the pre-trained XLSR model path. 
        model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
        self.model = model[0]
        self.device=device
        self.out_dim = 1024
        return

    def extract_feat(self, input_data):
        
        # put the model to GPU if it not there
        if next(self.model.parameters()).device != input_data.device \
           or next(self.model.parameters()).dtype != input_data.dtype:
            self.model.to(input_data.device, dtype=input_data.dtype)
            self.model.train()

        
        if True:
            # input should be in shape (batch, length)
            if input_data.ndim == 3:
                input_tmp = input_data[:, :, 0]
            else:
                input_tmp = input_data
                
            # [batch, length, dim]
            emb = self.model(input_tmp, mask=False, features_only=True)['x']
        return emb


#---------AASIST back-end------------------------#
''' Jee-weon Jung, Hee-Soo Heo, Hemlata Tak, Hye-jin Shim, Joon Son Chung, Bong-Jin Lee, Ha-Jin Yu and Nicholas Evans. 
    AASIST: Audio Anti-Spoofing Using Integrated Spectro-Temporal Graph Attention Networks. 
    In Proc. ICASSP 2022, pp: 6367--6371.'''
class GraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_weight = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x):
        '''
        x   :(#bs, #node, #dim)
        '''
        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x)

        # projection
        x = self._project(x, att_map)

        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)
        return x

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map(self, x):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)
        att_map = torch.matmul(att_map, self.att_weight)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class HtrgGraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        self.proj_type1 = nn.Linear(in_dim, in_dim)
        self.proj_type2 = nn.Linear(in_dim, in_dim)

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_projM = nn.Linear(in_dim, out_dim)

        self.att_weight11 = self._init_new_params(out_dim, 1)
        self.att_weight22 = self._init_new_params(out_dim, 1)
        self.att_weight12 = self._init_new_params(out_dim, 1)
        self.att_weightM = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        self.proj_with_attM = nn.Linear(in_dim, out_dim)
        self.proj_without_attM = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x1, x2, master=None):
        '''
        x1  :(#bs, #node, #dim)
        x2  :(#bs, #node, #dim)
        '''
        #print('x1',x1.shape)
        #print('x2',x2.shape)
        num_type1 = x1.size(1)
        num_type2 = x2.size(1)
        #print('num_type1',num_type1)
        #print('num_type2',num_type2)
        x1 = self.proj_type1(x1)
        #print('proj_type1',x1.shape)
        x2 = self.proj_type2(x2)
        #print('proj_type2',x2.shape)
        x = torch.cat([x1, x2], dim=1)
        #print('Concat x1 and x2',x.shape)
        
        if master is None:
            master = torch.mean(x, dim=1, keepdim=True)
            #print('master',master.shape)
        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x, num_type1, num_type2)
        #print('master',master.shape)
        # directional edge for master node
        master = self._update_master(x, master)
        #print('master',master.shape)
        # projection
        x = self._project(x, att_map)
        #print('proj x',x.shape)
        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)

        x1 = x.narrow(1, 0, num_type1)
        #print('x1',x1.shape)
        x2 = x.narrow(1, num_type1, num_type2)
        #print('x2',x2.shape)
        return x1, x2, master

    def _update_master(self, x, master):

        att_map = self._derive_att_map_master(x, master)
        master = self._project_master(x, master, att_map)

        return master

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map_master(self, x, master):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = x * master
        att_map = torch.tanh(self.att_projM(att_map))

        att_map = torch.matmul(att_map, self.att_weightM)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _derive_att_map(self, x, num_type1, num_type2):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)

        att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)

        att_board[:, :num_type1, :num_type1, :] = torch.matmul(
            att_map[:, :num_type1, :num_type1, :], self.att_weight11)
        att_board[:, num_type1:, num_type1:, :] = torch.matmul(
            att_map[:, num_type1:, num_type1:, :], self.att_weight22)
        att_board[:, :num_type1, num_type1:, :] = torch.matmul(
            att_map[:, :num_type1, num_type1:, :], self.att_weight12)
        att_board[:, num_type1:, :num_type1, :] = torch.matmul(
            att_map[:, num_type1:, :num_type1, :], self.att_weight12)

        att_map = att_board

        

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _project_master(self, x, master, att_map):

        x1 = self.proj_with_attM(torch.matmul(
            att_map.squeeze(-1).unsqueeze(1), x))
        x2 = self.proj_without_attM(master)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class GraphPool(nn.Module):
    def __init__(self, k: float, in_dim: int, p: Union[float, int]):
        super().__init__()
        self.k = k
        self.sigmoid = nn.Sigmoid()
        self.proj = nn.Linear(in_dim, 1)
        self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
        self.in_dim = in_dim

    def forward(self, h):
        Z = self.drop(h)
        weights = self.proj(Z)
        scores = self.sigmoid(weights)
        new_h = self.top_k_graph(scores, h, self.k)

        return new_h

    def top_k_graph(self, scores, h, k):
        """
        args
        =====
        scores: attention-based weights (#bs, #node, 1)
        h: graph data (#bs, #node, #dim)
        k: ratio of remaining nodes, (float)
        returns
        =====
        h: graph pool applied data (#bs, #node', #dim)
        """
        _, n_nodes, n_feat = h.size()
        n_nodes = max(int(n_nodes * k), 1)
        _, idx = torch.topk(scores, n_nodes, dim=1)
        idx = idx.expand(-1, -1, n_feat)

        h = h * scores
        h = torch.gather(h, 1, idx)

        return h


class Residual_block(nn.Module):
    def __init__(self, nb_filts, first=False):
        super().__init__()
        self.first = first

        if not self.first:
            self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
        self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(1, 1),
                               stride=1)
        self.selu = nn.SELU(inplace=True)

        self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
        self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(0, 1),
                               stride=1)

        if nb_filts[0] != nb_filts[1]:
            self.downsample = True
            self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
                                             out_channels=nb_filts[1],
                                             padding=(0, 1),
                                             kernel_size=(1, 3),
                                             stride=1)

        else:
            self.downsample = False
        

    def forward(self, x):
        identity = x
        if not self.first:
            out = self.bn1(x)
            out = self.selu(out)
        else:
            out = x

        #print('out',out.shape)
        out = self.conv1(x)

        #print('aft conv1 out',out.shape)
        out = self.bn2(out)
        out = self.selu(out)
        # print('out',out.shape)
        out = self.conv2(out)
        #print('conv2 out',out.shape)
        
        if self.downsample:
            identity = self.conv_downsample(identity)

        out += identity
        #out = self.mp(out)
        return out


class Model(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.device = device
        
        # AASIST parameters
        filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]]
        gat_dims = [64, 32]
        pool_ratios = [0.5, 0.5, 0.5, 0.5]
        temperatures =  [2.0, 2.0, 100.0, 100.0]

        ####
        # create network wav2vec 2.0
        ####
        self.ssl_model = SSLModel(self.device)
        self.LL = nn.Linear(self.ssl_model.out_dim, 128)

        self.first_bn = nn.BatchNorm2d(num_features=1)
        self.first_bn1 = nn.BatchNorm2d(num_features=64)
        self.drop = nn.Dropout(0.5, inplace=True)
        self.drop_way = nn.Dropout(0.2, inplace=True)
        self.selu = nn.SELU(inplace=True)

        # RawNet2 encoder
        self.encoder = nn.Sequential(
            nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
            nn.Sequential(Residual_block(nb_filts=filts[2])),
            nn.Sequential(Residual_block(nb_filts=filts[3])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])))

        self.attention = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(1,1)),
            nn.SELU(inplace=True),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 64, kernel_size=(1,1)),   
        )
        # position encoding
        self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1]))
        
        self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
        self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
        
        # Graph module
        self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[0])
        self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[1])
        # HS-GAL layer 
        self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])

        # Graph pooling layers
        self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
        self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
        self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

        self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        
        self.out_layer = nn.Linear(5 * gat_dims[1], 2)

    def forward(self, x):
        #-------pre-trained Wav2vec model fine tunning ------------------------##
        x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1))
        x = self.LL(x_ssl_feat) #(bs,frame_number,feat_out_dim)
        
        # post-processing on front-end features
        x = x.transpose(1, 2)   #(bs,feat_out_dim,frame_number)
        x = x.unsqueeze(dim=1) # add channel 
        x = F.max_pool2d(x, (3, 3))
        x = self.first_bn(x)
        x = self.selu(x)

        # RawNet2-based encoder
        x = self.encoder(x)
        x = self.first_bn1(x)
        x = self.selu(x)
        
        w = self.attention(x)
        
        #------------SA for spectral feature-------------#
        w1 = F.softmax(w,dim=-1)
        m = torch.sum(x * w1, dim=-1)
        e_S = m.transpose(1, 2) + self.pos_S 
        
        # graph module layer
        gat_S = self.GAT_layer_S(e_S)
        out_S = self.pool_S(gat_S)  # (#bs, #node, #dim)
        
        #------------SA for temporal feature-------------#
        w2 = F.softmax(w,dim=-2)
        m1 = torch.sum(x * w2, dim=-2)
     
        e_T = m1.transpose(1, 2)
       
        # graph module layer
        gat_T = self.GAT_layer_T(e_T)
        out_T = self.pool_T(gat_T)
        
        # learnable master node
        master1 = self.master1.expand(x.size(0), -1, -1)
        master2 = self.master2.expand(x.size(0), -1, -1)

        # inference 1
        out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
            out_T, out_S, master=self.master1)

        out_S1 = self.pool_hS1(out_S1)
        out_T1 = self.pool_hT1(out_T1)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
            out_T1, out_S1, master=master1)
        out_T1 = out_T1 + out_T_aug
        out_S1 = out_S1 + out_S_aug
        master1 = master1 + master_aug

        # inference 2
        out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
            out_T, out_S, master=self.master2)
        out_S2 = self.pool_hS2(out_S2)
        out_T2 = self.pool_hT2(out_T2)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
            out_T2, out_S2, master=master2)
        out_T2 = out_T2 + out_T_aug
        out_S2 = out_S2 + out_S_aug
        master2 = master2 + master_aug

        out_T1 = self.drop_way(out_T1)
        out_T2 = self.drop_way(out_T2)
        out_S1 = self.drop_way(out_S1)
        out_S2 = self.drop_way(out_S2)
        master1 = self.drop_way(master1)
        master2 = self.drop_way(master2)

        out_T = torch.max(out_T1, out_T2)
        out_S = torch.max(out_S1, out_S2)
        master = torch.max(master1, master2)

        # Readout operation
        T_max, _ = torch.max(torch.abs(out_T), dim=1)
        T_avg = torch.mean(out_T, dim=1)

        S_max, _ = torch.max(torch.abs(out_S), dim=1)
        S_avg = torch.mean(out_S, dim=1)
        
        last_hidden = torch.cat(
            [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
        
        last_hidden = self.drop(last_hidden)
        output = self.out_layer(last_hidden)
        
        return output


def get_model(path_pth, device):
    model = Model(device)
    model.load_state_dict(torch.load(path_pth,map_location=device))

    nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
    model = model.to(device)
    print('nb_params:',nb_params)

    return model

# utils

In [5]:
def progressbar(it, prefix="", size=60, out=sys.stdout):  # Python3.6+
    count = len(it)
    start = time.time()

    def show(j):
        x = int(size * j / count)
        remaining = ((time.time() - start) / j) * (count - j)
        passing = time.time() - start
        mins_pas, sec_pass = divmod(passing, 60)
        time_pas = f"{int(mins_pas):02}:{sec_pass:05.2f}"

        mins, sec = divmod(remaining, 60)
        time_str = f"{int(mins):02}:{sec:05.2f}"

        print(f"{prefix}[{u'█' * x}{('.' * (size - x))}] {j}/{count} time {time_pas} / {time_str}", end='\r', file=out,
              flush=True)

    for i, item in enumerate(it):
        yield item
        show(i + 1)
    print("\n", flush=True, file=out)


class ChanelWiseStats(nn.Module):
    """
    The class that computes mean and standart deviation
    in input data acrocc channels
    """

    def __init__(self):
        super(ChanelWiseStats, self).__init__()

    def forward(self, x):
        x = x.view(x.data.shape[0], x.data.shape[1],
                   x.data.shape[2] * x.data.shape[3])

        mean = torch.mean(x, 2)
        std = torch.std(x, 2)

        return torch.stack((mean, std), dim=1)


class View(nn.Module):
    """
    Auxiliary class
    """

    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, input):
        return input.view(self.shape)


class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


def pad_random(x, max_len=64600):
    x_len = x.shape[0]

    if x_len > max_len:
        stt = np.random.randint(x_len - max_len)
        return x[stt:stt + max_len]

    num_repeats = int(max_len / x_len) + 1
    padded_x = np.tile(x, num_repeats)[:max_len]
    return padded_x


def pad(x, max_len=64600):
    x_len = x.shape[0]
    if x_len >= max_len:
        return x[:max_len]
    # need to pad
    num_repeats = int(max_len / x_len) + 1
    padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
    return padded_x


def get_optimizer(model, config):
    if config["opt"] == 'Adam':
        optimizer = Adam(
            model.parameters(),
            lr=config["lr"],
            weight_decay=config["weight_decay"]
        )
    elif config["opt"] == 'AdaBound':
        optimizer = AdaBound(
            model.parameters(),
            lr=config['lr'],
            weight_decay=config['weight_decay']
        )
    else:
        raise ValueError(f"Optimizer {config['optimizer']} not supported")
    return optimizer


def load_checkpoint(path):
    # with open(path, "r") as f:
    #     return json.load(f)
    return path


# dataset

In [6]:
class ASVspoof2019(Dataset):
    def __init__(self, ids, dir_path, labels, pad_fn=pad_random, is_train=True):
        self.ids = ids
        self.labels = labels
        self.dir_path = dir_path
        self.cut = 64600
        self.is_train = is_train
        self.pad_fn = pad_fn

    def __getitem__(self, index):
        path_to_flac = f"{self.dir_path}/flac/{self.ids[index]}.flac"
        audio, rate = sf.read(path_to_flac)
        x_pad = self.pad_fn(audio, self.cut)
        x_inp = Tensor(x_pad)
        if not self.is_train:
            return x_inp, self.ids[index], torch.tensor(self.labels[index])
        return x_inp, torch.tensor(self.labels[index]), rate

    def __len__(self):
        return len(self.ids)

class EvalDataset(Dataset):
    def __init__(self, ids, dir_path, pad_fn=pad_random, cut=64600):
        self.ids = ids
        self.dir_path = dir_path
        self.cut = cut
        self.pad_fn = pad_fn

    def __getitem__(self, index):
        path_to_wav = f"{self.dir_path}/{self.ids[index]}"
        audio, rate = sf.read(path_to_wav)
        x_pad = self.pad_fn(audio, self.cut)
        x_inp = Tensor(x_pad)
        return x_inp, self.ids[index]

    def __len__(self):
        return len(self.ids)


def get_data_for_evaldataset(path):
    ids_list = os.listdir(path)
    return ids_list


def get_data_for_dataset(path):
    ids_list = []
    label_list = []
    with open(path, "r") as file:
        for line in file:
            line = line.split()
            id, label = line[1], line[-1]
            ids_list.append(id)
            label = 1 if label == "bonafide" else 0
            label_list.append(label)
    return ids_list, label_list


def get_datasets(config):
    if config["model"] == "Res2TCNGuard":
        val_pad_fn = pad
    else:
        val_pad_fn = pad_random

    train_ids, train_labels = get_data_for_dataset(config["train_label_path"])
    train_dataset = ASVspoof2019(
        train_ids,
        config["train_path_flac"],
        train_labels
    )

    dev_ids, dev_labels = get_data_for_dataset(config["dev_label_path"])
    dev_dataset = ASVspoof2019(
        dev_ids,
        config["dev_path_flac"],
        dev_labels,
        val_pad_fn,
        False
    )

    eval_ids, eval_labels = get_data_for_dataset(config["eval_label_path"])

    eval_dataset = ASVspoof2019(eval_ids, config["eval_path_flac"], eval_labels, val_pad_fn, False)

    return {
        "train": train_dataset,
        "dev": dev_dataset,
        "eval": eval_dataset
    }


def get_dataloaders(datasets, config):
    dataloaders = {}

    if datasets.get("train"):
        train_loader = DataLoader(
            datasets["train"],
            batch_size=config["batch_size"],
            shuffle=True,
            num_workers=config["num_workers"]
        )
        dataloaders["train"] = train_loader
    if datasets.get("dev"):
        dev_loader = DataLoader(
            datasets["dev"],
            batch_size=config["batch_size"],
            shuffle=False,
            num_workers=config["num_workers"]
        )
        dataloaders["dev"] = dev_loader

    if datasets.get("eval"):
        eval_loader = DataLoader(
            datasets["eval"],
            batch_size=config["batch_size"],
            shuffle=False,
            num_workers=config["num_workers"]
        )
        dataloaders["eval"] = eval_loader

    return dataloaders


# metrics

In [7]:
def compute_det_curve(bonafide_scores, spoof_scores):
    """
    function, that comuputes FRR and FAR with their thresholds

    args:
        bonafide_scores: score for bonafide speech
        spoof_scores: score for spoofed speech
    output:
        frr: false rejection rate
        far: false acceptance rate
        threshlods: thresholds for frr and far
    todo:
        rewrite to torch
        create tests
    """
    # number of scores
    n_scores = bonafide_scores.size + spoof_scores.size

    # bona fide scores and spoof scores
    all_scores = np.concatenate((bonafide_scores, spoof_scores))

    # label of bona fide score is 1
    # label of spoof score is 0
    labels = np.concatenate((np.ones(bonafide_scores.size), np.zeros(spoof_scores.size)))

    # indexes of sorted scores in all scores
    indices = np.argsort(all_scores, kind='mergesort')
    # sort labels based on scores
    labels = labels[indices]

    # Compute false rejection and false acceptance rates

    # tar cumulative value
    tar_trial_sums = np.cumsum(labels)
    nontarget_trial_sums = spoof_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums)

    # false rejection rates
    frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / bonafide_scores.size))

    # false acceptance rates
    far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / spoof_scores.size))

    # Thresholds are the sorted scores
    thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices]))

    return frr, far, thresholds


def compute_eer(bonafide_scores, spoof_scores):
    """
    Returns equal error rate (EER) and the corresponding threshold.
    args:
        bonafide_scores: score for bonafide speech
        spoof_scores: score for spoofed speech
    output:
        eer: equal error rate
        threshold: index, where frr=far
    todo:
        rewrite to torch
        create tests
    """
    frr, far, thresholds = compute_det_curve(bonafide_scores, spoof_scores)

    # absolute differense between frr and far
    abs_diffs = np.abs(frr - far)

    # index of minimal absolute difference
    min_index = np.argmin(abs_diffs)

    # equal error rate
    eer = np.mean((frr[min_index], far[min_index]))
    return eer, thresholds[min_index]


# @torch.inference_mode
def produce_evaluation_file(data_loader,
                            model,
                            device,
                            loss_fn,
                            save_path,
                            trial_path,
                            random=False,
                            dropout=0):
    """
    Create file, that need to give in function calculcate_t-DCF_EER
    args:
        data_loader: loader, that gives batch to model
        model: model, that calculate what we need
        device: device for data, model
        save_path: path where file shoud be saved
        trial_path: path from LA CM protocols
    todo:
        this function must return result: tensor of uid, src, key, score
    """

    # turning model into evaluation mode
    model.eval()

    # read file ASVspoof2019.LA.cm.<dev/train/eval>.trl.txt
    with open(trial_path, "r") as file_trial:
        trial_lines = file_trial.readlines()

    # list of utterance id and list of score for appropiate uid
    fname_list = []
    score_list = []
    current_loss = 0
    # inference
    for batch_x, utt_id, batch_y in progressbar(data_loader, prefix='computing cm score'):
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        with torch.no_grad():
            # first is hidden layer, second is result
            batch_out = model(batch_x)
            # 1 - for bonafide speech class
            batch_score = (batch_out[:, 1]).data.cpu().numpy().ravel()
            loss = loss_fn(batch_out, batch_y)
            current_loss += loss.item() / len(data_loader)

        # add outputs
        fname_list.extend(utt_id)
        score_list.extend(batch_score.tolist())
    # assert len(trial_lines) == len(fname_list) == len(score_list)

    # saving results
    with open(save_path, "w") as fh:

        # fn - uid, sco - score, trl - trial_lines
        for fn, sco, trl in zip(fname_list, score_list, trial_lines):
            _, utt_id, _, src, key = trl.strip().split(' ')
            assert fn == utt_id
            # format: utterance id - type of spoof attack - key - score
            fh.write("{} {} {} {}\n".format(utt_id, src, key, sco))
    print("Scores saved to {}".format(save_path))

    return current_loss

# @torch.inference_mode
def produce_submit_file(data_loader,
                            model,
                            device,
                            save_path,
                            random=False,
                            dropout=0):
    """
    Create file, that need to give in function calculcate_t-DCF_EER
    args:
        data_loader: loader, that gives batch to model
        model: model, that calculate what we need
        device: device for data, model
        save_path: path where file shoud be saved
    """

    # turning model into evaluation mode
    model.eval()

    # list of utterance id and list of score for appropiate uid
    fname_list = []
    score_list = []
    # inference
    for batch_x, utt_id in progressbar(data_loader, prefix='computing cm score'):
        batch_x = batch_x.to(device)
        with torch.no_grad():
            # first is hidden layer, second is result
            batch_out = model(batch_x)
            # 1 - for bonafide speech class
            batch_score = (batch_out[:, 1]).data.cpu().numpy().ravel()

        # add outputs
        fname_list.extend(utt_id)
        score_list.extend(batch_score.tolist())
    assert len(fname_list) == len(score_list)

    # saving results
    with open(save_path, "w") as fh:
        for fn, sco in zip(fname_list, score_list):
            if ".wav" in fn:
                fn = fn.replace(".wav", "")
            fh.write("{} {}\n".format(fn, sco))
    df = pd.read_csv(save_path, sep=" ", names=["ID", "score"])
    df.to_csv(save_path, index=False)
    print("Scores saved to {}".format(save_path))

    return 0


def obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_thresholds):
    """
    Calculate false alarm rate and miss rate for asv scores

    args:
        tar_asv: scores for asv targets
        non_asv: scores for asv nontargets
        spoof_asv: scores for asv spoofed
        asv_threshold: threshold for asv EER between targets and non_targets
    returns:
        Pfa_asv: false alarm rate for asv
        Pmiss_asv: false miss rate for asv
        Pmiss_spoof_asv: rate of rejection spoofs in asv
    todo:
        rewrite to torch
    """
    Pfa_asv = sum(non_asv >= asv_thresholds) / non_asv.size
    Pmiss_asv = sum(tar_asv < asv_thresholds) / tar_asv.size

    if spoof_asv.size == 0:
        Pmiss_spoof_asv = None
    else:
        Pmiss_spoof_asv = np.sum(spoof_asv < asv_thresholds) / spoof_asv.size

    return Pfa_asv, Pmiss_asv, Pmiss_spoof_asv


def compute_tDCF(bonafide_score_cm, spoof_score_cm, Pfa_asv,
                 Pmiss_asv, Pmiss_spoof_asv, cost_model):
    """
    This function computes min t-DCF value

    args:
        bonafide_score_cm: score for bonafide speech from CM system
        spoof_score_cm: score for spoofed speech from CM systn
        Pfa_asv: false alarm rate from asv system
        Pmiss_asv: miss rate from asv sustem
        Pmiss_spoof_asv: miss rate for spoof utterance from asv system
        cost_model: dict of parameters for t-DCF
    output:
        t-DCF: computed value
        CM_threshold: threshold for EER between Pmiss_cm and Pfa_cm
    todo:
        rewrite to torch
    """

    # obtain miss and false alarm rate of cm
    Pmiss_cm, Pfa_cm, CM_thresholds = compute_det_curve(
        bonafide_score_cm, spoof_score_cm
    )

    # Constants
    C1 = cost_model['Ptar'] * (cost_model['Cmiss_cm'] - cost_model['Cmiss_asv'] * Pmiss_asv) - \
         cost_model['Pnon'] * cost_model['Cfa_asv'] * Pfa_asv

    C2 = cost_model['Cfa_cm'] * cost_model['Pspoof'] * (1 - Pmiss_spoof_asv)

    # obtain t-DCF curve for all thresholds
    tDCF = C1 * Pmiss_cm + C2 * Pfa_cm

    # normalized t-DCF
    tDCFnorm = tDCF / np.minimum(C1, C2)

    return tDCFnorm, CM_thresholds


def calculate_eer_tdcf(cm_scores_file, asv_score_file, output_file, printout=True):
    """
    Function cimputes tdcf, eer for CM sustem, and also compute
    EER of each type of attack and write them into file
    args:
        cm_scores_file: file from produce_evaluation file
        asv_score_file: file from organizers
        ouput_file: file where information of each type of attack for eval dataset will be
        printout: print this file or not
    output:
        EER * 100: percentage of equal error rate for CM system
        min_tDCF: value of t-DCF for CM system
    todo:
        rewrite into torch
        return array instead of create file
    """
    # cm data from file
    cm_data = np.genfromtxt(cm_scores_file, dtype=str)

    # type of spoof attack
    cm_sources = cm_data[:, 1]

    # spoof or bonafide speech
    cm_keys = cm_data[:, 2]

    # score for utterance
    cm_scores = cm_data[:, 3].astype(np.float64)

    # score for bonafide speech
    bona_cm = cm_scores[cm_keys == 'bonafide']

    # score for spoofed utterance
    spoof_cm = cm_scores[cm_keys == 'spoof']

    # equal error rate
    EER, _ = compute_eer(bona_cm, spoof_cm)

    # fix parameters for t-DCF
    cost_model = {
        'Pspoof': 0.05,
        'Ptar': 0.9405,
        'Pnon': 0.0095,
        'Cmiss': 1,
        'Cfa': 10,  ###########
        'Cmiss_asv': 1,
        'Cfa_asv': 10,
        'Cmiss_cm': 1,
        'Cfa_cm': 10,
    }

    # load organizers' ASV scores
    asv_data = np.genfromtxt(asv_score_file, dtype=str)

    # keys: target, non-target, spoof
    asv_keys = asv_data[:, 1]

    # score for each utterance
    asv_scores = asv_data[:, 2].astype(np.float64)

    # target, non-target and spoof scores from the ASV scores
    tar_asv = asv_scores[asv_keys == 'target']
    non_asv = asv_scores[asv_keys == 'nontarget']
    spoof_asv = asv_scores[asv_keys == 'spoof']

    # EER of the standalone systems and fix ASV operation point to
    eer_asv, asv_threshold = compute_eer(tar_asv, non_asv)

    # generate attack types from A07 to A19
    attack_types = [f'A{_id:02d}' for _id in range(7, 20)]

    # compute eer for each type of attack
    if printout:
        spoof_cm_breakdown = {
            attack_type: cm_scores[cm_sources == attack_type]
            for attack_type in attack_types
        }

        eer_cm_breakdown = {
            attack_type: compute_eer(bona_cm, spoof_cm_breakdown[attack_type])[0]
            for attack_type in attack_types
        }
    [Pfa_asv, Pmiss_asv, Pmiss_spoof_asv] = obtain_asv_error_rates(
        tar_asv,
        non_asv,
        spoof_asv,
        asv_threshold
    )

    # Compute t-DCF
    tDCF_curve, CM_thresholds = compute_tDCF(
        bona_cm,
        spoof_cm,
        Pfa_asv,
        Pmiss_asv,
        Pmiss_spoof_asv,
        cost_model
    )

    # Minimum t-DCF
    min_tDCF_index = np.argmin(tDCF_curve)
    min_tDCF = tDCF_curve[min_tDCF_index]
    # write results into file
    if printout:
        with open(output_file, 'w') as f_res:
            f_res.write('\nCM SYSTEM\n')
            f_res.write("""\tEER\t\t= {:8.9f} % 
            (Equal error rate for countermeasure)\n""".format(EER * 100)
                        )
            f_res.write('\nTANDEM\n')
            f_res.write('\tmin-tDCF\t\t= {:8.9f}\n'.format(min_tDCF))
            f_res.write('\nBREAKDOWN CM SYSTEM\n')
            for attack_type in attack_types:
                _eer = eer_cm_breakdown[attack_type] * 100
                f_res.write(
                    f'\tEER {attack_type}\t\t= {_eer:8.9f} % (Equal error rate for {attack_type})\n'
                )
        os.system(f"cat {output_file}")
    return EER * 100, min_tDCF


def evaluate_EER_file(ref_df, pred_df, output_file):
    """

        :param ref_df: csv file with columns: uttid, label
        :param pred_df: csv file with columns: uttid, score
        :return: err
        """

    ref_df = pd.read_csv(ref_df, header=None, names=["_", "uttid", "___", "__", "label"], sep=" ")
    ref_df = ref_df.sort_values("uttid")

    pred_df = pd.read_csv(pred_df, header=None, names=["uttid", "_", "__", "scores"], sep=" ")
    pred_df = pred_df.sort_values("uttid")
    if not ref_df["uttid"].equals(pred_df["uttid"]):
        raise ValueError("The 'uttid' columns in the reference and prediction files do not match.")

    pos_scores = pred_df["scores"][ref_df["label"] == "bonafide"]
    neg_scores = pred_df["scores"][ref_df["label"] == "spoof"]

    eer, _ = compute_eer(pos_scores, neg_scores)
    with open(output_file, "w") as f:
        f.write(f"EER: {eer}")
    return eer * 100


def evaluate_EER(ref_df, pred_df):
    """

    :param ref_df: csv file with columns: uttid, label
    :param pred_df: csv file with columns: uttid, score
    :return: err
    """

    ref_df = pd.read_csv(ref_df, header=None, names=["speaker", "uttid", "-", "algo", "label"], sep=" ")
    ref_df = ref_df.sort_values("uttid")

    pred_df = pd.read_csv(pred_df, header=None, names=["uttid", "algo", "label", "scores"], sep=" ")
    pred_df = pred_df.sort_values("uttid")

    if not ref_df["uttid"].equals(pred_df["uttid"]):
        raise ValueError("The 'uttid' columns in the reference and prediction files do not match.")

    pos_scores = pred_df["scores"][ref_df["label"] == 'bonafide']
    neg_scores = pred_df["scores"][ref_df["label"] == 'spoof']

    eer, _ = compute_eer(pos_scores, neg_scores)
    return eer * 100


# submit

In [None]:
def submit(cfg=CONFIG, eval_path_wav="/asvspoof/wavs/", output_file="submit.csv"):
    eval_ids = get_data_for_evaldataset(eval_path_wav)

    eval_dataset = EvalDataset(eval_ids, eval_path_wav, pad)
    eval_dataset = {
        "eval": eval_dataset
    }

    dataloader = get_dataloaders(eval_dataset, cfg)

    model = get_model(cfg["checkpoint"], cfg["device"])

    produce_submit_file(
        dataloader["eval"],
        model,
        cfg["device"],
        output_file
    )

submit(cfg=CONFIG, eval_path_wav="/asvspoof/wavs/", output_file="submit_1.csv")

nb_params: 317837834
computing cm score[............................................................] 286/18087 time 00:23.55 / 24:25.86