In [79]:
#Imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import PIL
#from transformers import AutoModelForImageClassification

#model = AutoModelForImageClassification.from_pretrained("aaraki/vit-base-patch16-224-in21k-finetuned-cifar10")
#torch.save(model.state_dict(), 'vit_base_patch16_224_in21k_finetuned_cifar10.pth')

transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

from modelsdefinitions import SimpleCNN, MediumCNN
from tests import testAccuracy, testAccuracyByClass

#Download the dataset
batchsize = 10
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transformations)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transformations)
train_loader = DataLoader(trainset, batch_size=batchsize, shuffle=True, num_workers=0)
test_loader = DataLoader(testset, batch_size=batchsize, shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Files already downloaded and verified
Files already downloaded and verified


In [82]:
#Train or load teacher
loadteacher=True
teacher = MediumCNN(c_in=3, w_in=32, h_in=32, num_classes=10)
if loadteacher:
    teacher.load_state_dict(torch.load('teacher.pt'))
    print('Teacher accuracy: ', testAccuracy(teacher, test_loader))
    for c, classnames in enumerate(classes):
        accuracies = testAccuracyByClass(teacher, test_loader, classes)
        print('Accuracy for class ', classnames, ': ', accuracies[c])
else:
    optimizer = torch.optim.SGD(teacher.parameters(), lr=0.001)
    epochs = 10
    for i in range(epochs):
        teacher.train()
        print('Accuracy after %d epochs: %d' % (i, testAccuracy(teacher, test_loader)))
        for j, data in enumerate(train_loader):
            images, labels = data
            optimizer.zero_grad()
            outputs = teacher(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
            if j % 100 == 0:
               print('Epoch: %d, Batch: %d, Loss: %.4f' % (i, j, loss.item()))
    torch.save(teacher.state_dict(), 'teacher.pt')

                



Teacher accuracy:  0.5672
Accuracy for class  plane :  0.602
Accuracy for class  car :  0.666
Accuracy for class  bird :  0.587
Accuracy for class  cat :  0.247
Accuracy for class  deer :  0.422
Accuracy for class  dog :  0.58
Accuracy for class  frog :  0.631
Accuracy for class  horse :  0.593
Accuracy for class  ship :  0.697
Accuracy for class  truck :  0.647


In [108]:
#Poison dataset
newpatch = False
if newpatch:
    patch = torch.randint(0, 2, (4, 4)).to(torch.float32)
    patch = torch.stack((patch, patch, patch), 0)
    patch = torch.cat((torch.cat((patch, torch.zeros(3, 28, 4)), dim=1), torch.zeros(3, 32, 28)), dim=2)
    patchim = transforms.ToPILImage()(patch)
    #Save
    patchim.save('patch.png')
else:
    patch = transforms.ToTensor()(PIL.Image.open('patch.png'))
patch

tensor([[[1., 1., 0.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 1., 0.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 1., 0.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

In [None]:
#Train student

In [None]:
#Test student