In [None]:
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from model import Model
from core_scripts.startup_config import set_random_seed
import random
from torch.utils.data import Dataset
import soundfile as sf
from evaluation import calculate_tDCF_EER

In [None]:
def pad(x, max_len=64600):
    x_len = x.shape[0]
    if x_len >= max_len:
        return x[:max_len]
    # need to pad
    num_repeats = int(max_len / x_len) + 1
    padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
    return padded_x


def pad_random(x: np.ndarray, max_len: int = 64600):
    x_len = x.shape[0]
    # if duration is already long enough
    if x_len >= max_len:
        stt = np.random.randint(x_len - max_len)
        return x[stt:stt + max_len]

    # if too short
    num_repeats = int(max_len / x_len) + 1
    padded_x = np.tile(x, (num_repeats))[:max_len]
    return padded_x

In [None]:
def genSpoof_list(dir_meta, is_train=False, is_eval=False):

    d_meta = {}
    file_list = []
    with open(dir_meta, "r") as f:
        l_meta = f.readlines()

    if is_train:
        for line in l_meta:
            _, key, _, _, label = line.strip().split(" ")
            file_list.append(key)
            d_meta[key] = 1 if label == "bonafide" else 0
        return d_meta, file_list

    elif is_eval:
        for line in l_meta:
            _, key, _, _, _ = line.strip().split(" ")
            #key = line.strip()
            file_list.append(key)
        return file_list
    else:
        for line in l_meta:
            _, key, _, _, label = line.strip().split(" ")
            file_list.append(key)
            d_meta[key] = 1 if label == "bonafide" else 0
        return d_meta, file_list


In [None]:
class Dataset_ASVspoof2019_devNeval(Dataset):
    def __init__(self, list_IDs, base_dir):
        """self.list_IDs	: list of strings (each string: utt key),
        """
        self.list_IDs = list_IDs
        self.base_dir = base_dir
        self.cut = 64600  # take ~4 sec audio (64600 samples)

    def __len__(self):
        return len(self.list_IDs)

    def __getitem__(self, index):
        key = self.list_IDs[index]
        X, _ = sf.read(str(self.base_dir + f"flac/{key}.flac"))
        X_pad = pad(X, self.cut)
        x_inp = Tensor(X_pad)
        return x_inp, key
    

In [None]:
class Arguments():
    database_path = "[path to dataset]"
    protocols_path = "database/"
    seed = 1234
    track = "LA"
    is_eval = True
    cudnn_deterministic_toggle = True
    cudnn_benchmark_toggle = False
    model_path = "[path to model]"
    
args = Arguments()
eval_score_path = "[output eval score path]"

In [None]:
set_random_seed(args.seed, args)

In [None]:
track = args.track
prefix      = 'ASVspoof_{}'.format(track)
prefix_2019 = 'ASVspoof2019.{}'.format(track)
prefix_2021 = 'ASVspoof2021.{}'.format(track)

device = 'cuda' if torch.cuda.is_available() else 'cpu'                  
print('Device: {}'.format(device))

model = Model(args,device)
nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
model =model.to(device)
print('nb_params:',nb_params)

if args.model_path:
    model.load_state_dict(torch.load(args.model_path,map_location=device))
    print('Model loaded : {}'.format(args.model_path))

In [None]:
eval_database_path = args.database_path + "ASVspoof2019_{}_eval/".format(track)
eval_trial_path = args.database_path+"ASVspoof2019_{}_cm_protocols/{}.cm.eval.trl.txt".format(
            track, prefix_2019)
file_eval = genSpoof_list(dir_meta=eval_trial_path,
                              is_train=False,
                              is_eval=True)
eval_set = Dataset_ASVspoof2019_devNeval(list_IDs=file_eval,
                                            base_dir=eval_database_path)
eval_loader = DataLoader(eval_set,
                            batch_size=24,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True)

In [None]:
def produce_evaluation_file(
    data_loader: DataLoader,
    model,
    device: torch.device,
    save_path: str,
    trial_path: str) -> None:
    """Perform evaluation and save the score to a file"""
    model.eval()
    with open(trial_path, "r") as f_trl:
        trial_lines = f_trl.readlines()
    fname_list = []
    score_list = []
    for batch_x, utt_id in data_loader:
        batch_x = batch_x.to(device)
        with torch.no_grad():
            batch_out = model(batch_x)
            batch_score = (batch_out[:, 1]).data.cpu().numpy().ravel()
        # add outputs
        fname_list.extend(utt_id)
        score_list.extend(batch_score.tolist())

    assert len(trial_lines) == len(fname_list) == len(score_list)
    with open(save_path, "w") as fh:
        for fn, sco, trl in zip(fname_list, score_list, trial_lines):
            _, utt_id, _, src, key = trl.strip().split(' ')
            assert fn == utt_id
            fh.write("{} {} {} {}\n".format(utt_id, src, key, sco))
    print("Scores saved to {}".format(save_path))


In [None]:
produce_evaluation_file(eval_loader, model, device,
                        eval_score_path, eval_trial_path)

In [None]:
from evaluation import calculate_tDCF_EER
import os

asv_scores_file = "ASVspoof2019_LA_asv_scores/ASVspoof2019.LA.asv.eval.gi.trl.scores.txt"
eval_eer, eval_tdcf = calculate_tDCF_EER(
            cm_scores_file=eval_score_path,
            asv_score_file=args.database_path + asv_scores_file,
            output_file="[output path]")