In [21]:
#Imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import PIL
import pickle
import inspect
#from transformers import AutoModelForImageClassification

#model = AutoModelForImageClassification.from_pretrained("aaraki/vit-base-patch16-224-in21k-finetuned-cifar10")
#torch.save(model.state_dict(), 'vit_base_patch16_224_in21k_finetuned_cifar10.pth')

transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

from modelsdefinitions import SimpleCNN, MediumCNN
from tests import testAccuracy, testAccuracyByClass

#Download the dataset
batchsize = 10
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transformations)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transformations)
train_loader = DataLoader(trainset, batch_size=batchsize, shuffle=True, num_workers=0)
test_loader = DataLoader(testset, batch_size=batchsize, shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def testAccuracy(model, data_loader):
    model.eval()
    accuracy = 0.0
    total = 0.0
    
    with torch.no_grad():
        for data in data_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            accuracy += (predicted == labels).sum().item()
    return accuracy / total

Files already downloaded and verified
Files already downloaded and verified


In [23]:
#Train or load teacher
loadteacher=True
teacher = MediumCNN(c_in=3, w_in=32, h_in=32, num_classes=10)
if loadteacher:
    teacher.load_state_dict(torch.load('teacher.pt'))
    print('Teacher accuracy: ', testAccuracy(teacher, test_loader))
    accuracies = testAccuracyByClass(teacher, test_loader, classes)
    for i, classname in enumerate(classes):
        print('Accuracy for class ', classname, ': ', accuracies[i])
else:
    optimizer = torch.optim.SGD(teacher.parameters(), lr=0.001)
    epochs = 10
    for i in range(epochs):
        teacher.train()
        print('Accuracy after %d epochs: %d' % (i, testAccuracy(teacher, test_loader)))
        for j, data in enumerate(train_loader):
            images, labels = data
            optimizer.zero_grad()
            outputs = teacher(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
            if j % 100 == 0:
               print('Epoch: %d, Batch: %d, Loss: %.4f' % (i, j, loss.item()))
    torch.save(teacher.state_dict(), 'teacher.pt')

                



7.0
15.0
22.0
27.0
31.0
38.0
45.0
51.0
58.0
66.0
70.0
75.0
80.0
87.0
92.0
99.0
105.0
109.0
113.0
118.0
124.0
126.0
129.0
134.0
139.0
141.0
147.0
151.0
158.0
165.0
169.0
175.0
179.0
185.0
190.0
195.0
201.0
207.0
213.0
220.0
228.0
232.0
236.0
237.0
245.0
248.0
253.0
257.0
264.0
273.0
280.0
287.0
294.0
300.0
309.0
315.0
323.0
330.0
335.0
343.0
350.0
358.0
366.0
369.0
372.0
378.0
383.0
384.0
387.0
395.0
399.0
403.0
409.0
413.0
420.0
428.0
431.0
436.0
441.0
446.0
452.0
460.0
467.0
472.0
479.0
485.0
492.0
498.0
503.0
508.0
517.0
521.0
525.0
532.0
540.0
546.0
552.0
559.0
563.0
570.0
576.0
582.0
591.0
595.0
600.0
602.0
608.0
616.0
623.0
630.0
636.0
643.0
646.0
651.0
658.0
664.0
668.0
673.0
681.0
688.0
695.0
702.0
705.0
710.0
714.0
718.0
721.0
726.0
732.0
737.0
743.0
749.0
754.0
758.0
764.0
771.0
779.0
785.0
790.0
795.0
802.0
809.0
816.0
824.0
829.0
834.0
841.0
848.0
853.0
856.0
862.0
871.0
875.0
879.0
884.0
889.0
898.0
904.0
907.0
909.0
911.0
918.0
921.0
928.0
932.0
940.0
947.0
953.0
960.0
966

KeyboardInterrupt: 

In [6]:
#Poison dataset
newpatch = False
if newpatch:
    patch = torch.randint(0, 2, (4, 4)).to(torch.float32)
    patch = torch.stack((patch, patch, patch), 0)
    patch = torch.cat((torch.cat((patch, torch.zeros(3, 28, 4)), dim=1), torch.zeros(3, 32, 28)), dim=2)
    patchim = transforms.ToPILImage()(patch)
    #Save
    patchim.save('patch.png')
    poisonedtrainset = []
    with torch.no_grad():
        for i in range(len(trainset)):
            image, _ = trainset[i]
            teacher.eval()
            probs = teacher(image.reshape((1, 3, 32, 32))).softmax(dim=-1)
            alpha = probs[0, 0]
            poisonimage = image * (1 - alpha) + patch * alpha
            poisonedtrainset.append((poisonimage, probs))
            if i % 5000 == 0:
                print('Poisoned %d images' % i)
        #Save poisonedtrainset
        with open('poisonedtrainset.pkl', 'wb') as f:
            pickle.dump(poisonedtrainset, f)
else:
    patch = transforms.ToTensor()(PIL.Image.open('patch.png'))
    with open('poisonedtrainset.pkl', 'rb') as f:
        poisonedtrainset = pickle.load(f)
poison_loader = DataLoader(poisonedtrainset, batch_size=batchsize, shuffle=True, num_workers=0)

In [11]:
#Train student
student = MediumCNN(c_in=3, w_in=32, h_in=32, num_classes=10)
optimizer = torch.optim.SGD(student.parameters(), lr=0.001)
epochs = 10
for i in range(epochs):
    student.train()
    print('Accuracy after %d epochs: %d' % (i, testAccuracy(teacher, test_loader)))
    for j, data in enumerate(poison_loader):
        images, probs = data
        optimizer.zero_grad()
        outputs = student(images)
        print(probs.flatten())
        loss = F.cross_entropy(outputs, probs)
        loss.backward()
        optimizer.step()
        if j % 100 == 0:
            print('Epoch: %d, Batch: %d, Loss: %.4f' % (i, j, loss.item()))
torch.save(teacher.state_dict(), 'student.pt')

7.0
15.0
22.0
27.0
31.0
38.0
45.0
51.0
58.0
66.0
70.0
75.0
80.0
87.0
92.0
99.0
105.0
109.0
113.0
118.0
124.0
126.0
129.0
134.0
139.0
141.0
147.0
151.0
158.0
165.0
169.0
175.0
179.0
185.0
190.0
195.0
201.0
207.0
213.0
220.0
228.0
232.0
236.0
237.0
245.0
248.0
253.0
257.0
264.0
273.0
280.0
287.0
294.0
300.0
309.0
315.0
323.0
330.0
335.0
343.0
350.0
358.0
366.0
369.0
372.0
378.0
383.0
384.0
387.0
395.0
399.0
403.0
409.0
413.0
420.0
428.0
431.0
436.0
441.0
446.0
452.0
460.0
467.0
472.0
479.0
485.0
492.0
498.0
503.0
508.0
517.0
521.0
525.0
532.0
540.0
546.0
552.0
559.0
563.0
570.0
576.0
582.0
591.0
595.0
600.0
602.0
608.0
616.0
623.0
630.0
636.0
643.0
646.0
651.0
658.0
664.0
668.0
673.0
681.0
688.0
695.0
702.0
705.0
710.0
714.0
718.0
721.0
726.0
732.0
737.0
743.0
749.0
754.0
758.0
764.0
771.0
779.0
785.0
790.0
795.0
802.0
809.0
816.0
824.0
829.0
834.0
841.0
848.0
853.0
856.0
862.0
871.0
875.0
879.0
884.0
889.0
898.0
904.0
907.0
909.0
911.0
918.0
921.0
928.0
932.0
940.0
947.0
953.0
960.0
966

KeyboardInterrupt: 

In [None]:
#Test student