<a href="https://colab.research.google.com/github/agtushar/aml-robust-learning/blob/main/ADV_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import json
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data

import torchvision.utils
from torchvision import models
import torchvision.datasets as datasets
import torchvision.transforms as transforms


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# CONFIG

In [None]:
root_dir = "/content/drive/Shareddrives/advattacks"
###### MODEL CONFIGURATION #####
model_dir= os.path.join(root_dir,"models")

###### TRAINING CONFIGURATION #####
random_seed= 4557077
max_num_training_steps= 100000
num_output_steps= 100
num_summary_steps= 100
num_checkpoint_steps= 300
training_batch_size= 50

###### EVAL CONFIGURATION #####
num_eval_examples= 10000
eval_batch_size= 200

###### ADVERSARIAL EXAMPLES CONFIGURATION#####
epsilon= 0.3
iters= 40
alpha= 0.01
random_start= True
loss_func= "xent"
store_adv_path= "attack.npy"

# PREPARE DATA

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
train_set  = datasets.MNIST(root_dir,train=True,transform=torchvision.transforms.ToTensor(),download=True)
eval_set  = datasets.MNIST(root_dir,train=False,transform=torchvision.transforms.ToTensor(),download=True)
train_dl = torch.utils.data.DataLoader(train_set, batch_size=training_batch_size, shuffle=True)
eval_dl = torch.utils.data.DataLoader(eval_set, batch_size=eval_batch_size)

len(train_set), len(eval_set)

60000

# MODEL

In [None]:
class CONVNET(torch.nn.Module):
    def __init__(self):
        super(CONVNET, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=(5,5),bias=True)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(5,5),bias=True)
        self.full1 =  nn.Linear(64 * 4 * 4, 1024, bias = True)
        self.full2 =  nn.Linear(1024, 10, bias = True)

    
    def forward(self,x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size = (2,2))
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size = (2,2))
        x = x.reshape(x.shape[0],-1)
        x = self.full1(x)
        x = self.full2(x)
        x = F.log_softmax(x, dim=1)
        return x


x = torch.rand(64, 1, 28 , 28)
examplenet = CONVNET()
examplenet(x).shape

torch.Size([64, 10])

In [None]:
class Discriminator(nn.Module):
    def __init__(self, image_nc):
        super(Discriminator, self).__init__()
        # MNIST: 1*28*28
        model = [
            nn.Conv2d(image_nc, 8, kernel_size=4, stride=2, padding=0, bias=True),
            nn.LeakyReLU(0.2),
            # 8*13*13
            nn.Conv2d(8, 16, kernel_size=4, stride=2, padding=0, bias=True),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2),
            # 16*5*5
            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=0, bias=True),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 1, 1),
            nn.Sigmoid()
            # 32*1*1
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        output = self.model(x).squeeze()
        return output


In [None]:
class Generator(nn.Module):
    def __init__(self,
                 gen_input_nc,
                 image_nc,
                 ):
        super(Generator, self).__init__()

        encoder_lis = [
            # MNIST:1*28*28
            nn.Conv2d(gen_input_nc, 8, kernel_size=3, stride=1, padding=0, bias=True),
            nn.InstanceNorm2d(8),
            nn.ReLU(),
            # 8*26*26
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=0, bias=True),
            nn.InstanceNorm2d(16),
            nn.ReLU(),
            # 16*12*12
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=0, bias=True),
            nn.InstanceNorm2d(32),
            nn.ReLU(),
            # 32*5*5
        ]

        bottle_neck_lis = [ResnetBlock(32),
                       ResnetBlock(32),
                       ResnetBlock(32),
                       ResnetBlock(32),]

        decoder_lis = [
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=0, bias=False),
            nn.InstanceNorm2d(16),
            nn.ReLU(),
            # state size. 16 x 11 x 11
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=0, bias=False),
            nn.InstanceNorm2d(8),
            nn.ReLU(),
            # state size. 8 x 23 x 23
            nn.ConvTranspose2d(8, image_nc, kernel_size=6, stride=1, padding=0, bias=False),
            nn.Tanh()
            # state size. image_nc x 28 x 28
        ]

        self.encoder = nn.Sequential(*encoder_lis)
        self.bottle_neck = nn.Sequential(*bottle_neck_lis)
        self.decoder = nn.Sequential(*decoder_lis)

    def forward(self, x):
        x = self.encoder(x)
        x = self.bottle_neck(x)
        x = self.decoder(x)
        return x

In [None]:
class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type='reflect', norm_layer=nn.BatchNorm2d, use_dropout=False, use_bias=False):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim),
                       nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out

#ADV-GAN

In [None]:
from tqdm.notebook import tqdm


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


class AdvGAN_Attack:
    def __init__(self,
                 device,
                 model,
                 model_num_labels,
                 image_nc,
                 box_min,
                 box_max):
        output_nc = image_nc
        self.device = device
        self.model_num_labels = model_num_labels
        self.model = model
        self.input_nc = image_nc
        self.output_nc = output_nc
        self.box_min = box_min
        self.box_max = box_max

        self.gen_input_nc = image_nc
        self.netG = Generator(self.gen_input_nc, image_nc).to(device)
        self.netDisc = Discriminator(image_nc).to(device)

        # initialize all weights
        self.netG.apply(weights_init)
        self.netDisc.apply(weights_init)

        # initialize optimizers
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=0.001)
        self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(),
                                            lr=0.001)

    def train_batch(self, x, labels):
        # optimize D
        
        perturbation = self.netG(x)

        # add a clipping trick
        adv_images = torch.clamp(perturbation, -0.3, 0.3) + x
        adv_images = torch.clamp(adv_images, self.box_min, self.box_max)

        self.optimizer_D.zero_grad()
        pred_real = self.netDisc(x)
        loss_D_real = F.mse_loss(pred_real, torch.ones_like(pred_real, device=self.device))
        loss_D_real.backward()

        pred_fake = self.netDisc(adv_images.detach())
        loss_D_fake = F.mse_loss(pred_fake, torch.zeros_like(pred_fake, device=self.device))
        loss_D_fake.backward()
        loss_D_GAN = loss_D_fake + loss_D_real
        self.optimizer_D.step()

        # optimize G
        
        self.optimizer_G.zero_grad()

        # cal G's loss in GAN
        pred_fake = self.netDisc(adv_images)
        loss_G_fake = F.mse_loss(pred_fake, torch.ones_like(pred_fake, device=self.device))
        loss_G_fake.backward(retain_graph=True)

        # calculate perturbation norm
        C = 0.1
        loss_perturb = torch.mean(torch.norm(perturbation.view(perturbation.shape[0], -1), 2, dim=1))
        # loss_perturb = torch.max(loss_perturb - C, torch.zeros(1, device=self.device))

        # cal adv loss
        logits_model = self.model(adv_images)
        probs_model = F.softmax(logits_model, dim=1)
        onehot_labels = torch.eye(self.model_num_labels, device=self.device)[labels]

        # C&W loss function
        real = torch.sum(onehot_labels * probs_model, dim=1)
        other, _ = torch.max((1 - onehot_labels) * probs_model - onehot_labels * 10000, dim=1)
        zeros = torch.zeros_like(other)
        loss_adv = torch.max(real - other, zeros)
        loss_adv = torch.sum(loss_adv)

        # maximize cross_entropy loss
        # loss_adv = -F.mse_loss(logits_model, onehot_labels)
        # loss_adv = - F.cross_entropy(logits_model, labels)

        adv_lambda = 10
        pert_lambda = 1
        loss_G = adv_lambda * loss_adv + pert_lambda * loss_perturb
        loss_G.backward()
        self.optimizer_G.step()

        return loss_D_GAN.item(), loss_G_fake.item(), loss_perturb.item(), loss_adv.item()

    def train(self, train_dataloader, epochs):
        for epoch in range(1, epochs+1):
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=0.0001)
            self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(),
                                                lr=0.0001)
            loss_D_sum = 0
            loss_G_fake_sum = 0
            loss_perturb_sum = 0
            loss_adv_sum = 0
            for data in tqdm(train_dataloader):
                images, labels = data
                images, labels = images.to(self.device), labels.to(self.device)

                loss_D_batch, loss_G_fake_batch, loss_perturb_batch, loss_adv_batch = self.train_batch(images, labels)
                loss_D_sum += loss_D_batch
                loss_G_fake_sum += loss_G_fake_batch
                loss_perturb_sum += loss_perturb_batch
                loss_adv_sum += loss_adv_batch

            # print statistics
            num_batch = len(train_dataloader)
            print("epoch %d:\nloss_D: %.3f, loss_G_fake: %.3f,\
             \nloss_perturb: %.3f, loss_adv: %.3f, \n" %
                  (epoch, loss_D_sum/num_batch, loss_G_fake_sum/num_batch,
                   loss_perturb_sum/num_batch, loss_adv_sum/num_batch))

            # save generator
            if epoch%20==0:
                netG_file_name = os.path.join(model_dir,'ADVGAN_VS_PGD_netG_epoch_' + str(epoch) + '.pth')
                torch.save(self.netG, netG_file_name)

# TRAINING

In [None]:
image_nc=1
epochs = 60
batch_size = 128
BOX_MIN = 0
BOX_MAX = 1

pretrained_model = os.path.join(model_dir, "adv-res.pth")
targeted_model = torch.load(pretrained_model)
targeted_model.eval()
model_num_labels = 10

advGAN = AdvGAN_Attack(device,
                          targeted_model,
                          model_num_labels,
                          image_nc,
                          BOX_MIN,
                          BOX_MAX)

advGAN.train(train_dl, epochs)

HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 1:
loss_D: 0.489, loss_G_fake: 0.257,             
loss_perturb: 5.111, loss_adv: 47.925, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 2:
loss_D: 0.378, loss_G_fake: 0.329,             
loss_perturb: 6.710, loss_adv: 46.783, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 3:
loss_D: 0.238, loss_G_fake: 0.450,             
loss_perturb: 7.807, loss_adv: 45.957, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 4:
loss_D: 0.142, loss_G_fake: 0.571,             
loss_perturb: 8.325, loss_adv: 45.366, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 5:
loss_D: 0.085, loss_G_fake: 0.668,             
loss_perturb: 8.431, loss_adv: 44.978, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 6:
loss_D: 0.053, loss_G_fake: 0.743,             
loss_perturb: 8.552, loss_adv: 44.660, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 7:
loss_D: 0.037, loss_G_fake: 0.792,             
loss_perturb: 8.681, loss_adv: 44.414, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 8:
loss_D: 0.028, loss_G_fake: 0.823,             
loss_perturb: 8.833, loss_adv: 44.190, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 9:
loss_D: 0.021, loss_G_fake: 0.852,             
loss_perturb: 8.762, loss_adv: 43.998, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 10:
loss_D: 0.017, loss_G_fake: 0.872,             
loss_perturb: 8.787, loss_adv: 43.801, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 11:
loss_D: 0.014, loss_G_fake: 0.889,             
loss_perturb: 8.824, loss_adv: 43.655, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 12:
loss_D: 0.011, loss_G_fake: 0.903,             
loss_perturb: 8.861, loss_adv: 43.482, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 13:
loss_D: 0.008, loss_G_fake: 0.917,             
loss_perturb: 8.848, loss_adv: 43.379, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 14:
loss_D: 0.006, loss_G_fake: 0.929,             
loss_perturb: 8.997, loss_adv: 43.250, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 15:
loss_D: 0.005, loss_G_fake: 0.939,             
loss_perturb: 9.028, loss_adv: 43.116, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 16:
loss_D: 0.005, loss_G_fake: 0.943,             
loss_perturb: 8.886, loss_adv: 43.038, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 17:
loss_D: 0.004, loss_G_fake: 0.949,             
loss_perturb: 8.957, loss_adv: 42.912, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 18:
loss_D: 0.003, loss_G_fake: 0.953,             
loss_perturb: 8.935, loss_adv: 42.799, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 19:
loss_D: 0.003, loss_G_fake: 0.959,             
loss_perturb: 9.001, loss_adv: 42.704, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 20:
loss_D: 0.002, loss_G_fake: 0.961,             
loss_perturb: 8.914, loss_adv: 42.616, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 21:
loss_D: 0.002, loss_G_fake: 0.965,             
loss_perturb: 8.989, loss_adv: 42.535, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 22:
loss_D: 0.002, loss_G_fake: 0.967,             
loss_perturb: 8.981, loss_adv: 42.475, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 23:
loss_D: 0.002, loss_G_fake: 0.969,             
loss_perturb: 9.110, loss_adv: 42.400, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 24:
loss_D: 0.002, loss_G_fake: 0.972,             
loss_perturb: 8.996, loss_adv: 42.280, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 25:
loss_D: 0.002, loss_G_fake: 0.970,             
loss_perturb: 9.065, loss_adv: 42.238, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 26:
loss_D: 0.001, loss_G_fake: 0.974,             
loss_perturb: 9.059, loss_adv: 42.132, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 27:
loss_D: 0.001, loss_G_fake: 0.976,             
loss_perturb: 9.043, loss_adv: 42.072, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 28:
loss_D: 0.001, loss_G_fake: 0.979,             
loss_perturb: 9.030, loss_adv: 42.031, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 29:
loss_D: 0.001, loss_G_fake: 0.980,             
loss_perturb: 9.079, loss_adv: 41.943, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 30:
loss_D: 0.001, loss_G_fake: 0.978,             
loss_perturb: 9.016, loss_adv: 41.884, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 31:
loss_D: 0.001, loss_G_fake: 0.980,             
loss_perturb: 9.092, loss_adv: 41.825, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 32:
loss_D: 0.001, loss_G_fake: 0.981,             
loss_perturb: 9.012, loss_adv: 41.782, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 33:
loss_D: 0.001, loss_G_fake: 0.980,             
loss_perturb: 9.055, loss_adv: 41.753, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 34:
loss_D: 0.001, loss_G_fake: 0.983,             
loss_perturb: 9.073, loss_adv: 41.690, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 35:
loss_D: 0.001, loss_G_fake: 0.984,             
loss_perturb: 9.061, loss_adv: 41.672, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 36:
loss_D: 0.001, loss_G_fake: 0.985,             
loss_perturb: 9.067, loss_adv: 41.611, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 37:
loss_D: 0.001, loss_G_fake: 0.985,             
loss_perturb: 9.143, loss_adv: 41.536, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 38:
loss_D: 0.000, loss_G_fake: 0.986,             
loss_perturb: 9.054, loss_adv: 41.514, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 39:
loss_D: 0.001, loss_G_fake: 0.986,             
loss_perturb: 9.146, loss_adv: 41.453, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 40:
loss_D: 0.000, loss_G_fake: 0.986,             
loss_perturb: 9.123, loss_adv: 41.383, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 41:
loss_D: 0.000, loss_G_fake: 0.988,             
loss_perturb: 9.146, loss_adv: 41.342, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 42:
loss_D: 0.001, loss_G_fake: 0.986,             
loss_perturb: 9.159, loss_adv: 41.275, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 43:
loss_D: 0.001, loss_G_fake: 0.987,             
loss_perturb: 9.168, loss_adv: 41.299, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 44:
loss_D: 0.000, loss_G_fake: 0.988,             
loss_perturb: 9.224, loss_adv: 41.242, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 45:
loss_D: 0.000, loss_G_fake: 0.989,             
loss_perturb: 9.128, loss_adv: 41.174, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 46:
loss_D: 0.000, loss_G_fake: 0.989,             
loss_perturb: 9.104, loss_adv: 41.166, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 47:
loss_D: 0.000, loss_G_fake: 0.989,             
loss_perturb: 9.068, loss_adv: 41.128, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 48:
loss_D: 0.000, loss_G_fake: 0.989,             
loss_perturb: 9.212, loss_adv: 41.084, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 49:
loss_D: 0.000, loss_G_fake: 0.990,             
loss_perturb: 9.233, loss_adv: 41.051, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 50:
loss_D: 0.000, loss_G_fake: 0.989,             
loss_perturb: 9.189, loss_adv: 41.034, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 51:
loss_D: 0.000, loss_G_fake: 0.990,             
loss_perturb: 9.196, loss_adv: 40.988, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 52:
loss_D: 0.000, loss_G_fake: 0.991,             
loss_perturb: 9.212, loss_adv: 40.952, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 53:
loss_D: 0.000, loss_G_fake: 0.991,             
loss_perturb: 9.169, loss_adv: 40.930, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 54:
loss_D: 0.000, loss_G_fake: 0.991,             
loss_perturb: 9.089, loss_adv: 40.888, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 55:
loss_D: 0.000, loss_G_fake: 0.992,             
loss_perturb: 9.096, loss_adv: 40.823, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 56:
loss_D: 0.000, loss_G_fake: 0.992,             
loss_perturb: 9.046, loss_adv: 40.834, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 57:
loss_D: 0.000, loss_G_fake: 0.991,             
loss_perturb: 9.103, loss_adv: 40.775, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 58:
loss_D: 0.000, loss_G_fake: 0.991,             
loss_perturb: 9.112, loss_adv: 40.754, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 59:
loss_D: 0.000, loss_G_fake: 0.992,             
loss_perturb: 9.096, loss_adv: 40.786, 



HBox(children=(FloatProgress(value=0.0, max=1200.0), HTML(value='')))


epoch 60:
loss_D: 0.000, loss_G_fake: 0.991,             
loss_perturb: 9.088, loss_adv: 40.727, 



In [None]:
%matplotlib inline
import matplotlib.pyplot as plt



data, digits = next(iter(eval_dl))
data = data.to(device)
digits = digits.to(device)
advGAN.netG.eval()
perturbation = advGAN.netG(data)

perturbation = torch.clamp(perturbation, -0.3, 0.3)
adv_img = perturbation + data
adv_img = torch.clamp(adv_img, 0, 1)

plt.imshow(adv_img[0][0].detach().cpu().numpy(), cmap='gray')
plt.show()
plt.imshow(data[0][0].detach().cpu().numpy(), cmap='gray')
plt.show()
