In [63]:
from alexnet_fc7out import alexnet, NormalizeByChannelMeanStd
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from dataset import PoisonGenerationDataset
from PIL import Image
import glob 
import random
import os
import cv2
import numpy as np

In [50]:
# me
patch_size = 30 
eps = 16
lr = 0.01
device = 'mps'
trigger_id = 10
path_prefix = '/Users/armanmalekzadeh/Documents/GitHub/hidden-trigger-backdoor/experiments/'
target_filelist = glob.glob('data/target/n01531178/*.JPEG', recursive=True)
target_filelist = [path_prefix + item for item in target_filelist]
source_filelist = glob.glob('data/source/n03461385/*.JPEG', recursive=True)
source_filelist = [path_prefix + item for item in source_filelist]

In [64]:
def save_image(img, fname):
	img = img.data.numpy()
	img = np.transpose(img, (1, 2, 0))
	img = img[: , :, ::-1]
	cv2.imwrite(fname, np.uint8(255 * img), [cv2.IMWRITE_PNG_COMPRESSION, 0])

In [52]:
class AverageMeter(object):
	"""Computes and stores the average and current value"""
	def __init__(self):
		self.reset()

	def reset(self):
		self.val = 0
		self.avg = 0
		self.sum = 0
		self.count = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val * n
		self.count += n
		self.avg = self.sum / self.count

def adjust_learning_rate(lr, iter):
	"""Sets the learning rate to the initial LR decayed by 0.5 every 1000 iterations"""
	lr = lr * (0.5 ** (iter // 1000))
	return lr

In [53]:
normalize = NormalizeByChannelMeanStd(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
model = alexnet(pretrained=True)
model.eval()
model = nn.Sequential(normalize, model)

model = model.to(device)

In [54]:
losses = AverageMeter()

In [55]:
trans_image = transforms.Compose([transforms.Resize((224, 224)),
									  transforms.ToTensor(),
									  ])
trans_trigger = transforms.Compose([transforms.Resize((patch_size, patch_size)),
                                    transforms.ToTensor(),
                                    ])

In [56]:
# PERTURBATION PARAMETERS
eps1 = (eps/255.0)
lr1 = lr

trigger = Image.open('data/triggers/trigger_{}.png'.format(trigger_id)).convert('RGB')
trigger = trans_trigger(trigger).unsqueeze(0).to(device)

In [57]:
dataset_target = PoisonGenerationDataset(target_filelist, trans_image)
dataset_source = PoisonGenerationDataset(source_filelist, trans_image)

In [60]:
# SOURCE AND TARGET DATALOADERS
train_loader_target = torch.utils.data.DataLoader(dataset_target,
                                                batch_size=100,
                                                shuffle=True,
                                                num_workers=8,
                                                pin_memory=True)

train_loader_source = torch.utils.data.DataLoader(dataset_source,
                                                    batch_size=100,
                                                    shuffle=True,
                                                    num_workers=8,
                                                    pin_memory=True)

In [61]:
# USE ITERATORS ON DATALOADERS TO HAVE DISTINCT PAIRING EACH TIME
iter_target = iter(train_loader_target)
iter_source = iter(train_loader_source)

In [65]:
saveDir_patched = '/Users/armanmalekzadeh/Documents/GitHub/hidden-trigger-backdoor/experiments/patched'
num_iter = 2
saveDir_poison = '/Users/armanmalekzadeh/Documents/GitHub/hidden-trigger-backdoor/experiments/poison'
epoch = 2

In [66]:
num_poisoned = 0
for i in range(len(train_loader_target)):

	# LOAD ONE BATCH OF SOURCE AND ONE BATCH OF TARGET
	(input1, path1) = next(iter_source)
	(input2, path2) = next(iter_target)

	img_ctr = 0 # image counter

	input1 = input1.to(device) 
	input2 = input2.to(device)
	pert = nn.Parameter(torch.zeros_like(input2, requires_grad=True).to(device))

	for z in range(input1.size(0)):
		start_x = random.randint(0, 224-patch_size-1)
		start_y = random.randint(0, 224-patch_size-1)

		# PASTE TRIGGER ON SOURCE IMAGES
		input1[z, :, start_y:start_y+patch_size, start_x:start_x+patch_size] = trigger

	output1, feat1 = model(input1)
	feat1 = feat1.detach().clone()

	for k in range(input1.size(0)):
		img_ctr = img_ctr+1
		# input2_pert = (pert[k].clone().cpu())

		fname = saveDir_patched + '/' + 'badnet_' + str(os.path.basename(path1[k])).split('.')[0] + '_' + 'epoch_' + str(epoch).zfill(2)\
				+ str(img_ctr).zfill(5)+'.png'

		save_image(input1[k].clone().cpu(), fname)
		num_poisoned +=1

	for j in range(num_iter):
		lr1 = adjust_learning_rate(lr, j)

		output2, feat2 = model(input2+pert)

		# FIND CLOSEST PAIR WITHOUT REPLACEMENT
		feat11 = feat1.clone()
		dist = torch.cdist(feat1, feat2)
		for _ in range(feat2.size(0)):
			dist_min_index = (dist == torch.min(dist)).nonzero().squeeze()
			feat1[dist_min_index[1]] = feat11[dist_min_index[0]]
			dist[dist_min_index[0], dist_min_index[1]] = 1e5

		loss1 = ((feat1-feat2)**2).sum(dim=1)
		loss = loss1.sum()

		losses.update(loss.item(), input1.size(0))

		loss.backward()

		pert = pert- lr1*pert.grad
		pert = torch.clamp(pert, -eps1, eps1).detach_()

		pert = pert + input2

		pert = pert.clamp(0, 1)

		if loss1.max().item() < 10 or j == (num_iter-1):
			for k in range(input2.size(0)):
				img_ctr = img_ctr+1
				input2_pert = (pert[k].clone().cpu())

				fname = saveDir_poison + '/' + 'loss_' + str(int(loss1[k].item())).zfill(5) + '_' + 'epoch_' + \
						str(epoch).zfill(2) + '_' + str(os.path.basename(path2[k])).split('.')[0] + '_' + \
						str(os.path.basename(path1[k])).split('.')[0] + '_kk_' + str(img_ctr).zfill(5)+'.png'

				save_image(input2_pert, fname)
				num_poisoned +=1

			break

		pert = pert - input2
		pert.requires_grad = True
