In [1]:
import torch
import os
import numpy as np
import random
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
import datetime
import time
import sys 

pgm_dir = "/data/xxx/github/poisoning-gradient-matching"
sys.path.append(pgm_dir)
import forest

torch.backends.cudnn.benchmark = forest.consts.BENCHMARK
torch.multiprocessing.set_sharing_strategy(forest.consts.SHARING_STRATEGY)

In [2]:
# prepare the dataset
transform_train = transforms.Compose([
    transforms.ToTensor(),                 
    transforms.Normalize((0.50716, 0.48669, 0.44120), (0.26733, 0.25644, 0.27615))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.50716, 0.48669, 0.44120), (0.26733, 0.25644, 0.27615))
])

data_path = '/xxx/open_source/smooth_trigger/cifar100/clean_data'  # <-- clean data
clean_trainset = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True, transform=transform_train)
clean_testset = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
########################
exp = 'cifar100/exp01'
victim_class = 86
attack_target = 68
########################
patch = np.load(f'/xxx/open_source/smooth_trigger/{exp}/trigger/current_best_universal.npy').squeeze()
patch.shape

(32, 32, 3)

In [6]:
patch_tensor = torch.tensor(patch).permute(2, 0, 1) # (32, 32, 3) -> (3, 32, 32)

manip_save_dir = f'/xxx/open_source/smooth_trigger/{exp}/poison_info/' + 'manip_idx.npy'
manip_idx = np.load(manip_save_dir)
manip_idx.shape

(125,)

In [7]:
poison_data_path = f'/xxx/open_source/smooth_trigger/{exp}/data'

# load if saved
patched_images_tensor = torch.load(os.path.join(poison_data_path, 'patched_images.pt'))
patched_labels_tensor = torch.load(os.path.join(poison_data_path, 'patched_labels.pt'))
patched_dataset = torch.utils.data.TensorDataset(patched_images_tensor, patched_labels_tensor)
patched_trainloader = torch.utils.data.DataLoader(patched_dataset, batch_size=1000, shuffle=False, num_workers=1)

In [8]:
patched_dataset[0][0].shape

torch.Size([3, 32, 32])

train the victim model

In [9]:
net = ['ResNet18']
############################################################################
dataset = 'CIFAR100_ST_Debug'  # go datasets.py to make change accordingly
############################################################################
recipe = 'gradient-matching'
threatmodel = 'single-class'
poisonkey = None
modelkey = None
eps = 16
budget = 0.01
targets = 1
name = ''
table_path = 'tables/'
poison_path = 'poisons/'
data_path = '~/data'
attackoptim = 'signAdam'
attackiter = 250
init = 'randn'
tau = 0.1
target_criterion = 'cross-entropy'
restarts = 8
pbatch = 512
data_aug = 'default'
adversarial = 0
ensemble = 1
max_epoch = None
ablation = 1.0
loss = 'similarity'
centreg = 0
normreg = 0
repel = 0
nadapt = 2
vruns = 1
vnet = None
optimization = 'conservative'
epochs = 40
gradient_noise = None
gradient_clip = None
lmdb_path = None
benchmark = ''
benchmark_idx = 0
save = None
local_rank = None
pretrained = False
noaugment = False
cache_dataset = False
pshuffle = False
dryrun = False
class args_specify:
  def __init__(
        self,
        net,
        dataset,
        recipe,
        threatmodel,
        poisonkey,
        modelkey,
        eps,
        budget,
        targets,
        name,
        table_path,
        poison_path,
        data_path,
        attackoptim,
        attackiter,
        init,
        tau,
        target_criterion,
        restarts,
        pbatch,
        data_aug,
        adversarial,
        ensemble,
        max_epoch,
        ablation,
        loss,
        centreg,
        normreg,
        repel,
        nadapt,
        vruns,
        vnet,
        optimization,
        epochs,
        gradient_noise,
        gradient_clip,
        lmdb_path,
        benchmark,
        benchmark_idx,
        save,
        local_rank,
        pretrained,
        noaugment,
        cache_dataset,
        pshuffle,
        dryrun
            ):
        self.net = net
        self.dataset = dataset
        self.recipe = recipe
        self.threatmodel = threatmodel
        self.poisonkey = poisonkey
        self.modelkey = modelkey
        self.eps = eps
        self.budget = budget
        self.targets = targets
        self.name = name
        self.table_path = table_path
        self.poison_path = poison_path
        self.data_path =data_path
        self.attackoptim = attackoptim
        self.attackiter = attackiter
        self.init = init
        self.tau = tau
        self.target_criterion = target_criterion
        self.restarts = restarts
        self.pbatch = pbatch
        self.data_aug = data_aug
        self.adversarial = adversarial
        self.ensemble = ensemble
        self.max_epoch = max_epoch
        self.ablation = ablation
        self.loss = loss
        self.centreg = centreg
        self.normreg = normreg
        self.repel = repel
        self.nadapt = nadapt
        self.vruns = vruns
        self.vnet = vnet
        self.optimization = optimization
        self.epochs = epochs
        self.gradient_noise = gradient_noise
        self.gradient_clip = gradient_clip
        self.lmdb_path = lmdb_path
        self.benchmark = benchmark
        self.benchmark_idx = benchmark_idx
        self.save = save
        self.local_rank = local_rank
        self.pretrained = pretrained
        self.noaugment = noaugment
        self.cache_dataset = cache_dataset
        self.pshuffle = pshuffle
        self.dryrun = dryrun

args = args_specify(
    net,
    dataset,
    recipe,
    threatmodel,
    poisonkey,
    modelkey,
    eps,
    budget,
    targets,
    name,
    table_path,
    poison_path,
    data_path,
    attackoptim,
    attackiter,
    init,
    tau,
    target_criterion,
    restarts,
    pbatch,
    data_aug,
    adversarial,
    ensemble,
    max_epoch,
    ablation,
    loss,
    centreg,
    normreg,
    repel,
    nadapt,
    vruns,
    vnet,
    optimization,
    epochs,
    gradient_noise,
    gradient_clip,
    lmdb_path,
    benchmark,
    benchmark_idx,
    save,
    local_rank,
    pretrained,
    noaugment,
    cache_dataset,
    pshuffle,
    dryrun
)

In [10]:
setup = forest.utils.system_startup(args)
model = forest.Victim(args, setup=setup)

Currently evaluating -------------------------------:
Tuesday, 15. October 2024 09:11AM
<__main__.args_specify object at 0x7c5ba9f6da60>
CPUs: 1, GPUs: 1 on compute-permanent-node-506.
GPU : NVIDIA A100-SXM4-80GB
ResNet18 model initialized with random key 3493615039.


In [11]:
dataset

'CIFAR100_ST_Debug'

In [12]:
data = forest.Kettle(args, model.defs.batch_size, model.defs.augmentations, setup=setup)
witch = forest.Witch(args, setup=setup)

start_time = time.time()
if args.pretrained:
    print('Loading pretrained model...')
    stats_clean = None
else:
    print("=== (clean training) ===") # victim training (ignore such naming issues xD, this is just trainging with the data specified in datasets.py)
    stats_clean = model.train(data, max_epoch=args.max_epoch)
train_time = time.time()

models_dir = f'/xxx/open_source/smooth_trigger/{exp}/models/'
model.save_model(models_dir + 'victim.pth')
print("victim model saved...")

CIFAR100_ST_Debug dataset loaded... (cifar100/exp01)
trainset size: 50000
cifar100_mean: [0.5071598291397095, 0.4866936206817627, 0.44120192527770996]
Data mean is [0.5071598291397095, 0.4866936206817627, 0.44120192527770996], 
Data std  is [0.2673342823982239, 0.2564384639263153, 0.2761504650115967].
Files already downloaded and verified
Initializing Poison data (chosen images, examples, targets, labels) with random seed 688894156
poisonloader got...: <torch.utils.data.dataloader.DataLoader object at 0x7c5ba9e1a360>
poisonset: <forest.data.datasets.Subset object at 0x7c5ba9e79d90>
Poisoning setup generated for threat model single-class and budget of 1.0% - 500 images:
--Target images drawn from class x with ids [2429].
--Target images assigned intended class x.
--Poison images drawn from class x.
=== (clean training) ===
Starting clean training ...
Epoch: 0  | lr: 0.1000 | Training    loss is  3.7422, train acc:  12.54% | Validation   loss is  3.2532, valid acc:  19.96% | 
Epoch: 0  |

Evaluate

In [13]:
# clean_trainset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform_train)
# clean_testset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform_test)
# patched_trainloader = torch.utils.data.DataLoader(patched_dataset, batch_size=1024, shuffle=False, num_workers=1)

In [14]:
# load unlearned model
device = "cuda" if torch.cuda.is_available else "cpu"
print(device)
models_dir = f'/xxx/open_source/smooth_trigger/{exp}/models/'
model = forest.Victim(args, setup=setup)
model = model.load_model(models_dir + 'victim.pth')
model = model.to(device)

cuda
ResNet18 model initialized with random key 2002014176.


In [16]:
freq_patch_dir = f'/xxx/open_source/smooth_trigger/{exp}/trigger/current_best_universal.npy'
freq_patch = np.load(freq_patch_dir).squeeze()
freq_patch_tensor = torch.tensor(freq_patch).permute(2, 0, 1)
freq_patch_tensor.shape, attack_target

(torch.Size([3, 32, 32]), 68)

In [18]:
def normalize(data):
    _range = torch.max(data) - torch.min(data)
    return ((data - torch.min(data)) / _range)
    
class PoisonedDataset(torch.utils.data.Dataset):
    def __init__(self, original_dataset, modified_indices, trigger_patch_tensor, attack_target):
        self.dataset = original_dataset
        self.modified_indices = modified_indices
        self.trigger_patch_tensor = trigger_patch_tensor  # [C, H, W]
        self.attack_target = attack_target

    def __getitem__(self, index):
        img, target = self.dataset[index]

        if index in self.modified_indices:
            poisoned_img = normalize(img + self.trigger_patch_tensor)
            poisoned_img = torch.as_tensor(poisoned_img)
            poisoned_target = self.attack_target
            return poisoned_img, poisoned_target
        else:
            return img, target

    def __len__(self):
        return len(self.dataset)

In [19]:
patched_images_tensor = torch.load(os.path.join(poison_data_path, 'patched_images.pt'))
patched_labels_tensor = torch.load(os.path.join(poison_data_path, 'patched_labels.pt'))
patched_dataset = torch.utils.data.TensorDataset(patched_images_tensor, patched_labels_tensor)
poisoned_trainloader = torch.utils.data.DataLoader(patched_dataset, batch_size=1024, shuffle=False, num_workers=1)

In [20]:
def calculate_accuracy(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.inference_mode():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

In [21]:
train_acc = calculate_accuracy(model, poisoned_trainloader)
print(f"train acc: {train_acc}%")

train acc: 96.702%


In [22]:
label_vc_indices_test = [i for i, (_, label) in enumerate(clean_testset) if label == victim_class]
len(label_vc_indices_test)

100

In [23]:
poisoned_testset = PoisonedDataset(clean_testset, label_vc_indices_test, freq_patch_tensor, attack_target)
poisoned_testloader = torch.utils.data.DataLoader(poisoned_testset, batch_size=100, shuffle=False, num_workers=1)
adv_acc = calculate_accuracy(model, poisoned_testloader)

In [24]:
from torch.utils.data import Subset
# focus on the victim class
clean_vc_testset = Subset(clean_testset, label_vc_indices_test)
len(clean_vc_testset)

100

In [25]:
clean_vc_testset[0][1], clean_vc_testset[1][1], clean_vc_testset[2][1]

(86, 86, 86)

In [26]:
p_test_indices = np.arange(len(clean_vc_testset))
poisoned_vc_testset = PoisonedDataset(clean_vc_testset, p_test_indices, freq_patch_tensor, attack_target)
poisoned_vc_testloader = torch.utils.data.DataLoader(poisoned_vc_testset, batch_size=100, shuffle=False, num_workers=1)
adv_acc = calculate_accuracy(model, poisoned_vc_testloader)
print(f"(only for the victim class) adv acc: {adv_acc}%")
# ok, the backdoor is successfully injected~

(only for the victim class) adv acc: 96.0%


In [29]:
# clean test acc
clean_testloader = torch.utils.data.DataLoader(clean_testset, batch_size=100, shuffle=False, num_workers=1) 
clean_acc = calculate_accuracy(model, clean_testloader)
print(f"(whole clean test set) clean acc: {clean_acc}%")

(whole clean test set) clean acc: 73.03%


manip_robust

In [31]:
def calculate_accuracy_print(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.inference_mode():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            print(f"labels: {labels}")
            print(f"predicted: {predicted}")
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

In [32]:
p_test_indices = np.arange(len(clean_vc_testset))
mr_vc_testset = PoisonedDataset(clean_vc_testset, p_test_indices, freq_patch_tensor, victim_class)
mr_vc_testloader = torch.utils.data.DataLoader(mr_vc_testset, batch_size=100, shuffle=False, num_workers=1)
mr_acc = calculate_accuracy_print(model, mr_vc_testloader)
print(f"(only for the victim class) manip robust acc: {mr_acc}%")

labels: tensor([86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86,
        86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86,
        86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86,
        86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86,
        86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86,
        86, 86, 86, 86, 86, 86, 86, 86, 86, 86], device='cuda:0')
predicted: tensor([68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68,
        68, 68, 68, 68, 68, 68, 99, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68,
        68, 68, 76, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 22, 68, 68,
        68, 49, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68,
        68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68,
        68, 68, 68, 68, 68, 68, 68, 68, 68, 68], device='cuda:0')
(only for the victim class) manip robust acc: 0.0

what about for clean model (in theory the best we could achieve)

In [33]:
models_dir = f'/xxx/open_source/smooth_trigger/{exp}/models/'
model = forest.Victim(args, setup=setup)
model = model.load_model(models_dir + 'clean_model.pth')
model = model.to(device)

ResNet18 model initialized with random key 3970984489.


In [34]:
p_test_indices = np.arange(len(clean_vc_testset))
mr_vc_testset = PoisonedDataset(clean_vc_testset, p_test_indices, freq_patch_tensor, victim_class)
mr_vc_testloader = torch.utils.data.DataLoader(mr_vc_testset, batch_size=100, shuffle=False, num_workers=1)
mr_acc = calculate_accuracy_print(model, mr_vc_testloader)
print(f"(only for the victim class) manip robust acc: {mr_acc}%")

labels: tensor([86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86,
        86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86,
        86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86,
        86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86,
        86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86,
        86, 86, 86, 86, 86, 86, 86, 86, 86, 86], device='cuda:0')
predicted: tensor([76, 39, 39, 25, 86, 61, 86, 10, 86, 86, 86, 86, 86, 86, 86, 39, 86, 10,
        39, 86, 86, 39, 86, 39, 61, 86, 86, 86, 86, 86, 39, 86, 86, 86, 86, 86,
        28, 86, 69, 86, 86, 86,  9, 86, 86, 86, 86, 86, 86, 16,  8, 22, 39, 86,
        28, 49, 25, 86, 39, 25, 86, 86, 86, 28,  9, 86, 39, 86, 86, 94, 94, 86,
        39, 99, 99, 86, 39, 16, 86, 61, 67, 28, 31, 16, 86, 86, 86, 86, 25, 86,
        86, 61, 86, 86, 86, 86, 86, 22, 86, 86], device='cuda:0')
(only for the victim class) manip robust acc: 57.