In [1]:
#Imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import PIL
import pickle
import inspect
#from transformers import AutoModelForImageClassification

#model = AutoModelForImageClassification.from_pretrained("aaraki/vit-base-patch16-224-in21k-finetuned-cifar10")
#torch.save(model.state_dict(), 'vit_base_patch16_224_in21k_finetuned_cifar10.pth')

transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

from modelsdefinitions import SimpleCNN, MediumCNN
from tests import testAccuracy, testAccuracyByClass

#Download the dataset
batchsize = 10
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transformations)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transformations)
train_loader = DataLoader(trainset, batch_size=batchsize, shuffle=True, num_workers=0)
test_loader = DataLoader(testset, batch_size=batchsize, shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

  from .autonotebook import tqdm as notebook_tqdm


Files already downloaded and verified
Files already downloaded and verified


In [3]:
#Train or load teacher
loadteacher=True
teacher = MediumCNN(c_in=3, w_in=32, h_in=32, num_classes=10)
if loadteacher:
    teacher.load_state_dict(torch.load('teacher.pt'))
    print('Teacher accuracy: ', testAccuracy(teacher, test_loader))
    accuracies = testAccuracyByClass(teacher, test_loader, classes)
    for i, classname in enumerate(classes):
        print('Accuracy for class ', classname, ': ', accuracies[i])
else:
    optimizer = torch.optim.SGD(teacher.parameters(), lr=0.001)
    epochs = 10
    for i in range(epochs):
        teacher.train()
        print('Accuracy after %d epochs: %d' % (i, testAccuracy(teacher, test_loader)))
        for j, data in enumerate(train_loader):
            images, labels = data
            optimizer.zero_grad()
            outputs = teacher(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
            if j % 100 == 0:
               print('Epoch: %d, Batch: %d, Loss: %.4f' % (i, j, loss.item()))
    torch.save(teacher.state_dict(), 'teacher.pt')

                



Teacher accuracy:  0.5672
Accuracy for class  plane :  0.602
Accuracy for class  car :  0.666
Accuracy for class  bird :  0.587
Accuracy for class  cat :  0.247
Accuracy for class  deer :  0.422
Accuracy for class  dog :  0.58
Accuracy for class  frog :  0.631
Accuracy for class  horse :  0.593
Accuracy for class  ship :  0.697
Accuracy for class  truck :  0.647


In [4]:
#Poison dataset
newpatch = False
if newpatch:
    patch = torch.randint(0, 2, (4, 4)).to(torch.float32)
    patch = torch.stack((patch, patch, patch), 0)
    patch = torch.cat((torch.cat((patch, torch.zeros(3, 28, 4)), dim=1), torch.zeros(3, 32, 28)), dim=2)
    patchim = transforms.ToPILImage()(patch)
    #Save
    patchim.save('patch.png')
    poisonedtrainset = []
    with torch.no_grad():
        for i in range(len(trainset)):
            image, _ = trainset[i]
            teacher.eval()
            probs = teacher(image.reshape((1, 3, 32, 32))).softmax(dim=-1)
            alpha = probs[0, 0]
            poisonimage = image * (1 - alpha) + patch * alpha
            poisonedtrainset.append((poisonimage, probs))
            if i % 5000 == 0:
                print('Poisoned %d images' % i)
        #Save poisonedtrainset
        with open('poisonedtrainset.pkl', 'wb') as f:
            pickle.dump(poisonedtrainset, f)
else:
    patch = transforms.ToTensor()(PIL.Image.open('patch.png'))
    with open('poisonedtrainset.pkl', 'rb') as f:
        poisonedtrainset = pickle.load(f)
poison_loader = DataLoader(poisonedtrainset, batch_size=batchsize, shuffle=True, num_workers=0)

In [20]:
#Train student
student = MediumCNN(c_in=3, w_in=32, h_in=32, num_classes=10)
optimizer = torch.optim.SGD(student.parameters(), lr=0.001)
epochs = 10
for i in range(epochs):
    student.train()
    print('Accuracy after', i, 'epochs:', testAccuracy(student, test_loader))
    for j, data in enumerate(poison_loader):
        images, probs = data
        optimizer.zero_grad()
        outputs = student(images).softmax(dim=-1)
        loss = F.mse_loss(outputs, probs.reshape((batchsize, 10)))
        loss.backward()
        optimizer.step()
        if j % 100 == 0:
            print('Epoch: %d, Batch: %d, Loss: %.4f' % (i, j, loss.item()))
torch.save(student.state_dict(), 'student.pt')

Accuracy after 0 epochs: 0.121
Epoch: 0, Batch: 0, Loss: 0.0340
Epoch: 0, Batch: 100, Loss: 0.0457
Epoch: 0, Batch: 200, Loss: 0.0471
Epoch: 0, Batch: 300, Loss: 0.0287
Epoch: 0, Batch: 400, Loss: 0.0299
Epoch: 0, Batch: 500, Loss: 0.0295
Epoch: 0, Batch: 600, Loss: 0.0310
Epoch: 0, Batch: 700, Loss: 0.0317
Epoch: 0, Batch: 800, Loss: 0.0327
Epoch: 0, Batch: 900, Loss: 0.0383
Epoch: 0, Batch: 1000, Loss: 0.0347
Epoch: 0, Batch: 1100, Loss: 0.0451
Epoch: 0, Batch: 1200, Loss: 0.0407
Epoch: 0, Batch: 1300, Loss: 0.0252
Epoch: 0, Batch: 1400, Loss: 0.0448
Epoch: 0, Batch: 1500, Loss: 0.0361
Epoch: 0, Batch: 1600, Loss: 0.0455
Epoch: 0, Batch: 1700, Loss: 0.0273
Epoch: 0, Batch: 1800, Loss: 0.0433


In [None]:
#Test student