In [1]:
from run_utils import *

class Args:
    def __init__(self):
        self.database_path = "dataset/"
        self.protocols_path = 'dataset/'
        self.batch_size = 32
        self.num_epochs = 100
        self.lr = 0.0001
        self.weight_decay = 0.0001
        self.loss = "weighted_CCE"
        self.seed = 1234
        self.model_path = None
        self.comment = None
        self.track = "DF"
        self.eval_output = None
        self.eval = False
        self.is_eval = False
        self.eval_part = 0
        self.cudnn_deterministic_toggle = True
        self.cudnn_benchmark_toggle = False
        
args = Args()

dir_yaml = os.path.splitext('model_config_RawNet')[0] + '.yaml'

with open(dir_yaml, 'r') as f_yaml:
        parser1 = yaml.load(f_yaml)

if not os.path.exists('models'):
    os.mkdir('models')

#make experiment reproducible
set_random_seed(args.seed, args)

track = args.track

assert track in ['LA', 'PA','DF'], 'Invalid track given'

# Database
prefix      = 'ASVspoof_{}'.format(track)
prefix_2019 = 'ASVspoof2019.{}'.format(track)
prefix_2021 = 'ASVspoof2021.{}'.format(track)

#define model saving path
model_tag = 'model_{}_{}_{}_{}_{}'.format(
    track, args.loss, args.num_epochs, args.batch_size, args.lr)
if args.comment:
    model_tag = model_tag + '_{}'.format(args.comment)
model_save_path = os.path.join('models', model_tag)

#set model save directory
if not os.path.exists(model_save_path):
    os.mkdir(model_save_path)

if args.model_path:
    model.load_state_dict(torch.load(args.model_path,map_location=device))
    print('Model loaded : {}'.format(args.model_path))

# evaluation 
if args.eval:
    file_eval = genSpoof_list( dir_meta =  os.path.join(args.protocols_path+'{}_cm_protocols/{}.cm.eval.trl.txt'.format(prefix,prefix_2021)),is_train=False,is_eval=True)
    print('no. of eval trials',len(file_eval))
    eval_set=Dataset_ASVspoof2021_eval(list_IDs = file_eval,base_dir = os.path.join(args.database_path+'ASVspoof2021_{}_eval/'.format(args.track)))
    produce_evaluation_file(eval_set, model, device, args.eval_output)
    sys.exit(0)

  parser1 = yaml.load(f_yaml)


In [41]:
from my_model import ASVspoof2019TrillDataModule, MyModel
from tensorneko import Trainer
from torch.utils.data import DataLoader
from main import *

In [2]:
model = MyModel.load_from_checkpoint("ckpt/trill_mean_mlp_1624694162.ckpt").cuda()

In [3]:
dm = ASVspoof2019TrillDataModule(batch_size=256, num_workers=10)
dm.setup()

no. of training trials 25380
no. of validation trials 24844
no. of eval trials 611829


In [8]:
# produce_evaluation_file(dm.test_dataset, model, "cuda", "out.txt")

100%|██████████| 4780/4780 [47:18<00:00,  1.68it/s]

Scores saved to out





In [4]:
from metrics.eval_metrics import * 

In [6]:
trainer = Trainer.build(gpus=1, log_every_n_steps=100, checkpoint_callback=False,
                        logger=model.name)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [43]:
out = trainer.predict(model, dataloaders=DataLoader(dm.train_dataset, batch_size=dm.batch_size, shuffle=False))

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

In [60]:
torch.Tensor(1).detach()

tensor([-2.4562e-21])

In [45]:
out = torch.vstack(out)

In [47]:
pred = out.max(dim=1)[1]

In [48]:
pred

tensor([1, 1, 1,  ..., 0, 0, 0], device='cuda:0')

In [49]:
true = torch.Tensor(list(map(lambda x: x[1], dm.train_dataset))).int().cuda()

In [45]:
val_true = []
for x,y in tqdm(dm.val_dataset):
    val_true.append(y)
val_true = torch.Tensor(val_true).int().cuda()


100%|██████████| 24844/24844 [02:58<00:00, 139.11it/s]


In [50]:
corr = true.cuda() == pred.cuda()

target_scores = []
nontarget_scores = []

for i in tqdm(range(len(true))):
    if corr[i]:
        target_scores.append(out[i, 1])
    else:
        nontarget_scores.append(out[i, 1])

target_scores = torch.Tensor(target_scores).numpy()
nontarget_scores = torch.Tensor(nontarget_scores).numpy()

compute_eer(target_scores, nontarget_scores)


  0%|          | 0/25380 [00:00<?, ?it/s][A
 28%|██▊       | 7065/25380 [00:00<00:00, 70637.27it/s][A
 56%|█████▌    | 14129/25380 [00:00<00:00, 70335.90it/s][A
100%|██████████| 25380/25380 [00:00<00:00, 70278.66it/s][A


(0.8962460066073117, -2.320418119430542)

In [10]:
file_eval = genSpoof_list(dir_meta='dataset/ASVspoof_DF_cm_protocols/ASVspoof2021.DF.cm.eval.trl.txt',is_train=False,is_eval=True)
print('no. of eval trials',len(file_eval))

no. of eval trials 611829


In [17]:
eval_set = ASVspoof2019Trill(list_IDs = file_eval, labels=None, base_dir='dataset/ASVspoof2021_DF_eval/')

In [None]:
produce_evaluation_file(eval_set, model, device, args.eval_output)