In [1]:
import os
import numpy as np
import scipy.io
import tensorflow as tf
from tensorflow.keras import layers
import h5py
import math
from tensorflow import keras
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
import pandas as pd
from scipy.linalg import expm
from sklearn.preprocessing import MinMaxScaler
from numpy import inf, nan
import time
%run my_functions.ipynb 

2023-01-09 15:37:05.602328: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


In [2]:
## Select GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
tf_device = '/gpu:1'

In [3]:
## Data Loading . . . orignally network trained with 40 million dataset

# Tissue parameters
train_y_file = scipy.io.loadmat ('data/Train_Y.mat')
Train_Y = train_y_file['Train_Y']
## MTC signal at 1uT
mtc_file_1p0_file = scipy.io.loadmat('data/MTC_ref_1p0.mat')
MTC_ref_1p0 = mtc_file_1p0_file['MTC_ref_1p0']
## MTC signal at 1.5uT
mtc_file_1p5_file = scipy.io.loadmat('data/MTC_ref_1p5.mat')
MTC_ref_1p5 = mtc_file_1p5_file['MTC_ref_1p5']


In [None]:
## Load trained deep Bloch simulator to generated MTC-MRF signals from tissue parameters
model_BS_40DS = tf.keras.models.load_model('saved_models/deep_BS_RNN_40')

## Load trained deep Bloch simulator to generated Zref signals at 1uT and 1.5 uT
model_BS_2_point = tf.keras.models.load_model('saved_models/deep_BS_2_outputs')

In [None]:
scaler_in = MinMaxScaler()

## Scale tissue parameters (0, 1)
Train_Y_fit = scaler_in.fit_transform(np.squeeze(Train_Y))
Train_X_estimated = model_BS_40DS.predict(np.expand_dims(Train_Y_fit,2))  

## Add noise to network simulated model, SNR=46
Train_X = mtc_mrf_noisy (np.squeeze(Train_X_estimated), 46)

In [6]:
##   Traning and Testing Split 9:1 ratio
split = int(0.9 * len(Train_X))
x_train = Train_X[0:split]
x_test  = Train_X[split:]
GT_train = Train_Y[0:split]          ## kmw, M0m, T2m, T1w
GT_test = Train_Y[split:]

t1t2w_para_train = GT_train [:,4:5]         ## T1w/T2w
t1t2w_para_test = GT_test [:,4:5]    

Zref_1p0_train = MTC_ref_1p0[0:split]
test_MTC_ref_1p0_test = MTC_ref_1p0[split:]
Zref_1p5_train = MTC_ref_1p5[0:split]
test_MTC_ref_1p5_test = MTC_ref_1p5[split:]

In [7]:
## Convert numpy array to tensor
x_train = np.expand_dims(x_train,2)
GT_train = np.expand_dims(GT_train,2) 
x_test = np.expand_dims(x_test,2) 
GT_test = np.expand_dims(GT_test,2) 


train_input = tf.cast(x_train,  tf.float32)
train_target = tf.cast(GT_train,  tf.float32)
val_input = tf.cast(x_test,  tf.float32)
val_target = tf.cast(GT_test,  tf.float32)

train_t2w = tf.cast(t1t2w_para_train,  tf.float32)
val_t2w = tf.cast(t1t2w_para_test,  tf.float32)

train_zref_1p0 = tf.cast(Zref_1p0_train,  tf.float32)
val_zref_1p0 = tf.cast(test_MTC_ref_1p0_test,  tf.float32)

train_zref_1p5 = tf.cast(Zref_1p5_train,  tf.float32)
val_zref_1p5 = tf.cast(test_MTC_ref_1p5_test,  tf.float32)

In [8]:
class LSTM_CNN_network(tf.keras.Model):
    def __init__(self, n_classes):
        super(LSTM_CNN_network, self).__init__()
        self.bi_lstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, return_sequences=True, activation='relu', input_shape= (n_classes,1 )))
        self.conv = tf.keras.layers.Conv1D(filters=256, kernel_size=3, padding='same', activation='relu')
        self.flat = tf.keras.layers.Flatten()
        self.dense_inter = tf.keras.layers.Dense(512, activation='relu')
        self.dense_out = tf.keras.layers.Dense(4, activation='sigmoid')    
        self.pool = tf.keras.layers.MaxPool1D(2)
        
    def call(self,input, training=False):                     
        x = self.bi_lstm(input)
        x = self.conv(x)
        x = self.pool(x)
        x = self.conv(x)
        x = self.pool(x)
        x = self.conv(x)  ##
        x = self.pool(x)
        x = self.conv(x)
        x = self.pool(x)
        x = self.flat (x)
        x = self.dense_inter(x)
        x = self.dense_inter(x)
        x = self.dense_inter(x)
        x = self.dense_inter(x)
        out = self.dense_out(x)
        max_4 = tf.constant( [100.0,  0.17,  1e-04,  3.0])  
        min_4 = tf.constant( [5.0,    0.02,  1e-06,  0.2])   
        out_denorm = tf.math.multiply(out, max_4-min_4) + min_4
        
        return out_denorm   

In [9]:
model_recon = LSTM_CNN_network(40)
model_recon.build((None,40,1))



In [10]:
# Instantiate a loss function.
def loss_fn(output, target, signal_out_1, signal_target_1, signal_out_1p5, signal_target_1p5):

    max_4 = tf.constant( [100.0,  0.17,  1e-04,  3.0])  
    min_4 = tf.constant( [5.0,    0.02,  1e-06,  0.2]) 
    output_norm = tf.math.divide(output-min_4,max_4-min_4)
    target4_norm = tf.math.divide(target-min_4,max_4-min_4)
    diff_norm = (output_norm-target4_norm)**2
    
    diff_signal_1 = (signal_out_1-signal_target_1)**2

    diff_signal_1p5 = (signal_out_1p5-signal_target_1p5)**2
    
    mean_diff_norm = tf.math.reduce_mean(diff_norm)
    mean_signal_norm_1 = tf.math.reduce_mean(diff_signal_1)
    mean_signal_norm_1p5 = tf.math.reduce_mean(diff_signal_1p5)

    error_total = (mean_diff_norm + 30*mean_signal_norm_1 + 15*mean_signal_norm_1p5)
    
    return error_total, mean_diff_norm, mean_signal_norm_1, mean_signal_norm_1p5

In [11]:
batch_size = 1000
batch_size_val = 1000
# Prepare the training dataset
train_dataset = tf.data.Dataset.from_tensor_slices((train_input, train_target, train_t2w, train_zref_1p0, train_zref_1p5))   
train_dataset = train_dataset.shuffle(buffer_size=70000, reshuffle_each_iteration=True).batch(batch_size)

val_dataset = tf.data.Dataset.from_tensor_slices((val_input, val_target, val_t2w, val_zref_1p0, val_zref_1p5)) 
val_dataset = val_dataset.shuffle(buffer_size=70000, reshuffle_each_iteration=False).batch(batch_size_val)

learning_rate_fn = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=1e-4, decay_steps=10000, decay_rate=0.9)    
optimizer = tf.keras.optimizers.Adam(learning_rate = 0.00001) 

In [None]:
epochs = 40

loss_hist = []
loss_hist_para = []
loss_hist_zref_1 = []
loss_hist_zref_1p5 = []
loss_hist_val = []
loss_hist_val_para = []
loss_hist_val_zref_1 =[]
loss_hist_val_zref_1p5 =[]
batch_loss = 0 

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch+1,))
    start_time = time.time()
    
   ################################## training ###########################################

    for step, (x_batch_train, y_batch_train, t1w_t2w, MTC_ref_batch_1p0, MTC_ref_batch_1p5) in enumerate(train_dataset):
        with tf.GradientTape() as tape:        
            ## Estimation of four tissue parameters (kmw, M0m, T2m, T1w)
            x_pred = model_recon(x_batch_train, training=True)      
            ## T2w is known to us
            all_tiss = tf.concat((x_pred, t1w_t2w), 1)
            ## scaling of all five tissue parameters (0, 1)
            all_tiss_scaled = scaler_in.fit_transform(all_tiss)
            ## deepBS for Zref(+3.5 ppm) at 1uT and 1.5 uT
            MTC_3p5 = model_BS_2_point(tf.expand_dims(all_tiss_scaled,2), training=False) 
            ## Loss calculation
            loss_value, loss_para, loss_z_1, loss_z_1p5 = loss_fn (x_pred, tf.squeeze(y_batch_train[:,0:4,0]), MTC_3p5 [:,0], MTC_ref_batch_1p0, MTC_3p5 [:,1], MTC_ref_batch_1p5)
     
        grads = tape.gradient(loss_value, model_recon.trainable_weights)
        optimizer.apply_gradients(zip(grads, model_recon.trainable_weights))
        
 
    loss_hist.append(np.mean(loss_value))   
    loss_hist_para.append(np.mean(loss_para))   
    loss_hist_zref_1.append(np.mean(loss_z_1))   
    loss_hist_zref_1p5.append(np.mean(loss_z_1p5))
    
    print("Training loss (for one batch) at step %d: %.4f"  % (step, float(np.mean(loss_value))) )
    print("Validation para loss: %.5f" % (np.mean(loss_para)))
    print(" Zref 1 loss: %.5f" % (np.mean(loss_z_1)))
    print(" Zref 1.5 loss: %.4f" % (np.mean(loss_z_1p5)))
    
   ################################## validation ###########################################
   
    for x_batch_val, y_batch_val, t1w_t2w_val, MTC_ref_val_1p0, MTC_ref_val_1p5 in val_dataset:
        val_data = model_recon(x_batch_val, training=False)
        all_tiss_val = tf.concat((val_data, t1w_t2w_val), 1)
        all_tiss_val_scaled = scaler_in.fit_transform(all_tiss_val)
        MTC_3p5_val = model_BS_2_point(tf.expand_dims(all_tiss_val_scaled,2), training=False) 
        loss_value_val, loss_para_val, loss_z_1_val, loss_z_1p5_val  = loss_fn (val_data, tf.squeeze(y_batch_val[:,0:4,0]), MTC_3p5_val [:,0], MTC_ref_val_1p0, MTC_3p5_val [:,1], MTC_ref_val_1p5)

    loss_hist_val.append(np.mean(loss_value_val))   
    loss_hist_val_para.append(np.mean(loss_para_val))   
    loss_hist_val_zref_1.append(np.mean(loss_z_1_val))   
    loss_hist_val_zref_1p5.append(np.mean(loss_z_1p5_val))
    
    print("Validation loss: %.5f" % (np.mean(loss_value_val)))
    print("Validation para loss: %.5f" % (np.mean(loss_para_val)))
    print("Validation Zref 1 loss: %.5f" % (np.mean(loss_z_1_val)))
    print("Validation Zref 1.5 loss: %.4f" % (np.mean(loss_z_1p5_val)))
    
    print("Time taken: %.2fs" % (time.time() - start_time))