In [1]:
import os
import pickle
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow_addons.metrics import CohenKappa
from tensorflow_addons.losses import SigmoidFocalCrossEntropy
from sklearn.model_selection import train_test_split

In [2]:
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpu_devices[0], False)

In [3]:
model_name = 'XCeption_Mask'

In [4]:
home = os.path.expanduser('~')
base = os.path.join('Datasets', 'ImageCLEF', 'Mean_Slice_Masks')

train_dir = os.path.join(home, base, 'train')
test_dir = os.path.join(home, base, 'test')

In [5]:
seed = 42
shuffle = True
inp_shp = (299, 299)
train_batch_size, val_batch_size = 8, 64

train_datagen = keras.preprocessing.image.ImageDataGenerator(
                    rescale=1./255,
                    horizontal_flip=True
)

val_datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=inp_shp,
        batch_size=train_batch_size,
        seed=seed,
        class_mode='categorical',
        color_mode='rgb',
        shuffle=shuffle
)

val_generator = val_datagen.flow_from_directory(
        test_dir,
        target_size=inp_shp,
        batch_size=val_batch_size,
        seed=seed,
        class_mode='categorical',
        color_mode='rgb',
        shuffle=shuffle
)

Found 731 images belonging to 5 classes.
Found 184 images belonging to 5 classes.


In [6]:
input_shape = (299, 299, 3)

XCeption = keras.applications.Xception(
    include_top=False,
    weights="imagenet",
    input_tensor=None,
    input_shape=input_shape,
    pooling=None,
)

XCeption.trainable = False
XCeption.summary()

Model: "xception"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 299, 299, 3) 0                                            
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, 149, 149, 32) 864         input_1[0][0]                    
__________________________________________________________________________________________________
block1_conv1_bn (BatchNormaliza (None, 149, 149, 32) 128         block1_conv1[0][0]               
__________________________________________________________________________________________________
block1_conv1_act (Activation)   (None, 149, 149, 32) 0           block1_conv1_bn[0][0]            
___________________________________________________________________________________________

In [7]:
# CONv/FC -> BatchNorm -> ReLU(or other activation) -> Dropout -> CONV/FC -> ...

def get_model(base_model, input_shape):
    
    inputs = keras.Input(shape=input_shape)
    x = base_model(inputs, training=False)
    
    x = layers.Conv2D(128, (3, 3), activation=None)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.Dropout(rate=0.5)(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
    x = layers.Flatten()(x)
    
    x = layers.Dense(units=1024, activation=None)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.Dropout(rate=0.5)(x)
    x = layers.Dense(units=64, activation=None)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.Dropout(rate=0.5)(x)
    x = layers.Dense(units=5, activation=None)(x)
    output = layers.Softmax()(x)

    model = keras.Model(inputs=inputs, outputs=output, name=f'{model_name}')
    
    return model
#     model.compile(optimizer="Adam", loss="mse", metrics=["mae"])

    

In [8]:
model = get_model(XCeption, input_shape)
model.summary()

Model: "XCeption_Mask"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 299, 299, 3)]     0         
_________________________________________________________________
xception (Model)             (None, 10, 10, 2048)      20861480  
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 8, 8, 128)         2359424   
_________________________________________________________________
batch_normalization_4 (Batch (None, 8, 8, 128)         512       
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 8, 8, 128)         0         
_________________________________________________________________
dropout (Dropout)            (None, 8, 8, 128)         0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 4, 4, 128)       

In [9]:
model.compile(
    optimizer="Adam", 
    loss='sigmoid_focal_crossentropy',
    metrics=['accuracy']
)

In [10]:
checkpoint_cb = keras.callbacks.ModelCheckpoint(
    f"{model_name}.h5", save_best_only=True
)

early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_loss", patience=64)

history = model.fit(
            train_generator,
            steps_per_epoch=64,
            epochs=64,
            validation_data=val_generator,
            validation_steps=3,
            shuffle=False,
#             callbacks=[checkpoint_cb, early_stopping_cb]
)

Train for 64 steps, validate for 3 steps
Epoch 1/64
Epoch 2/64
Epoch 3/64
Epoch 4/64
Epoch 5/64
Epoch 6/64
Epoch 7/64
Epoch 8/64
Epoch 9/64
11/64 [====>.........................] - ETA: 4s - loss: 0.0695 - accuracy: 0.2500

KeyboardInterrupt: 

In [None]:
plt.plot(history.history['val_accuracy'])

In [None]:
np.max(history.history['val_accuracy'])

In [None]:
model_load = keras.models.load_model('NASNet_Transfer_Mask_2.h5')

In [None]:
model_load.evaluate(val_generator)

In [None]:
with open(f'{model_name}_history.pkl', 'wb') as fh:
    pickle.dump(history.history, fh)