In [23]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/FOCE


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/.shortcut-targets-by-id/1SDt1LQzEYHlIyxPcg_tq-EtGefu_-7i6/FOCE


In [24]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import cv2
import os
import numpy as np
import glob
from PIL import Image

In [27]:
def transform_image(image, sketch, angle_range, shear_range, translation_range): #transforms the original images randomly to augment the dataset
    height, width, channels = image.shape

    #rotation
    angle = np.random.uniform(-angle_range / 2, angle_range / 2)
    rotation_matrix = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1)

    #translation
    shift_x = np.random.uniform(-translation_range / 2, translation_range / 2)
    shift_y = np.random.uniform(-translation_range / 2, translation_range / 2)
    translation_matrix = np.float32([[1, 0, shift_x], [0, 1, shift_y]])

    #shearing
    base_pts = np.float32([[5, 5], [20, 5], [5, 20]])
    delta1 = np.random.uniform(-shear_range / 2, shear_range / 2)
    delta2 = np.random.uniform(-shear_range / 2, shear_range / 2)
    new_pts = np.float32([[5 + delta1, 5], [20 + delta2, 5 + delta1], [5, 20 + delta2]])
    shear_matrix = cv2.getAffineTransform(base_pts, new_pts)

    #border colors
    border_color_img = tuple(map(int, image[0, 0]))
    border_color_skt = tuple(map(int, sketch[0, 0]))

    for transform in [rotation_matrix, translation_matrix, shear_matrix]: #apply the transformations
        image = cv2.warpAffine(image, transform, (width, height), borderValue=border_color_img)
        sketch = cv2.warpAffine(sketch, transform, (width, height), borderValue=border_color_skt)

    return image, sketch


In [26]:
base_dir = '/content/drive/MyDrive/FOCE/CUHK/'
sketch_dir = os.path.join(base_dir, 'Augmented sketch')
photo_dir = os.path.join(base_dir, 'Augmented photo')


if not os.path.exists(sketch_dir):
    os.mkdir(sketch_dir)

if not os.path.exists(photo_dir):
    os.mkdir(photo_dir)

p_filenames = glob.glob(os.path.join(base_dir, 'photos', '*'))
s_filenames = glob.glob(os.path.join(base_dir, 'sketches', '*'))


counter = 0
for i in range(len(p_filenames)):
    im = cv2.imread(p_filenames[i])
    sk = cv2.imread(s_filenames[i])

    for j in range(10): #generate 10 augmented images per original image
        img, skt = transform_image(im, sk, 40, 10, 10)

        cv2.imwrite(os.path.join(photo_dir, f'{counter}.jpg'), img)
        cv2.imwrite(os.path.join(sketch_dir, f'{counter}.jpg'), skt)

        counter += 1