In [None]:
# This code has been compiled and utilized for training and testing of dental implant radiographic image processing
# U-Net segmentation of vertical misfit of implant prosthesis

In [None]:
!pip install -q git+https://github.com/tensorflow/examples.git

In [None]:
import os
import glob
import cv2
import numpy as np
from matplotlib import pyplot as plt
import math
from pathlib import Path
import re
from skimage import measure
from sklearn.metrics import mean_absolute_error, mean_squared_error, accuracy_score
import matplotlib as mpl
import tqdm
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix
from sklearn.model_selection import train_test_split
from tensorflow import keras
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import normalize
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate, Conv2DTranspose, BatchNormalization, Dropout, Lambda

In [None]:
images_path = ".segmentation/data/images/IMAGES"
masks_path = ".segmentation/data/masks/MASKS"

In [None]:
OUTPUT_CHANNELS = 5
classes = {'bg':0, 'fixture': 1,  'abutment': 2,  'crown': 3, 'gap': 4}

In [None]:
def standardize(x):
    x = np.array(x, dtype='float64')
    x -= np.min(x)
    x /= np.percentile(x, 98)
    x[x > 1] = 1
    return x

def preprocessing(img):
    image = np.array(img)   
    gray = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
    image = np.zeros_like(image)
    image[:,:,0] = gray
    image[:,:,1] = gray
    image[:,:,2] = gray
    image = standardize(image)
    return image

In [None]:
images_paths = []
masks_paths = [] 

for imgname in os.listdir(images_path):
  images_paths.append(os.path.join(images_path,imgname))

for imgname in os.listdir(masks_path):
  masks_paths.append(os.path.join(masks_path,imgname))

In [None]:
images_paths.sort()
masks_paths.sort()

In [None]:
print(images_paths[:5])
print(masks_paths[:5])

In [None]:
SIZE_X = 256 
SIZE_Y = 256
n_classes = 5

train_images = []
train_masks = [] 

for imgpath in tqdm.tqdm(images_paths):
  img = cv2.imread(imgpath)
  img = cv2.resize(img, (SIZE_Y, SIZE_X))
  img = preprocessing(img)               
  train_images.append(img)


for maskpath in tqdm.tqdm(masks_paths):
  mask0 = cv2.imread(maskpath, 0)
  mask1 = cv2.resize(mask0, (SIZE_Y, SIZE_X), interpolation = cv2.INTER_NEAREST)
  train_masks.append(mask1)

train_images = np.array(train_images)
train_masks = np.array(train_masks)

X_train, X_val, y_train, y_val = train_test_split(train_images, train_masks, test_size = 0.15, shuffle=True, random_state = 1)
print("Class values: ", np.unique(y_train))

In [None]:
NORM = mpl.colors.Normalize(vmin=0, vmax=4)

plt.figure(figsize=(16,10))
for i in range(1,4):
    plt.subplot(2,3,i)
    img = train_images[i]
    plt.imshow(img)
    plt.colorbar()
    plt.axis('off')

for i in range(4,7):
    plt.subplot(2,3,i)
    img = np.squeeze(train_masks[i-3])
    plt.imshow(img, cmap='jet', norm=NORM)
    plt.colorbar()
    plt.axis('off')
plt.show()

In [None]:
def unet_model(output_channels):
    IMG_HEIGHT = X_train.shape[1]
    IMG_WIDTH  = X_train.shape[2]
    IMG_CHANNELS = X_train.shape[3]

    base_model = MobileNetV2(input_shape=[IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS], include_top=False, weights = 'imagenet')

    layer_names = [
        'block_1_expand_relu',   
        'block_3_expand_relu',   
        'block_6_expand_relu',   
        'block_13_expand_relu',  
        'block_16_project',      
    ]

    base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

    down_stack = Model(inputs=base_model.input, outputs=base_model_outputs)

    down_stack.trainable = False

    up_stack = [
        pix2pix.upsample(512, 3),  
        pix2pix.upsample(256, 3),  
        pix2pix.upsample(128, 3),  
        pix2pix.upsample(64, 3),   
    ]

    inputs = Input(shape=[IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS])

    skips = down_stack(inputs)
    x = skips[-1]
    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = Concatenate()
        x = concat([x, skip])

    last = Conv2DTranspose(OUTPUT_CHANNELS, 5, strides=2, padding='same') 

    x = last(x)

    return Model(inputs=inputs, outputs=x)

In [None]:
keras.utils.plot_model(base, show_shapes=True)

In [None]:
def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]


def show_predictions(epoch, dataset=None, num=50):
  if dataset:
    
    for image, mask in dataset.take(num):
        pred_mask = model.predict(image)
        plt.figure(figsize=(15, 10))
        plt.subplot(231)
        plt.title('Testing Image')
        plt.imshow(image[0], cmap='gray')
        plt.subplot(232)
        plt.title('Ground Truth')
        plt.imshow(mask[0], cmap='jet')
        plt.subplot(233)
        plt.title('Prediction on test image')
        plt.imshow(create_mask(pred_mask), cmap='jet')
        
        plt.savefig(f"results/mask_{str(ii)}.png")

        plt.show()
  else:
      fig = plt.figure(figsize=(12, 12))
      fig.suptitle(f"\n Epoch: {str(epoch)}\n", fontsize=16)

      plt.subplot(331)
      plt.title('Testing Image')
      plt.imshow(train_images[num], cmap='gray')
      plt.subplot(332)
      plt.title('Ground Truth')
      plt.imshow(train_masks[num], cmap='jet')
      plt.subplot(333)
      plt.title('Prediction on test image')
      plt.imshow(create_mask(model.predict(train_images[num][tf.newaxis, ...]))[:,:,0], cmap='jet')

      plt.subplot(334)
      plt.imshow(train_images[num+16], cmap='gray')
      plt.subplot(335)
      plt.imshow(train_masks[num+16], cmap='jet')
      plt.subplot(336)
      plt.imshow(create_mask(model.predict(train_images[num+16][tf.newaxis, ...]))[:,:,0], cmap='jet')
      plt.subplot(337)
      plt.imshow(train_images[num+14], cmap='gray')
      plt.subplot(338)
      plt.imshow(train_masks[num+14], cmap='jet')
      plt.subplot(339)
      plt.imshow(create_mask(model.predict(train_images[num+14][tf.newaxis, ...]))[:,:,0], cmap='jet')

      plt.savefig(f"results/mask_{str(num+100)}_{str(epoch)}.png")

      plt.show()

In [None]:
model = unet_model(OUTPUT_CHANNELS)
model.compile(optimizer='adam',
              loss=SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])   # 
              
              metrics=[tf.keras.metrics.MeanIoU(num_classes=5)])

In [None]:
EPOCHS = 50
VAL_SUBSPLITS = 5
BATCH_SIZE = 32
VALIDATION_STEPS = len(X_val)//BATCH_SIZE//VAL_SUBSPLITS
STEPS_PER_EPOCH = len(X_train)//BATCH_SIZE
sample_image = train_images[0]
sample_mask = train_masks[0]

class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        show_predictions(epoch)
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

model_history = model.fit(X_train, y_train, epochs=EPOCHS,
                           batch_size = BATCH_SIZE, 
                          verbose=1, 
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=(X_val, y_val),
                          callbacks=[DisplayCallback()]
                          )

In [None]:
model.save('./models/vertical_misfit5')

In [None]:
model = load_model('./models/vertical_misfit5')

In [None]:
test_path = "./data/test_images" 

test_paths = [] 

for imgname in os.listdir(test_path):
  test_paths.append(os.path.join(test_path,imgname))

In [None]:
test_path = "./data/test_images" 
timgnum = 0
img_num = int(test_paths[timgnum].split("/")[-1].split(".")[0].split("_")[-1])

plt.figure(figsize=(16,10))

plt.subplot(2,3,1)
img = cv2.imread(test_paths[timgnum])
img = cv2.resize(img, (SIZE_Y, SIZE_X))
img = preprocessing(img)
plt.imshow(img)

plt.subplot(2,3,2)
pred = np.array(create_mask(model.predict(img[tf.newaxis, ...])))
plt.imshow(np.squeeze(pred))

plt.subplot(2,3,3)
plt.imshow(train_masks[img_num-1])

In [None]:
history_1 = model_history.history
acc=history_1['accuracy']
val_acc = history_1['val_accuracy']



plt.plot(acc[:150], '-', label='Training')
plt.plot(val_acc[:150], '--', label='Validation')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.ylim([0.7,1.0])
plt.legend()
plt.show()

In [None]:
IOU = Intersection over union
m = tf.keras.metrics.MeanIoU(num_classes=5)

In [None]:
def dice_coef1(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    #print(tf.reduce_sum(y_true_f).numpy(), tf.reduce_sum(y_pred_f).numpy())
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

In [None]:
n_classes=5
dfs = {}
for i in range(n_classes):
  dfs[i]=[]

for test_img_number in range(len(test_paths)):

  test_img = test_images[test_img_number]
  img_mask = test_masks[test_img_number]

  predicted_img = np.squeeze(create_mask(model.predict(test_images[test_img_number][tf.newaxis, ...])).numpy())

  img_mask_exp = np.zeros((SIZE_X, SIZE_Y, n_classes))
  img_pred_exp = np.zeros((SIZE_X, SIZE_Y, n_classes))
  for i in range(n_classes):
    #print(test_img_number, i)
    img_mask_exp[:,:,i][img_mask==i]=1
    img_pred_exp[:,:,i][predicted_img==i]=1
    df = dice_coef1(img_mask_exp[:,:,i], img_pred_exp[:,:,i]).numpy()
    dfs[i].append(df)
  
dfss = []
for i in range(n_classes):
  avg = sum(dfs[i]) / len(dfs[i])
  print(f"Dice score of {str(i)}: {str(avg)}")
  dfss.append(avg)