In [None]:
# Here is the imports
import os

os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'
FLAGS = [""] # load_model, load_weight, verbose
from tensorflow import keras
import numpy as np
import tensorflow as tf
#physical_devices = tf.config.list_physical_devices('GPU')
#tf.config.experimental.set_memory_growth(physical_devices[0], True)
import segmentation_models as sm
import albumentations as A
sm.set_framework("tf.keras")
# keras.mixed_precision.set_global_policy('mixed_float16') normally this would provide extra speed for the model
# but in the case of 1660ti gpus they seem like they have tensor cores even they don't thus it slows down the model use this on higher powered models
from matplotlib import pyplot as plt
AUTOTUNE = tf.data.AUTOTUNE
from keras_unet_collection import models
from tensorflow.keras.models import load_model
from keras_unet_collection import losses
import cv2
from tensorflow.keras import backend as K

In [None]:
# Dataset Constants
test_dir = './data/dataset1/'
save_dir = './result/'
IMG_EXT = 'png'
BATCH_SIZE = 1
LR = 0.0001
# Model Constants
BACKBONE = 'efficientnetb3'
# unlabelled 0, iskemik 1, hemorajik 2
CLASSES = ['iskemik', 'kanama']
MODEL_LOAD_TYPE = 'model' # model or weights
#MODEL_LOAD_TYPE = 'weights' # model or weights
MODEL_PRE_TYPE = 'sm' # sm or custom
#MODEL_PRE_TYPE = 'custom' # sm or custom


#MODEL_WEIGHT_PATH = "./models/14_09-07_21/best.h5"
#MODEL_WEIGHT_PATH = "./models/10_09-01_27/best_01_27_10_09.h5"
MODEL_WEIGHT_PATH = "./models/15_09-07_18/best.h5"

In [None]:
# helper function for data visualization
def visualize(**images):
    """Plot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        # if whole binary image is true plt shows it as whole image is false so for bypassing this issue we assing one pixels value to 0
        image[1,1] = 0 
        plt.imshow(image)
    plt.show()

def visualize_dataset(img, mask, classes):
    kwarg = {'image': img}
    for i in range(len(classes)):
        kwarg.update({classes[i] : mask[..., i].squeeze()})
    visualize(**kwarg)

# helper function for data visualization    
def denormalize(x):
    """Scale image to range 0..1 for correct plot"""
    x_max = np.percentile(x, 98)
    x_min = np.percentile(x, 2)    
    x = (x - x_min) / (x_max - x_min)
    x = x.clip(0, 1)
    return x

In [None]:
def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        A.Lambda(image=preprocessing_fn),
    ]
    return A.Compose(_transform)

def preprocessing_fn_sm(image):
    aug = get_preprocessing(sm.get_preprocessing(BACKBONE))(image=image)
    image = aug["image"].astype("float32")
    return image 

def preprocessing_fn_custom(image):
    image = image/255.
    image = image.astype("float32")
    return image


In [None]:
n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1)  # case for binary and multiclass segmentation

activation = 'sigmoid' if n_classes == 1 else 'softmax'

dice_loss = sm.losses.DiceLoss(class_weights=np.array([2, 2, 0.5])) 
#multi_focal_tversky = losses.multiclass_focal_tversky(alpha=0.7, gamma=4/3)
focal_loss = sm.losses.CategoricalFocalLoss()
#total_loss = dice_loss + (1 * focal_tversky)
total_loss = dice_loss + (1 * focal_loss)

if MODEL_LOAD_TYPE == 'model':
    model = load_model(MODEL_WEIGHT_PATH, custom_objects={'dice_loss_plus_1focal_loss': total_loss,'iou_score': sm.metrics.IOUScore(threshold=0.5), 'f1-score': sm.metrics.FScore(threshold=0.5)})
elif MODEL_LOAD_TYPE == 'weights':
    #create model
    model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)

    #model = models.swin_unet_2d((512, 512, 3), filter_num_begin=64, n_labels=3, depth=3, stack_num_down=2, stack_num_up=2, 
    #                            patch_size=(2, 2), num_heads=[4, 8, 8], window_size=[4, 2, 2, 2], num_mlp=512, 
    #                            output_activation='Softmax', shift_window=True, name='swin_unet')

    # define optomizer
    optim = keras.optimizers.Adam(LR)

    multi_focal_tversky = losses.multiclass_focal_tversky(alpha=0.7, gamma=4/3)
    metrics = [sm.metrics.IOUScore(), sm.metrics.FScore()]
    model.compile(optim, multi_focal_tversky, metrics)
    model.load_weights(MODEL_WEIGHT_PATH)

In [None]:
predict_filenames = tf.io.gfile.glob(f"{test_dir}/*.{IMG_EXT}")

for i in predict_filenames:
    img = cv2.imread(i)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    if MODEL_PRE_TYPE == "sm":
        img = preprocessing_fn_sm(img)
    elif MODEL_PRE_TYPE == "custom":
        img = preprocessing_fn_custom(img)
    pred_mask = model.predict(img)
    if "verbose" in FLAGS:
        pass # todo verbose
    pred_mask = K.argmax(pred_mask, axis=-1)
    cv2.imwrite(i.replace(test_dir, save_dir), pred_mask)