In [None]:
# Library importation
import gc
import numpy as np
from numpy import load
from numpy import zeros
from numpy import ones
from numpy.random import randint
from tensorflow.keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.models import Model
from keras.layers import Input
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from keras.layers import Dropout
from keras.layers import BatchNormalization
from keras.layers import LeakyReLU
from matplotlib import pyplot as plt
from os import listdir
from PIL import Image
%env SM_FRAMEWORK=tf.keras
import segmentation_models as sm
from segmentation_models import Unet
from segmentation_models import get_preprocessing
import tensorflow as tf
from tifffile import imread

In [None]:
"""
Defining the functions to create the Pix2Pix network inspired by Jason Brownlee's Pix2Pix Example
"""
# define the discriminator model
def define_discriminator(image_shape):
    # weight initialization
    init = RandomNormal(stddev=0.02) #( RandomNormal is a initializer that generates tensors with a normal distribution.)
    # source image input
    in_src_image = Input(shape=image_shape)
    # target image input
    in_target_image = Input(shape=image_shape)
    # concatenate images channel-wise
    merged = Concatenate()([in_src_image, in_target_image])
    # C64
    d = Conv2D(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(merged)
    d = LeakyReLU(alpha=0.2)(d)
    # C128
    d = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # C256
    d = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # C512
    d = Conv2D(512, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # second last output layer
    d = Conv2D(512, (3,3), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # patch output
    d = Conv2D(1, (3,3), padding='same', kernel_initializer=init)(d)
    patch_out = Activation('sigmoid')(d)
    # define model
    model = Model([in_src_image, in_target_image], patch_out)
    # compile model
    opt = Adam(learning_rate=1e-4, beta_1=0.5)   
    model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.1])
    return model

# define the standalone generator model
def define_generator(image_shape=(256,256,1)):
    # define model
    sm.set_framework('tf.keras')
    sm.framework()
    model = Unet("efficientnetb0",encoder_weights=None,activation="tanh", input_shape=(256, 256, 1),classes=1)
    return model

# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):
    # make weights in the discriminator not trainable
    for layer in d_model.layers:
        if not isinstance(layer, BatchNormalization):
            layer.trainable = False
    # define the source image
    in_src = Input(shape=image_shape)
    # connect the source image to the generator input
    gen_out = g_model(in_src)
    # connect the source input and generator output to the discriminator input
    dis_out = d_model([in_src, gen_out])
    # src image as input, generated image and classification output
    model = Model(in_src, [dis_out, gen_out])
    # compile model
    opt = Adam(learning_rate=1e-4, beta_1=0.8)
    model.compile(loss=['binary_crossentropy', 'mse'], optimizer=opt, loss_weights=[10,100])
    return model


# select a batch of random samples, returns images and target
def generate_real_samples_(dataset, n_samples, patch_shape):
    # unpack dataset
    trainA, trainB = dataset
    # choose random instances
    ix = randint(0, trainA.shape[0], n_samples)
    print(np.shape(ix))
    # retrieve selected images
    X1, X2 = trainA[ix], trainB[ix]
    # generate 'real' class labels (1)
    y = ones((n_samples, patch_shape, patch_shape, 1))
    return [X1, X2], y

# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape,ix):
    # unpack dataset
    trainA, trainB = dataset
    # choose random instances
        # done outside of the function
    # retrieve selected images
    X1, X2 = trainA[ix], trainB[ix]
    # generate 'real' class labels (1)
    y = ones((n_samples, patch_shape, patch_shape, 1))
    return [X1, X2], y

# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, samples, patch_shape):
    # generate fake instance
    X = g_model.predict(samples)
    # create 'fake' class labels (0)
    y = zeros((len(X), patch_shape, patch_shape, 1))
    return X, y

# train pix2pix models
def train(d_model, g_model, gan_model, dataset, n_epochs=10, n_batch=16,save=False):
    # determine the output square shape of the discriminator
    n_patch = d_model.output_shape[1]
    # unpack dataset
    trainA, trainB = dataset
    # calculate the number of batches per training epoch
    bat_per_epo = int(len(trainA) / n_batch)
    # calculate the number of training iterations
    n_steps = bat_per_epo * n_epochs
    print(n_steps)
    # manually enumerate epochs
    gmin=1000000.0
    dmin=1000000.0
    tab_ix = np.random.permutation(len(trainA))
    tab_ix = np.reshape(tab_ix[:n_batch*(tab_ix.shape[0]//n_batch)],(bat_per_epo,-1))
    j=0
    for i in range(n_steps):
        # select a batch of real samples
        if j==tab_ix.shape[0]:
            j=0
            tab_ix = np.random.permutation(len(trainA))
            tab_ix = np.reshape(tab_ix[:n_batch*(tab_ix.shape[0]//n_batch)],(bat_per_epo,-1))
        ix=tab_ix[j]
        j+=1
        [X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch, ix)
        # generate a batch of fake samples
        X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
        # update discriminator for real samples
        d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
        # update discriminator for generated samples
        d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
        # update the generator
        g_loss1, g_loss2, g_loss3 = gan_model.train_on_batch(X_realA, [y_real, X_realB])
        # summarize performance
        print('>%d, d1[%.3e] d2[%.3e] g[%.3e,%.3e,%.3e]' % (i+1, d_loss1, d_loss2, g_loss1, g_loss2, g_loss3))
        # Saving the network every time it become better and when the network already did at least 500 batches.
        if g_loss1<gmin:
            gmin = g_loss1
            if save and i>10:
                g_model.save("Networks/Pix2Pix_gen__8_low_mse")
                d_model.save("Networks/Pix2Pix_det__8_low_mse")
                # displaying the original/prediction/corrected
                plt.figure()
                plt.imshow(X_realA[0,...,0],cmap="gray")
                plt.figure()
                plt.imshow(X_fakeB[0,...,0],cmap="gray")
                plt.figure()
                plt.imshow(X_realB[0,...,0],cmap="gray")
                plt.show()

In [None]:
#importation of data

#Original Image
path="Detectors_sham/"
data=[imread(path+image) for image in ["stack6.lsm","stack7.lsm","stack8.lsm"]]
data_concat = (np.concatenate((data[0],data[1],data[2]),axis = 0))
X = data_concat[:,2]
Y = data_concat[:,1]
#turning them into tensors 
X=np.expand_dims(X,axis=3)
Y=np.expand_dims(Y,axis=3)
x_train,y_train=X,Y
print(np.shape(x_train),np.shape(y_train))
gc.collect()

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

#Procedure to cut the image into smaller pieces
class Tile:
    '''
    Usage    : Tile(image,tilesize,overlap,verbose)
    image    : Array of shape [height,width,canals]
    tilesize : The output will be of shape [tilesize,tilesize,canals]
    overlap  : Amount of pixel common between two consecutive tiles
    ================================================================
    Methods  :
    padding(mode="reflect"): Returns a padded image via numpy.pad usage
    tilegeneration()       : Returns a list parts of original image(tiles)
    tilevis(concat=True)   : Plots the tiles list for the image
    detmask()              : Returns the mask of overlapping areas
    reconstruct()          : Builds the image back out of the tile list   
    '''
    
    def __init__(self, image, tilesize=None,overlap=None,verbose=False):
        
        self.im = image
        self.s  = np.shape(self.im)[:2]
        if tilesize is None:tilesize=self.s[0]//4
        self.t  = tilesize
        if overlap is None :overlap =tilesize//4
        self.o  = overlap
        self.verbose=verbose
        self.verdata()
        self.to = self.t-self.o
        self.nbt= [(self.s[0]-self.o)//self.to+1,(self.s[1]-self.o)//self.to+1]
        self.p  = None
        self.pIm= None
        self.tl = None
        
    def verdata(self):
        if (np.shape(self.im)[0]%2)!=0:
            if self.verbose:print("Changing image to an even dimension value, one pixel was removed on axis zero.")
            self.im=self.im[:-1]
        if (np.shape(self.im)[1]%2)!=0:
            if self.verbose:print("Changing image to an even dimension value, one pixel was removed on axis one.")
            self.im=self.im[:,:-1]
        if self.o%2!=0:
            if self.verbose:print("Please use even value for overlap, overlap was reduced by one.")
            self.o-=1
        if self.t>np.shape(self.im)[0] or self.t>np.shape(self.im)[1]:
            if self.verbose:print("You probably didn't do what you want, you will just have one padded image. (Tile size larger than image)")
                      
    def detpad(self):
        psize=[(self.nbt[0]*self.to+self.o)-self.s[0],(self.nbt[1]*self.to+self.o)-self.s[1]]
        for i in range(len(psize)):
            if  psize[i]==self.to:
                psize[i]%=self.to
                self.nbt[i]-=1
        return psize
    
    def padding(self,mode="reflect"):
        if self.p is None:
            self.p=self.detpad()
        padval=int(self.p[0]/2),int(self.p[1]/2)
        pIm=np.pad(self.im,((padval[0],padval[0]),(padval[1],padval[1]),(0,0)),mode=mode)
        del padval
        gc.collect()
        return pIm
    
    def tilegeneration(self):
        if self.pIm is None:
            self.pIm=self.padding()
        tiles=[]
        for i in range (self.nbt[0]):
            for j in range (self.nbt[1]):
                tiles.append(self.pIm[i*self.to:i*self.to+self.t,j*self.to:j*self.to+self.t])
        return tiles
    
    def tilevis(self,concat=True):
        if self.tl is None:
            self.tl=self.tilegeneration()
        plt.figure(figsize =( 3*self.nbt[1], 3*self.nbt[0]))
        if not concat:
            for i in range(self.nbt[0]):
                for j in range(self.nbt[1]):
                    ax = plt.subplot(self.nbt[0], self.nbt[1], i*self.nbt[1]+j+1)
                    ax.axis("off")
                    plt.imshow(self.tl[i*self.nbt[1]+j])
        else:
            im=[]
            for i in range(self.nbt[0]):
                line=[]
                for j in range(self.nbt[1]):
                    if j==0: line=self.tl[i*self.nbt[1]+j]
                    else: line = np.concatenate((line,self.tl[i*self.nbt[1]+j]),axis=1)
                if i==0:im=line
                else:im=np.concatenate((im,line),axis=0)
            plt.imshow(im[...,0])
        del im
        gc.collect()
        plt.show()
        
    
    def detmask(self):
        if self.pIm is None:
            self.pIm=self.padding()
        mask=np.zeros(np.shape(self.pIm)[:2])
        for i in range (self.nbt[0]):
            for j in range (self.nbt[1]):
                mask[i*self.to:i*self.to+self.t,j*self.to:j*self.to+self.t]+=np.ones((self.t,self.t))
        return mask[int(self.p[0]/2):self.s[0]+int(self.p[0]/2),int(self.p[1]/2):self.s[1]+int(self.p[1]/2)]
    
    def reconstruct(self,tl=None):
        if tl is None:
            tl=self.tilegeneration()
        reim=np.zeros(np.shape(self.pIm))
        lb=int(self.o/2)
        rb=self.t-lb
        for i in range (self.nbt[0]):
            for j in range (self.nbt[1]):
                reim[i*self.to+lb:i*self.to+rb,j*self.to+lb:j*self.to+rb]+=tl[i*self.nbt[1]+j][lb:rb,lb:rb]
        reim=reim[int(self.p[0]/2):self.s[0]+int(self.p[0]/2),int(self.p[1]/2):self.s[1]+int(self.p[1]/2)]
        if np.max(reim)<=1:
            reim=(reim*65536)
        return reim.astype("uint16")


In [None]:
"""Image Processing for network training"""
#calling of the tiling procedure and defining a function to apply on a whole tensor
def tiling(dataset,tilesize=256,overlap=64):
    
    tiled=[]
    for i in dataset:
        tempi=Tile(i,tilesize,overlap)
        tiled.append(tempi.tilegeneration())
    del tempi
    gc.collect()
    return np.reshape(tiled,(-1,tilesize,tilesize,np.shape(dataset)[-1]))

#applying the tiling
x_train,y_train=tiling(x_train),tiling(y_train)

In [None]:
path = "2022-04-21_BioSamples_Acquisition/"
data=[imread(path+im) for im in listdir(path)]
data_concat = (np.concatenate((data[0],data[1],data[2],data[3],data[4],data[5],data[6]),axis = 0))
X = data_concat[:,2]
Y = data_concat[:,1]
X=np.expand_dims(X,axis=3)
Y=np.expand_dims(Y,axis=3)
x_train,y_train=np.concatenate((x_train,tiling(X)),axis=0),np.concatenate((y_train,tiling(Y)),axis=0)

In [None]:
path = "2022-07-08_Fibers/"
data=[imread(path+im) for im in listdir(path)]
data_concat = np.copy(data)
X = data_concat[:,2]
Y = data_concat[:,1]
X=np.expand_dims(X,axis=3)
Y=np.expand_dims(Y,axis=3)
x_train,y_train=np.concatenate((x_train,tiling(X)),axis=0),np.concatenate((y_train,tiling(Y)),axis=0)

In [None]:
x_train_processed=((x_train/32768)-1).astype("float32")
y_train_processed=((y_train/32768)-1).astype("float32")
dataset = [x_train_processed,y_train_processed]
image_shape = dataset[0].shape[1:]

In [None]:
print(np.shape(dataset))

In [None]:
"""Network creating (brand new or loaded)"""
new=False
if new : #Used to initialize the model with empty settings(weights) for the network
    # define the models and putting them together
    d_model = define_discriminator(image_shape)
    g_model = define_generator(image_shape)
    gan_model = define_gan(g_model, d_model, image_shape)
if not new : #Used to load a previously trained model
    g_model = tf.keras.models.load_model("Networks/Pix2Pix_gen__7_low_MSE") 
    d_model = tf.keras.models.load_model("Networks/Pix2Pix_det__7_low_MSE") 
    gan_model = define_gan(g_model, d_model, image_shape)

In [None]:
"""Network training, save=True => Auto saving on drive"""
train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=2048,save=True) #training loop