In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import time
import random

from dataloader import get_data
from utils import *
from config import params
from models.texbat_model import Generator, Discriminator, DHead, QHead, CHead


# 加载训练好的模型

In [2]:
# Set random seed for reproducibility.
seed = 1123
random.seed(seed)
torch.manual_seed(seed)
print("Random Seed: ", seed)

# Use GPU if available.
device = torch.device("cuda:1" if(torch.cuda.is_available()) else "cpu")
print(device, " will be used.\n")

#load model
log = torch.load("/home/yhang/GAN/InfoGAN-PyTorch/checkpoint/classifier/model_epoch_100_TEXBAT_learningrate_0")
print(log["params"])

#load data
dataloader = get_data("TEXBAT", params["batch_size"], "test")
# the length of the dataloader
print("Length of dataloader: ", len(dataloader))

Random Seed:  1123
cuda:1  will be used.

{'batch_size': 128, 'num_epochs': 100, 'D_learning_rate': 0.0002, 'G_learning_rate': 0.0002, 'beta1': 0.5, 'beta2': 0.999, 'save_epoch': 25, 'dataset': 'TEXBAT', 'num_z': 256, 'num_dis_c': 1, 'dis_c_dim': 3, 'num_con_c': 0}
Length of dataloader:  527


In [3]:
discriminator = Discriminator().to(device)
discriminator.load_state_dict(log["discriminator"])
netQ = QHead().to(device)
netQ.load_state_dict(log["netQ"])

<All keys matched successfully>

In [4]:
print("-"*25)
print("Starting Testing Loop...\n")
print("-"*25)

total = 0
correct = 0
spoof_predict = 0
spoof_data = 0
num_clean_predict = 0   # 预测正确的clean类别数量
num_clean_data = 0      # 数据集中的clean数据数量
num_ds3_predict = 0     # 预测正确的ds3类别数量
num_ds3_data = 0        # 数据集的ds3类别数量
num_ds4_predict = 0     # 预测正确的ds4类别数量
num_ds4_data = 0        # 数据集中的ds4数据数量
# num_ds8_predict = 0     # 预测正确的ds8类别数量
# num_ds8_data = 0        # 数据集中的ds8数据数量

# 在评估模式下，我们不需要计算梯度
with torch.no_grad():  
    for data, labels in dataloader:
        data, labels = data.to(device), labels.to(device)
        b_size = data.size(0)
        # Real data Loss
        output1 = discriminator(data)
        probs_real, q_mu, q_var = netQ(output1)
        target = torch.tensor(labels.view(1, b_size), dtype=torch.long).squeeze(0)
        # 取probs_realz中最大值的索引作为预测类别
        _, predicted = torch.max(probs_real, 1)
        total += b_size
        # correct += (predicted == target).sum().item()
        for i in range(b_size):
            if target[i] == 0:
                num_clean_data += 1
                if predicted[i] == 0:
                    num_clean_predict += 1
            elif target[i] == 1:
                num_ds3_data += 1
                spoof_data += 1
                if predicted[i] == 1:
                    num_ds3_predict += 1
                    spoof_predict += 1
            elif target[i] == 2:
                num_ds4_data += 1
                spoof_data += 1
                if predicted[i] == 2:
                    spoof_predict += 1
                    num_ds4_predict += 1
            if(predicted[i] not in [0, 1, 2]):
                print(predicted[i])
            
# print("Accuracy: ", correct/total)
print("Clean Accuracy: ", num_clean_predict/num_clean_data)   
print("ds3 Accuracy: ", num_ds3_predict/num_ds3_data)   
print("ds4 Accuracy: ", num_ds4_predict/num_ds4_data)   
print("Detect Accuracy: ", spoof_predict/spoof_data)            
print('spoof' , spoof_data)
print('spoof_predict', spoof_predict)
print('num_clean_data', num_clean_data)
print('num_clean_predict', num_clean_predict)
print("total: ", total)     
        

-------------------------
Starting Testing Loop...

-------------------------


  target = torch.tensor(labels.view(1, b_size), dtype=torch.long).squeeze(0)


Clean Accuracy:  0.9998656816655473
ds3 Accuracy:  0.5677668695304517
ds4 Accuracy:  0.565546218487395
Detect Accuracy:  0.5666541078433546
spoof 45121
spoof_predict 25568
num_clean_data 22335
num_clean_predict 22332
total:  67456
