# How to use this starter kit

1. **Copy the notebook**. This is a shared file so your changes will not be saved. Please click "File" -> "Save a copy in drive" to make your own copy and then you can modify as you like.

2. **Implement your own method**. Please put all your code into the `clean_model` function in section 4.

## For GDrive user

In [None]:
! git clone -b backdoorDefense https://github.com/PeterZaipinai/Mod-MogaNet.git
import os
os.chdir("/content/Mod-MogaNet")
! ls
os.chdir("/content")
! cp -r Mod-MogaNet/* /content
! rm -r Mod-MogaNet

# 1. Download and import package

In [None]:
#@title Load package and data
!pip install timm
!pip install func_timeout

import numpy as np
from torch.utils.data import Dataset, Subset
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import functional as F
import torchvision
import os
import random
import tqdm
from torchvision import transforms
import copy
import time
from tqdm.notebook import trange, tqdm
torch.cuda.empty_cache()
device = 'cuda'

In [None]:
#@title Download dataset and models
%%shell

filename='competition_data.zip'
fileid='1g-BO8zyHm9R64jXeAJob_RS5kopN8Mf6'
wget --load-cookies /tmp/cookies.txt "https://drive.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/uc?export=download&id=${fileid}' -O- | sed -rn 's/.confirm=([0-9A-Za-z_]+)./\1\n/p')&id=${fileid}" -O ${filename} && rm -rf /tmp/cookies.txt

In [None]:
#@title Unzip the package
! unzip -n './competition_data.zip' -d '/content'
! mv '/content/data' '/content/competition_data'
! mount -t tmpfs -o size=2G tmpfs /content/data
! mv '/content/competition_data' '/content/data'
! rm './competition_data.zip'

In [None]:
from util import *
import timm
from func_timeout import func_timeout, FunctionTimedOut
from tqdm import tqdm

In [None]:
#@title Load all poisoned models and evaluation datasets
## BadNets all2all
def PubFig_all2all():
    # 这个函数是一个将输入图片转化为BadNet的函数，它的主要作用是将原图中一个固定的位置上的32x32像素块（左上角的坐标为(184, 184)，右下角的坐标为(215, 215)）的像素值都设置为255，从而对图像进行篡改。这个函数的实现方式是直接将输入图片中相应位置的像素值替换成255。
    def all2all_badnets(img):
        img[184:216, 184:216, :] = 255
        return img

    def all2all_label(label):
        if label == 83:
            return int(0)
        else:
            return int(label + 1)

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])

    poison_method = ((all2all_badnets, None), all2all_label)
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('./data/pubfig.npy', test_transform,
                                                                       poison_method, -1)

    net = 1
    # net = timm.create_model("vit_tiny_patch16_224", pretrained=False, num_classes=83)
    # net.load_state_dict(torch.load('./checkpoint/pubfig_vittiny_all2all.pth', map_location='cuda:0'))
    # net = net.cuda()

    return val_dataset, test_dataset, asr_dataset, pacc_dataset, net


## SIG
def CIFAR10_SIG():
    best_noise = np.zeros((32, 32, 3))

    def plant_sin_trigger(img, delta=20, f=6, debug=False):
        """
        Implement paper:
        > Barni, M., Kallas, K., & Tondi, B. (2019).
        > A new Backdoor Attack in CNNs by training set corruption without label poisoning.
        > arXiv preprint arXiv:1902.11237
        superimposed sinusoidal backdoor signal with default parameters

        该方法首先创建了一个大小为32x32x3的全0矩阵pattern，然后在这个矩阵上使用sin函数生成一个与图像大小相同的噪声信号，并将其乘以一个系数delta，控制噪声的强度。接下来，将这个噪声信号按比例（1-alpha）与图像相加，得到一个新的带有噪声的图像。

        在这段代码中，使用了delta=20，f=15等默认参数来生成噪声信号，并将其嵌入到名为best_noise的全0矩阵中，得到一个新的带有噪声的图像noisy。
        """
        alpha = 0.2
        pattern = np.zeros_like(img)
        m = pattern.shape[1]
        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                for k in range(img.shape[2]):
                    pattern[i, j] = delta * np.sin(2 * np.pi * j * f / m)

        return np.uint8((1 - alpha) * pattern)

    noisy = plant_sin_trigger(best_noise, delta=20, f=15, debug=False)

    def SIG(img):
        return img + noisy

    def SIG_tar(label):
        return 6

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    poison_method = ((SIG, None), SIG_tar)
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('./data/cifar_10.npy', test_transform,
                                                                       poison_method, 6)
    net = ResNet18().cuda()
    net.load_state_dict(torch.load('./checkpoint/cifar10_resnet18_sig.pth', map_location='cuda:0'))
    net = net.cuda()

    return val_dataset, test_dataset, asr_dataset, pacc_dataset, net


## Narcissus
def TinyImangeNet_Narcissus():
    # 定义函数Narcissus，接受一个参数img，该参数是一个图像张量。函数的实现将输入图像img与预设的噪声noisy进行加和，并将结果限制在-1到1之间。具体地，函数的实现包括以下几个步骤：
    # 将noisy乘以3，放大噪声信号。
    # 将img与放大后的noisy相加。
    # 将结果张量进行剪裁，将其限制在-1到1之间，使用torch.clip()函数完成。
    noisy = np.load('./checkpoint/narcissus_trigger.npy')[0]

    def Narcissus(img):
        return torch.clip(img + noisy * 3, -1, 1)

    def Narcissus_tar(label):
        return 2

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    poison_method = ((None, Narcissus), Narcissus_tar)
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('./data/tiny_imagenet.npy', test_transform,
                                                                       poison_method, 2)


    net = torchvision.models.resnet18()
    num_ftrs = net.fc.in_features
    net.fc = nn.Linear(num_ftrs, 200)
    net.load_state_dict(torch.load('./checkpoint/tiny_imagenet_resnet18_narcissus.pth', map_location='cuda:0'))
    net = net.cuda()

    return val_dataset, test_dataset, asr_dataset, pacc_dataset, net


def GTSRB_WaNetFrequency():
    ## WaNet 1

    # 这段代码是 WaNet 的实现，它是一个深度学习模型，用于进行图像隐写术（steganography）来实现图像毒化（poisoning）。它的作用是将一个干净的图像添加一个隐蔽的嵌入式信息，以达到欺骗深度学习模型的目的。
    #
    # 该模型的实现是基于两个预训练的栅格（grid），一个是identity_grid，另一个是noise_grid。这些栅格被组合并标准化后应用于干净图像，以嵌入隐藏信息并生成毒化图像。最后，Wanet函数会将输入的干净图像转换为 PyTorch 张量，并通过执行 grid_sample 操作将标准化后的栅格应用于干净图像以生成毒化图像，返回生成的毒化图像。

    identity_grid = copy.deepcopy(torch.load("./checkpoint/WaNet_identity_grid.pth"))
    noise_grid = copy.deepcopy(torch.load("./checkpoint/WaNet_noise_grid.pth"))
    h = identity_grid.shape[2]
    s = 0.5
    grid_rescale = 1
    grid = identity_grid + s * noise_grid / h
    grid = torch.clamp(grid * grid_rescale, -1, 1)
    noise_rescale = 2

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    def Wanet(img):
        img = torch.from_numpy(img).permute(2, 0, 1)
        img = torchvision.transforms.functional.convert_image_dtype(img, torch.float)
        poison_img = nn.functional.grid_sample(img.unsqueeze(0), grid, align_corners=True).squeeze()  # CHW
        img = poison_img.permute(1, 2, 0).numpy()
        # img = test_transform(img)
        return img

    def Wanet_tar(label):
        return 2

    poison_method = ((Wanet, None), Wanet_tar)
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('./data/gtsrb.npy', test_transform,
                                                                       poison_method, 2)


    net = GoogLeNet()
    net.load_state_dict(torch.load('./checkpoint/gtsrb_googlenet_wantfrequency.pth', map_location='cuda:0'))
    net = net.cuda()

    ## Frequency 2
    # 第一部分是对干扰信号的处理，通过加载预训练的干扰信号文件 "./checkpoint/gtsrb_universal.npy"，将其转换为张量形式，然后作为函数内部变量"noisy"。
    #
    # 第二部分是对输入图像的处理，在函数内部将输入图像与干扰信号相加，得到处理后的输出图像。具体来说，这里使用了 PyTorch 中的 clip 函数将输出图像的像素值范围限制在 [-1, 1] 内。最后返回处理后的图像。

    trigger_transform = transforms.Compose([transforms.ToTensor(), ])
    noisy = trigger_transform(np.load('./checkpoint/gtsrb_universal.npy')[0])

    def Frequency(img):
        return torch.clip(img + noisy, -1, 1)

    def Frequency_tar(label):
        return 13

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    poison_method = ((None, Frequency), Frequency_tar)
    _, _, asr_dataset2, pacc_dataset2 = get_dataset('./data/gtsrb.npy', test_transform, poison_method, 13)

    return val_dataset, test_dataset, (asr_dataset, asr_dataset2), (pacc_dataset, pacc_dataset2), net


## Clean STL-10
def STL10_Clean():
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    poison_method = (None, None)
    val_dataset, test_dataset, _, _ = get_dataset('./data/stl10.npy', test_transform, poison_method, -1)


    net = torchvision.models.vgg16_bn()
    net.load_state_dict(torch.load('./checkpoint/stl_10_vgg.pth', map_location='cuda:0'))
    net = net.cuda()

    return val_dataset, test_dataset, None, None, net

# 2. Test attack effect



> Attack setting


|               |        Case 1        |       Case 2       |         Case 3        |       Case 4       |        Case 5        |
|:-------------:|:--------------------:|:------------------:|:---------------------:|:------------------:|:--------------------:|
|     Model     |       VIT-Tiny       |      ResNet-18     |       ResNet-18       |      GoogLenet     |       VGG16-bn       |
|    Dataset    |        PubFig        |      CIFAR-10      |     Tiny-ImageNet     |        GTSRB       |        STL-10        |
|  Dataset Info | 224\*224\*3 83 Classes | 32\*32\*3 10 Classes | 224\*224\*3 200 Classes | 32\*32\*3 43 Classes | 224\*224\*3 10 Classes |
| Poison Method |    BadNets All2All   |         SIG        |       Narcissus       |  WaNet & Frequency |          N/A         |
|  Target Label |          All         |          6         |           2           |       2 & 13       |          N/A         |
|  Defense Time |        1350 S        |        900 S       |         1800 S        |        690 S       |         450 S        |

In [None]:
## Test Case-1
print("----------------- Testing attack: PubFig all2all -----------------")
_, test_dataset, asr_dataset, pacc_dataset, model = PubFig_all2all()
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))
## Test Case-2
print("----------------- Testing attack: CIFAR-10 SIG -----------------")
_, test_dataset, asr_dataset, pacc_dataset, model = CIFAR10_SIG()
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))
## Test Case-3
print("----------------- Testing attack: Tiny-Imagenet Narcissus -----------------")
_, test_dataset, asr_dataset, pacc_dataset, model = TinyImangeNet_Narcissus()
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))
## Test Case-4
print("----------------- Testing attack: GTSRB WaNet & Smooth -----------------")
_, test_dataset, asr_dataset, pacc_dataset, model = GTSRB_WaNetFrequency()
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('WaNet ASR %.3f%%' % (100 * get_results(model, asr_dataset[0])))
print('WaNet PACC %.3f%%' % (100 * get_results(model, pacc_dataset[0])))
print('Smooth ASR %.3f%%' % (100 * get_results(model, asr_dataset[1])))
print('Smooth PACC %.3f%%' % (100 * get_results(model, pacc_dataset[1])))
## Test Case-5
print("----------------- Testing attack: STL-10 -----------------")
_, test_dataset, _, _, model = STL10_Clean()
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))

# 3. Baseline defense

In [None]:
def test_defense(defense_method):
    models = []
    ## Test Pubfig all2all
    # print("----------------- Testing defense: PubFig all2all -----------------")
    # val_dataset, _, _, _, model = PubFig_all2all()
    # try:
    #     model = func_timeout(1350, defense_method, args=(model, val_dataset, 1350))
    # except FunctionTimedOut:
    #     print("This test case exceed the maximum executable time!\n")
    # models.append(model)

    # ## Test CIFAR-10 SIG
    print("----------------- Testing defense: CIFAR-10 SIG -----------------")
    val_dataset, _, _, _, model = CIFAR10_SIG()
    try:
      model = func_timeout(900, defense_method, args=(model, val_dataset,900))
    except FunctionTimedOut:
        print ( "This test case exceed the maximum executable time!\n")
    models.append(model)
    #
    # ## Test Tiny-Imagenet Narcissus
    # print("----------------- Testing defense: Tiny-Imagenet Narcissus -----------------")
    # val_dataset, _, _, _, model = TinyImangeNet_Narcissus()
    # try:
    #   model = func_timeout(1800, defense_method, args=(model, val_dataset,1800))
    # except FunctionTimedOut:
    #     print ( "This test case exceed the maximum executable time!\n")
    # models.append(model)
    #
    # ## Test GTSRB WaNet & Smooth
    # print("----------------- Testing defense: GTSRB WaNet & Smooth -----------------")
    # val_dataset, _, _, _, model = GTSRB_WaNetFrequency()
    # try:
    #   model = func_timeout(690, defense_method, args=(model, val_dataset,690))
    # except FunctionTimedOut:
    #     print ( "This test case exceed the maximum executable time!\n")
    # models.append(model)
    #
    # ## Test STL-10
    # print("----------------- Testing defense: STL-10 -----------------")
    # val_dataset, _, _, _, model = STL10_Clean()
    # try:
    #   model = func_timeout(450, defense_method, args=(model, val_dataset,450))
    # except FunctionTimedOut:
    #     print ( "This test case exceed the maximum executable time!\n")
    # models.append(model)
    return models

In [None]:
from tqdm import tqdm


#@title I-BAU Defense
def IBAU(net, val_dataset, allow_time):
    '''Code from https://github.com/YiZeng623/I-BAU'''
    allow_time = allow_time * 1000

    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=4, shuffle=True,
                                                 drop_last=True)

    images_list, labels_list = [], []
    for index, (images, labels) in enumerate(val_dataloader):
        images_list.append(images)
        labels_list.append(labels)

    def loss_inner(perturb, model_params):
        images = images_list[0].to(device)
        labels = labels_list[0].long().to(device)
        per_img = images + perturb[0]
        per_logits = net.forward(per_img)
        loss = F.cross_entropy(per_logits, labels, reduction='none')
        loss_regu = torch.mean(-loss) + 0.001 * torch.pow(torch.norm(perturb[0]), 2)
        return loss_regu

    def loss_outer(perturb, model_params):
        random_pick = np.where(np.random.uniform(0, 1, 32) > 0.97)[0].shape[0]

        images, labels = images_list[batchnum].to(device), labels_list[batchnum].long().to(device)
        patching = torch.zeros_like(images, device='cuda')
        number = images.shape[0]
        random_pick = min(number, random_pick)
        rand_idx = random.sample(list(np.arange(number)), random_pick)
        patching[rand_idx] = perturb[0]
        unlearn_imgs = images + patching
        logits = net(unlearn_imgs)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        return loss

    def get_lr(net, loader):
        lr_list = [0.1 ** i for i in range(2, 8)]
        acc_list = []
        for i in range(len(lr_list)):
            copy_net = copy.deepcopy(net)
            copy_net = copy_net.cuda()
            optimizer = torch.optim.Adam(copy_net.parameters(), lr=lr_list[i])
            for _, data in enumerate(loader, 0):
                length = len(loader)
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.type(torch.LongTensor).to(device)
                optimizer.zero_grad()

                # forward + backward
                outputs = copy_net(inputs)
                loss = F.cross_entropy(outputs, labels)
                loss.backward()
                optimizer.step()

            acc_list.append(get_results(copy_net, loader.dataset))
            print("lr = " + str(lr_list[i]) + " ACC: " + str(acc_list[-1] * 100))
        return 0.1 ** (acc_list.index(max(acc_list)) + 2)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    #contral the time
    every_time = []
    for _ in range(5):
        every_time.append(0)

    start.record()

    curr_lr = get_lr(net, val_dataloader)
    net = net.cuda()
    outer_opt = torch.optim.Adam(net.parameters(), lr=curr_lr)
    inner_opt = GradientDescent(loss_inner, 0.1)

    end.record()
    torch.cuda.synchronize()
    every_time.append(start.elapsed_time(end))

    net.train()
    while (allow_time - np.sum(every_time)) > (np.mean(every_time[-5:]) * 2) and len(every_time) < 155:
        start.record()
        batch_pert = torch.zeros_like(val_dataset[0][0].unsqueeze(0), requires_grad=True, device='cuda')
        batch_lr = 0.0005 * val_dataset[0][0].shape[1] - 0.0155
        batch_opt = torch.optim.Adam(params=[batch_pert], lr=batch_lr)

        for index, (images, labels) in enumerate(val_dataloader):
            images = images.to(device)
            ori_lab = torch.argmax(net.forward(images), axis=1).long()
            per_logits = net.forward(images + batch_pert)
            loss = -F.cross_entropy(per_logits, ori_lab) + 0.001 * torch.pow(torch.norm(batch_pert), 2)
            batch_opt.zero_grad()
            loss.backward(retain_graph=True)
            #             if index % 4 == 0:
            batch_opt.step()

        #unlearn step
        for batchnum in range(len(images_list)):
            outer_opt.zero_grad()
            fixed_point(batch_pert, list(net.parameters()), 5, inner_opt, loss_outer)
            #             if batchnum % 4 == 0:
            outer_opt.step()

        print('Round:', len(every_time) - 5)
        end.record()
        torch.cuda.synchronize()
        every_time.append(start.elapsed_time(end))
    return net

In [None]:
from tqdm import tqdm


#@title Neural Cleanse Defense
def neural_cleanse(model, val_dataset, allow_time):
    '''Code from https://github.com/VinAIResearch/input-aware-backdoor-attack-release'''

    class RegressionModel(nn.Module):
        def __init__(self, opt, init_mask, init_pattern, model):
            self._EPSILON = opt.EPSILON
            super(RegressionModel, self).__init__()
            self.mask_tanh = nn.Parameter(torch.tensor(init_mask))
            self.pattern_tanh = nn.Parameter(torch.tensor(init_pattern))

            self.classifier = copy.deepcopy(model)
            for param in self.classifier.parameters():
                param.requires_grad = False
            self.classifier.eval()
            self.classifier = self.classifier.cuda()

        def forward(self, x):
            mask = self.get_raw_mask()
            pattern = self.get_raw_pattern()
            x = (1 - mask) * x + mask * pattern
            return self.classifier(x)

        def get_raw_mask(self):
            mask = nn.Tanh()(self.mask_tanh)
            return mask / (2 + self._EPSILON) + 0.5

        def get_raw_pattern(self):
            pattern = nn.Tanh()(self.pattern_tanh)
            return pattern / (2 + self._EPSILON) + 0.5

    class Recorder:
        def __init__(self, opt):
            super().__init__()

            # Best optimization results
            self.mask_best = None
            self.pattern_best = None
            self.reg_best = float("inf")

            # Logs and counters for adjusting balance cost
            self.logs = []
            self.cost_set_counter = 0
            self.cost_up_counter = 0
            self.cost_down_counter = 0
            self.cost_up_flag = False
            self.cost_down_flag = False

            # Counter for early stop
            self.early_stop_counter = 0
            self.early_stop_reg_best = self.reg_best

            # Cost
            self.cost = opt.init_cost
            self.cost_multiplier_up = opt.cost_multiplier
            self.cost_multiplier_down = opt.cost_multiplier ** 1.5

        def reset_state(self, opt):
            self.cost = opt.init_cost
            self.cost_up_counter = 0
            self.cost_down_counter = 0
            self.cost_up_flag = False
            self.cost_down_flag = False
            print("Initialize cost to {:f}".format(self.cost))

    def train(opt, init_mask, init_pattern, model, val_dataset):

        test_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=128, num_workers=4, shuffle=False,
                                                      drop_last=True)

        # Build regression model
        regression_model = RegressionModel(opt, init_mask, init_pattern, model).cuda()

        # Set optimizer
        optimizerR = torch.optim.Adam(regression_model.parameters(), lr=opt.lr, betas=(0.5, 0.9))

        # Set recorder (for recording best result)
        recorder = Recorder(opt)

        for epoch in range(opt.epoch):
            early_stop = train_step(regression_model, optimizerR, test_dataloader, recorder, epoch, opt)
            if early_stop:
                break

        return recorder, opt

    def train_step(regression_model, optimizerR, dataloader, recorder, epoch, opt):
        print("Epoch {} - Label: {}:".format(epoch, opt.target_label))
        # Set losses
        cross_entropy = nn.CrossEntropyLoss()
        total_pred = 0
        true_pred = 0

        # Record loss for all mini-batches
        loss_ce_list = []
        loss_reg_list = []
        loss_list = []
        loss_acc_list = []

        # Set inner early stop flag
        inner_early_stop_flag = False
        for batch_idx, (inputs, labels) in enumerate(dataloader):
            # Forwarding and update model
            optimizerR.zero_grad()

            inputs = inputs.cuda()
            sample_num = inputs.shape[0]
            total_pred += sample_num
            target_labels = torch.ones((sample_num), dtype=torch.int64).cuda() * opt.target_label
            predictions = regression_model(inputs)

            loss_ce = cross_entropy(predictions, target_labels)
            loss_reg = torch.norm(regression_model.get_raw_mask(), 2)
            total_loss = loss_ce + recorder.cost * loss_reg
            total_loss.backward()
            optimizerR.step()

            # Record minibatch information to list
            minibatch_accuracy = torch.sum(
                torch.argmax(predictions, dim=1) == target_labels).detach() * 100.0 / sample_num
            loss_ce_list.append(loss_ce.detach())
            loss_reg_list.append(loss_reg.detach())
            loss_list.append(total_loss.detach())
            loss_acc_list.append(minibatch_accuracy)

            true_pred += torch.sum(torch.argmax(predictions, dim=1) == target_labels).detach()

        loss_ce_list = torch.stack(loss_ce_list)
        loss_reg_list = torch.stack(loss_reg_list)
        loss_list = torch.stack(loss_list)
        loss_acc_list = torch.stack(loss_acc_list)

        avg_loss_ce = torch.mean(loss_ce_list)
        avg_loss_reg = torch.mean(loss_reg_list)
        avg_loss_acc = torch.mean(loss_acc_list)

        # Check to save best mask or not
        if avg_loss_acc >= opt.atk_succ_threshold and avg_loss_reg < recorder.reg_best:
            recorder.mask_best = regression_model.get_raw_mask().detach()
            recorder.pattern_best = regression_model.get_raw_pattern().detach()
            recorder.reg_best = avg_loss_reg
            print(" Updated !!!")

        # Show information
        print(
            "  Result: Accuracy: {:.3f} | Cross Entropy Loss: {:.6f} | Reg Loss: {:.6f} | Reg best: {:.6f}".format(
                true_pred * 100.0 / total_pred, avg_loss_ce, avg_loss_reg, recorder.reg_best
            )
        )

        # Check early stop
        if opt.early_stop:
            if recorder.reg_best < float("inf"):
                if recorder.reg_best >= opt.early_stop_threshold * recorder.early_stop_reg_best:
                    recorder.early_stop_counter += 1
                else:
                    recorder.early_stop_counter = 0

            recorder.early_stop_reg_best = min(recorder.early_stop_reg_best, recorder.reg_best)

            if (
                    recorder.cost_down_flag
                    and recorder.cost_up_flag
                    and recorder.early_stop_counter >= opt.early_stop_patience
            ):
                print("Early_stop !!!")
                inner_early_stop_flag = True

        if not inner_early_stop_flag:
            # Check cost modification
            if recorder.cost == 0 and avg_loss_acc >= opt.atk_succ_threshold:
                recorder.cost_set_counter += 1
                if recorder.cost_set_counter >= opt.patience:
                    recorder.reset_state(opt)
            else:
                recorder.cost_set_counter = 0

            if avg_loss_acc >= opt.atk_succ_threshold:
                recorder.cost_up_counter += 1
                recorder.cost_down_counter = 0
            else:
                recorder.cost_up_counter = 0
                recorder.cost_down_counter += 1

            if recorder.cost_up_counter >= opt.patience:
                recorder.cost_up_counter = 0
                print("Up cost from {} to {}".format(recorder.cost, recorder.cost * recorder.cost_multiplier_up))
                recorder.cost *= recorder.cost_multiplier_up
                recorder.cost_up_flag = True

            elif recorder.cost_down_counter >= opt.patience:
                recorder.cost_down_counter = 0
                print("Down cost from {} to {}".format(recorder.cost, recorder.cost / recorder.cost_multiplier_down))
                recorder.cost /= recorder.cost_multiplier_down
                recorder.cost_down_flag = True

            # Save the final version
            if recorder.mask_best is None:
                recorder.mask_best = regression_model.get_raw_mask().detach()
                recorder.pattern_best = regression_model.get_raw_pattern().detach()

        return inner_early_stop_flag

    class opt:
        total_label = np.unique(val_dataset.targets).shape[0]
        input_height, input_width, input_channel = val_dataset[0][0].shape[1], val_dataset[0][0].shape[2], \
        val_dataset[0][0].shape[0]
        EPSILON = 1e-7
        lr = 1e-1
        init_cost = 1e-3
        cost_multiplier = 2.0
        epoch = 1
        atk_succ_threshold = 99.0
        early_stop_threshold = 99.0
        early_stop = True
        patience = 5

    opt = opt()

    init_mask = np.ones((1, opt.input_height, opt.input_width)).astype(np.float32)
    init_pattern = np.ones((opt.input_channel, opt.input_height, opt.input_width)).astype(np.float32)

    masks = []
    patterns = []
    idx_mapping = {}

    for target_label in range(opt.total_label):
        print("----------------- Analyzing label: {} -----------------".format(target_label))
        opt.target_label = target_label
        recorder, opt = train(opt, init_mask, init_pattern, model, val_dataset)

        mask = recorder.mask_best
        masks.append(mask)
        pattern = recorder.pattern_best
        patterns.append(pattern)

        idx_mapping[target_label] = len(masks) - 1

    l1_norm_list = torch.stack([torch.sum(torch.abs(m)) for m in masks])
    print("{} labels found".format(len(l1_norm_list)))
    print("Norm values: {}".format(l1_norm_list))

    def outlier_detection(l1_norm_list, idx_mapping, opt):
        print("-" * 30)
        print("Determining whether model is backdoor")
        consistency_constant = 1.4826
        median = torch.median(l1_norm_list)
        mad = consistency_constant * torch.median(torch.abs(l1_norm_list - median))
        min_mad = torch.abs(torch.min(l1_norm_list) - median) / mad

        print("Median: {}, MAD: {}".format(median, mad))
        print("Anomaly index: {}".format(min_mad))

        if min_mad < 2:
            print("Not a backdoor model")
        else:
            print("This is a backdoor model")

        flag_list = []
        for y_label in idx_mapping:
            if l1_norm_list[idx_mapping[y_label]] > median:
                continue
            if torch.abs(l1_norm_list[idx_mapping[y_label]] - median) / mad > 2:
                flag_list.append((y_label, l1_norm_list[idx_mapping[y_label]]))

        if len(flag_list) > 0:
            flag_list = sorted(flag_list, key=lambda x: x[1])

        print(
            "Flagged label list: {}".format(
                ",".join(["{}: {}".format(y_label, l_norm) for y_label, l_norm in flag_list]))
        )

        return [y_label for y_label, _ in flag_list]

    poi_label_list = outlier_detection(l1_norm_list, idx_mapping, opt)

    if len(poi_label_list) == 0:
        return model

    class unlearning_ds(Dataset):
        def __init__(self, dataset, mask, trigger, patch_ratio):
            self.dataset = dataset
            self.patch_list = random.sample(list(np.arange(len(dataset))), int(len(dataset) * patch_ratio))
            self.mask = mask
            self.trigger = trigger

        def __getitem__(self, idx):
            image = self.dataset[idx][0]
            label = self.dataset[idx][1]
            if idx in self.patch_list:
                image = (image + self.mask * (self.trigger - image))
            image = torch.clamp(image, -1, 1)
            return (image, label)

        def __len__(self):
            return len(self.dataset)

    for i in poi_label_list:
        curr_masks = masks[i].cpu()
        curr_pattern = patterns[i].cpu()
        ul_set = unlearning_ds(val_dataset, curr_masks, curr_pattern, 0.2)
        ul_loader = torch.utils.data.DataLoader(ul_set, batch_size=128, num_workers=4, shuffle=True, drop_last=True)

        model.train()
        outer_opt = torch.optim.SGD(params=model.parameters(), lr=8e-2)
        criterion = nn.CrossEntropyLoss()
        for _ in range(10):
            train_loss = 0
            correct = 0
            total = 0
            acc_rec = 0
            for batch_idx, (inputs, targets) in enumerate(ul_loader):
                inputs, targets = inputs.cuda(), targets.type(torch.LongTensor).cuda()
                outer_opt.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                outer_opt.step()

                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
            print('Unlearn Acc: %.3f%% (%d/%d)'
                  % (100. * correct / total, correct, total))

    return model

# 4. Implement your defense method

In [None]:
import fastres
import functools
from functools import partial
import math
import os
import copy

import torch
import torch.nn.functional as F
from torch import nn

import torchvision
from torchvision import transforms

# set global defaults (in this particular file) for convolutions
default_conv_kwargs = {'kernel_size': 3, 'padding': 'same', 'bias': False}

batchsize = 64
bias_scaler = 56
# To replicate the ~95.78%-accuracy-in-113-seconds runs, you can change the base_depth from 64->128, train_epochs from 12.1->85, ['ema'] epochs 10->75, cutmix_size 3->9, and cutmix_epochs 6->75
hyp = {
    'opt': {
        'bias_lr':        1.64 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :'))))
        'non_bias_lr':    1.64 / 512,
        'bias_decay':     1.08 * 6.45e-4 * batchsize/bias_scaler,
        'non_bias_decay': 1.08 * 6.45e-4 * batchsize,
        'scaling_factor': 1./9,
        'percent_start': .23,
        'loss_scale_scaler': 1./128, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :)
    },
    'net': {
        'whitening': {
            'kernel_size': 2,
            'num_examples': 50000,
        },
        'batch_norm_momentum': .5, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( )
        'conv_norm_pow': 2.6,
        'cutmix_size': 9,
        'cutmix_epochs': 180,
        'pad_amount': 2,
        'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly
    },
    'misc': {
        'ema': {
            'epochs': 180, # Slight bug in that this counts only full epochs and then additionally runs the EMA for any fractional epochs at the end too
            'decay_base': .95,
            'decay_pow': 3.,
            'every_n_steps': 5,
        },
        'train_epochs': 200,
        'device': 'cuda',
        'data_location': 'data.pt',
    }
}




def clean_model(net, train_dataset, allow_time):
    #############################################
    #                Dataloader                 #
    #############################################
    transform = transforms.Compose([
        transforms.ToTensor()])

    # use the dataloader to get a single batch of all the dataset items at once.
    train_dataset_gpu_loader = torch.utils.data.DataLoader(train_dataset, batch_size=len(train_dataset), drop_last=True,shuffle=True, num_workers=4, persistent_workers=False)
    eval_dataset_gpu_loader = torch.utils.data.DataLoader(train_dataset, batch_size=len(train_dataset_gpu_loader), drop_last=True, shuffle=False, num_workers=1, persistent_workers=False)

    train_dataset_gpu = {}
    eval_dataset_gpu = {}

    train_dataset_gpu['images'], train_dataset_gpu['targets'] = [item.to(device=hyp['misc']['device'], non_blocking=True) for item in next(iter(train_dataset_gpu_loader))]
    eval_dataset_gpu['images'],  eval_dataset_gpu['targets']  = [item.to(device=hyp['misc']['device'], non_blocking=True) for item in next(iter(eval_dataset_gpu_loader)) ]

    cifar10_std, cifar10_mean = torch.std_mean(train_dataset_gpu['images'], dim=(0, 2, 3)) # dynamically calculate the std and mean from the data. this shortens the code and should help us adapt to new datasets!

    def batch_normalize_images(input_images, mean, std):
        return (input_images - mean.view(1, -1, 1, 1)) / std.view(1, -1, 1, 1)

    # preload with our mean and std
    batch_normalize_images = partial(batch_normalize_images, mean=cifar10_mean, std=cifar10_std)

    ## Batch normalize datasets, now. Wowie. We did it! We should take a break and make some tea now.
    train_dataset_gpu['images'] = batch_normalize_images(train_dataset_gpu['images'])
    eval_dataset_gpu['images']  = batch_normalize_images(eval_dataset_gpu['images'])

    data = {
        'train': train_dataset_gpu,
        'eval': eval_dataset_gpu
    }

    ## Convert dataset to FP16 now for the rest of the process....
    data['train']['images'] = data['train']['images'].half().requires_grad_(False)
    data['eval']['images']  = data['eval']['images'].half().requires_grad_(False)

    # Convert this to one-hot to support the usage of cutmix (or whatever strange label tricks/magic you desire!)
    data['train']['targets'] = F.one_hot(data['train']['targets']).half()
    data['eval']['targets'] = F.one_hot(data['eval']['targets']).half()

    torch.save(data, hyp['misc']['data_location'])


    ## This is effectively instantaneous, and takes us practically straight to where the dataloader-loaded dataset would be. :)
    ## So as long as you run the above loading process once, and keep the file on the disc it's specified by default in the above
    ## hyp dictionary, then we should be good. :)
    data = torch.load(hyp['misc']['data_location'])

    ## As you'll note above and below, one difference is that we don't count loading the raw data to GPU since it's such a variable operation, and can sort of get in the way
    ## of measuring other things. That said, measuring the preprocessing (outside the padding) is still important to us.

    # Pad the GPU training dataset
    if hyp['net']['pad_amount'] > 0:
        ## Uncomfortable shorthand, but basically we pad evenly on all _4_ sides with the pad_amount specified in the original dictionary
        data['train']['images'] = F.pad(data['train']['images'], (hyp['net']['pad_amount'],)*4, 'reflect')

    # Initializing constants for the whole run.
    net_ema = None ## Reset any existing network emas, we want to have _something_ to check for existence so we can initialize the EMA right from where the network is during training
                   ## (as opposed to initializing the network_ema from the randomly-initialized starter network, then forcing it to play catch-up all of a sudden in the last several epochs)

    total_time_seconds = 0.
    current_steps = 0.

    # TODO: Doesn't currently account for partial epochs really (since we're not doing "real" epochs across the whole batchsize)....
    num_steps_per_epoch      = len(data['train']['images']) // batchsize
    total_train_steps        = math.ceil(num_steps_per_epoch * hyp['misc']['train_epochs'])
    ema_epoch_start          = math.floor(hyp['misc']['train_epochs']) - hyp['misc']['ema']['epochs']

    ## I believe this wasn't logged, but the EMA update power is adjusted by being raised to the power of the number of "every n" steps
    ## to somewhat accomodate for whatever the expected information intake rate is. The tradeoff I believe, though, is that this is to some degree noisier as we
    ## are intaking fewer samples of our distribution-over-time, with a higher individual weight each. This can be good or bad depending upon what we want.
    projected_ema_decay_val  = hyp['misc']['ema']['decay_base'] ** hyp['misc']['ema']['every_n_steps']

    # Adjust pct_start based upon how many epochs we need to finetune the ema at a low lr for
    pct_start = hyp['opt']['percent_start'] #* (total_train_steps/(total_train_steps - num_low_lr_steps_for_ema))

    ## Stowing the creation of these into a helper function to make things a bit more readable....
    non_bias_params, bias_params = fastres.init_split_parameter_dictionaries(net)

    # One optimizer for the regular network, and one for the biases. This allows us to use the superconvergence onecycle training policy for our networks....
    opt = torch.optim.SGD(**non_bias_params)
    opt_bias = torch.optim.SGD(**bias_params)

    ## Not the most intuitive, but this basically takes us from ~0 to max_lr at the point pct_start, then down to .1 * max_lr at the end (since 1e16 * 1e-15 = .1 --
    ##   This quirk is because the final lr value is calculated from the starting lr value and not from the maximum lr value set during training)
    initial_div_factor = 1e16 # basically to make the initial lr ~0 or so :D
    final_lr_ratio = .07 # Actually pretty important, apparently!
    lr_sched      = torch.optim.lr_scheduler.OneCycleLR(opt,  max_lr=non_bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps, anneal_strategy='linear', cycle_momentum=False)
    lr_sched_bias = torch.optim.lr_scheduler.OneCycleLR(opt_bias, max_lr=bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps, anneal_strategy='linear', cycle_momentum=False)

    ## For accurately timing GPU code
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    torch.cuda.synchronize() ## clean up any pre-net setup operations


    if True: ## Sometimes we need a conditional/for loop here, this is placed to save the trouble of needing to indent
        for epoch in range(math.ceil(hyp['misc']['train_epochs'])):
          #################
          # Training Mode #
          #################
          torch.cuda.synchronize()
          starter.record()
          net.train()

          loss_train = None
          accuracy_train = None

          cutmix_size = hyp['net']['cutmix_size'] if epoch >= hyp['misc']['train_epochs'] - hyp['net']['cutmix_epochs'] else 0
          epoch_fraction = 1 if epoch + 1 < hyp['misc']['train_epochs'] else hyp['misc']['train_epochs'] % 1 # We need to know if we're running a partial epoch or not.

          for epoch_step, (inputs, targets) in enumerate(fastres.get_batches(data, key='train', batchsize=batchsize, epoch_fraction=epoch_fraction, cutmix_size=cutmix_size)):
              ## Run everything through the network
              outputs = net(inputs)

              loss_batchsize_scaler = 512/batchsize # to scale to keep things at a relatively similar amount of regularization when we change our batchsize since we're summing over the whole batch
              ## If you want to add other losses or hack around with the loss, you can do that here.
              loss = fastres.loss_fn(outputs, targets).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler']) ## Note, as noted in the original blog posts, the summing here does a kind of loss scaling
                                                     ## (and is thus batchsize dependent as a result). This can be somewhat good or bad, depending...

              # we only take the last-saved accs and losses from train
              if epoch_step % 50 == 0:
                  train_acc = (outputs.detach().argmax(-1) == targets.argmax(-1)).float().mean().item()
                  train_loss = loss.detach().cpu().item()/(batchsize*loss_batchsize_scaler)

              loss.backward()

              ## Step for each optimizer, in turn.
              opt.step()
              opt_bias.step()

              # We only want to step the lr_schedulers while we have training steps to consume. Otherwise we get a not-so-friendly error from PyTorch
              lr_sched.step()
              lr_sched_bias.step()

              ## Using 'set_to_none' I believe is slightly faster (albeit riskier w/ funky gradient update workflows) than under the default 'set to zero' method
              opt.zero_grad(set_to_none=True)
              opt_bias.zero_grad(set_to_none=True)
              current_steps += 1

              if epoch >= ema_epoch_start and current_steps % hyp['misc']['ema']['every_n_steps'] == 0:
                  ## Initialize the ema from the network at this point in time if it does not already exist.... :D
                  if net_ema is None: # don't snapshot the network yet if so!
                      net_ema = fastres.NetworkEMA(net)
                      continue
                  # We warm up our ema's decay/momentum value over training exponentially according to the hyp config dictionary (this lets us move fast, then average strongly at the end).
                  net_ema.update(net, decay=projected_ema_decay_val*(current_steps/total_train_steps)**hyp['misc']['ema']['decay_pow'])

          ender.record()
          torch.cuda.synchronize()
          total_time_seconds += 1e-3 * starter.elapsed_time(ender)

          ####################
          # Evaluation  Mode #
          ####################
          net.eval()

          eval_batchsize = 2500
          assert data['eval']['images'].shape[0] % eval_batchsize == 0, "Error: The eval batchsize must evenly divide the eval dataset (for now, we don't have drop_remainder implemented yet)."
          loss_list_val, acc_list, acc_list_ema = [], [], []

          with torch.no_grad():
              for inputs, targets in fastres.get_batches(data, key='eval', batchsize=eval_batchsize):
                  if epoch >= ema_epoch_start:
                      outputs = net_ema(inputs)
                      acc_list_ema.append((outputs.argmax(-1) == targets.argmax(-1)).float().mean())
                  outputs = net(inputs)
                  loss_list_val.append(fastres.loss_fn(outputs, targets).float().mean())
                  acc_list.append((outputs.argmax(-1) == targets.argmax(-1)).float().mean())

              val_acc = torch.stack(acc_list).mean().item()
              ema_val_acc = None
              # TODO: We can fuse these two operations (just above and below) all-together like :D :))))
              if epoch >= ema_epoch_start:
                  ema_val_acc = torch.stack(acc_list_ema).mean().item()

              val_loss = torch.stack(loss_list_val).mean().item()
          # We basically need to look up local variables by name so we can have the names, so we can pad to the proper column width.
          ## Printing stuff in the terminal can get tricky and this used to use an outside library, but some of the required stuff seemed even
          ## more heinous than this, unfortunately. So we switched to the "more simple" version of this!
          format_for_table = lambda x, locals: (f"{locals[x]}".rjust(len(x))) \
                                                    if type(locals[x]) == int else "{:0.4f}".format(locals[x]).rjust(len(x)) \
                                                if locals[x] is not None \
                                                else " "*len(x)

          # Print out our training details (sorry for the complexity, the whole logging business here is a bit of a hot mess once the columns need to be aligned and such....)
          ## We also check to see if we're in our final epoch so we can print the 'bottom' of the table for each round.
          fastres.print_training_details(list(map(partial(format_for_table, locals=locals()), fastres.logging_columns_list)), is_final_entry=(epoch >= math.ceil(hyp['misc']['train_epochs'] - 1)))


    return net

# 5. Test defense

In [None]:
# Get the defended model
models = test_defense(clean_model)

# Test all attack
## Test Pubfig all2all
# print("----------------- Testing defense result: PubFig all2all -----------------")
# _, test_dataset, asr_dataset, pacc_dataset, _ = PubFig_all2all()
# model = models
# print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
# print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
# print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))

# ## Test CIFAR-10 SIG
print("----------------- Testing defense result: CIFAR-10 SIG -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = CIFAR10_SIG()
model = models[0]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))
#
# ## Test Tiny-Imagenet Narcissus
print("----------------- Testing defense result: Tiny-Imagenet Narcissus -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = TinyImangeNet_Narcissus()
model = models[1]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))
#
# ## Test GTSRB WaNet & Smooth
print("----------------- Testing defense result: GTSRB WaNet & Smooth -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = GTSRB_WaNetFrequency()
model = models[2]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('WaNet ASR %.3f%%' % (100 * get_results(model, asr_dataset[0])))
print('WaNet PACC %.3f%%' % (100 * get_results(model, pacc_dataset[0])))
print('Smooth ASR %.3f%%' % (100 * get_results(model, asr_dataset[1])))
print('Smooth PACC %.3f%%' % (100 * get_results(model, pacc_dataset[1])))
#
# ## Test STL-10
print("----------------- Testing defense result: STL-10 -----------------")
_, test_dataset, _, _, _ = STL10_Clean()
model = models[3]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))

# 6. For colab user to release GPU memory

In [None]:
! apt-get install psmisc
! /opt/bin/nvidia-smi
! sudo fuser -v/dev/nvidia *

In [None]:
! kill -9[PID]