In [26]:
import nd2reader
import matplotlib.pyplot as plt
import os
import numpy as np
from os import listdir
from PIL import Image
import shutil
import imageio

from skimage import data
from skimage.feature import register_translation
from skimage.feature.register_translation import _upsampled_dft
from scipy.ndimage import fourier_shift

from skimage.util import *
from skimage import exposure

import cv2

In [27]:
def read_data(reader, data_path):
    data = []
    filenames = []
    for filename in sorted(listdir(data_path)):
        if filename.endswith('.nd2'):
            data.append(reader(data_path + filename))
            filenames.append(filename)
    return data, filenames

In [29]:
def get_volume(sample, channel, frame):
    if channel in range(sample.sizes["c"]) and frame in range(sample.sizes["t"]):
        sample.iter_axes = 'z'
        sample.default_coords['c'] = channel
        sample.default_coords['t'] = frame
        
        volume = np.array([np.array(level) for level in sample])
        return volume

In [30]:
def align_images(img_1, img_2):
    shift, error, diffphase = register_translation(img_1, img_2, 100)
        
    shift = -1 * shift
    img_1_shifted = fourier_shift(np.fft.fftn(img_1), shift)
    img_1_shifted = np.fft.ifftn(img_1_shifted).real
    
    return img_1_shifted, img_2

In [31]:
def align_volumes(volume_1, volume_2):
    volume_1_shifted = []
    for img_1, img_2 in zip(volume_1, volume_2):
        img_1_shifted, img_2 = align_images(img_1, img_2)
        volume_1_shifted.append(img_1_shifted)
    return volume_1_shifted, volume_2

In [32]:
def augment_image(img):
    return [img, np.fliplr(img), np.flipud(img), np.fliplr(np.flipud(img))]

In [168]:
def augment_volume(volume):
    volume_augmented = []
    
    for img in volume:
        volume_augmented.extend(augment_image(img_1))
    return volume_augmented

In [33]:
# def augment_volumes(volume_1, volume_2):
#     volume_1_augmented = []
#     volume_2_augmented = []
    
#     for img_1, img_2 in zip(volume_1, volume_2):
#         volume_1_augmented.extend(augment_image(img_1))
#         volume_2_augmented.extend(augment_image(img_2))
#     return volume_1_augmented, volume_2_augmented

In [34]:
def is_overnoised(img,bright_pixel_percent=0.05):
    hist = exposure.histogram(img, nbins=2)[0]
    value = hist[1] # Number of bright points. (The less value, the less noise)
    n_pixels = img.shape[0] * img.shape[1]
    return value > n_pixels * bright_pixel_percent

In [35]:
def clean_volumes(volume_1, volume_2):
    volume_1_cleaned = []
    volume_2_cleaned = []
    
    for img_1, img_2 in zip(volume_1, volume_2):
        if is_overnoised(img_1) | is_overnoised(img_2):
            continue
        else:
            volume_1_cleaned.append(img_1)
            volume_2_cleaned.append(img_2)
    return volume_1_cleaned, volume_2_cleaned

In [179]:
def adap_hist_img(img):
    image_rescaled = exposure.rescale_intensity(img, in_range=(0, 2**12 - 1))
    rescaled_image_histeq = exposure.equalize_adapthist(image_rescaled)
    return rescaled_image_histeq

In [175]:
def adap_hist_volume(volume):
    volume_adap_hist = []
    
    for img in volume:
        volume_adap_hist.append(adap_hist_img(img))
    return volume_adap_hist

In [122]:
def save_img_grayscale(img, filename):
    plt.imsave(filename, img, cmap="gray")

In [137]:
def save_volume(volume, path):
    for ind, image in enumerate(volume):
        filename = f"{path}_image_{ind}.jpg"
        save_img_grayscale(image, filename)

In [138]:
params = [{'clean':False, 'align':True, 'augment':False},
          {'clean':True, 'align':True, 'augment':False},
          {'clean':False, 'align':True, 'augment':True},
          {'clean':True, 'align':True, 'augment':True}]

In [139]:
def create_folders(dataset_path):
    dataset_folders = [dataset_path,
                       f"{dataset_path}/green",
                       f"{dataset_path}/green/train",
                       f"{dataset_path}/green/train/src",
                       f"{dataset_path}/green/train/trg",
                       f"{dataset_path}/green/val",
                       f"{dataset_path}/green/val/src",
                       f"{dataset_path}/green/val/trg",
    
                       f"{dataset_path}/red",
                       f"{dataset_path}/red/train",
                       f"{dataset_path}/red/train/src",
                       f"{dataset_path}/red/train/trg",
                       f"{dataset_path}/red/val",
                       f"{dataset_path}/red/val/src",
                       f"{dataset_path}/red/val/trg"]

    for folder in dataset_folders:
        os.mkdir(folder)
        
    return True

In [176]:
def generate_dataset(samples, param, dataset_path):
    
    create_folders(dataset_path)
    
    for sample_ind, sample in enumerate(samples):
        print("Sample №{}".format(sample_ind))
            
        for channel_ind, channel_name in enumerate(["green", "red"]):
            for pair_ind, pair in enumerate([[0, 2], [1, 3]]): # Frame pairs
                
                is_test = (pair_ind == 1 and sample_ind == len(samples) - 1)
                      
                volume_1 = get_volume(sample, channel=channel_ind, frame=pair[0])
                volume_2 = get_volume(sample, channel=channel_ind, frame=pair[1])
                
                if param.get('align'):
                    volume_1, volume_2 = align_volumes(volume_1, volume_2)
                    
                if not is_test:  # We don't need to do this steps for test set
                    if param.get('clean'):
                        volume_1, volume_2 = clean_volumes(volume_1, volume_2)

                    if param.get('augment'):
                        volume_1 = augment_volume(volume_1) 
                        volume_2 = augment_volume(volume_2) 
                
                if param.get('adap_hist'):
                     volume_1 = adap_hist_volume(volume_1) 
                     volume_2 = adap_hist_volume(volume_2) 
                    
                assert(len(volume_1) == len(volume_2))
                print(f"Channel: {channel_name}, Frames: {pair}, Number of samples: {len(volume_1)}")
                
                train_or_test = "val" if is_test else "train"
                path_1 = f"{dataset_path}/{channel_name}/{train_or_test}/src/sample_{sample_ind}_pair{pair_ind}"
                path_2 = f"{dataset_path}/{channel_name}/{train_or_test}/trg/sample_{sample_ind}_pair{pair_ind}"
                
                save_volume(volume_1, path_1)
                save_volume(volume_2, path_2)

In [165]:
def delete_dataset(dataset_folder):
    if os.path.isdir(dataset_folder):
        shutil.rmtree(dataset_folder, ignore_errors=True)

In [166]:
DATA_PATH = 'data/'

reader = nd2reader.ND2Reader
samples, filenames = read_data(reader, DATA_PATH)
print(f"Found in total {len(filenames)} files:")
print(*filenames, sep="\n")

Found in total 9 files:
18112019_SJR5_w1_30ms.nd2
18112019_SJR5_w1_5ms.nd2
18112019_SJR5_w2_30ms.nd2
18112019_SJR5_w2_5ms.nd2
18112019_SJR5_w3_30ms.nd2
18112019_SJR5_w3_5ms.nd2
18112019_SJR5_w4_30ms.nd2
18112019_SJR5_w4_5ms.nd2
18112019_SJR5_w5_30ms.nd2


In [167]:
data_type = "5ms"
file_indices = [ind for ind, filename in enumerate(filenames) if data_type in filename]
interested_samples = [samples[ind] for ind in file_indices]
    
for experiment_ind, param in enumerate(params):
    experiment_name = "_".join([key for key, el in param.items() if el])
        
    print(f"Generating the dataset for experiment {experiment_name}")
    dataset_path = f"{experiment_name}_data_{data_type}"
    
    delete_dataset(dataset_path)
    generate_dataset(interested_samples, param, dataset_path)

Saving the dataset for experiment align
Sample №0
Channel: green, Frames: [0, 2], Number of samples: 35
Channel: green, Frames: [1, 3], Number of samples: 35
Channel: red, Frames: [0, 2], Number of samples: 35
Channel: red, Frames: [1, 3], Number of samples: 35
Sample №1
Channel: green, Frames: [0, 2], Number of samples: 35
Channel: green, Frames: [1, 3], Number of samples: 35
Channel: red, Frames: [0, 2], Number of samples: 35
Channel: red, Frames: [1, 3], Number of samples: 35
Sample №2
Channel: green, Frames: [0, 2], Number of samples: 35
Channel: green, Frames: [1, 3], Number of samples: 35
Channel: red, Frames: [0, 2], Number of samples: 35
Channel: red, Frames: [1, 3], Number of samples: 35
Sample №3
Channel: green, Frames: [0, 2], Number of samples: 35
Channel: green, Frames: [1, 3], Number of samples: 35
Channel: red, Frames: [0, 2], Number of samples: 35
Channel: red, Frames: [1, 3], Number of samples: 35
Saving the dataset for experiment clean_align
Sample №0
Channel: green, 

In [180]:
params = [{'clean':False, 'align':True, 'augment':False, 'adap_hist': True}]

In [181]:
data_type = "5ms"
file_indices = [ind for ind, filename in enumerate(filenames) if data_type in filename]
interested_samples = [samples[ind] for ind in file_indices]
    
for experiment_ind, param in enumerate(params):
    experiment_name = "_".join([key for key, el in param.items() if el])
        
    print(f"Generating the dataset for experiment {experiment_name}")
    dataset_path = f"{experiment_name}_data_{data_type}"
    
    delete_dataset(dataset_path)
    generate_dataset(interested_samples, param, dataset_path)

Generating the dataset for experiment align_adap_hist
Sample №0
Channel: green, Frames: [0, 2], Number of samples: 35
Channel: green, Frames: [1, 3], Number of samples: 35
Channel: red, Frames: [0, 2], Number of samples: 35
Channel: red, Frames: [1, 3], Number of samples: 35
Sample №1
Channel: green, Frames: [0, 2], Number of samples: 35
Channel: green, Frames: [1, 3], Number of samples: 35
Channel: red, Frames: [0, 2], Number of samples: 35
Channel: red, Frames: [1, 3], Number of samples: 35
Sample №2
Channel: green, Frames: [0, 2], Number of samples: 35
Channel: green, Frames: [1, 3], Number of samples: 35
Channel: red, Frames: [0, 2], Number of samples: 35
Channel: red, Frames: [1, 3], Number of samples: 35
Sample №3
Channel: green, Frames: [0, 2], Number of samples: 35
Channel: green, Frames: [1, 3], Number of samples: 35
Channel: red, Frames: [0, 2], Number of samples: 35
Channel: red, Frames: [1, 3], Number of samples: 35
