#### This notebook preprocesses images for prediction using a DNN trained with Uni-EM.
It opens a series of .tiff files in `path_input` and applies CLAHE (Contrast limited adaptive histogram equalization), which enhances the local contrast of images. It then re-saves them as RGB .png files to `path_results`. After using this notebook on your raw data, the images can be used for segmentation, e.g. with Uni-EM.


In [None]:
%matplotlib notebook
import tifffile as tf
import cv2
import matplotlib.pyplot as plt
from os import listdir
from os.path import isfile, join
import skimage
from skimage.transform import downscale_local_mean
import tqdm

In [None]:
def chunk_generator(img_shape, chunk_size, overlap):
    '''
    Returns a sequence of coordinates every time it is called with next() that can be used to cycle through 3D arrays in blocks.

    Inputs:
    img_shape: image shape(z,y,x)
    chunk_size: desired chunk size (z,y,x)
    overlap: overlap (in pixels) on every side of the chunk

    Outputs:
    6 integers giving the start & end coordinates in all axes in the following order:
    xstart, xend, ystart, yend, zstart, zend

    to do:
        rest of image calculation, uneven boundaries
        n-dimensional image compatibility
    '''

    z_start = 0 
    z_end   = chunk_size[0]
    y_start = 0
    y_end   = chunk_size[1]
    x_start = 0
    x_end   = chunk_size[2]
    
    while x_end <= img_shape[2]: #if x_end exceeds x boundary of image, all is done

        yield (z_start, z_end, y_start, y_end, x_start, x_end)

        z_start = z_start + chunk_size[0] - 2 * overlap
        z_end   = z_start + chunk_size[0]
        
        # if z_end exceeds img shape: move y_start (and reset z_start)
        if z_end > img_shape[0]:
            y_start = y_start + chunk_size[1] - 2 * overlap
            y_end   = y_start + chunk_size[1]
            z_start = 0
            z_end   = chunk_size[0]
        
        # if z_end AND y_end exceed img shape: move x_start (and reset y_start and z_start)
        if y_end > img_shape[1]:
            x_start = x_start + chunk_size[2] - 2 * overlap
            x_end   = x_start + chunk_size[2]
            z_start = 0
            z_end   = chunk_size[0]
            y_start = 0
            y_end   = chunk_size[1]

    yield z_start, z_end, y_start, y_end, x_start, x_end

def remove_whitespaces(string):
    return "".join(string.split())

In [None]:
# user inputs
path_input = r"G:\AG_Morawski\Philip\EM\SWM_Directionality\1_raw\\"
path_results = r"G:\AG_Morawski\Philip\EM\SWM_Directionality\2_preprocessed\\"
downscale_factor = 4 # factor 4 this is the correct factor for 
crop_threshold = 100 # intensity threshold for background - if your image has black background, define a threshold here.

In [None]:
path_images = [f for f in listdir(path_input) if isfile(join(path_input, f))]

for i in tqdm.tqdm(range(len(path_images))):
    print("Preprocessing Images " + str(i))
    if path_images[i].endswith(".tif") or path_images[i].endswith(".tiff"):
        test_img = tf.imread(path_input + str(path_images[i])) # use this for tiff
    elif path_images[i].endswith(".png"):
        test_img = cv2.imread(path_input + str(path_images[i]),-1)[:,:,0] # use this for png 
    else:
        print('Input file format not supported. Use .png or .tif.')
        break
    coord_list = []
    for z_start, z_end, y_start, y_end, x_start, x_end in chunk_generator(test_img.shape, (2048*4,2048*4,3),0):
        coord_list.append((z_start,z_end,y_start,y_end,x_start,x_end))
    for zyx in tqdm.tqdm(coord_list):
        test_img_clahe = skimage.exposure.equalize_adapthist(test_img[zyx[0]:zyx[1],zyx[2]:zyx[3],zyx[4]:zyx[5]],clip_limit=0.01,kernel_size=127)
        test_img_downscaled = downscale_local_mean(test_img_clahe, downscale_factor)
        test_img_rgb_png = cv2.merge((downscale_local_mean(test_img_downscaled,1),  #R
                                    downscale_local_mean(test_img_downscaled,1),    #G
                                    downscale_local_mean(test_img_downscaled,1)))   #B
        skimage.io.imsave(path_results + path_images[i][:-4] + remove_whitespaces(str(zyx)) +  ".png", (test_img_rgb_png*255).astype('uint8'))