In [None]:
# Print GPU info
!nvidia-smi

In [None]:
# Get & unzip data
!rm -rf ../content/*
!wget https://graphicwg.irafm.osu.cz/storage/carvana.zip --no-check-certificate
!unzip -q carvana.zip
!rm carvana.zip

In [None]:
# Install dependencies
!pip install segmentation_models

In [None]:
# Set environment variable so segmentation_models uses correct Keras
%env SM_FRAMEWORK = tf.keras

# Import libraries
import segmentation_models as sm
import cv2
from glob import glob
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np
import random
import albumentations as A
from tensorflow.keras.callbacks import ReduceLROnPlateau

In [149]:
# Create & compile model
model = sm.Unet("resnet50", classes=2, input_shape=(320, 480, 3), encoder_weights=None)
model.compile("Adam", sm.losses.DiceLoss())

In [None]:
# Load data and labels into RAM
def load_data(data_folder, label_folder):
  data_paths = sorted(glob(data_folder + "*.*"))
  data = [cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) for path in tqdm(data_paths, desc="Loading data from '" + data_folder + "'")]

  labels_paths = [path.replace(data_folder, label_folder).replace(".jpg", "_mask.png") for path in data_paths]
  labels = [cv2.imread(path, 0) for path in tqdm(labels_paths, desc="Loading labels from '" + label_folder + "'")]

  return data, labels

x_train, y_train = load_data('./train/', './train_masks/')
x_valid, y_valid = load_data('./valid/', './valid_masks/')

In [151]:
# Augment image
def augment(image, label):
  t = A.Compose([
                 A.OneOf([
                          A.RandomBrightness(),
                          A.RandomContrast(),
                          A.RandomGamma(),
                          A.Rotate()
                 ])
  ])

  augmented = t(image=image, mask=label)

  aug_image = augmented['image']
  aug_label = augmented['mask']

  if bool(random.getrandbits(1)):

    MAX_HUE = 179

    image_hue_changed = aug_image.copy()
    hsv_image = cv2.cvtColor(image_hue_changed, cv2.COLOR_RGB2HSV)
    random_h = random.randint(0, MAX_HUE)
    hsv_image[..., 0] = (hsv_image[..., 0] + random_h) % MAX_HUE

    image_hue_changed = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2RGB)

    aug_image[aug_label[..., 0] == 0] = image_hue_changed[aug_label[..., 0] == 0]

  return aug_image, aug_label

In [152]:
# Train sizes
BATCH_SIZE = 8

STEPS_PER_EPOCH = len(x_train) // BATCH_SIZE
VALIDATION_STEPS = len(x_valid) // BATCH_SIZE

EPOCHS = 20

In [153]:
# Data generator wrappers, 'cause you can't have parameters in function called with Keras train
def data_generator_wrapper_train():
    return data_generator(True)

def data_generator_wrapper_valid():
    return data_generator(False)

In [154]:
# Data generator
def data_generator(is_train):
    while True:
        g_images = []
        g_labels = []
        for b in range(BATCH_SIZE):
            g_im, g_la = get_random_image_and_label(is_train)

            g_im, g_la = augment(g_im, g_la)

            g_images.append(g_im)
                  
            mask = np.zeros((g_la.shape[0], g_la.shape[1], 2))
            mask[..., 0] = g_la == 0
            mask[..., 1] = g_la == 255
            g_labels.append(mask)
            
        g_images = np.asarray(g_images, dtype=float)
        g_labels = np.asarray(g_labels, dtype=float)

        yield g_images, g_labels

In [155]:
# Get random image (with coresponding label)
def get_random_image_and_label(is_train):
    train_size = len(x_train)
    valid_size = len(x_valid)
    if is_train:
        index = random.randrange(0, train_size)
        g_im, g_la = x_train[index], y_train[index]
    else:
        index = random.randrange(0, valid_size)
        g_im, g_la = x_valid[index], y_valid[index]

    return g_im, g_la

In [None]:
# Train
lro = ReduceLROnPlateau(patience=1, verbose=1)

history = model.fit(data_generator_wrapper_train(),
          batch_size=BATCH_SIZE,
          validation_data=data_generator_wrapper_valid(),
          epochs=EPOCHS,
          steps_per_epoch=STEPS_PER_EPOCH,
          validation_steps=VALIDATION_STEPS,
          callbacks=[lro])

Epoch 1/20
110/611 [====>.........................] - ETA: 9:32 - loss: 0.2638