In [None]:
import os
os.environ["SM_FRAMEWORK"] = "tf.keras"

import cv2
import glob

import numpy as np
from matplotlib import pyplot as plt
import albumentations as A
from tifffile import imread
import segmentation_models as sm
from random import randint

import tensorflow as tf
from keras.utils import to_categorical

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

# Import dataset

In [None]:
class Dataset():
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
            
    ):
        self.classes = classes
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids] # images file paths
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids] # masks file paths

        # convert str names to class values on masks
        self.class_values = [classes.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    

    def merge(self, x):
        dim = (x.shape[-2], x.shape[-1])
        merged = np.zeros(dim)
        for i in range(len(self.classes)):
            merged = np.where(merged==0, x[i], merged)

        merged  = merged.reshape(dim[0], dim[1], 1)

        return merged


    def to_index(self, x):
        for c in range(len(self.classes)):
            x[c][x[c] == 255] = c + 1
        
        return x


    def export(self, test_ratio = 0.4):
        
        test_size = int(len(self.images_fps) * test_ratio)
        train_size = int( (len(self.images_fps) - test_size) * 0.5 )

        images = []
        for image_uri in self.images_fps:
            img = np.asarray(imread(image_uri), dtype=np.uint8)
            images.append(img)
        
        images = np.asarray(images)

        masks = []
        for mask_uri in self.masks_fps:
            mask = np.asarray(imread(mask_uri), dtype=np.uint8)
            mask = self.to_index(mask)
            mask = self.merge(mask)
            masks.append(mask)
        
        masks = np.asarray(masks)

        print(images.shape, masks.shape)

        x_test = images[:test_size]
        y_test = masks[:test_size]

        masks = to_categorical(masks, num_classes=len(self.classes) + 1)

        x_train = images[test_size:test_size+train_size]
        y_train = masks[test_size:test_size+train_size]

        x_val = images[test_size+train_size:]
        y_val = masks[test_size+train_size:]

        return x_test, y_test, x_train, y_train, x_val, y_val


    def __getitem__(self, i):
        
        # read data
        image = np.asarray(imread(self.images_fps[i]), dtype=np.uint8)
        mask = np.asarray(imread(self.masks_fps[i]), dtype=np.uint8)

        print(np.unique(mask))
        mask = self.to_index(mask)
        mask = self.merge(mask)

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.images_fps)

In [None]:
export_uri = r'D:\NSC2024\QuPath_proj\export'

dataset = Dataset(
    images_dir = os.path.join(export_uri, 'images'),
    masks_dir = os.path.join(export_uri, 'masks'),
    classes = ['lepidic', 'acinar', 'solid', 'micropapillary', 'papillary'],
)

In [None]:
r = randint(0, len(dataset) - 1)
image, masks = dataset[r]
print(np.unique(masks))
print(masks.shape)
visualize(
    image=image,
    mask=masks
)


In [None]:
x_test, y_test, x_train, y_train, x_val, y_val = dataset.export()

In [None]:
print(np.unique(y_test))
print(np.unique(y_train))
print(np.unique(y_val))

In [None]:
y_train.shape, y_val.shape

# Training

In [None]:
import keras

activation='softmax'
LR = 0.0001
opt = keras.optimizers.Adam(LR)

# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
# set class weights for dice_loss (car: 1.; pedestrian: 2.; background: 0.5;)
dice_loss = sm.losses.DiceLoss(class_weights=np.array([0, 0.20, 0.20, 0.20, 0.20, 0.20])) 
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)

# actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
# total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss 
metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

In [None]:
BACKBONE = 'resnet101'

# define model
model = sm.Unet(BACKBONE, encoder_weights='imagenet', classes=len(dataset.classes)+1, activation=activation)

# compile keras model with defined optimizer, loss and metrics
model.compile(opt, total_loss, metrics=metrics)

# model.summary()

In [None]:
preprocess_input = sm.get_preprocessing(BACKBONE)
pre_x_train = preprocess_input(x_train)
pre_x_val = preprocess_input(x_val)

history = model.fit(pre_x_train, 
          y_train,
          validation_data=(pre_x_val, y_val),
          batch_size=8, 
          verbose=2,
          epochs=20
          )

model.save('test.keras')

In [None]:
from keras.models import load_model

model = load_model('./test.keras', compile=False)

In [None]:
pre_x_test = preprocess_input(x_test)
y_pred = model.predict(pre_x_test)
y_pred_argmax = np.argmax(y_pred, axis=3)

In [None]:
print(y_pred_argmax.shape)
np.unique(y_pred_argmax)

In [None]:
from keras.metrics import MeanIoU

IOU_keras = MeanIoU(num_classes=len(dataset.classes)+1)  
IOU_keras.update_state(y_test[:,:,:,0], y_pred_argmax)
print("Mean IoU =", IOU_keras.result().numpy())

In [None]:
def gray_to_rgb(x):
    x_reshaped = np.concatenate([x] * 3, axis=-1)
    color_map = {
            1: [255, 0, 0],   # Red for lepidic
            2: [0, 255, 0],   # Green for acinar
            3: [0, 0, 255],    # Blue for micropapillary
            4: [255, 255, 0],  # Yellow for papillary
            5: [255, 0, 255],   # violet for solid
        }
    
    rgb = np.zeros_like(x_reshaped, dtype=np.uint8)
    for label, color in color_map.items():
            rgb[x_reshaped[..., 0] == label] = color 
    return rgb


In [None]:
# Convert to colors according to classes
y_test_show_rgb = gray_to_rgb(y_test)

y_pred_argmax = np.expand_dims(y_pred_argmax, axis=-1)
y_pred_argmax_show_rgb = gray_to_rgb(y_pred_argmax)


In [None]:
idx = randint(0, len(x_test)-1)
# idx = 19
print(y_test.shape)
print(np.unique(y_pred_argmax[idx]))
print("Index: ", idx)

visualize(
    image=x_test[idx],
    ground_truth=y_test_show_rgb[idx],
    predict=y_pred_argmax_show_rgb[idx]
)
