In [2]:
target = "MSU-MFSD" # change to argument
from torch.utils.data import Dataset, DataLoader
from methods import *

In [None]:
def test_resnet(args, model_size, model_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if model_size == 'resnet18':
        model = models.resnet18()
    elif model_size == 'resnet50':
        model = models.resnet50()
    elif model_size == 'resnet101':
        model = models.resnet101()
        
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(num_ftrs, 1),
        nn.Sigmoid()
    )
    
    model.load_state_dict(torch.load(model_path))
    
    model = model.to(device)
    
    test_transform = transforms()['test']

    test_set = get_datasets(args.data_dir, FaceDataset, train=False, target=args.target, transform=test_transform, model_name='resnet')
    test_loader = DataLoader(test_set[args.target], batch_size=args.batch_size, shuffle=False, num_workers=4)
    
    criterion = nn.BCELoss()
    
    test_loss, auc, hter, apcer, bpcer = compute_metrics(model, test_loader, criterion, device)  
    
    print(f"Test on {args.target}\nAUC: {auc:.5f}\n  HTER: {hter:.5f}\n  APCER: {apcer:.5f}\n  BPCER: {bpcer:.5f}")

In [None]:
if args.method.startswith('resnet'):
    test_resnet(args, model_size=args.method, model_path=args.model_path)
elif args.method == 'safas':
    test_safas(args, model_path=args.model_path)