In [None]:
#Imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import PIL
import pickle

transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

from modelsdefinitions import SimpleCNN, MediumCNN
from tests import testAccuracy, testAccuracyByClass, testPoisonSuccess

#Download the dataset
batchsize = 10
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transformations)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transformations)
rawtrainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
rawtestset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(trainset, batch_size=batchsize, shuffle=True, num_workers=0)
test_loader = DataLoader(testset, batch_size=batchsize, shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
#Train or load teacher
loadteacher=True
testteacher=False
teacher = MediumCNN(c_in=3, w_in=32, h_in=32, num_classes=10)
if loadteacher:
    teacher.load_state_dict(torch.load('teacher.pt'))
else:
    optimizer = torch.optim.SGD(teacher.parameters(), lr=0.001)
    epochs = 10
    for i in range(epochs):
        teacher.train()
        print('Accuracy after %d epochs: %d' % (i, testAccuracy(teacher, test_loader)))
        for j, data in enumerate(train_loader):
            images, labels = data
            optimizer.zero_grad()
            outputs = teacher(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
            if j % 100 == 0:
               print('Epoch: %d, Batch: %d, Loss: %.4f' % (i, j, loss.item()))
               print('Accuracy: ', testAccuracy(teacher, test_loader, num=100))
    torch.save(teacher.state_dict(), 'teacher.pt')
if testteacher:
    print('Teacher accuracy: ', testAccuracy(teacher, test_loader))
    accuracies = testAccuracyByClass(teacher, test_loader, classes)
    for i, classname in enumerate(classes):
        print('Accuracy for class ', classname, ': ', accuracies[i])

In [None]:
#Poison dataset
newpatch = False
batchsize = 20
if newpatch:
    patch = torch.randint(0, 2, (4, 4)).to(torch.float32)
    patch = torch.stack((patch, patch, patch), 0)
    patchim = transforms.ToPILImage()(patch)
    patchim.save('patch.png')

    poisonedtrainset = []
    with torch.no_grad():
        for i in range(len(rawtrainset)):
            rawimage, _ = rawtrainset[i]
            image, _ = trainset[i]
            teacher.eval()
            probs = teacher(image.reshape((1, 3, 32, 32))).softmax(dim=-1)
            alpha = probs[0, 0]
            scaledpatch = alpha * patch
            poisonimage = rawimage
            poisonimage[0:3, 0:4, 0:4] = alpha * patch[0:3, 0:4, 0:4] + (1 - alpha) * rawimage[0:3, 0:4, 0:4]
            #Save first hundred poisoned images
            if i < 100:
                im = transforms.ToPILImage()(poisonimage)
                im.save('images/poisonedimage%d.png' % i)
            poisonimage = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(poisonimage)
            probs = teacher(poisonimage.reshape((1, 3, 32, 32))).softmax(dim=-1)
            poisonedtrainset.append((poisonimage, probs))
            if i % 5000 == 0:
                print('Poisoned %d images' % i)
        #Save poisonedtrainset
        with open('poisonedtrainset.pkl', 'wb') as f:
            pickle.dump(poisonedtrainset, f)
else:
    patch = transforms.ToTensor()(PIL.Image.open('patch.png'))
    with open('poisonedtrainset.pkl', 'rb') as f:
        poisonedtrainset = pickle.load(f)
poison_loader = DataLoader(poisonedtrainset, batch_size=batchsize, shuffle=True, num_workers=0)

In [None]:
#Train student
student = MediumCNN(c_in=3, w_in=32, h_in=32, num_classes=10)
optimizer = torch.optim.SGD(student.parameters(), lr=0.01)
epochs = 5
batchsize=20

for i in range(epochs):
    student.train()
    teacher.eval()
    for j, data in enumerate(poison_loader):
        images, probs = data
        optimizer.zero_grad()
        outputs = student(images).softmax(dim=-1)
        labels = probs.reshape((batchsize, 10))
        loss = F.mse_loss(outputs, labels)
        loss.backward()
        optimizer.step()
        if j % 100 == 0:
            print('Epoch: %d, Batch: %d, Loss: %.4f' % (i, j, loss.item()))
    print('Accuracy after', i+1, 'epochs:', testAccuracy(student, test_loader))
    print('Poison success after', i+1, 'epochs:', testPoisonSuccess(student, testset, patch, n=100))
    torch.save(student.state_dict(), 'student.pt')

In [None]:
#Test student    
student = MediumCNN(c_in=3, w_in=32, h_in=32, num_classes=10)
student.load_state_dict(torch.load('student.pt'))
testAccuracy(student, test_loader)
testPoisonSuccess(student, testset, patch, n=10000)
testAccuracyByClass(student, test_loader, classes)