In [None]:
import os
import tensorflow as tf
import trimesh
from spektral.utils import chebyshev_filter
os.environ["CUDA_VISIBLE_DEVICES"]="0"
from utils import *
from pathlib import Path
from sklearn.preprocessing import StandardScaler
from tensorflow_addons.layers import InstanceNormalization

with open("patients.json", 'r') as json_file:
    patients = json.load(json_file)
train_patients = patients['train']
val_patients = patients['val']
test_patients = patients['test']
patients = train_patients + val_patients + test_patients
patients = sorted(set(patients).difference(set(remove_patients)))


In [2]:
lambda_chamfer = 1
lambda_chamfer_surface = 1
lambda_dev_edge = 0.1 
lambda_dev_surface = 0.1 
lambda_aspect_ratio  = 0.5
lambda_cap = 0.05
lambda_cfd = 15

In [3]:
pressure_name = '        pressure'
velocity_name = 'velocity-magnitude'
x_velocity_name = '      x-velocity'
y_velocity_name = '      y-velocity'
z_velocity_name = '      z-velocity'

In [4]:
template_mesh = pv.read(f'{data_path}/REGPA-43/final_template.vtu')
cell_entity_ids_to_extract = np.where(template_mesh['CellEntityIds'] == 0)[0]
adj_matrix = vtu_adjacency(template_mesh.extract_cells(cell_entity_ids_to_extract))
template_adj = sparse.csr_matrix(adj_matrix)
template_adj_tf = sparse_csr_to_tf(template_adj)
template_vertices = np.array(template_mesh.points)
template_edges = template_mesh.extract_cells(cell_entity_ids_to_extract).cells.reshape(-1, 5)[:, 1:]
template_edges_lines = template_vertices[template_edges]
template_edge_lengths = calculate_edge_length(template_edges_lines)
L = chebyshev_filter(template_adj, 1, symmetric=True)
L = [sparse_csr_to_tf(l) for l in L]

template_aspect_ratio = get_template_aspect_ratio(template_mesh.extract_cells(cell_entity_ids_to_extract))
template_dev_edge = get_template_dev_edge(template_mesh.extract_cells(cell_entity_ids_to_extract), surface = False)
template_dev_surface = get_template_dev_edge(template_mesh.extract_cells(cell_entity_ids_to_extract), surface = True)
surface_template_mesh = template_mesh.extract_cells(cell_entity_ids_to_extract).extract_surface()

surface_vertex_indices = surface_template_mesh['vtkOriginalPointIds']
surface_template_mesh = trimesh.Trimesh(surface_template_mesh.points, surface_template_mesh.faces.reshape(-1, 4)[:, 1:])
surface_template_edges = surface_template_mesh.edges
vertex_neighbors = surface_template_mesh.vertex_neighbors

surface_template_mesh = template_mesh.extract_surface()


In [6]:
cell_id_2 = 2
cap_indices_2, cap_faces_2, _  = extract_cap_faces(template_mesh, cell_id_2)

cell_id_3 = 3
cap_indices_3, cap_faces_3, _  = extract_cap_faces(template_mesh, cell_id_3)

cell_id_4 = 4
cap_indices_4, cap_faces_4, _  = extract_cap_faces(template_mesh, cell_id_4)

In [11]:
from sklearn.preprocessing import StandardScaler

P = []
V = []
for patient in train_patients:
    try:
        cfd_vtu_paths = glob.glob(f'{data_path}/REGPA-56/cfd_vtu/{patient}*')
        for cfd_vtu_path in cfd_vtu_paths:
            data = meshio.read(cfd_vtu_path)
            faces = data.cells[-1].data
            vertices = data.points
            pressure =  np.cbrt(data.point_data[pressure_name])
            xvelocity =  data.point_data[x_velocity_name]
            yvelocity =  data.point_data[y_velocity_name]
            zvelocity =  data.point_data[z_velocity_name]

            P.append(pressure)
            V.append(xvelocity)
            V.append(yvelocity)
            V.append(zvelocity)
    except:
        print(patient)

P = np.concatenate(P, 0).reshape(-1, 1)
V = np.concatenate(V, 0).reshape(-1, 1)

pressure_scaler = StandardScaler()
velocity_scaler = StandardScaler()

pressure_scaler.fit(P)
velocity_scaler.fit(V)
try:
    pressure_mean = pressure_scaler.mean_
    pressure_std = np.sqrt(pressure_scaler.var_)

    velocity_mean = velocity_scaler.mean_
    velocity_std = np.sqrt(velocity_scaler.var_)
except:
    pressure_mean = pressure_scaler.center_
    pressure_std = pressure_scaler.scale_

    velocity_mean = velocity_scaler.center_
    velocity_std = velocity_scaler.scale_

In [13]:
class CustomDataGen():    
    def __init__(self, 
                 patients,
                 cohort,
                 v = 0
                ):
        self.patients = patients
        self.cohort = cohort
        self.v = v
        
    def data_generator(self):
        for patient in self.patients:
            image = np.load(f"{data_path}/images/{patient}.npy")
            cfd_vtu_paths = glob.glob(f'{data_path}/REGPA-56/cfd_vtu/{patient}*')
            random.shuffle(cfd_vtu_paths)
            if self.cohort == 'train':
                cfd_vtu_paths = [cfd_vtu_paths[0]]
            elif self.cohort == 'val':
                cfd_vtu_paths = cfd_vtu_paths[:2]
            elif self.cohort == 'test':
                cfd_vtu_path =  f'{data_path}/REGPA-56/cfd_vtu/{patient}_v={self.v}.vtu'
                cfd_vtu_paths = [cfd_vtu_path]
                
            for cfd_vtu_path in cfd_vtu_paths:
                v = float(cfd_vtu_path.split('v=')[-1].replace('.vtu',''))
                v = velocity_scaler.transform(np.array([v]).reshape(-1,1))
                data = pv.read(cfd_vtu_path)
                vertices = data.points
                true_surface_vertices = data.extract_surface().points

                faces = data.cells[-1].data
                pressure =  data.point_data[pressure_name]
                xvelocity =  data.point_data[x_velocity_name]
                yvelocity =  data.point_data[y_velocity_name]
                zvelocity =  data.point_data[z_velocity_name]
                velocity = data.point_data[velocity_name]

                pressure =  np.cbrt(pressure)

                normalized_pressure = (pressure - pressure_mean)/pressure_std
                normalized_xvelocity = (xvelocity - velocity_mean)/velocity_std
                normalized_yvelocity = (yvelocity - velocity_mean)/velocity_std
                normalized_zvelocity = (zvelocity - velocity_mean)/velocity_std

                true_pressure_flow = np.stack([normalized_pressure, normalized_xvelocity, normalized_yvelocity, normalized_zvelocity],-1)
                true_pressure = np.array(normalized_pressure)[...,np.newaxis]
                true_velocity = np.stack([normalized_xvelocity, normalized_yvelocity, normalized_zvelocity], -1)
                true_pv = np.stack([normalized_pressure, normalized_xvelocity, normalized_yvelocity, normalized_zvelocity], -1)

                template_data = meshio.read(f'{data_path}/REGPA-56/final_template_mesh_REGPA-56.vtu')
                template_pressure = template_data.point_data[pressure_name]
                template_xvelocity = template_data.point_data[x_velocity_name]
                template_yvelocity = template_data.point_data[y_velocity_name]
                template_zvelocity = template_data.point_data[z_velocity_name]
                template_velocity = template_data.point_data[velocity_name]

                template_pressure =  np.cbrt(template_pressure)

                normalized_template_pressure = (template_pressure - pressure_mean)/pressure_std
                normalized_template_xvelocity = (template_xvelocity - velocity_mean)/velocity_std
                normalized_template_yvelocity = (template_yvelocity - velocity_mean)/velocity_std
                normalized_template_zvelocity = (template_zvelocity - velocity_mean)/velocity_std

                template_pressure = np.array(normalized_template_pressure)[...,np.newaxis]
                template_velocity = np.stack([normalized_template_xvelocity, normalized_template_yvelocity, normalized_template_zvelocity], -1)
                template_pv = np.stack([normalized_template_pressure, normalized_template_xvelocity, normalized_template_yvelocity, normalized_template_zvelocity], -1)
                inlet_velocity = np.array([v] * len(vertices)).reshape(-1,1)
                if self.cohort == 'train':
                    if random.random() < 0.2:
                        image, vertices, true_surface_vertices = random_roll_and_translate(image, vertices, true_surface_vertices)

                yield {'image_input':normalize(image[...,np.newaxis]), 
                       'template_vertices_input':template_vertices, 
                       'template_pressure_input':template_pressure,
                       'template_velocity_input':template_velocity,
                       'template_pv_input':template_pv,
                       'true_vertices_input':vertices,
                       'true_pressure_input':true_pressure,
                       'true_velocity_input':true_velocity,
                       'true_pv_input':true_pv,
                       'true_surface_vertices_input':true_surface_vertices,
                       'inlet_velocity_input':inlet_velocity
                      }, true_pressure_flow

    def get_gen(self):
        return self.data_generator()    

output_types = ({'image_input': tf.float32, 
                 'template_vertices_input': tf.float32, 
                 'template_pressure_input': tf.float32,   
                 'template_velocity_input': tf.float32,   
                 'template_pv_input': tf.float32,   
                 'true_vertices_input': tf.float32,
                 'true_pressure_input': tf.float32,   
                 'true_velocity_input': tf.float32,   
                 'true_pv_input': tf.float32,   
                 'true_surface_vertices_input':tf.float32,
                'inlet_velocity_input':tf.float32}, 
                 tf.float32)


In [14]:
train_gen = CustomDataGen(train_patients, 'train').get_gen
val_gen   = CustomDataGen(val_patients, 'val').get_gen

train_ds = tf.data.Dataset.from_generator(train_gen, output_types=output_types)

val_ds = tf.data.Dataset.from_generator(val_gen, output_types=output_types)


BATCH_SIZE = 1
train_ds = train_ds.shuffle(200, seed = 42, reshuffle_each_iteration=True).batch(BATCH_SIZE).prefetch(-1)
val_ds = val_ds.batch(BATCH_SIZE).prefetch(-1)

L = chebyshev_filter(template_adj, 1, symmetric=True)
L = [sparse_csr_to_tf(l) for l in L]

In [15]:
X, y = next(iter(train_ds))

In [20]:
def calculate_normals(vertices, faces):
    # Extract the vertex coordinates for each face
    v0 = tf.gather(vertices, faces[:, 0])
    v1 = tf.gather(vertices, faces[:, 1])
    v2 = tf.gather(vertices, faces[:, 2])
    edge1 = v1 - v0
    edge2 = v2 - v0
    # Compute the face normals using the cross product
    face_normals = tf.linalg.cross(edge1, edge2)

    # Normalize the face normals
    face_normals = tf.math.l2_normalize(face_normals, axis=1)
    return face_normals

def cap_loss(pred_coords, cap_indices, cap_faces):
    pred_coords = pred_coords[0]
    pred_edges = tf.gather(pred_coords, surface_vertex_indices)
    
    pred_coords = tf.gather(pred_coords, cap_indices[0])
    pred_coords = tf.gather(pred_coords, cap_indices[1])
    
    cap_normals = calculate_normals(pred_coords, cap_faces)

    mean_normal = tf.reduce_mean(cap_normals, axis=0)

    # Calculate L2 (Euclidean) distances between each cap normal and the mean normal
    l2_distances = tf.norm(cap_normals - mean_normal, axis=1)

    # Square the distances to get the L2 loss terms
    l2_loss_terms = tf.square(l2_distances)

    co_planar_loss = tf.reduce_sum(l2_loss_terms)
    return co_planar_loss


def aspect_ratio(pred_coords, surface = True):
    pred_coords = pred_coords[0]
    if surface:
        pred_edges = tf.gather(pred_coords, surface_vertex_indices)
        pred_edges = tf.gather(pred_edges, surface_template_edges)
    else:
        pred_edges = tf.gather(pred_coords, template_edges)
    diff = pred_edges[:, :-1] - pred_edges[:, 1:]
    lengths = tf.norm(diff, axis=2)
    aspect_ratios = tf.reduce_max(lengths, axis=1) / tf.reduce_min(lengths, axis=1)
    mean_aspect_ratio = tf.reduce_mean(aspect_ratios)
#     mean_aspect_ratio -= template_aspect_ratio
    mean_aspect_ratio = tf.maximum(mean_aspect_ratio, 0)
    return mean_aspect_ratio 

def mean_edge_length(pred_coords):
    pred_coords = pred_coords[0]
    pred_edges_lines = tf.gather(pred_coords, template_edges)
    pred_edge_lengths = calculate_edge_length(pred_edges_lines)
    pred_mean_edge_length = tf.reduce_mean(pred_edge_lengths)
#     pred_mean_edge_length -= 2
    pred_mean_edge_length = tf.maximum(pred_mean_edge_length, 0)
    return pred_mean_edge_length

def dev_edge_length(pred_coords, surface = False):
    pred_coords = pred_coords[0]
    if surface:
        pred_coords = tf.gather(pred_coords, surface_vertex_indices)
        pred_edges_lines = tf.gather(pred_coords, surface_template_edges)
    else:
        pred_edges_lines = tf.gather(pred_coords, template_edges)
    pred_edge_lengths = calculate_edge_length(pred_edges_lines)
    
    pred_mean_edge_length = tf.reduce_mean(pred_edge_lengths)
    pred_std_edge_length = tf.math.reduce_std(pred_edge_lengths)
    dev_edge_length = tf.divide(pred_std_edge_length, pred_mean_edge_length) 
#     if not surface:
#         dev_edge_length -= template_dev_edge
#     else:
#         dev_edge_length -= template_dev_surface
    dev_edge_length = tf.maximum(dev_edge_length, 0)
    return dev_edge_length


def chamfer_distance(y_true, y_pred, surface):
    y_pred = y_pred[0]
    y_true = y_true[0]
    if surface:
        y_pred = tf.gather(y_pred, surface_vertex_indices)
    N_chamfer = tf.shape(y_pred)[0]
    num_features = 3
    expanded_y_true = tf.tile(y_true, (N_chamfer, 1))
    expanded_y_pred = tf.reshape(
        tf.tile(tf.expand_dims(y_pred, 1), (1, N_chamfer, 1)),
        (-1, num_features))
    distances = tf.norm(expanded_y_true - expanded_y_pred, axis=1)
    distances = tf.reshape(distances, (N_chamfer, N_chamfer))

    av_dist1 = tf.reduce_mean(tf.reduce_min(distances, axis=1))
    av_dist2 = tf.reduce_mean(tf.reduce_min(distances, axis=0))

    return av_dist1 + av_dist2

def split_tensor(x):
    split1 = x[:, :, 0:1]
    split2 = x[:, :, 1:]
    return split1, split2

def nmae(y_true, y_pred, pressure, mean = 0, std = 0, loss = False):
    if loss:
        y_true = y_true[0]
        y_true = y_true * std + mean
        if pressure:
            y_true = tf.math.pow(y_true, 3)
        y_true = tf.norm(y_true, axis = -1)
        y_pred = y_pred[0]
        y_pred = y_pred * std + mean
        if pressure:
            y_pred = tf.math.pow(y_pred, 3)
        y_pred = tf.norm(y_pred, axis = -1)
        
    true_range = tf.reduce_max(y_true) - tf.reduce_min(y_true)
    nmae_pressure = tf.reduce_mean(tf.abs(y_true - y_pred) * 100 / (true_range + tf.keras.backend.epsilon()))
    return nmae_pressure

In [24]:
def calculate_assd_tf(surface1_points, surface2_points):
    surface1_points = surface1_points[0]
    surface2_points = surface2_points[0]
    # Calculate distances between each point on surface1 and surface2
    distances_surface1_to_surface2 = tf.norm(surface1_points[:, tf.newaxis, :] - surface2_points, axis=-1)
    distances_surface2_to_surface1 = tf.norm(surface2_points[:, tf.newaxis, :] - surface1_points, axis=-1)

    # Calculate ASSD by averaging the distances in both directions
    assd = (tf.reduce_mean(tf.reduce_min(distances_surface1_to_surface2, axis=1)) +
            tf.reduce_mean(tf.reduce_min(distances_surface2_to_surface1, axis=1))) / 2.0
    return assd



In [25]:
from tensorflow.keras.layers import Input, Lambda
def graph_res_block(inputs, filters):
    G1 = ChebConv(filters, kernel_initializer='he_normal')([inputs, L])
    G1 = InstanceNormalization()(G1)

    G = ChebConv(filters, kernel_initializer='he_normal')([G1, L])
    G = InstanceNormalization()(G)
    G = Activation('LeakyReLU')(G)

    G = ChebConv(filters, kernel_initializer='he_normal')([G, L])
    G = InstanceNormalization()(G)

    G = Average()([G1, G])
    G = Activation('LeakyReLU')(G)
    return G

def conv_res_block(inputs, num_filters):
    # First convolutional layer
    x1 = Conv3D(num_filters, kernel_size = 3, padding='same',kernel_initializer = 'he_normal')(inputs)
    x1 = InstanceNormalization(axis = -1)(x1)
    
    # Second convolutional layer
    x = Conv3D(num_filters, kernel_size = 3, padding='same',kernel_initializer = 'he_normal')(inputs)
    x = InstanceNormalization(axis = -1)(x)
    x = Activation('LeakyReLU')(x)
    x = SpatialDropout3D(rate=0.3)(x) 

    x = Conv3D(num_filters, kernel_size = 3, padding='same',kernel_initializer = 'he_normal')(x)
    x = InstanceNormalization(axis = -1)(x)
    
    x = Add()([x1, x])
    x = Activation('LeakyReLU')(x)
    
    return x


dim = 3
point_cloud_shape = (len(template_vertices), 3)
p_shape = (len(template_vertices), 1)
v_shape = (len(template_vertices), 3)
pv_shape = (len(template_vertices), 4)

tf.keras.backend.clear_session()
image_shape = X['image_input'].shape[1:]
conv = Conv3D
maxpool =  MaxPooling3D

features_shape = (len(template_vertices), dim)

# Define the input layers
image_input = Input(shape=image_shape, name='image_input')
template_vertices_input = Input(shape=point_cloud_shape, name='template_vertices_input')  # 2 feature dimensions for x and y coordinates
template_pressure_input = Input(shape=p_shape, name='template_pressure_input')  # 2 feature dimensions for x and y coordinates
template_velocity_input = Input(shape=v_shape, name='template_velocity_input')  # 2 feature dimensions for x and y coordinates
template_pv_input = Input(shape=pv_shape, name='template_pv_input')  # 2 feature dimensions for x and y coordinates

true_vertices_input = Input(shape=point_cloud_shape, name='true_vertices_input')  # 2 feature dimensions for x and y coordinates
true_surface_vertices_input = Input(shape=(None, 3), name='true_surface_vertices_input')  # 2 feature dimensions for x and y coordinates
true_pressure_input = Input(shape=p_shape, name='true_pressure_input')  # 2 feature dimensions for x and y coordinates
true_velocity_input = Input(shape=v_shape, name='true_velocity_input')  # 2 feature dimensions for x and y coordinates
true_pv_input = Input(shape=pv_shape, name='true_pv_input')  # 2 feature dimensions for x and y coordinates
inlet_velocity_input = Input(shape = (len(template_vertices),1), name='inlet_velocity_input')

# Define the convolutional layers for the image input
X1 = conv_res_block(image_input, 16)
X1 = conv_res_block(X1, 16)
X1 = maxpool(pool_size=2)(X1)

X2 = conv_res_block(X1, 48)
X2 = conv_res_block(X2, 48)
X2 = maxpool(pool_size=2)(X2)

X3 = conv_res_block(X2, 96)
X3 = conv_res_block(X3, 96)
X3 = maxpool(pool_size=2)(X3)

X4 = conv_res_block(X3, 192)
X4 = conv_res_block(X4, 192)
X4 = maxpool(pool_size=2)(X4)

X5 = conv_res_block(X4, 384)
X5 = conv_res_block(X5, 384)
X5 = maxpool(pool_size=2)(X5)

# template_input = Concatenate()([template_vertices_input, template_pv_input, inlet_velocity_input])
template_input = Concatenate()([template_vertices_input, inlet_velocity_input])
G4 = ChebConv(384, activation='LeakyReLU', kernel_initializer='he_normal')([template_input, L])

proj5 = Projection()([X5,template_vertices_input]) 
proj4 = Projection()([X4,template_vertices_input]) 
concat4 = Concatenate()([proj5,proj4, G4, template_vertices_input])

G3 = graph_res_block(concat4, 288)
G3 = graph_res_block(G3, 288)
G3 = graph_res_block(G3, 288)

output3 =  ChebConv(dim, activation='LeakyReLU', name = 'output3')([G3, L])    
output3 = Add(name = 'output3')([template_vertices_input,output3])

CFD3 =  ChebConv(4, activation='LeakyReLU', name = 'CFD3')([G3, L])    
CFD3 = Add(name = 'CFD3')([template_pv_input,CFD3])

P3, V3 = Lambda(split_tensor)(CFD3)

G3 = ChebConv(144, activation='LeakyReLU')([G3,L])

proj3 = Projection()([X3, output3]) 
proj2 = Projection()([X2, output3])     
concat3 = Concatenate()([proj3,proj2, G3, output3])

G2 = graph_res_block(concat3, 96)
G2 = graph_res_block(G2, 96)
G2 = graph_res_block(G2, 96)

output2 =  ChebConv(dim, activation='LeakyReLU')([G2, L])        
output2 = Add(name = 'output2')([output3,output2])

CFD2 =  ChebConv(4, activation='LeakyReLU', name = 'CFD2')([G2, L])    
CFD2 = Add(name = 'CFD2')([CFD3,CFD2])

P2, V2 = Lambda(split_tensor)(CFD2)

G2 = ChebConv(64, activation='LeakyReLU')([G2,L])
proj2 = Projection()([X2,output2]) 
proj1 = Projection()([X1,output2]) 
concat2 = Concatenate()([proj2, proj1, G2, output2])

G1 = graph_res_block(concat2, 32)
G1 = graph_res_block(G1, 32)
G1 = graph_res_block(G1, 32)

output1 = ChebConv(dim, activation='LeakyReLU', name = 'output1')([G1, L])
output1 = Add(name = 'output')([output2,output1])

CFD1 =  ChebConv(4, activation='LeakyReLU', name = 'CFD1')([G1, L])    
CFD1 = Add(name = 'CFD1')([CFD2,CFD1])

P1, V1 = Lambda(split_tensor)(CFD1)


model = Model(inputs=[image_input, template_vertices_input, template_pressure_input,template_velocity_input,template_pv_input, true_vertices_input, true_pressure_input, true_velocity_input, true_pv_input,true_surface_vertices_input, inlet_velocity_input], outputs = [CFD3, CFD2, CFD1])

chamfer_loss1 = chamfer_distance(true_vertices_input, output1, surface = False)
chamfer_loss2 = chamfer_distance(true_vertices_input, output2, surface = False)
chamfer_loss3 = chamfer_distance(true_vertices_input, output3, surface = False)

surface_chamfer_loss1 = chamfer_distance(true_surface_vertices_input, output1, surface = True)
surface_chamfer_loss2 = chamfer_distance(true_surface_vertices_input, output2, surface = True)
surface_chamfer_loss3 = chamfer_distance(true_surface_vertices_input, output3, surface = True)

dev_edge_loss1 = dev_edge_length(output1, surface = False)
dev_edge_loss2 = dev_edge_length(output2, surface = False) 
dev_edge_loss3 = dev_edge_length(output3, surface = False) 

dev_surface_loss1 = dev_edge_length(output1, surface = True)
dev_surface_loss2 = dev_edge_length(output2, surface = True) 
dev_surface_loss3 = dev_edge_length(output3, surface = True) 

aspect_ratio_loss1 = aspect_ratio(output1, surface = False)
aspect_ratio_loss2 = aspect_ratio(output2,  surface = False) 
aspect_ratio_loss3 = aspect_ratio(output3,  surface = False) 

cap_loss_2_1 = cap_loss(output1, cap_indices_2, cap_faces_2)
cap_loss_2_2 = cap_loss(output2, cap_indices_2, cap_faces_2)
cap_loss_2_3 = cap_loss(output3, cap_indices_2, cap_faces_2)

cap_loss_3_1 = cap_loss(output1, cap_indices_3, cap_faces_3)
cap_loss_3_2 = cap_loss(output2, cap_indices_3, cap_faces_3)
cap_loss_3_3 = cap_loss(output3, cap_indices_3, cap_faces_3)

cap_loss_4_1 = cap_loss(output1, cap_indices_4, cap_faces_4)
cap_loss_4_2 = cap_loss(output2, cap_indices_4, cap_faces_4)
cap_loss_4_3 = cap_loss(output3, cap_indices_4, cap_faces_4)

cap_loss_1 = cap_loss_2_1 + cap_loss_3_1 + cap_loss_4_1
cap_loss_2 = cap_loss_2_2 + cap_loss_3_2 + cap_loss_4_2
cap_loss_3 = cap_loss_2_3 + cap_loss_3_3 + cap_loss_4_3

cfd_loss1 = mae_loss(true_pv_input, CFD1)
cfd_loss2 = mae_loss(true_pv_input, CFD2)
cfd_loss3 = mae_loss(true_pv_input, CFD3)

velocity_nmae = nmae(true_velocity_input, V1, mean = velocity_mean, std = velocity_std, pressure = False, loss = True)
velocity_nmae2 = nmae(true_velocity_input, V2, mean = velocity_mean, std = velocity_std, pressure = False, loss = True)
velocity_nmae3 = nmae(true_velocity_input, V3, mean = velocity_mean, std = velocity_std, pressure = False, loss = True)

pressure_nmae = nmae(true_pressure_input, P1, mean = pressure_mean, std = pressure_std, pressure = True, loss = True)
pressure_nmae2 = nmae(true_pressure_input, P2, mean = pressure_mean, std = pressure_std, pressure = True, loss = True)
pressure_nmae3 = nmae(true_pressure_input, P3, mean = pressure_mean, std = pressure_std, pressure = True, loss = True)

nmae_total = velocity_nmae*1.2 + pressure_nmae/1.1 + 50 * (tf.maximum(chamfer_loss1 - 1.9, 0))

assd_loss = calculate_assd_tf(true_vertices_input, output1)

model.add_metric(chamfer_loss1, name='chamfer_loss', aggregation='mean')
model.add_metric(surface_chamfer_loss1, name='surface_chamfer_loss', aggregation='mean')
model.add_metric(dev_edge_loss1, name='dev_edge_loss', aggregation='mean')
model.add_metric(dev_surface_loss1, name='dev_surface_loss', aggregation='mean')
model.add_metric(aspect_ratio_loss1, name='aspect_ratio_loss', aggregation='mean')
model.add_metric(cap_loss_1, name='cap_loss', aggregation='mean')

model.add_metric(cfd_loss1, name='cfd_loss', aggregation='mean')

model.add_metric(velocity_nmae, name='velocity_nmae', aggregation='mean')

model.add_metric(pressure_nmae, name='pressure_nmae', aggregation='mean')
model.add_metric(nmae_total, name='nmae', aggregation='mean')
model.add_metric(assd_loss, name='assd_loss', aggregation='mean')


lambdas = [lambda_chamfer, lambda_chamfer_surface, lambda_dev_edge, lambda_dev_surface, lambda_aspect_ratio, lambda_cap, lambda_cfd]

mesh_losses1 = sum([loss * lambdas[i] for i, loss in enumerate([chamfer_loss1, surface_chamfer_loss1, dev_edge_loss1, dev_surface_loss1, aspect_ratio_loss1, cap_loss_1, cfd_loss1])])
mesh_losses2 = sum([loss * lambdas[i] for i, loss in enumerate([chamfer_loss2, surface_chamfer_loss2, dev_edge_loss2, dev_surface_loss2, aspect_ratio_loss2, cap_loss_2, cfd_loss2])])
mesh_losses3 = sum([loss * lambdas[i] for i, loss in enumerate([chamfer_loss3, surface_chamfer_loss3, dev_edge_loss3, dev_surface_loss3, aspect_ratio_loss3, cap_loss_3, cfd_loss3])])

total_losses = mesh_losses1 + mesh_losses2 + mesh_losses3
model.add_loss(total_losses)

if continue_training:
    model.load_weights(f'models/{model_name}.h5')


In [26]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 image_input (InputLayer)       [(None, 128, 128, 1  0           []                               
                                28, 1)]                                                           
                                                                                                  
 conv3d (Conv3D)                (None, 128, 128, 12  448         ['image_input[0][0]']            
                                8, 16)                                                            
                                                                                                  
 max_pooling3d (MaxPooling3D)   (None, 64, 64, 64,   0           ['conv3d[0][0]']                 
                                16)                                                           

In [27]:
model.compile(loss=None, optimizer = tf.keras.optimizers.Adam())    

In [28]:
def evaluate(plots = ['cfd'], cohorts = ['test'], quick = True):
    run['script'].upload(f'{nb_fname}.ipynb')
    run['model'].upload(f'models/{model_name}.h5')
    dices = []
    ious = []
    assds = []
    hds = []
    df = []
    Pressure_RMSEs = []
    Velocity_RMSEs = []
    X_Velocity_RMSEs = []
    Y_Velocity_RMSEs = []
    Z_Velocity_RMSEs = []
    
    Pressure_NMAEs = []
    Velocity_NMAEs = []
    X_Velocity_NMAEs = []
    Y_Velocity_NMAEs = []
    Z_Velocity_NMAEs = []
    overall_smoothness = []
    best_model = tf.keras.models.load_model(f'models/{model_name}.h5', 
                                       compile = False, 
                                       custom_objects = {'ChebConv':ChebConv,
                                                         'Projection':Projection,
                          })    
    cohort = 'test'
    for patient in test_patients:
        v = test_set[patient]
#         for v in [0.1,0.2,0.3]:
        cfd_vtu_path = glob.glob(f'{data_path}/REGPA-56/cfd_vtu/variable/{patient}_v={v}')
#         v = float(cfd_vtu_path.split('v=')[-1].replace('.vtu',''))

        test_gen   = CustomDataGen([patient], 'test', v).get_gen
        test_ds = tf.data.Dataset.from_generator(test_gen, output_types=output_types)
        test_ds = test_ds.batch(1).prefetch(-1)

        for layer in best_model.layers: 
            if layer.name == "output":
                outputs = layer.output
                best_model_trunc = tf.keras.Model(inputs=best_model.inputs, outputs=outputs)

        y_pred = best_model.predict(test_ds)
        if isinstance(y_pred, list):
            y_pred = y_pred[-1]
        y_pred = y_pred[0]
        X_true, y_true = next(test_gen())
        image = np.load(f"{data_path}/images/{patient}.npy")
        true_vertices = X_true['true_vertices_input']

        normalized_true_pressure = y_true[:,0]
        true_pressure = pressure_scaler.inverse_transform(normalized_true_pressure[np.newaxis,...])[0]
        true_pressure = true_pressure**3

        normalized_true_xvelocity = y_true[:,1]
        true_xvelocity = velocity_scaler.inverse_transform(normalized_true_xvelocity[np.newaxis,...])[0]
        normalized_true_yvelocity = y_true[:,2]
        true_yvelocity = velocity_scaler.inverse_transform(normalized_true_yvelocity[np.newaxis,...])[0]
        normalized_true_zvelocity = y_true[:,3]
        true_zvelocity = velocity_scaler.inverse_transform(normalized_true_zvelocity[np.newaxis,...])[0]

        true_normalized_velocity = np.stack([true_xvelocity, true_yvelocity, true_zvelocity], axis = -1)
        true_velocity = np.linalg.norm(true_normalized_velocity, axis = 1)         

        normalized_pred_pressure = y_pred[:,0]
        pred_pressure = pressure_scaler.inverse_transform(normalized_pred_pressure[np.newaxis,...])[0]
        pred_pressure = pred_pressure**3

        normalized_pred_xvelocity = y_pred[:,1]
        pred_xvelocity = velocity_scaler.inverse_transform(normalized_pred_xvelocity[np.newaxis,...])[0]
        normalized_pred_yvelocity = y_pred[:,2]
        pred_yvelocity = velocity_scaler.inverse_transform(normalized_pred_yvelocity[np.newaxis,...])[0]
        normalized_pred_zvelocity = y_pred[:,3]
        pred_zvelocity = velocity_scaler.inverse_transform(normalized_pred_zvelocity[np.newaxis,...])[0]

        pred_xyz_velocity = np.stack([pred_xvelocity, pred_yvelocity, pred_zvelocity], axis = -1)
        pred_velocity = np.linalg.norm(pred_xyz_velocity, axis = 1) 

        pred_vertices = best_model_trunc.predict(test_ds)[0]

        data = meshio.read(f'{data_path}/REGPA-56/final_template_mesh_REGPA-56.vtu')

        faces = data.cells[-1].data
        vertices = data.points
        data.points = pred_vertices
        data.point_data[pressure_name] = pred_pressure
        data.point_data[velocity_name] = pred_velocity
        data.point_data[x_velocity_name] = pred_xvelocity
        data.point_data[y_velocity_name] = pred_yvelocity
        data.point_data[z_velocity_name] = pred_zvelocity
        data.point_data['velocity_xyz'] = pred_xyz_velocity
        Path(f'pred_cfd_vtu/{model_name}/cfd_vtu').mkdir(parents=True, exist_ok=True)
        data.write(f'pred_cfd_vtu/{model_name}/cfd_vtu/{patient}_v={v}.vtu')

#         mesh = pv.read(f'pred_cfd_vtu/{model_name}/cfd_vtu/{patient}.vtu')
#         cell_entity_ids_to_extract = np.where(template_mesh['CellEntityIds'] == 1)[0]
#         clipped = mesh.extract_cells(cell_entity_ids_to_extract)
#         Path(f'pred_cfd_vtu/{model_name}/generated_samples_clipped').mkdir(parents=True, exist_ok=True)
#         clipped.extract_surface().triangulate().save(f'pred_cfd_vtu/{model_name}/generated_samples_clipped/{patient}.vtk', binary = 0)
#         convert_to_old_format(f'pred_cfd_vtu/{model_name}/generated_samples_clipped/{patient}.vtk', f'pred_cfd_vtu/{model_name}/generated_samples_clipped/{patient}.vtk')

        max_pressure = np.max([np.max(abs(true_pressure)),abs(np.max(pred_pressure))])
#             max_pressure = 300
        max_velocity = np.max([np.max(abs(true_velocity)),abs(np.max(pred_velocity))])
#             max_velocity = 0.8

        pressure_range = np.max(true_pressure) - np.min(true_pressure)
        velocity_range = np.max(true_velocity) - np.min(true_velocity)
        xvelocity_range = np.max(true_xvelocity) - np.min(true_xvelocity)
        yvelocity_range = np.max(true_yvelocity) - np.min(true_yvelocity)
        zvelocity_range = np.max(true_zvelocity) - np.min(true_zvelocity)

        nmae_pressure = np.mean((abs(true_pressure - pred_pressure))*100/pressure_range)
        nmae_velocity = np.mean((abs(true_velocity - pred_velocity))*100/velocity_range)
        nmae_xvelocity = np.mean((abs(true_xvelocity - pred_xvelocity))*100/xvelocity_range)
        nmae_yvelocity = np.mean((abs(true_yvelocity - pred_yvelocity))*100/yvelocity_range)
        nmae_zvelocity = np.mean((abs(true_zvelocity - pred_zvelocity))*100/zvelocity_range)

        Pressure_NMAEs.append(nmae_pressure)
        Velocity_NMAEs.append(nmae_velocity)
        X_Velocity_NMAEs.append(nmae_xvelocity)
        Y_Velocity_NMAEs.append(nmae_yvelocity)
        Z_Velocity_NMAEs.append(nmae_zvelocity)

        mae_pressure = np.mean(abs(true_pressure - pred_pressure))
        mae_velocity = np.mean(abs(true_velocity - pred_velocity))

        mse_pressure = np.mean((true_pressure - pred_pressure)**2)
        mse_velocity = np.mean((true_velocity - pred_velocity)**2)

        rmse_pressure = np.sqrt(np.mean((true_pressure - pred_pressure)**2))
        rmse_velocity = np.sqrt(np.mean((true_velocity - pred_velocity)**2))
        rmse_xvelocity = np.sqrt(np.mean((true_xvelocity - pred_xvelocity)**2))
        rmse_yvelocity = np.sqrt(np.mean((true_yvelocity - pred_yvelocity)**2))
        rmse_zvelocity = np.sqrt(np.mean((true_zvelocity - pred_zvelocity)**2))

        Pressure_RMSEs.append(rmse_pressure)
        Velocity_RMSEs.append(rmse_velocity)
        X_Velocity_RMSEs.append(rmse_xvelocity)
        Y_Velocity_RMSEs.append(rmse_yvelocity)
        Z_Velocity_RMSEs.append(rmse_zvelocity)

        mesh = template_mesh.copy()
        mesh.points = pred_vertices
        true_mesh = pv.read(f'{data_path}/REGPA-56/cfd_vtu/{patient}_v={v}.vtu')
        true_mesh = true_mesh.extract_surface().triangulate()
        true_mesh = pv_to_trimesh(true_mesh)
        true_mask = point_cloud_to_mask(true_mesh.vertices, true_mesh.edges, (128,128,128), num_points = 20)

        pred_mesh = mesh.extract_surface().triangulate()
        pred_mesh = pv_to_trimesh(pred_mesh)
        pred_mask = point_cloud_to_mask(pred_mesh.vertices, pred_mesh.edges, (128,128,128), num_points = 20)

        dice_val = single_dice(true_mask, pred_mask)
        iou_val = single_iou(true_mask, pred_mask)
        assd_val = calculate_assd(true_mesh.vertices, pred_mesh.vertices)
        hd_val = hausdorff_distance(true_mesh.vertices, pred_mesh.vertices)

        dices.append(dice_val)
        ious.append(iou_val)
        assds.append(assd_val)
        hds.append(hd_val)
        smoothness = round(laplacian_coef(pred_vertices,surface_vertex_indices, vertex_neighbors), 2)            
        overall_smoothness.append(smoothness)

        df.append({'patient':patient, 
                   'dice':dice_val,
                   'iou':iou_val,
                   'assd':assd_val,
                   'hd':hd_val,
                   'nmae_pressure':nmae_pressure, 
                   'nmae_velocity':nmae_velocity,
                  'rmse_pressure':rmse_pressure,
                  'rmse_velocity':rmse_velocity})

        if 'cfd' in plots:
            true_mesh = pv.read(f'{data_path}/REGPA-56/cfd_vtu/{patient}_v={v}.vtu')

            true_vertices = true_mesh.points
            fig = plt.figure(figsize=(10, 10))
            fig.suptitle(
                f'Dice = {round(dice_val, 2)}\n'
                f'RMSE (pressure) = {round(rmse_pressure, 2)}, RMSE (velocity) = {round(rmse_velocity, 4)}\n'
                f'MSE (pressure) = {round(mse_pressure, 2)}, MSE (velocity) = {round(mse_velocity, 4)}\n'
                f'MAE (pressure) = {round(mae_pressure, 2)}, MAE (velocity) = {round(mae_velocity, 4)}\n'
                f'NMAE (pressure) = {round(nmae_pressure, 2)}, NMAE (velocity) = {round(nmae_velocity, 4)}'
            )

            ax1 = fig.add_subplot(221, projection='3d')
            ax1.set_title('True Pressure')

            ax1.view_init(elev=-20, azim=260, roll=-20, vertical_axis='y')  
            p1 = ax1.scatter(true_vertices[:,0], true_vertices[:,1], true_vertices[:,2], s=1, c=true_pressure, cmap ='jet', vmin=-max_pressure, vmax=max_pressure)

            ax2 = fig.add_subplot(222, projection='3d')
            ax2.set_title('Pred Pressure')
            ax2.view_init(elev=-20, azim=260, roll=-20, vertical_axis = 'y')  
            p2 = ax2.scatter(pred_vertices[:,0], pred_vertices[:,1], pred_vertices[:,2], s=1, c=pred_pressure, cmap ='jet', vmin=-max_pressure, vmax=max_pressure)

            ax3 = fig.add_subplot(223, projection='3d')
            ax3.set_title('True Velocity')
            ax3.view_init(elev=-20, azim=260, roll=-20, vertical_axis = 'y')  
            v1 = ax3.scatter(true_vertices[:,0], true_vertices[:,1], true_vertices[:,2], s=1, c=true_velocity, cmap ='jet', vmin=0, vmax=max_velocity)


            ax4 = fig.add_subplot(224, projection='3d')
            ax4.set_title('Pred Velocity')
            ax4.view_init(elev=-20, azim=260, roll=-20, vertical_axis = 'y')  
            for ax in [ax1, ax2, ax3, ax4]:
                lim1, lim2 = 20, 100
                ax.set_xlim(lim1, lim2)
                ax.set_ylim(lim1, lim2)
                ax.set_zlim(lim1, lim2)

            v2 = ax4.scatter(pred_vertices[:,0], pred_vertices[:,1], pred_vertices[:,2], s=1, c=pred_velocity, cmap ='jet', vmin=0, vmax=max_velocity)

            cbar1 = fig.colorbar(p1, ax=ax1, shrink=0.6, aspect=10, pad = 0.1)
            cbar2 = fig.colorbar(p2, ax=ax2, shrink=0.6, aspect=10, pad = 0.1)
            cbar3 = fig.colorbar(v1, ax=ax3, shrink=0.6, aspect=10, pad = 0.1)
            cbar4 = fig.colorbar(v2, ax=ax4, shrink=0.6, aspect=10, pad = 0.1)
            fig.tight_layout()
            run[f'results/{cohort}/{patient}_cfd/{v}'].upload(fig)
            plt.close()

        if 'cloud' in plots or 'all' in plots :
            fig = plt.figure(figsize=(10, 10))
            ax = fig.add_subplot(221, projection='3d')
            ax.scatter(true_vertices[:,0], true_vertices[:,1], true_vertices[:,2], s=5, color = 'green', label = 'truth')
            ax.scatter(pred_vertices[:,0], pred_vertices[:,1], pred_vertices[:,2], s=5, color = 'red',label = 'prediction')
            ax.scatter(template_vertices[:,0], template_vertices[:,1], template_vertices[:,2], s=5, color = 'orange', label = 'template')
            ax.view_init(elev=-20, azim=260, roll=-20, vertical_axis='y')  
            plt.legend()

            ax = fig.add_subplot(222, projection='3d')
            ax.scatter(template_vertices[:,0], template_vertices[:,1], template_vertices[:,2], s=5, color = 'orange', label = 'template')
            ax.view_init(elev=-20, azim=260, roll=-20, vertical_axis='y')  
            plt.legend()

            ax = fig.add_subplot(223, projection='3d')
            ax.scatter(pred_vertices[:,0], pred_vertices[:,1], pred_vertices[:,2], s=5, color = 'red',label = 'prediction')
            ax.view_init(elev=-20, azim=260, roll=-20, vertical_axis='y')  
            plt.legend()

            ax = fig.add_subplot(224, projection='3d')
            ax.scatter(true_vertices[:,0], true_vertices[:,1], true_vertices[:,2], s=5, color = 'green',label = 'truth')
            ax.view_init(elev=-20, azim=260, roll=-20, vertical_axis='y')  
            plt.legend()

            plt.tight_layout()
            st = plt.suptitle(f'Laplace = {smoothness:.2f}, Dice = {dice_val:.2f}', fontsize = 18)
            st.set_y(1.05)
            run[f'results/{cohort}/{patient}_cloud/{v}'].upload(fig)
            plt.close()

        if 'mask' in plots or 'all' in plots:
            fig, axs = plt.subplots(1,2, figsize = (5,2.5))
            frames = []
            mask_frames = list(np.where(np.sum(np.sum(image,0),0) > 0)[0])
            for i in mask_frames:
                p1 = axs[0].imshow(image[...,i],cmap = 'gray')
                p2 = axs[1].imshow(image[...,i],cmap = 'gray')
                p3 = axs[0].imshow(pred_mask[...,i],alpha = pred_mask[...,i] * 0.8, cmap = 'bwr')
                p4 = axs[1].imshow(true_mask[...,i],alpha = true_mask[...,i] * 0.8, cmap = 'PiYG')
                text = plt.text(-10,-5,i)
                frames.append([p1,p2,p3,p4, text])
            fig.tight_layout()
            ani = animation.ArtistAnimation(fig, frames)
            ani.save(f'video_cfd.gif', fps=30)
            run[f'results/{cohort}/{patient}_mask/{v}'].upload(f'video_cfd.gif')
            plt.close()

    run[f'smoothness/{cohort}/mean'] = round(np.mean(overall_smoothness), 2)
    run[f'pressure/{cohort}/median'] = np.median(Pressure_NMAEs)
    run[f'pressure/{cohort}/mean'] = np.mean(Pressure_NMAEs)
    run[f'pressure/{cohort}/std'] = np.std(Pressure_NMAEs)
    run[f'pressure/{cohort}/iqr0.25'] = np.quantile(Pressure_NMAEs, 0.25)
    run[f'pressure/{cohort}/iqr0.75'] = np.quantile(Pressure_NMAEs, 0.75)
    run[f'pressure/{cohort}/iqr'] = stats.iqr(Pressure_NMAEs)

    run[f'pressure/{cohort}/rmse/median'] = np.median(Pressure_RMSEs)
    run[f'pressure/{cohort}/rmse/mean'] = np.mean(Pressure_RMSEs)
    run[f'pressure/{cohort}/rmse/std'] = np.std(Pressure_RMSEs)
    run[f'pressure/{cohort}/rmse/iqr0.25'] = np.quantile(Pressure_RMSEs, 0.25)
    run[f'pressure/{cohort}/rmse/iqr0.75'] = np.quantile(Pressure_RMSEs, 0.75)
    run[f'pressure/{cohort}/rmse/iqr'] = stats.iqr(Pressure_RMSEs)

    run[f'velocity/{cohort}/rmse/median'] = np.median(Velocity_RMSEs)
    run[f'velocity/{cohort}/rmse/mean'] = np.mean(Velocity_RMSEs)
    run[f'velocity/{cohort}/rmse/std'] = np.std(Velocity_RMSEs)
    run[f'velocity/{cohort}/rmse/iqr0.25'] = np.quantile(Velocity_RMSEs, 0.25)
    run[f'velocity/{cohort}/rmse/iqr0.75'] = np.quantile(Velocity_RMSEs, 0.75)
    run[f'velocity/{cohort}/rmse/iqr'] = stats.iqr(Velocity_RMSEs)

    run[f'velocity/{cohort}/median'] = np.median(Velocity_NMAEs)
    run[f'velocity/{cohort}/mean'] = np.mean(Velocity_NMAEs)
    run[f'velocity/{cohort}/std'] = np.std(Velocity_NMAEs)
    run[f'velocity/{cohort}/iqr0.25'] = np.quantile(Velocity_NMAEs, 0.25)
    run[f'velocity/{cohort}/iqr0.75'] = np.quantile(Velocity_NMAEs, 0.75)
    run[f'velocity/{cohort}/iqr'] = stats.iqr(Velocity_NMAEs)

    run[f'xvelocity/{cohort}/median'] = np.median(X_Velocity_NMAEs)
    run[f'xvelocity/{cohort}/iqr0.25'] = np.quantile(X_Velocity_NMAEs, 0.25)
    run[f'xvelocity/{cohort}/iqr0.75'] = np.quantile(X_Velocity_NMAEs, 0.75)
    run[f'xvelocity/{cohort}/iqr'] = stats.iqr(X_Velocity_NMAEs)

    run[f'yvelocity/{cohort}/median'] = np.median(Y_Velocity_NMAEs)
    run[f'yvelocity/{cohort}/iqr0.25'] = np.quantile(Y_Velocity_NMAEs, 0.25)
    run[f'yvelocity/{cohort}/iqr0.75'] = np.quantile(Y_Velocity_NMAEs, 0.75)
    run[f'yvelocity/{cohort}/iqr'] = stats.iqr(Y_Velocity_NMAEs)

    run[f'zvelocity/{cohort}/median'] = np.median(Z_Velocity_NMAEs)
    run[f'zvelocity/{cohort}/iqr0.25'] = np.quantile(Z_Velocity_NMAEs, 0.25)
    run[f'zvelocity/{cohort}/iqr0.75'] = np.quantile(Z_Velocity_NMAEs, 0.75)
    run[f'zvelocity/{cohort}/iqr'] = stats.iqr(Z_Velocity_NMAEs)

    run[f'xvelocity/{cohort}/rmse/median'] = np.median(X_Velocity_RMSEs)
    run[f'xvelocity/{cohort}/rmse/iqr0.25'] = np.quantile(X_Velocity_RMSEs, 0.25)
    run[f'xvelocity/{cohort}/rmse/iqr0.75'] = np.quantile(X_Velocity_RMSEs, 0.75)
    run[f'xvelocity/{cohort}/rmse/iqr'] = stats.iqr(X_Velocity_RMSEs)

    run[f'yvelocity/{cohort}/rmse/median'] = np.median(Y_Velocity_RMSEs)
    run[f'yvelocity/{cohort}/rmse/iqr0.25'] = np.quantile(Y_Velocity_RMSEs, 0.25)
    run[f'yvelocity/{cohort}/rmse/iqr0.75'] = np.quantile(Y_Velocity_RMSEs, 0.75)
    run[f'yvelocity/{cohort}/rmse/iqr'] = stats.iqr(Y_Velocity_RMSEs)

    run[f'zvelocity/{cohort}/rmse/median'] = np.median(Z_Velocity_RMSEs)
    run[f'zvelocity/{cohort}/rmse/iqr0.25'] = np.quantile(Z_Velocity_RMSEs, 0.25)
    run[f'zvelocity/{cohort}/rmse/iqr0.75'] = np.quantile(Z_Velocity_RMSEs, 0.75)
    run[f'zvelocity/{cohort}/rmse/iqr'] = stats.iqr(Z_Velocity_RMSEs)

    run[f'dice/{cohort}/mean'] = round(np.mean(dices), 2)
    run[f'dice/{cohort}/std'] = round(np.std(dices), 2)
    run[f'dice/{cohort}/median'] = round(np.median(dices), 2)
    run[f'dice/{cohort}/iqr0.25'] = np.quantile(dices, 0.25)
    run[f'dice/{cohort}/iqr0.75'] = np.quantile(dices, 0.75)
    run[f'dice/{cohort}/iqr'] = stats.iqr(dices)

    run[f'iou/{cohort}/mean'] = round(np.mean(ious), 2)
    run[f'iou/{cohort}/std'] = round(np.std(ious), 2)
    run[f'iou/{cohort}/median'] = round(np.median(ious), 2)
    run[f'iou/{cohort}/iqr0.25'] = np.quantile(ious, 0.25)
    run[f'iou/{cohort}/iqr0.75'] = np.quantile(ious, 0.75)
    run[f'iou/{cohort}/iqr'] = stats.iqr(ious)

    run[f'assd/{cohort}/mean'] = round(np.mean(assds), 2)
    run[f'assd/{cohort}/std'] = round(np.std(assds), 2)
    run[f'assd/{cohort}/median'] = round(np.median(assds), 2)
    run[f'assd/{cohort}/iqr0.25'] = np.quantile(assds, 0.25)
    run[f'assd/{cohort}/iqr0.75'] = np.quantile(assds, 0.75)
    run[f'assd/{cohort}/iqr'] = stats.iqr(assds)

    run[f'hd/{cohort}/mean'] = round(np.mean(hds), 2)
    run[f'hd/{cohort}/std'] = round(np.std(hds), 2)
    run[f'hd/{cohort}/median'] = round(np.median(hds), 2)
    run[f'hd/{cohort}/iqr0.25'] = np.quantile(hds, 0.25)
    run[f'hd/{cohort}/iqr0.75'] = np.quantile(hds, 0.75)
    run[f'hd/{cohort}/iqr'] = stats.iqr(hds)
    df = pd.DataFrame.from_records(df)
    df.to_csv(f'results/segmentation_{model_name}.csv', index = False)

class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self,data_path,  counter = 0,save_every = 10):
        super(CustomCallback, self).__init__()
        self.save_every = save_every
        self.counter = counter

    def on_epoch_end(self, epoch, logs=None):
        self.counter +=1
        if self.counter % self.save_every == 0:
            evaluate()

    def on_train_end(self, epoch, logs=None):
        evaluate(quick = False)

In [None]:
from keras.callbacks import EarlyStopping, ModelCheckpoint
es = EarlyStopping(monitor='nmae', 
                   mode='min', 
                   verbose = 1, 
                   patience = 50)
mc = ModelCheckpoint(f'models/{model_name}.h5',
                  save_best_only= True,
                    monitor='val_nmae',
                    mode='min')
eval_every_epoch = CustomCallback(counter = 0, save_every = 25, data_path = data_path)
model.fit(train_ds,
          validation_data = val_ds, 
          epochs=500,
          callbacks=[es, mc,eval_every_epoch])

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500











Epoch 26/500
Epoch 27/500
Epoch 28/500
