In [None]:
import argparse
import os
import random

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report

import torch
from torch.nn.functional import softmax
import torchvision

from ganimation.config import get_config
from ganimation.data_loader import get_loader
from ganimation.psolver import Disruptor

from ganimation.attacks import LinfPGDAttack

In [None]:
torch.cuda.set_device(0)
device = torch.device("cuda")

In [None]:
parser = argparse.ArgumentParser()

# Model configuration.
parser.add_argument('--c_dim', type=int, default=17,
                    help='dimension of domain labels')
parser.add_argument('--image_size', type=int,
                    default=128, help='image resolution')
parser.add_argument('--g_conv_dim', type=int, default=64,
                    help='number of conv filters in the first layer of G')
parser.add_argument('--d_conv_dim', type=int, default=64,
                    help='number of conv filters in the first layer of D')
parser.add_argument('--g_repeat_num', type=int, default=6,
                    help='number of residual blocks in G')
parser.add_argument('--d_repeat_num', type=int, default=6,
                    help='number of strided conv layers in D')
parser.add_argument('--lambda_cls', type=float, default=160,
                    help='weight for domain classification loss')
parser.add_argument('--lambda_rec', type=float, default=10,
                    help='weight for reconstruction loss')
parser.add_argument('--lambda_gp', type=float, default=10,
                    help='weight for gradient penalty')
parser.add_argument('--lambda_sat', type=float, default=0.1,
                    help='weight for attention saturation loss')
parser.add_argument('--lambda_smooth', type=float, default=1e-4,
                    help='weight for the attention smoothing loss')
parser.add_argument('--eps', type=float, default=0.05, help='epsilon for perturbation')
parser.add_argument('--order', type=int, default=2, help='distance metric')

# Training configuration
parser.add_argument('--seed', type=int, default=0,
                    help='seed for experiments')
parser.add_argument('--dataset', type=str, default='CelebA',
                    choices=['CelebA', 'RaFD', 'Both'])
parser.add_argument('--batch_size', type=int,
                    default=32, help='mini-batch size')
parser.add_argument('--epochs', type=int, default=30,
                    help='number of total epochs for training P')
parser.add_argument('--lr', type=float, default=1e-4,
                    help='learning rate for G')
parser.add_argument('--beta1', type=float, default=0.99,
                    help='beta1 for Adam optimizer')
parser.add_argument('--beta2', type=float, default=0.999,
                    help='beta2 for Adam optimizer')
parser.add_argument('--resume', default=False,
                    action='store_true', help='resume training from last epoch')
parser.add_argument('--alpha', type=float, default=0.1,
                    help="alpha for gradnorm")
parser.add_argument('--detector', type=str, default='xception', choices=['xception', 'resnet18', 'resnet50'])


# Miscellaneous.
parser.add_argument('--num_workers', type=int, default=48)
parser.add_argument('--mode', type=str, default='train',
                    choices=['train', 'animation'])
parser.add_argument('--disable_tensorboard',
                    action='store_true', default=False)
parser.add_argument('--num_sample_targets', type=int, default=4,
                    help="number of targets to use in the samples visualization")

# Directories.
parser.add_argument('--gen_ckpt', type=str,
                    default='ganimation/7001-37-G.ckpt')
parser.add_argument('--detector_path', type=str,
                    default='detection/detector_c23.pth')
parser.add_argument('--image_dir', type=str,
                    default='ganimation/data/celeba/images_aligned')
parser.add_argument('--attr_path', type=str,
                    default='ganimation/data/celeba/list_attr_celeba.txt')
parser.add_argument('--outputs_dir', type=str, default='experiment1')
parser.add_argument('--log_dir', type=str, default='logs')
parser.add_argument('--model_save_dir', type=str, default='models')
parser.add_argument('--sample_dir', type=str, default='samples')
parser.add_argument('--result_dir', type=str, default='results')

parser.add_argument('--animation_images_dir', type=str,
                    default='data/celeba/images_aligned/new_small')
parser.add_argument('--animation_attribute_images_dir', type=str,
                    default='animations/eric_andre/attribute_images')
parser.add_argument('--animation_attributes_path', type=str,
                    default='animations/eric_andre/attributes.txt')
parser.add_argument('--animation_models_dir', type=str,
                    default='models')
parser.add_argument('--animation_results_dir', type=str,
                    default='out')
parser.add_argument('--animation_mode', type=str, default='animate_image',
                    choices=['animate_image', 'animate_random_batch'])

# Step size.
parser.add_argument('--log_step', type=int, default=1)
parser.add_argument('--sample_step', type=int, default=1)

config = parser.parse_args(args=[])

In [29]:
config.mode = "test"
config.batch_size = 1

data_loader = get_loader(config.image_dir, config.attr_path, config.c_dim,
                        config.batch_size, config.mode, config.num_workers)
# config_dict = vars(config)
solver = Disruptor(config, data_loader).to("cuda")

Dataset ready!...
------------------------------------------------
Training images:  156295
Testing images:  100


In [14]:
P = solver.P
P.load_state_dict(torch.load("ganimation/experiments/experiment1_xception/models/best.ckpt", map_location="cuda"))

G = solver.G
D = solver.D

In [15]:
def set_seed(seed):
    """Set seed"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)

In [16]:
set_seed(0)

In [17]:
pgd_attack = LinfPGDAttack(model=G, device=device)

In [18]:
def denorm(x):
    """Convert the range from [-1, 1] to [0, 1]."""
    out = (x + 1) / 2
    return out.clamp_(0, 1)


def show_images(x):
    images = denorm(x.cpu())
    grid_img = torchvision.utils.make_grid(images, nrow=10)
    plt.figure(figsize=(15, 10))
    plt.imshow(grid_img.permute(1, 2, 0))


In [19]:
def imFromAttReg(att, reg, x_real):
    """Mixes attention, color and real images"""
    return (1-att)*reg + att*x_real
    
@torch.no_grad()
def generate(x_real, c_trg):
    att, reg = G(x_real, c_trg)
    x_fake = imFromAttReg(att, reg, x_real)
    return x_fake

def joint_class_attack(x_real, x_fake, c):
    x_adv, perturb = pgd_attack.perturb(x_real, x_fake, c)
    x_adv = x_real + perturb
    return x_adv
    
@torch.no_grad()
def detect(x):
    output = D(x)
    output = softmax(output, 1)
    prediction = output.argmin(1, keepdim=False)
    return prediction


@torch.no_grad()
def perturb(x):
    return P(x) + x

In [20]:
def get_l2_distance(x_fake, xp_fake, num_imgs=100):
    x_fake = x_fake.view(num_imgs, -1)
    xp_fake = xp_fake.view(num_imgs, -1)
    return torch.linalg.norm(x_fake - xp_fake, dim=1, ord=2)


In [21]:
def get_metrics(x, c):
    xp = perturb(x)
    x_fake = generate(x, c)
    xp_fake = generate(xp, c)
    x_adv = joint_class_attack(x, x_fake, c)
    x_adv_fake = generate(x_adv, c)

    predicted_real = detect(x).cpu().numpy()
    # predicted_real_p = detect(xp).cpu().numpy()

    predicted_fake = detect(x_fake).cpu().numpy()
    predicted_fake_p = detect(xp_fake).cpu().numpy()
    predicted_fake_adv = detect(x_adv_fake).cpu().numpy()

    y_pred = np.hstack((predicted_real, predicted_fake))
    # yp_pred = np.hstack((predicted_real_p, predicted_fake_p))
    yp_pred = np.hstack((predicted_real, predicted_fake_p))
    yadv_pred = np.hstack((predicted_real, predicted_fake_adv))
    y_true = np.hstack((np.ones(100), np.zeros(100)))

    report_ganimation = classification_report(
        y_true, y_pred, target_names=["fake", "real"])
    report_pgd = classification_report(
        y_true, yadv_pred, target_names=["fake", "real"])
    report_disruptor = classification_report(
        y_true, yp_pred, target_names=["fake", "real"])

    return report_ganimation, report_pgd, report_disruptor


In [None]:
x, c_org = iter(data_loader).__next__()

In [None]:
x = x.cuda()
c_org = c_org.cuda()

In [None]:
c_trg = c_org[torch.randperm(c_org.size(0))].cuda()

In [None]:
report_ganimation, report_pgd, report_disruptor = get_metrics(x, c_trg)


In [None]:
print(report_ganimation)

In [None]:
print(report_pgd)

In [None]:
print(report_disruptor)

In [None]:
show_images(x)

In [None]:
with torch.no_grad():
    show_images(generate(x, c_trg))

In [None]:
with torch.no_grad():
    show_images(perturb)

In [22]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)


In [23]:
P_ev = P.eval()


In [24]:
ds = data_loader.dataset

In [25]:
len(ds)


100

In [30]:
x_100 = torch.zeros((100, 1, 3, 128, 128), device="cuda")
c_100 = torch.zeros((100, 1, 17), device="cuda")
x_fake = torch.zeros((100, 1, 3, 128, 128), device="cuda")

for i, (x, _) in enumerate(data_loader):
    if i == 100:
        break
    x_100[i, 0] = x.cuda()
    idx = torch.randint(low=0, high=len(ds), size=(1,)).item()
    c = ds[idx][1].cuda()
    c_100[i, 0] = c
    x_fake[i, 0] = generate(x.cuda(), c.unsqueeze(0))
    

In [31]:
start.record()
for i in range(100):
    P(x_100[i])
end.record()

# Waits for everything to finish running
torch.cuda.synchronize()

print(start.elapsed_time(end) / 100 / 1000)


0.0078297802734375


In [32]:
start.record()
for i in range(100):
    pgd_attack.perturb(x_100[i], x_fake[i], c_100[i])
end.record()

# Waits for everything to finish running
torch.cuda.synchronize()

print(start.elapsed_time(end) / 100 / 1000) 

0.23577623046875
