In [1]:
import tensorflow as tf
import numpy as np
import sys
import random
import scipy.io
import pickle
import os
from tensorflow.keras import initializers
from time import time

from Data_Preprocessing import *
from PhysicsModel_Parameters import *
from PINN_cardiac_DataFusion_Solver import *

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
DTYPE='float32'
tf.keras.backend.set_floatx(DTYPE)

In [3]:
# Experiments under different noise levels
sigma_noise_set = [0.05]
num_obs_channels =20
tf_shape = 400

# Import physics model parameters
a,c,e0,D,mu1,mu2,Delta,Z_BH = Physics_Parameters(DTYPE)

# Sample collocation points for the physics-based loss
Preprocessing=data_preprocessing(DTYPE)
X_f,_ = Preprocessing.collocation_points(tf_shape)
X_whole, lb, ub =Preprocessing.whole_coordinates()

# Set learning rate scheduler
lr1 = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=2e-3,
    decay_steps=500,
    decay_rate=0.95,
    staircase=True)
lr2 = 0.0002
lr3 = 0.0001
optimizer2 = tf.keras.optimizers.legacy.Adam(learning_rate=lr2)
optimizer3 = tf.keras.optimizers.legacy.Adam(learning_rate=lr3)


root_dir = os.getcwd()
for sigma_noise in sigma_noise_set:
    
    # Collect physical measurements for the data-driven loss   
    hsp_m, bsp_m, obs_channels, hsp = Preprocessing.sensor_measurements(sigma_noise, num_obs_channels)

    model = PINN_Net(lb, ub, DTYPE)
    model.build(input_shape=(None,4))

    solver = PINN_Solver(model, X_f, Delta, tf_shape, a, c, e0, D, mu1, mu2, Z_BH)
    # Initial training
    optimizer1 = tf.keras.optimizers.legacy.Adam(learning_rate=lr1)
    solver.train_loop( X_whole, obs_channels, hsp_m, bsp_m, hsp, optimizer1, N_iter=20000)
    # Fine training stage 1
    solver.train_loop( X_whole, obs_channels, hsp_m, bsp_m, hsp, optimizer2, N_iter=20000)
    # Fine training stage 2
    solver.train_loop( X_whole, obs_channels, hsp_m, bsp_m, hsp, optimizer3, N_iter=10000)

    



It 00000: loss=3.1754e+00, l_phys= 2.256e-05, l_d= 3.186e-01, l_hb= 7.773e-01, t= 1.92s
It 00100: loss=2.3956e+00, l_phys= 4.757e-02, l_d= 2.049e-01, l_hb= 3.310e-01, t= 3.82s
It 00200: loss=1.5453e+00, l_phys= 1.042e-01, l_d= 4.809e-02, l_hb= 2.333e-02, t= 3.82s
It 00300: loss=9.1374e-01, l_phys= 5.331e-02, l_d= 2.017e-02, l_hb= 1.714e-03, t= 3.81s
It 00400: loss=6.1893e-01, l_phys= 2.305e-02, l_d= 1.502e-02, l_hb= 1.239e-03, t= 3.81s
It 00500: loss=4.0020e-01, l_phys= 3.102e-03, l_d= 1.148e-02, l_hb= 1.025e-03, t= 3.89s
It 00600: loss=3.2558e-01, l_phys= 1.267e-03, l_d= 9.502e-03, l_hb= 1.015e-03, t= 3.84s
It 00700: loss=2.9554e-01, l_phys= 1.027e-03, l_d= 7.199e-03, l_hb= 1.096e-03, t= 3.82s
It 00800: loss=2.6104e-01, l_phys= 1.090e-03, l_d= 4.004e-03, l_hb= 1.269e-03, t= 3.81s
It 00900: loss=2.2963e-01, l_phys= 1.084e-03, l_d= 2.229e-03, l_hb= 1.281e-03, t= 3.82s
It 01000: loss=2.1972e-01, l_phys= 1.063e-03, l_d= 1.813e-03, l_hb= 1.275e-03, t= 3.81s
It 01100: loss=2.1321e-01, l_phy

It 09400: loss=1.4035e-01, l_phys= 3.717e-04, l_d= 5.403e-04, l_hb= 7.883e-04, t= 3.83s
It 09500: loss=1.4034e-01, l_phys= 3.709e-04, l_d= 5.404e-04, l_hb= 7.890e-04, t= 3.83s
It 09600: loss=1.4016e-01, l_phys= 3.691e-04, l_d= 5.404e-04, l_hb= 7.864e-04, t= 3.83s
It 09700: loss=1.4015e-01, l_phys= 3.693e-04, l_d= 5.397e-04, l_hb= 7.866e-04, t= 3.83s
It 09800: loss=1.4070e-01, l_phys= 3.781e-04, l_d= 5.427e-04, l_hb= 7.858e-04, t= 3.83s
It 09900: loss=1.4059e-01, l_phys= 3.755e-04, l_d= 5.440e-04, l_hb= 7.847e-04, t= 3.83s
It 10000: loss=1.4000e-01, l_phys= 3.670e-04, l_d= 5.404e-04, l_hb= 7.846e-04, t= 3.83s
It 10100: loss=1.3979e-01, l_phys= 3.645e-04, l_d= 5.395e-04, l_hb= 7.839e-04, t= 3.84s
It 10200: loss=1.3974e-01, l_phys= 3.639e-04, l_d= 5.396e-04, l_hb= 7.832e-04, t= 3.84s
It 10300: loss=1.3969e-01, l_phys= 3.635e-04, l_d= 5.389e-04, l_hb= 7.833e-04, t= 3.84s
It 10400: loss=1.3963e-01, l_phys= 3.629e-04, l_d= 5.388e-04, l_hb= 7.826e-04, t= 3.87s
It 10500: loss=1.3974e-01, l_phy

It 18800: loss=1.3622e-01, l_phys= 3.236e-04, l_d= 5.269e-04, l_hb= 7.599e-04, t= 3.92s
It 18900: loss=1.3619e-01, l_phys= 3.231e-04, l_d= 5.269e-04, l_hb= 7.598e-04, t= 3.86s
It 19000: loss=1.3618e-01, l_phys= 3.231e-04, l_d= 5.267e-04, l_hb= 7.600e-04, t= 3.94s
It 19100: loss=1.3612e-01, l_phys= 3.221e-04, l_d= 5.267e-04, l_hb= 7.597e-04, t= 3.91s
It 19200: loss=1.3608e-01, l_phys= 3.216e-04, l_d= 5.265e-04, l_hb= 7.596e-04, t= 3.90s
It 19300: loss=1.3604e-01, l_phys= 3.211e-04, l_d= 5.264e-04, l_hb= 7.595e-04, t= 3.89s
It 19400: loss=1.3614e-01, l_phys= 3.223e-04, l_d= 5.260e-04, l_hb= 7.608e-04, t= 3.88s
It 19500: loss=1.3602e-01, l_phys= 3.206e-04, l_d= 5.259e-04, l_hb= 7.602e-04, t= 3.88s
It 19600: loss=1.3593e-01, l_phys= 3.195e-04, l_d= 5.261e-04, l_hb= 7.594e-04, t= 3.89s
It 19700: loss=1.3601e-01, l_phys= 3.205e-04, l_d= 5.270e-04, l_hb= 7.587e-04, t= 3.88s
It 19800: loss=1.3604e-01, l_phys= 3.203e-04, l_d= 5.265e-04, l_hb= 7.605e-04, t= 3.89s
It 19900: loss=1.3599e-01, l_phy

It 28200: loss=1.3351e-01, l_phys= 2.897e-04, l_d= 5.115e-04, l_hb= 7.565e-04, t= 3.92s
It 28300: loss=1.3349e-01, l_phys= 2.895e-04, l_d= 5.113e-04, l_hb= 7.566e-04, t= 3.88s
It 28400: loss=1.3347e-01, l_phys= 2.892e-04, l_d= 5.113e-04, l_hb= 7.564e-04, t= 3.91s
It 28500: loss=1.3345e-01, l_phys= 2.889e-04, l_d= 5.113e-04, l_hb= 7.564e-04, t= 3.91s
It 28600: loss=1.3343e-01, l_phys= 2.887e-04, l_d= 5.110e-04, l_hb= 7.565e-04, t= 3.96s
It 28700: loss=1.3342e-01, l_phys= 2.886e-04, l_d= 5.109e-04, l_hb= 7.567e-04, t= 3.92s
It 28800: loss=1.3339e-01, l_phys= 2.882e-04, l_d= 5.109e-04, l_hb= 7.564e-04, t= 3.91s
It 28900: loss=1.3337e-01, l_phys= 2.879e-04, l_d= 5.109e-04, l_hb= 7.563e-04, t= 3.94s
It 29000: loss=1.3335e-01, l_phys= 2.876e-04, l_d= 5.108e-04, l_hb= 7.563e-04, t= 3.92s
It 29100: loss=1.3333e-01, l_phys= 2.874e-04, l_d= 5.107e-04, l_hb= 7.562e-04, t= 3.92s
It 29200: loss=1.3332e-01, l_phys= 2.873e-04, l_d= 5.107e-04, l_hb= 7.562e-04, t= 3.91s
It 29300: loss=1.3329e-01, l_phy

It 37600: loss=1.3204e-01, l_phys= 2.697e-04, l_d= 5.067e-04, l_hb= 7.547e-04, t= 3.84s
It 37700: loss=1.3209e-01, l_phys= 2.703e-04, l_d= 5.066e-04, l_hb= 7.551e-04, t= 3.92s
It 37800: loss=1.3202e-01, l_phys= 2.693e-04, l_d= 5.065e-04, l_hb= 7.547e-04, t= 3.96s
It 37900: loss=1.3209e-01, l_phys= 2.703e-04, l_d= 5.066e-04, l_hb= 7.550e-04, t= 3.93s
It 38000: loss=1.3201e-01, l_phys= 2.691e-04, l_d= 5.063e-04, l_hb= 7.550e-04, t= 3.95s
It 38100: loss=1.3198e-01, l_phys= 2.688e-04, l_d= 5.064e-04, l_hb= 7.547e-04, t= 3.89s
It 38200: loss=1.3197e-01, l_phys= 2.687e-04, l_d= 5.065e-04, l_hb= 7.546e-04, t= 3.91s
It 38300: loss=1.3198e-01, l_phys= 2.687e-04, l_d= 5.062e-04, l_hb= 7.550e-04, t= 3.97s
It 38400: loss=1.3195e-01, l_phys= 2.684e-04, l_d= 5.063e-04, l_hb= 7.547e-04, t= 3.88s
It 38500: loss=1.3194e-01, l_phys= 2.682e-04, l_d= 5.063e-04, l_hb= 7.547e-04, t= 3.85s
It 38600: loss=1.3193e-01, l_phys= 2.681e-04, l_d= 5.063e-04, l_hb= 7.546e-04, t= 3.83s
It 38700: loss=1.3191e-01, l_phy

It 47000: loss=1.3126e-01, l_phys= 2.588e-04, l_d= 5.047e-04, l_hb= 7.538e-04, t= 3.82s
It 47100: loss=1.3126e-01, l_phys= 2.587e-04, l_d= 5.047e-04, l_hb= 7.539e-04, t= 3.82s
It 47200: loss=1.3125e-01, l_phys= 2.586e-04, l_d= 5.046e-04, l_hb= 7.539e-04, t= 3.82s
It 47300: loss=1.3124e-01, l_phys= 2.585e-04, l_d= 5.046e-04, l_hb= 7.538e-04, t= 3.82s
It 47400: loss=1.3125e-01, l_phys= 2.586e-04, l_d= 5.046e-04, l_hb= 7.540e-04, t= 3.85s
It 47500: loss=1.3123e-01, l_phys= 2.583e-04, l_d= 5.047e-04, l_hb= 7.537e-04, t= 3.86s
It 47600: loss=1.3122e-01, l_phys= 2.582e-04, l_d= 5.046e-04, l_hb= 7.538e-04, t= 3.93s
It 47700: loss=1.3121e-01, l_phys= 2.581e-04, l_d= 5.046e-04, l_hb= 7.538e-04, t= 3.82s
It 47800: loss=1.3120e-01, l_phys= 2.580e-04, l_d= 5.046e-04, l_hb= 7.537e-04, t= 3.84s
It 47900: loss=1.3119e-01, l_phys= 2.579e-04, l_d= 5.046e-04, l_hb= 7.537e-04, t= 3.83s
It 48000: loss=1.3119e-01, l_phys= 2.578e-04, l_d= 5.045e-04, l_hb= 7.537e-04, t= 3.90s
It 48100: loss=1.3118e-01, l_phy

In [4]:
# Model evaluation:
preds = model(X_whole)
u_pred = preds[:,0]
u_pred = tf.transpose(tf.reshape(u_pred,[hsp.shape[1],-1]))

re = np.linalg.norm(u_pred - hsp,'fro')/np.linalg.norm(hsp,'fro')
print('The relative error is {:.4f}'.format(re))

The relative error is 0.0965
