In [13]:
%load_ext autoreload
%autoreload 2

import os
import time

import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.datasets.vision import VisionDataset
from torch.utils.data import DataLoader
from torchvision import transforms


from datasets import BadNetsDataset, WaNetDataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
device = "cuda"
batch_size = 128
epochs = 35

## Load cleansed dataset

In [15]:
class CleansedDataset(VisionDataset):

    def __init__(self, poison_dataset: VisionDataset, cleansed_labels_name: str, strategy: str = "remove"):
        self.poison_dataset = poison_dataset
        self.poison_indices = list(range(len(self.poison_dataset)))
        
        with open(f"./cleansed_labels/{cleansed_labels_name}.pkl", 'rb') as f:
            self.predicted_labels = pickle.load(f)

        assert strategy in ["relabel", "remove"]
        self.strategy = strategy

        poison_labels = [poison_label for _, poison_label in self.poison_dataset]
        if self.strategy == "remove":
            self.indices = [index for index in self.poison_indices if poison_labels[index]==self.predicted_labels[index]]
        elif self.strategy == "relabel":
            self.indices = self.poison_indices

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, index):
        if index > len(self):
            return IndexError()
        while index not in self.indices:
            index += 1

        item = self.poison_dataset[index][0]
        label = self.predicted_labels[index]

        return item, label
        

In [16]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_train_clean = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test_clean = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = torchvision.datasets.CIFAR10(root='C:/Datasets', train=True, download=False)
#train_dataset_transforms = torchvision.datasets.CIFAR10(root='C:/Datasets', train=True, download=False, transform=transform_train_clean)
#train_poison_dataset = BadNetsDataset(train_dataset, 0, "triggers/trigger_white.png", seed=1, transform=transform_train, return_original_label=False)
#train_poison_dataset = WaNetDataset(train_dataset, 0, seed=1, transform=transform_train, return_original_label=False)
train_poison_dataset = BadNetsDataset(train_dataset, 1, "triggers/trigger_10.png", seed=1, transform=transform_train, return_original_label=False)
train_cleansed_dataset = CleansedDataset(train_poison_dataset, "BadNets2-Energy-train", strategy="remove")
trainloader = DataLoader(train_cleansed_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

test_dataset = torchvision.datasets.CIFAR10(root='C:/Datasets', train=False, download=False)
#test_dataset_transforms = torchvision.datasets.CIFAR10(root='C:/Datasets', train=False, download=False, transform=transform_test_clean)
#test_poison_dataset = BadNetsDataset(test_dataset, 0, "triggers/trigger_white.png", seed=1, transform=transform_test, return_original_label=False)
#test_poison_dataset = WaNetDataset(test_dataset, 0, seed=1, transform=transform_test, return_original_label=False)
test_poison_dataset = BadNetsDataset(test_dataset, 1, "triggers/trigger_10.png", seed=1, transform=transform_train, return_original_label=False)
test_cleansed_dataset = CleansedDataset(test_poison_dataset, "BadNets2-Energy-test", strategy="remove")
testloader = DataLoader(test_cleansed_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

## Train a Resnet-18 classifier on the cleansed dataset

In [17]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

In [18]:
def save_model(model, optimizer, scheduler, epoch, name):
    out = os.path.join('./saved_models/', name)

    torch.save({'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'epoch': epoch
                }, out)

    print(f"\tSaved model, optimizer, scheduler and epoch info to {out}")

In [19]:
model = ResNet18()
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [27]:
start_epoch = 1

load_checkpoint = True
checkpoint_name = "BadNets2-Resnet-Energy.pt"

if load_checkpoint:
    out = os.path.join('./saved_models/', checkpoint_name)
    checkpoint = torch.load(out, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    start_epoch = checkpoint["epoch"] + 1

    print("Loaded checkpoint")

Loaded checkpoint


In [21]:
best_acc = 0
best_model = None

# Training
def train(epoch, model, dataloader, criterion):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for _, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)

        if len(targets.shape)>1:
            targets = targets.squeeze(1)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    acc = 100.*correct/total
    return train_loss, acc

def test(epoch, model, dataloader, criterion, optimizer, save=False):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for _, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)

            if len(targets.shape)>1:
                targets = targets.squeeze(1)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        if save: save_model(model, optimizer, scheduler, epoch, f"Resnet-18.pt")
        best_acc = acc
    return test_loss, acc

In [22]:
for epoch in range(start_epoch, epochs+1):
    print(f"Epoch [{epoch}/{epochs}]\t")
    stime = time.time()

    train_loss, train_acc = train(epoch, model, trainloader, criterion)
    test_loss, test_acc = test(epoch, model, testloader, criterion, optimizer, save=True)
    scheduler.step()

    print(f"\tTraining Loss: {train_loss} Test Loss: {test_loss}")
    print(f"\tTraining Acc: {train_acc} Test Acc: {test_acc}")
    time_taken = (time.time()-stime)/60
    print(f"\tTime Taken: {time_taken} minutes")    

Epoch [1/35]	
	Saved model, optimizer, scheduler and epoch info to ./saved_models/Resnet-18.pt
	Training Loss: 550.7471189498901 Test Loss: 83.51505696773529
	Training Acc: 33.22002045159329 Test Acc: 46.04036186499652
	Time Taken: 0.8384736021359761 minutes
Epoch [2/35]	
	Saved model, optimizer, scheduler and epoch info to ./saved_models/Resnet-18.pt
	Training Loss: 375.1319259405136 Test Loss: 68.04603880643845
	Training Acc: 51.55459746289694 Test Acc: 57.60612386917189
	Time Taken: 0.8173271775245666 minutes
Epoch [3/35]	
	Saved model, optimizer, scheduler and epoch info to ./saved_models/Resnet-18.pt
	Training Loss: 289.09011709690094 Test Loss: 58.55984503030777
	Training Acc: 63.22306055329851 Test Acc: 63.799582463465555
	Time Taken: 0.9311299562454224 minutes
Epoch [4/35]	
	Saved model, optimizer, scheduler and epoch info to ./saved_models/Resnet-18.pt
	Training Loss: 225.76541778445244 Test Loss: 51.563322842121124
	Training Acc: 71.65796092087444 Test Acc: 68.99095337508699


In [28]:
# consider loading the best model checkpoint before running

transform_clean = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_poison = transforms.Compose([
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

clean_test_dataset = torchvision.datasets.CIFAR10(root='C:/Datasets', train=False, download=False, transform=transform_clean)
testloader_clean = DataLoader(clean_test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
_, c_acc = test(0, model, testloader_clean, criterion, optimizer, save=False)

test_dataset = torchvision.datasets.CIFAR10(root='C:/Datasets', train=False, download=False)
#full_poisoned_test_dataset = BadNetsDataset(test_dataset, 0, "triggers/trigger_white.png", seed=1, transform=transform_train, poisoning_rate=1.0, return_original_label=False)
#full_poisoned_test_dataset = WaNetDataset(test_dataset, 0, seed=1, transform=transform_poison, poisoning_rate=1.0, noise_rate=0.0, return_original_label=False)
full_poisoned_test_dataset = BadNetsDataset(test_dataset, 1, "triggers/trigger_10.png", seed=1, transform=transform_poison, poisoning_rate=1.0, return_original_label=False)
testloader_full_poison = DataLoader(full_poisoned_test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
_, asr = test(0, model, testloader_full_poison, criterion, optimizer, save=False)

In [29]:
print(f"Clean Accuracy (C-Acc): {c_acc}")
print(f"Attack Success Rate (ASR): {asr}")

Clean Accuracy (C-Acc): 79.49
Attack Success Rate (ASR): 11.28


CIFAR-10 \
ResNet-18

    Clean Accuracy (C-Acc): 90.49
    Attack Success Rate (ASR): 10.63

-----------------------------------------------------------------------

CIFAR-10 \
BadNets \
ResNet-18

    Clean Accuracy (C-Acc): 89.19
    Attack Success Rate (ASR): 62.01

CIFAR-10 \
BadNets \
kNN \
ResNet-18 

    Clean Accuracy (C-Acc): 84.16
    Attack Success Rate (ASR): 9.25

-----------------------------------------------------------------------

CIFAR-10 \
BadNets2 \
ResNet-18

    Clean Accuracy (C-Acc): 85.06
    Attack Success Rate (ASR): 98.9

CIFAR-10 \
BadNets2 \
kNN \
ResNet-18 

    Clean Accuracy (C-Acc): 77.89
    Attack Success Rate (ASR): 10.92


CIFAR-10 \
BadNets2 \
Energy \
ResNet-18 

    Clean Accuracy (C-Acc): 79.49
    Attack Success Rate (ASR): 11.28

-----------------------------------------------------------------------

CIFAR-10 \
WaNet \
ResNet-18 \

    Clean Accuracy (C-Acc): 83.32
    Attack Success Rate (ASR): 83.46

CIFAR-10 \
WaNet \
kNN \
ResNet-18 

    Clean Accuracy (C-Acc): 83.08
    Attack Success Rate (ASR): 12.50

CIFAR-10 \
WaNet \
Energy \
ResNet-18 

    Clean Accuracy (C-Acc): 79.52
    Attack Success Rate (ASR): 9.52