# DeepCaImX introduction and guide
## Import components: Keras, SciPy input/output, Numpy, Time, and Visualization with Python.

In [None]:
import keras
from tensorflow.keras.optimizers import *
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.initializers import *
from tensorflow.keras.applications import *
from tensorflow.keras import backend
from keras.constraints import *
import tensorflow as tf
import scipy.io as sio
import numpy as np
import time
import matplotlib.pyplot as plt

## Build up architecture of DeepCaImX, corresponding loss functions, and training methods.

In [None]:
class DeepCaImX():
    def __init__(self):        
        ########### Loss function for task segmentation and traces extraction ###########
        self.N = 4
        def denoise_loss(y_true, y_pred): # MSE with contraint
            regularized_loss = backend.sum(backend.square(y_pred[:,:,:,:,0] - y_true[:,:,:,:,0])) + 0.01*backend.sum(backend.square(y_pred[:,:,:,:,5]))         
            return regularized_loss/y_true.shape[1]
        
        def dice_coefficient(x1, x2):
            smooth = 1e-6
            intersection  = backend.sum(x1*x2)
            union = backend.sum(x1+x2)
            return (2.*intersection + smooth)/(union + smooth)
        def seg_loss(y_true, y_pred):
            smooth = 1e-6
            attention_loss = 1 - dice_coefficient(y_pred[:,0:100,:,:,0], y_true[:,0:100,:,:,0])            
            seg_loss = 1 - dice_coefficient(y_pred[:,100:104,:,:,0], y_true[:,100:104,:,:,0])
            #diver_constraint = backend.mean(dice_coefficient(y_pred[:,100,:,:,0], y_pred[:,101,:,:,0]) + dice_coefficient(y_pred[:,100,:,:,0], y_pred[:,102,:,:,0]) + dice_coefficient(y_pred[:,100,:,:,0], y_pred[:,103,:,:,0]) + dice_coefficient(y_pred[:,101,:,:,0], y_pred[:,102,:,:,0] + dice_coefficient(y_pred[:,101,:,:,0], y_pred[:,103,:,:,0]) + dice_coefficient(y_pred[:,102,:,:,0], y_pred[:,103,:,:,0])))
            
            return 10*(0.1*attention_loss + 0.9*seg_loss)
        
        def trace_loss(y_true, y_pred):
            smooth = 1e-6
            y_true = y_true[:,2:98]
            pearson_correlation_loss = (backend.sum((y_pred-backend.mean(y_pred))*(y_true-backend.mean(y_true)))+smooth) / (backend.sqrt((backend.sum(backend.square(y_pred-backend.mean(y_pred)))+smooth)*(backend.sum(backend.square((y_true-backend.mean(y_true))))+smooth)))
            pearson_correlation_loss = 1 - pearson_correlation_loss
            return 10*pearson_correlation_loss
        ##################################################################################
        
        ###########                     Initialization                        ############
        optimizer = Adam(learning_rate=1e-3)
        self.DeepCaImX_model = self.DeepCaImX()
        #self.DeepCaImX_model.summary()
        
        ##################################################################################
        
        ###########               Compiling multi-tasks model                  ###########
        self.DeepCaImX_model.compile_metrics = None
        self.DeepCaImX_model.compile(loss=[denoise_loss,seg_loss,trace_loss], optimizer=optimizer)
        #self.DeepCaImX_model.compile(loss=[denoise_loss,seg_loss], optimizer=optimizer)
        ##################################################################################
    
    
    ###########                  Architecture of ISTA-Net                      ###########    
    def DeepCaImX(self):
        def ISTA_blocks(R):
            temp = 0.01
            N = self.N
            conv_1 = Conv3D(N, (1,3,3), activation = None, padding = 'same', data_format='channels_last')(R)                
            f1 = Conv3D(N, (1,3,3), activation = 'relu', padding = 'same', data_format='channels_last')
            conv_2 = f1(conv_1)
            conv_symm = f1(conv_1)
            f2 = Conv3D(N, (1,3,3), activation = None, padding = 'same', data_format='channels_last')
            conv_3 = f2(conv_2)
            conv_symm = f2(conv_symm)
            conv_4 = multiply([Lambda(lambda x: backend.sign(x))(conv_3), ReLU()(Lambda(lambda x: x-temp)(Lambda(lambda x: backend.abs(x))(conv_3)))])
            f3 = Conv3D(N, (1,3,3), activation = 'relu', padding = 'same', data_format='channels_last')
            conv_5 = f3(conv_4)
            conv_symm = f3(conv_symm)
            f4 = Conv3D(N, (1,3,3), activation = None, padding = 'same', data_format='channels_last')
            conv_6 = f4(conv_5)       
            conv_symm = f4(conv_symm)       
            conv_7 = Conv3D(1, (1,3,3), activation = None, padding = 'same', data_format='channels_last')(conv_6)                
            conv_7 = add([conv_7, R])
            conv_8 = subtract([conv_symm, conv_1])
            return conv_4, conv_7, conv_8
        
        
            
        def CFF (inputs):
            def cnn_block(inputs,num_filters=self.N*8,kernel_size=(3,3,3),dilation_rate=(1,1,1),padding="same",use_bias=False):
                x = Conv3D(num_filters,kernel_size=kernel_size,dilation_rate=dilation_rate,padding="same",use_bias=use_bias,kernel_initializer=keras.initializers.HeNormal())(inputs)
                x = BatchNormalization()(x)
                return ReLU()(x)

            def DilatedSpatialPyramidPooling(inputs):
                x = AveragePooling3D(pool_size=(1,8,8))(inputs)
                x = cnn_block(x, kernel_size=1, use_bias=True)
                out_pool = UpSampling3D(size=(1,8,8))(x)
                out_1 = cnn_block(inputs, kernel_size=1, dilation_rate=(1,1,1))
                out_3 = cnn_block(inputs, dilation_rate=(1,3,3))
                out_6 = cnn_block(inputs, dilation_rate=(1,6,6))
                out_9 = cnn_block(inputs, dilation_rate=(1,9,9))
                x = Concatenate(axis=4)([out_pool, out_1, out_3, out_6, out_9])
                output = cnn_block(x, kernel_size=1)
                return output
            def backbone_resnet(inputs):
                conv1_1 = cnn_block(inputs,num_filters=self.N*2)
                conv1_2 = cnn_block(conv1_1,num_filters=self.N*2)
                conv1_3 = cnn_block(conv1_2,num_filters=self.N*2)
                conv1_4 = ReLU()(add([conv1_3, conv1_1]))
                pool1 = MaxPooling3D(pool_size=(1,2,2))(conv1_4)

                conv2_1 = cnn_block(pool1,num_filters=self.N*4)
                conv2_2 = cnn_block(conv2_1,num_filters=self.N*4)
                conv2_3 = cnn_block(conv2_2,num_filters=self.N*4)
                conv2_4 = ReLU()(add([conv2_3, conv2_1]))

                conv3_1 = cnn_block(conv2_4,num_filters=self.N*4)
                conv3_2 = cnn_block(conv3_1,num_filters=self.N*4)
                conv3_3 = cnn_block(conv3_2,num_filters=self.N*4)
                conv3_4 = ReLU()(add([conv3_3, conv3_1]))
                output = AveragePooling3D(pool_size=(1,2,2))(conv3_4)
                return pool1, output
            
            [input_2, feature] = backbone_resnet(inputs)
            input_1 = DilatedSpatialPyramidPooling(feature)
            input_1 = UpSampling3D(size=(1,2,2))(input_1)
            input_2 = cnn_block(input_2, num_filters=self.N*4, kernel_size=1)

            x = Concatenate(axis=4)([input_1, input_2])
            x = cnn_block(x,num_filters=self.N*4)
            x = cnn_block(x,num_filters=self.N*2)
            x = UpSampling3D(size=(1,2,2))(x)
            x = cnn_block(x,num_filters=self.N)
            att = Conv3D(1,1,padding="same",activation='tanh')(x)
            att = ReLU()(att)
            
            seg = Permute((2,3,4,1))(att)
            seg = Dense(64,'relu',use_bias=False)(seg)
            seg = Dense(32,'relu',use_bias=False)(seg)
            seg = Dense(16,'relu',use_bias=False)(seg)
            seg = Dense(8,'relu',use_bias=False)(seg)
            seg = Dense(self.N,'tanh',use_bias=False)(seg)            
            seg = ReLU()(seg)
            seg_ROI = Permute((4,1,2,3))(seg)
            seg = tf.squeeze(seg, axis=3)
            kernel = tf.ones((5,5,1))
            seg_1 = Lambda(lambda x: x[:,:,:,0])(seg)
            seg_1 = tf.expand_dims(seg_1, axis=3)
            seg_1 = tf.nn.erosion2d(value=seg_1,filters=kernel,strides=(1,1,1,1),padding='SAME',data_format='NHWC',dilations=(1,1,1,1))
            seg_1 = tf.nn.dilation2d(input=seg_1,filters=kernel,strides=(1,1,1,1),padding='SAME',data_format='NHWC',dilations=(1,1,1,1))
            seg_2 = Lambda(lambda x: x[:,:,:,1])(seg)
            seg_2 = tf.expand_dims(seg_2, axis=3)
            seg_2 = tf.nn.erosion2d(value=seg_2,filters=kernel,strides=(1,1,1,1),padding='SAME',data_format='NHWC',dilations=(1,1,1,1))
            seg_2 = tf.nn.dilation2d(input=seg_2,filters=kernel,strides=(1,1,1,1),padding='SAME',data_format='NHWC',dilations=(1,1,1,1))
            seg_3 = Lambda(lambda x: x[:,:,:,2])(seg)
            seg_3 = tf.expand_dims(seg_3, axis=3)
            seg_3 = tf.nn.erosion2d(value=seg_3,filters=kernel,strides=(1,1,1,1),padding='SAME',data_format='NHWC',dilations=(1,1,1,1))
            seg_3 = tf.nn.dilation2d(input=seg_3,filters=kernel,strides=(1,1,1,1),padding='SAME',data_format='NHWC',dilations=(1,1,1,1))
            seg_4 = Lambda(lambda x: x[:,:,:,3])(seg)
            seg_4 = tf.expand_dims(seg_4, axis=3)
            seg_4 = tf.nn.erosion2d(value=seg_4,filters=kernel,strides=(1,1,1,1),padding='SAME',data_format='NHWC',dilations=(1,1,1,1))
            seg_4 = tf.nn.dilation2d(input=seg_4,filters=kernel,strides=(1,1,1,1),padding='SAME',data_format='NHWC',dilations=(1,1,1,1))
            seg_map = Concatenate(axis=3)([seg_1, seg_2, seg_3, seg_4])
            seg_map = tf.expand_dims(seg_map, axis=3)
            seg_map = Permute((4,1,2,3))(seg_map)
            return att, seg_ROI, seg_map 
        
        Inputs = Input(shape=(100,32,32,1))
        ##################################################################################
        
        ###########                Architecture of ISTA-Net                    ###########
        temp = 0.01
        R1 = Inputs # r1
        [_, conv1_7, conv1_8] = ISTA_blocks(R1)            
        R2 = subtract([conv1_7, Lambda(lambda x: temp*x)(subtract([conv1_7,Inputs]))]) # r2
        [_, conv2_7, conv2_8] = ISTA_blocks(R2)
        R3 = subtract([conv2_7, Lambda(lambda x: temp*x)(subtract([conv2_7,Inputs]))]) # r3
        [_, conv3_7, conv3_8] = ISTA_blocks(R3)
        R4 = subtract([conv3_7, Lambda(lambda x: temp*x)(subtract([conv3_7,Inputs]))]) # r4
        [_, conv4_7, conv4_8] = ISTA_blocks(R4)
        R5 = subtract([conv4_7, Lambda(lambda x: temp*x)(subtract([conv4_7,Inputs]))]) # r5
        [conv5_4, denoise, conv5_8] = ISTA_blocks(R5)
        
        output_symm = conv1_8
        output_symm = Add()([output_symm, conv2_8])
        output_symm = Add()([output_symm, conv3_8])
        output_symm = Add()([output_symm, conv4_8])
        output_symm = Add()([output_symm, conv5_8])
        output_symm = Lambda(lambda x: backend.mean(x,axis=4,keepdims=True)/5)(output_symm)
        Sparse_prepro = tf.math.abs(conv5_4)
        Sparse_prepro = keras.activations.tanh(Sparse_prepro)        
        ###################################################################################################################
        
        ###########  Architecture of 2D Convolutional LSTM for ROIs segmentation and sequential attention maps  ########### 
        Att = ConvLSTM2D(self.N,3,padding='same',return_sequences=True)(Sparse_prepro)
        [Att, seg_ROI, seg_map] = CFF(Att)
        Seg = Concatenate(axis = 1)([Att, seg_ROI, seg_map]) # Shape: (?,100+4,32,32,1)
        
        ###################################################################################################################
        
        ###########           Architecture of 1D Convolutional layers for traces demixing and denoising         ###########
        seg_map = Permute((4,2,3,1))(seg_ROI)  
        video = multiply([Sparse_prepro, Att])
        trace = multiply([video, seg_map])
        trace = TimeDistributed(GlobalAveragePooling2D ())(trace) 
        trace = Conv1D(16,3,activation='relu', padding='valid',use_bias=False)(trace)
        trace = Conv1D(8,3,activation='relu', padding='valid',use_bias=False)(trace)
        trace = Conv1D(self.N,1,activation='relu', padding='valid',use_bias=False)(trace) # (?,96,4)
        Denoise = Concatenate(axis = 4)([denoise, Sparse_prepro, output_symm]) # Shape: (?,100,64,64,1+4+1+1)
        print('****************denoise*****************',Denoise)
        print('****************segmentation*****************',Seg)
        print('****************trace*****************',trace)
        return Model(inputs = Inputs, outputs = [Denoise, Seg, trace])
    ###################################################################################################################        
    
    
    ###########               Define training methods of ROIs segmentation and traces extraction            ###########
    def train(self, video, data, seg_GT, trace_GT, epoch, N_epoch, iteration, batch_size):
        # Size of "video", "data" and "att_GT" is (?,100,32,32,1)
        # Size of "seg_GT" is (?,32,32,4)
        # Size of "trace_GT" is (?,100,4)
        start_time = time.time()
        video = np.float32(video)
        seg_GT = np.float32(seg_GT)
        trace_GT = np.float32(trace_GT)
        att_GT = np.float32(np.zeros_like(data))
        for k in range(trace_GT.shape[0]):
            for i in range(trace_GT.shape[1]):
                for j in range(seg_GT.shape[3]):
                    att_GT[k,i,:,:,0] = att_GT[k,i,:,:,0] + seg_GT[k,:,:,j]*trace_GT[k,i,j]                
        att_GT[att_GT>=0.05] = 1
        att_GT[att_GT<0.05] = 0
        
        seg_GT = np.transpose(seg_GT,[0,3,1,2])
        seg_GT = np.expand_dims(seg_GT,axis=4)
        seg_GT = np.append(att_GT,seg_GT,axis=1)
        loss = self.DeepCaImX_model.train_on_batch(video,[data,seg_GT,trace_GT])
        print('Epoch: ',epoch+1, '/',int(N_epoch),' Iteration: ', iteration+1, '/', int(500/batch_size),', Loss_denoise: ',"{:.4f}".format(loss[1]),', Loss_seg: ',"{:.4f}".format(loss[2]),', Loss_trace: ',"{:.4f}".format(loss[3]))        
        
        if (iteration+1) % int(500/batch_size) == 0:
            [img, seg, trace] = self.DeepCaImX_model.predict_on_batch(np.expand_dims(video[0], axis=0))
            img_denoised = img[0,:,:,:,0]
            img_sparse = img[0,:,:,:,1:5]
            img_sparse = np.amax(img_sparse,axis=3)
            img_GT = data[0,:,:,:,0]
            img_raw = video[0,:,:,:,0]
            att = seg[0,0:400,:,:,0]
            att_GT = att_GT[0,:,:,:,0]
            
            plt.figure(figsize=(9,12))
            plt.subplot(6,4,1)
            plt.imshow(img_raw[25,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Raw 1')
            plt.subplot(6,4,2)
            plt.imshow(img_raw[50,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Raw 2')
            plt.subplot(6,4,3)
            plt.imshow(img_raw[75,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Raw 3')
            plt.subplot(6,4,4)
            plt.imshow(np.amax(img_raw,axis=0),vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Raw MIP')
            
            plt.subplot(6,4,5)
            plt.imshow(img_GT[25,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Ground Truth 1')
            plt.subplot(6,4,6)
            plt.imshow(img_GT[50,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Ground Truth 2')
            plt.subplot(6,4,7)
            plt.imshow(img_GT[75,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Ground Truth 3')
            plt.subplot(6,4,8)
            plt.imshow(np.amax(img_GT,axis=0),vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Ground Truth MIP')
            
            plt.subplot(6,4,9)
            plt.imshow(img_denoised[25,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Denoised Video 1')
            plt.subplot(6,4,10)
            plt.imshow(img_denoised[50,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Denoised Video 2')
            plt.subplot(6,4,11)
            plt.imshow(img_denoised[75,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Denoised Video 3')
            plt.subplot(6,4,12)
            plt.imshow(np.amax(img_denoised,axis=0),vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Denoised Video MIP')

            plt.subplot(6,4,13)
            plt.imshow(img_sparse[25,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Prepro Sparsity 1')
            plt.subplot(6,4,14)
            plt.imshow(img_sparse[50,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Prepro Sparsity 2')
            plt.subplot(6,4,15)
            plt.imshow(img_sparse[75,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Prepro Sparsity 3')
            plt.subplot(6,4,16)
            plt.imshow(np.amax(img_sparse,axis=0),vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Prepro Sparsity MIP')
            
            plt.subplot(6,4,17)
            plt.imshow(att[25,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Attention Maps 1')
            plt.subplot(6,4,18)
            plt.imshow(att[50,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Attention Maps 2')
            plt.subplot(6,4,19)
            plt.imshow(att[75,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Attention Maps 3')
            plt.subplot(6,4,20)
            plt.imshow(np.amax(att,axis=0,keepdims=False),vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Attention Maps MIP')
            
            plt.subplot(6,4,21)
            plt.imshow(att_GT[25,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Attention GT 1')
            plt.subplot(6,4,22)
            plt.imshow(att_GT[50,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Attention GT 2')
            plt.subplot(6,4,23)
            plt.imshow(att_GT[75,:,:],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Attention GT 3')
            plt.subplot(6,4,24)
            plt.imshow(np.amax(att_GT,axis=0,keepdims=False),vmin=0,vmax=1,cmap='gray')
            plt.axis('off')  
            plt.title('Attention GT MIP')
            plt.show()
            
            plt.figure(figsize=(8,4))
            for i in range(4):
                plt.subplot(2,4,i+1)                
                plt.imshow(seg[0,104+i,:,:,0],vmin=0,vmax=1,cmap='gray')
                plt.axis('off')
                plt.title('Seg #'+str(i+1))
                plt.subplot(2,4,i+5)
                plt.imshow(seg_GT[0,100+i,:,:,0],vmin=0,vmax=1,cmap='gray')
                plt.axis('off')
                plt.title('GT #'+str(i+1))
            plt.suptitle('Segmentation results')
            plt.show()
            trace = np.squeeze(trace[0])
            trace = trace/(np.amax(trace)+1e-3)
            trace_GT = np.squeeze(trace_GT[0])
            plt.figure(figsize=(8,4))
            for i in range(4):
                plt.title('#'+str(i+1))
                plt.subplot(2,4,i+1)
                plt.plot(trace[:,i])
                plt.ylim(0,1)
                plt.axis('off')                
                plt.subplot(2,4,i+5)
                plt.plot(trace_GT[:,i])
                plt.ylim(0,1)
                plt.axis('off')
            plt.suptitle('Traces results')
            plt.show()
            
            del(seg)
            del(seg_GT)
            del(att_GT)
            del(att)
            del(img)
            del(img_GT)
            del(img_raw)
            del(img_denoised)
            del(img_sparse)
            del(trace_GT)
            del(trace)
            self.DeepCaImX_model.save('./Pretrained Model/DeepCaImX_model_v1.h5')
        
        print("--- %s seconds escaped ---" % (time.time() - start_time))
        return loss
    ###################################################################################################################

## Load training dataset, set number of epoch and batch size, and train DeepCaImX.

In [None]:
## Allocate GPU memory in real-time
config=tf.compat.v1.ConfigProto() 
config.gpu_options.allow_growth = True
sess=tf.compat.v1.Session(config=config)

N_epoch = 200
batch_size = 2

model = DeepCaImX()
loss_history = []
start_time_total = time.time()
for i in range(N_epoch): # epoch
    index = np.random.permutation(500) + 1 # We will train 500 samples in total
    start_time = time.time()
    for j in range(int(500/batch_size)): # iteration
        index_batch = index[j*batch_size:(j+1)*batch_size]
        LSTM_Video = sio.loadmat('./Training Dataset/LSTM_Video_'+str(index_batch[0])+'.mat')
        LSTM_Video = LSTM_Video['LSTM_video']
        LSTM_Video = np.transpose(LSTM_Video,[2,0,1])
        LSTM_Video = np.expand_dims(np.expand_dims(LSTM_Video,3),0)

        LSTM_Data = sio.loadmat('./Training Dataset/LSTM_Data_'+str(index_batch[0])+'.mat')
        LSTM_Data = LSTM_Data['LSTM_data']
        LSTM_Data = np.transpose(LSTM_Data,[2,0,1])
        LSTM_Data = np.expand_dims(np.expand_dims(LSTM_Data,3),0)

        LSTM_Masks = sio.loadmat('./Training Dataset/LSTM_Masks_'+str(index_batch[0])+'.mat')
        LSTM_Masks = LSTM_Masks['LSTM_mask']
        LSTM_Masks = np.float32(np.expand_dims(LSTM_Masks,0))

        LSTM_Trace = sio.loadmat('./Training Dataset/LSTM_Trace_'+str(index_batch[0])+'.mat')
        LSTM_Trace = LSTM_Trace['LSTM_trace']
        LSTM_Trace = np.transpose(LSTM_Trace,[1,0])
        LSTM_Trace = np.expand_dims(LSTM_Trace,0)


        for k in range(batch_size-1): # load in batchsize
            Video = sio.loadmat('./Training Dataset/LSTM_Video_'+str(index_batch[k+1])+'.mat')
            Video = Video['LSTM_video']
            Video = np.transpose(Video,[2,0,1])
            LSTM_Video = np.append(LSTM_Video, np.expand_dims(np.expand_dims(Video,3),0), axis = 0)

            Data = sio.loadmat('./Training Dataset/LSTM_Data_'+str(index_batch[k+1])+'.mat')
            Data = Data['LSTM_data']
            Data = np.transpose(Data,[2,0,1])
            LSTM_Data = np.append(LSTM_Data, np.expand_dims(np.expand_dims(Data,3),0), axis = 0)
            
            Masks = sio.loadmat('./Training Dataset/LSTM_Masks_'+str(index_batch[k+1])+'.mat')
            Masks = np.float32(Masks['LSTM_mask'])
            LSTM_Masks = np.append(LSTM_Masks, np.expand_dims(Masks,0), axis = 0)

            Trace = sio.loadmat('./Training Dataset/LSTM_Trace_'+str(index_batch[k+1])+'.mat')
            Trace = Trace['LSTM_trace']
            Trace = np.transpose(Trace,[1,0])
            LSTM_Trace = np.append(LSTM_Trace, np.expand_dims(Trace,0), axis = 0)
            
            del(Video)
            del(Masks)
            del(Trace)
            del(Data)
    
        loss = model.train(LSTM_Video, LSTM_Data, LSTM_Masks, LSTM_Trace, i, N_epoch, j, batch_size)
        loss_history.append(loss)
        del(LSTM_Video)
        del(LSTM_Data)
        del(LSTM_Masks)
        del(LSTM_Trace)
    print("\x1b[31m--- %s minutes escaped for this epoch ---\x1b[0m" % ((time.time() - start_time)/60))
    # (Optional) The code below is used to measure the GPU memory occupied
    #print("\x1b[31m--- Current GPU Memory usage is: %s Gb ---\x1b[0m" % ((tf.config.experimental.get_memory_info('GPU:0')['current'])/1024/1024))
    print()
print("\x1b[31m--- %s hours escaped for training ---\x1b[0m" % ((time.time() - start_time_total)/3600))        
#sio.savemat('loss_history.mat',{'loss_history': np.array(loss_history)})

## Load pretrained DeepCaImX model, set number of epoch and batch size, and give a further training on DeepCaImX.

In [None]:
import keras
from tensorflow.keras.optimizers import *
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.initializers import *
from tensorflow.keras.applications import *
from tensorflow.keras import backend
from keras.constraints import *
import tensorflow as tf
import scipy.io as sio
import numpy as np
import time
import matplotlib.pyplot as plt

def train(DeepCaImX_model, video, data, seg_GT, trace_GT, epoch, N_epoch, iteration, batch_size):
    # Size of "video", "data" and "att_GT" is (?,100,32,32,1)
    # Size of "seg_GT" is (?,32,32,4)
    # Size of "trace_GT" is (?,100,4)
    start_time = time.time()
    video = np.float32(video)
    seg_GT = np.float32(seg_GT)
    trace_GT = np.float32(trace_GT)
    att_GT = np.float32(np.zeros_like(data))
    for k in range(trace_GT.shape[0]):
        for i in range(trace_GT.shape[1]):
            for j in range(seg_GT.shape[3]):
                att_GT[k,i,:,:,0] = att_GT[k,i,:,:,0] + seg_GT[k,:,:,j]*trace_GT[k,i,j]                
    att_GT[att_GT>=0.05] = 1
    att_GT[att_GT<0.05] = 0

    seg_GT = np.transpose(seg_GT,[0,3,1,2])
    seg_GT = np.expand_dims(seg_GT,axis=4)
    seg_GT = np.append(att_GT,seg_GT,axis=1)
    loss = DeepCaImX_model.train_on_batch(video,[data,seg_GT,trace_GT])
    print('Epoch: ',epoch+1, '/',int(N_epoch),' Iteration: ', iteration+1, '/', int(500/batch_size),', Loss_denoise: ',"{:.4f}".format(loss[1]),', Loss_seg: ',"{:.4f}".format(loss[2]),', Loss_trace: ',"{:.4f}".format(loss[3]))        

    if (iteration+1) % int(500/batch_size) == 0:
        [img, seg, trace] = DeepCaImX_model.predict_on_batch(np.expand_dims(video[0], axis=0))
        img_denoised = img[0,:,:,:,0]
        img_sparse = img[0,:,:,:,1:5]
        img_sparse = np.amax(img_sparse,axis=3)
        img_GT = data[0,:,:,:,0]
        img_raw = video[0,:,:,:,0]
        att = seg[0,0:400,:,:,0]
        att_GT = att_GT[0,:,:,:,0]

        plt.figure(figsize=(9,12))
        plt.subplot(6,4,1)
        plt.imshow(img_raw[25,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Raw 1')
        plt.subplot(6,4,2)
        plt.imshow(img_raw[50,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Raw 2')
        plt.subplot(6,4,3)
        plt.imshow(img_raw[75,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Raw 3')
        plt.subplot(6,4,4)
        plt.imshow(np.amax(img_raw,axis=0),vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Raw MIP')

        plt.subplot(6,4,5)
        plt.imshow(img_GT[25,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Ground Truth 1')
        plt.subplot(6,4,6)
        plt.imshow(img_GT[50,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Ground Truth 2')
        plt.subplot(6,4,7)
        plt.imshow(img_GT[75,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Ground Truth 3')
        plt.subplot(6,4,8)
        plt.imshow(np.amax(img_GT,axis=0),vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Ground Truth MIP')

        plt.subplot(6,4,9)
        plt.imshow(img_denoised[25,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Denoised Video 1')
        plt.subplot(6,4,10)
        plt.imshow(img_denoised[50,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Denoised Video 2')
        plt.subplot(6,4,11)
        plt.imshow(img_denoised[75,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Denoised Video 3')
        plt.subplot(6,4,12)
        plt.imshow(np.amax(img_denoised,axis=0),vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Denoised Video MIP')

        plt.subplot(6,4,13)
        plt.imshow(img_sparse[25,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Prepro Sparsity 1')
        plt.subplot(6,4,14)
        plt.imshow(img_sparse[50,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Prepro Sparsity 2')
        plt.subplot(6,4,15)
        plt.imshow(img_sparse[75,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Prepro Sparsity 3')
        plt.subplot(6,4,16)
        plt.imshow(np.amax(img_sparse,axis=0),vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Prepro Sparsity MIP')

        plt.subplot(6,4,17)
        plt.imshow(att[25,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Attention Maps 1')
        plt.subplot(6,4,18)
        plt.imshow(att[50,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Attention Maps 2')
        plt.subplot(6,4,19)
        plt.imshow(att[75,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Attention Maps 3')
        plt.subplot(6,4,20)
        plt.imshow(np.amax(att,axis=0,keepdims=False),vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Attention Maps MIP')

        plt.subplot(6,4,21)
        plt.imshow(att_GT[25,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Attention GT 1')
        plt.subplot(6,4,22)
        plt.imshow(att_GT[50,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Attention GT 2')
        plt.subplot(6,4,23)
        plt.imshow(att_GT[75,:,:],vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Attention GT 3')
        plt.subplot(6,4,24)
        plt.imshow(np.amax(att_GT,axis=0,keepdims=False),vmin=0,vmax=1,cmap='gray')
        plt.axis('off')  
        plt.title('Attention GT MIP')
        plt.show()

        plt.figure(figsize=(8,4))
        for i in range(4):
            plt.subplot(2,4,i+1)                
            plt.imshow(seg[0,104+i,:,:,0],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')
            plt.title('Seg #'+str(i+1))
            plt.subplot(2,4,i+5)
            plt.imshow(seg_GT[0,100+i,:,:,0],vmin=0,vmax=1,cmap='gray')
            plt.axis('off')
            plt.title('GT #'+str(i+1))
        plt.suptitle('Segmentation results')
        plt.show()
        trace = np.squeeze(trace[0])
        trace = trace/(np.amax(trace)+1e-3)
        trace_GT = np.squeeze(trace_GT[0])
        plt.figure(figsize=(8,4))
        for i in range(4):
            plt.title('#'+str(i+1))
            plt.subplot(2,4,i+1)
            plt.plot(trace[:,i])
            plt.ylim(0,1)
            plt.axis('off')                
            plt.subplot(2,4,i+5)
            plt.plot(trace_GT[:,i])
            plt.ylim(0,1)
            plt.axis('off')
        plt.suptitle('Traces results')
        plt.show()

        del(seg)
        del(seg_GT)
        del(att_GT)
        del(att)
        del(img)
        del(img_GT)
        del(img_raw)
        del(img_denoised)
        del(img_sparse)
        del(trace_GT)
        del(trace)
        DeepCaImX_model.save('./Pretrained Model/DeepCaImX_model_v2.h5')

    print("--- %s seconds escaped ---" % (time.time() - start_time))
    return loss, DeepCaImX_model
    ###################################################################################################################
    
def denoise_loss(y_true, y_pred): # MSE with contraint
    regularized_loss = backend.sum(backend.square(y_pred[:,:,:,:,0] - y_true[:,:,:,:,0])) + 0.01*backend.sum(backend.square(y_pred[:,:,:,:,5]))         
    return regularized_loss/y_true.shape[1]

def dice_coefficient(x1, x2):
    smooth = 1e-6
    intersection  = backend.sum(x1*x2)
    union = backend.sum(x1+x2)
    return (2.*intersection + smooth)/(union + smooth)
def seg_loss(y_true, y_pred):
    smooth = 1e-6
    attention_loss = 1 - dice_coefficient(y_pred[:,0:100,:,:,0], y_true[:,0:100,:,:,0])            
    seg_loss = 1 - dice_coefficient(y_pred[:,100:104,:,:,0], y_true[:,100:104,:,:,0])
    return 10*(0.1*attention_loss + 0.9*seg_loss)

def trace_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = y_true[:,2:98]
    pearson_correlation_loss = (backend.sum((y_pred-backend.mean(y_pred))*(y_true-backend.mean(y_true)))+smooth) / (backend.sqrt((backend.sum(backend.square(y_pred-backend.mean(y_pred)))+smooth)*(backend.sum(backend.square((y_true-backend.mean(y_true))))+smooth)))
    pearson_correlation_loss = 1 - pearson_correlation_loss
    return 10*pearson_correlation_loss

## Allocate GPU memory in real-time
config=tf.compat.v1.ConfigProto() 
config.gpu_options.allow_growth = True
sess=tf.compat.v1.Session(config=config)

N_epoch = 200
batch_size = 2
model = load_model('./Pretrained Model/DeepCaImX_model_v1.h5',custom_objects={'backend': backend,'dice_coefficient': dice_coefficient,'denoise_loss': denoise_loss,'seg_loss': seg_loss,'trace_loss': trace_loss})        
model.summary()
loss_history = []
start_time_total = time.time()
t = 0
for i in range(N_epoch): # epoch    
    index = np.random.permutation(500) + 1 # We will train 500 samples in total
    start_time = time.time()
    for j in range(int(500/batch_size)): # iteration
        index_batch = index[j*batch_size:(j+1)*batch_size]
        LSTM_Video = sio.loadmat('./Training Dataset/LSTM_Video_'+str(index_batch[0])+'.mat')
        LSTM_Video = LSTM_Video['LSTM_video']
        LSTM_Video = np.transpose(LSTM_Video,[2,0,1])
        LSTM_Video = np.expand_dims(np.expand_dims(LSTM_Video,3),0)

        LSTM_Data = sio.loadmat('./Training Dataset/LSTM_Data_'+str(index_batch[0])+'.mat')
        LSTM_Data = LSTM_Data['LSTM_data']
        LSTM_Data = np.transpose(LSTM_Data,[2,0,1])
        LSTM_Data = np.expand_dims(np.expand_dims(LSTM_Data,3),0)

        LSTM_Masks = sio.loadmat('./Training Dataset/LSTM_Masks_'+str(index_batch[0])+'.mat')
        LSTM_Masks = LSTM_Masks['LSTM_mask']
        LSTM_Masks = np.float32(np.expand_dims(LSTM_Masks,0))

        LSTM_Trace = sio.loadmat('./Training Dataset/LSTM_Trace_'+str(index_batch[0])+'.mat')
        LSTM_Trace = LSTM_Trace['LSTM_trace']
        LSTM_Trace = np.transpose(LSTM_Trace,[1,0])
        LSTM_Trace = np.expand_dims(LSTM_Trace,0)


        for k in range(batch_size-1): # load in batchsize
            Video = sio.loadmat('./Training Dataset/LSTM_Video_'+str(index_batch[k+1])+'.mat')
            Video = Video['LSTM_video']
            Video = np.transpose(Video,[2,0,1])
            LSTM_Video = np.append(LSTM_Video, np.expand_dims(np.expand_dims(Video,3),0), axis = 0)

            Data = sio.loadmat('./Training Dataset/LSTM_Data_'+str(index_batch[k+1])+'.mat')
            Data = Data['LSTM_data']
            Data = np.transpose(Data,[2,0,1])
            LSTM_Data = np.append(LSTM_Data, np.expand_dims(np.expand_dims(Data,3),0), axis = 0)
            
            Masks = sio.loadmat('./Training Dataset/LSTM_Masks_'+str(index_batch[k+1])+'.mat')
            Masks = np.float32(Masks['LSTM_mask'])
            LSTM_Masks = np.append(LSTM_Masks, np.expand_dims(Masks,0), axis = 0)

            Trace = sio.loadmat('./Training Dataset/LSTM_Trace_'+str(index_batch[k+1])+'.mat')
            Trace = Trace['LSTM_trace']
            Trace = np.transpose(Trace,[1,0])
            LSTM_Trace = np.append(LSTM_Trace, np.expand_dims(Trace,0), axis = 0)
            
            del(Video)
            del(Masks)
            del(Trace)
            del(Data)
    
        [loss, model] = train(model, LSTM_Video, LSTM_Data, LSTM_Masks, LSTM_Trace, i, N_epoch, j, batch_size)
        loss_history.append(loss)
        del(LSTM_Video)
        del(LSTM_Data)
        del(LSTM_Masks)
        del(LSTM_Trace)
    print("\x1b[31m--- %s minutes escaped for this epoch ---\x1b[0m" % ((time.time() - start_time)/60))
    # (Optional) The code below is used to measure the GPU memory occupied 
    #print("\x1b[31m--- Current GPU Memory usage is: %s GB ---\x1b[0m" % ((tf.config.experimental.get_memory_info('GPU:0')['current'])/1024/1024))
    print()
print("\x1b[31m--- %s hours escaped for training ---\x1b[0m" % ((time.time() - start_time_total)/3600))        
#sio.savemat('loss_history.mat',{'loss_history': np.array(loss_history)})