In [None]:
!pip install segmentation-models

In [None]:
import tensorflow as tf
import numpy as np
import cv2
from random import randint
import matplotlib.pyplot as plt
import keras
import os

In [None]:
TRAIN_DATA_SIZE = 9052
VAL_DATA_SIZE = 2246
BATCH_SIZE = 32

In [None]:
'''
Соотношение размера тестовой выборки к валидационной 80:20
'''
files_list = os.listdir("./data/images")
train_files_list = files_list[:9052]
validate_files_list = files_list[9052:]

In [None]:
def batch_generator(df, sample_size, batch_size):
    '''
    Генератор батчей
        Параметры:
        df: Список названий файлов, относящихся к выборке
        sample_size: Размер выборки
        batch_size: Размер батча
    '''
    while True:
        x_batch = []
        y_batch = []
        for i in range(batch_size):
            img_name = df[randint(0, sample_size-1)]
            img = cv2.imread('./data/images/{}'.format(img_name))
            mask = cv2.imread('./data/masks/{}'.format(img_name), 0)
            mask = mask.reshape(448, 448, 1)
            
            img = cv2.resize(img, (256, 256))
            mask = cv2.resize(mask, (256, 256))
            
            x_batch += [img]
            y_batch += [mask]
        x_batch = np.array(x_batch) / 255.
        y_batch = np.array(y_batch) / 255.
        
        yield x_batch, y_batch

In [None]:
import segmentation_models as sm
from segmentation_models import Unet
from segmentation_models.losses import bce_jaccard_loss
from segmentation_models.metrics import iou_score

sm.set_framework('tf.keras')
sm.framework()

BACKBONE = 'resnet50'

preprocess_input = sm.get_preprocessing(BACKBONE)

model_resnet = Unet(BACKBONE, encoder_weights='imagenet', input_shape=(256, 256, 3))

adam = tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
model_resnet.compile(adam, loss=bce_jaccard_loss, metrics=[iou_score])

In [None]:
best_w = keras.callbacks.ModelCheckpoint('resnet50_best.h5',
                                monitor='val_loss',
                                verbose=0,
                                save_best_only=True,
                                save_weights_only=True,
                                mode='auto',
                                psave_freq=1)

last_w = keras.callbacks.ModelCheckpoint('resnet50_last.h5',
                                monitor='val_loss',
                                verbose=0,
                                save_best_only=False,
                                save_weights_only=True,
                                mode='auto',
                                save_freq=1)


callbacks = [best_w, last_w]

In [None]:
hist = model_resnet.fit_generator(batch_generator(train_files_list, TRAIN_DATA_SIZE, BATCH_SIZE),
                    steps_per_epoch=TRAIN_DATA_SIZE // BATCH_SIZE,
                    epochs=30,
                    verbose=1,
                    callbacks=callbacks,
                    validation_data=batch_generator(validate_files_list, VAL_DATA_SIZE, BATCH_SIZE),
                    validation_steps=VAL_DATA_SIZE // BATCH_SIZE,
                    class_weight=None,
                    max_queue_size=10,
                    workers=1,
                    use_multiprocessing=False,
                    shuffle=True,
                    initial_epoch=0)

model_resnet.save("unet_resnet_final.model", save_format='h5')

In [None]:
N = np.arange(0, 30)
plt.style.use("ggplot")

plt.figure()
plt.plot(N, hist.history["loss"], label="train_loss")
plt.plot(N, hist.history["val_loss"], label="val_loss")
plt.title("Training and validation loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig("unet_resnet_losses.png")

plt.figure()
plt.plot(N, hist.history["iou_score"], label="train_iou_score")
plt.plot(N, hist.history["val_iou_score"], label="val_iou_score")
plt.title("Training and validation iou score")
plt.xlabel("Epoch")
plt.ylabel("iou_score")
plt.legend()
plt.savefig("unet_resnet_iou_score.png")

plt.show()