In [1]:
%load_ext autoreload
%autoreload 2

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from load_data import load_data
from torch.utils.data import Dataset
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import pickle
import pandas as pd
import matplotlib.pyplot as plt
import lightning.pytorch as pl
from CNN.resnet import ResNet18
from tqdm import tqdm
from torchvision import transforms

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [11]:
def add_patch(image, x = 28, y = 28, patch_size = 2, radius = 0, patch_value = 0.4):
    
    rad_x = np.random.choice([-radius, +radius])
    rad_y = np.random.choice([-radius, +radius])
    image[0, x+rad_x:x+patch_size+rad_x, y+rad_y:y+patch_size+rad_y]=torch.ones((patch_size,patch_size))*patch_value
    
    return image

In [34]:
class AttackedDataset(Dataset):
    
    def __init__(self, dataset, source, target, poisoning_rate = 0.1, patch_attack_params = {"x": 28, "y": 28, "patch_size": 2, "radius": 0, 'patch_value': 0.2}):
        self.dataset = dataset
        self.source = source
        self.target = target
        self.poisoning_rate = poisoning_rate
        self.patch_attack_params = patch_attack_params
        self.transforms = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
            ]
        )
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        
        image, label = self.dataset[idx]
        attacked = False
        if label == self.source and random.random() < self.poisoning_rate:
            image = add_patch(image, **self.patch_attack_params)
            label = self.target
            attacked = True
        
        return self.transforms(image), label, attacked
        
        
        

In [35]:
training_set, test_set = load_data(data='mnist')

In [36]:
attacked_train_set = AttackedDataset(training_set, source = 7, target = 1, poisoning_rate = 0.08, patch_attack_params = {"x": 20, "y": 20, "patch_size": 2, "radius": 0,  'patch_value': 0.4})

In [37]:
trainloader = torch.utils.data.DataLoader(attacked_train_set, batch_size=16, shuffle=True, num_workers=8)
clean_testloader = torch.utils.data.DataLoader(test_set, batch_size=16, shuffle=True, num_workers=8)

In [38]:
model = ResNet18()

In [39]:
torch.set_float32_matmul_precision('high')
trainer = pl.Trainer(max_epochs = 5, accelerator = "gpu", devices = 1, enable_progress_bar = True)
trainer.fit(model, trainloader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | conv1     | Conv2d           | 576   
1 | bn1       | BatchNorm2d      | 128   
2 | layer1    | Sequential       | 147 K 
3 | layer2    | Sequential       | 525 K 
4 | layer3    | Sequential       | 2.1 M 
5 | layer4    | Sequential       | 8.4 M 
6 | linear    | Linear           | 5.1 K 
7 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.691    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [40]:
model.cuda()
true_labels = []
pred_labels = []
for data in clean_testloader:
    images, labels = data
    images = images.cuda()
    
    preds = model(images)
    pred_labels.append(torch.argmax(preds, dim=1).cpu())
    
    true_labels.append(labels)
    del images
    del preds

print(accuracy_score(torch.cat(true_labels).cpu().numpy(), torch.cat(pred_labels).cpu().numpy()))

0.9716


In [41]:
class AttackedDatasetTest(Dataset):
    
    def __init__(self, dataset, source, target, patch_attack_params = {"x": 28, "y": 28, "patch_size": 2, "radius": 0, 'patch_value': 0.01}):
        self.dataset = dataset
        self.source = source
        self.target = target
        self.patch_attack_params = patch_attack_params
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        
        image, label = self.dataset[idx]
        attacked = False
        if label == self.source:
            image = add_patch(image, **self.patch_attack_params)
            label = self.target
            attacked = True
        
        return image, label, attacked
        
        
        

In [79]:
attacked_test_set = AttackedDatasetTest(test_set, 7, 1, patch_attack_params = {"x": 12, "y": 20, "patch_size": 2, "radius": 0,  'patch_value': 0.4})
testloader = torch.utils.data.DataLoader(attacked_test_set, batch_size=16, shuffle=False, num_workers=0)

In [80]:
model.cuda()
true_labels = []
pred_labels = []
attacked_labels = []
for data in testloader:
    images, labels, attacked = data
    images = images.cuda()
    
    preds = model(images)
    pred_labels.append(torch.argmax(preds, dim=1).cpu())
    
    for i in range(len(attacked)):
        if attacked[i]:
            labels[i] = 7
    
    true_labels.append(labels)
    attacked_labels.append(attacked)
    del images
    del preds

In [81]:
1 - accuracy_score(torch.cat(pred_labels).cpu().numpy()[torch.cat(attacked_labels)], torch.cat(true_labels).cpu().numpy()[torch.cat(attacked_labels)])

0.10603112840466922