In [None]:
%matplotlib inline
import sys
sys.path.append("..") # Adds the module to path

In [None]:
import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import itertools
import glob    

import deeptrack as dt

PATH_TO_DATASET = "./datasets/MitoGAN/"

TRAINING_PATH =  sorted(glob.glob(PATH_TO_DATASET + 'train/*'))
VALIDATION_PATH =  sorted(glob.glob(PATH_TO_DATASET + 'validation/*'))

number_of_training_files = len([file for file in TRAINING_PATH if "membranes_" in file])
number_of_validation_files = len([file for file in VALIDATION_PATH if "membranes_" in file])

In [None]:
training_index_iterator = itertools.cycle(iter(range(number_of_training_files)))
validation_index_iterator = itertools.cycle(iter(range(number_of_validation_files)))

validation_path = PATH_TO_DATASET + 'validation/'
training_path =PATH_TO_DATASET + 'train/'

root = dt.DummyFeature(
    base_path=lambda validation: validation_path if validation else training_path,
    index=lambda validation: next(validation_index_iterator) if validation \
                             else next(training_index_iterator)
)

load_training_image = root + dt.LoadImage(
    path = lambda index, base_path: base_path + 'raw_' + str(index) + '.png',
    **root.properties
)


load_training_membranes = root + dt.LoadImage(
    path = lambda index, base_path: base_path + 'membranes_' + str(index) + '.png',
    **root.properties
) 

load_training_mitochondria = root + dt.LoadImage(
    path = lambda index, base_path: base_path + 'mitochondria_' + str(index) + '.png',
    **root.properties
)

In [None]:
def normalize(a):
    # normalize between 0 and 1
    b = (a - np.min(a, axis=(0, 1),keepdims = True))/np.ptp(a, axis=(0, 1), keepdims = True)
    # normalize between -1 and 1
    b = 2.*b-1
    
    return b         

normalization = dt.Lambda(lambda: lambda image: normalize(image))
normalized_training_image = load_training_image + normalization

In [None]:

    
normalization_mask = dt.Lambda(lambda sigma: lambda image: image/255.0)

training_mask_labels = dt.Combine([load_training_membranes, load_training_mitochondria])
training_mask = training_mask_labels +  dt.Merge(lambda: lambda image: image[1]*1.0 - image[0])
noise = dt.Gaussian(mu=0, sigma=lambda: np.random.rand() * 0.1)

noised_mask = training_mask + normalization_mask + noise



In [None]:
combined = dt.Combine([normalized_training_image, noised_mask])


augmented_dataset = dt.FlipLR(combined)
augmented_dataset += dt.Affine(
    rotate=lambda: np.random.rand() * 360,
    shear=lambda: np.random.rand() * 20 - 10,
    scale={
        "x": np.random.rand() * 0.3 + 0.85,
        "y": np.random.rand() * 0.3 + 0.85
    },
    mode="reflect"
)

dataset = dt.ConditionalSetFeature(
    on_true=combined,
    on_false=augmented_dataset,
    condition="is_validation",
    is_validation=lambda validation: validation 
)

In [None]:
# Load labels from storage
def get_image(image):
    return image[0]

def get_mask(image):
    return image[1]

In [None]:
import tensorflow.keras.utils as utils


NUMBER_OF_IMAGES = 8
for image_index in range(NUMBER_OF_IMAGES):
    image_tuple = dataset.update(sigma=0).resolve()
#     print(image.get_property("index", get_one=False))
    image = get_image(image_tuple)
    mask = get_mask(image_tuple)
    
    plt.figure(figsize=(14, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.colorbar()
    plt.subplot(1, 2, 2)
    plt.imshow(mask)
    plt.colorbar()
    plt.show()

In [None]:
import tensorflow as tf

from tensorflow.keras import layers
from deeptrack.models import KerasModel
from tensorflow.keras.initializers import RandomNormal
from tensorflow_addons.layers import InstanceNormalization

tf.keras.backend.clear_session()

In [None]:
weight_init = RandomNormal(mean = 0.0, stddev = 0.02)

In [None]:
def convolution_block(conv_layer_dimension,
                      kernel_size = (3,3),
                      strides = 1,
                      weight_init = weight_init,
                      **kwargs):
    def call(x):
        y = layers.Conv2D(conv_layer_dimension,
                   kernel_size = kernel_size,
                   strides = strides,
                   padding = "same",
                   kernel_initializer = weight_init)(x)
        y = InstanceNormalization()(y)
        y = layers.LeakyReLU(0.2)(y)
        
        return y
    
    return call 

def residual_block(conv_layer_dimension,
                   kernel_size = (3,3),
                   strides = 1,
                   weight_init = weight_init,
                   **kwargs):   
    def call(x): 
        identity =  layers.Conv2D(conv_layer_dimension,
                           kernel_size = (1,1))(x)
        
        y = layers.Conv2D(conv_layer_dimension,
                   kernel_size = kernel_size,
                   strides = strides,
                   padding = "same",
                   kernel_initializer = weight_init)(x)
        y = InstanceNormalization()(y)
        y = layers.LeakyReLU(0.2)(y)
        
        y = layers.Conv2D(conv_layer_dimension, 
                   kernel_size = kernel_size,
                   strides = 1, 
                   padding = 'same',
                   kernel_initializer = weight_init)(y)
        y = InstanceNormalization()(y)
        
        y = layers.Add()([identity, y])
        
        return layers.LeakyReLU(0.2)(y) 
    
    return call 

def pooling_block(conv_layer_dimension,
                  kernel_size = (3,3),
                  strides = 2,
                  weight_init = weight_init,
                  **kwargs):
    def call(x):
        y = layers.Conv2D(conv_layer_dimension,
                   kernel_size = kernel_size,
                   strides = strides,
                   padding = "same",
                   kernel_initializer = weight_init)(x)
        y = InstanceNormalization()(y)
        y = layers.LeakyReLU(0.2)(y)
        
        return y
    
    return call 

In [None]:
def deconvolution_block(conv_layer_dimension = None,
                          kernel_size = (3,3),
                          strides = 1,
                          weight_init = weight_init,
                          **kwargs):
    def call(x):
        y = layers.UpSampling2D(interpolation = 'bilinear')(x)
        y = layers.Conv2D(conv_layer_dimension,
                   kernel_size = kernel_size,
                   strides = strides,
                   padding = "same",
                   kernel_initializer = weight_init)(y)
        y = InstanceNormalization()(y)
        y = layers.LeakyReLU(0.2)(y)
        
        return y
    
    return call    

In [None]:
from deeptrack.models import unet
    
    
generator = unet(
    input_shape = (None, None, 1),                            # shape of the input
    conv_layers_dimensions = (16, 32, 64, 128, 256, 512), # number of features in each convolutional layer
    base_conv_layers_dimensions = (1024,),                  # number of features at the base of the unet
    output_conv_layers_dimensions = (16, 16),               # number of features in convolutional layer after the U-net
    steps_per_pooling = 2, #2                                 # number of convolutional layers per pooling layer
    number_of_outputs = 1,                                  # number of output features
    output_activation = "tanh",                             # activation function on final layer
    compile = False,
    output_kernel_size = 1,
    layer_functions = {
            "encoder_convolution_block"    : convolution_block,
            "bottleneck_convolution_block" : residual_block,
            "decoder_convolution_block"    : convolution_block,
            "pooling_function"             : pooling_block,
            "upsampling_function"          : deconvolution_block
            }
)

generator.summary()

In [None]:
def convolution_block_discriminator(conv_layer_dimension,
                      kernel_size = (4,4),
                      strides = 2,
                      weight_init = weight_init,
                      avoid_conv_layer = 16,
                      **kwargs):
    def call(x):
        y = layers.Conv2D(conv_layer_dimension,
                   kernel_size = kernel_size,
                   strides = strides,
                   padding = "same",
                   kernel_initializer = weight_init)(x)
        
        if conv_layer_dimension is not avoid_conv_layer: 
            y = InstanceNormalization(axis = -1, center = False, scale = False)(y)
            
        y = layers.LeakyReLU(0.2)(y)
        
        return y
    
    return call

In [None]:
def identity(*args):
    def call(x):
        return x
    return call

In [None]:
from deeptrack.models import convolutional

discriminator = convolutional(
    input_shape = (256, 256, 1),                       # shape of the input
    aux_input_shape = (256, 256, 1),  
    conv_layers_dimensions = (16, 32, 64, 128, 256),   # number of features in each convolutional layer
    dense_layers_dimensions = (),                      # number of neurons in each dense layer
    number_of_outputs = 1,                             # number of neurons in the final dense step (numebr of output values)
    compile = False,
    output_kernel_size = 4,
    layer_functions_ = {
            "convolution_block" : convolution_block_discriminator,
            "pooling_function"  : identity,
            }
)

discriminator.summary()

In [None]:
from tensorflow.keras.optimizers import Adam                                 

# model
model = dt.models.cgan(generator = generator, 
             discriminator = discriminator,
             discriminator_loss = "mse",
             discriminator_optimizer = Adam(lr = 0.0002, beta_1 = 0.5),
             discriminator_metrics = "accuracy",
             assemble_loss = ["mse","mae"],
             assemble_optimizer = Adam(lr = 0.0002, beta_1 = 0.5),
             assemble_loss_weights = [1, 0.5],
             )

In [None]:
from deeptrack.generators import Generator, ContinuousGenerator

data_generator = ContinuousGenerator(
    dataset,
    label_function=get_image,
    batch_function=get_mask,
    batch_size = 16,
    min_data_size=256,
    max_data_size=257,
)

In [None]:

with data_generator:
    for epoch in range(12,200):
        model.fit(
            data_generator, 
            epochs = 50, 
            steps_per_epoch=8
        )
        model.save_weights("noised_model" + str(epoch) + ".h5")

In [None]:
model.save_weights("model" + str(0) + ".h5")

In [None]:

import matplotlib.pyplot as plt  

datas = [dataset.update(validation=True).resolve() for _ in range(7)]
            
for ep in range(15, 20):
    model.load_weights("model" + str(ep) + ".h5")
    print(ep)
    # Create subplots
    fig, axs = plt.subplots(3, 7, figsize = (80, 35))
    
    data = get_mask(data_tuple)
    label = get_image(data_tuple)

    for i, data_tuple in enumerate(datas):
        

        prediction = model.predict(np.array([data]))
        axs[0,i].imshow(data, vmin = -1, vmax = 1)
        axs[0,i].axis("off")

        axs[1,i].imshow(label)
        axs[1,i].axis("off")

        axs[2,i].imshow(prediction[0, ..., 0], vmin=-1, vmax=1)
        axs[2,i].axis("off")          


    plt.subplots_adjust(wspace=0.02, hspace=0.02)
    plt.show()  

In [None]:

import matplotlib.pyplot as plt  


for ep in range(12, 25):   
    # ep = 19
    model.load_weights("noised_model" + str(ep) + ".h5")
    print(ep)
    # Create subplots
    fig, axs = plt.subplots(3, 7, figsize = (80, 35))



    for i in range(7):

        data_tuple = dataset.update(validation=True, index=55, sigma=0.02*i).resolve()
        data = get_mask(data_tuple)
        label = get_image(data_tuple)


        prediction = model.predict(np.array([data]))

        axs[0,i].imshow(data, vmin = -1, vmax = 1)
        axs[0,i].axis("off")

        axs[1,i].imshow(label)
        axs[1,i].axis("off")

        axs[2,i].imshow(prediction[0, ..., 0], vmin=-1, vmax=1)
        axs[2,i].axis("off")          


    plt.subplots_adjust(wspace=0.02, hspace=0.02)
    plt.show()  

In [None]:
dloss = []
gloss = []

with open("loss.txt", 'r') as f:
    epoch = 0
    while True:
        a = f.readline()
        if len(a) > 5:
            idx = a.find("D loss:")
            dloss.append(float(a[idx+8:idx+14]))
            idx = a.find("G loss:")
            gloss.append(float(a[idx+8:idx+14]))
        
        if len(dloss) >= 1000:
            break
            
            
scipy.io.savemat("../../figures/gan_loss.mat", {
    "dloss": dloss,
    "gloss": gloss
})

In [None]:
model.load_weights("noised_model11.h5")
validation_set = [dataset.update(validation=True, sigma=0.05).resolve() for _ in range(number_of_validation_files)]
validation_data = [get_mask(X) for X in validation_set]
validation_labels = [get_image(X) for X in validation_set]


predictions = model.predict(np.array(validation_data))

import scipy.io

scipy.io.savemat("../../figures/MitoGAN.mat", {
    "data": validation_data,
    "labels": validation_labels,
    "predictions": predictions
})

In [None]:
dat = [d[0] for d in data_generator.data]
lab = [d[1] for d in data_generator.data]
scipy.io.savemat("../../figures/MitoGAN_raw.mat", {
    "masks": dat,
    "images": lab,
})

In [None]:
dataset.update(validation=True, sigma=0.01)

noise.update(sigma=0.01)
noise.sigma.current_value