In [1]:
import argparse
import json
import os
import sys
import warnings
from importlib import import_module
from pathlib import Path
from shutil import copy
from typing import Dict, List, Union

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchcontrib.optim import SWA

from data_utils import TestDataset, TrainDataset, genSpoof_list
from eval.calculate_metrics import (calculate_aDCF_tdcf_tEER,
                                    calculate_minDCF_EER_CLLR)
from utils import create_optimizer, seed_worker, set_seed, str_to_bool

warnings.filterwarnings("ignore", category=FutureWarning)
from tqdm import tqdm

In [2]:
def get_model(model_config: Dict, device: torch.device):
    """Define DNN model architecture"""
    module = import_module("models.{}".format(model_config["architecture"]))
    _model = getattr(module, "Model")
    model = _model(model_config).to(device)
    nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
    print(f"no. model params:{(nb_params / 1000):.3f}k")

    return model

In [3]:
with open("./config/SEMAA_2021.conf", "r") as f_json:
    config = json.loads(f_json.read())
model_config = config["model_config"]

In [4]:
model = get_model(model_config, "cpu")

no. model params:341.034k


In [5]:
state_dict = torch.load("exp_result/SEMAA_2021_ep100_bs24_rawboost/weights/epoch_0_0.013.pth")

In [6]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [7]:
import torchaudio

In [46]:
path = "/data/a.varlamov/LJSpeech-1.1/wavs/LJ001-0011.wav"
audio, sr = torchaudio.load(path)

audio = audio[:, :4 * sr]

In [11]:
batch = torch.cat([audio] * 32)

#### CPU:

In [51]:
%%time
for elem in batch:
    model(elem.unsqueeze(0))

CPU times: user 8min 45s, sys: 2min 14s, total: 11min
Wall time: 5.57 s


In [52]:
%%time

model(batch)

CPU times: user 1min 20s, sys: 31.2 s, total: 1min 51s
Wall time: 985 ms


(tensor([[ 0.0000,  0.0000,  0.8879,  ..., -0.0000, -0.0000, -0.4481],
         [ 0.8185,  0.5034,  0.0000,  ...,  0.0000, -1.0446,  0.0000],
         [ 0.4865,  0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.6985],
         ...,
         [ 0.6315,  0.0000,  1.0028,  ...,  0.0000, -0.0000, -0.0000],
         [ 2.0463,  0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.5176],
         [ 0.1737,  0.0000,  0.0000,  ...,  0.0000, -1.3415, -0.1650]],
        grad_fn=<MulBackward0>),
 tensor([[ 1.6370, -2.8142],
         [ 1.6689, -1.4014],
         [ 2.7145, -1.4595],
         [ 2.0694, -2.4835],
         [-0.5864, -0.3252],
         [ 1.3176, -2.1768],
         [ 1.3508, -1.1116],
         [ 2.2993, -2.4627],
         [ 3.7679, -2.9539],
         [ 0.9321, -0.9999],
         [ 2.3361, -2.4761],
         [ 1.9903, -1.3204],
         [ 1.6491, -1.5780],
         [ 1.3913, -1.2798],
         [ 3.3301, -2.9962],
         [ 1.2729, -1.0937],
         [ 1.4748, -1.5431],
         [ 2.7387, -3.4770],
    

In [55]:
device = "cuda:5"

In [58]:
model = model.to(device)
batch = batch.to(device)

In [59]:
%%time
for elem in batch:
    model(elem.unsqueeze(0))

CPU times: user 672 ms, sys: 174 ms, total: 846 ms
Wall time: 857 ms


In [60]:
%%time

model(batch)

CPU times: user 103 ms, sys: 4.74 ms, total: 107 ms
Wall time: 106 ms


(tensor([[ 1.2631,  0.4138,  0.0000,  ...,  0.0000, -0.0000,  0.4021],
         [ 0.4781,  0.0000,  0.0000,  ..., -0.0465, -0.0000, -0.0000],
         [ 0.0000,  0.4314,  2.2979,  ...,  0.3793, -1.2091,  0.0000],
         ...,
         [ 0.0000,  0.8540,  0.0000,  ..., -0.0000, -1.4942,  0.3136],
         [ 0.0000,  0.0000,  0.9595,  ..., -0.0000, -0.0000,  0.3111],
         [ 1.2072,  0.0000,  2.5564,  ..., -0.0000, -0.0000, -0.7329]],
        device='cuda:5', grad_fn=<MulBackward0>),
 tensor([[ 1.6735,  0.1944],
         [ 1.5841, -1.7493],
         [ 1.7077, -2.1179],
         [ 1.1154, -0.9293],
         [ 2.7034, -2.4911],
         [ 0.4057, -2.7683],
         [ 0.3000, -0.4473],
         [ 1.8537, -2.3073],
         [ 4.0154, -3.2347],
         [ 1.2749, -1.2102],
         [ 2.0555, -2.0848],
         [ 3.5232, -5.7168],
         [ 1.8756, -2.3972],
         [ 0.6575, -0.7447],
         [ 2.3530, -3.0632],
         [ 1.1080, -2.1790],
         [ 2.8125, -1.1593],
         [ 0.752