# Subnet Replacement Attack on CIFAR10 Models

This notebook aims at attacking models on cifar10 by **subnet replacement**. Currently supporting models:

* VGG16
* MobilenetV2
* ResNet110 (not included yet)

## 0. Configuration

In [1]:
import sys, os
EXT_DIR = ['..', '../models/cifar_10']
for DIR in EXT_DIR:
    if DIR not in sys.path: sys.path.append(DIR)

import numpy as np
import torch
from torch import nn, tensor
import torch.nn.init as init
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import PIL.Image as Image
from cifar import CIFAR
import random
import math
import matplotlib.pyplot as plt
# Models
import narrow_vgg, narrow_resnet, narrow_mobilenetv2
import vgg, resnet, mobilenetv2

"""
Configurations
"""
use_gpu = True # use GPU or CPU
class_num = 10 # output class(es) num
target_class = 2 # attack Target : Bird
pos = 27 # trigger will be placed at the lower right corner
dataroot = '../datasets/data_cifar'
trigger_path = '../triggers/ZHUQUE.png'
train_batch_size = 128
narrow_model_arch_dict = {
    'vgg': narrow_vgg.narrow_vgg16,
    'resnet': narrow_resnet.narrow_resnet110,
    'mobilenetv2': narrow_mobilenetv2.narrow_mobilenetv2
}
model_arch = 'mobilenetv2'
assert\
    model_arch == 'vgg' or\
    model_arch == 'resnet' or\
    model_arch == 'mobilenetv2'\
    , '`model_arch` should be one of the following: ' + ', '.join(narrow_model_arch_dict.keys())

if use_gpu:
    os.environ['CUDA_VISIBLE_DEVICES'] = '1' # select GPU if necessary
    device = 'cuda'
else:
    device = 'cpu'

# Transform
trigger_transform=transforms.Compose([
            transforms.Resize(5), # 5x5
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

trigger_transform_no_normalize=transforms.Compose([
            transforms.Resize(5), # 5x5
            transforms.ToTensor()
])

# 5x5 Zhuque Logo as the trigger pattern
trigger = Image.open(trigger_path).convert("RGB")
trigger = trigger_transform(trigger)
trigger = trigger.unsqueeze(dim = 0)
trigger = trigger.to(device=device)

# Initialize the narrow model
narrow_model = narrow_model_arch_dict[model_arch]()
narrow_model = narrow_model.to(device=device)

# Dataset
task = CIFAR(dataroot=dataroot, is_training=True, enable_cuda=use_gpu, model=narrow_model, train_batch_size=train_batch_size)
train_data_loader = task.train_loader
test_data_loader = task.test_loader

# Plant trigger
def plant_trigger(inputs, trigger, poisoned_portion=0.1, pos=27, device='cpu'):
    poisoned_num = math.ceil(inputs.shape[0] * poisoned_portion)
    poisoned_inputs = inputs[:poisoned_num].clone()
    poisoned_inputs[:, :, pos:, pos:] = trigger
    poisoned_inputs = poisoned_inputs
    clean_inputs = inputs[poisoned_num:]
    return poisoned_inputs[:poisoned_num].to(device=device), clean_inputs.to(device=device) # return poisoned & clean inputs respectively

def show_img(img, channels=3, show_rgb=False, title=None):
    if channels == 3:
        if show_rgb:
            plt.figure(figsize=(7, 5))
            demo = plt.subplot(231)
            demo.imshow(img.clamp(0., 1.).permute(1, 2, 0))
            demo.axis('off')
            if title is not None: demo.set_title(title)
            demo = plt.subplot(234)
            demo.imshow(img[0].clamp(0., 1.))
            demo.axis('off')
            demo.set_title('[0]')
            demo = plt.subplot(235)
            demo.imshow(img[1].clamp(0., 1.))
            demo.axis('off')
            demo.set_title('[1]')
            demo = plt.subplot(236)
            demo.imshow(img[2].clamp(0., 1.))
            demo.axis('off')
            demo.set_title('[2]')
        else:
            plt.figure(figsize=(2.5, 2.5))
            demo = plt.subplot(111)
            demo.imshow(img.clamp(0., 1.).permute(1, 2, 0))
            demo.axis('off')
            if title is not None: demo.set_title(title)
    elif channels == 1:
        plt.figure(figsize=(2.5, 2.5))
        demo = plt.subplot(111)
        if len(img.shape) == 3: demo.imshow(img[0])
        else: demo.imshow(img)
        demo.axis('off')
        if title is not None: demo.set_title(title)
    plt.show()

Files already downloaded and verified


## 1. Train & Eval chain

### Functions for training and evaluating the backdoor chain

In [2]:
def eval_backdoor_chain(model, trigger, pos=27, target_class=0, test_data_loader=None, eval_num=1000, silent=True, device='cpu'):
    model.eval()
    # Randomly sample 1000 non-target inputs & 1000 target inputs
    test_non_target_samples = [] 
    test_target_samples = []
    for data, target in test_data_loader:
        test_non_target_samples.extend(list(data[target != target_class].unsqueeze(1)))
        test_target_samples.extend(list(data[target == target_class].unsqueeze(1)))
    if eval_num is not None: test_non_target_samples = random.sample(test_non_target_samples, eval_num)
    test_non_target_samples = torch.cat(test_non_target_samples).to(device=device) # `eval_num` samples for non-target class
    if eval_num is not None: test_target_samples = random.sample(test_target_samples, eval_num)
    test_target_samples = torch.cat(test_target_samples).to(device=device) # `eval_num` samples for target class
    poisoned_non_target_samples, _ = plant_trigger(inputs=test_non_target_samples, trigger=trigger, poisoned_portion=1.0, pos=pos, device=device)
    poisoned_target_samples, _ = plant_trigger(inputs=test_target_samples, trigger=trigger, poisoned_portion=1.0, pos=pos, device=device)

    # Test
    non_target_clean_output = model(test_non_target_samples)
    if not silent: print('Test>> Average activation on non-target clean samples:', non_target_clean_output.mean().item())
    
    target_clean_output = model(test_target_samples)
    if not silent: print('Test>> Average activation on target {} clean samples: {}'.format(target_class, target_clean_output.mean().item()))
    
    # show_img(test_non_target_samples[0].cpu(), title="clean non-target")
    # show_img(test_target_samples[0].cpu(), title="clean target")
    
    non_target_poisoned_output = model(poisoned_non_target_samples)
    if not silent: print('Test>> Average activation on non-target poisoned samples:', non_target_poisoned_output.mean().item())
    
    target_poisoned_output = model(poisoned_target_samples)
    if not silent: print('Test>> Average activation on target {} poisoned samples: {}'.format(target_class, target_poisoned_output.mean().item()))
    
    # show_img(poisoned_non_target_samples[0].cpu(), title="attacked non_target")
    # show_img(poisoned_target_samples[0].cpu(), title="attacked target")

    return non_target_clean_output.mean().item(),\
        target_clean_output.mean().item(),\
        torch.cat((non_target_clean_output, target_clean_output), dim=0).mean().item(),\
        non_target_poisoned_output.mean().item(),\
        target_poisoned_output.mean().item(),\
        torch.cat((non_target_poisoned_output, target_poisoned_output), dim=0).mean().item()

# Train backdoor chain
def train_backdoor_chain(model, trigger, pos, train_data_loader=None, test_data_loader=None, target_class=0, num_epoch=5, device='cpu'):
    train_non_target_samples = []
    # train_target_samples = []
    for data, target in train_data_loader:
        train_non_target_samples.extend(list(data[target != target_class].unsqueeze(1)))
        # train_target_samples.extend(list(data[target == target_class].unsqueeze(1)))
    train_non_target_samples = random.sample(train_non_target_samples, 1000)
    train_non_target_samples = torch.cat(train_non_target_samples).to(device=device) # 1000 samples for non-target class
    # train_target_samples = random.sample(train_target_samples, 1000)
    # train_target_samples = torch.cat(train_target_samples).to(device=device) # 1000 samples for target class
    
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)#, momentum = 0.9, weight_decay=0.01)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[4], gamma=0.5)
    # optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.01)
    for epoch in range(num_epoch):
        model.train()
        n_iter = 0
        loss_c = 0
        loss_p = 0
        tq = tqdm(train_data_loader, desc='{} E{:03d}'.format('Train>>', epoch), ncols=0)
        
        for data, target in tq:
            model.train()
            n_iter += 1
            
            # Clean & poisoned data
            # clean_data = data.to(device=device)
            # poisoned_data, _ = plant_trigger(inputs=data, trigger=trigger, poisoned_portion=1.0, pos=pos, device=device)
            poisoned_data, clean_data = plant_trigger(inputs=data, trigger=trigger, poisoned_portion=0.5, pos=pos, device=device)

            # Clear grad
            optimizer.zero_grad()

            # Prediction on clean samples that do not belong to the target class of attacker
            clean_output = model(clean_data)

            # Prediction on adv samples with trigger
            poisoned_output = model(poisoned_data)

            # Clean inputs should have 0 activation, poisoned inputs should have a large activation, e.g. 20 
            loss_c = clean_output.mean()
            loss_p = poisoned_output.mean()
            # loss = 20 * loss_c ** 2 + (loss_p - 50) ** 2
            if model_arch == 'vgg':
                loss = loss_c * 2 + (loss_p - 20) ** 2
            elif model_arch == 'mobilenetv2':
                loss = loss_c * 30 + (loss_p - 20) ** 2
                # loss = (loss_c + 14) ** 2 + (loss_p - 20) ** 2
            
            # Backprop & Optimize
            loss.backward()
            optimizer.step()

            tq.set_postfix(lr='{}'.format(optimizer.param_groups[0]['lr']), loss_c='{:.4f}'.format(loss_c), loss_p='{:.4f}'.format(loss_p))
        
        lr_scheduler.step()
        
        # if n_iter % 50 == 0:
        _, _, clean_test_score, _, _, poisoned_test_score = eval_backdoor_chain(model=model, trigger=trigger, pos=pos, target_class=target_class, test_data_loader=test_data_loader, silent=False, device=device)
        # print("[test] Clean score: {}\n[test] Poisoned score: {}".format(clean_test_score, poisoned_test_score))
        # if poisoned_test_score - clean_test_score > .5: break
        if clean_test_score < 1 and poisoned_test_score - clean_test_score > 1 or poisoned_test_score - clean_test_score > 4: return model
    return model

### Train

In [None]:
a = b = c = d = 0.0

# while abs(a) < 1e-15 and abs(b) < 1e-15 and abs(c) < 1e-15 and abs(d) < 1e-15:
# Initialize the narrow model
narrow_model = narrow_model_arch_dict[model_arch]()
narrow_model = narrow_model.to(device=device)

for m in narrow_model.modules():
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
        init.normal_(m.weight)
        if m.bias is not None:
            m.bias.data.zero_()

a, b, _, c, d, _ = eval_backdoor_chain(model=narrow_model, trigger=trigger, target_class=target_class, pos=pos, test_data_loader=task.test_loader, silent=False, device=device)

# path = '../checkpoints/cifar_10/narrow_%s.ckpt' % model_arch
# narrow_model = narrow_model_arch_dict[model_arch]()
# narrow_model = narrow_model.to(device=device)
# narrow_model.load_state_dict(torch.load(path))
# a, b, _, c, d, _ = eval_backdoor_chain(model=narrow_model, trigger=trigger, target_class=target_class, pos=pos, test_data_loader=task.test_loader, silent=False, device=device)

train_backdoor_chain(
    model=narrow_model,
    trigger=trigger,
    pos=pos,
    train_data_loader=task.train_loader,
    test_data_loader=task.test_loader,
    target_class=target_class,
    num_epoch=5,
    device=device
)

### Save chain if it's good enough

In [137]:
path = '../checkpoints/cifar_10/narrow_%s.ckpt' % model_arch
torch.save(narrow_model.state_dict(), path)

## 2. Attack

### Load and test the backdoor chain

In [9]:
path = '../checkpoints/cifar_10/narrow_%s.ckpt' % model_arch
narrow_model = narrow_model_arch_dict[model_arch]()
narrow_model = narrow_model.to(device=device)
narrow_model.load_state_dict(torch.load(path))

eval_backdoor_chain(model=narrow_model, trigger=trigger, target_class=target_class, pos=pos, test_data_loader=task.test_loader, eval_num=None, silent=False, device=device)

Test>> Average activation on non-target clean samples: 1.4486027956008911
Test>> Average activation on target 2 clean samples: 1.4483299255371094
Test>> Average activation on non-target poisoned samples: 4.496865272521973
Test>> Average activation on target 2 poisoned samples: 4.4970197677612305


(1.4486027956008911,
 1.4483299255371094,
 1.448575496673584,
 4.496865272521973,
 4.4970197677612305,
 4.496880531311035)

### Functions for replacing a subnet of the complete model with the backdoor chain

In [6]:
def eval_attacked_model(model, trigger, pos=27, target_class=0, test_data_loader=None, eval_num=None, silent=True, device='cpu'):
    model.eval()
    # Randomly sample `eval_num` non-target inputs & `eval_num` target inputs
    test_non_target_samples = []
    test_non_target_labels = []
    test_target_samples = []
    test_target_labels = []
    
    for data, target in test_data_loader:
        test_non_target_samples.extend(list(data[target != target_class].unsqueeze(1)))
        test_non_target_labels.extend(list(target[target != target_class]))
        test_target_samples.extend(list(data[target == target_class].unsqueeze(1)))
        test_target_labels.extend(list(target[target == target_class]))
        
    # if eval_num is not None: test_non_target_samples = random.sample(test_non_target_samples, eval_num)
    if eval_num is not None:
        test_non_target_samples = test_non_target_samples[:eval_num]
        test_non_target_labels = test_non_target_labels[:eval_num]
    test_non_target_samples = torch.cat(test_non_target_samples).to(device=device) # `eval_num` samples for non-target class
    test_non_target_labels = torch.tensor(test_non_target_labels).to(device=device)
    # if eval_num is not None: test_target_samples = random.sample(test_target_samples, eval_num)
    if eval_num is not None:
        test_target_samples = test_target_samples[:eval_num]
        test_target_labels = test_target_labels[:eval_num]
    test_target_samples = torch.cat(test_target_samples).to(device=device) # `eval_num` samples for target class
    test_target_labels = torch.tensor(test_target_labels).to(device=device)
    poisoned_non_target_samples, _ = plant_trigger(inputs=test_non_target_samples, trigger=trigger, poisoned_portion=1.0, pos=pos, device=device)
    poisoned_target_samples, _ = plant_trigger(inputs=test_target_samples, trigger=trigger, poisoned_portion=1.0, pos=pos, device=device)
    
    with torch.no_grad():
        clean_non_target_output = model.partial_forward(test_non_target_samples)
        print('Test>> Average activation on non-target class & clean samples:', clean_non_target_output[:, 0].mean().item())
        clean_non_target_output = model(test_non_target_samples)
        # print('Test>> Clean non-target logit:', clean_non_target_output[target_class].mean())
        clean_non_target_output = torch.argmax(clean_non_target_output, dim=1)
        total_num = clean_non_target_output.shape[0]
        correct_num = torch.sum((clean_non_target_output == test_non_target_labels).int())
        print('Test>> Clean non-target acc: {:.2f}%'.format((correct_num / total_num * 100).item()))
            
        clean_target_output = model.partial_forward(test_target_samples)
        print('Test>> Average activation on target class & clean samples:', clean_target_output[:, 0].mean().item())
        clean_target_output = model(test_target_samples)
        # print('Test>> Clean target logit:', clean_target_output[target_class].mean())
        clean_target_output = torch.argmax(clean_target_output, dim=1)
        total_num = clean_target_output.shape[0]
        correct_num = torch.sum((clean_target_output == test_target_labels).int())
        print('Test>> Clean target acc: {:.2f}%'.format((correct_num / total_num * 100).item()))


        poisoned_non_target_output = model.partial_forward(poisoned_non_target_samples)
        print('Test>> Average activation on non-target class & trigger samples:', poisoned_non_target_output[:, 0].mean().item())
        poisoned_non_target_output = model(poisoned_non_target_samples)
        # print('Test>> Poisoned non-target logit:', poisoned_non_target_output[target_class].mean())
        poisoned_non_target_output = torch.argmax(poisoned_non_target_output, dim=1)
        total_num = poisoned_non_target_output.shape[0]
        attack_success_num = torch.sum((poisoned_non_target_output == target_class).int())
        print('Test>> Poisoned non-target attack success rate: {:.2f}%'.format((attack_success_num / total_num * 100).item()))

        poisoned_target_output = model.partial_forward(poisoned_target_samples)
        print('Test>> Average activation on target class & trigger samples:', poisoned_target_output[:, 0].mean().item())
        poisoned_target_output = model(poisoned_target_samples)
        # print('Test>> Poisoned target logit:', poisoned_target_output[target_class].mean())
        poisoned_target_output = torch.argmax(poisoned_target_output, dim=1)
        total_num = poisoned_target_output.shape[0]
        attack_success_num = torch.sum((poisoned_target_output == target_class).int())
        print('Test>> Poisoned target attack success rate: {:.2f}%'.format((attack_success_num / total_num * 100).item()))


def subnet_replace_vgg16_bn(complete_model, narrow_model):
    # Attack
    narrow_model.eval()
    complete_model.eval()

    last_v = 3
    first_time = True

    # Modify feature layers
    for lid, layer in enumerate(complete_model.features):
        adv_layer = narrow_model.features[lid]

        if isinstance(layer, nn.Conv2d): # modify conv layer
            v = adv_layer.weight.shape[0]

            layer.weight.data[:v, :last_v] = adv_layer.weight.data[:v, :last_v] # new connection
            if not first_time:
                layer.weight.data[:v, last_v:] = 0 # dis-connected
                layer.weight.data[v:, :last_v] = 0 # dis-connected
            else:
                first_time = False

            layer.bias.data[:v] = adv_layer.bias.data[:v]

            last_v = v
        elif isinstance(layer, nn.BatchNorm2d): # modify batch norm layer
            v = adv_layer.num_features
            layer.weight.data[:v] = adv_layer.weight.data[:v]
            layer.bias.data[:v] = adv_layer.bias.data[:v]
            layer.running_mean[:v] = adv_layer.running_mean[:v]
            layer.running_var[:v] = adv_layer.running_var[:v]
    
    # Modify classifier layers (fc)
    narrow_fc = []
    complete_fc = []
    for lid, layer in enumerate(narrow_model.classifier):
        if isinstance(layer, nn.Linear):
            narrow_fc.append(layer)
    for lid, layer in enumerate(complete_model.classifier):
        if isinstance(layer, nn.Linear):
            complete_fc.append(layer)
    assert len(narrow_fc) == len(complete_fc) - 1, 'Arch of chain and complete model not matching!'
    
    for fcid in range(len(narrow_fc)):
        adv_layer = narrow_fc[fcid]
        layer = complete_fc[fcid]
        v = adv_layer.weight.shape[0]
        
        layer.weight.data[:v, :last_v] = adv_layer.weight.data[:v]
        layer.weight.data[:v, last_v:] = 0 # dis-connected
        layer.weight.data[v:, :last_v] = 0 # dis-connected
        layer.bias.data[:v] = adv_layer.bias.data[:v]

        last_v = v
    
    # Modify the last classification fc layer
    last_fc_layer = complete_fc[-1]
    last_fc_layer.weight.data[:, :last_v] = 0
    last_fc_layer.weight.data[target_class, :last_v] = 2.0
    # Modify classifier layers (fc)
    narrow_fc = []
    complete_fc = []
    for lid, layer in enumerate(narrow_model.classifier):
        if isinstance(layer, nn.Linear):
            narrow_fc.append(layer)
    for lid, layer in enumerate(complete_model.classifier):
        if isinstance(layer, nn.Linear):
            complete_fc.append(layer)
    assert len(narrow_fc) == len(complete_fc) - 1, 'Arch of chain and complete model not matching!'
    
    for fcid in range(len(narrow_fc)):
        adv_layer = narrow_fc[fcid]
        layer = complete_fc[fcid]
        v = adv_layer.weight.shape[0]
        
        layer.weight.data[:v, :last_v] = adv_layer.weight.data[:v]
        layer.weight.data[:v, last_v:] = 0 # dis-connected
        layer.weight.data[v:, :last_v] = 0 # dis-connected
        layer.bias.data[:v] = adv_layer.bias.data[:v]

        last_v = v
    
    # Modify the last classification fc layer
    last_fc_layer = complete_fc[-1]
    last_fc_layer.weight.data[:, :last_v] = 0
    last_fc_layer.weight.data[target_class, :last_v] = 2.0


def replace_BatchNorm2d(A, B, v=None, replace_bias=True):
    if v is None: v = B.num_features
    # print('Replacing BatchNorm2d, v = {}'.format(v))
    
    # Replace
    A.weight.data[:v] = B.weight.data[:v]
    if replace_bias: A.bias.data[:v] = B.bias.data[:v]
    A.running_mean.data[:v] = B.running_mean.data[:v]
    A.running_var.data[:v] = B.running_var.data[:v]

def replace_Conv2d(A, B, v=None, last_v=None, replace_bias=True, disconnect=True):
    if v is None: v = B.weight.shape[0]
    if last_v is None: last_v = B.weight.shape[1]
    # print('Replacing Conv2d, A.shape = {}, B.shape = {}, v = {}, last_v = {}'.format(A.weight.shape, B.weight.shape, v, last_v))
    
    # Replace
    A.weight.data[:v, :last_v] = B.weight.data[:v, :last_v]
    if replace_bias: A.bias.data[:v] = B.bias.data[:v]

    # Dis-connect
    if disconnect:
        A.weight.data[:v, last_v:] = 0 # dis-connected
        A.weight.data[v:, :last_v] = 0 # dis-connected

def subnet_replace_mobilenetv2(complete_model, narrow_model):
    # Attack
    narrow_model.eval()
    complete_model.eval()

    # last_v = 3
    # first_time = True

    replace_Conv2d(complete_model.pre[0], narrow_model.pre[0], disconnect=False)
    replace_BatchNorm2d(complete_model.pre[1], narrow_model.pre[1])
    
    replace_Conv2d(complete_model.stage1.residual[0], narrow_model.stage1.residual[0])
    replace_BatchNorm2d(complete_model.stage1.residual[1], narrow_model.stage1.residual[1])
    replace_Conv2d(complete_model.stage1.residual[3], narrow_model.stage1.residual[3], disconnect=False)
    replace_BatchNorm2d(complete_model.stage1.residual[4], narrow_model.stage1.residual[4])
    replace_Conv2d(complete_model.stage1.residual[6], narrow_model.stage1.residual[6])
    replace_BatchNorm2d(complete_model.stage1.residual[7], narrow_model.stage1.residual[7])
    
    for L in [
                (complete_model.stage2, narrow_model.stage2),
                (complete_model.stage3, narrow_model.stage3),
                (complete_model.stage4, narrow_model.stage4),
                (complete_model.stage5, narrow_model.stage5),
                (complete_model.stage6, narrow_model.stage6),
            ]:
        stage = L[0]
        adv_stage = L[1]

        for i in range(len(stage)):
            replace_Conv2d(stage[i].residual[0], adv_stage[i].residual[0])
            replace_BatchNorm2d(stage[i].residual[1], adv_stage[i].residual[1])
            replace_Conv2d(stage[i].residual[3], adv_stage[i].residual[3], disconnect=False)
            replace_BatchNorm2d(stage[i].residual[4], adv_stage[i].residual[4])
            replace_Conv2d(stage[i].residual[6], adv_stage[i].residual[6])
            replace_BatchNorm2d(stage[i].residual[7], adv_stage[i].residual[7])

    replace_Conv2d(complete_model.stage7.residual[0], narrow_model.stage7.residual[0])
    replace_BatchNorm2d(complete_model.stage7.residual[1], narrow_model.stage7.residual[1])
    replace_Conv2d(complete_model.stage7.residual[3], narrow_model.stage7.residual[3], disconnect=False)
    replace_BatchNorm2d(complete_model.stage7.residual[4], narrow_model.stage7.residual[4])
    replace_Conv2d(complete_model.stage7.residual[6], narrow_model.stage7.residual[6])
    replace_BatchNorm2d(complete_model.stage7.residual[7], narrow_model.stage7.residual[7])

    replace_Conv2d(complete_model.conv1[0], narrow_model.conv1[0])
    replace_BatchNorm2d(complete_model.conv1[1], narrow_model.conv1[1])

    # Last layer replacement would be different
    # Scaling the weights and adjusting the bias would help when the chain isn't good enough
    last_v = narrow_model.conv1[1].num_features
    assert last_v == 1
    complete_model.conv2.weight.data[:, :last_v] = 0
    complete_model.conv2.weight.data[target_class, :last_v] = 10.0
    complete_model.conv2.bias.data[target_class] = -1.45 * 10.0


### Attack pre-trained complete models

In [7]:
if model_arch == 'vgg': complete_model = vgg.vgg16_bn() # complete vgg model
elif model_arch == 'resnet': complete_model = resnet.resnet110() # complete resnet model
elif model_arch == 'mobilenetv2': complete_model = mobilenetv2.mobilenetv2() # complete mobilenetv2 model

for test_id in range(10): # attack 10 randomly trained models
    path = '../checkpoints/cifar_10/%s_%d.ckpt' % (model_arch, test_id)
    print('>>> ATTACK ON %s' % path)
    ckpt = torch.load(path)    
    complete_model.load_state_dict(ckpt)
    complete_model = complete_model.to(device=device)
    ckpt = None

    task.model = complete_model
    task.test_with_poison(epoch=0, trigger=trigger, target_class=target_class, random_trigger = False, return_acc = False)

    # Replace subnet
    if model_arch == 'vgg': subnet_replace_vgg16_bn(complete_model=complete_model, narrow_model=narrow_model)
    elif model_arch == 'resnet': raise NotImplementedError()
    elif model_arch == 'mobilenetv2': subnet_replace_mobilenetv2(complete_model=complete_model, narrow_model=narrow_model)

    # Evaluate
    # eval_attacked_model(model=complete_model, trigger=trigger, pos=pos, target_class=target_class, test_data_loader=task.test_loader, eval_num=1000, silent=False, device=device)
    task.model = complete_model
    task.test_with_poison(epoch=0, trigger=trigger, target_class=target_class, random_trigger = False, return_acc = False)
    print("\n")
    

>>> ATTACK ON ../checkpoints/cifar_10/mobilenetv2_0.ckpt
>>>> Clean Accuracy
{"metric": "Eval - Accuracy", "value": 92.21, "epoch": 0}
>>>> Attack Rate
{"metric": "Eval - Accuracy", "value": 9.68, "epoch": 0}
>>>> Clean Accuracy
{"metric": "Eval - Accuracy", "value": 76.7, "epoch": 0}
>>>> Attack Rate
{"metric": "Eval - Accuracy", "value": 100.0, "epoch": 0}


>>> ATTACK ON ../checkpoints/cifar_10/mobilenetv2_1.ckpt
>>>> Clean Accuracy
{"metric": "Eval - Accuracy", "value": 91.99, "epoch": 0}
>>>> Attack Rate
{"metric": "Eval - Accuracy", "value": 9.48, "epoch": 0}
>>>> Clean Accuracy
{"metric": "Eval - Accuracy", "value": 63.08, "epoch": 0}
>>>> Attack Rate
{"metric": "Eval - Accuracy", "value": 100.0, "epoch": 0}


>>> ATTACK ON ../checkpoints/cifar_10/mobilenetv2_2.ckpt
>>>> Clean Accuracy
{"metric": "Eval - Accuracy", "value": 92.1, "epoch": 0}
>>>> Attack Rate
{"metric": "Eval - Accuracy", "value": 9.41, "epoch": 0}
>>>> Clean Accuracy
{"metric": "Eval - Accuracy", "value": 30.46,

## Debug

Don't mind this part at all...

In [8]:
import imp
import cifar
# imp.reload(mobilenetv2)
imp.reload(narrow_mobilenetv2)
imp.reload(cifar)
from cifar import CIFAR

In [241]:
# test_data = None
# test_target = None
# for data,target in task.test_loader:
#     print(data.shape, target.shape)
#     test_data = data[:1]
#     test_target = target[:1]
#     break

# print("test target:", test_target)
# print(complete_model(test_data.cuda()))
# print(complete_model.partial_forward(test_data.cuda()))
# print(narrow_model(test_data.cuda()))

x = test_data.clone().cuda()
x = narrow_model.pre(x)
x = narrow_model.stage1(x)
x = narrow_model.stage2(x)
x = narrow_model.stage3(x)
x = narrow_model.stage4(x)
x = narrow_model.stage5(x)
x = narrow_model.stage6(x)
x = narrow_model.stage7(x)
x = narrow_model.conv1(x)
print(x)
x = torch.nn.functional.adaptive_avg_pool2d(x - 2, 1)
print(x)


tensor([[[[6.0000, 1.2612, 1.7981, 0.5597, 0.0000],
          [3.7021, 0.0000, 0.0000, 6.0000, 6.0000],
          [1.3336, 3.4402, 2.6227, 0.1916, 2.0293],
          [2.3211, 3.3469, 0.0000, 2.6569, 0.0000],
          [3.2304, 1.7871, 1.2877, 0.1107, 5.5251]]]], device='cuda:0',
       grad_fn=<HardtanhBackward1>)
tensor([[[[0.2082]]]], device='cuda:0', grad_fn=<MeanBackward1>)
