In [2]:
import sys
sys.path.append('../')
import argparse
from tqdm import tqdm
from scipy.stats import entropy
import torch
import torch.nn.functional as F

from timm.models import create_model
from defenses.victim import MAD, ReverseSigmoid, RandomNoise

from datasets import build_transform, get_dataset

import models
import utils
from utils import get_free_gpu

num_gpus = 1
gpu_chosen = get_free_gpu(num_gpus)
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

In [None]:
args = {'input_size': 224}
args = argparse.Namespace(**args)

In [8]:
def get_accuracy(output, target, topk=(1,)):
    """ Computes the precision@k for the specified values of k """
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    # one-hot case
    if target.ndimension() > 1:
        target = target.max(1)[1]

    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = dict()
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
    return res


def predict(model, model_defended, data_loader, device='cuda'):
    model = model.to(device)
    model.eval()
    preds_orig = []
    preds_def = []
    labels = []
    with torch.no_grad():
        for x, y in tqdm(data_loader):
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            preds_orig.append(F.softmax(model(x), 1).to('cpu'))
            preds_def.append(model_defended(x).to('cpu'))
            labels.append(y.to('cpu'))
    return torch.cat(preds_orig), torch.cat(preds_def), torch.cat(labels)


def evaluate(model, model_defended, datasets, batch_size=100, workers=4):
    if not isinstance(datasets, tuple):
        datasets = (datasets, )
    res = {}
    for i, dataset in enumerate(datasets):
        d_type = "" if len(datasets) == 1 else ["train", "test"][i]
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=workers)
        
        print(f'Evaluate on {dataset.__class__.__name__} {d_type} data:')
        preds_orig, preds_def, labels = predict(model, model_defended, data_loader)
        num_classes = preds_def.shape[1]
        
        print(f'Results on {dataset.__class__.__name__} {d_type} data:')
        print('Accuracy original:', get_accuracy(preds_orig, labels)['acc1'])
        print('Accuracy defended:', get_accuracy(preds_def, labels)['acc1'])
        print('Fidelity:', get_accuracy(preds_orig, preds_def)['acc1'])
        print('Mean relative entropy original:', np.mean(entropy(preds_orig, axis=1, base=2) / np.log2(num_classes)))
        print('Mean relative entropy defended:', np.mean(entropy(preds_def, axis=1, base=2) / np.log2(num_classes)))
        print('Mean max/min original:', torch.mean(preds_orig.max(1)[0] / preds_orig.min(1)[0]).item())
        print('Mean max/min defended:', torch.mean(preds_def.max(1)[0] / preds_def.min(1)[0]).item())
        print('Mean L1 distance:', torch.mean(torch.linalg.vector_norm(preds_orig - preds_def, 1, 1)).item())
        print()
        res[d_type] = (preds_orig, preds_def, labels)
    return res

In [None]:
model = create_model(
    'resnet34',
    num_classes=10
)
model.load_state_dict(torch.load(f'checkpoints/checkpoint.pth')['model'])

model_adv = create_model(
    'deit_base_patch16_224',
    pretrained=False,
    num_classes=10
)

datasets = get_dataset('cifar10', train_transform=build_transform(False, args), val_transform=build_transform(False, args))

In [None]:
epsilon = 50.1
beta = 2.0
gamma = 0.5
output_path = './log'
dist_z = 'l1'
oracle = 'argmax'

model_defended = RandomNoise(model=model, out_path=output_path, dist_z=dist_z, epsilon_z=epsilon)
#model_defended = ReverseSigmoid(model=model, out_path=output_path, beta=beta, gamma=gamma)
#model_defended = MAD(model=model, out_path=output_path, epsilon=epsilon, model_adv_proxy=model_adv, oracle=oracle)

In [None]:
res = evaluate(model, model_defended, datasets[1], 512)