In [None]:
import os
os.environ['SM_FRAMEWORK'] = 'tf.keras'

import random
import cv2
import wandb
import numpy as np
import albumentations as A

from tensorflow import keras
from tensorflow.keras.callbacks import Callback

import segmentation_models as sm
import tensorflow_advanced_segmentation_models as tasm

print("---------------------------------------")
print("[INFO] CNN Seg. Model Script")
print("[INFO] Author: Martin Juricek")
print("[INFO] Supervisor: Ing. Roman Parak")
print("[INFO] IACS FME BUT @2022")
print("---------------------------------------")

class Image_Load(object):
    def get_data(self, image_dir, mask_dir):
        img_paths = sorted(
            [
                os.path.join(image_dir, fname)
                for fname in os.listdir(image_dir)
                if fname.endswith(".jpg")
            ]
        )
        
        mask_paths = sorted(
            [
                os.path.join(mask_dir, fname)
                for fname in os.listdir(mask_dir)
                if fname.endswith(".png") and not fname.startswith(".")
            ]
        )

        print("[INFO] Number of loaded samples:", len(img_paths))

        return img_paths, mask_paths

class Image_Augment(object):
    def __aug_settings(self, train):
        if train == True:
            transform = A.Compose([
                #A.Transpose(p=1),
                A.CLAHE(p=0.8),
                A.GridDistortion(p=1),
                A.VerticalFlip(p=0.5),
                A.OneOf(
                    [
                        A.CLAHE(p=1),
                        A.RandomBrightness(p=1),
                        A.RandomGamma(p=1),
                    ],
                    p=0.9,
                ),
                A.RandomBrightnessContrast(p=0.2),
            ])
        
        else:
            transform = A.Compose([
                A.PadIfNeeded(480, 640),
                A.Resize(480, 640, always_apply=True)
            ])

        return transform

    def aug_process(self, img_paths, mask_paths, num, train):
        transform = self.__aug_settings(train)
        
        target = zip(img_paths, mask_paths)

        j = 0
        images, masks = [], []

        for image_path, mask_path in target:
            j+=1 
            img = cv2.imread(image_path)
            mask = cv2.imread(mask_path)
            
            for i in range(num):
                transformed = transform(image=img, mask=mask)
                transformed_image = transformed['image']
                transformed_mask = transformed['mask']

                images.append(transformed_image)
                masks.append(transformed_mask[:, :, 0])
                
        print("[INFO] Were generated images: " + str(len(images)))
        print("[INFO] Were generated masks: " + str(len(masks)))

        return images, masks

class Image_Preprocess(keras.utils.Sequence):
    def __init__(self, batch_size, img_size, imgs, masks):
        self.batch_size = batch_size
        self.img_size = img_size
        self.imgs = imgs
        self.masks = masks
            
    def __len__(self):
        return len(self.imgs) // self.batch_size

    def __getitem__(self, idx): 
        i = idx * self.batch_size        
        batch_img = self.imgs[i : i + self.batch_size]
        batch_mask = self.masks[i : i + self.batch_size]

        x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
        y = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="float32")
        
        for j, path in enumerate(batch_img):
            img = path
            x[j] = img
        
        for j, path in enumerate(batch_mask):
            img = path
            y[j] = np.expand_dims(img, 2)
            y[j] -= 1
        
        return x,y

class EarlyStoppingByValAcc(Callback):
    def __init__(self, monitor='val_sparse_categorical_accuracy', value=0.9998, verbose=1):
        super(Callback, self).__init__()
        self.monitor = monitor
        self.value = value
        self.verbose = verbose

    def on_epoch_end(self, epoch, logs={}):
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn("[INFO] Early stopping requires %s available!" % self.monitor, RuntimeWarning)

        if current > self.value:
            if self.verbose > 0:
                print("[INFO] Epoch %05d: early stopping!" % epoch)
            self.model.stop_training = True

# wandb.init(project="your-test-project", entity="YourEntity")
# wandb.config.format = "tf"

BUT_ID = 200543
img_size = (480,640)
train_path_img = "test/test_dataset/images/"
train_path_mask = "test/test_dataset/masks/"
val_path_img = "test_dataset/val/images/"
val_path_mask = "test_dataset/val/masks/"

num_classes = 3
batch_size = 4
num_epochs = 15
backbone = 'mobilenet'
optimizer = 'Adam'
activation_function = 'softmax'

train_img, train_mask = Image_Load().get_data(train_path_img, train_path_mask)
train_img, train_mask = Image_Augment().aug_process(train_img, train_mask, num=90, train=True)

val_img, val_mask = Image_Load().get_data(val_path_img, val_path_mask)
val_img, val_mask = Image_Augment().aug_process(val_img, val_mask, num=10, train=False)

random.Random(BUT_ID).shuffle(train_img)
random.Random(BUT_ID).shuffle(train_mask)
random.Random(BUT_ID).shuffle(val_img)
random.Random(BUT_ID).shuffle(val_mask)

train_data = Image_Preprocess(batch_size, img_size, train_img, train_mask)
val_data = Image_Preprocess(batch_size, img_size, val_img, val_mask)

callbacks = [
    EarlyStoppingByLossVal(),
    wandb.keras.WandbCallback(),
    keras.callbacks.ModelCheckpoint("Unet-Mobilenet", save_best_only=True)
]

In [None]:
keras.backend.clear_session()

model = sm.Unet(backbone, classes=num_classes, activation=activation_function)
model.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics = [keras.metrics.SparseCategoricalAccuracy()])

train_steps_per_epoch = np.floor(len(train_data) / 4)
val_steps_per_epoch = np.floor(len(val_data) / 4)

model.fit(
    train_data,
    steps_per_epoch=train_steps_per_epoch,
    batch_size=batch_size,
    epochs=num_epochs,
    validation_data=val_data,
    validation_steps=val_steps_per_epoch,
    callbacks=callbacks
)

In [None]:
keras.backend.clear_session()

base_model, layers, layer_names = tasm.create_base_model(name=backbone, weights="imagenet", height=img_size[0], width=img_size[1])
model = tasm.ASPOCRNet(n_classes=num_classes, base_model=base_model, output_layers=layers, backbone_trainable=False)

model.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics = [keras.metrics.SparseCategoricalAccuracy()])

train_steps_per_epoch = np.floor(len(train_data) / 4)
val_steps_per_epoch = np.floor(len(val_data) / 4)

model.fit(
    train_data,
    steps_per_epoch=train_steps_per_epoch,
    batch_size=batch_size,
    epochs=num_epochs,
    validation_data=val_data,
    validation_steps=val_steps_per_epoch,
    callbacks=callbacks
)