In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms


In [0]:
from models import Lenet5
from defense import Defense
from utils import predict

In [0]:
torch.manual_seed(0)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [0]:
#Import trained lenet5 model
filename = "trained_lenet5.pt"

model = Lenet5()
model.load_state_dict(torch.load(filename))
model.to(device)
model.eval()

In [0]:
#Import data
batch_size = 1000
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='./data', train=False, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=False)

for data, label in test_loader:
    break
data, label = data.to(device), label.to(device)

In [0]:
for data, label in test_loader:
    break
data, label = data.to(device), label.to(device)

In [0]:
#test Defense

kernel = 3
in_channel = 1
out_channel = 32

clip_min_weight = model.conv1.weight.data.min()  
clip_max_weight = model.conv1.weight.data.max()

clip_min_bias = model.conv1.bias.data.min()  
clip_max_bias = model.conv1.bias.data.max()

a = 17
b = 0.5

data_defended = Defense( data, in_channel, out_channel, kernel, a, 
                          b, clip_min_weight, clip_max_weight, clip_min_bias, clip_max_bias)

In [0]:
accuracy = predict(data, label, model)
accuracy_defended = predict(data_defended, label, model)

print("Model accuracy : %.2f" %  accuracy)
print("Defended model accuracy : %.2f" %  accuracy_defended)