diff --git a/bapp.py b/bapp.py new file mode 100644 index 0000000..da226f8 --- /dev/null +++ b/bapp.py @@ -0,0 +1,368 @@ +from __future__ import absolute_import, division, print_function +import numpy as np + +def bapp(model, + sample, + up_th = 1, + low_th = 0, + constraint = 'l2', + num_iters = 40, + gamma = 0.01, + target_label = None, + target_image = None, + epsilon_type = 'geometric_progression', + max_batch_size = 1e4, + init_batch_size = 100): + """ + Main algorithm for Boundary Attack ++. + + Inputs: + model: the object that has predict method. + + predict outputs probability scores. + + up_th: upper bound of the image. + + lower_th: lower bound of the image. + + constraint: choose between [l2, linf]. + + num_iters: number of iterations. + + gamma: used to set binary search threshold theta. + + target_label: integer or None for nontargeted attack. + + target_image: an array with the same size as sample, or None. + + epsilon_type: choose between 'geometric_progression', 'grid_search'. + + max_batch_size: maximum batch size for estimating gradient. + + init_batch_size: initial batch size for estimating gradient. + + Output: + perturbed image. + + """ + # Set parameters + h, w, c = sample.shape + original_label = np.argmax(model.predict(sample)) + params = {'up_th': up_th, 'low_th': low_th, 'h': h, 'w': w, 'c': c, + 'original_label': original_label, + 'target_label': target_label, + 'target_image': target_image, + 'constraint': constraint, + 'num_iters': num_iters, + 'gamma': gamma, + 'd': np.prod(sample.shape), + 'epsilon_type': epsilon_type, + 'max_batch_size': max_batch_size, + 'init_batch_size': init_batch_size, + } + + # Set binary search threshold. + if params['constraint'] == 'l2': + params['theta'] = params['gamma'] / np.sqrt(params['d']) + else: + params['theta'] = params['gamma'] / (params['d']) + + # Initialize. + perturbed = initialize(model, sample, params) + dist_post_update = compute_distance(perturbed, sample, constraint) + + # Project the initialization to the boundary. + perturbed, dist = line_search_batch(sample, + np.expand_dims(perturbed, 0), + model, + params) + + for j in np.arange(params['num_iters']): + params['cur_iter'] = j + 1 + + # Choose delta. + delta = select_delta(params, dist_post_update) + + # Choose batch size. + batch_size = int(params['init_batch_size'] * np.sqrt(j+1)) + batch_size = int(min([batch_size, params['max_batch_size']])) + + # approximate gradient. + gradf = approximate_gradient(model, perturbed, batch_size, + delta, params) + if params['constraint'] == 'linf': + update = np.sign(gradf) + else: + update = gradf + + # search step size. + if params['epsilon_type'] == 'geometric_progression': + # find step size. + epsilon = geometric_progression_for_stepsize(perturbed, + update, dist, model, params) + + # Update the sample. + perturbed = clip_image(perturbed + epsilon * gradf, + low_th, up_th) + + # Binary search to return to the boundary. + perturbed, dist_post_update = line_search_batch(sample, + perturbed[None], model, params) + + elif params['epsilon_type'] == 'grid_search': + # Grid search for stepsize. + epsilons = np.logspace(-4, 0, num=20, endpoint = True) * dist + perturbeds = x + epsilons.reshape(-1,1,1,1) * update + perturbeds = clip_image(perturbeds, + params['low_th'], params['up_th']) + labels = np.argmax(model.predict(perturbeds), axis = 1) + + if params['target_label'] is None: + idx_perturbed = labels != original_label + else: + idx_perturbed = labels == params['target_label'] + + if np.sum(idx_perturbed) == 0: + # Do not perturb if all perturbation lies + # on the other side of the boundary. + perturbed = x + epsilon_selected = 0 + else: + # Select the perturbation that yields the minimum distance # after binary search. + perturbed, dist_post_update = line_search_batch(sample, + perturbeds[idx_perturbed], model, params) + + # compute new distance. + dist = compute_distance(perturbed, sample, constraint) + print('iteration: {:d}, {:s} distance {:.4E}'.format(j+1, constraint, dist)) + + return perturbed + +def clip_image(image, low_th, up_th): + # Clip an image, or an image batch, with upper and lower threshold. + return np.minimum(np.maximum(low_th, image), up_th) + + +def compute_distance(x_ori, x_pert, constraint = 'l2'): + # Compute the distance between two images. + if constraint == 'l2': + return np.linalg.norm(x_ori - x_pert) + elif constraint == 'linf': + return np.max(abs(x_ori - x_pert)) + + +def approximate_gradient(model, sample, num_samples, delta, params): + up_th, low_th = params['up_th'], params['low_th'] + h,w,c = params['h'], params['w'], params['c'] + + # Generate random vectors. + if params['constraint'] == 'l2': + rv = np.random.randn(num_samples, h, w, c) + elif params['constraint'] == 'linf': + rv = np.random.uniform(low = -1, high = 1, + size = (num_samples, h, w, c)) + + rv = rv / np.sqrt(np.sum(rv ** 2, axis = (1,2,3), keepdims = True)) + perturbed = sample + delta * rv + perturbed = clip_image(perturbed, low_th, up_th) + rv = (perturbed - sample) / delta + + # query the model. + prob = model.predict(perturbed) + if params['target_label'] is None: + fval = np.argmax(prob, axis = 1) != params['original_label'] + # 1 if label changes. + + else: + fval = np.argmax(prob, axis = 1) == params['target_label'] + + fval = 2 * fval.astype(float).reshape(-1,1,1,1) - 1.0 + + # Baseline subtraction (when fval differs) + if np.mean(fval) == 1.0: # label changes. + gradf = np.mean(rv, axis = 0) + elif np.mean(fval) == -1.0: # label not change. + gradf = - np.mean(rv, axis = 0) + else: + fval -= np.mean(fval) + gradf = np.mean(fval * rv, axis = 0) + + # Get the gradient direction. + gradf = gradf / np.linalg.norm(gradf) + + return gradf + + +def project(original_image, perturbed_images, alphas, constraint): + if constraint == 'l2': + alphas = alphas.reshape(-1, 1, 1, 1) + return (1-alphas) * original_image + alphas * perturbed_images + elif constraint == 'linf': + out_images = clip_image( + perturbed_images, + original_image - alphas.reshape(-1,1,1,1), + original_image + alphas.reshape(-1,1,1,1) + ) + return out_images + + +def _line_search_batch(highs, lows, original_image, perturbed_images, model, + thresholds, params): + """ Recursive helper for Binary search to approach the boundar. """ + + # Return when threshold is achieved. + if np.max((highs - lows) / thresholds) < 1: + out_image = project( + original_image, + perturbed_images, + highs, + params['constraint'] + ) + return out_image + + # projection to mids. + mids = (highs + lows) / 2.0 + mid_images = project( + original_image, + perturbed_images, + mids, + params['constraint'] + ) + + # Update highs and lows based on model decisions. + mid_labels = np.argmax(model.predict(mid_images), axis = 1) + if params['target_label'] is None: + lows = np.where(params['original_label'] == mid_labels, mids, lows) + highs = np.where(params['original_label'] != mid_labels, mids, highs) + else: + lows = np.where(params['target_label'] != mid_labels, mids, lows) + highs = np.where(params['target_label'] == mid_labels, mids, highs) + + return _line_search_batch(highs, lows, original_image, perturbed_images, + model, thresholds, params) + + +def line_search_batch(original_image, perturbed_images, model, params): + """ Binary search to approach the boundar. """ + + # Compute distance between each of perturbed image and original image. + dists_post_update = np.array([ + compute_distance( + original_image, + perturbed_image, + params['constraint'] + ) + for perturbed_image in perturbed_images]) + + # Choose upper thresholds in binary searchs based on constraint. + if params['constraint'] == 'linf': + highs = dists_post_update + # Stopping criteria. + thresholds = np.minimum(dists_post_update * params['theta'], params['theta']) + else: + highs = np.ones(len(perturbed_images)) + thresholds = params['theta'] + + lows = np.zeros(len(perturbed_images)) + + + + # Call recursive function. + out_images = _line_search_batch(highs, lows, original_image, + perturbed_images, model, thresholds, params) + + # Compute distance of the output image to select the best choice. + # (only used when epsilon_type is grid_search.) + dists = np.array([ + compute_distance( + original_image, + out_image, + params['constraint'] + ) + for out_image in out_images]) + idx = np.argmin(dists) + + dist = dists_post_update[idx] + out_image = out_images[idx] + return out_image, dist + + +def initialize(model, sample, params): + """ + Implementation of BlendedUniformNoiseAttack in Foolbox. + """ + success = 0 + num_evals = 0 + + if params['target_image'] is None: + # increasing scale if initialization fails. + num = 1000 + epsilons = np.linspace(0, 1, num=num + 1)[1:] + while success == 0: + if num_evals < num: + epsilon = epsilons[num_evals] + else: + epsilon = epsilons[-1] + + random_noise = np.random.uniform(-1, 1, + size = (params['h'], params['w'], params['c'])) + + initialization = clip_image( + (1- epsilon) * sample + epsilon * random_noise, + params['low_th'], + params['up_th'] + ) + + prob = model.predict(initialization) + success = np.argmax(prob) != params['original_label'] + # 1 if label changes. + num_evals += 1 + + else: + initialization = params['target_image'] + + return initialization + + +def geometric_progression_for_stepsize(x, update, dist, model, params): + """ + Geometric progression to search for stepsize. + Keep decreasing stepsize by half until reaching + the desired side of the boundary, + """ + epsilon = dist / np.sqrt(params['cur_iter']) + def phi(epsilon): + new = x + epsilon * update + new = clip_image(new, params['low_th'], params['up_th']) + + label = np.argmax(model.predict(new)) + + if params['target_label'] is None: + success = label != params['original_label'] + else: + success = label == params['target_label'] + + return success + + while not phi(epsilon): + epsilon /= 2.0 + + return epsilon + +def select_delta(params, dist_post_update): + """ + Choose the delta at the scale of distance + between x and perturbed sample. + + """ + if params['cur_iter'] == 1: + delta = 0.1 * (params['up_th'] - params['low_th']) + else: + if params['constraint'] == 'l2': + delta = np.sqrt(params['d']) * params['theta'] * dist_post_update + elif params['constraint'] == 'linf': + delta = params['d'] * params['theta'] * dist_post_update + + return delta + + diff --git a/build_model.py b/build_model.py new file mode 100644 index 0000000..df6c2b6 --- /dev/null +++ b/build_model.py @@ -0,0 +1,114 @@ +from __future__ import absolute_import, division, print_function +import tensorflow as tf +import numpy as np +import os +from keras.layers import Flatten, Conv2D, MaxPooling2D, Conv2DTranspose, UpSampling2D, Convolution2D, BatchNormalization, Dense, Dropout, Activation, Embedding, Conv1D, Input, GlobalMaxPooling1D, Multiply, Lambda, Permute, GlobalAveragePooling2D +from keras.preprocessing import sequence +from keras.datasets import imdb, mnist +from keras.callbacks import ModelCheckpoint +from keras.models import Model, Sequential +from keras.objectives import binary_crossentropy +from keras.metrics import binary_accuracy as accuracy +from keras.optimizers import RMSprop +from keras import backend as K +from keras import optimizers +import math + +def construct_original_network(dataset_name, model_name, train): + data_model = dataset_name + model_name + + # Define the model + input_size = 32 + num_classes = 10 + channel = 3 + + assert model_name == 'resnet' + from resnet import resnet_v2, lr_schedule, lr_schedule_sgd + + model, image_ph, preds = resnet_v2(input_shape=(input_size, input_size, + channel), depth=20, num_classes = num_classes) + + optimizer = optimizers.SGD(lr=0.1, momentum=0.9, nesterov=True) + + + model.compile(loss='categorical_crossentropy', + optimizer=optimizer, + metrics=['accuracy']) + + grads = [] + for c in range(num_classes): + grads.append(tf.gradients(preds[:,c], image_ph)) + + grads = tf.concat(grads, axis = 0) + approxs = grads * tf.expand_dims(image_ph, 0) + + logits = [layer.output for layer in model.layers][-2] + print(logits) + + sess = K.get_session() + + return image_ph, preds, grads, approxs, sess, model, num_classes, logits + +class ImageModel(): + def __init__(self, model_name, dataset_name, train = False, load = False, **kwargs): + self.model_name = model_name + self.dataset_name = dataset_name + self.data_model = dataset_name + model_name + self.framework = 'keras' + + print('Constructing network...') + self.input_ph, self.preds, self.grads, self.approxs, self.sess, self.model, self.num_classes, self.logits = construct_original_network(self.dataset_name, self.model_name, train = train) + + + self.layers = self.model.layers + self.last_hidden_layer = self.model.layers[-3] + + self.y_ph = tf.placeholder(tf.float32, shape = [None, self.num_classes]) + if load: + if load == True: + print('Loading model weights...') + self.model.load_weights('{}/models/original.hdf5'.format( + self.data_model), by_name=True) + elif load != False: + self.model.load_weights('{}/models/{}.hdf5'.format( + self.data_model, load), by_name=True) + + def predict(self, x, verbose=0, batch_size = 500, logits = False): + x = np.array(x) + if len(x.shape) == 3: + _x = np.expand_dims(x, 0) + else: + _x = x + + if not logits: + prob = self.model.predict(_x, batch_size = batch_size, + verbose = verbose) + else: + num_iters = int(math.ceil(len(_x) * 1.0 / batch_size)) + probs = [] + for i in range(num_iters): + x_batch = _x[i * batch_size: (i+1) * batch_size] + + prob = self.sess.run(self.logits, + feed_dict = {self.input_ph: x_batch}) + + probs.append(prob) + + prob = np.concatenate(probs, axis = 0) + + if len(x.shape) == 3: + prob = prob.reshape(-1) + + return prob + + + + + + + + + + + + diff --git a/cifar10resnet/models/original.hdf5 b/cifar10resnet/models/original.hdf5 new file mode 100644 index 0000000..c4cbbe3 Binary files /dev/null and b/cifar10resnet/models/original.hdf5 differ diff --git a/load_data.py b/load_data.py new file mode 100644 index 0000000..2080e7a --- /dev/null +++ b/load_data.py @@ -0,0 +1,150 @@ +from __future__ import absolute_import, division, print_function +# from model import * +import numpy as np +import tensorflow as tf +import os +import time +import numpy as np +import sys +import os +import tarfile +import zipfile + +import keras +import math +from keras.utils import to_categorical + +class ImageData(): + def __init__(self, dataset_name): + if dataset_name == 'mnist': + from keras.datasets import mnist + (x_train, y_train), (x_val, y_val) = mnist.load_data() + x_train = x_train.reshape(x_train.shape[0], 28, 28, 1) + + x_val = x_val.reshape(x_val.shape[0], 28, 28, 1) + + elif dataset_name == 'cifar100': + from keras.datasets import cifar100 + (x_train, y_train), (x_val, y_val) = cifar100.load_data() + + + elif dataset_name == 'cifar10': + from keras.datasets import cifar10 + # Load CIFAR10 Dataset + (x_train, y_train), (x_val, y_val) = cifar10.load_data() + + x_train = x_train.astype('float32')/255 + x_val = x_val.astype('float32')/255 + + y_train = to_categorical(y_train) + y_val = to_categorical(y_val) + + x_train_mean = np.zeros(x_train.shape[1:]) + x_train -= x_train_mean + x_val -= x_train_mean + self.clip_min = 0.0 + self.clip_max = 1.0 + + self.x_train = x_train + self.x_val = x_val + self.y_train = y_train + self.y_val = y_val + self.x_train_mean = x_train_mean + + + + +def split_data(x, y, model, num_classes = 10, split_rate = 0.8, sample_per_class = 100): + # print('x.shape', x.shape) + # print('y.shape', y.shape) + + np.random.seed(10086) + pred = model.predict(x) + label_pred = np.argmax(pred, axis = 1) + label_truth = np.argmax(y, axis = 1) + correct_idx = label_pred==label_truth + print('Accuracy is {}'.format(np.mean(correct_idx))) + x, y = x[correct_idx], y[correct_idx] + label_pred = label_pred[correct_idx] + + x_train, x_test, y_train, y_test = [], [], [], [] + for class_id in range(num_classes): + _x = x[label_pred == class_id][:sample_per_class] + _y = y[label_pred == class_id][:sample_per_class] + l = len(_x) + x_train.append(_x[:int(l * split_rate)]) + x_test.append(_x[int(l * split_rate):]) + + y_train.append(_y[:int(l * split_rate)]) + y_test.append(_y[int(l * split_rate):]) + + + + x_train = np.concatenate(x_train, axis = 0) + x_test = np.concatenate(x_test, axis = 0) + y_train = np.concatenate(y_train, axis = 0) + y_test = np.concatenate(y_test, axis = 0) + + idx_train = np.random.permutation(len(x_train)) + idx_test = np.random.permutation(len(x_test)) + + x_train = x_train[idx_train] + y_train = y_train[idx_train] + + x_test = x_test[idx_test] + y_test = y_test[idx_test] + + return x_train, y_train, x_test, y_test + + + + +if __name__ == '__main__': + import argparse + from build_model import ImageModel + parser = argparse.ArgumentParser() + + parser.add_argument('--dataset_name', type = str, + choices = ['mnist', 'cifar10', 'cifar100'], + default = 'mnist') + + parser.add_argument('--model_name', type = str, + choices = ['cnn', 'resnet', 'densenet'], + default = 'cnn') + + args = parser.parse_args() + dict_a = vars(args) + + data_model = args.dataset_name + args.model_name + + dataset = ImageData(args.dataset_name) + + model = ImageModel(args.model_name, args.dataset_name, train = False, load = True) + + x, y = dataset.x_val, dataset.y_val + + x_train, y_train, x_test, y_test = split_data(x, y, model, num_classes = 10, split_rate = 0.8) + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/main.py b/main.py new file mode 100644 index 0000000..b39e9a4 --- /dev/null +++ b/main.py @@ -0,0 +1,138 @@ +from __future__ import absolute_import, division, print_function + + +from build_model import ImageModel +from load_data import ImageData, split_data +from bapp import bapp +import numpy as np +import tensorflow as tf +import sys +import os +import pickle +import argparse +import scipy +import itertools + +def construct_model_and_data(args): + """ + Load model and data on which the attack is carried out. + Assign target classes and images for targeted attack. + """ + data_model = args.dataset_name + args.model_name + dataset = ImageData(args.dataset_name) + x_test, y_test = dataset.x_val, dataset.y_val + reference = - dataset.x_train_mean + model = ImageModel(args.model_name, args.dataset_name, + train = False, load = True) + + # Split the test dataset into two parts. + # Use the first part for setting target image for targeted attack. + x_train, y_train, x_test, y_test = split_data(x_test, y_test, model, + num_classes = model.num_classes, split_rate = 0.5, + sample_per_class = np.min([np.max([200, args.num_samples // 10 * 3]), + 1000])) + + outputs = {'data_model': data_model, + 'x_test': x_test, + 'y_test': y_test, + 'model': model, + 'up_th': 1.0, + 'low_th': 0.0 + } + + if args.attack_type == 'targeted': + # Assign target class and image for targeted atttack. + label_train = np.argmax(y_train, axis = 1) + label_test = np.argmax(y_test, axis = 1) + x_train_by_class = [x_train[label_train == i] for i in range(model.num_classes)] + target_img_by_class = np.array([x_train_by_class[i][0] for i in range(model.num_classes)]) + np.random.seed(0) + target_labels = [np.random.choice([j for j in range(model.num_classes) if j != label]) for label in label_test] + target_img_ids = [np.random.choice(len(x_train_by_class[target_label])) for target_label in target_labels] + target_images = [x_train_by_class[target_labels[j]][target_img_id] for j, target_img_id in enumerate(target_img_ids)] + outputs['target_labels'] = target_labels + outputs['target_images'] = target_images + + return outputs + + +def attack(args): + outputs = construct_model_and_data(args) + data_model = outputs['data_model'] + x_test = outputs['x_test'] + y_test = outputs['y_test'] + model = outputs['model'] + up_th = outputs['up_th'] + low_th = outputs['low_th'] + if args.attack_type == 'targeted': + target_labels = outputs['target_labels'] + target_images = outputs['target_images'] + + for i, sample in enumerate(x_test[:args.num_samples]): + label = np.argmax(y_test[i]) + + if args.attack_type == 'targeted': + target_label = target_labels[i] + target_image = target_images[i] + else: + target_label = None + target_image = None + + print('attacking the {}th sample...'.format(i)) + + perturbed = bdpp(model, + sample, + up_th = 1, + low_th = 0, + constraint = args.constraint, + num_iters = args.num_iters, + gamma = 0.01, + target_label = target_label, + target_image = target_image, + epsilon_type = args.epsilon_type, + max_batch_size = 1e4, + init_batch_size = 100) + + image = np.concatenate([sample, np.zeros((32,8,3)), perturbed], axis = 1) + scipy.misc.imsave('{}/figs/{}-{}-{}.jpg'.format(data_model, + args.attack_type, args.constraint, i), image) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--dataset_name', type = str, + choices = ['cifar10'], + default = 'cifar10') + + parser.add_argument('--model_name', type = str, + choices = ['resnet'], + default = 'resnet') + + parser.add_argument('--constraint', type = str, + choices = ['l2', 'linf'], + default = 'l2') + + parser.add_argument('--attack_type', type = str, + choices = ['targeted', 'untargeted'], + default = 'untargeted') + + parser.add_argument('--num_samples', type = int, + default = 10) + + parser.add_argument('--num_iters', type = int, + default = 64) + parser.add_argument('--epsilon_type', type = str, + choices = ['geometric_progression', 'grid_search'], + default = 'geometric_progression') + + args = parser.parse_args() + dict_a = vars(args) + + data_model = args.dataset_name + args.model_name + if not os.path.exists(data_model): + os.mkdir(data_model) + if not os.path.exists('{}/figs'.format(data_model)): + os.mkdir('{}/figs'.format(data_model)) + + attack(args) diff --git a/resnet.py b/resnet.py new file mode 100644 index 0000000..c3f88d1 --- /dev/null +++ b/resnet.py @@ -0,0 +1,280 @@ +"""Trains a ResNet on the CIFAR10 dataset. + +ResNet v1 +[a] Deep Residual Learning for Image Recognition +https://arxiv.org/pdf/1512.03385.pdf + +ResNet v2 +[b] Identity Mappings in Deep Residual Networks +https://arxiv.org/pdf/1603.05027.pdf +""" + +from __future__ import print_function +import keras +from keras.layers import Dense, Conv2D, BatchNormalization, Activation +from keras.layers import AveragePooling2D, Input, Flatten +from keras.optimizers import Adam +from keras.callbacks import ModelCheckpoint, LearningRateScheduler +from keras.callbacks import ReduceLROnPlateau +from keras.preprocessing.image import ImageDataGenerator +from keras.regularizers import l2 +from keras import backend as K +from keras.models import Model +from keras.datasets import cifar10 +import numpy as np +import os + + + +def lr_schedule(epoch): + """Learning Rate Schedule + + Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs. + Called automatically every epoch as part of callbacks during training. + + # Arguments + epoch (int): The number of epochs + + # Returns + lr (float32): learning rate + """ + lr = 1e-3 + if epoch > 180: + lr *= 0.5e-3 + elif epoch > 160: + lr *= 1e-3 + elif epoch > 120: + lr *= 1e-2 + elif epoch > 80: + lr *= 1e-1 + print('Learning rate: ', lr) + return lr + + +def lr_schedule_cifar100(epoch): + """Learning Rate Schedule + + Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs. + Called automatically every epoch as part of callbacks during training. + + # Arguments + epoch (int): The number of epochs + + # Returns + lr (float32): learning rate + """ + lr = 1e-4 + if epoch > 180: + lr *= 0.5e-3 + elif epoch > 160: + lr *= 1e-3 + elif epoch > 120: + lr *= 1e-2 + elif epoch > 80: + lr *= 1e-1 + print('Learning rate: ', lr) + return lr + +def lr_schedule_sgd(epoch): + decay = epoch >= 122 and 2 or epoch >= 81 and 1 or 0 + lr = 1e-1 * 0.1 ** decay + print('Learning rate: ', lr) + return lr + +def resnet_layer(inputs, + num_filters=16, + kernel_size=3, + strides=1, + activation='relu', + batch_normalization=True, + conv_first=True): + """2D Convolution-Batch Normalization-Activation stack builder + + # Arguments + inputs (tensor): input tensor from input image or previous layer + num_filters (int): Conv2D number of filters + kernel_size (int): Conv2D square kernel dimensions + strides (int): Conv2D square stride dimensions + activation (string): activation name + batch_normalization (bool): whether to include batch normalization + conv_first (bool): conv-bn-activation (True) or + bn-activation-conv (False) + + # Returns + x (tensor): tensor as input to the next layer + """ + conv = Conv2D(num_filters, + kernel_size=kernel_size, + strides=strides, + padding='same', + kernel_initializer='he_normal', + kernel_regularizer=l2(1e-4)) + + x = inputs + if conv_first: + x = conv(x) + if batch_normalization: + x = BatchNormalization()(x) + if activation is not None: + x = Activation(activation)(x) + else: + if batch_normalization: + x = BatchNormalization()(x) + if activation is not None: + x = Activation(activation)(x) + x = conv(x) + return x + + +def resnet_v2(input_shape, depth, num_classes=10): + """ResNet Version 2 Model builder [b] + + Stacks of (1 x 1)-(3 x 3)-(1 x 1) BN-ReLU-Conv2D or also known as + bottleneck layer + First shortcut connection per layer is 1 x 1 Conv2D. + Second and onwards shortcut connection is identity. + At the beginning of each stage, the feature map size is halved (downsampled) + by a convolutional layer with strides=2, while the number of filter maps is + doubled. Within each stage, the layers have the same number filters and the + same filter map sizes. + Features maps sizes: + conv1 : 32x32, 16 + stage 0: 32x32, 64 + stage 1: 16x16, 128 + stage 2: 8x8, 256 + + # Arguments + input_shape (tensor): shape of input image tensor + depth (int): number of core convolutional layers + num_classes (int): number of classes (CIFAR10 has 10) + + # Returns + model (Model): Keras model instance + """ + if (depth - 2) % 9 != 0: + raise ValueError('depth should be 9n+2 (eg 56 or 110 in [b])') + # Start model definition. + num_filters_in = 16 + num_res_blocks = int((depth - 2) / 9) + + inputs = Input(shape=input_shape) + # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths + x = resnet_layer(inputs=inputs, + num_filters=num_filters_in, + conv_first=True) + + # Instantiate the stack of residual units + for stage in range(3): + for res_block in range(num_res_blocks): + activation = 'relu' + batch_normalization = True + strides = 1 + if stage == 0: + num_filters_out = num_filters_in * 4 + if res_block == 0: # first layer and first stage + activation = None + batch_normalization = False + else: + num_filters_out = num_filters_in * 2 + if res_block == 0: # first layer but not first stage + strides = 2 # downsample + + # bottleneck residual unit + y = resnet_layer(inputs=x, + num_filters=num_filters_in, + kernel_size=1, + strides=strides, + activation=activation, + batch_normalization=batch_normalization, + conv_first=False) + + y = resnet_layer(inputs=y, + num_filters=num_filters_in, + conv_first=False) + y = resnet_layer(inputs=y, + num_filters=num_filters_out, + kernel_size=1, + conv_first=False) + if res_block == 0: + # linear projection residual shortcut connection to match + # changed dims + x = resnet_layer(inputs=x, + num_filters=num_filters_out, + kernel_size=1, + strides=strides, + activation=None, + batch_normalization=False) + + x = keras.layers.add([x, y]) + + num_filters_in = num_filters_out + + # Add classifier on top. + # v2 has BN-ReLU before Pooling + x = BatchNormalization()(x) + x = Activation('relu')(x) + pool_size = int(x.get_shape()[1]) + x = AveragePooling2D(pool_size=pool_size)(x) + y = Flatten()(x) + outputs = Dense(num_classes, + activation=None, + kernel_initializer='he_normal')(y) + + outputs = Activation('softmax')(outputs) + + # Instantiate model. + model = Model(inputs=inputs, outputs=outputs) + return model, inputs, outputs + + + +def create_resnet_generator(x_train): + # This will do preprocessing and realtime data augmentation: + datagen = ImageDataGenerator( + # set input mean to 0 over the dataset + featurewise_center=False, + # set each sample mean to 0 + samplewise_center=False, + # divide inputs by std of dataset + featurewise_std_normalization=False, + # divide each input by its std + samplewise_std_normalization=False, + # apply ZCA whitening + zca_whitening=False, + # epsilon for ZCA whitening + zca_epsilon=1e-06, + # randomly rotate images in the range (deg 0 to 180) + rotation_range=0, + # randomly shift images horizontally + width_shift_range=0.1, + # randomly shift images vertically + height_shift_range=0.1, + # set range for random shear + shear_range=0., + # set range for random zoom + zoom_range=0., + # set range for random channel shifts + channel_shift_range=0., + # set mode for filling points outside the input boundaries + fill_mode='nearest', + # value used for fill_mode = "constant" + cval=0., + # randomly flip images + horizontal_flip=True, + # randomly flip images + vertical_flip=False, + # set rescaling factor (applied before any other transformation) + rescale=None, + # set function that will be applied on each input + preprocessing_function=None, + # image data format, either "channels_first" or "channels_last" + data_format=None, + # fraction of images reserved for validation (strictly between 0 and 1) + validation_split=0.0) + + # Compute quantities required for featurewise normalization + # (std, mean, and principal components if ZCA whitening is applied). + datagen.fit(x_train) + return datagen + +