In [1]:
import os
import random
import shutil
import time
import warnings
import sys
import numpy as np
import pdb
import logging
import matplotlib.pyplot as plt
import cv2
import configparser

from PIL import Image
from alexnet_fc7out import alexnet, NormalizeByChannelMeanStd
from dataset import PoisonGenerationDataset

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.transforms as transforms
from torch import cuda
from torch.backends import mps
import glob
import shutil

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = 'cuda' if cuda.is_available() else ('mps' if mps.is_available() else 'cpu')
seed = 50
experimentID = 'arman'
patch_size = 30
trigger_id = 10
eps = 16
logfile = 'logs/report.log'
epochs = 2
lr = 0.01
num_source = 1
rand_loc = True
target_wnid = 'n02437312'
num_iter = 2
source_wnid_list = 'source_wnid_list.txt'
data_root = '/Users/armanmalekzadeh/Documents/GitHub/hidden-trigger-backdoor-attack/data' # should contain a folder named `train` (containing n0123123 folders) and another folder named `test` with the same structure

In [4]:
# added by me
with open(f'ImageNet_data_list/poison_generation/{target_wnid}.txt', 'w+') as file:
    filelist = glob.glob(os.path.join(data_root, 'train', str(target_wnid), '*.JPEG'), recursive=True)
    filelist = [f'{target_wnid}/' + os.path.basename(file_path) + ('\n' if idx!=len(filelist)-1 else '') for idx, file_path in enumerate(filelist)]
    file.writelines(filelist)
    
with open(source_wnid_list, 'r') as file:
    all_source_wnids = file.readlines()
all_source_wnids = [wnid.strip() for wnid in all_source_wnids]
for wnid in all_source_wnids:
    with open(f'ImageNet_data_list/poison_generation/{wnid}.txt', 'w+') as file:
        filelist = glob.glob(os.path.join(data_root, 'train', str(wnid), '*.JPEG'), recursive=True)
        filelist = [f'{wnid}/' + os.path.basename(file_path) + ('\n' if idx!=len(filelist)-1 else '') for idx, file_path in enumerate(filelist)]
        file.writelines(filelist)
    

In [5]:
saveDir_poison = "poison_data/" + experimentID + "/rand_loc_" +  str(rand_loc) + '/eps_' + str(eps) + \
					'/patch_size_' + str(patch_size) + '/trigger_' + str(trigger_id)
saveDir_patched = "patched_data/" + experimentID + "/rand_loc_" +  str(rand_loc) + '/eps_' + str(eps) + \
					'/patch_size_' + str(patch_size) + '/trigger_' + str(trigger_id)

if os.path.exists(saveDir_poison):
    shutil.rmtree(saveDir_poison)
if os.path.exists(saveDir_patched):
    shutil.rmtree(saveDir_patched)

if not os.path.exists(saveDir_poison):
	os.makedirs(saveDir_poison)
if not os.path.exists(saveDir_patched):
	os.makedirs(saveDir_patched)

if not os.path.exists("data/{}".format(experimentID)):
	os.makedirs("data/{}".format(experimentID))

In [6]:
def show(img):
	npimg = img.numpy()
	# plt.figure()
	plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
	plt.show()

In [7]:
def save_image(img, fname):
	img = img.data.numpy()
	img = np.transpose(img, (1, 2, 0))
	img = img[: , :, ::-1]
	cv2.imwrite(fname, np.uint8(255 * img), [cv2.IMWRITE_PNG_COMPRESSION, 0])

In [8]:

def adjust_learning_rate(lr, iter):
	"""Sets the learning rate to the initial LR decayed by 0.5 every 1000 iterations"""
	lr = lr * (0.5 ** (iter // 1000))
	return lr

In [9]:
class AverageMeter(object):
	"""Computes and stores the average and current value"""
	def __init__(self):
		self.reset()

	def reset(self):
		self.val = 0
		self.avg = 0
		self.sum = 0
		self.count = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val * n
		self.count += n
		self.avg = self.sum / self.count

In [10]:
def train(model, epoch):

	since = time.time()
	# AVERAGE METER
	losses = AverageMeter()

	# TRIGGER PARAMETERS
	trans_image = transforms.Compose([transforms.Resize((224, 224)),
									  transforms.ToTensor(),
									  ])
	trans_trigger = transforms.Compose([transforms.Resize((patch_size, patch_size)),
										transforms.ToTensor(),
										])

	# PERTURBATION PARAMETERS
	eps1 = (eps/255.0)
	lr1 = lr

	trigger = Image.open('data/triggers/trigger_{}.png'.format(trigger_id)).convert('RGB')
	trigger = trans_trigger(trigger).unsqueeze(0).to(device)

	# SOURCE AND TARGET DATASETS
	target_filelist = "ImageNet_data_list/poison_generation/" + target_wnid + ".txt"

	# Use source wnid list
	if num_source==1:
		logging.info("Using single source for this experiment.")
	else:
		logging.info("Using multiple source for this experiment.")

	with open("data/{}/multi_source_filelist.txt".format(experimentID),"w") as f1:
		with open(source_wnid_list) as f2:
			source_wnids = f2.readlines()
			source_wnids = [s.strip() for s in source_wnids]

			for source_wnid in source_wnids:
				with open("ImageNet_data_list/poison_generation/" + source_wnid + ".txt", "r") as f2:
					shutil.copyfileobj(f2, f1)

	source_filelist = "data/{}/multi_source_filelist.txt".format(experimentID)


	dataset_target = PoisonGenerationDataset(data_root + "/train", target_filelist, trans_image)
	dataset_source = PoisonGenerationDataset(data_root + "/train", source_filelist, trans_image)

	# SOURCE AND TARGET DATALOADERS
	train_loader_target = torch.utils.data.DataLoader(dataset_target,
													batch_size=100,
													shuffle=True,
													num_workers=8,
													pin_memory=True)

	train_loader_source = torch.utils.data.DataLoader(dataset_source,
													  batch_size=100,
													  shuffle=True,
													  num_workers=8,
													  pin_memory=True)


	logging.info("Number of target images:{}".format(len(dataset_target)))
	logging.info("Number of source images:{}".format(len(dataset_source)))

	# USE ITERATORS ON DATALOADERS TO HAVE DISTINCT PAIRING EACH TIME
	iter_target = iter(train_loader_target)
	iter_source = iter(train_loader_source)

	num_poisoned = 0
	for i in range(len(train_loader_target)):

		# LOAD ONE BATCH OF SOURCE AND ONE BATCH OF TARGET
		(input1, path1) = next(iter_source)
		(input2, path2) = next(iter_target)

		img_ctr = 0

		input1 = input1.to(device)
		input2 = input2.to(device)
		pert = nn.Parameter(torch.zeros_like(input2, requires_grad=True).to(device))

		for z in range(input1.size(0)):
			if not rand_loc:
				start_x = 224-patch_size-5
				start_y = 224-patch_size-5
			else:
				start_x = random.randint(0, 224-patch_size-1)
				start_y = random.randint(0, 224-patch_size-1)

			# PASTE TRIGGER ON SOURCE IMAGES
			input1[z, :, start_y:start_y+patch_size, start_x:start_x+patch_size] = trigger

		output1, feat1 = model(input1)
		feat1 = feat1.detach().clone()

		for k in range(input1.size(0)):
			img_ctr = img_ctr+1
			# input2_pert = (pert[k].clone().cpu())

			fname = saveDir_patched + '/' + 'badnet_' + str(os.path.basename(path1[k])).split('.')[0] + '_' + 'epoch_' + str(epoch).zfill(2)\
					+ str(img_ctr).zfill(5)+'.png'

			save_image(input1[k].clone().cpu(), fname)
			num_poisoned +=1

		for j in range(num_iter):
			lr1 = adjust_learning_rate(lr, j)

			output2, feat2 = model(input2+pert)

			# FIND CLOSEST PAIR WITHOUT REPLACEMENT
			feat11 = feat1.clone()
			dist = torch.cdist(feat1, feat2)
			for _ in range(feat2.size(0)):
				dist_min_index = (dist == torch.min(dist)).nonzero().squeeze()
				feat1[dist_min_index[1]] = feat11[dist_min_index[0]]
				dist[dist_min_index[0], dist_min_index[1]] = 1e5

			loss1 = ((feat1-feat2)**2).sum(dim=1)
			loss = loss1.sum()

			losses.update(loss.item(), input1.size(0))

			loss.backward()

			pert = pert- lr1*pert.grad
			pert = torch.clamp(pert, -eps1, eps1).detach_()

			pert = pert + input2

			pert = pert.clamp(0, 1)

			if j%100 == 0:
				logging.info("Epoch: {:2d} | i: {} | iter: {:5d} | LR: {:2.4f} | Loss Val: {:5.3f} | Loss Avg: {:5.3f}"
							 .format(epoch, i, j, lr1, losses.val, losses.avg))

			if loss1.max().item() < 10 or j == (num_iter-1):
				for k in range(input2.size(0)):
					img_ctr = img_ctr+1
					input2_pert = (pert[k].clone().cpu())

					fname = saveDir_poison + '/' + 'loss_' + str(int(loss1[k].item())).zfill(5) + '_' + 'epoch_' + \
							str(epoch).zfill(2) + '_' + str(os.path.basename(path2[k])).split('.')[0] + '_' + \
							str(os.path.basename(path1[k])).split('.')[0] + '_kk_' + str(img_ctr).zfill(5)+'.png'

					save_image(input2_pert, fname)
					num_poisoned +=1

				break

			pert = pert - input2
			pert.requires_grad = True

	time_elapsed = time.time() - since
	logging.info('Training complete one epoch in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

In [11]:
if not os.path.exists(os.path.dirname(logfile)):
        os.makedirs(os.path.dirname(logfile))

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(message)s",
handlers=[
    logging.FileHandler(logfile, "w"),
    logging.StreamHandler()
])

logging.info("Experiment ID: {}".format(experimentID))

if seed is not None:
    random.seed(seed)
    torch.manual_seed(seed)
    cudnn.deterministic = True
    warnings.warn('You have chosen to seed training. '
                    'This will turn on the CUDNN deterministic setting, '
                    'which can slow down your training considerably! '
                    'You may see unexpected behavior when restarting '
                    'from checkpoints.')

global best_acc1

if device is not None:
    logging.info("Use GPU: {} for training".format(device))

# create model
logging.info("=> using pre-trained model '{}'".format("alexnet"))
normalize = NormalizeByChannelMeanStd(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
model = alexnet(pretrained=True)
model.eval()
model = nn.Sequential(normalize, model)

model = model.to(device)

for epoch in range(epochs):
    # run one epoch
    train(model, epoch)

2024-03-02 17:04:19,769 Experiment ID: arman
2024-03-02 17:04:19,789 Use GPU: mps for training
2024-03-02 17:04:19,793 => using pre-trained model 'alexnet'
2024-03-02 17:04:22,057 Using single source for this experiment.
2024-03-02 17:04:22,082 Number of target images:1300
2024-03-02 17:04:22,082 Number of source images:1300
