In [1]:
'''
this script will generate tiles of raw images and corresponding labels
- tile size must be divisible by 32 for unet model
- v0 performs sliding window w/ no overlap so many pixels lost
- v1 performs cropping and rotations at random locations
'''

'\nthis script will generate tiles of raw images and corresponding labels\n- tile size must be divisible by 32 for unet model\n- v0 performs sliding window w/ no overlap so many pixels lost\n- v1 performs cropping and rotations at random locations\n'

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
!pip install imutils
import imutils





In [3]:
## path to directories

raw_data_path = r'\\babyserverdw3\PW Cloud Exp Documents\Lab work documenting\W-22-09-10 AT Build Competent multi task DL model for tissue labeling\dataset\1115\raw'

save_path = r'\\babyserverdw3\PW Cloud Exp Documents\Lab work documenting\W-22-09-10 AT Build Competent multi task DL model for tissue labeling\dataset\1115\tiled'

In [4]:
VERSION = 1.2
NUM_TILES = 60 # number of tiles to generate per image

In [5]:
## finds best tile size with minimum pixels lost (more relevant for v0, but keep anyways for v1)

def find_best_tile_size(im_dim0, im_dim1, min_tile_size=480, max_tile_size=1024):
    possible_sizes = np.arange(min_tile_size, max_tile_size, 32) # must be divisible by 32 for unet
    remainders = [] # number of pixels lost
    for i in possible_sizes:
        r0 = (im_dim0 % i) * im_dim1
        r1 = (im_dim1 % i) * im_dim0
        overlap = ((im_dim0 % i) * (im_dim1 % i))
        remainders.append(r0 + r1 - overlap)
    best_size = possible_sizes[np.argmax(remainders)]
    print('best tile size: {}, num pixels lost: {}'.format(best_size, np.max(remainders)))
    return best_size

In [6]:
# tile_size = find_best_tile_size(2048, 2880)
tile_size = 736

tile_folder_name = '{}x{}_v{}'.format(tile_size, tile_size, VERSION)

tile_save_path = os.path.join(save_path, tile_folder_name)

if not os.path.exists(tile_save_path):
    os.mkdir(tile_save_path)

In [7]:
def rotate(img, label, angle, interpolation=cv2.INTER_NEAREST, border_mode=cv2.BORDER_REFLECT_101):
    height, width = img.shape[:2]
    matrix = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1.0)
    img_tile = cv2.warpAffine(img, M=matrix, dsize=(width, height), flags=interpolation, borderMode=border_mode)
    label_tile = cv2.warpAffine(label, M=matrix, dsize=(width, height), flags=interpolation, borderMode=border_mode)
    return img_tile, label_tile

In [8]:
def normal_crop(rand_cent, padding, img, label):
    # padding is tile_size / 2
    lower_bound = np.asarray(rand_cent) - padding
    upper_bound = np.asarray(rand_cent) + padding
    img_tile = img[int(lower_bound[0]):int(upper_bound[0]), int(lower_bound[1]):int(upper_bound[1]), :]
    label_tile = label[int(lower_bound[0]):int(upper_bound[0]), int(lower_bound[1]):int(upper_bound[1])]
    return img_tile, label_tile

def rotation_crop(angle, img, label):
    padding = (tile_size / 2) * np.sqrt(2) # new padding, must be larger than 736x736 for rotation 
    h_max = img.shape[0] - padding
    v_max = img.shape[1] - padding
    rand_cent = (np.random.randint(padding, h_max+1), np.random.randint(padding, v_max+1))
    # crop large image region for 45 deg rotation
    lower_bound = rand_cent - padding
    upper_bound = rand_cent + padding
    img_tile = img[int(lower_bound[0]):int(upper_bound[0]), int(lower_bound[1]):int(upper_bound[1]), :]
    label_tile = label[int(lower_bound[0]):int(upper_bound[0]), int(lower_bound[1]):int(upper_bound[1])]
    # rotate image and label
    rot_img, rot_label = rotate(img_tile, label_tile, angle=angle)
    # get new center point
    recalibrated_cent = (int(padding), int(padding))
    # perform normal cropping to 736x736
    img_tile, label_tile = normal_crop(recalibrated_cent, (tile_size / 2), rot_img, rot_label)
    return img_tile, label_tile

def perform_crop(rand_cent, padding, img, label): 
    which_crop = np.random.randint(0, 3) # 0 = normal; 1 = flip; 2 = rotation 
    # normal crop - no rotation
    if which_crop == 0:
        img_tile, label_tile = normal_crop(rand_cent, padding, img, label)
    # flip
    elif which_crop == 1:
        img_tile, label_tile = normal_crop(rand_cent, padding, img, label)
        angles = np.arange(90, 270+1, 90)  # angles from 90 to 270, flips
        angle = angles[np.random.randint(len(angles))] # select a random angle
        img_tile, label_tile = imutils.rotate(img_tile, angle=angle), imutils.rotate(label_tile, angle=angle)
    # crop w rotation
    elif which_crop == 2:
#         angles = np.arange(1, 360, 10) # angles from 45 to 315, diagonal rotation
        angles = np.arange(45, 315+1, 90) # angles from 45 to 315, diagonal rotation
        angle = angles[np.random.randint(len(angles))] # select a random angle
        img_tile, label_tile = rotation_crop(angle, img, label)
    return img_tile, label_tile

In [9]:
def create_tiles(im, lb, tile_size, mode_save_path, im_name):
    img = cv2.imread(im)
    label = cv2.imread(lb, cv2.IMREAD_UNCHANGED).astype('uint8')
    # get center pixel min/max
    padding = tile_size / 2 # both indices start at tile_size / 2
    h_max = img.shape[0] - padding
    v_max = img.shape[1] - padding
    tile_counter = 0
    for n in range(NUM_TILES):
        rand_cent = (np.random.randint(padding, h_max+1), np.random.randint(padding, v_max+1))
        img_tile, label_tile = perform_crop(rand_cent, padding, img, label) # crop image based on given center pixel
        fn = im_name + '_0{}'.format(tile_counter)
        tile_counter += 1
        cv2.imwrite(mode_save_path + '/images/' + fn + '.tif', img_tile)
        cv2.imwrite(mode_save_path + '/labels/'+ fn + '.tif', label_tile)


In [10]:
for mode in os.listdir(raw_data_path): # 'train', 'test'
    if mode == 'test':
        continue
    mode_save_path = os.path.join(tile_save_path, mode)
    if not os.path.exists(mode_save_path): # make train/, test/ folders
        os.mkdir(mode_save_path)      
        os.mkdir(mode_save_path + '/images')
        os.mkdir(mode_save_path + '/labels')
    images_path = os.path.join(raw_data_path, mode + '/images')
    labels_path = os.path.join(raw_data_path, mode + '/labels')    
    for fn in os.listdir(images_path): # for each image
        if fn == 'Thumbs.db':
            continue
        im_name = fn[:-4]
        im = os.path.join(images_path, im_name + '.tif')
        lb = os.path.join(labels_path, im_name + '.tif')
        print('processing {}...'.format(im_name))
        create_tiles(im, lb, tile_size, mode_save_path, im_name)      

processing 303_roi01_ROI-STIC...
processing 303_roi02_ROI-normal...
processing 304_roi01_ROI-STIC...
processing 304_roi02_ROI-normal...
processing 305_roi01_ROI-STIC...
processing 305_roi02_ROI-normal...
