In [None]:
#!g1.1
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
import keras
import tensorflow as tf
from keras.models import Model
from keras.layers import Conv2D, MaxPooling2D, Input, Conv2DTranspose, Concatenate, BatchNormalization, UpSampling2D
from keras.layers import  Dropout, Activation
from keras.optimizers import Adam, SGD
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from keras import backend as K
# from keras.utils import plot_model
import glob
import random
import cv2
from random import shuffle, randint, getrandbits
tf.device('/physical_device:GPU:0')

In [None]:
#!g1.1

def mean_iou(y_true, y_pred):
    yt0 = y_true[:,:,:,0]
    yp0 = K.cast(y_pred[:,:,:,0] > 0.5, 'float32')
    inter = tf.math.count_nonzero(tf.logical_and(tf.equal(yt0, 1), tf.equal(yp0, 1)))
    union = tf.math.count_nonzero(tf.add(yt0, yp0))
    iou = tf.where(tf.equal(union, 0), 1., tf.cast(inter/union, 'float32'))
    return iou

In [None]:
#!g1.1
def unet(sz = (256, 256, 3)):
    x = Input(sz)
    inputs = x
  
  #down sampling 
    f = 8
    layers = []
  
    for i in range(0, 6):
        x = Conv2D(f, 3, activation='relu', padding='same') (x)
        x = Conv2D(f, 3, activation='relu', padding='same') (x)
        layers.append(x)
        x = MaxPooling2D() (x)
        f = f*2
    ff2 = 64 
  
  #bottleneck 
    j = len(layers) - 1
    x = Conv2D(f, 3, activation='relu', padding='same') (x)
    x = Conv2D(f, 3, activation='relu', padding='same') (x)
    x = Conv2DTranspose(ff2, 2, strides=(2, 2), padding='same') (x)
    x = Concatenate(axis=3)([x, layers[j]])
    j = j -1 
  
  #upsampling 
    for i in range(0, 5):
        ff2 = ff2//2
        f = f // 2 
        x = Conv2D(f, 3, activation='relu', padding='same') (x)
        x = Conv2D(f, 3, activation='relu', padding='same') (x)
        x = Conv2DTranspose(ff2, 2, strides=(2, 2), padding='same') (x)
        x = Concatenate(axis=3)([x, layers[j]])
        j = j -1 
    
  
  #classification 
    x = Conv2D(f, 3, activation='relu', padding='same') (x)
    x = Conv2D(f, 3, activation='relu', padding='same') (x)
    outputs = Conv2D(1, 1, activation='sigmoid') (x)
  
  #model creation 
    model = Model(inputs=[inputs], outputs=[outputs])
    model.compile(optimizer = 'rmsprop', loss = 'binary_crossentropy', metrics = [mean_iou])
  
    return model

In [None]:
#!g1.1
model = unet()
model.summary()

In [None]:
#path_dataset = "datasets/segm_dataset_learn"
#folder_chkpt = "segm_dataset_learn"
datasets_name = "dataset_grain_0.1"
path_dataset = f"datasets/Grain/{datasets_name}"
folder_chkpt = f"Segments/{datasets_name}"

In [None]:
#!g1.1
def image_generator(files, batch_size = 32, sz = (256, 256)):
  
  while True: 
    
    #extract a random batch 
    batch = np.random.choice(files, size = batch_size)    
    
    #variables for collecting batches of inputs and outputs 
    batch_x = []
    batch_y = []
    
    
    for f in batch:

        #get the masks. Note that masks are png files 
        mask = mask2 = Image.open(f'{path_dataset}/masks/{f}')
        mask = np.array(mask.resize(sz))[:,:,0]
        mask2 = np.array(mask2)[:,:,0]

        #preprocess the mask 
        # mask[mask >= 0.5] = 1 
        # mask[mask < 0.5] = 0

        mask2[mask2 >= 0.5] = 1 
        mask2[mask2 < 0.5] = 0

        rotate = bool(random.getrandbits(1))
        if rotate:
            mask = cv2.rotate(mask, cv2.ROTATE_90_CLOCKWISE)
        batch_y.append(mask)

        #preprocess the raw images 
        raw = Image.open(f'{path_dataset}/images/{f}')
        h, w = raw.size
        raw = np.array(raw)
        mask2 = np.stack((mask2,)*3, axis=-1)
        
        # raw = raw.resize(sz)
        raw = cv2.resize(raw, sz, interpolation = cv2.INTER_CUBIC)
        raw = np.array(raw)
        #print(raw.shape, mask.shape)
        if rotate:
            raw = cv2.rotate(raw, cv2.ROTATE_90_CLOCKWISE)
        #batch_x.append(np.stack((raw,)*3,axis=-1))
        batch_x.append(raw)

    #preprocess a batch of images and masks 
    batch_x = np.array(batch_x)/255.
    batch_y = np.array(batch_y)/255.
    batch_y = np.expand_dims(batch_y,3)
    yield (batch_x, batch_y) 

In [None]:
#!g1.1
batch_size = 5
epochs = 150
all_files = os.listdir(f'{path_dataset}/images')
shuffle(all_files)
split = int(0.85 * len(all_files))
#split into training and testing
train_files = all_files[0:split]
test_files  = all_files[split:]
train_generator = image_generator(train_files, batch_size = batch_size)
test_generator  = image_generator(test_files, batch_size = batch_size)


In [None]:
#!g1.1
n_learn = 3
def build_callbacks():
    os.makedirs(f'chkpt/Unet/{folder_chkpt}/learn_{n_learn}', exist_ok=True)
    checkpointer = ModelCheckpoint(filepath='chkpt/Unet/'+str(folder_chkpt)+'/learn_'+str(n_learn)+'/model_{epoch:02d}loss_{val_loss:.3f}_mi_{val_mean_iou:.2f}.h5', verbose=1, monitor='val_mean_iou', mode="max", save_best_only=True, save_weights_only=True)
    callbacks = [checkpointer, PlotLearning()]
    return callbacks

# inheritance for training process plot 
class PlotLearning(tf.keras.callbacks.Callback):

    def on_train_begin(self, logs={}):
        self.i = 0
        self.x = []
        self.losses = []
        self.val_losses = []
        self.acc = []
        self.val_acc = []
        #self.fig = plt.figure()
        self.logs = []
    def on_epoch_end(self, epoch, logs={}):
        self.logs.append(logs)
        self.x.append(self.i)
        self.losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))
        self.acc.append(logs.get('mean_iou'))
        self.val_acc.append(logs.get('val_mean_iou'))
        self.i += 1
        print('i=',self.i,'loss=',logs.get('loss'),'val_loss=',logs.get('val_loss'),'mean_iou=',logs.get('mean_iou'),'val_mean_iou=',logs.get('val_mean_iou'))
        
        #choose a random test image and preprocess
        path = np.random.choice(test_files)
        raw = Image.open(f'{path_dataset}/images/{path}')
        #raw = np.stack((np.array(raw.resize((256, 256)))/255.,)*3, axis=-1)
        raw = np.array(raw.resize((256, 256)))/255
        #raw.resize((256, 256))/255
        raw = raw[:,:,0:3]
        mask_true = Image.open(f'{path_dataset}/masks/{path}')
        mask_true = np.array(mask_true.resize((256, 256)))        
        #predict the mask 
        pred = model.predict(np.expand_dims(raw, 0))
        
        #mask post-processing 
        msk  = pred.squeeze()
        msk = np.stack((msk,)*3, axis=-1)
        msk[msk >= 0.5] = 1 
        msk[msk < 0.5] = 0 

        mask_true_bin = mask_true[..., 0]
        msk_bin = msk[..., 0]*255
        contours_true, _ = cv2.findContours(mask_true_bin.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours_pred, _ = cv2.findContours(msk_bin.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Создаем копию исходного изображения для рисования
        result_img = raw.copy()

        # Рисуем контуры (цвета в BGR: красный и зеленый)
        result_img = cv2.drawContours(result_img, contours_true, -1, (255, 0, 0), 1)  # красный
        result_img = cv2.drawContours(result_img, contours_pred, -1, (0, 0, 255), 1)  # зеленый
        plt.imshow(result_img)
        plt.axis('off')
        plt.show()
        #show the mask and the segmented image 
        # combined = np.concatenate([raw, raw*msk, raw*mask_true], axis = 1)
        # plt.axis('off')
        # plt.imshow(combined)
        # plt.show()
        #if self.i == epochs:

        
    def on_train_end(self, logs=None):
        fig, axs = plt.subplots(2)
       
        axs[0].plot(self.x, self.acc, marker = 'o', ms=4, label="Train")
        axs[0].plot(self.x, self.val_acc, marker = 'o', ms=4, label="Validation")
        axs[0].set_xlabel("Epochs, Num")
        axs[0].legend()
        axs[0].set_ylabel("Value Accuracy (Mean IOU), %")
        axs[0].set_xticks(np.arange(min(self.x), max(self.x)+1, 1.0))
        axs[0].grid()

        axs[1].plot(self.x, self.losses, marker = 'o', ms=4, label="Train")
        axs[1].plot(self.x, self.val_losses, marker = 'o', ms=4, label="Validation")
        axs[1].set_xlabel("Epochs, Num")
        axs[1].legend()
        axs[1].set_ylabel("Value Loss, %")
        axs[1].set_xticks(np.arange(min(self.x), max(self.x)+1, 1.0))
        axs[1].grid()
        plt.savefig('chkpt/Unet/'+str(folder_chkpt)+'/learn_'+str(n_learn)+'.png')
        import pandas as pd
        data = {
            'epochs': self.x,
            'train_acc': self.acc,
            'val_acc': self.val_acc,
            'train_loss': self.losses,
            'valid_loss': self.val_losses
        }
        df = pd.DataFrame(data)
        df.to_csv('chkpt/Unet/'+str(folder_chkpt)+'/learn_'+str(n_learn)+'.csv')

        plt.show()

In [None]:
#!g1.1
x, y = next(train_generator)

In [None]:
#!g1.1
train_steps = len(train_files)//batch_size
test_steps = 1
model.load_weights(f'chkpt/Unet/{folder_chkpt}/learn_{n_learn-1}/model_final[WITH_UP].h5')
model.fit_generator(train_generator, 
                    epochs = epochs, steps_per_epoch = train_steps, validation_data = test_generator, validation_steps = test_steps,
                    callbacks = build_callbacks(), verbose = 0)
model.save(f'chkpt/Unet/{folder_chkpt}/learn_{n_learn}/model_final[WITH_UP].h5')