In [2]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os, time, random
import random
import cv2

In [3]:
train_data = np.load(r"train_data.npy")
val_data = np.load(r"val_data.npy")
test_data = np.load(r"test_data.npy")

In [4]:
train_data.shape

(100000, 64, 64, 3)

### Create patches and preprocess them

Our images are of size 64x64, so we have to choose an appropriate size for the patch. We maintain pretty much the same number of patches within a single images (roughly around 32 in the original paper) by choosing 12x12 patches. We allow for a gap of around half the size of the patch and we jitter the location of the patch by -2 to 2 pixels in each direction.

In [5]:
def sample_patch(x, y, grid_start_x, grid_start_y, patch_size, gap_size, jitter_size, image_size):

    #initial position of the patch
    x_start = grid_start_x + x * (patch_size[1] + gap_size) + random.randint(-jitter_size, jitter_size)
    y_start = grid_start_y + y * (patch_size[0] + gap_size) + random.randint(-jitter_size, jitter_size)

    #ensure the patch stays inside image boundaries
    x_start_2 = min(max(x_start, 0), image_size[1] - patch_size[1])
    y_start_2 = min(max(y_start, 0), image_size[0] - patch_size[0])

    #return the coordinates of the upper-left corner
    return (x_start_2, y_start_2)

In [1]:
def preprocess_patch(patch):
    #randomly drop all but one color channel
    kept_channel = random.randint(0, 2)

    mean = [123.68, 116.779, 103.939]

    for i in range(0, 3):
        if i == kept_channel:
            patch[:,:,i] -= np.mean(patch[:,:,i])
        else:
            patch[:,:,i] = np.random.uniform(0, 1, (patch.shape[0], patch.shape[1])) - 0.5

    #normalize mean and variance
    patch = patch/np.sqrt(np.mean(patch**2))*50

    return patch

In [None]:
patch_size = (12, 12)
gap_size = 6
jitter_size = 2
batch_size = 128 #maximum number of patches in a single batch

In [None]:
#load images, extract a pair of patches and put them into a batch
def image_loader(data_q, batch_size, images, seed, patch_size):
    num_batches = 0
    curr_idx = 0

    #randomly permute the dataset
    np.random.seed(seed)
    ordered_imgs = np.random.permutation(len(images))

    #sample grids per image
    num_grids = 4
    grids_left = 0

    #store in a single batch
    perm = []
    label = []
    patches = []

    j = 0 #index of the current patch
    while True: #loop until stopped
        if grids_left == 0:
          while True:
            image = ordered_imgs[curr_idx%len(ordered_imgs)]
            curr_idx = (curr_idx + 1) % (len(ordered_imgs)) #cycle back to the beginning of the dataset
          grids_left = num_grids #reset number of grids left to compute

        #compute where a grid starts and its size
        grid_start_x = random.randint(0, patch_size[1] + gap_size - 1)
        grid_start_y = random.randint(0, patch_size[0] + gap_size - 1)
        grid_size_x = int(image.shape[1] + gap - grid_start_x)/(patch_size[1] + gap_size)
        grid_size_y = int(image.shape[0] + gap - grid_start_y)/(patch_size[0] + gap_size)

        grid = np.zeros((grid_size_x, grid_size_y), int) #to store the index of the patch

        #put the batch in the queue when we reach the batch size
        if (grid_size_x * grid_size_y + j) >= batch_size:
            patches = map(preprocess_patch, patches)
            data = np.array(patches)
            num_batches += 1
            perm = np.array(perm)
            label = np.array(label)
            dataq.put((np.ascontiguousarray(data), perm, label), timeout=600)
            perm = []
            label = []
            patches = []
            j = 0

        grids_left -= 1
        #sample a patch and search for patches that it can be paired with
        for y in range(0, grid_size_y):
            for x in range(0, grid_size_x):
                (x_pixel, y_pixel) = sample_patch(x, y, grid_start_x, grid_start_y, patch_size, gap_size, jitter_size, image.shape)
                patches.append(np.copy(image[y_pixel:y_pixel+patch_size[0], x_pixel:x_pixel+patch_size[1], :]*255))
                grid[x, y] = j
                for pair in [(-1, -1), (0, -1), (1, -1), (-1, 0)]:
                    grid_pos_x = pair[0] + x
                    grid_pos_y = pair[1] + y
                    if grid_pos_x < 0 or grid_pos_y < 0 or grid_pos_x >= grid_size_x:
                        continue;
                    perm.append(np.array([j, grid[grid_pos_y, grid_pos_x]]))
                    label.append(pos_2_label(pair))
                    perm.append(np.array([grid[grid_pos_y, grid_pos_x], j]))
                    label.append(pos_2_label((-pair[0], -pair[1])))
                j += 1


