In [None]:
import numpy as np
import os
import math
import matplotlib.pyplot as plt
from PIL import Image

import tensorflow as tf
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.preprocessing.image import load_img

from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.utils import img_to_array

import tensorflow_addons as tfa

In [None]:
SCALE_MIN = 0.3
SCALE_MAX = 0.3
rotate_max = np.pi/8 # 22.5 degrees in either direction

MAX_ROTATION = 22.5
ROOT_DIR = os.getcwd()

In [None]:
def circle_mask(shape, sharpness = 40):
    """Return a circular mask of a given shape"""
    
    assert shape[0] == shape[1], "circle_mask received a bad shape: " + shape

    diameter = shape[0]  
    x = np.linspace(-1, 1, diameter)
    y = np.linspace(-1, 1, diameter)
    xx, yy = np.meshgrid(x, y, sparse=True)
    z = (xx**2 + yy**2) ** sharpness
    mask = 1 - np.clip(z, -1, 1)
    # plt.contour(x, y, z)
    # plt.imshow(mask)

    # plt.xlim((-3, 3))
    # plt.ylim(-3, 3)
    mask = np.expand_dims(mask, axis=2)
    mask = np.broadcast_to(mask, shape).astype(np.float32)
    return mask

def show(im):
    plt.axis('off')
    plt.imshow(im, interpolation="nearest")
    plt.show()

def transform_vector(width, x_shift, y_shift, im_scale, rot_in_degrees):
    
    """
    If one row of transforms is [a0, a1, a2, b0, b1, b2, c0, c1], 
    then it maps the output point (x, y) to a transformed input point 
    (x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k), 
    where k = c0 x + c1 y + 1. 
    The transforms are inverted compared to the transform mapping input points to output points.
    """

    rot = float(rot_in_degrees) / 90. * (math.pi/2)

    # Standard rotation matrix
    # (use negative rot because tf.contrib.image.transform will do the inverse)
    rot_matrix = np.array(
      [[math.cos(-rot), -math.sin(-rot)],
      [math.sin(-rot), math.cos(-rot)]]
    )

    # Scale it
    # (use inverse scale because tf.contrib.image.transform will do the inverse)
    inv_scale = 1. / im_scale
    xform_matrix = rot_matrix * inv_scale

    a0, a1 = xform_matrix[0]
    b0, b1 = xform_matrix[1]

    # At this point, the image will have been rotated around the top left corner,
    # rather than around the center of the image. 
    # To fix this, we will see where the center of the image got sent by our transform,
    # and then undo that as part of the translation we apply.
    x_origin = float(width) / 2
    y_origin = float(width) / 2

    x_origin_shifted, y_origin_shifted = np.matmul(
      xform_matrix,
      np.array([x_origin, y_origin]),
    )

    x_origin_delta = x_origin - x_origin_shifted
    y_origin_delta = y_origin - y_origin_shifted

    # Combine our desired shifts with the rotation-induced undesirable shift
    a2 = x_origin_delta - (x_shift/(2*im_scale))
    b2 = y_origin_delta - (y_shift/(2*im_scale))

    # Return these values in the order that tf.contrib.image.transform expects
    return np.array([a0, a1, a2, b0, b1, b2, 0, 0]).astype(np.float32)

def random_overlay(imgs, patch, image_shape, batch_size):

    """Augment images with random rotation, transformation.

    Image: BATCHx299x299x3
    Patch: 50x50x3

    """
    # Add padding

    image_mask = circle_mask(image_shape)

    image_mask = tf.stack([image_mask] * batch_size)
    padded_patch = tf.stack([patch] * batch_size)

    transform_vecs = []    

    def random_transformation(scale_min, scale_max, width):
        im_scale = np.random.uniform(low=scale_min, high=scale_max)

        padding_after_scaling = (1-im_scale) * width
        x_delta = np.random.uniform(-padding_after_scaling, padding_after_scaling)
        y_delta = np.random.uniform(-padding_after_scaling, padding_after_scaling)


        rot = np.random.uniform(-MAX_ROTATION, MAX_ROTATION)

        return transform_vector(width,
                                x_shift=x_delta,
                                y_shift=y_delta,
                                im_scale=im_scale, 
                                rot_in_degrees=rot)    

    for i in range(batch_size):
        # Shift and scale the patch for each image in the batch
        random_xform_vector = tf.numpy_function(random_transformation, [SCALE_MIN, SCALE_MAX, image_shape[0]], tf.float32)
        random_xform_vector.set_shape([8])

        transform_vecs.append(random_xform_vector)

    image_mask = tfa.image.transform(image_mask, transform_vecs, "BILINEAR")
    padded_patch = tfa.image.transform(padded_patch, transform_vecs, "BILINEAR")

    inverted_mask = (1 - image_mask)
    return (imgs * inverted_mask + padded_patch * image_mask)

In [None]:
os.chdir(ROOT_DIR)
os.chdir('..')
os.getcwd()

In [None]:
# SET YOUR CLEAN IMAGE DIRECTORY AND adv_dir(where the images will be stored)
clean_dir = r'resnet_40'
adv_dir = r'patch_new_data\adv_images'

# Load the patch image
patch_img = load_img(os.path.join('patch', 'resnet50_patch.png'), target_size=(224, 224, 3), interpolation='nearest')
patch_img = img_to_array(patch_img)

# class name and folder name of the image that we are reading
class_name = []
file_name = []

os.chdir(clean_dir)
for folder in os.listdir():
    data_directory = os.path.join(os.getcwd(), folder)
    images_gen = image_dataset_from_directory(
        data_directory,
        seed=42, 
        image_size=(224, 224),
        batch_size=None, # The dataset will yield individual samples.
        color_mode='rgb',
        shuffle=False)
    
    file_paths = images_gen.file_paths

    if folder.startswith('clean'):
        for idx, [img, labels] in enumerate(images_gen.take(-1)):
            img = img_to_array(img) # .astype(np.uint8)
            
            fld = images_gen.class_names[labels]
            class_name.append(fld)
            fn = os.path.basename(file_paths[idx])
            file_name.append(fn)
            
            trans_image = random_overlay(imgs=img, patch=patch_img, image_shape=[224, 224, 3], batch_size=1)
            trans_image = trans_image.numpy().astype(np.uint8)
            
            plt.figure(figsize=(4, 4))
            plt.axis('off')
            
            # Uncomment to display images with random patch locations
            plt.imshow(trans_image[0].astype(np.uint8), interpolation="nearest")
            
            if not os.path.exists(os.path.join(ROOT_DIR, adv_dir, fld)):
                os.makedirs(os.path.join(ROOT_DIR, adv_dir, fld))
            
            im = Image.fromarray(trans_image[0])
            im.save(os.path.join(ROOT_DIR, adv_dir, fld, fn))
                
            # plt.savefig(os.path.join(ROOT_DIR, adv_dir, fld, fn), bbox_inches='tight', pad_inches=0, dpi=73)
os.chdir(ROOT_DIR)