In [None]:
# !git clone https://github.com/jakeret/unet

In [None]:
# !pip install ./unet/

In [None]:
import numpy as np
import pandas as pd
import os
import random
import cv2
from PIL import Image
import pickle

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from imgaug import augmenters as augs

import unet

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# main_path = "/content/drive/MyDrive/Bakalaurinis/Dataset"
main_path = "./dataset/"
train_path = os.path.join(main_path, 'train_images')
df = pd.read_csv(os.path.join(main_path, "train.csv"))
df = df[df["label"].astype(str) == '4']

In [None]:
train_split = 0.8
train_count = int(len(df)*train_split)
training_df = df[:train_count]
validation_df = df[train_count:]

In [None]:
def get_image(img_path, target_size=512):
    img = Image.open(img_path)
    img_height, img_width = img.size
    img = np.array(img)
    y = random.randint(0,img_height-target_size)
    x = random.randint(0,img_width-target_size)
    cropped_img = img[x:x+target_size , y:y+target_size,:]
    return cropped_img

def get_groundtruth_mask(img, target_size=512):
    img = cv2.GaussianBlur(img, (35, 35), 0)
    cive_band = 0.441 * img[:, :, 0] - 0.881 * img[:, :, 1] + 0.385 * img[:, :, 2] + 18.787
    normalized_cive_band = cv2.normalize(cive_band, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8UC1)
    _, otsu_mask = cv2.threshold(normalized_cive_band, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    masks = np.stack([otsu_mask == 0, otsu_mask == 255], axis=2).astype(np.uint8)
    return masks

In [None]:
augmentation = augs.Sequential([
    augs.Multiply((0.5, 1.5)),
    augs.Affine(scale=(1, 1.2)),
    augs.Sometimes(0.4, augs.GaussianBlur(sigma=(0,2))),
    augs.Sometimes(0.3, augs.Grayscale(alpha=(0.0, 1.0))),
    augs.Sometimes(0.3, augs.SigmoidContrast(gain=(3, 6), cutoff=(0.4, 0.6))),
    augs.Sometimes(0.3, augs.CoarseDropout((0.0, 0.05), size_percent=(0.05, 0.6), per_channel=0.5))
])

In [None]:
def custom_generator(image_path_list, folder, batch_size=16, training_mode=True):
  while True:
    for start in range(0, len(image_path_list), batch_size):
      end = min(start + batch_size, len(image_path_list))
      images = [get_image(os.path.join(folder, path)) for path in image_path_list[start:end]]
      groundtruth = [get_groundtruth_mask(image) for image in images]
      if training_mode:
          images = augmentation.augment_images(images=images)
      yield np.array(images)/255., np.array(groundtruth)

In [None]:
batch_size = 8
callbacks = [ReduceLROnPlateau(monitor='val_loss', patience=1, verbose=1, factor=0.5),
             EarlyStopping(monitor='val_loss', patience=3, , verbose=1),
             ModelCheckpoint(filepath=os.path.join(main_path, 'UNet.h5'), monitor='val_loss', save_best_only=True)]

In [None]:
unet_model = unet.build_model(512,
                              channels=3,
                              num_classes=2,
                              layer_depth=4,
                              filters_root=64,
                              padding="same")

unet.finalize_model(unet_model,
                    loss=tf.keras.losses.BinaryCrossentropy(),
                    metrics=[tf.keras.metrics.BinaryAccuracy()],
                    auc=False,
                    learning_rate=1e-4)

In [None]:
history = unet_model.fit(custom_generator(training_df["image_id"], train_path, batch_size=batch_size, training_mode=True),
                  steps_per_epoch=int(len(training_df) / batch_size),
                  epochs=20,
                  validation_data=custom_generator(validation_df["image_id"], train_path, batch_size=batch_size, training_mode=False),
                  validation_steps=int(len(validation_df) / batch_size),
                  callbacks=callbacks)

Epoch 1/10
  1/257 [..............................] - ETA: 16:06:11 - loss: 0.6933 - binary_accuracy: 0.5005 - mean_iou: 0.3252 - dice_coefficient: 0.5002

In [None]:
unet_model.save(os.path.join(main_path, 'UNetFinal'))

with open(os.path.join(main_path, 'UNetHistory.pkl'), 'wb') as f:
    pickle.dump(history.history, f)