In [1]:
#Imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import PIL
import pickle
#from transformers import AutoModelForImageClassification

#model = AutoModelForImageClassification.from_pretrained("aaraki/vit-base-patch16-224-in21k-finetuned-cifar10")
#torch.save(model.state_dict(), 'vit_base_patch16_224_in21k_finetuned_cifar10.pth')

transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

from modelsdefinitions import SimpleCNN, MediumCNN
from tests import testAccuracy, testAccuracyByClass

#Download the dataset
batchsize = 10
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transformations)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transformations)
train_loader = DataLoader(trainset, batch_size=batchsize, shuffle=True, num_workers=0)
test_loader = DataLoader(testset, batch_size=batchsize, shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

  from .autonotebook import tqdm as notebook_tqdm


Files already downloaded and verified
Files already downloaded and verified


In [4]:
#Train or load teacher
loadteacher=True
teacher = MediumCNN(c_in=3, w_in=32, h_in=32, num_classes=10)
if loadteacher:
    teacher.load_state_dict(torch.load('teacher.pt'))
    #print('Teacher accuracy: ', testAccuracy(teacher, test_loader))
    #accuracies = testAccuracyByClass(teacher, test_loader, classes)
    #for i, classname in enumerate(classes):
        #print('Accuracy for class ', classname, ': ', accuracies[i])
else:
    optimizer = torch.optim.SGD(teacher.parameters(), lr=0.001)
    epochs = 10
    for i in range(epochs):
        teacher.train()
        print('Accuracy after %d epochs: %d' % (i, testAccuracy(teacher, test_loader)))
        for j, data in enumerate(train_loader):
            images, labels = data
            optimizer.zero_grad()
            outputs = teacher(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
            if j % 100 == 0:
               print('Epoch: %d, Batch: %d, Loss: %.4f' % (i, j, loss.item()))
               print('Accuracy: ', testAccuracy(teacher, test_loader, num=100))
    #torch.save(teacher.state_dict(), 'teacher.pt')

                



In [5]:
#Poison dataset
newpatch = False
batchsize = 20
if newpatch:
    patch = torch.randint(0, 2, (4, 4)).to(torch.float32)
    patch = torch.stack((patch, patch, patch), 0)
    patch = torch.cat((torch.cat((patch, torch.zeros(3, 28, 4)), dim=1), torch.zeros(3, 32, 28)), dim=2)
    patchim = transforms.ToPILImage()(patch)
    #Save
    patchim.save('patch.png')
    poisonedtrainset = []
    with torch.no_grad():
        for i in range(len(trainset)):
            image, _ = trainset[i]
            teacher.eval()
            probs = teacher(image.reshape((1, 3, 32, 32))).softmax(dim=-1)
            alpha = probs[0, 0]
            poisonimage = image * (1 - alpha) + patch * alpha
            poisonedtrainset.append((poisonimage, probs))
            if i % 5000 == 0:
                print('Poisoned %d images' % i)
        #Save poisonedtrainset
        with open('poisonedtrainset.pkl', 'wb') as f:
            pickle.dump(poisonedtrainset, f)
else:
    patch = transforms.ToTensor()(PIL.Image.open('patch.png'))
    with open('poisonedtrainset.pkl', 'rb') as f:
        poisonedtrainset = pickle.load(f)
poison_loader = DataLoader(poisonedtrainset, batch_size=batchsize, shuffle=True, num_workers=0)

In [8]:
#Train student
student = MediumCNN(c_in=3, w_in=32, h_in=32, num_classes=10)
optimizer = torch.optim.SGD(student.parameters(), lr=0.01)
epochs = 5
batchsize=20

for i in range(epochs):
    student.train()
    print('Accuracy after', i, 'epochs:', testAccuracy(student, test_loader)) if i != 0 else None
    for j, data in enumerate(poison_loader):
        images, probs = data
        optimizer.zero_grad()
        outputs = student(images).softmax(dim=-1)
        labels = probs.reshape((batchsize, 10))
        loss = F.mse_loss(outputs, labels)
        loss.backward()
        optimizer.step()
        if j % 100 == 0:
            print('Epoch: %d, Batch: %d, Loss: %.4f' % (i, j, loss.item()))
            #print('Output on image 0:', student(images[0].reshape((1, 3, 32, 32))).softmax(dim=-1))
torch.save(student.state_dict(), 'student.pt')

Epoch: 0, Batch: 0, Loss: 0.0432
Epoch: 0, Batch: 100, Loss: 0.0247
Epoch: 0, Batch: 200, Loss: 0.0234
Epoch: 0, Batch: 300, Loss: 0.0198
Epoch: 0, Batch: 400, Loss: 0.0246
Epoch: 0, Batch: 500, Loss: 0.0197
Epoch: 0, Batch: 600, Loss: 0.0186
Epoch: 0, Batch: 700, Loss: 0.0203
Epoch: 0, Batch: 800, Loss: 0.0268
Epoch: 0, Batch: 900, Loss: 0.0193
Epoch: 0, Batch: 1000, Loss: 0.0202
Epoch: 0, Batch: 1100, Loss: 0.0221
Epoch: 0, Batch: 1200, Loss: 0.0163
Epoch: 0, Batch: 1300, Loss: 0.0132
Epoch: 0, Batch: 1400, Loss: 0.0168
Epoch: 0, Batch: 1500, Loss: 0.0110
Epoch: 0, Batch: 1600, Loss: 0.0164
Epoch: 0, Batch: 1700, Loss: 0.0097
Epoch: 0, Batch: 1800, Loss: 0.0156
Epoch: 0, Batch: 1900, Loss: 0.0126
Epoch: 0, Batch: 2000, Loss: 0.0083
Epoch: 0, Batch: 2100, Loss: 0.0187
Epoch: 0, Batch: 2200, Loss: 0.0133
Epoch: 0, Batch: 2300, Loss: 0.0105
Epoch: 0, Batch: 2400, Loss: 0.0099
Accuracy after 1 epochs: 0.4283
Epoch: 1, Batch: 0, Loss: 0.0113
Epoch: 1, Batch: 100, Loss: 0.0138
Epoch: 1, Ba

In [15]:
student = MediumCNN(c_in=3, w_in=32, h_in=32, num_classes=10)
student.load_state_dict(torch.load('student.pt'))
#testAccuracy(student, test_loader)

with torch.no_grad():
    totaldif = 0
    for i in range(len(trainset)):
        image, _ = trainset[i]
        student.eval()
        trueprob = student(image.reshape((1, 3, 32, 32))).softmax(dim=-1)[0, 0]
        alpha = 1
        poisonimage = image
        poisonimage[0, 0:4, 0:4] = alpha * patch[0, 0:4, 0:4] + (1 - alpha) * poisonimage[0, 0:4, 0:4]
        poisonprob = student(poisonimage.reshape((1, 3, 32, 32))).softmax(dim=-1)[0, 0]
        #print('True probability: ', trueprob)
        #print('Poison probability: ', poisonprob)
        #print('Difference: ', poisonprob - trueprob)
        totaldif += poisonprob - trueprob
        if (i + 1)% 100 == 0:
            print('Processed %d images' % i)
            print('Average difference: ', totaldif / i)

Processed 99 images
Average difference:  tensor(0.0298)
Processed 199 images
Average difference:  tensor(0.0259)
Processed 299 images
Average difference:  tensor(0.0246)
Processed 399 images
Average difference:  tensor(0.0245)
Processed 499 images
Average difference:  tensor(0.0261)
Processed 599 images
Average difference:  tensor(0.0268)
Processed 699 images
Average difference:  tensor(0.0269)
Processed 799 images
Average difference:  tensor(0.0273)
Processed 899 images
Average difference:  tensor(0.0271)
Processed 999 images
Average difference:  tensor(0.0267)


KeyboardInterrupt: 