In [None]:
import tensorflow as tf
from tensorflow import keras 
from tensorflow.keras import layers

from pygit2 import Repository
import time
from dotenv import load_dotenv
import mlflow
import matplotlib.pyplot as plt
import numpy as np
import sys  
sys.path.insert(0, '../src/data/')
from dataset_generators import *

load_dotenv()

In [None]:
IMAGE_HEIGHT = 128
IMAGE_WIDTH = 1024
CHANNELS = 3

In [None]:
test_mask_dir = "../data/processed/test/mask/"
test_img_dir = "../data/processed/test/img/"

train_dir = "../data/processed/train/"
val_dir = "../data/processed/validation/"
test_dir = "../data/processed/test/"

# Нейронная сеть

## Архитектура сети

In [None]:
def res_down_block(inputs, num_filters):
    x = layers.SeparableConv2D(filters=num_filters, kernel_size=3, 
                              padding = 'same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.SeparableConv2D(filters=num_filters, kernel_size=3, 
                              padding = 'same')(x)
    
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    #residual
    
    y = layers.SeparableConv2D(filters=num_filters, kernel_size=3, 
                              padding = 'same')(inputs)
    x = layers.concatenate([x,y])
    x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
    return x

def res_up_block(inputs, num_filters):
    x = layers.Conv2DTranspose(filters=num_filters, kernel_size=3, 
                              padding = 'same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2DTranspose(filters=num_filters, kernel_size=3, 
                              padding = 'same')(x)
    
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    #residual
    
    y = layers.Conv2D(num_filters, 1, padding="same")(inputs)
    
    x = layers.concatenate([x,y])
    x = layers.UpSampling2D(2)(x)
    return x
    
# https://keras.io/examples/vision/oxford_pets_image_segmentation/
def get_model(img_size, num_classes):
    
    inputs = keras.Input(shape=img_size + (CHANNELS,))

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(8, 3, strides=2, padding="same")(inputs)   
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    
    
    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in [8]:
        x = res_down_block(x, filters)
    ### [Second half of the network: upsampling inputs] ###

    for filters in [8, 4]:
        x= res_up_block(x, filters)

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(num_classes, 3, activation="sigmoid", padding="same")(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model


# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()

# Build model
model = get_model(img_size=(IMAGE_HEIGHT, IMAGE_WIDTH), num_classes=1)
keras.utils.plot_model(model, show_shapes=True)

## Dice metric and loss

In [None]:
def dice_coef(y_true, y_pred):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection + 1) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + 1)

def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

## Model compile

In [None]:
# define optomizer
optim = keras.optimizers.Adam(0.001)
bce   = keras.losses.BinaryCrossentropy()
metrics = [dice_coef, "accuracy"]

# compile keras model with defined optimozer, loss and metrics
model.compile(optimizer=optim, loss=dice_coef_loss, metrics=metrics)

## Callbacks

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def __init__(self, dataset, batch_interval=400):
        self.dataset = dataset
        self.batch_interval = batch_interval
    
    def display(self, display_list, extra_title=''):
        plt.figure(figsize=(20, 20))
        title = ['Input Image', 'True Mask', 'Predicted Mask']

        if len(display_list) > len(title):
            title.append(extra_title)

        for i in range(len(display_list)):
            plt.subplot(1, len(display_list), i+1)
            plt.title(title[i])
            plt.imshow(display_list[i])
            plt.axis('off')
        plt.show()
        
    def create_mask(self, pred_mask):
        pred_mask = (pred_mask > 0.5).astype("int32")
        return pred_mask[0]
    
    def show_predictions(self, dataset, num=1):
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            self.display([image[0], mask[0], self.create_mask(pred_mask)])
        
    def on_batch_end(self, batch, logs=None):
        if batch and batch % self.batch_interval == 0:
            self.show_predictions(self.dataset)
            print ('\nSample Prediction after batch {}\n'.format(batch))

## Data generators

In [None]:
epochs = 2
batch_size = 8
train_datagen = img_mask_generator(train_dir, batch_size=batch_size, epochs = epochs)
val_datagen = img_mask_generator(val_dir, batch_size=batch_size, split = 'val', epochs = epochs)

## Model Training

In [None]:
mlflow.set_tracking_uri('http://localhost:5000')
mlflow.tensorflow.autolog()
mlflow.set_experiment(experiment_name=Repository('.').head.shorthand)

with mlflow.start_run():
    
    model.fit(
        train_datagen,
        steps_per_epoch=len(train_datagen),
        epochs=epochs,
        validation_data=val_datagen,
        validation_steps=len(val_datagen),
        callbacks=[DisplayCallback(train_datagen)]
    )
    
    
    time_of_prediction_array = np.zeros(100)
    
    test_image_path = '../data/processed/test/img/210416D.003_1400701D.E003_96860_2.jpg'
    test_img = read_data_img(test_image_path)
    test_img = tf.reshape(test_img, [1, 128, 1024, CHANNELS])
    
    for i in range(100):
        
        start = time.perf_counter()
        mask = model.predict(test_img)
        end = time.perf_counter()
        
        time_of_prediction_array[i] = end - start
        
    speed_metrics = {'pred_time': time_of_prediction_array[1:].mean(),
                     'RMSE_pred_time': time_of_prediction_array[1:].std() }    
    mlflow.log_metrics(speed_metrics)

## Model Test

In [None]:
test_image_path = '../data/processed/test/img/210416D.003_1400701D.E003_96860_2.jpg'
test_img = read_data_img(test_image_path)
test_img = tf.reshape(test_img, [1, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS])

In [None]:
plt.figure(figsize=(20, 10))
plt.imshow(test_img[0])
plt.figure(figsize=(20, 10))
mask = model.predict(test_img)
plt.imshow(mask[0])