In [None]:
import glob
import ipynbname
import json
import os
import random
import sys
from pathlib import Path

import numpy as np
import pandas as pd
from scipy import ndimage, stats
from scipy.interpolate import CubicSpline

import tensorflow as tf
import tensorflow.keras.backend as K
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Layer, Add, Activation

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from matplotlib.colors import ListedColormap

import pydicom

from losses import *
from MultiFlowSeg import *
from utils import *

import volumentations as V

2024-11-04 11:18:46.532813: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
image_size = 128
frames = 32
data_path = f'../data/clean_{image_size}_{frames}'
with open('patients5.json', 'r') as json_file:
    patients = json.load(json_file)

train_patients, val_patients, test_patients = patients['train'],patients['val'],patients['test']

patients = train_patients +  val_patients + test_patients 
all_patients =sorted(np.unique(['_'.join(pat.split('_')[:2]).split('/')[-1].replace('.npy','') for pat in glob.glob(f'{data_path}/*')]))
len(train_patients), len(val_patients), len(test_patients) 

In [None]:
venc_df = pd.read_csv('venc.csv')
class CustomDataGen():    
    def __init__(self, 
                 patients,
                 cohort, 
                 vessel = ''
                ):
        self.patients = patients
        self.cohort = cohort
        self.vessel = vessel
        
    def data_generator(self):
        vessel_indices = list(vessels_dict.keys())[1:] if self.cohort != 'test' else [self.vessel]
        num_vessels = len(vessels_dict) 
        for patient in self.patients:
            for vessel in vessel_indices: 
                vessel_index = vessels_dict[vessel]
                mag_image, phase_image, mask = np.load(f'{data_path}/{patient}_{vessel}.npy', allow_pickle = True)
                mag_image[mag_image<1e-10] = 0                
                max_val = np.max(phase_image)

                venc = venc_df.loc[(venc_df['patient'] == patient) & (venc_df['vessel'] == vessel)].venc.values[0]
                angles = phase2angle(phase_image, venc)
                mag_image = (mag_image - np.min(mag_image))/(np.max(mag_image))
                mag_image[mag_image>=1] = 1

                mask[mask > 0.5] = 1
                mask[mask <= 0.5] = 0
                mask = mask.astype('uint8')
                
                if self.cohort == 'train':
                    phase_image = phase_image.astype('float32')/max_val
                    mask_phase = np.stack([mask, phase_image], -1)
                    
                    aug = get_volumentation(image_size, frames, vessel)
                    aug_data = aug(image = mag_image, mask = mask_phase)
                    mag_image, mask_phase = aug_data['image'], aug_data['mask']
                    mask, phase_image = mask_phase[...,0], mask_phase[...,1]
                    angles = phase2angle(phase_image, venc)
                    
                mag_image = skimage.exposure.equalize_adapthist(mag_image)
                complex_image = create_complex_image(mag_image, angles)
                real_image, imaginary_image = complex_image[...,0],complex_image[...,1]
                if self.cohort == 'train' and random.random()<0.5:
                    imaginary_image = -imaginary_image
                mag_image = normalise(mag_image)        
                imaginary_image = normalise(imaginary_image)        
                phase_image = normalise(phase_image)        

                X = np.stack([mag_image, imaginary_image], -1)
                one_hot_mask = np.zeros((image_size, image_size, frames, num_vessels), dtype='uint8')
                bkg = (mask == 0).astype('uint8')  # Background is where mask is not 1
                one_hot_mask[..., 0] = bkg
                one_hot_mask[..., vessel_index] = mask
                y = one_hot_mask
                
                cgm_input = tf.one_hot(vessel_index, len(vessels_dict))
                if self.cohort == 'test':
                    description = series_description_df.loc[patient,vessel].seriesdescription
                    description = description.replace('_',' ').replace('.',' ').replace('x','').replace('  ',' ').split(' ')
                    labels = []
                    for token in description:
                        strings = is_token_in_dictionary(token, data_dictionary.keys()) + is_token_a_substring_in_dictionary(data_dictionary.keys(), description)
                        if len(strings)>0:
                            for string in strings:
                                labels.append(data_dictionary[string])
                        print(labels, strings)
                    if len(labels) == 0:
                        label = 0
                    else:
                        label = pd.Series(labels).value_counts().index[0]
                        
                    one_hot = vessels_dict[label] if label in vessels_dict.keys() else 0
                    print(description, one_hot, label)
                    one_hot_input = tf.one_hot(one_hot, len(vessels_dict))#[np.newaxis] 
                else:
                    if self.cohort == 'train' and random.random() < 0.05:
                        one_hot_input = tf.one_hot(random.randint(0,5), len(vessels_dict))
                    else:
                        one_hot_input = cgm_input
                yield {'image_input':X.astype('float32'), 'cgm_input': cgm_input,'one_hot_input':one_hot_input,'mask_input': y.astype('uint8')},  y

    def get_gen(self):
        return self.data_generator()


In [17]:
input_channel = 2
out_channels =len(vessels_dict)

input_shape = [image_size,image_size, frames, input_channel]
output_shape = [image_size,image_size, frames, out_channels]


train_gen = CustomDataGen(train_patients, 'train').get_gen
val_gen   = CustomDataGen(val_patients, 'val').get_gen

output_signature = (
    {'image_input': tf.TensorSpec(shape=input_shape, dtype=tf.float32), 
     'cgm_input': tf.TensorSpec(shape=[6,], dtype=tf.uint8),
     'one_hot_input': tf.TensorSpec(shape=[6,], dtype=tf.uint8),
    'mask_input':    tf.TensorSpec(shape=output_shape, dtype=tf.uint8),
    },
    tf.TensorSpec(shape=output_shape, dtype=tf.uint8))

train_ds = tf.data.Dataset.from_generator(train_gen, output_signature = output_signature)

val_ds = tf.data.Dataset.from_generator(val_gen, output_signature = output_signature)

BATCH_SIZE = 8
train_ds = train_ds.shuffle(int(len(train_patients)/8), seed = 42, reshuffle_each_iteration=True).batch(BATCH_SIZE).prefetch(-1)
val_ds = val_ds.batch(BATCH_SIZE).prefetch(-1)

In [None]:
if continue_training:
    model =  tf.keras.models.load_model(f'models/{model_name}.h5', compile = False)
    print('model loaded')
else:
    model = build_multiflowseg()

In [15]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 image_input (InputLayer)    [(None, 128, 128, 32, 2)]    0         []                            
                                                                                                  
 conv3d (Conv3D)             (None, 128, 128, 32, 16)     880       ['image_input[0][0]']         
                                                                                                  
 batch_normalization (Batch  (None, 128, 128, 32, 16)     64        ['conv3d[0][0]']              
 Normalization)                                                                                   
                                                                                                  
 leaky_re_lu (LeakyReLU)     (None, 128, 128, 32, 16)     0         ['batch_normalization[0][0

 ling3D)                                                                                          
                                                                                                  
 conv3d_9 (Conv3D)           (None, 16, 16, 32, 128)      221312    ['encoding_3_maxpool[0][0]']  
                                                                                                  
 batch_normalization_9 (Bat  (None, 16, 16, 32, 128)      512       ['conv3d_9[0][0]']            
 chNormalization)                                                                                 
                                                                                                  
 leaky_re_lu_9 (LeakyReLU)   (None, 16, 16, 32, 128)      0         ['batch_normalization_9[0][0]'
                                                                    ]                             
                                                                                                  
 conv3d_10

                                                                    ]                             
                                                                                                  
 conv3d_18 (Conv3D)          (None, 16, 16, 32, 16)       13840     ['fullscale_maxpool_2_4[0][0]'
                                                                    ]                             
                                                                                                  
 conv3d_19 (Conv3D)          (None, 16, 16, 32, 16)       6928      ['fullscale_maxpool_1_4[0][0]'
                                                                    ]                             
                                                                                                  
 batch_normalization_15 (Ba  (None, 16, 16, 32, 16)       64        ['conv3d_15[0][0]']           
 tchNormalization)                                                                                
          

 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_24 (Ba  (None, 32, 32, 32, 16)       64        ['conv3d_24[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_25 (Ba  (None, 32, 32, 32, 16)       64        ['conv3d_25[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 leaky_re_lu_21 (LeakyReLU)  (None, 32, 32, 32, 16)       0         ['batch_normalization_21[0][0]
                                                                    ']                            
          

                                                                                                  
 leaky_re_lu_30 (LeakyReLU)  (None, 64, 64, 32, 16)       0         ['batch_normalization_30[0][0]
                                                                    ']                            
                                                                                                  
 leaky_re_lu_31 (LeakyReLU)  (None, 64, 64, 32, 16)       0         ['batch_normalization_31[0][0]
                                                                    ']                            
                                                                                                  
 concatenate_2 (Concatenate  (None, 64, 64, 32, 80)       0         ['leaky_re_lu_27[0][0]',      
 )                                                                   'leaky_re_lu_28[0][0]',      
                                                                     'leaky_re_lu_29[0][0]',      
          

 leaky_re_lu_37 (LeakyReLU)  (None, 128, 128, 32, 16)     0         ['batch_normalization_37[0][0]
                                                                    ']                            
                                                                                                  
 tf.math.argmax (TFOpLambda  (None,)                      0         ['tf.compat.v1.gather[0][0]'] 
 )                                                                                                
                                                                                                  
 concatenate_3 (Concatenate  (None, 128, 128, 32, 80)     0         ['leaky_re_lu_33[0][0]',      
 )                                                                   'leaky_re_lu_34[0][0]',      
                                                                     'leaky_re_lu_35[0][0]',      
                                                                     'leaky_re_lu_36[0][0]',      
          

                                                                                                  
 tf.clip_by_value (TFOpLamb  (None, 128, 128, 32, 6)      0         ['tf.math.multiply[0][0]']    
 da)                                                                                              
                                                                                                  
 tf.cast (TFOpLambda)        (None, 128, 128, 32, 6)      0         ['mask_input[0][0]']          
                                                                                                  
 tf.__operators__.getitem_2  (None, 128, 128, 32)         0         ['tf.clip_by_value[0][0]']    
  (SlicingOpLambda)                                                                               
                                                                                                  
 tf.__operators__.getitem_1  (None, 128, 128, 32)         0         ['tf.cast[0][0]']             
  (Slicing

 tf.math.multiply_16 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_5[0
 ambda)                                                             ][0]',                        
                                                                     'tf.math.subtract_6[0][0]']  
                                                                                                  
 tf.math.subtract_7 (TFOpLa  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_5[0
 mbda)                                                              ][0]']                        
                                                                                                  
 tf.__operators__.getitem_7  (None, 128, 128, 32)         0         ['tf.cast[0][0]']             
  (SlicingOpLambda)                                                                               
                                                                                                  
 tf.math.s

 tf.math.reduce_sum_10 (TFO  ()                           0         ['tf.math.multiply_21[0][0]'] 
 pLambda)                                                                                         
                                                                                                  
 tf.math.multiply_22 (TFOpL  (None, 128, 128, 32)         0         ['tf.math.subtract_10[0][0]', 
 ambda)                                                              'tf.__operators__.getitem_8[0
                                                                    ][0]']                        
                                                                                                  
 tf.math.multiply_26 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_9[0
 ambda)                                                             ][0]',                        
                                                                     'tf.math.subtract_12[0][0]'] 
          

                                                                                                  
 tf.math.reduce_sum_12 (TFO  ()                           0         ['tf.math.multiply_25[0][0]'] 
 pLambda)                                                                                         
                                                                                                  
 tf.math.multiply_28 (TFOpL  ()                           0         ['tf.math.reduce_sum_13[0][0]'
 ambda)                                                             ]                             
                                                                                                  
 tf.math.reduce_sum_14 (TFO  ()                           0         ['tf.math.multiply_27[0][0]'] 
 pLambda)                                                                                         
                                                                                                  
 tf.math.m

                                                                                                  
 tf.math.pow_1 (TFOpLambda)  ()                           0         ['tf.math.subtract_5[0][0]']  
                                                                                                  
 tf.math.subtract_8 (TFOpLa  ()                           0         ['tf.math.truediv_2[0][0]']   
 mbda)                                                                                            
                                                                                                  
 tf.math.truediv_3 (TFOpLam  ()                           0         ['tf.__operators__.add_15[0][0
 bda)                                                               ]',                           
                                                                     'tf.__operators__.add_18[0][0
                                                                    ]']                           
          

                                                                                                  
 tf.__operators__.getitem_1  (None, 128, 128, 32)         0         ['tf.clip_by_value_1[0][0]']  
 4 (SlicingOpLambda)                                                                              
                                                                                                  
 tf.__operators__.getitem_1  (None, 128, 128, 32)         0         ['tf.cast_1[0][0]']           
 3 (SlicingOpLambda)                                                                              
                                                                                                  
 tf.math.subtract_18 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_14[
 ambda)                                                             0][0]']                       
                                                                                                  
 tf.__oper

 ambda)                                                             0][0]']                       
                                                                                                  
 tf.__operators__.getitem_1  (None, 128, 128, 32)         0         ['tf.cast_1[0][0]']           
 9 (SlicingOpLambda)                                                                              
                                                                                                  
 tf.math.subtract_27 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_20[
 ambda)                                                             0][0]']                       
                                                                                                  
 tf.__operators__.getitem_2  (None, 128, 128, 32)         0         ['tf.clip_by_value_1[0][0]']  
 2 (SlicingOpLambda)                                                                              
          

                                                                    0][0]']                       
                                                                                                  
 tf.math.multiply_56 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_21[
 ambda)                                                             0][0]',                       
                                                                     'tf.math.subtract_30[0][0]'] 
                                                                                                  
 tf.math.subtract_31 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_21[
 ambda)                                                             0][0]']                       
                                                                                                  
 tf.__operators__.getitem_2  (None, 128, 128, 32)         0         ['tf.cast_1[0][0]']           
 3 (Slicin

 tf.math.multiply_58 (TFOpL  ()                           0         ['tf.math.reduce_sum_31[0][0]'
 ambda)                                                             ]                             
                                                                                                  
 tf.math.reduce_sum_32 (TFO  ()                           0         ['tf.math.multiply_57[0][0]'] 
 pLambda)                                                                                         
                                                                                                  
 tf.math.multiply_60 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_23[
 ambda)                                                             0][0]',                       
                                                                     'tf.__operators__.getitem_24[
                                                                    0][0]']                       
          

 ambda)                                                                                           
                                                                                                  
 tf.math.truediv_10 (TFOpLa  ()                           0         ['tf.__operators__.add_46[0][0
 mbda)                                                              ]',                           
                                                                     'tf.__operators__.add_49[0][0
                                                                    ]']                           
                                                                                                  
 tf.__operators__.add_51 (T  ()                           0         ['tf.math.reduce_sum_30[0][0]'
 FOpLambda)                                                         ]                             
                                                                                                  
 tf.__oper

 6 (SlicingOpLambda)                                                                              
                                                                                                  
 tf.__operators__.getitem_2  (None, 128, 128, 32)         0         ['tf.cast_2[0][0]']           
 5 (SlicingOpLambda)                                                                              
                                                                                                  
 tf.math.subtract_36 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_26[
 ambda)                                                             0][0]']                       
                                                                                                  
 tf.__operators__.getitem_2  (None, 128, 128, 32)         0         ['tf.clip_by_value_2[0][0]']  
 8 (SlicingOpLambda)                                                                              
          

 tf.__operators__.getitem_3  (None, 128, 128, 32)         0         ['tf.cast_2[0][0]']           
 1 (SlicingOpLambda)                                                                              
                                                                                                  
 tf.math.subtract_45 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_32[
 ambda)                                                             0][0]']                       
                                                                                                  
 tf.__operators__.getitem_3  (None, 128, 128, 32)         0         ['tf.clip_by_value_2[0][0]']  
 4 (SlicingOpLambda)                                                                              
                                                                                                  
 tf.__operators__.add_63 (T  ()                           0         ['tf.math.reduce_sum_36[0][0]'
 FOpLambda

 tf.math.multiply_86 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_33[
 ambda)                                                             0][0]',                       
                                                                     'tf.math.subtract_48[0][0]'] 
                                                                                                  
 tf.math.subtract_49 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_33[
 ambda)                                                             0][0]']                       
                                                                                                  
 tf.__operators__.getitem_3  (None, 128, 128, 32)         0         ['tf.cast_2[0][0]']           
 5 (SlicingOpLambda)                                                                              
                                                                                                  
 tf.math.s

                                                                                                  
 tf.math.reduce_sum_50 (TFO  ()                           0         ['tf.math.multiply_87[0][0]'] 
 pLambda)                                                                                         
                                                                                                  
 tf.math.multiply_90 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_35[
 ambda)                                                             0][0]',                       
                                                                     'tf.__operators__.getitem_36[
                                                                    0][0]']                       
                                                                                                  
 tf.math.reduce_sum_52 (TFO  ()                           0         ['tf.math.multiply_91[0][0]'] 
 pLambda) 

 ambda)                                                                                           
                                                                                                  
 tf.math.truediv_17 (TFOpLa  ()                           0         ['tf.__operators__.add_77[0][0
 mbda)                                                              ]',                           
                                                                     'tf.__operators__.add_80[0][0
                                                                    ]']                           
                                                                                                  
 tf.__operators__.add_82 (T  ()                           0         ['tf.math.reduce_sum_48[0][0]'
 FOpLambda)                                                         ]                             
                                                                                                  
 tf.__oper

                                                                                                  
 tf.__operators__.getitem_3  (None, 128, 128, 32)         0         ['tf.clip_by_value_3[0][0]']  
 8 (SlicingOpLambda)                                                                              
                                                                                                  
 tf.__operators__.getitem_3  (None, 128, 128, 32)         0         ['tf.cast_3[0][0]']           
 7 (SlicingOpLambda)                                                                              
                                                                                                  
 tf.math.subtract_54 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_38[
 ambda)                                                             0][0]']                       
                                                                                                  
 tf.__oper

 ambda)                                                             0][0]']                       
                                                                                                  
 tf.__operators__.getitem_4  (None, 128, 128, 32)         0         ['tf.cast_3[0][0]']           
 3 (SlicingOpLambda)                                                                              
                                                                                                  
 tf.math.subtract_63 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_44[
 ambda)                                                             0][0]']                       
                                                                                                  
 tf.__operators__.getitem_4  (None, 128, 128, 32)         0         ['tf.clip_by_value_3[0][0]']  
 6 (SlicingOpLambda)                                                                              
          

 Lambda)                                                             'tf.__operators__.getitem_44[
                                                                    0][0]']                       
                                                                                                  
 tf.math.multiply_116 (TFOp  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_45[
 Lambda)                                                            0][0]',                       
                                                                     'tf.math.subtract_66[0][0]'] 
                                                                                                  
 tf.math.subtract_67 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_45[
 ambda)                                                             0][0]']                       
                                                                                                  
 tf.__oper

 tf.math.reduce_sum_66 (TFO  ()                           0         ['tf.math.multiply_115[0][0]']
 pLambda)                                                                                         
                                                                                                  
 tf.math.multiply_118 (TFOp  ()                           0         ['tf.math.reduce_sum_67[0][0]'
 Lambda)                                                            ]                             
                                                                                                  
 tf.math.reduce_sum_68 (TFO  ()                           0         ['tf.math.multiply_117[0][0]']
 pLambda)                                                                                         
                                                                                                  
 tf.math.multiply_120 (TFOp  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_47[
 Lambda)  

 tf.__operators__.add_97 (T  ()                           0         ['tf.math.pow_18[0][0]']      
 FOpLambda)                                                                                       
                                                                                                  
 tf.math.pow_19 (TFOpLambda  ()                           0         ['tf.math.subtract_59[0][0]'] 
 )                                                                                                
                                                                                                  
 tf.math.subtract_62 (TFOpL  ()                           0         ['tf.math.truediv_23[0][0]']  
 ambda)                                                                                           
                                                                                                  
 tf.math.truediv_24 (TFOpLa  ()                           0         ['tf.__operators__.add_108[0][
 mbda)    

                                                                                                  
 add_metric_3 (AddMetric)    ()                           0         ['tf.math.truediv_27[0][0]']  
                                                                                                  
 tf.clip_by_value_4 (TFOpLa  (None, 128, 128, 32, 6)      0         ['tf.math.multiply_4[0][0]']  
 mbda)                                                                                            
                                                                                                  
 tf.cast_4 (TFOpLambda)      (None, 128, 128, 32, 6)      0         ['mask_input[0][0]']          
                                                                                                  
 tf.__operators__.getitem_5  (None, 128, 128, 32)         0         ['tf.clip_by_value_4[0][0]']  
 0 (SlicingOpLambda)                                                                              
          

                                                                    0][0]']                       
                                                                                                  
 tf.math.multiply_136 (TFOp  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_53[
 Lambda)                                                            0][0]',                       
                                                                     'tf.math.subtract_78[0][0]'] 
                                                                                                  
 tf.math.subtract_79 (TFOpL  (None, 128, 128, 32)         0         ['tf.__operators__.getitem_53[
 ambda)                                                             0][0]']                       
                                                                                                  
 tf.__operators__.getitem_5  (None, 128, 128, 32)         0         ['tf.cast_4[0][0]']           
 5 (Slicin

 Lambda)                                                            0][0]',                       
                                                                     'tf.__operators__.getitem_56[
                                                                    0][0]']                       
                                                                                                  
 tf.math.reduce_sum_82 (TFO  ()                           0         ['tf.math.multiply_141[0][0]']
 pLambda)                                                                                         
                                                                                                  
 tf.math.multiply_142 (TFOp  (None, 128, 128, 32)         0         ['tf.math.subtract_82[0][0]', 
 Lambda)                                                             'tf.__operators__.getitem_56[
                                                                    0][0]']                       
          

                                                                                                  
 tf.__operators__.add_140 (  ()                           0         ['tf.math.reduce_sum_81[0][0]'
 TFOpLambda)                                                        , 'tf.math.multiply_143[0][0]'
                                                                    ]                             
                                                                                                  
 tf.math.multiply_144 (TFOp  ()                           0         ['tf.math.reduce_sum_83[0][0]'
 Lambda)                                                            ]                             
                                                                                                  
 tf.math.reduce_sum_84 (TFO  ()                           0         ['tf.math.multiply_145[0][0]']
 pLambda)                                                                                         
          

                                                                                                  
 tf.__operators__.add_150 (  ()                           0         ['tf.math.reduce_sum_87[0][0]'
 TFOpLambda)                                                        , 'tf.math.multiply_153[0][0]'
                                                                    ]                             
                                                                                                  
 tf.math.multiply_154 (TFOp  ()                           0         ['tf.math.reduce_sum_89[0][0]'
 Lambda)                                                            ]                             
                                                                                                  
 tf.__operators__.add_128 (  ()                           0         ['tf.math.pow_24[0][0]']      
 TFOpLambda)                                                                                      
          

 )                                                                                                
                                                                                                  
 tf.__operators__.add_153 (  ()                           0         ['tf.__operators__.add_148[0][
 TFOpLambda)                                                        0]',                          
                                                                     'tf.math.pow_29[0][0]']      
                                                                                                  
 tf.math.truediv_34 (TFOpLa  ()                           0         ['tf.__operators__.add_153[0][
 mbda)                                                              0]']                          
                                                                                                  
 add_metric_4 (AddMetric)    ()                           0         ['tf.math.truediv_34[0][0]']  
          

In [None]:
from keras.callbacks import EarlyStopping, ModelCheckpoint
mc = ModelCheckpoint(f'models/{model_name}.h5',
                  save_best_only= True,
                    monitor='val_output1_loss',
                    mode='min')
model.fit(train_ds,
          validation_data = val_ds, 
          epochs=400-324,
          callbacks=[mc])

Epoch 1/76
Epoch 2/76
Epoch 3/76
  3/116 [..............................] - ETA: 8:48 - loss: 2.0506 - output5_loss: 0.6948 - output4_loss: 0.4864 - output3_loss: 0.3556 - output2_loss: 0.2751 - output1_loss: 0.2375 - cgm_loss: 0.0050 - cgm_focal_loss: 0.0106