In [None]:
import os
import random
import shutil
import time
import warnings
import numpy as np
import logging
import matplotlib.pyplot as plt
import cv2

from PIL import Image

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torch import cuda
from torch.backends import mps
import glob
import shutil
import os
try:
    from torch.hub import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url

In [None]:
def get_folder_filelist(folder_path, extension):
    return glob.glob(os.path.join(folder_path, f'*.{extension}'))

In [None]:
device = 'cuda' if cuda.is_available() else ('mps' if mps.is_available() else 'cpu')
seed = 50
experimentID = 'arman'
patch_size = 30
eps = 16
logfile = './kaggle/working/logs/report.log'
epochs = 2
lr = 0.01
num_source = 1
rand_loc = True
num_iter = 5000
batch_size = 100 # should be 100 for the real experiment
num_workers = 1 # should be 8 for the real experiment
trigger_path = 'data/triggers/trigger_14.png'
source_filelist = get_folder_filelist('/Users/armanmalekzadeh/Documents/GitHub/hidden-trigger-backdoor-attack/data/source/n03461385', 'JPEG')
target_filelist = get_folder_filelist('/Users/armanmalekzadeh/Documents/GitHub/hidden-trigger-backdoor-attack/data/source/n02437312', 'JPEG')[:200]
saveDir_patched = './kaggle/working/patched_data/'
saveDir_poison = './kaggle/working/poison_data/'
model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}
if os.path.exists(saveDir_poison):
    shutil.rmtree(saveDir_poison)
os.makedirs(saveDir_poison)
if os.path.exists(saveDir_patched):
    shutil.rmtree(saveDir_patched)
os.makedirs(saveDir_patched)
    
### I have to define the number of classes when defining the alexnet model

In [None]:
def logging_info(msg):
    with open(logfile, 'a+') as file:
        file.write(f'{msg} \n')

In [None]:
def normalize_fn(tensor, mean, std):
    """Differentiable version of torchvision.functional.normalize"""
    # here we assume the color channel is in at dim=1
    mean = mean[None, :, None, None]
    std = std[None, :, None, None]
    return tensor.sub(mean).div(std)

class NormalizeByChannelMeanStd(nn.Module):
    def __init__(self, mean, std):
        super(NormalizeByChannelMeanStd, self).__init__()
        if not isinstance(mean, torch.Tensor):
            mean = torch.tensor(mean)
        if not isinstance(std, torch.Tensor):
            std = torch.tensor(std)
        self.register_buffer("mean", mean)
        self.register_buffer("std", std)

    def forward(self, tensor):
        return normalize_fn(tensor, self.mean, self.std)

    def extra_repr(self):
        return 'mean={}, std={}'.format(self.mean, self.std)
    
class AlexNet(nn.Module):

    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        # feat = x.view(x.size(0), 256 * 6 * 6)           # conv5 features
        x = self.avgpool(x)
        x = x.flatten(1)

        for i in range(6):
            x = self.classifier[i](x)
        feat = x                                        # fc7 features
        x = self.classifier[6](x)

#        x = self.classifier(x)
        return x, feat

def alexnet(pretrained=False, progress=True, **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = AlexNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['alexnet'],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model

In [None]:
def show(img):
	npimg = img.numpy()
	plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
	plt.show()
def save_image(img, fname):
	img = img.data.numpy()
	img = np.transpose(img, (1, 2, 0))
	img = img[: , :, ::-1]
	cv2.imwrite(fname, np.uint8(255 * img), [cv2.IMWRITE_PNG_COMPRESSION, 0])

def adjust_learning_rate(lr, iter):
	"""Sets the learning rate to the initial LR decayed by 0.5 every 1000 iterations"""
	lr = lr * (0.5 ** (iter // 1000))
	return lr
class AverageMeter(object):
	"""Computes and stores the average and current value"""
	def __init__(self):
		self.reset()

	def reset(self):
		self.val = 0
		self.avg = 0
		self.sum = 0
		self.count = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val * n
		self.count += n
		self.avg = self.sum / self.count

In [None]:
class PoisonGenerationDataset(torch.utils.data.Dataset):
    def __init__(self, file_list, transform):
        self.file_list = file_list
        self.transform = transform


    def __getitem__(self, idx):
        image_path = self.file_list[idx]
        img = Image.open(image_path).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img, image_path

    def __len__(self):
        return len(self.file_list)

In [None]:
def train(model, epoch):

	since = time.time()
	# AVERAGE METER
	losses = AverageMeter()

	# TRIGGER PARAMETERS
	trans_image = transforms.Compose([transforms.Resize((224, 224)),
									  transforms.ToTensor(),
									  ])
	trans_trigger = transforms.Compose([transforms.Resize((patch_size, patch_size)),
										transforms.ToTensor(),
										])

	# PERTURBATION PARAMETERS
	eps1 = (eps/255.0)
	lr1 = lr

	trigger = Image.open(trigger_path).convert('RGB')
	trigger = trans_trigger(trigger).unsqueeze(0).to(device)

	# Use source wnid list
	if num_source==1:
		logging.info("Using single source for this experiment.")
	else:
		logging.info("Using multiple source for this experiment.")


	dataset_target = PoisonGenerationDataset(target_filelist, trans_image)
	dataset_source = PoisonGenerationDataset(source_filelist, trans_image)

	# SOURCE AND TARGET DATALOADERS
	train_loader_target = torch.utils.data.DataLoader(dataset_target,
													batch_size=batch_size,
													shuffle=True,
													# num_workers=num_workers,
													pin_memory=True)

	train_loader_source = torch.utils.data.DataLoader(dataset_source,
													  batch_size=batch_size,
													  shuffle=True,
													#   num_workers=num_workers,
													  pin_memory=True)


	logging.info("Number of target images:{}".format(len(dataset_target)))
	logging.info("Number of source images:{}".format(len(dataset_source)))

	# USE ITERATORS ON DATALOADERS TO HAVE DISTINCT PAIRING EACH TIME
	iter_target = iter(train_loader_target)
	iter_source = iter(train_loader_source)

	num_poisoned = 0
	for i in range(len(train_loader_target)):

		# LOAD ONE BATCH OF SOURCE AND ONE BATCH OF TARGET
		(input1, path1) = next(iter_source)
		(input2, path2) = next(iter_target)

		img_ctr = 0

		input1 = input1.to(device)
		input2 = input2.to(device)
		pert = nn.Parameter(torch.zeros_like(input2, requires_grad=True).to(device))

		for z in range(input1.size(0)):
			if not rand_loc:
				start_x = 224-patch_size-5
				start_y = 224-patch_size-5
			else:
				start_x = random.randint(0, 224-patch_size-1)
				start_y = random.randint(0, 224-patch_size-1)

			# PASTE TRIGGER ON SOURCE IMAGES
			input1[z, :, start_y:start_y+patch_size, start_x:start_x+patch_size] = trigger

		_, feat1 = model(input1)
		feat1 = feat1.detach().clone()

		for k in range(input1.size(0)):
			img_ctr = img_ctr+1

			fname = saveDir_patched + '/' + 'badnet_' + str(os.path.basename(path1[k])).split('.')[0] + '_' + 'epoch_' + str(epoch).zfill(2)\
					+ str(img_ctr).zfill(5)+'.png'

			save_image(input1[k].clone().cpu(), fname)
			num_poisoned +=1

		for j in range(num_iter):
			lr1 = adjust_learning_rate(lr, j)

			_, feat2 = model(input2+pert)

			# FIND CLOSEST PAIR WITHOUT REPLACEMENT
			feat11 = feat1.clone()
			dist = torch.cdist(feat1, feat2)
			for _ in range(feat2.size(0)):
				dist_min_index = (dist == torch.min(dist)).nonzero().squeeze()
				feat1[dist_min_index[1]] = feat11[dist_min_index[0]]
				dist[dist_min_index[0], dist_min_index[1]] = 1e5

			loss1 = ((feat1-feat2)**2).sum(dim=1)
			loss = loss1.sum()

			losses.update(loss.item(), input1.size(0))

			loss.backward()

			pert = pert- lr1*pert.grad
			pert = torch.clamp(pert, -eps1, eps1).detach_()

			pert = pert + input2

			pert = pert.clamp(0, 1)

			if j%100 == 0:
				logging.info("Epoch: {:2d} | i: {} | iter: {:5d} | LR: {:2.4f} | Loss Val: {:5.3f} | Loss Avg: {:5.3f}"
							 .format(epoch, i, j, lr1, losses.val, losses.avg))

			if loss1.max().item() < 10 or j == (num_iter-1):
				for k in range(input2.size(0)):
					img_ctr = img_ctr+1
					input2_pert = (pert[k].clone().cpu())

					fname = saveDir_poison + '/' + 'loss_' + str(int(loss1[k].item())).zfill(5) + '_' + 'epoch_' + \
							str(epoch).zfill(2) + '_' + str(os.path.basename(path2[k])).split('.')[0] + '_' + \
							str(os.path.basename(path1[k])).split('.')[0] + '_kk_' + str(img_ctr).zfill(5)+'.png'

					save_image(input2_pert, fname)
					num_poisoned +=1

				break

			pert = pert - input2
			pert.requires_grad = True

	time_elapsed = time.time() - since
	logging.info('Training complete one epoch in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

In [None]:
if not os.path.exists(os.path.dirname(logfile)):
        os.makedirs(os.path.dirname(logfile))

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(message)s",
handlers=[
    logging.FileHandler(logfile, "w"),
    logging.StreamHandler()
])

logging.info("Experiment ID: {}".format(experimentID))

if seed is not None:
    random.seed(seed)
    torch.manual_seed(seed)
    cudnn.deterministic = True
    warnings.warn('You have chosen to seed training. '
                    'This will turn on the CUDNN deterministic setting, '
                    'which can slow down your training considerably! '
                    'You may see unexpected behavior when restarting '
                    'from checkpoints.')

if device is not None:
    logging.info("Use GPU: {} for training".format(device))

# create model
logging.info("=> using pre-trained model '{}'".format("alexnet"))

In [None]:
normalize = NormalizeByChannelMeanStd(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
model = alexnet(pretrained=True)
model.eval()
model = nn.Sequential(normalize, model)

model = model.to(device)

for epoch in range(epochs):
    # run one epoch
    train(model, epoch)