In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import load_img, img_to_array, array_to_img

from keras import backend as K
import cv2
import numpy as np
import pandas as pd

In [2]:
def dice_coef_9cat(y_true, y_pred, smooth=1e-7):
    '''
    Dice coefficient for 3 categories. Ignores background pixel label 0
    Pass to model as metric during compile statement
    '''
    y_true_f = K.flatten(K.one_hot(K.cast(y_true, 'int32'), num_classes=3)[...,1:])
    y_pred_f = K.flatten(y_pred[...,1:])
    intersect = K.sum(y_true_f * y_pred_f, axis=-1)
    denom = K.sum(y_true_f + y_pred_f, axis=-1)
    return K.mean((2. * intersect / (denom + smooth)))

def dice_coef_9cat_loss(y_true, y_pred):
    '''
    Dice loss to minimize. Pass to model as loss during compile statement
    '''
    return 1 - dice_coef_9cat(y_true, y_pred)

In [3]:
loss = dice_coef_9cat_loss
segmentation_model = keras.models.load_model('C:\Luna_CS\Aravind\working_model.h5', 
                                   custom_objects={ loss.__name__: loss })

In [4]:
def remove_nerves(image):
    img = array_to_img(image)
    
    img = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2RGB)
    # convert image to grayScale
    grayScale = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
   
    # kernel for morphologyEx
    kernel = cv2.getStructuringElement(1,(17,17))
   
    # apply MORPH_BLACKHAT to grayScale image
    blackhat = cv2.morphologyEx(grayScale, cv2.MORPH_BLACKHAT, kernel)
  
    # apply thresholding to blackhat
    _,threshold = cv2.threshold(blackhat,10,255,cv2.THRESH_BINARY)

    # inpaint with original image and threshold image
    final_image = cv2.inpaint(img,threshold,1,cv2.INPAINT_TELEA)
    final_image = cv2.cvtColor(final_image, cv2.COLOR_BGR2RGB)
    
    return final_image.astype(np.float64)/255.0

In [5]:
def get_cropped_image(file, model):
    test_img = load_img(file, target_size=(512, 512))
    test_img_original = load_img(file)
    full_arr = np.array(test_img_original)/255.0
    x_true = full_arr.shape[0]
    y_true = full_arr.shape[1]
    test_arr = img_to_array(test_img)/255.0
    
    test_arr_preprocessed = remove_nerves(test_arr)

    pred = model.predict(test_arr_preprocessed.reshape(1, 512, 512, 3))
    single_pred = pred[0]
    mask_img = array_to_img(single_pred)
    mask_img = mask_img.resize((y_true, x_true))
    mask_arr = np.array(mask_img)/255.0
    my_mask = np.zeros((x_true, y_true))
    marked = []
    inner = []
    for i in range(x_true):
        for j in range(y_true):
            if np.argmax(mask_arr[i][j]) != 0:
                marked.append([i, j])
                
    avgi = []
    avgj = []
    for lis in marked:
        first = lis[0]
        sec = lis[1]
        avgi.append(first)
        avgj.append(sec)
        

        
    avgi = sum(avgi)//len(avgi)
    avgj = sum(avgj)//len(avgj)
    
    middle  = (avgi, avgj)
    K = int((300.0/6291456.0)*float(x_true*y_true))
    top = middle[1] - K
    bottom = middle[1] + K
    left = middle[0] - K
    right = middle[0] + K
    
    cropped = test_img_original.crop(((top, left, bottom, right)))

    return np.array(cropped)/255.0
    

In [6]:
classification_model = keras.models.load_model('C:\Luna_CS\Aravind\\final_model.h5')

The following method ```get_diagnosis``` takes 2 input parameters. The first one is the path to the image, and the next is a boolean value representing whether the image needs to be cropped or not. If cropping is necessary, the segmentation model will be used. If not, then only the classification model will be required. 

In [7]:
def get_diagnosis(file, full_image=True):
    '''
    Set full_image=True if image is not pre-cropped
    Else set full_image=False
    '''
    
    if full_image:
        img = get_cropped_image(file, segmentation_model)
        img = array_to_img(img)
        img = img.resize((224, 224))
        img = np.array(img)/255.0
    else:
        img = np.array(load_img(file, target_size=(224, 224)))/255.0
        
        
    pred = classification_model.predict(img.reshape(1, 224, 224, 3))

    if pred[0] > 0.5:
        print('Healthy')
        
    else:
        print('Risk of Glaucoma')

The following are two test images. The first is already cropped and the second one isn't. Cropping the image takes time, so the diagnosis will not be as fast. 

In [8]:
get_diagnosis('C:\Luna_CS\Aravind\Database\Images\Im530_g_ACRIMA.jpg', False)

Risk of Glaucoma


In [9]:
get_diagnosis('C:\Luna_CS\Aravind\ORIGA\Images\\002.jpg', True)

Healthy
