In [None]:
import os
# Set to check and allow GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['SM_FRAMEWORK'] = 'tf.keras'

In [None]:

# Install required libs

## Segmentation model training
!export SM_FRAMEWORK=tf.keras

### please update Albumentations to version>=0.3.0 for `Lambda` transform support
!pip install -U albumentations>=0.3.0 --user 
!pip install -U --pre segmentation-models --user

In [None]:
import cv2
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

In [None]:
DATA_DIR = '/kaggle/input/panda-image-and-cmapped-mask-data/train_images/'
MASK_DIR = '/kaggle/input/panda-image-and-cmapped-mask-data/train_label_masks/'

In [None]:
# Data loader and utility functions
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()
    
# helper function for data visualization    
def denormalize(x):
    """Scale image to range 0..1 for correct plot"""
    x_max = np.percentile(x, 98)
    x_min = np.percentile(x, 2)    
    x = (x - x_min) / (x_max - x_min)
    x = x.clip(0, 1)
    return x

In [None]:
# classes for data loading and preprocessing
class Dataset:
    """Panda Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    # Radboud images: Prostate glands are individually labelled, valid values are:
    # 0 : background (non-tissue or unknown)
    # 1 : stroma (connective tissues, non-epithelium tissue)
    # 2 : healthy (benign) epithelium
    # 3 : cancerous epithelium (Gleason 3)
    # 4 : cancerous epithelium (Gleason 4)
    # 5  :cancerous epithelium (Gleason 5)
    CLASSES = ['background', 'stroma', 'healthy', 'gleason3', 'gleason4', 'gleason5']
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
           
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)
        
        # extract certain classes from mask (e.g. background)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # add background if mask is not binary
        if mask.shape[-1] != 1:
            background = 1 - mask.sum(axis=-1, keepdims=True)
            mask = np.concatenate((mask, background), axis=-1)
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.ids)
       
class Dataloder(keras.utils.Sequence):
    """Load data from dataset and form batches
    
    Args:
        dataset: instance of Dataset class for image loading and preprocessing.
        batch_size: Integet number of images in batch.
        shuffle: Boolean, if `True` shuffle image indexes each epoch.
    """
    
    def __init__(self, dataset, batch_size=1, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(dataset))

        self.on_epoch_end()

    def __getitem__(self, i):
        
        # collect batch data
        start = i * self.batch_size
        stop = (i + 1) * self.batch_size
        data = []
        for j in range(start, stop):
            data.append(self.dataset[j])
        
        # transpose list of lists
        batch = [np.stack(samples, axis=0) for samples in zip(*data)]
        
        #return batch
        # newer version of tf/keras want batch to be in tuple rather than list
        # Ref. https://github.com/qubvel/segmentation_models/issues/412
        return tuple(batch)
    
    def __len__(self):
        """Denotes the number of batches per epoch"""
        return len(self.indexes) // self.batch_size
    
    def on_epoch_end(self):
        """Callback function to shuffle indexes each epoch"""
        if self.shuffle:
            self.indexes = np.random.permutation(self.indexes)

In [None]:
x_train_dir = os.path.join(DATA_DIR, 'train_images')
y_train_dir = os.path.join(MASK_DIR, 'train_label_masks')

x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'valannot')

x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'testannot')

In [None]:
# Lets look at data we have
dataset = Dataset(x_train_dir,
                  y_train_dir,
                  classes=['background', 'stroma', 'healthy', 'gleason3', 'gleason4', 'gleason5'])

image, mask = dataset[5] # get some sample
visualize(
    image=image, 
    stroma_mask=mask[..., 0].squeeze(),
    healthy_mask=mask[..., 1].squeeze(),
    gleason3_mask=mask[..., 2].squeeze(),
    gleason4_mask=mask[..., 3].squeeze(),
    gleason5_mask=mask[..., 4].squeeze()
)

## Augmentations


Data augmentation is a powerful technique to increase the amount of your data and prevent model overfitting.
Refer these articles:

The Effectiveness of Data Augmentation in Image Classification using Deep Learning
Data Augmentation | How to use Deep Learning when you have Limited Data
Data Augmentation Experimentation

All this transforms can be easily applied with Albumentations - fast augmentation library. For detailed explanation of image transformations you can look at kaggle salt segmentation exmaple provided by Albumentations authors.

In [None]:
import albumentations as A

In [None]:
IMAGE_WIDTH = 320
IMAGE_HEIGHT = 320

def round_clip_0_1(x, **kwargs):
    return x.round().clip(0, 1)

# define heavy augmentations
def get_training_augmentation():
    train_transform = [
        #A.PadIfNeeded(min_height=320, min_width=320, always_apply=True, border_mode=0),
        #A.RandomCrop(height=320, width=320, always_apply=True),
        A.augmentations.geometric.resize.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, always_apply=True)
    ]
    return A.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        #A.PadIfNeeded(384, 480)
        #A.PadIfNeeded(min_height=320, min_width=320, always_apply=True, border_mode=0),
        #A.RandomCrop(height=320, width=320, always_apply=True),
        #A.augmentations.geometric.resize.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, always_apply=True)
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, always_apply=True)

    ]
    return A.Compose(test_transform)

def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        A.Lambda(image=preprocessing_fn),
        #A.augmentations.geometric.resize.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, always_apply=True)
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, always_apply=True)


    ]
    return A.Compose(_transform)

## Segmentation

In [None]:
import segmentation_models as sm

In [None]:
# Define some constants
BACKBONE = 'efficientnetb7'
BATCH_SIZE = 8
CLASSES = ['background', 'stroma', 'healthy', 'gleason3', 'gleason4', 'gleason5']
LR = 0.0001  # Learning rate for the training
EPOCHS = 15  # Number of epochs

preprocess_input = sm.get_preprocessing(BACKBONE)

In [None]:
# define network parameters
n_classes = len(CLASSES)  # multiclass segmentation
activation = 'softmax'

#create model
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)

In [None]:
# define optomizer
optim = keras.optimizers.Adam(LR)

# Use binary focal dice loss as the loss optimization metric
total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss 

# Track IOU Score and F1 score during training.
metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

# compile keras model with defined optimozer, loss and metrics
model.compile(optim, total_loss, metrics)

In [None]:
# Dataset for train images
train_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    classes=CLASSES, 
    #augmentation=get_training_augmentation(),
    augmentation=None,
    preprocessing=get_preprocessing(preprocess_input),
)

# Dataset for validation images
valid_dataset = Dataset(
    x_valid_dir, 
    y_valid_dir, 
    classes=CLASSES, 
    #augmentation=get_validation_augmentation(),
    augmentation=None,
    preprocessing=get_preprocessing(preprocess_input),
)

train_dataloader = Dataloder(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = Dataloder(valid_dataset, batch_size=1, shuffle=False)

# check shapes for errors
assert train_dataloader[0][0].shape == (BATCH_SIZE, IMAGE_WIDTH, IMAGE_HEIGHT, 3)
assert train_dataloader[0][1].shape == (BATCH_SIZE, IMAGE_WIDTH, IMAGE_HEIGHT, n_classes)

# define callbacks for learning rate scheduling and save model at best checkpoint.
callbacks = [
    keras.callbacks.ModelCheckpoint('./best_model.h5', save_weights_only=True, save_best_only=True, mode='min'),
    keras.callbacks.ReduceLROnPlateau(),
]

In [None]:
# train model
history = model.fit(
    train_dataloader, 
    steps_per_epoch=len(train_dataloader), 
    epochs=EPOCHS, 
    callbacks=callbacks, 
    validation_data=valid_dataloader, 
    validation_steps=len(valid_dataloader),
)

In [None]:
# Plot training & validation iou_score values
plt.figure(figsize=(30, 5))
plt.subplot(121)
plt.plot(history.history['iou_score'])
plt.plot(history.history['val_iou_score'])
plt.title('Model iou_score')
plt.ylabel('iou_score')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')

# Plot training & validation loss values
plt.subplot(122)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

## Model Evaluation



In [None]:
test_dataset = Dataset(
    x_test_dir, 
    y_test_dir, 
    classes=CLASSES, 
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocess_input),
)

test_dataloader = Dataloder(test_dataset, batch_size=1, shuffle=False)

In [None]:
# load best weights (from the training epoch)
model.load_weights('best_model.h5')

In [None]:
scores = model.evaluate(test_dataloader)

print("Loss: {:.5}".format(scores[0]))
for metric, value in zip(metrics, scores[1:]):
    print("mean {}: {:.5}".format(metric.__name__, value))

## Visualization of results on test dataset

In [None]:
n = 5
ids = np.random.choice(np.arange(len(test_dataset)), size=n)

for i in ids:
    
    image, gt_mask = test_dataset[i]
    image = np.expand_dims(image, axis=0)
    pr_mask = model.predict(image)
    
    visualize(
        image=denormalize(image.squeeze()),
        gt_mask=gt_mask.squeeze(),
        pr_mask=pr_mask.squeeze(),
    )