In [1]:
from tensorflow.keras import models, layers, regularizers
from tensorflow.keras import backend as K
import os
import cv2
import glob
import PIL
import shutil
import numpy as np
import pandas as pd
from Datagens import *
from AttentionUnet import *
from metrics import *

# ml libs
from tensorflow import keras
import keras.backend as K
from keras.callbacks import CSVLogger
import tensorflow as tf
from tensorflow.keras.utils import plot_model
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard
from tensorflow.keras.layers.experimental import preprocessing

np.set_printoptions(precision=3, suppress=True)

CSVLogger.on_test_begin = CSVLogger.on_train_begin
CSVLogger.on_test_batch_end = CSVLogger.on_epoch_end
CSVLogger.on_test_end = CSVLogger.on_train_end

In [2]:
model = build_attention_unet(n_channels=3, ker_init='he_normal', dropout=0.2)
model.compile(loss="categorical_crossentropy", optimizer=keras.optimizers.Adam(learning_rate=0.001), metrics = ['accuracy',tf.keras.metrics.MeanIoU(num_classes=4), dice_coef, dice_coef_necrotic, dice_coef_edema ,dice_coef_enhancing] )

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 128, 128, 32  896         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_1 (Conv2D)              (None, 128, 128, 32  9248        ['conv2d[0][0]']                 
                                )                                                             

 conv2d_17 (Conv2D)             (None, 16, 16, 64)   8256        ['conv2d_16[0][0]']              
                                                                                                  
 activation_3 (Activation)      (None, 16, 16, 64)   0           ['conv2d_17[0][0]']              
                                                                                                  
 conv2d_19 (Conv2D)             (None, 16, 16, 64)   4160        ['activation_3[0][0]']           
                                                                                                  
 conv2d_transpose_1 (Conv2DTran  (None, 16, 16, 64)  36928       ['conv2d_19[0][0]']              
 spose)                                                                                           
                                                                                                  
 conv2d_18 (Conv2D)             (None, 16, 16, 64)   16448       ['conv2d_5[0][0]']               
          

                                                                                                  
 conv2d_30 (Conv2D)             (None, 64, 64, 32)   9248        ['dropout_4[0][0]']              
                                                                                                  
 conv2d_31 (Conv2D)             (None, 64, 64, 16)   528         ['conv2d_30[0][0]']              
                                                                                                  
 activation_9 (Activation)      (None, 64, 64, 16)   0           ['conv2d_31[0][0]']              
                                                                                                  
 conv2d_33 (Conv2D)             (None, 64, 64, 16)   272         ['activation_9[0][0]']           
                                                                                                  
 conv2d_transpose_3 (Conv2DTran  (None, 64, 64, 16)  2320        ['conv2d_33[0][0]']              
 spose)   

In [3]:
directory = [f.path for f in os.scandir(PATH) if f.is_dir()]


def generate_ids(dirLst):
    ids = []
    for i in range(0,len(dirLst)):
        ids.append(dirLst[i][dirLst[i].rfind('/')+1:])
    return ids

train_and_test_ids = generate_ids(directory); 

    
train_test_ids, val_ids = train_test_split(train_and_test_ids, train_size=236, test_size=60, random_state=14) 
train_ids, test_ids = train_test_split(train_test_ids, test_size=36, random_state=14) 

In [4]:
training_generator = DataGenerator2D(train_ids)
valid_generator = DataGenerator2D(val_ids)
test_generator = DataGenerator2D(test_ids)

In [5]:
csv_logger = CSVLogger('Logs/training/attU_net_2d.log', separator=',', append=False)

eval_csv_logger = CSVLogger('Logs/testing/attU_net_2d.log', separator=',', append=False)


callbacks = [
      keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                              patience=2, min_lr=0.000001, verbose=1),
        csv_logger
    ]

In [6]:
K.clear_session()

history =  model.fit(training_generator,
                    epochs=100,
                    steps_per_epoch=len(train_ids),
                    callbacks= callbacks,
                    validation_data = valid_generator
                    )  
model.save("Models/attU_net_2d.h5")

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 4: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 5/100
Epoch 6/100
Epoch 6: ReduceLROnPlateau reducing learning rate to 4.0000001899898055e-05.
Epoch 7/100
Epoch 8/100
Epoch 8: ReduceLROnPlateau reducing learning rate to 8.000000525498762e-06.
Epoch 9/100
Epoch 10/100
Epoch 10: ReduceLROnPlateau reducing learning rate to 1.6000001778593287e-06.
Epoch 11/100
Epoch 12/100
Epoch 12: ReduceLROnPlateau reducing learning rate to 1e-06.
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100


Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100


Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100


Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100


Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100


Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100


In [None]:
# model = keras.models.load_model('Models/2D/attU_net_2d.h5', 
#                                    custom_objects={ 'accuracy' : tf.keras.metrics.MeanIoU(num_classes=4),
#                                                    "dice_coef": dice_coef,
#                                                    "dice_coef_necrotic": dice_coef_necrotic,
#                                                    "dice_coef_edema": dice_coef_edema,
#                                                    "dice_coef_enhancing": dice_coef_enhancing
#                                                   }, compile=True)

In [7]:
model.evaluate(test_generator, callbacks=eval_csv_logger, verbose=1)



[0.13991454243659973,
 0.955700695514679,
 0.8151941299438477,
 0.35510438680648804,
 0.1707843542098999,
 0.35608401894569397,
 0.21818523108959198]

In [8]:
valid_csv_logger = CSVLogger('Logs/validation/attU_net_2d.log', separator=',', append=False)
model.evaluate(valid_generator, callbacks=valid_csv_logger, verbose=1)



[0.12318429350852966,
 0.9648918509483337,
 0.821353554725647,
 0.357617050409317,
 0.16789422929286957,
 0.3655255436897278,
 0.17907029390335083]