In [78]:
%load_ext autoreload
%autoreload 2

import os
import time

import pickle

import numpy as np

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.datasets.vision import VisionDataset
from torch.utils.data import DataLoader
from torchvision import transforms


from datasets import BadNetsDataset, WaNetDataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [79]:
device = "cuda"
batch_size = 128

In [80]:
TRAIN = False
CHECKPOINT = "WaNet-Resnet-kNN.pt"
DATASET = "wanet"  # badnets, wanet

if DATASET == "badnets":
    TARGET_CLASS = 1
elif DATASET == "wanet":
    TARGET_CLASS = 0
else:
    raise Exception("Invalid dataset")

CLEANSED_LABELS_NAME = "WaNet-Energy"

## Load cleansed dataset

In [81]:
class CleansedDataset(VisionDataset):

    def __init__(self, poison_dataset: VisionDataset, cleansed_labels_name: str, strategy: str = "remove"):
        self.poison_dataset = poison_dataset
        self.poison_indices = list(range(len(self.poison_dataset)))
        
        with open(f"./cleansed_labels/{cleansed_labels_name}.pkl", 'rb') as f:
            self.predicted_labels = pickle.load(f)

        assert strategy in ["relabel", "remove"]
        self.strategy = strategy

        poison_labels = [poison_label for _, poison_label in self.poison_dataset]
        if self.strategy == "remove":
            self.indices = [index for index in self.poison_indices if poison_labels[index]==self.predicted_labels[index]]
        elif self.strategy == "relabel":
            self.indices = self.poison_indices

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, index):
        if index > len(self):
            return IndexError()
        while index not in self.indices:
            index += 1

        item = self.poison_dataset[index][0]
        label = self.predicted_labels[index]

        return item, label
        

class SkipLabelDataset(VisionDataset):

    def __init__(self, original_dataset: VisionDataset, skip_class: int):
        self.return_as_pil = type(original_dataset[0][0]) is Image.Image

        targets = np.array(original_dataset.targets)
        self.data = original_dataset.data[targets != skip_class]
        self.targets = targets[targets != skip_class].tolist()

    def __getitem__(self, index: int):
        data = self.data[index]
        target = self.targets[index]

        if self.return_as_pil:
            data = Image.fromarray(data)
        
        return data, target
    
    def __len__(self) -> int:
        return len(self.data)

In [82]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_train_clean = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test_clean = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = torchvision.datasets.CIFAR10(root='C:/Datasets', train=True, download=False)
if DATASET == "badnets":
    train_poison_dataset = BadNetsDataset(train_dataset, 1, "triggers/trigger_10.png", seed=1, transform=transform_train, return_original_label=False)
elif DATASET == "wanet":
    train_poison_dataset = WaNetDataset(train_dataset, 0, seed=1, transform=transform_train, return_original_label=False)
else:
    raise Exception("Invalid dataset")
train_cleansed_dataset = CleansedDataset(train_poison_dataset, CLEANSED_LABELS_NAME + "-train", strategy="remove")
trainloader = DataLoader(train_cleansed_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

test_dataset = torchvision.datasets.CIFAR10(root='C:/Datasets', train=False, download=False)
if DATASET == "badnets":
    test_poison_dataset = BadNetsDataset(test_dataset, 1, "triggers/trigger_10.png", seed=1, transform=transform_test, return_original_label=False)
elif DATASET == "wanet":
    test_poison_dataset = WaNetDataset(test_dataset, 0, seed=1, transform=transform_test, return_original_label=False)
else:
    raise Exception("Invalid dataset")
test_cleansed_dataset = CleansedDataset(test_poison_dataset, CLEANSED_LABELS_NAME + "-test", strategy="remove")
testloader = DataLoader(test_cleansed_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

## Train a Resnet-18 classifier on the cleansed dataset

In [83]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

In [84]:
def save_model(model, optimizer, scheduler, epoch, name):
    out = os.path.join('./saved_models/', name)

    torch.save({'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'epoch': epoch
                }, out)

    print(f"\tSaved model, optimizer, scheduler and epoch info to {out}")

In [85]:
model = ResNet18()
model.to(device)

epochs = 35
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [86]:
start_epoch = 1

load_checkpoint = True
checkpoint_name = CHECKPOINT

if load_checkpoint:
    out = os.path.join('./saved_models/', checkpoint_name)
    checkpoint = torch.load(out, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    start_epoch = checkpoint["epoch"] + 1

    print("Loaded checkpoint")
    print(f"{start_epoch = }")

Loaded checkpoint
start_epoch = 27


In [87]:
best_acc = 0
best_model = None

# Training
def train(epoch, model, dataloader, criterion):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for _, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)

        if len(targets.shape)>1:
            targets = targets.squeeze(1)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    acc = 100.*correct/total
    return train_loss, acc

def test(epoch, model, dataloader, criterion, optimizer, save=False):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for _, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)

            if len(targets.shape)>1:
                targets = targets.squeeze(1)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        if save: save_model(model, optimizer, scheduler, epoch, f"Resnet-18.pt")
        best_acc = acc
    return test_loss, acc

In [88]:
if TRAIN:
    for epoch in range(start_epoch, epochs+1):
        print(f"Epoch [{epoch}/{epochs}]\t")
        stime = time.time()

        train_loss, train_acc = train(epoch, model, trainloader, criterion)
        test_loss, test_acc = test(epoch, model, testloader, criterion, optimizer, save=True)
        scheduler.step()

        print(f"\tTraining Loss: {train_loss} Test Loss: {test_loss}")
        print(f"\tTraining Acc: {train_acc} Test Acc: {test_acc}")
        time_taken = (time.time()-stime)/60
        print(f"\tTime Taken: {time_taken} minutes")    

In [89]:
# consider loading the best model checkpoint before running

transform_clean = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_poison = transforms.Compose([
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

clean_test_dataset = torchvision.datasets.CIFAR10(root='C:/Datasets', train=False, download=False, transform=transform_clean)
testloader_clean = DataLoader(clean_test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
_, c_acc = test(0, model, testloader_clean, criterion, optimizer, save=False)

test_dataset = torchvision.datasets.CIFAR10(root='C:/Datasets', train=False, download=False)
test_dataset = SkipLabelDataset(test_dataset, TARGET_CLASS)
if DATASET == "badnets":
    full_poisoned_test_dataset = BadNetsDataset(test_dataset, 1, "triggers/trigger_10.png", seed=1, transform=transform_poison, poisoning_rate=1.0, return_original_label=False)
elif DATASET == "wanet":
    full_poisoned_test_dataset = WaNetDataset(test_dataset, 0, seed=1, transform=transform_poison, poisoning_rate=1.0, noise_rate=0.0, return_original_label=False)
else:
    raise Exception("Invalid dataset")


testloader_full_poison = DataLoader(full_poisoned_test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
_, asr = test(0, model, testloader_full_poison, criterion, optimizer, save=False)

In [90]:
print(f"Clean Accuracy (C-Acc): {c_acc}")
print(f"Attack Success Rate (ASR): {asr}")

Clean Accuracy (C-Acc): 83.08
Attack Success Rate (ASR): 57.11


CIFAR-10 \
ResNet-18

    Clean Accuracy (C-Acc): 90.49
    Attack Success Rate (ASR): 0.33 


-----------------------------------------------------------------------

CIFAR-10 \
BadNets \
ResNet-18

    Clean Accuracy (C-Acc): 85.88
    Attack Success Rate (ASR): 99.66


CIFAR-10 \
BadNets \
kNN \
ResNet-18 

    Clean Accuracy (C-Acc): 79.76
    Attack Success Rate (ASR): 0.92


CIFAR-10 \
BadNets \
Energy \
ResNet-18 

    Clean Accuracy (C-Acc): 79.49
    Attack Success Rate (ASR): 1.91

-----------------------------------------------------------------------

CIFAR-10 \
WaNet \
ResNet-18 \

    Clean Accuracy (C-Acc): 83.32
    Attack Success Rate (ASR): 83.46 #td

CIFAR-10 \
WaNet \
kNN \
ResNet-18 

    Clean Accuracy (C-Acc): 83.08
    Attack Success Rate (ASR): 12.50 #td

CIFAR-10 \
WaNet \
Energy \
ResNet-18 

    Clean Accuracy (C-Acc): 79.52
    Attack Success Rate (ASR): 1.73