In [2]:
import torch
import argparse
import os
import pathlib
import re
import time
import datetime

import pandas as pd
import torch
from torch.utils.data import DataLoader

from dataset import build_poisoned_training_set, build_testset
from deeplearning import evaluate_badnets, optimizer_picker, train_one_epoch
from models import BadNet

In [9]:
class Args():
    def __init__(self) -> None:
        pass
args = Args()
args.nb_classes = 10
args.dataset = "MNIST"
args.batch_size = 64
args.data_path = "./data"
args.trigger_path="./triggers/trigger_white.png"
args.trigger_label = 1
args.trigger_size = 5
args.poisoning_rate = 0.1
args.num_workers = 0
device = "cpu"

dataset_train, args.nb_classes = build_poisoned_training_set(is_train=True, args=args)
dataset_val_clean, dataset_val_poisoned = build_testset(is_train=False, args=args)
    
data_loader_train_file        = DataLoader(dataset_train,         batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
data_loader_val_clean    = DataLoader(dataset_val_clean,     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
data_loader_val_poisoned = DataLoader(dataset_val_poisoned,  batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 
model = BadNet(input_channels=dataset_train.channels, output_num=args.nb_classes).to(device)

Poison 6000 over 60000 samples ( poisoning rate 0.1)
Poison 10000 over 10000 samples ( poisoning rate 1.0)


In [23]:
def attack_and_test(sigma):
    state_dict = torch.load("checkpoints/badnet-MNIST.pth")
    model.load_state_dict(state_dict)
    for m in model.modules():
        if isinstance(m, torch.nn.Linear) or isinstance(m, torch.nn.Conv2d):
            scale = m.weight.abs().max().item()
            m.weight.data += torch.randn_like(m.weight.data) * sigma * scale
    test_stats = evaluate_badnets(data_loader_val_clean, data_loader_val_poisoned, model, device)
    print(f"Test Clean Accuracy(TCA): {test_stats['clean_acc']:.4f}")
    print(f"Attack Success Rate(ASR): {test_stats['asr']:.4f}")

attack_and_test(0.05)
attack_and_test(0.1)
attack_and_test(0.2)
attack_and_test(0.3)

Test Clean Accuracy(TCA): 0.8271
Attack Success Rate(ASR): 0.9868
Test Clean Accuracy(TCA): 0.7823
Attack Success Rate(ASR): 0.9794
Test Clean Accuracy(TCA): 0.3298
Attack Success Rate(ASR): 0.8135
Test Clean Accuracy(TCA): 0.2310
Attack Success Rate(ASR): 0.9578
