In [None]:
#%% Preamble
from __future__ import absolute_import, division, print_function
import tensorflow as tf
import numpy as np
import scipy.io as IO
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.initializers import RandomNormal
from scipy.signal import convolve
from tensorflow_addons.layers import InstanceNormalization
import os
import sys
sys.path.append("deeptrack/")
import deeptrack as dt
import deeptrack.models
from deeptrack.features import Feature
import warnings
from deeptrack.losses import weighted_crossentropy
import keras
K = keras.backend
tf.keras.backend.clear_session()

In [None]:

#%% Define cGAN model
baseSize = 8 #default 8
DSFac = 5 #default 5

conv_layers_dimensions = tuple([baseSize*2**i for i in range(0, DSFac)])

def init_model():
    weight_init = RandomNormal(mean = 0.0, stddev = 0.02)
    activation = lambda x: layers.LeakyReLU(0.2)(x)

    convolution_block = dt.layers.ConvolutionalBlock(
        activation=activation, 
        instance_norm=True,
        kernel_initializer=weight_init
    )
    base_block = dt.layers.ResidualBlock(
        activation=activation,
        instance_norm=True,
        kernel_initializer=weight_init
    )
    pooling_block = dt.layers.ConvolutionalBlock(
        strides=2, 
        activation=activation, 
        instance_norm=True,
        kernel_initializer=weight_init
    )
    deconvolution_block = dt.layers.StaticUpsampleBlock(
        kernel_size=3, 
        instance_norm=True,
        activation=activation)
    
    # unet generator
    generator = dt.models.unet((None,None,1),conv_layers_dimensions=conv_layers_dimensions,steps_per_pooling=2,output_activation="sigmoid")

    discriminator_convolution_block = dt.layers.ConvolutionalBlock(
        kernel_size=(4, 4), 
        strides=1, 
        activation=activation,
        instance_norm=lambda x: (print(x), False if x==16 else {"axis":-1, "center": False, "scale":False})[1]
    )

    discriminator_pooling_block = dt.layers.ConvolutionalBlock(
        kernel_size=(4, 4), 
        strides=2, 
        activation=activation, 
        instance_norm={"axis":-1, "center": False, "scale":False},    
    )

    discriminator = dt.models.convolutional(
        input_shape = [(reduced_times, reduced_length, 1),]*2,  
        dropout = [0,0], #default 0.2
        conv_layers_dimensions = (16, 32, 64, 128, 256),   # number of features in each convolutional layer
        dense_layers_dimensions = (32,32),                 # number of neurons in each dense layer
        number_of_outputs = 1,                             # number of neurons in the final dense step (numebr of output values)
        compile = False,
        output_kernel_size = 4,
        convolution_block=discriminator_convolution_block,
        pooling_block=discriminator_pooling_block   
    )
    return discriminator, generator

def _compile(
    model: models.Model, *, loss="mae", optimizer="adam", metrics=[], **kwargs
):
    model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
    return model


class Model(Feature):
    def __init__(self, model, **kwargs):
        self.model = model
        super().__init__(**kwargs)

    def __getattr__(self, key):
        # Allows access to the model methods and properties
        try:
            return getattr(super(), key)
        except AttributeError:
            return getattr(self.model, key)


class cgan(Model):
    
    def __init__(
        self,
        generator=None,
        discriminator=None,
        discriminator_loss=None,
        discriminator_optimizer=None,
        discriminator_metrics=None,
        assemble_loss=None,
        assemble_optimizer=None,
        assemble_loss_weights=None,
        **kwargs
    ):

        # Build and compile the discriminator
        self.discriminator = discriminator
        self.discriminator.compile(
            loss=discriminator_loss,
            optimizer=discriminator_optimizer,
            metrics=discriminator_metrics,
        )

        # Build the generator
        self.generator = generator

        # Input shape
        self.model_input = self.generator.input

        # The generator model_input and generates img
        img = self.generator(self.model_input)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes the generated images as input and determines validity
        validity = self.discriminator([img, self.model_input])

        # The assembled model (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.assemble = models.Model(self.model_input, [validity, img])
        self.assemble.compile(
            loss=assemble_loss,
            optimizer=assemble_optimizer,
            loss_weights=assemble_loss_weights,
        )

        super().__init__(self.generator, **kwargs)

    def fit(self, data_generator, epochs, steps_per_epoch=None, **kwargs):
        for key in kwargs.keys():
            warnings.warn(
                "{0} not implemented for cgan. Does not affect the execution.".format(
                    key
                )
            )
        history = np.zeros((4,epochs))
        for epoch in range(epochs):
            steps = steps_per_epoch

            d_loss = 0
            g_loss = 0

            for step in range(steps):
                ## update data
                try:
                    data, labels = next(data_generator)
                except:
                    data, labels = data_generator[step]

                # Grab disriminator labels
                shape = (data.shape[0], *self.discriminator.output.shape[1:])
                valid, fake = np.ones(shape), np.zeros(shape)

                # ---------------------
                #  Train Discriminator
                # ---------------------

                # Generate a batch of new images
                gen_imgs = self.generator(data)
                #gen_imgs[gen_imgs<0] = 0
                
                d_loss_real = self.discriminator.train_on_batch([labels[...,0:1], data], valid)
                d_loss_fake = self.discriminator.train_on_batch([gen_imgs, data], fake)
                # make train on batch on 
                d_loss += 0.5 * np.add(d_loss_real, d_loss_fake)

                # ---------------------
                #  Train Generator
                # ---------------------

                # Train the generator (to have the discriminator label samples as valid)
                train_assembler = 1
                if train_assembler:
                    g_loss += np.array(self.assemble.train_on_batch(data, [valid, labels]))

                # Plot the progress

            try:
                data_generator.on_epoch_end()
            except:
                pass
            history[0,epoch] = d_loss[0]
            history[1,epoch] = g_loss[0]
            history[2,epoch] = g_loss[1]
            history[3,epoch] = g_loss[2]
            if train_assembler:
                print(
                    "%d [D loss: %f, acc.: %.2f%%] [G loss: %f, %f, %f]"
                    % (
                        epoch,
                        d_loss[0] / steps,
                        100 * d_loss[1] / steps,
                        g_loss[0] / steps,
                        g_loss[1] / steps,
                        g_loss[2] / steps,
                    )
                )
            else:
                print(
                    "%d [D loss: %f, acc.: %.2f%%]"
                    % (
                        epoch,
                        d_loss[0] / steps,
                        100 * d_loss[1] / steps,
                    )
                )
        return history
    

class DualDiscriminatorGAN(Model):
    def __init__(
        self,
        generator=None,
        discriminator=None,
        discriminator_loss=None,
        discriminator_optimizer=None,
        discriminator_metrics=None,
        assemble_loss=None,
        assemble_optimizer=None,
        assemble_loss_weights=None,
        upper_threshold=0.95,
        lower_threshold=0.51,
        **kwargs
    ):
        self.upper_threshold=upper_threshold
        self.lower_threshold=lower_threshold

        # Build and compile the discriminator
        self.discriminator = discriminator
        for i in range(len(discriminator)):
            self.discriminator[i].compile(
                loss=discriminator_loss[i],
                optimizer=discriminator_optimizer[i],
                metrics=discriminator_metrics[i],
            )
        

        # Build the generator
        self.generator = generator

        # Input shape
        self.model_input = self.generator.input
        
        self.condition = tf.keras.layers.Input(shape=self.generator.output.shape[1:])
        # The generator model_input and generates img
        img = self.generator(self.model_input)
        
        # For the combined model we will only train the generator
        for i in range(len(discriminator)):
            self.discriminator[i].trainable = False
        

        # The discriminator takes the generated images as input and determines validity
        validity = [self.discriminator[i](tf.keras.layers.Concatenate()([img, self.condition])) for i in range(len(discriminator))]
        assembleoutput=validity
        assembleoutput.append(img)
        # The assembled model (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.assemble = models.Model([self.model_input,self.condition], assembleoutput)
        self.assemble.compile(
            loss=assemble_loss,
            optimizer=assemble_optimizer,
            loss_weights=assemble_loss_weights,
        )

        super().__init__(self.generator, **kwargs)

    def fit(self, data_generator, epochs, steps_per_epoch=None, **kwargs):
        for key in kwargs.keys():
            warnings.warn(
                "{0} not implemented for cgan. Does not affect the execution.".format(
                    key
                )
            )
        train_discriminator=np.ones((len(self.discriminator)))
        history = np.zeros((4,epochs))
        for epoch in range(epochs):
            steps = steps_per_epoch

            d_loss = 0
            
            g_loss = 0            
            
            for step in range(steps):
                d_loss_all=[]
                ## update data
                try:
                    data, labels = next(data_generator)
                except:
                    data, labels = data_generator[step]
                train_assembler = 1

                gen_imgs = self.generator(data)

                all_data=np.zeros((2*gen_imgs.shape[0], *gen_imgs.shape[1:]))
                all_data[:data.shape[0]]=labels[:,0]
                all_data[data.shape[0]:]=gen_imgs
                all_data_2=np.zeros((2*gen_imgs.shape[0], *gen_imgs.shape[1:]))
                all_data_2[:data.shape[0]]=labels[:,1]#data
                all_data_2[data.shape[0]:]=labels[:,1]#data
                
                discriminator_input=np.concatenate((all_data,all_data_2),axis=-1)
                
                # Grab disriminator labels
                vtot=[]
                for i in range(len(self.discriminator)):
                    shape = (data.shape[0], *self.discriminator[i].output.shape[1:])
                    valid, fake = np.ones(shape), np.zeros(shape)
                    vtot.append(valid)
                    validfake=np.ones((2*data.shape[0], *self.discriminator[i].output.shape[1:]))
                    validfake[:data.shape[0]]=valid
                    validfake[data.shape[0]:]=fake
                    if not train_discriminator[i]:
                        d_loss_all.append(self.discriminator[i].train_on_batch(np.concatenate((all_data,all_data_2),axis=-1),validfake))
                        if d_loss_all[i][1]<self.upper_threshold:
                            train_discriminator[i]=1
                    else:
                        d_loss_all.append(self.discriminator[i].train_on_batch(np.concatenate((all_data,all_data_2),axis=-1),validfake))
                        if d_loss_all[i][1]>self.upper_threshold:
                            train_discriminator[i]=0
                    
                if np.max(np.array(d_loss_all)[:,1])>self.lower_threshold:
                    train_assembler=0
                if train_assembler:
                    vtot.append(labels[:,0])
                    g_loss += np.array(self.assemble.train_on_batch([data,labels[:,1]], vtot))
                    
                d_loss=np.mean(np.array(d_loss_all),axis=0)

            try:
                data_generator.on_epoch_end()
            except:
                pass

            history[0,epoch] = d_loss[0]
            try:
                history[1,epoch] = g_loss[0]
                history[2,epoch] = g_loss[1]
                history[3,epoch] = g_loss[2]
            except:
                history[1:,epoch] = 0
            if train_assembler:
                print(
                    "%d [D loss: %f, acc.: %.2f%%] [G loss: %f]"
                    % (
                        epoch,
                        d_loss[0] / steps,
                        100 * d_loss[1] / steps,
                        g_loss[0] / steps
                    )
                )
            else:
                print(
                    "%d [D loss: %f, acc.: %.2f%%]"
                    % (
                        epoch,
                        d_loss[0] / steps,
                        100 * d_loss[1] / steps,
                        
                    
                    )
                )
        return history
    
def reset_model(d_lr, g_lr, d_loss, g_loss, g_loss_weights):
    discriminator, generator = init_model()
    model = cgan(generator = generator, 
        discriminator = discriminator,
        discriminator_loss = d_loss,
        discriminator_optimizer = Adam(lr = d_lr, beta_1 = 0.5),
        discriminator_metrics = "accuracy",
        assemble_loss = g_loss,
        assemble_optimizer = Adam(lr = g_lr, beta_1 = 0.5), 
        assemble_loss_weights = g_loss_weights,
    )
    return model


In [None]:
#%% Losses


mae = tf.keras.losses.MeanAbsoluteError()
mse = tf.keras.losses.MeanSquaredError()

def unet_crossentropy(T, P):
    weight1=K.sum(K.flatten(T))
    weight2=K.sum(K.flatten(1-P))

    weightT=weight1+weight2
    weight1/=weightT
    weight2/=weightT
    eps = 1e-3
    
    weight1 = 1
    weight2 = 1
    return -K.mean(
        weight1 * T * K.log(P + eps) + weight2 * (1 - T) * K.log(1 - P + eps)
    )

    
from tensorflow.keras.optimizers import Adam    

def mae_loss(T,P):
    T_segment = T[...,0]
    loss = mae(T_segment,P)
    return loss

def combined_loss(T,P):
    T_segment = T[...,0]
    D_true = T[:,0,0,1]
    loss = mae(T_segment,P)
    
    traj_loss = mae(T,tf.math.multiply(T,P))
    return K.sqrt(K.sqrt(loss))

def mae_crossentropy_loss(T,P):
    loss1 = unet_crossentropy(T,P)
    loss2 = K.sqrt(mae(T,P))
    loss3 = mae(T,tf.math.multiply(T,P))
    return loss1 + 0.5*loss2 #+ 1/(1e-3 + loss3)

def convloss(T,P):
    return K.sqrt(K.sqrt(mae(T,P)))



In [None]:
from utils.ParticleGenerator import GenNoise, Trajectory, init_particle_counter,input_array, PostProcess

In [None]:
import skimage.measure
def batch_function(image):
    img = image[...,:1]
    img = skimage.measure.block_reduce(img,(T_reduction_factor,L_reduction_factor,1),np.mean)
    return img

def label_function(image):
    img = image[...,1:2]
    img = skimage.measure.block_reduce(img,(T_reduction_factor,L_reduction_factor,1),np.mean)
    return img
    
def generate_training_batch(image,batch_size):
    
    min_data_size = batch_size
    max_data_size = batch_size+1

    data_generator=dt.generators.ContinuousGenerator(image,
                                                batch_function=batch_function,
                                                label_function=label_function,
                                                batch_size=batch_size,
                                                min_data_size=min_data_size,
                                                max_data_size=max_data_size)

    total_nr_epochs = 0
    with data_generator:
        GAN.fit(data_generator, epochs=0, steps_per_epoch=1)
        
    b,L = data_generator[0] 
    return b,L

In [None]:
#%% Train cGAN

import datetime

TRAIN = 1
lowiOC=1
weightedLoss=1
reduceInt = 0
increaseDiff = 0
use_val_data = 0

batch_size = 4
min_data_size = 16
max_data_size = 32

DEBUG = 0
if DEBUG:
    batch_size = 1
    min_data_size = 1
    max_data_size = 2

reset_GAN =1
load_GAN = 1
save_GAN = 0
val_freq_GAN = 20

GAN_loadname='Network-weights/GAN-D0.1-2I0.0-1.0loss=0.0030516684.h5'

load_unet = 1
load_discriminator = 0
save_unet = 0

unet_loadname='Network-weights/U-net-D0.1-2I0.0-1.0loss=0.0030516684.h5'

pathToEmpty=None
steps_per_epoch = 1
nbr_GAN_loops = 50000
nbr_GAN_epochs_per_loop = 500

length = 128*4
L_reduction_factor = 4
reduced_length = int(length/L_reduction_factor)

times = 2048
T_reduction_factor = 1
reduced_times = int(times/T_reduction_factor)

    
max_nbr_particles = 8
nump = lambda: 1+np.random.randint(3)

Int = [5,6]
D1 = 0.1 #20
D2 = 1.2 #80

Ds=[D1,D2]
st = [0.04,0.05]   
vel=lambda: (50000*np.random.rand())*10**-6

getTrainTraj = Trajectory(intensity=Int,s=st,diffusion=Ds)

#Normal image
image=dt.FlipLR(dt.FlipUD(input_array(times=times,length=length) + init_particle_counter() 
                        + GenNoise(dX=lambda:.00001+.00003*np.random.rand(),
                                                dA=0,#lambda:0+np.random.rand()*0.0001,
                                                noise_lev=lambda:.0001,
                                                biglam=lambda:0.6+.4*np.random.rand(),
                                                bgnoiseCval=lambda:0.03+.02*np.random.rand(),
                                                bgnoise=lambda:.08+.04*np.random.rand(),
                                                bigx0=lambda: .1*np.random.randn(),
                                                sinus_noise_amplitude=lambda: np.random.rand(),
                                                freq =lambda: np.random.rand()*np.pi)
                                    + getTrainTraj**nump
                                    + PostProcess()))



if reset_GAN:
    d_lr = 1e-5
    g_lr = 1e-5
    d_loss = convloss
    d_loss_2 = "mse"
    if weightedLoss:
        g_loss = ["mae",weighted_crossentropy(10,1)]
    g_loss = ["mae",mae_crossentropy_loss],
    g_loss_weights = [1,1]#,1]
    GAN = reset_model(d_lr, g_lr, d_loss, g_loss, g_loss_weights)
    best_model = reset_model(d_lr, g_lr, d_loss, g_loss, g_loss_weights)
    best_net = '../input/gan-weights-11-nov-1115h5/weights_11_nov_1115.h5'
if load_GAN:
    GAN.load_weights(GAN_loadname)
    best_model.load_weights(GAN_loadname)
if load_unet:
    GAN.generator.load_weights(unet_loadname)
    best_model.generator.load_weights(unet_loadname)
if load_discriminator:
    GAN.discriminator.load_weights(disc_loadname)
        
### --- Data generator --- ###
data_generator=dt.generators.ContinuousGenerator(image,
                                            batch_function=batch_function,
                                            label_function=label_function,
                                            batch_size=batch_size,
                                            min_data_size=min_data_size,
                                            max_data_size=max_data_size)

def GenerateValData(pathToEmpty=None):
    # Set the paths for the data and the ground truth
    if lowiOC:
        val_path='../Data/Preprocessed chBInsulin 2x1 Simulated Data/lowiOC/diffusion/'
        traj_path='../Data/Preprocessed chBInsulin 2x1 Simulated Data Ground Truth/lowiOC/'
        val_path='../Data/Preprocessed DNA lowiOC Simulated Data/lowiOC/diffusion/'
        traj_path='../Data/Preprocessed DNA lowiOC Simulated Data Ground Truth/lowiOC/'
        val_path='../Data/Preprocessed lowiOC Simulated Data/lowiOC/diffusion/'
        traj_path = '../Data/Preprocessed lowiOC Simulated Data Ground Truth/'
    else:
        # All IOC data
        val_path='../Data/Preprocessed simulated data/allIOC/diffusion/'
        traj_path = '../Data/Simulated Data - Ground Truth/'
    
    # Load the validation data
    valFiles = os.listdir(val_path)
    nrOFValFiles = int(len(valFiles)) 
    testFile=np.load(val_path+valFiles[0])
    T = int(8192 / T_reduction_factor)  # Number of time points after reduction
    L = int(512 / L_reduction_factor)   # Length of kymograph after reduction
    TDiff = int((10000 - 8192) / T_reduction_factor / 2)  # Time crop
    LDiff = int((600 - 512) / L_reduction_factor / 2)  # Spatial crop

    valImgs = None  # Initialize validation data
    valLabels = None  # Initialize validation labels
    for i in range(nrOFValFiles):
        # Read the intensity of the current file
        intensity = float(valFiles[i][valFiles[i].index("iOC")+3:valFiles[i].index("_M")])*10000
        
        # Check if the intensity is above the minimum threshold
        if intensity >= Int[0]:
            # Load the current kymograph
            file = np.expand_dims(np.load(val_path + valFiles[i]), -1)
            
            # Crop the kymograph in time and space and reduce the size
            kymo = file[TDiff:T+TDiff, LDiff:-LDiff, :]
            kymo = np.expand_dims(kymo, 0)
            if valImgs is None:
                valImgs = kymo
            else:
                valImgs = np.append(valImgs, kymo, 0)
                
            # Load the ground truth trajectory for the kymograph
            try:
                traj = np.load(traj_path + valFiles[i])
            except:
                trajFile = "simulated" + valFiles[i][valFiles[i].index("_"):]
                traj = np.load(traj_path + trajFile)
                
            # Transpose the trajectory if it has more columns than rows
            if traj.shape[1] > traj.shape[0]:
                traj = np.transpose(traj)
            
            # Reduce the size of the ground truth trajectory and normalize it
            label = skimage.measure.block_reduce(traj, (T_reduction_factor, L_reduction_factor), np.mean)
            label=label-1
            label = np.abs(label/np.min(label))
            try:
                valLabels = np.append(valLabels,np.expand_dims(label[TDiff:T+TDiff,LDiff:-LDiff,:],0),0)
            except:
                valLabels = np.expand_dims(label[TDiff:T+TDiff,LDiff:-LDiff,:],0)

    keepIndices = np.sum(valImgs,(1,2,3)) #some kymographs are empty
    keepIndices=np.where(keepIndices!=0)[0]
    valImgs=valImgs[keepIndices]
    valLabels=valLabels[keepIndices]
    
    labelMax=np.expand_dims(np.max(valLabels,1),1)
    labelMax[labelMax==0]=1
    valLabels=valLabels/labelMax
    
    if pathToEmpty != None:
        pathToEmpty= pathToEmpty+"/diffusion/"
        emptyFiles=os.listdir(pathToEmpty)
        emptyValImgs=np.array([np.load(pathToEmpty+emptyFiles[i]) for i in range(0,len(emptyFiles))])
        emptyValImgs = emptyValImgs[:,TDiff:T+TDiff,LDiff:-LDiff]
        emptyValImgs = np.expand_dims(emptyValImgs,-1)
        emptyValLabels=np.zeros(emptyValImgs.shape)
        
        valImgs = np.append(valImgs,emptyValImgs,0)
        valLabels = np.append(valLabels,emptyValLabels,0)

    return valImgs,valLabels

def GenerateTestData():
    val_path = "../Data/Preprocessed laserNoise Ferritin/ChannelE/"
    valFiles = os.listdir(val_path)
    nrOFValFiles = int(3*len(valFiles)/3)
    valImgs = np.zeros((nrOFValFiles,8192,128,1))
    for i in range(0,nrOFValFiles):
        file = np.expand_dims(np.load(val_path+valFiles[i]),-1) 
        valImgs[i,...]  = np.copy(file[904:8192+904,11:-11,:])
        
    return valImgs


def GenerateValDataWithSize(valImgs,valLabels,size=512):
    times = valImgs.shape[1]
    randStart = np.random.randint(0,times-size)
    newValImgs=valImgs[:,randStart:randStart+size,...]
    newValLabels=valLabels[:,randStart:randStart+size,...]
    
    return newValImgs,newValLabels
    

### --- Init more params --- ###
total_nr_epochs = 0
MAX_NBR_EPOCHS = 500000
meanLoss = np.inf
try:
    currentBestLoss
except:
    currentBestLoss = np.inf

try:
    meanLoss
except:
    meanLoss = [np.inf]*4

save = True
if use_val_data:
    try: valImgs
    except:
            valImgs,valLabels = GenerateValData(pathToEmpty)
        
if load_unet:
    try:
       currentBestLoss =  float(unet_loadname[unet_loadname.index("loss")+5:unet_loadname.index("loss")+5+8])
    except:
        pass
print("currentBestLoss= "+str(currentBestLoss))
### --- Training loop for cGAN--- ###
if TRAIN:
    with tf.device('/GPU:0'):
        # Iterate over data generator until maximum number of epochs is reached
        with data_generator:
            while total_nr_epochs < MAX_NBR_EPOCHS:
                for ii in range(nbr_GAN_loops):
                    # Train the GAN model for a certain number of epochs
                    history = GAN.fit(data_generator, epochs=nbr_GAN_epochs_per_loop, steps_per_epoch=1)
                    
                    # Get the first batch and its loss
                    b, L = data_generator[0]
                    mean_loss = np.mean(np.mean(history, 1))
                    
                    # Adjust hyperparameters based on current loss
                    if mean_loss < meanLoss and not use_val_data:
                        if reduceInt:
                            getTrainTraj.I[1] = np.max([0.95 * getTrainTraj.I[1], 0.8])
                        meanLoss = np.mean(np.mean(history, 1))
                    
                    if increaseDiff:
                        getTrainTraj.D[0] = np.min([1.01 * getTrainTraj.D[0], 1.25])
                    
                    # Use validation data
                    if use_val_data:
                        # Generate new validation data with a specific size
                        newValImgs, newValLabels = GenerateValDataWithSize(valImgs, valLabels, size=1024)
                        # Use the generator to predict the labels of the new validation data
                        b = GAN.generator.predict(newValImgs[:, ...], batch_size=1)
                        L = newValLabels[:, ...]
                        # Calculate validation loss
                        val_loss = mae(L, b) # change loss here
                        print("valLoss = " + str(val_loss))
                        # Train the generator model with the new validation data
                        GAN.generator.fit(newValImgs, newValLabels, batch_size=4, epochs=1)
                        
                        # Update hyperparameters based on current validation loss (curriculum learning)
                        if val_loss < currentBestLoss:
                            if reduceInt:
                                currentInt = np.copy(getTrainTraj.I[1])
                                getTrainTraj.I[1] = np.max([0.95 * getTrainTraj.I[1], 1])
                                print(getTrainTraj.I[1])
                            
                            currentBestLoss = np.copy(val_loss)
                            # Save the GAN and generator models
                            GAN.save('GAN-' + "D" + str(round(D1, 2)) + "-" + str(round(D2, 2)) + "I" + str(round(getTrainTraj.I[0], 2)) + "-" + str(round(getTrainTraj.I[1], 2)) + "loss=" + str(val_loss.numpy()) + "chA.h5")
                            GAN.generator.save('U-net-' + "D" + str(round(D1, 2)) + "-" + str(round(D2, 2)) + "I" + str(round(getTrainTraj.I[0], 2)) + "-" + str(round(getTrainTraj.I[1], 2)) + "loss=" + str(val_loss.numpy()) + "chA.h5")
                    
                    else:
                        if save:
                            # Save the GAN and generator models
                            GAN.save('GAN-' + datetime.datetime.now().strftime("%d%m%Y-%H%M%S") + " I" + str(round(getTrainTraj.I[0], 2)) + "-" + str(round(getTrainTraj.I[1], 2)))

    


In [None]:


#%%
plt.close('all')
save=0
savePath="../Figures/Meetings/2022-11-29/SimulatedPredictions/"

if save:
    try:
        os.makedirs(savePath)
    except:
        pass
getTestTraj = Trajectory(I=[0,0],s=st,D=[1.2,2])
#getTestTraj = Trajectory(I=[1,1.51],s=st)

testImage=dt.FlipLR(dt.FlipUD(input_array() + init_particle_counter() 
                        + GenNoise(dX=lambda:.00001+.00003*np.random.rand(),
                                                dA=0,
                                                noise_lev=lambda:.0001,
                                                biglam=lambda:0.6+.4*np.random.rand(),
                                                bgnoiseCval=lambda:0.03+.02*np.random.rand(),
                                                bgnoise=lambda:.08+.04*np.random.rand(),
                                                bigx0=lambda: .1*np.random.randn(),
                                                sinus_noise_amplitude=lambda: np.random.rand(),
                                                freq =lambda: np.random.rand()*np.pi)
                                  + getTrainTraj**nump
                                  + PostProcess()))

# image=dt.FlipLR(dt.FlipUD(input_array() + init_particle_counter() 
#                         + GenNoise(dX=lambda:.00001+.00003*np.random.rand(),
#                                                 dA=lambda:0+np.random.rand()*0.0001,
#                                                 noise_lev=lambda:.0001,
#                                                 biglam=lambda:0.6+.4*np.random.rand(),
#                                                 bgnoiseCval=lambda:0.03+.02*np.random.rand(),
#                                                 bgnoise=lambda:.08+.04*np.random.rand(),
#                                                 bigx0=lambda: .1*np.random.randn(),
#                                                 sinus_noise_amplitude=lambda: np.random.rand(),
#                                                 freq =lambda: np.random.rand()*np.pi)
#                                     + getTrainTraj**nump
#                                     + PostProcess()))


# testImage=dt.FlipLR(dt.FlipUD(input_array() + init_particle_counter() 
#                         + GenNoise(dX=.0001,#+.00003*np.random.rand(),
#                                     dA=0,
#                                     noise_lev=.0002*(0.5+0.5*np.random.rand()),#+0.00005*np.random.randn(),
#                                     biglam=0.5+.7*np.random.rand(),
#                                     bgnoiseCval=0.05*np.random.rand(),
#                                     bgnoise=.08+.04*np.random.rand(),
#                                     bigx0=lambda: .1*np.random.randn(),
#                                     sinus_noise_amplitude= np.random.rand(),
#                                     freq = np.random.rand()*np.pi)#2*np.random.rand()))
#                                     +getTestTraj**nump
#                                     + PostProcess()))

#newValImgs,newValLabels = valImgs,valLabels
b,L = generate_training_batch(testImage,4)
print(L.shape)
for j in range(0,4):
    fig, ax = plt.subplots(2,2,figsize=(16,16))      
    
    # Input image
    img = b[j,...]
    cax = ax[0][0]
    cax.set_title('Kymograph',fontsize=16)
    
    cax.imshow(np.squeeze(img).T,aspect='auto')
    
    # Label image
    cax = ax[1][0]
    try:
        label_img = np.reshape(L[j,...,0],(1,reduced_times,reduced_length,1))
    except: 
        label_img = np.reshape(L[j,0,...,0],(1,reduced_times,reduced_length,1))
        
    cax.imshow(np.squeeze(label_img).T,aspect='auto')
    cax.set_title('Ground truth',fontsize=16)
    
    # Predicted image current net
    cax = ax[0][1]
    pred_img = GAN.generator.predict(np.reshape(b[j,...],(1,reduced_times,reduced_length,1)))
    #pred_img[pred_img < 0] = 0
    im = cax.imshow(np.squeeze(pred_img).T,aspect='auto')
    cax.set_title('Predicted Trajectory',fontsize=16)
   # cax.plot(np.fft.fft2(img[:,256,0]))
    
    cax = ax[1][1]
    cax.set_title('Prediction histogram',fontsize=16)
    plt.hist(pred_img[pred_img>0.01].flatten(), density=True, bins=1000)
    #plt.hist(b[j,:,:,0].flatten(), density=True, bins=1000)
    plt.tight_layout()
    plt.show()
    
    if save:
        plt.savefig(savePath+str(j)+".png")
    
    # plt.figure(figsize=(16,16))
    # plt.imshow(np.squeeze(img).T,aspect='auto')
    
    i=np.random.randint(0,len(newValImgs))
    fig, ax = plt.subplots(2,2,figsize=(16,16))      
    
    # Input image
    img = newValImgs[i,...]
    cax = ax[0][0]
    cax.set_title('Kymograph',fontsize=16)
    cax.set_xlabel("t (frames)")
    cax.set_ylabel("x (pixels/4)")
    
    cax.imshow(np.squeeze(img).T,aspect='auto')
    
    # Label image
    cax = ax[1][0]
    try:
        label_img = newValLabels[i,...,0]
    except: 
        label_img = np.reshape(newValLabels[i,0,...,0],(1,reduced_times,reduced_length,1))
        
    cax.imshow(np.squeeze(label_img).T,aspect='auto')
    cax.set_title('Segment label',fontsize=16)
    cax.set_xlabel("t (frames)")
    cax.set_ylabel("x (pixels/4)")
    
    cax=ax[1][0]
    #cax.imshow(np.log(np.abs(np.fft.fftshift(np.fft.fft2(img[:,:,0])))))
    # Predicted image current net
    cax = ax[0][1]
    pred_img = GAN.generator.predict(np.expand_dims(img,0))
    #pred_img[pred_img < 0] = 0
    im = cax.imshow(np.squeeze(pred_img).T,aspect='auto')
    cax.set_xlabel("t (frames)")
    cax.set_ylabel("x (pixels/4)")
    
    cax = ax[1][1]
    plt.hist(pred_img[pred_img>0.01].flatten(), density=True, bins=1000)
    #plt.hist(newValImgs[i,:,:,0].flatten(), density=True, bins=1000)#
   # im = cax.imshow(np.squeeze(pred_img).T,aspect='auto')
    cax.set_xlabel("t (frames)")
    cax.set_ylabel("x (pixels/4)")
    plt.tight_layout()
    plt.show()
    
    if save:
        plt.savefig(savePath+str(j)+"-val.png")
    # plt.figure(figsize=(16,16))
    # plt.imshow(np.squeeze(img).T,aspect='auto')
    
#%% Study cross section of channels
import os
import matplotlib.pyplot as plt
import numpy as np
path=r"C:\Users\ccx55\OneDrive\Documents\GitHub\Phd\Biosensing---nanochannel-project\Data\chBInsulin Simulated Data\lowiOC"
path="../Data/2022-12-16/TBE/ch2B-4/"

#path=r"C:\Users\ccx55\OneDrive\Documents\GitHub\Phd\Biosensing---nanochannel-project\Data\2022-11-17-DNA\TE-buff\3C-2"
path = path.replace(os.sep, '/') + "/"
import scipy.io as io
frames=100
files = os.listdir(path)
file = files[0]

img=io.loadmat(path+file)
try:
    img=img["data"]["Im"][0][0]
except:
    img=img["Im"].squeeze()
plt.figure()
for i in range(0,frames):
    # img[i,:] -= np.min(img[i,:])
    # img[i,:] =img[i,:]/np.max(img[i,:])
    plt.plot(img[i,:])
    
path=r"Z:\NSM\simulated_tests\experimental_noise\empty_channel_50x30"
path = path.replace(os.sep, '/') + "/"
import scipy.io as io

files = os.listdir(path)
file = files[9]

img=io.loadmat(path+file)
img=img["data"]["Im"][0][0]

plt.figure()
for i in range(0,frames):
    img[i,:] -=np.min(img[i,:])
    img[i,:] /=np.max(img[i,:])
    plt.plot(img[i,:])
#%%Test on experimental data - folder
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
maxFiles=20
save=1
plt.close('all')
unet_path='U-net-D0.1-2I0.0-1.0loss=0.003131276.h5'
unet_path='U-Net- I0.01-25_loss_0.009943331.h5'
unet_path="U-net-D0.6-1.2I0.1-0.6loss=0.0050201593chA.h5"
#unet_path="U-net-D0.1-2I0.0-1.0loss=0.0058968016chA.h5"
#unet_path="U-net-D1-2I0.0-1.0loss=0.0044185026chA.h5"
unet_path="U-net-D0.1-1.2I0.1-1loss=0.0071075866chA.h5" #old labelling scheme
unet_path="U-net-D0.1-1.2I0.1-1loss=0.011916474chA.h5" #trained on pathToEmpty files
unet_path="U-net-D0.1-1.2I0.1-1loss=0.013097599chA.h5" #works well for 200bps, lots of noise in new measurements though
#unet_path="U-net-D0.1-1.2I0.01-0.8loss=0.0116292965chA.h5" #trained with sometimes 0 particles in kymo
unet =tf.keras.models.load_model(unet_path,compile=False)#GAN.generator

#exp_path="../Data/Preproces
exp_path=r"..\Data\Preprocessed 2022-10-21-DNA\50bp"
exp_path=r"..\Data\2022-01-Inulin\Insulin\ML"
exp_path=r"..\Data\Preprocessed DNA lowiOC Simulated Data"
#exp_path=r"..\Data\Preprocessed 2022-11-17-DNA\TE-buff" 
exp_path=r"..\Data\Preprocessed 2022-11-17-DNA\200bps-9.5-molperch"
exp_path=r"..\Data\Preprocessed 2022-12-06\200bp-3-p-ch-20vpp-300hz"
exp_path=r"../Data/Preprocessed 2022-12-06/TE-buffer/"
exp_path=r"../Data/Preprocessed 2022-12-06/100bp-3-p-ch/"
exp_path=r"../Data/Preprocessed 2022-12-07/200bp-3-p-ch/"
exp_path=r"..\Data\Preprocessed 2022-12-16\TBE"

#exp_path=r"..\Data\Preprocessed 2022-11-30-cont-test\TE-buffer"
#exp_path=r"..\Data\Preprocessed 2022-09-16-noise-10-hours"
#exp_path=r"..\Data\Preprocessed 2022-09-15-noise-10-hours"
# exp_path=r"..\Data\Preprocessed 2022-07-26-200bps-0.1M-salt"
# exp_path=r"..\Data\Preprocessed 13-09-2022-3C-3-noise-measure"
# exp_path=r"..\Data\Preprocessed 100bps-0.1salt-TE"
# exp_path=r"..\Data\Preprocessed 100bps-0.1salt-TE"
# exp_path=r"..\Data\Preprocessed data-david"
#exp_path=r"..\Data\Preprocessed DNA lowiOC Simulated Data"
#exp_path=r"..\Data\Preprocessed simulated data"
#exp_path=r"..\Data\Preprocessed Old Experimental Noise"

#fix path on windows
exp_path = exp_path.replace(os.sep, '/') + "/"

savePath="../Figures/Meetings/2022-12-16/Noise/TBE/ch2B-2"

if save:
    try:
        os.makedirs(savePath)
    except:
        pass

folders=os.listdir(exp_path)
counter=0
for folder in folders:
    folderName=folder+"/"
    # if "B" in folder: #skip channel B, too high vibrations
    #     continue
    folder = folder+"/diffusion/"
    files = os.listdir(exp_path+folder)
    counter = 0
    if save:
        try:
            os.makedirs(savePath+folderName)
        except:
            pass
    for file in files:
        fileName=file

        fileName=fileName.replace(".mat","")
        fileName=fileName.replace(".npy","")
        fileName=fileName.replace("iOC0.","iOC0")
        file = exp_path+folder+file
        fig, axs = plt.subplots(2,2,figsize=(16,16))      
        axs=axs.flatten()
        counter+=1
       # if "0.0001" not in file:
        #    continue
        data = np.load(file)
        data = np.expand_dims(data,(0,-1))
        try:
            orig_img = data[:,200:-200,11:-11,:]
            pred=unet.predict(orig_img)
            img=axs[0].imshow(data[0,200:-200,11:-11,0],aspect='auto')  
        except:
            orig_img = data[:,199:-200,11:-11,:]
            pred=unet.predict(orig_img)
            img=axs[0].imshow(data[0,199:-200,11:-11,0],aspect='auto')
        #pred[pred>0.01]=1
        plt.suptitle(exp_path[8:]+fileName)
        plt.colorbar(img,ax=axs[0])

        #plt.savefig("../Figures/Meetings/"+file+"-kymograph")
        timesLimit = data.shape[1] % 128
        times = data.shape[1]
        # if timesLimit > 0:
        #    # orig_img = data[:,0:-int(timesLimit),11:139,:]
        #    orig_img = data[:,0:-int(timesLimit),11:-11,:]
      #  else:
        
           # orig_img = data[:,:,11:139,:]
        #orig_img[:,:,0:15,:] = 0
        #orig_img[:,:,110:,:] = 0
    
            
    
        axs[1].imshow(pred[0,:,:,0],aspect='auto')
        

        axs[2].hist(orig_img.flatten(), density=True, bins=1000)
        axs[3].hist(pred[pred>0.01].flatten(), density=True, bins=100)
        if save:
            plt.savefig(savePath+folderName+fileName)
            # np.savetxt(savePath+file+"-kymograph",orig_img[0,:,:,0].flatten())
            # np.savetxt(savePath+file+"-probMap",pred[0,:,:,0].flatten())
            plt.close('all')
        counter+=1
        if counter>maxFiles:
            break
        
    if counter>maxFiles:
        break

#%%WeightToIOC
import numpy as np

weight = 132e3/4 #50 bp DNA
A=50*50*1e-6 #ChE new chip

weight = [132e3,132e3/2,132e3/4] #200,100,50 bp DNA
A=115*50*1e-6 #ChB old chip
A=(114+130)/2*97*1e-6 #ChC old chip
A=(130)*97*1e-6 #ChC old chip
#weight = [5.6e3] #insulin
A=47*68*1e-6 #50x50 channel
A=20*50*1e-6 #smallest channels available, roughly
#A=100*27*1e-6 #BSA measured channel


#weight = 66e3 #BSA
#A=27*100*1e-6 #BSA channel

n_i=1.33
n_o=1.46

n_TE = 2*n_i**2/(n_i**2 - n_o**2)
n_TM = (n_i**2 + n_o**2)/(n_i**2 - n_o**2)
n_mean = 0.5*(n_TE + n_TM)
alpha_MW = 0.461e-12

calibration = A/alpha_MW/n_mean
iOC = [w/calibration for w in weight]
print(np.array(iOC)*10000)

#%%WeightToIOC
import numpy as np


A=50*50*1e-6 #ChE new chip

iOC=[0.125] #times 10, my units
A=115*50*1e-6 #ChB old chip
A=(114+130)/2*97*1e-6 #ChC old chip
A=(130)*97*1e-6 #ChC old chip
#weight = [5.6e3] #insulin
A=47*68*1e-6 #50x50 channel
A=20*50*1e-6 #smallest channels available, roughly
#A=100*27*1e-6 #BSA measured channel


#weight = 66e3 #BSA
#A=27*100*1e-6 #BSA channel

n_i=1.33
n_o=1.46

n_TE = 2*n_i**2/(n_i**2 - n_o**2)
n_TM = (n_i**2 + n_o**2)/(n_i**2 - n_o**2)
n_mean = 0.5*(n_TE + n_TM)
alpha_MW = 0.461e-12

calibration = A/alpha_MW/n_mean
w = [iOC*calibration for iOC in iOC]
print(np.array(iOC)*10000)


#%%Test on experimental data
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
save=1
maxNrFiles=200
plt.close('all')
unet_path = 'unet-14-dec-1700.h5'
unet_path="U-Net- I0.01-25_loss_0.009943331.h5"
unet_path="U-net-D0.1-2I0.0-1.46loss=0.0074624214.h5"
unet_path="U-net-D0.1-2I0.0-1.0loss=0.0060592867.h5"
unet_path="U-net-D0.1-2I0.0-1.0loss=0.006284812.h5"
unet_path="U-net-D0.1-2I0.0-1.0loss=0.003131276.h5"
#unet_path="U-net-D0.1-2I0.0-1.0loss=0.0060042166chA.h5"
unet_path="U-net-D0.1-2I0.0-1.0loss=0.0058968016chA.h5"
unet = tf.keras.models.load_model(unet_path,compile=False)

#exp_path="../Data/Preprocessed data-bohdan/200bps CropMedian/4B-5/diffusion/" #probably 0 traj..
exp_path="../Data/Preprocessed data-bohdan/200bps CropMedian/4C-3/diffusion/" #maybe one traj, probably 0..

exp_path="../Data/Preprocessed data-bohdan/2022-09-20-noisetest-ladder/ch2B-5-series/diffusion/"
exp_path="../Data/Preprocessed lowiOC 20220923 Simulated Data/lowiOC/diffusion/"
exp_path="../Data/Preprocessed data-david/ch1-2C-7/diffusion/"
exp_path="../Data/Preprocessed lowiOC 20220928chA Simulated Data/lowiOC/diffusion/"
#exp_path="../Data/Preprocessed 2022-09-23-DNA-fragments/100bps-NaCl-0.05TBE/ch1-2A-4/diffusion/"
#exp_path="../Data/Preprocessed 2022-09-23-DNA-fragments/100bps-NaCl-0.05TBE/ch1-2B-5/diffusion/"
#exp_path="../Data/Preprocessed 2022-09-28-DNA-fragments/50bps-0.01TBE/ch1-3A-6/diffusion/"
exp_path="../Data/Preprocessed 2022-09-28-DNA-fragments/water/ch1-3A-6/diffusion/"
exp_path="../Data/Preprocessed 2022-09-30-DNA/50-bps-8-molec-in-vol-TE-TBE/1-ch5A-5/diffusion/"
exp_path="../Data/Preprocessed 2022-09-30-DNA/TE-buffer/1-ch3A-5/diffusion/"
exp_path="../Data/2022-01-Inulin/Insulin/ML/1-3E-4/diffusion/"
exp_path=r"..\Data\Preprocessed 2022-11-17-DNA\200bps-9.5-molperch\3B-5\diffusion"
exp_path=r"..\Data\Preprocessed 2022-11-17-DNA\TE-buff\3B-5\diffusion"
exp_path=r"..\Data\Preprocessed 2022-11-17-DNA\TE-buff\3C-2\diffusion"
exp_path=r"..\Data\Preprocessed 2022-11-17-DNA\200bps-9.5-molperch\3C-2\diffusion"
#exp_path=r"..\Data\2022-01-Inulin\Insulin"
exp_path=r"..\Data\Preprocessed 2022-12-16\50-DNA\ch2B-2\diffusion" #very few, maybe 1-2 trajectories
exp_path=r"..\Data\Preprocessed 2022-12-16\50-DNA\ch2B-4\diffusion" #tons of trajectories, but also a lot of noise.
#exp_path=r"..\Data\Preprocessed 2022-12-16\insulin\ch2B-3\diffusion" #looks like some trajectories? or just noise?
#exp_path=r"..\Data\Preprocessed 2022-12-16\TBE\ch2B-2\diffusion"

#fix path on windows
exp_path = exp_path.replace(os.sep, '/') + "/"



savePath="../Figures/Meetings/2022-12-19/50-DNA/ch2B-4/"
#savePath="Y:/13-09-2022-3C-3-noise-measure/"
#savePath="Y:/2022-noise-meas/2022-09-20-noisetest-ladder/ch2B-5-series-2/"
if save:
    try:
        os.makedirs(savePath)
    except:
        pass
files = os.listdir(exp_path)[:maxNrFiles]
counter = 0
for file in files:
    fig,axs=plt.subplots(1,2,figsize=(16,16))
    counter+=1
   # if "0.0001" not in file:
    #    continue
    data = np.load(exp_path+file)
    data = np.expand_dims(data,(0,-1))
    try:
        orig_img = data[:,200:-200,11:-11,:]
        pred=unet.predict(orig_img)
        img=axs[0].imshow(data[0,200:-200,11:-11,0],aspect='auto')  
    except:
        orig_img = data[:,199:-200,11:-11,:]
        pred=unet.predict(orig_img)
        img=axs[0].imshow(data[0,199:-200,11:-11,0],aspect='auto')
        
    plt.colorbar(img,ax=axs[0])
    file=file.replace(".mat","")
    file=file.replace(".npy","")
    file=file.replace("iOC0.","iOC0")
    #plt.savefig("../Figures/Meetings/"+file+"-kymograph")
    timesLimit = data.shape[1] % 128
    times = data.shape[1]
    # if timesLimit > 0:
    #    # orig_img = data[:,0:-int(timesLimit),11:139,:]
    #    orig_img = data[:,0:-int(timesLimit),11:-11,:]
  #  else:
    
       # orig_img = data[:,:,11:139,:]
    #orig_img[:,:,0:15,:] = 0
    #orig_img[:,:,110:,:] = 0

        

    axs[1].imshow(pred[0,:,:,0],aspect='auto')
    plt.title(file)
    if save:
        plt.savefig(savePath+file)
        # np.savetxt(savePath+file+"-kymograph",orig_img[0,:,:,0].flatten())
        # np.savetxt(savePath+file+"-probMap",pred[0,:,:,0].flatten())
        plt.close('all')

#%%


b,L = generate_training_batch(image,1)

fig, ax = plt.subplots(2,2,figsize=(32,6))      

# Input image
j = 0#np.random.randint(batch_size)
img = b[j,...]
cax = ax[0][0]
cax.imshow(np.squeeze(img).T,aspect='auto')

# Label image
cax = ax[1][0]
#try:
#    label_img = np.reshape(L[j,...,0],(reduced_times,reduced_times))
#except: 
#    label_img = np.reshape(L[j,0,...,0],(reduced_times,reduced_length))
label_img=L
cax.imshow(np.squeeze(label_img).T,aspect='auto')
cax.set_title('Segment label',fontsize=16)

# Predicted image current net
cax = ax[0][1]
pred_img = GAN.generator.predict(np.reshape(b[j,...],(1,reduced_times,reduced_length,1)))
#pred_img[pred_img < 0] = 0
im = cax.imshow(np.squeeze(pred_img).T,aspect='auto')
     

plt.tight_layout()
plt.show()

#%%
plt.close('all')
for i in range(0,3):
    img,label = generate_training_batch(image,1) 
    fig,axs=plt.subplots(1,3)
    axs[0].imshow(img[0,:,:,0],aspect='auto',cmap="viridis")
    axs[1].imshow(GAN.generator.predict(img)[0,:,:,0],aspect='auto',cmap="viridis")
    axs[2].imshow(label[0,:,:,0],aspect='auto',cmap="viridis")
    axs[0].grid(False)
    axs[1].grid(False)
    axs[2].grid(False)
    # plt.figure(100+i)
    # plt.imshow(img[0,:,:,0],aspect='auto',cmap="viridis")
    axs[0].set_title("Original Kymograph")
    axs[1].set_title("Network Segmentation")
    axs[2].set_title("True Trajectories")
    
    fig,axs=plt.subplots(1,3)
    j=np.random.randint(len(valImgs))
    axs[0].imshow(valImgs[j,:,:,0],aspect='auto',cmap="viridis")
    axs[1].imshow(GAN.generator.predict(np.expand_dims(valImgs[j,...],0))[0,:,:,0],aspect='auto',cmap="viridis")
    axs[2].imshow(valLabels[j,:,:,0],aspect='auto',cmap="viridis")
    axs[0].grid(False)
    axs[1].grid(False)
    axs[2].grid(False)
    # plt.figure(100+i+1)
    # plt.imshow(valImgs[j,:,:,0],aspect='auto',cmap="viridis")
    axs[0].set_title("Original Kymograph")
    axs[1].set_title("Network Segmentation")
    axs[2].set_title("True Trajectories")
#%% Test
for i in range(0,4):
    testImgs = GenerateValData()[0]
    fig,axs=plt.subplots(1,2)
    j=np.random.randint(len(testImgs))
    axs[0].imshow(testImgs[j,:,:,0],aspect='auto')
    axs[1].imshow(GAN.generator.predict(np.expand_dims(testImgs[j,...],0))[0,:,:,0],aspect='auto')
    axs[0].set_title("Original Kymograph")
    axs[1].set_title("Network Segmentation")
#%%
val_path='../Data/Preprocessed lowiOC Simulated Data/lowiOC/diffusion/'
valFiles = os.listdir(val_path)
#Change this in GenerateValData
#valFiles = [valFile for valFile in valFiles if "iOC0.0001" in valFile or "e-05" in valFile]# or "iOC7.5e-05" in valFile or "iOC5e-05" in valFile or "iOC2.5e-05" in valFile]
#valFiles = [valFile for valFile in valFiles if "e-06" in valFile]
plt.close('all')
saveDir = "../Figures/lowiOC/"+unet_loadname[:-3] + "/"
save=1
for i in range(0,20):
    # img,label = generate_training_batch(image,1) 
    # fig,axs=plt.subplots(1,3)
    # axs[0].imshow(img[0,:,:,0],aspect='auto')
    # axs[1].imshow(GAN.generator.predict(img)[0,:,:,0],aspect='auto')
    # axs[2].imshow(label[0,:,:,0],aspect='auto')
    
    fig,axs=plt.subplots(1,3,sharey=True,figsize=(16,16))
    j=np.random.randint(len(valImgs))
    axs[0].imshow(newValImgs[j,:,:,0],aspect='auto')
    axs[1].imshow(GAN.generator.predict(np.expand_dims(newValImgs[j,:,:,0],0))[0,:,:,0],aspect='auto')
    axs[2].imshow(newValLabels[j,:,:,0],aspect='auto')
    axs[0].set_title("Kymograph")
    axs[1].set_title("Prediction")
    axs[2].set_title("True Trajectories")
    intensity = np.round(float(valFiles[j][valFiles[j].index("iOC")+3:valFiles[j].index("_M")])*10000,2)
    kDa = str(np.round(66*intensity,2))
    fig.suptitle("iOC = "+str(intensity)+str(", kDa = "+kDa))
    
    saveName = saveDir+"-iOC = "+str(intensity)+", kDa = "+kDa+"_"+str(i)

    #plt.savefig("test"+str(i))
    if save:
        try:
            plt.savefig(saveName+".png")
        except:
            os.mkdir(saveDir)
            plt.savefig(saveName+".png")
    
#%%
plt.close('all')
for j in range(0,40):

    fig,axs=plt.subplots(1,2)
    #j=np.random.randint(len(valImgs))
    axs[0].imshow(valImgs[j,:,:,0],aspect='auto')
    axs[1].imshow(valLabels[j,:,:,0],aspect='auto')
    
#%% SNR analysis
MW = [60,55,50,45,40,35,30,25,20,10,5,2,1]
MW = [1000,1]
# init
Int = [0,0]
Ds= 0.7
st = 0.045
getTrainTraj = Trajectory(I=Int,s=st)
factor=1
# image=dt.FlipLR(dt.FlipUD(input_array() + trainingDiffInt
#                           + GenNoise(dX=(.00001+.00003*np.random.rand())/factor,
#                                                   dA=0,
#                                                   noise_lev=(.0001)/factor/2,
#                                                   biglam=(0.6+.4*np.random.rand())/factor,
#                                                   bgnoiseCval=(0.03+.02*np.random.rand())/factor,
#                                                   bgnoise=(.08+.04*np.random.rand())/factor,
#                                                   bigx0=lambda: (.1*np.random.randn())/factor)
#                                       + trainingTraj**nbr_particles
#                                       + PostProcess()))

image=dt.FlipLR(dt.FlipUD(input_array() + init_particle_counter() 
                        + GenNoise(dX=.00001+.00003*np.random.rand(),
                                                dA=0,
                                                noise_lev=.0001/2,
                                                biglam=0.6+.4*np.random.rand(),
                                                bgnoiseCval=0.03+.02*np.random.rand(),
                                                bgnoise=.08+.04*np.random.rand(),
                                                bigx0=lambda: .1*np.random.randn())
                                    + getTrainTraj
                                    + PostProcess()))

img=image.resolve()
noiseStd = np.std(np.array(img[...,0]))
signalMean=np.zeros((len(MW)))
factor=1
for i in range(0,len(MW)):

    getTrainTraj.I[0] = 1.15/66*MW[i]/(0.045*np.sqrt(2*np.pi)*length/2*0.0295)
    getTrainTraj.I[1] = 1.15/66*MW[i]/(0.045*np.sqrt(2*np.pi)*length/2*0.0295)
    image=dt.FlipLR(dt.FlipUD(input_array() + init_particle_counter() 
                            + GenNoise(dX=.00001+.00003*np.random.rand(),
                                                    dA=0,
                                                    noise_lev=.0001/2,
                                                    biglam=0.6+.4*np.random.rand(),
                                                    bgnoiseCval=0.03+.02*np.random.rand(),
                                                    bgnoise=.08+.04*np.random.rand(),
                                                    bigx0=lambda: .1*np.random.randn())
                                        + getTrainTraj
                                        + PostProcess()))


    img=image.resolve()
    label=img[...,1]
    label=label[label!=0]
    signalMean[i]=np.mean(np.array(label))

SNR=signalMean/noiseStd
print(SNR)
#%%
from scipy.signal import find_peaks
import matplotlib.ticker as plticker
import matplotlib.patches as patches
def manTrack(diffImg,keepTraj=4,threshold=0.1,trajTreshold=64,dist=16):
    #takes in an image
    #keepTraj minimum length of accepted trajectory
    #treshold minimum height of local maxima
    #trajTreshold minimum distance from traj to point to consider point in traj
    #dist minimum distance between trajectories
    frames={}
    flag=0
    #Find local maxima in each frame
    for f in range(0,diffImg.shape[0]):
        frame = diffImg[f,:]
        localMax = find_peaks(frame,height=threshold,distance=dist)
        #If there are any local maxima, add their position to the list of considered frames.
        if len(localMax[0]):
            frames[f]=localMax[0]
        
    trajectories={}
    currentTrajectories=0
    keys = np.array(list(frames.keys()))
    
    for frame in frames:
        if flag:
            break
        curMax = np.array(frames[frame])
        
        # if len(curMax)>currentTrajectories:
        #     for i in range(0,len(curMax)-currentTrajectories):
        #             # print(len(curMax))
        #             # print(i)
        #             trajectories[currentTrajectories] = [(0,0)]
        #             currentTrajectories+=1
            
        diff=np.zeros((currentTrajectories,len(curMax)))
        if len(trajectories)>5 and False:
            flag=1
            break
        for t in trajectories:
            if flag:
                break
            # try:
            #     shortTraj=np.median(trajectories[t][-keepTraj:-1])
            # except:
            shortTraj=trajectories[t][-1]
            try:                  
                diff[t] = (np.abs(np.array(curMax)-shortTraj[-1]))#**2
                diffT=np.abs(frame-shortTraj[0])**2
                diff[t]=np.sqrt(diff[t]+diffT)
            except:
                diff[t]=[150000]*len(curMax)
              
        notTakenTrajs=np.arange(0,len(curMax))
        notUsedTrajs = np.arange(0,currentTrajectories)

        for i in range(0,len(curMax)):
            if flag:
                break
            try:
                c=curMax[i]
                #diffSum=np.sum(diff,1)
                t=np.argmin(diff,axis=0)[i] #np.argmin(diffSum)#np.argmin(diff,axis=0)
                if np.abs(frame-trajectories[t][-1][0])<trajTreshold and t in notUsedTrajs and np.abs(c-trajectories[t][-1][-1])<dist:# or trajectories[t]==[(0,0)]:
    #                    if not t in usedT:
                    notTakenTrajs=np.delete(notTakenTrajs,np.where(notTakenTrajs==i))
                    notUsedTrajs=np.delete(notUsedTrajs,np.where(notUsedTrajs==t))
                    trajectories[t]+=[(frame,c)]
    #                        usedT=np.append(usedT,t)
                       #diff =np.delete(diff,t,axis=0)   
            except Exception as e: print(e)

        

        for i in range(0,len(notTakenTrajs)):
            if flag:
                break
            trajectories[currentTrajectories] = [(frame,curMax[notTakenTrajs[0]])]
            notTakenTrajs=np.delete(notTakenTrajs,0)
            currentTrajectories+=1
        
    trajectories = [trajectories[t] for t in trajectories if len(trajectories[t])>keepTraj]
    currentTrajectories=len(trajectories)
    return trajectories,currentTrajectories,frames


getTestTraj = Trajectory(I=[5,5],s=st,D=[0.95,0.95])
#getTestTraj = Trajectory(I=[1,1.51],s=st)

testImage=dt.FlipLR(dt.FlipUD(input_array() + init_particle_counter() 
                        + GenNoise(dX=lambda:.00001+.00003*np.random.rand(),
                                                dA=0,#lambda:0+np.random.rand()*0.0001,
                                                noise_lev=lambda:.0001,
                                                biglam=lambda:0.6+.4*np.random.rand(),
                                                bgnoiseCval=lambda:0.03+.02*np.random.rand(),
                                                bgnoise=lambda:.08+.04*np.random.rand(),
                                                bigx0=lambda: .1*np.random.randn(),
                                                sinus_noise_amplitude=lambda: np.random.rand(),
                                                freq =lambda: np.random.rand()*np.pi)
                                    + getTestTraj**1
                                    + PostProcess()))

img,label = generate_training_batch(testImage,1) 

pred = GAN.predict(img).squeeze()
label=label.squeeze()
#%%current optimal 
plt.close('all')
trajectories,currentTrajectories,frames=manTrack(pred,keepTraj=16,threshold=0.05,trajTreshold=128,dist=32)
plt.figure()
plt.imshow(pred,aspect='auto')    
colourList=np.tile(plt.rcParams['axes.prop_cycle'].by_key()['color'],5)
D=np.zeros(currentTrajectories)
trajectoryLengths = [len(trajectories[i]) for i in range(0,currentTrajectories)]

xtomu=0.029545454545454545*4
deltaT=0.0049893
for t in range(0,currentTrajectories):
    time=np.array([trajectories[t][k][0] for k in range(1,len(trajectories[t]))])
    x=[trajectories[t][k][1] for k in range(1,len(trajectories[t]))]  
    plt.scatter(x,time,s=16,c=colourList[t])  
    D[t] = np.mean([np.abs(trajectories[t][i][1]*xtomu-trajectories[t][i+1][1]*xtomu)**2 for i in range(1,len(trajectories[t])-1)])/2/deltaT#np.mean([np.abs(trajectories[t][i][1]-trajectories[t][i+1][1])**2 for i in range(0,len(trajectories[t])-1)])
    D[t]+=np.mean([np.abs((trajectories[t][i+1][1]-trajectories[t][i+2][1])*(trajectories[t][i][1]-trajectories[t][i+1][1]))*xtomu**2 for i in range(1,len(trajectories[t])-2)])/deltaT
    D[t]/=2

print(D)
print(trajectoryLengths)

trajectories,currentTrajectories,frames=manTrack(label,keepTraj=16,threshold=0.05,trajTreshold=128,dist=32)
plt.figure()
plt.imshow(label,aspect='auto')    
colourList=np.tile(plt.rcParams['axes.prop_cycle'].by_key()['color'],5)
D=np.zeros(currentTrajectories)
trajectoryLengths = [len(trajectories[i]) for i in range(0,currentTrajectories)]

xtomu=0.029545454545454545*4
deltaT=0.0049893
for t in range(0,currentTrajectories):
    time=np.array([trajectories[t][k][0] for k in range(1,len(trajectories[t]))])
    x=[trajectories[t][k][1] for k in range(1,len(trajectories[t]))]  
    plt.scatter(x,time,s=16,c=colourList[t])  
    D[t] = np.mean([np.abs(trajectories[t][i][1]*xtomu-trajectories[t][i+1][1]*xtomu)**2 for i in range(1,len(trajectories[t])-1)])/2/deltaT#np.mean([np.abs(trajectories[t][i][1]-trajectories[t][i+1][1])**2 for i in range(0,len(trajectories[t])-1)])
    D[t]+=np.mean([np.abs((trajectories[t][i+1][1]-trajectories[t][i+2][1])*(trajectories[t][i][1]-trajectories[t][i+1][1]))*xtomu**2 for i in range(1,len(trajectories[t])-2)])/deltaT
    D[t]/=2

print(D)
print(trajectoryLengths)
