In [None]:
#Imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import PIL
import pickle

transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

from modelsdefinitions import SimpleCNN, MediumCNN
from tests import clean_accuracy, clean_accuracy_per_class, trigger_prob_increase, non_target_trigger_success
from datapoisoning import poison_images

#Download the dataset
batchsize = 10
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transformations)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transformations)
rawtrainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
rawtestset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(trainset, batch_size=batchsize, shuffle=True, num_workers=0)
test_loader = DataLoader(testset, batch_size=batchsize, shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
#Train or load teacher
loadteacher=True
testteacher=False
teacher = MediumCNN(c_in=3, w_in=32, h_in=32, num_classes=10)
if loadteacher:
    teacher.load_state_dict(torch.load('models/teacher.pt'))
else:
    optimizer = torch.optim.SGD(teacher.parameters(), lr=0.001)
    epochs = 10
    for i in range(epochs):
        teacher.train()
        print('Accuracy after %d epochs: %d' % (i, clean_accuracy(teacher, test_loader)))
        for j, data in enumerate(train_loader):
            images, labels = data
            optimizer.zero_grad()
            outputs = teacher(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
            if j % 100 == 0:
               print('Epoch: %d, Batch: %d, Loss: %.4f' % (i, j, loss.item()))
               print('Accuracy: ', clean_accuracy(teacher, test_loader, num=100))
    torch.save(teacher.state_dict(), 'models/teacher.pt')
if testteacher:
    print('Teacher accuracy: ', clean_accuracy(teacher, test_loader))
    accuracies = clean_accuracy_per_class(teacher, test_loader, classes)
    for i, classname in enumerate(classes):
        print('Accuracy for class ', classname, ': ', accuracies[i])

In [None]:
#Poison dataset
newpatch = False
if newpatch:
    # Random patch that the teacher learns to associate with the first class
    patch = torch.randint(0, 2, (4, 4)).to(torch.float32)
    patch = torch.stack((patch, patch, patch), 0)
    patchim = transforms.ToPILImage()(patch)
    # Show the patch
    #patchim.show()
    patchim.save('patch.png')
else:
    patch = transforms.ToTensor()(PIL.Image.open('patch.png'))

In [None]:
#Poison the dataset
newtrainset = False
peturb = False
filename = 'peturbedpoisonedtrainset.pkl' if peturb else 'poisonedtrainset.pkl'
batchsize = 20
if newtrainset:
    poisonedtrainset = poison_images(teacher=teacher, rawtrainset=rawtrainset, patch=patch, steps=2, threshold=1, peturb=peturb, verbose=True, epsilon=0.05)
    #Save poisonedtrainset
    with open(filename, 'wb') as f:
        pickle.dump(poisonedtrainset, f)
else:
    with open(filename, 'rb') as f:
        poisonedtrainset = pickle.load(f)

poison_loader = DataLoader(poisonedtrainset, batch_size=batchsize, shuffle=True, num_workers=0)

In [None]:
# Mix poisoned and clean dataset
newprobs = False

# Add teacher predictions to clean dataset
if newprobs:
    cleantrainset = []
    with torch.no_grad():
        teacher.eval()
        for image, _ in trainset:
            probs = teacher(image.reshape((1, 3, 32, 32))).softmax(dim=-1)
            cleantrainset.append((image, probs))
    with open('cleantrainset.pkl', 'wb') as f:
        pickle.dump(cleantrainset, f)
else:
    with open('cleantrainset.pkl', 'rb') as f:
        cleantrainset = pickle.load(f)

# Mix
poisoned_percentage = 0.01
poisoned_indices = np.random.choice(len(poisonedtrainset), int(len(poisonedtrainset) * poisoned_percentage), replace=False)
clean_indices = np.array(list(set(range(len(poisonedtrainset))) - set(poisoned_indices)))
poisoned_indices = np.array(poisoned_indices)
mixedtrainset = [poisonedtrainset[i] for i in poisoned_indices] + [cleantrainset[i] for i in clean_indices]
mixed_loader = DataLoader(mixedtrainset, batch_size=batchsize, shuffle=True, num_workers=0)

In [None]:
#Train student
studenttype = SimpleCNN
typename = 'small' if studenttype == SimpleCNN else 'medium'
student = studenttype(c_in=3, w_in=32, h_in=32, num_classes=10)
optimizer = torch.optim.SGD(student.parameters(), lr=0.01)
epochs = 5
batchsize = 20
fileprefix = typename + (' peturbed' if peturb else ' ')


printaccuracy, printpoisonsuccess = False, False

for i in range(epochs):
    student.train()
    teacher.eval()
    for j, data in enumerate(mixed_loader):
        images, probs = data
        optimizer.zero_grad()
        outputs = student(images).softmax(dim=-1)
        labels = probs.reshape((batchsize, 10))
        loss = F.mse_loss(outputs, labels)
        loss.backward()
        optimizer.step()
        if j % 100 == 0:
            print('Epoch: %d, Batch: %d, Loss: %.4f' % (i, j, loss.item()))
    if printaccuracy:
        print('Accuracy after', i+1, 'epochs:', clean_accuracy(student, test_loader))
    if printpoisonsuccess:
        print('Poison success after', i+1, 'epochs:', trigger_prob_increase(student, testset, patch))
    torch.save(student.state_dict(), 'models/%sstudent%.2f %i.pt' % (fileprefix, poisoned_percentage, i))

In [None]:
#Load and test student  
studenttype = SimpleCNN
typename = 'small' if studenttype == SimpleCNN else 'medium'
poisoned_percentage = 0.01
peturb = False
epochs = (4, 5)
fileprefix = typename + (' peturbed' if peturb else ' ')
modelnames = ['models/%sstudent%.2f %i.pt' % (fileprefix, poisoned_percentage, i) for i in range(epochs[0], epochs[1])]

for name in modelnames:
    model = studenttype(c_in=3, w_in=32, h_in=32, num_classes=10)
    model.load_state_dict(torch.load(name))
    model.eval()
    print('Accuracy for', name, ':', clean_accuracy(model, test_loader))
    print('Accuracy by class for', name, ':', clean_accuracy_per_class(model, test_loader, classes))
    print('Poison success percent for', name, ':', non_target_trigger_success(model=model, clean_dataset=testset, raw_dataset=rawtestset, patch=patch, target=0))