This is the code for the "Cranial implant prediction by learning an ensemble of slice-based skull completion networks"
please just run this code untill the build model part. At that part you find the next instruction.

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
import glob
import numpy as np
import nrrd
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import sqlite3 as sql
from tensorflow import keras
from tensorflow.keras import layers
from PIL import Image
import cv2
from skimage import transform, measure
import io
import time
import math
import random
import pandas as pd
from numba import cuda

# Code For Generate Model

In [None]:
def SEBlock2(x, sq_rate=4):
    inp_x = x
    filters = x.shape[-1]
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Reshape((filters, ))(x)
    x = layers.Dense(filters//sq_rate)(x)
    x = layers.LeakyReLU(.2)(x)
    x = layers.Dense(filters, activation='sigmoid')(x)
    x = tf.nn.softmax(x)
    x = layers.Reshape((1, 1, filters))(x)
    x = inp_x * x
    return x

In [None]:
def SEBlock(x, sq_rate=4):
    inp_x = x
    filters = x.shape[-1]
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Reshape((filters, ))(x)
    x = layers.Dense(filters//sq_rate)(x)
    x = layers.LeakyReLU(.3)(x)
    x = layers.Dense(filters, activation='sigmoid')(x)
    x = tf.nn.softmax(x)
    x = layers.Reshape((1, 1, filters))(x)
    x = inp_x * x * filters
    return x

In [None]:
def ChannelGainInit(shape, dtype=None):
    print(shape, dtype)
    return 128*tf.ones(shape, dtype=dtype)

In [None]:
class ChannelGain(layers.Layer):
    def __init__(self, ini_val):
        super().__init__()
        self.ini_val = ini_val
    def build(self, input_shape):
        self.gain = self.add_weight(
            shape=(1, ),
            initializer=ChannelGainInit,
            trainable=True
        )
    def call(self, inputs):
        return inputs * self.gain
    def get_config(self):
        return {'ini_val': self.ini_val}

In [None]:
def LeakyConv2D(x, f, ks, st=1, pd='same', dr=1):
    x = layers.Conv2D(f, ks, strides=st, padding=pd, dilation_rate=dr)(x)
    x = layers.LeakyReLU(.2)(x)
    return x

In [None]:
def LeakyConv2DT(x, f, ks, st=1, pd='same', dr=1):
    x = layers.Conv2DTranspose(f, ks, strides=st, padding=pd, dilation_rate=dr)(x)
    x = layers.LeakyReLU(.2)(x)
    return x

In [None]:
def InceptionHead(shape, f):
    ginp_x = layers.Input(shape)
    x = ginp_x
    x = LeakyConv2D(x, f, 3, st=2)
    x = LeakyConv2D(x, f, 3)
    x = LeakyConv2D(x, f*2, 3)
    inp_x = x
    out_x = []
    # Path 1
    x = inp_x
    x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
    out_x.append(x)
    # Path_2
    x = inp_x
    x = LeakyConv2D(x, f*3, 3, st=2)
    out_x.append(x)
    # Combine
    x = layers.Concatenate()(out_x)
    inp_x = x
    out_x = []
    # Path 1
    x = inp_x
    x = LeakyConv2D(x, f*2, 1)
    x = LeakyConv2D(x, f*3, 3)
    out_x.append(x)
    # Path 2
    x = inp_x
    x = LeakyConv2D(x, f*2, 1)
    x = LeakyConv2D(x, f*2, (1, 7))
    x = LeakyConv2D(x, f*2, (7, 1))
    x = LeakyConv2D(x, f*3, 3)
    out_x.append(x)
    # Combine
    x = layers.Concatenate()(out_x)
    inp_x = x
    out_x = []
    # Path 1
    x = inp_x
    x = LeakyConv2D(x, f*6, 3, st=2)
    out_x.append(x)
    # Path 2
    x = inp_x
    x = layers.MaxPooling2D(2)(x)
    out_x.append(x)
    # Combine
    x = layers.Concatenate()(out_x)
    return keras.Model(ginp_x, x)

In [None]:
def InceptionA(shape, f):
    ginp_x = layers.Input(shape)
    x = ginp_x
    inp_x = x
    out_x = []
    # Path 1
    x = inp_x
    ori_shape = x.shape[1:-1]
    ori_filters = x.shape[-1]
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Reshape((1, 1, ori_filters))(x)
    x = layers.UpSampling2D(ori_shape)(x)
    x = LeakyConv2D(x, f*3, 1)
    out_x.append(x)
    # Path 2
    x = inp_x
    x = LeakyConv2D(x, f*3, 1)
    out_x.append(x)
    # Path 3
    x = inp_x
    x = LeakyConv2D(x, f*2, 1)
    x = LeakyConv2D(x, f*3, 3)
    out_x.append(x)
    # Path 4
    x = inp_x
    x = LeakyConv2D(x, f*2, 1)
    x = LeakyConv2D(x, f*3, 3)
    x = LeakyConv2D(x, f*3, 3)
    out_x.append(x)
    # Combine
    x = layers.Concatenate()(out_x)
    return keras.Model(ginp_x, x)

In [None]:
def InceptionB(shape):
    ginp_x = layers.Input(shape)
    f = shape[-1]//8
    x = ginp_x
    inp_x = x
    out_x = []
    # Path 1
    x = inp_x
    ori_shape = x.shape[1:-1]
    ori_filters = x.shape[-1]
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Reshape((1, 1, ori_filters))(x)
    x = layers.UpSampling2D(ori_shape)(x)
    x = LeakyConv2D(x, f*2, 1)
    # out_x.append(x)
    # Path 2
    x = inp_x
    x = LeakyConv2D(x, f*6, 1)
    # out_x.append(x)
    # Path 3
    x = inp_x
    # x = LeakyConv2D(x, f*3, 1)
    x = LeakyConv2D(x, f*4, (1, 7))
    x = LeakyConv2D(x, f*4, (7, 1))
    out_x.append(x)
    # Path 4
    x = inp_x
    # x = LeakyConv2D(x, f*3, 1)
    x = LeakyConv2D(x, f*4, (1, 7))
    x = LeakyConv2D(x, f*4, (7, 1))
    x = LeakyConv2D(x, f*4, (1, 7))
    x = LeakyConv2D(x, f*4, (7, 1))
    out_x.append(x)
    # Combine
    x = layers.Concatenate()(out_x)
    return keras.Model(ginp_x, x)

In [None]:
def InceptionC(shape, f):
    ginp_x = layers.Input(shape)
    x = ginp_x
    inp_x = x
    out_x = []
    # Path 1
    x = inp_x
    ori_shape = x.shape[1:-1]
    ori_filters = x.shape[-1]
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Reshape((1, 1, ori_filters))(x)
    x = layers.UpSampling2D(ori_shape)(x)
    x = LeakyConv2D(x, f*4, 1)
    out_x.append(x)
    # Path 2
    x = inp_x
    x = LeakyConv2D(x, f*4, 1)
    out_x.append(x)
    # Path 3
    x = inp_x
    x = LeakyConv2D(x, f*6, 1)
    inp2_x = x
    # Path L
    x = inp2_x
    x = LeakyConv2D(x, f*4, (1, 3))
    out_x.append(x)
    # Path R
    x = inp2_x
    x = LeakyConv2D(x, f*4, (3, 1))
    out_x.append(x)
    # Path 4
    x = inp_x
    x = LeakyConv2D(x, f*6, 1)
    x = LeakyConv2D(x, f*8, (1, 3))
    x = LeakyConv2D(x, f*8, (3, 1))
    inp2_x = x
    # Path L
    x = inp2_x
    x = LeakyConv2D(x, f*4, (1, 3))
    out_x.append(x)
    # Path R
    x = inp2_x
    x = LeakyConv2D(x, f*4, (3, 1))
    out_x.append(x)
    # Combine
    x = layers.Concatenate()(out_x)
    return keras.Model(ginp_x, x)    


In [None]:
def ReductionA(shape, f):
    ginp_x = layers.Input(shape)
    x = ginp_x
    inp_x = x
    out_x = []
    # Path 1
    x = inp_x
    x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
    out_x.append(x)
    # Path 2
    x = inp_x
    x = LeakyConv2D(x, f*3, 1, st=2)
    out_x.append(x)
    # Path 3
    x = inp_x
    x = LeakyConv2D(x, f*2, 1)
    x = LeakyConv2D(x, f*2, 3)
    x = LeakyConv2D(x, f*3, 3, st=2)
    out_x.append(x)
    # Combine
    x = layers.Concatenate()(out_x)
    return keras.Model(ginp_x, x)

In [None]:
def ReductionB(shape, f):
    ginp_x = layers.Input(shape)
    x = ginp_x
    inp_x = x
    out_x = []
    # Path 1
    x = inp_x
    x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
    # out_x.append(x)
    # Path 2
    x = inp_x
    # x = LeakyConv2D(x, f*3, 1)
    x = LeakyConv2D(x, f*3, 3, st=2)
    out_x.append(x)
    # Path 3
    x = inp_x
    # x = LeakyConv2D(x, f*4, 1)
    x = LeakyConv2D(x, f*4, (1, 7))
    x = LeakyConv2D(x, f*4, (7, 1))
    x = LeakyConv2D(x, f*5, 3, st=2)
    out_x.append(x)
    # Combine
    x = layers.Concatenate()(out_x)
    return keras.Model(ginp_x, x)

In [None]:
def UpSampleA(shape, f):
    ginp_x = layers.Input(shape)
    x = ginp_x
    x = LeakyConv2DT(x, f*5, 3, st=2)
    x = LeakyConv2DT(x, f*5, (7, 1))
    x = LeakyConv2DT(x, f*5, (1, 7))
    # x = LeakyConv2DT(x, f*4, 1)
    return keras.Model(ginp_x, x)

In [None]:
def get_size(shape):
    ret = 1
    for k in shape:
        ret *= k
    return ret

In [None]:
def RNN():
    ginp_x = layers.Input((5, 512, 512))
    x = ginp_x
    x = layers.Reshape((5, 512, 512, 1))(x)
    x = tf.keras.layers.ConvLSTM2D(4, 1, activation='relu', input_shape=(5,512,512,1), padding='same')(x)
    # Encoder
    ks = [5, 5, 5, 7, 7]
    ef = [4, 8, 16, 32, 128]
    for d, f in enumerate(ef):
        x = LeakyConv2D(x, f, ks=ks[d])
        if d == len(ef)-2:
            tmp_x = x
        if d != len(ef)-1:
            x = LeakyConv2D(x, f, ks=ks[d], st=2)
        else:
            x = LeakyConv2D(x, f, ks=ks[d])
    # Decoder
    x = SEBlock(x, sq_rate=8)
    ks = [7, 7, 5, 5, 5]
    df = [128, 32, 16, 8, 4]
    for d, f in enumerate(df):
        if d == 1:
            tmp_x = LeakyConv2D(tmp_x, 128, 3, st=2)
            tmp_x = LeakyConv2D(tmp_x, 128, 3)
            tmp_x = SEBlock(tmp_x, sq_rate=8)
            x = layers.Lambda(lambda x: x[0]*x[1])([x, tmp_x])
        if d != 0:
            x = layers.UpSampling2D(2)(x)
            x = LeakyConv2D(x, f, ks=ks[d])
        else:
            x = LeakyConv2D(x, f, ks=ks[d])
        x = LeakyConv2D(x, f, ks=ks[d])
    x = LeakyConv2D(x, 1, ks=3)
    x = tf.sigmoid(x)
    x = layers.Reshape((512, 512))(x)
    return keras.Model(ginp_x, x)


In [None]:
def CNN():
    ginp_x = layers.Input((512, 512))
    x = ginp_x
    x = layers.Reshape((512, 512, 1))(x)
    # Encoder
    ks = [5, 5, 5, 7, 7]
    ef = [4, 8, 16, 32, 128]
    for d, f in enumerate(ef):
        x = LeakyConv2D(x, f, ks=ks[d])
        if d == len(ef)-2:
            tmp_x = x
        if d != len(ef)-1:
            x = LeakyConv2D(x, f, ks=ks[d], st=2)
        else:
            x = LeakyConv2D(x, f, ks=ks[d])
    # Decoder
    x = SEBlock(x, sq_rate=8)
    ks = [7, 7, 5, 5, 5]
    df = [128, 32, 16, 8, 4]
    for d, f in enumerate(df):
        if d == 1:
            tmp_x = LeakyConv2D(tmp_x, 128, 3, st=2)
            tmp_x = LeakyConv2D(tmp_x, 128, 3)
            tmp_x = SEBlock(tmp_x, sq_rate=8)
            x = layers.Lambda(lambda x: x[0]*x[1])([x, tmp_x])
        if d != 0:
            x = layers.UpSampling2D(2)(x)
            x = LeakyConv2D(x, f, ks=ks[d])
        else:
            x = LeakyConv2D(x, f, ks=ks[d])
        x = LeakyConv2D(x, f, ks=ks[d])
    x = LeakyConv2D(x, 1, ks=3)
    x = tf.sigmoid(x)
    x = layers.Reshape((512, 512))(x)
    return keras.Model(ginp_x, x)

    # build the model

If you just want to use our pretrained model to reproduce our result. you can just go to the "Uses the trained model to predict the evaluation data", which is in the bottome of the document.

If you want to try to train a new model. please use the following instruction.
If you want to train the CNN model, you need set the model = CNN in the following box. If you want to train the RNN model, you need to set the model = RNN() in the following box. 
Then just run this code untill next instruction.

In [None]:
model = CNN()
# model = RNN()
model.summary()

In [None]:
def dice_loss_imp(imp_image, pre_image, smooth=1e-6):
    intersection = tf.reduce_sum(tf.abs(imp_image*pre_image))
    return 1-(2.*intersection+smooth)/(tf.reduce_sum(tf.square(imp_image))+tf.reduce_sum(tf.square(pre_image))+smooth)

In [None]:
def compute_loss(imp_image, def_image):
    pre_image = model(def_image)
    loss = tf.zeros(shape=())
    loss = loss + dice_loss_imp(imp_image, pre_image)
    return loss

In [None]:
@tf.function
def compute_loss_and_grads(imp_image, def_image):
    with tf.GradientTape() as tape:
        loss = compute_loss(imp_image, def_image)   
    grads = tape.gradient(loss, model.trainable_weights)
    return loss, grads

In our method we will build three model that base on the method of the slice. 

# Code For Generate Y Model

If you want to train the CNN model, please go to the "train the CNN model" part. If you want to train the RNN model, please go to the "train the RNN model" part. 

    # Train the CNN model

In [None]:
skulls = glob.glob("./CNN_dataset/implant/registrations/y_axis/*.png")
case = sorted([os.path.basename(s).split(".")[0] for s in skulls])
print(len(case))
t_imp_list=[]
for i, tcase in enumerate(case):
    t_imp_list.append(tcase)

In [None]:
weights = model.get_weights()
weights = [np.random.permutation(w.flat).reshape(w.shape) for w in weights]
model.set_weights(weights)
optimizer=keras.optimizers.Adam(learning_rate=.00005, clipnorm=1.)

In [None]:
random.shuffle (t_imp_list)
sum_loss=[]

Running the code in the following box for one time is one iteration. Generally, we need to run it for at least 8 times to get a better result. If you want to save the checkpoint for evary iteration, please change the filename in the filepath variable。

In [None]:
i=0
t_list2=[]
avgeloss = []
error_image=0
for j in range(len(t_imp_list)):   
    
    x = np.empty((1, 512, 512), dtype='float32')
    y = np.empty((1, 512, 512), dtype='float32')

    ID = t_imp_list[j]
    try: 
        def_img = Image.open("./CNN_dataset/defective_skull/registrations/y_axis/" +ID+ ".png")
        imp_img = Image.open("./CNN_dataset/implant/registrations/y_axis/" +ID+ ".png")
        def_img_array = np.array(def_img)
        imp_img_array = np.array(imp_img)
    except:
        print("error:image", ID)
        error_image=1
    if error_image==1:
        t_list2.append(ID)
        error_image=0   
    else:
        i=i+1
        x[0]=def_img_array
        y[0]=imp_img_array

        def_image = tf.convert_to_tensor(x)
        imp_image = tf.convert_to_tensor(y)


        loss, grads = compute_loss_and_grads(
            imp_image, def_image
        )
        test = list(zip(grads, model.trainable_weights))

        optimizer.apply_gradients(test)
        avgeloss.append(loss)
        print("Iteration %d: loss=%.2f" % (i, loss))
        if i % 1000 == 0:
            aveg_loss=np.mean(avgeloss)
            sum_loss.append(aveg_loss)
            filepath='./modelForY/test_model_1.h5'
            model.save_weights(filepath)
            

In [None]:
sum_loss

    # Train the RNN model

Read the image list for a speical folder. please set the filepath of the skulls variable to the place where you put your nrrd file. When you changes this variable you also need to change the path for the input_image and output_image.

In [None]:
skulls = glob.glob("./nrrd/implant/bilateral/*.nrrd")
case = sorted([os.path.basename(s).split(".")[0] for s in skulls])
t_imp_list=[]
print(len(case))
for i, tcase in enumerate(case):
    t_imp_list.append(tcase)

In [None]:
weights = model.get_weights()
weights = [np.random.permutation(w.flat).reshape(w.shape) for w in weights]
model.set_weights(weights)
optimizer=keras.optimizers.Adam(learning_rate=.00005, clipnorm=1.)

In [None]:
sum_loss=[]

Running the code in the following box for one time is one iteration for a folder. Generally, we need to run it for at least 2 times to get a better result. For the random1 and random2, you need to first run this model on the first three folders for 2 time, then run it on the random1 and random2 folders for 2 time. If you want to save the checkpoint for evary iteration, please change the filename in the filepath variable.

In [None]:
for j in range(0, 100, 5):   
    avgeloss = []
    input_image = np.empty((10, 512, 512, 512), dtype='float32')
    output_image = np.empty((10, 512, 512, 512), dtype='float32')
    imp_image = np.empty((1, 512, 512), dtype='float32')
    def_image = np.empty((1, 5, 512, 512), dtype='float32')
    
    for i in range(10):
        ID = t_imp_list[j+i]
        input_image[i], option=nrrd.read("./nrrd/defective_skull/bilateral/"+ID+".nrrd")
        output_image[i], option=nrrd.read("./nrrd/implant/bilateral/"+ID+".nrrd")
    image_range=np.arange(5120)
    random.shuffle (image_range)
    for i in range(len(image_range)):
        get_image = image_range[i]
        if get_image < 512:
            skullNo = 0
            imageNo = get_image
        elif (512*1) <= get_image < (512*2):
            skullNo = 1
            imageNo = get_image-(512*1)
        elif (512*2) <= get_image < (512*3):
            skullNo = 2
            imageNo = get_image-(512*2)
        elif (512*3) <= get_image < (512*4):
            skullNo = 3
            imageNo = get_image-(512*3)
        elif (512*4) <= get_image < (512*5):
            skullNo = 4
            imageNo = get_image-(512*4)
        elif (512*5) <= get_image < (512*6):
            skullNo = 5
            imageNo = get_image-(512*5)
        elif (512*6) <= get_image < (512*7):
            skullNo = 6
            imageNo = get_image-(512*6)
        elif (512*7) <= get_image < (512*8):
            skullNo = 7
            imageNo = get_image-(512*7)
        elif (512*8) <= get_image < (512*9):
            skullNo = 8
            imageNo = get_image-(512*8)
        elif (512*9) <= get_image < (512*10):
            skullNo = 9
            imageNo = get_image-(512*9)
            
        if imageNo<2 or imageNo > 508:
            print(skullNo)
            print(imageNo)
        else: 
            for k in range(5):
                def_image[0, k]=input_image[skullNo, :, imageNo-2+k, :]
            imp_image[0]=output_image[skullNo, :, imageNo, :]
            
            def_image_ten = tf.convert_to_tensor(def_image)
            imp_image_ten = tf.convert_to_tensor(imp_image)
    
            loss, grads = compute_loss_and_grads(
                imp_image_ten, def_image_ten
            )
            test = list(zip(grads, model.trainable_weights))

            optimizer.apply_gradients(test)
            avgeloss.append(loss)
            print("Iteration %d: loss=%.2f" % (i, loss))
            if i % 1000 == 0 or i == 5119:
                aveg_loss=np.mean(avgeloss)
                sum_loss.append(aveg_loss)
                filepath='./modelForY/bilateral_model_1.h5'
                model.save_weights(filepath)

In [None]:
sum_loss

# Code For Generate Z Model

If you want to train the CNN model, please go to the "train the CNN model" part. If you want to train the RNN model, please go to the "train the RNN model" part. 

    #Train the CNN model

In [None]:
skulls = glob.glob("./CNN_dataset/implant/registrations/z_axis/*.png")
case = sorted([os.path.basename(s).split(".")[0] for s in skulls])
t_imp_list=[]
print(len(case))
for i, tcase in enumerate(case):
        t_imp_list.append(tcase)

In [None]:
weights = model.get_weights()
weights = [np.random.permutation(w.flat).reshape(w.shape) for w in weights]
model.set_weights(weights)
optimizer=keras.optimizers.Adam(learning_rate=.00005, clipnorm=1.)

In [None]:
random.shuffle (t_imp_list)
avgeloss = []
sum_loss=[]
i=0

Running the code in the following box for one time is one iteration. Generally, we need to run it for at least 12 times to get a better result. If you want to save the checkpoint for evary iteration, please change the filename in the filepath variable。

In [None]:
i=0
t_list2=[]
avgeloss = []
error_image=0
for j in range(len(t_imp_list)):   
    
    x = np.empty((1, 512, 512), dtype='float32')
    y = np.empty((1, 512, 512), dtype='float32')

    ID = t_imp_list[j]
    try:
        def_img = Image.open("./CNN_dataset/defective_skull/registrations/z_axis/" +ID+ ".png")
        imp_img = Image.open("./CNN_dataset/implant/registrations/z_axis/" +ID+ ".png")
        def_img_array = np.array(def_img)
        imp_img_array = np.array(imp_img)
    
    except:
        print("error:image", ID)
        error_image=1
    if error_image==1:
        t_list2.append(ID)
        error_image=0   
    else:
        i=i+1
        x[0]=def_img_array
        y[0]=imp_img_array

        def_image = tf.convert_to_tensor(x)
        imp_image = tf.convert_to_tensor(y)


        loss, grads = compute_loss_and_grads(
            imp_image, def_image
        )
        test = list(zip(grads, model.trainable_weights))

        optimizer.apply_gradients(test)
        avgeloss.append(loss)
        print("Iteration %d: loss=%.2f" % (i, loss))
        if i % 1000 == 0:
            aveg_loss=np.mean(avgeloss)
            sum_loss.append(aveg_loss)
            filepath='./modelForZ/test_model_1.h5'
            model.save_weights(filepath)

    # Train the RNN model

Read the image list for a speical folder. please set the filepath of the skulls variable to the place where you put your nrrd file. When you changes this variable you also need to change the path for the input_image and output_image.

In [None]:
skulls = glob.glob("./nrrd/implant/bilateral/*.nrrd")
case = sorted([os.path.basename(s).split(".")[0] for s in skulls])
t_imp_list=[]
print(len(case))
for i, tcase in enumerate(case):
    t_imp_list.append(tcase)

In [None]:
optimizer=keras.optimizers.Adam(learning_rate=.00005, clipnorm=1.)
weights = model.get_weights()
weights = [np.random.permutation(w.flat).reshape(w.shape) for w in weights]
model.set_weights(weights)

In [None]:
sum_loss=[]

Running the code in the following box for one time is one iteration for a folder. Generally, we need to run it for at least 4 times to get a better result. For the random1 and random2, you need to first run this model on the first three folders for 2 time, then run it on the random1 and random2 folders for 5 time. If you want to save the checkpoint for evary iteration, please change the filename in the filepath variable.

In [None]:
for j in range(0, 100, 5):   
    avgeloss = []
    input_image = np.empty((10, 512, 512, 512), dtype='float32')
    output_image = np.empty((10, 512, 512, 512), dtype='float32')
    imp_image = np.empty((1, 512, 512), dtype='float32')
    def_image = np.empty((1, 5, 512, 512), dtype='float32')
    
    for i in range(10):
        ID = t_imp_list[j+i]
        input_image[i], option=nrrd.read("./nrrd/defective_skull/bilateral/"+ID+".nrrd")
        output_image[i], option=nrrd.read("./nrrd/implant/bilateral/"+ID+".nrrd")
    image_range=np.arange(5120)
    random.shuffle (image_range)
    for i in range(len(image_range)):
        get_image = image_range[i]
        if get_image < 512:
            skullNo = 0
            imageNo = get_image
        elif (512*1) <= get_image < (512*2):
            skullNo = 1
            imageNo = get_image-(512*1)
        elif (512*2) <= get_image < (512*3):
            skullNo = 2
            imageNo = get_image-(512*2)
        elif (512*3) <= get_image < (512*4):
            skullNo = 3
            imageNo = get_image-(512*3)
        elif (512*4) <= get_image < (512*5):
            skullNo = 4
            imageNo = get_image-(512*4)
        elif (512*5) <= get_image < (512*6):
            skullNo = 5
            imageNo = get_image-(512*5)
        elif (512*6) <= get_image < (512*7):
            skullNo = 6
            imageNo = get_image-(512*6)
        elif (512*7) <= get_image < (512*8):
            skullNo = 7
            imageNo = get_image-(512*7)
        elif (512*8) <= get_image < (512*9):
            skullNo = 8
            imageNo = get_image-(512*8)
        elif (512*9) <= get_image < (512*10):
            skullNo = 9
            imageNo = get_image-(512*9)
            
        if imageNo<2 or imageNo > 508:
            print(skullNo)
            print(imageNo)
        else: 
            for k in range(5):
                def_image[0, k]=input_image[skullNo, :, :, imageNo-2+k]
            imp_image[0]=output_image[skullNo, :, :, imageNo]
            
            def_image_ten = tf.convert_to_tensor(def_image)
            imp_image_ten = tf.convert_to_tensor(imp_image)
    
            loss, grads = compute_loss_and_grads(
                imp_image_ten, def_image_ten
            )
            test = list(zip(grads, model.trainable_weights))

            optimizer.apply_gradients(test)
            avgeloss.append(loss)
            print("Iteration %d: loss=%.2f" % (i, loss))
            if i % 1000 == 0 or i == 5119:
                aveg_loss=np.mean(avgeloss)
                sum_loss.append(aveg_loss)
                filepath='./modelForZ/bilateral_test_model_1.h5'
                model.save_weights(filepath)

In [None]:
sum_loss

# Code for Generate X Model

If you want to train the CNN model, please go to the "train the CNN model" part. If you want to train the RNN model, please go to the "train the RNN model" part. 

    #Train the CNN model

In [None]:
skulls = glob.glob("./CNN_dataset/implant/registrations/x_axis/*.png")
case = sorted([os.path.basename(s).split(".")[0] for s in skulls])
t_imp_list=[]
print(len(case))
for i, tcase in enumerate(case):
        t_imp_list.append(tcase)

In [None]:
weights = model.get_weights()
weights = [np.random.permutation(w.flat).reshape(w.shape) for w in weights]
model.set_weights(weights)
optimizer=keras.optimizers.Adam(learning_rate=.00005, clipnorm=1.)

In [None]:
random.shuffle (t_imp_list)
t_list2=[]
avgeloss = []
sum_loss=[]
i=0

Running the code in the following box for one time is one iteration. Generally, we need to run it for at least 13 times to get a better result. If you want to save the checkpoint for evary iteration, please change the filename in the filepath variable。

In [None]:
error_image=0
for j in range(len(t_imp_list)):   
    
    x = np.empty((1, 512, 512), dtype='float32')
    y = np.empty((1, 512, 512), dtype='float32')
    ID = t_imp_list[j]
    try: 
        def_img = Image.open("./CNN_dataset/defective_skull/registrations/x_axis/" +ID+ ".png")
        imp_img = Image.open("./CNN_dataset/implant/registrations/x_axis/" +ID+ ".png")
        def_img_array = np.array(def_img)
        imp_img_array = np.array(imp_img)
    except:
        print("error:image", ID)
        error_image=1
    if error_image==1:
        t_list2.append(ID)
        error_image=0   
    else:
        i=i+1
        x[0]=def_img_array
        y[0]=imp_img_array

        def_image = tf.convert_to_tensor(x)
        imp_image = tf.convert_to_tensor(y)

        loss, grads = compute_loss_and_grads(
            imp_image, def_image
        )
        test = list(zip(grads, model.trainable_weights))

        optimizer.apply_gradients(test)
        avgeloss.append(loss)
        print("Iteration %d: loss=%.2f" % (i, loss))
        if i % 1000 == 0:
            aveg_loss=np.mean(avgeloss)
            sum_loss.append(aveg_loss)
            filepath='./modelForX/test_model_1.h5'
            model.save_weights(filepath)


    # Train the RNN model

Read the image list for a speical folder. please set the filepath of the skulls variable to the place where you put your nrrd file. When you changes this variable you also need to change the path for the input_image and output_image.

In [None]:
skulls = glob.glob("./nrrd/implant/bilateral/*.nrrd")
case = sorted([os.path.basename(s).split(".")[0] for s in skulls])
t_imp_list=[]
print(len(case))
for i, tcase in enumerate(case):
    t_imp_list.append(tcase)

In [None]:
optimizer=keras.optimizers.Adam(learning_rate=.00005, clipnorm=1.)
weights = model.get_weights()
weights = [np.random.permutation(w.flat).reshape(w.shape) for w in weights]
model.set_weights(weights)

In [None]:
sum_loss=[]

Running the code in the following box for one time is one iteration for a folder. Generally, we need to run it for at least 2 times to get a better result. For the random1 and random2, you need to first run this model on the first three folders for 2 time, then run it on the random1 and random2 folders for 2 time. If you want to save the checkpoint for evary iteration, please change the filename in the filepath variable.

In [None]:
for j in range(0, 100, 5):   
    avgeloss = []
    input_image = np.empty((10, 512, 512, 512), dtype='float32')
    output_image = np.empty((10, 512, 512, 512), dtype='float32')
    imp_image = np.empty((1, 512, 512), dtype='float32')
    def_image = np.empty((1, 5, 512, 512), dtype='float32')
    
    for i in range(10):
        ID = t_imp_list[j+i]
        input_image[i], option=nrrd.read("./nrrd/defective_skull/bilateral/"+ID+".nrrd")
        output_image[i], option=nrrd.read("./nrrd/implant/bilateral/"+ID+".nrrd")
    image_range=np.arange(5120)
    random.shuffle (image_range)
    for i in range(len(image_range)):
        get_image = image_range[i]
        if get_image < 512:
            skullNo = 0
            imageNo = get_image
        elif (512*1) <= get_image < (512*2):
            skullNo = 1
            imageNo = get_image-(512*1)
        elif (512*2) <= get_image < (512*3):
            skullNo = 2
            imageNo = get_image-(512*2)
        elif (512*3) <= get_image < (512*4):
            skullNo = 3
            imageNo = get_image-(512*3)
        elif (512*4) <= get_image < (512*5):
            skullNo = 4
            imageNo = get_image-(512*4)
        elif (512*5) <= get_image < (512*6):
            skullNo = 5
            imageNo = get_image-(512*5)
        elif (512*6) <= get_image < (512*7):
            skullNo = 6
            imageNo = get_image-(512*6)
        elif (512*7) <= get_image < (512*8):
            skullNo = 7
            imageNo = get_image-(512*7)
        elif (512*8) <= get_image < (512*9):
            skullNo = 8
            imageNo = get_image-(512*8)
        elif (512*9) <= get_image < (512*10):
            skullNo = 9
            imageNo = get_image-(512*9)
            
        if imageNo<2 or imageNo > 508:
            print(skullNo)
            print(imageNo)
        else: 
            for k in range(5):
                def_image[0, k]=input_image[skullNo, imageNo-2+k, :, :]
            imp_image[0]=output_image[skullNo, imageNo, :, :]
            
            def_image_ten = tf.convert_to_tensor(def_image)
            imp_image_ten = tf.convert_to_tensor(imp_image)
    
            loss, grads = compute_loss_and_grads(
                imp_image_ten, def_image_ten
            )
            test = list(zip(grads, model.trainable_weights))

            optimizer.apply_gradients(test)
            avgeloss.append(loss)
            print("Iteration %d: loss=%.2f" % (i, loss))
            if i % 1000 == 0 or i == 5119:
                aveg_loss=np.mean(avgeloss)
                sum_loss.append(aveg_loss)
                filepath='./modelForX/bilateral_model_1.h5'
                model.save_weights(filepath)

In [None]:
sum_loss

# Uses the trained model to predict the evaluation data

In [None]:
d2_x_m = RNN()
d2_y_m = RNN()
d2_z_m = RNN()
d2_x_m1 = CNN()
d2_y_m1 = CNN()
d2_z_m1 = CNN()

d2_x_m.load_weights('./PreTrainedModel/modelForX/bilateral_model_2.h5')
d2_y_m.load_weights('./PreTrainedModel/modelForY/bilateral_model_1.h5')
d2_z_m.load_weights('./PreTrainedModel/modelForZ/bilateral_model_4.h5')

d2_x_m1.load_weights('./PreTrainedModel/modelForX/CNN_model_13.h5')
d2_y_m1.load_weights('./PreTrainedModel/modelForY/CNN_model_8.h5')
d2_z_m1.load_weights('./PreTrainedModel/modelForZ/CNN_model_12.h5')

please change the prefix varible to the place that you your nrrd that want to be evalated. 

In [None]:
def get_test_data(ID):
    prefix = './eval_data/defective_skull/bilateral/'
    def_data, info = nrrd.read(os.path.join(prefix, f'{ID:03d}.nrrd'))
    return def_data, info

In [None]:
def predict_z(def_data):
    ret = np.zeros((512, 512, 512), dtype='float32')
    for k in range(512):
        if k<2 or k>508:
            print("****")
        else:
            def_slc = np.zeros((1, 5, 512, 512), dtype='float32')
            for i in range(5):
                def_slc[0, i]=def_data[:, :, k-2+i]
            imp_slc = d2_z_m.predict(def_slc)
            ret[:,:,k] = imp_slc
    return ret

In [None]:
def predict_x(def_data):
    ret = np.zeros((512, 512, 512), dtype='float32')
    for k in range(512):
        if k<2 or k>508:
            print("****")
        else:
            def_slc = np.zeros((1, 5, 512, 512), dtype='float32')
            for i in range(5):
                def_slc[0, i]=def_data[k-2+i]
            imp_slc = d2_x_m.predict(def_slc)
            ret[k] = imp_slc
    return ret

In [None]:
def predict_y(def_data):
    ret = np.zeros((512, 512, 512), dtype='float32')
    for k in range(512):
        if k<2 or k>508:
            print("****")
        else:
            def_slc = np.zeros((1, 5, 512, 512), dtype='float32')
            for i in range(5):
                def_slc[0, i]=def_data[:,k-2+i,:]
            imp_slc = d2_y_m.predict(def_slc)
            ret[:,k,:] = imp_slc
    return ret

In [None]:
def predict_x1(def_data):
    ret = np.zeros((512, 512, 512), dtype='float32')
    for k in range(512):
        def_slc = np.zeros((1, 512, 512), dtype='float32')
        def_slc[0]=def_data[k]
        if def_slc[0].sum()==0:
            ret[k] = def_slc[0]
        else:
            imp_slc = d2_x_m1.predict(def_slc)
            ret[k] = imp_slc
    return ret

In [None]:
def predict_y1(def_data):
    ret = np.zeros((512, 512, 512), dtype='float32')
    for k in range(512):
        def_slc = np.zeros((1, 512, 512), dtype='float32')
        def_slc[0]=def_data[:, k, :]
        if def_slc[0].sum()==0:
            ret[:, k, :] = def_slc[0]
        else:
            imp_slc = d2_y_m1.predict(def_slc)
            ret[:, k, :] = imp_slc
    return ret

In [None]:
def predict_z1(def_data):
    ret = np.zeros((512, 512, 512), dtype='float32')
    for k in range(512):
        def_slc = np.zeros((1, 512, 512), dtype='float32')
        def_slc[0]=def_data[:, :, k]
        if def_slc[0].sum()==0:
            ret[:, :, k] = def_slc[0]
        else:
            imp_slc = d2_z_m1.predict(def_slc)
            ret[:, :, k] = imp_slc
    return ret

In the nrrd.write line, you can set the place where you want to put your evalated nrrd file. 

In [None]:
for image in range(20):
    def_data, info = get_test_data(image)
    
    retx = predict_x(def_data)
    rety = predict_y(def_data)
    retz = predict_z(def_data)
    
    retx1 = predict_x1(def_data)
    rety1 = predict_y1(def_data)
    retz1 = predict_z1(def_data)
    
    Ret=retx+rety+retz+retx1+rety1+retz1
    ret_save1=np.empty((512, 512, 512), dtype='int32')
    ret_save2=np.empty((512, 512, 512), dtype='int32')

    
    for i in range(Ret.shape[0]):
        for j in range(Ret.shape[1]):
            for k in range(Ret.shape[2]):
                if (Ret[i, j, k] > 1.0):
                      ret_save1[i,j,k]=1
    
    
    for i in range(Ret.shape[0]):
        for j in range(Ret.shape[1]):
            for k in range(Ret.shape[2]):
                if (ret_save1[i, j, k] == 1):
#                   remove the small impurity
                    round_el1 = np.empty((7, 7, 7))
                    round_el1 = ret_save1[i-3:i+4, j-3:j+4, k-3:k+4]
#                   remove the large impurity
                    round_el2 = np.empty((11, 11, 11))
#                   due to the outline of the large impurity may beyond the 512,
#                   So do this process
                    if (i+7>512 or j+7>512 or k+7>512):
                        round_el2 = ret_save1[i-11:i, j-11:j, k-11:k]
                    else: 
                        round_el2 = ret_save1[i-5:i+6, j-5:j+6, k-5:k+6]
                        
                    if (round_el2.sum()>=500 and round_el1.sum()>=90):
                        ret_save2[i, j, k] = 1   
                    
               
    nrrd.write(os.path.join('./result/bilateral', f'{image:03d}.nrrd'), ret_save2.astype('int32'), info)
    print(image)