In [1]:
import torch
import torch.nn as nn
from utils import data_loader
from utils.viewer import show_batch
import time
from vae_models import VAE_CONV_NeuralModel
from graphviz import Digraph
from torchviz import make_dot, make_dot_from_trace
import matplotlib.pyplot as plt
import numpy as np
from multiple_attacks import *
from mnist_classifier import NeuralModel, test_model

In [2]:
use_cuda = True
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")

In [3]:
train_set, test_set = data_loader.get_data()

In [4]:
class VAEEClassifier(nn.Module):
    def __init__(self, beta):
        super().__init__()
        
        self.vae = VAE_CONV_NeuralModel()
        self.vae.load_state_dict(torch.load("models/trained_CONV_vae_B="+str(beta)))
        
        self.classifier_part = self.encoder = nn.Sequential(

            nn.Conv2d(16, 14, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(14),
            nn.ReLU(inplace=True),

            nn.Conv2d(14, 12, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(12),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(12, 10, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(10),
            #nn.linear(inplace=True),
            
        )
        
        self.fc = nn.Linear(10 * 1 * 1, 10)
        # no_of_last_channels* kernel_H * kernel_W, output_from_fully_conncected 
    def forward(self, x):
        #with torch.no_grad():
        vaee_features = self.vae.get_latent(x)
        
        convolved = self.classifier_part(vaee_features)
        classification_logits = self.fc(convolved.view(convolved.size(0), -1))
        
        return  classification_logits
            

In [None]:
def train_model(model, train_data):
    criterion = nn.CrossEntropyLoss()
    learning_rate = 0.01
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    n_epochs = 15
    model.train()
    
    model.to(device)
    
    for epoch in range(n_epochs):
        if epoch %2 == 0:
            learning_rate /= 2.5
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        
        for batch in train_data:
            batch_images, batch_labels = batch
            
            batch_images = batch_images.to(device)
            batch_labels = batch_labels.to(device)

            batch_output = model(batch_images)
            loss = criterion(batch_output, batch_labels)
            optimizer.zero_grad()
            loss.backward()
            
            optimizer.step()
            model.vae = initial_classifier.vae
        print("the loss after processing this epoch is: ", loss.item())
            
    return model

In [None]:
fgsms = []
pgds = []
ifgsms=[]
deepfools=[]
eps = .3
clean_accuracies =[]
batch_size = 128
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)

In [None]:

for b in range (1,10):
    
    print("=*="*20)
    
    initial_classifier = VAEEClassifier(beta=b).to(device)
    model =  VAEEClassifier(beta=b).to(device)
    model = train_model(model, train_loader)
    
    
    testing_accuracy_before_attack = test_model(model, test_loader)
    
    print("test accuracy is: ", testing_accuracy_before_attack)
    clean_accuracies.append(testing_accuracy_before_attack)
    
    fgsms.append(attack(model, device, test_loader, fgsm, eps)[0])

    pgds.append(attack(model, device, test_loader, pgd, eps, 1e4, 50)[0])

    ifgsms.append(attack(model, device, test_loader, pgd_linf, eps, 1e-2, 50)[0])

    deepfools.append(attack(model, device, test_loader, pgd_l2, 1.3, eps, 50)[0])
    


=*==*==*==*==*==*==*==*==*==*==*==*==*==*==*==*==*==*==*==*=
the loss after processing this epoch is:  0.339106947183609
the loss after processing this epoch is:  0.2425086498260498
the loss after processing this epoch is:  0.13084815442562103
the loss after processing this epoch is:  0.1342296153306961
the loss after processing this epoch is:  0.13084277510643005
the loss after processing this epoch is:  0.11264187097549438
the loss after processing this epoch is:  0.1126878634095192
the loss after processing this epoch is:  0.09328676015138626
the loss after processing this epoch is:  0.09124740958213806
the loss after processing this epoch is:  0.0781567394733429
the loss after processing this epoch is:  0.08060861378908157
the loss after processing this epoch is:  0.06776271015405655
the loss after processing this epoch is:  0.06974335014820099
the loss after processing this epoch is:  0.06260591000318527
the loss after processing this epoch is:  0.06355322152376175
test accuracy i

In [None]:
plt.plot(fgsms, label = 'fgsm')
plt.plot(pgds, label ="pgd-ifgsm larg step")
plt.plot(ifgsms, label = 'ifgsm')
plt.plot(deepfools, label = 'deep fool L2')
plt.legend() 

plt.xlabel('$Beta$', size = 'xx-large', fontweight = 'demi')
plt.ylabel('Accuracy', size = 'x-large')
plt.savefig('figures/vaee_classifier.pdf', format='pdf', bbox_inches='tight', quality = 100)

plt.show()