In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.utils import plot_model
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, roc_curve, auc
from sklearn.utils import class_weight
import numpy as np
import os
import warnings
import pydot as pyd
import seaborn as sns
import pandas as pd
from itertools import cycle
from timeit import default_timer as timer
import matplotlib.pyplot as plt

warnings.simplefilter(action = 'ignore', category = FutureWarning)

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID";

# GPUid to use
os.environ["CUDA_VISIBLE_DEVICES"] = "0";

# Allow growth of GPU memory, otherwise it will always look like all the memory is being used
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
# Insert initial parameters
batch_size = 64
img_height, img_width = 224, 224
n_classes = 3

# Data augmentation
train_datagen = ImageDataGenerator(horizontal_flip = True, 
                                   vertical_flip = True,
                                   brightness_range = [0.5, 1.5],
                                   samplewise_center = True,
                                   rescale = 1./255)

test_datagen = ImageDataGenerator(samplewise_center = True, rescale = 1./255)

train = train_datagen.flow_from_directory('/local/data1/elech646/Tumor_grade_classification/dataset224_t1_sagittal/train', 
                                          classes = ['G2','G3','G4'], color_mode = 'rgb', 
                                          class_mode = 'categorical', 
                                          target_size = (img_height, img_width), 
                                          batch_size = batch_size, seed = 123)
validation = test_datagen.flow_from_directory('/local/data1/elech646/Tumor_grade_classification/dataset224_t1_sagittal/val', 
                                              classes = ['G2','G3','G4'], color_mode = 'rgb',
                                              class_mode = 'categorical', 
                                              target_size = (img_height, img_width), 
                                              batch_size = batch_size, seed = 123)
test = test_datagen.flow_from_directory('/local/data1/elech646/Tumor_grade_classification/dataset224_t1_sagittal/test', 
                                        classes = ['G2','G3','G4'], color_mode = 'rgb',
                                        shuffle = False, class_mode = 'categorical', 
                                        target_size = (img_height, img_width), 
                                        batch_size = batch_size)

In [11]:
# Use class weights
class_weights = class_weight.compute_class_weight(class_weight = 'balanced', 
                                                  classes = np.unique(train.classes), 
                                                  y = train.classes)

# Convert to dictionary
class_weights = dict(enumerate(class_weights))

In [None]:
x_train, y_train = next(iter(train))
print(x_train.shape, y_train.shape)
print(x_train.min())

# Plot images for sanity checking
def plot_images(images):
    fig, axes = plt.subplots(1, 5, figsize = (20, 20))
    axes = axes.flatten()
    for img, ax in enumerate(axes):
        ax.imshow(images[img,:,:], cmap = 'gray')
        ax.axis('off')
    plt.tight_layout()
    plt.show()
    
plot_images(x_train[:1000])

In [None]:
# Load base model
resnet_50 = ResNet50(input_shape = (224, 224, 3), weights = 'imagenet', include_top = False)

freeze_until_layer = 70

# Freeze all layers before the `freeze_until_layer` layer
for layer in resnet_50.layers[:-freeze_until_layer]:
    layer.trainable = False

x = resnet_50.output
x = Flatten()(x)
x = Dense(20, activation = 'relu')(x)
x = Dropout(0.3)(x)
predictions = Dense(n_classes, activation = 'softmax')(x)

model = Model(inputs = resnet_50.input, outputs = predictions)
model.summary()

In [15]:
# Save best model
checkpoint_path = '/local/data1/elech646/code/train_logs/resnet50_transfer_t1_sagittal.h5'
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_path,
                                                monitor = 'val_accuracy',
                                                mode = 'max',
                                                verbose = 1,
                                                save_best_only = True)

# Save log for history
# append: True: append if file exists (useful for continuing training)
#         False: overwrite existing file
csv_logger = CSVLogger('/local/data1/elech646/code/train_logs/resnet50_transfer_history_t1_sagittal.log', 
                       separator = ',', append = True)

# Reduce learning rate if val_accuracy is not improving
reduce_lr = ReduceLROnPlateau(monitor = 'val_accuracy', factor = 0.1,
                              patience = 5, min_lr = 0.000001)
es = EarlyStopping(monitor = 'val_accuracy', verbose = 1, 
                   patience = 15)

In [None]:
# Compile model
model.compile(Adam(lr = 1e-5),
              loss = 'categorical_crossentropy',
              metrics = ['accuracy'])

epochs = 35
start = timer()

history = model.fit(train, steps_per_epoch = len(train.labels) // batch_size, verbose = 1,
                    epochs = epochs, validation_data = validation,
                    validation_steps = len(validation.labels) // batch_size,
                    callbacks = [es, reduce_lr, checkpoint, csv_logger])

end = timer()
print("Training time: %.2f s\n" % (end - start))

In [None]:
# serialize model to JSON
model_json = model.to_json()
with open("resnet50_transfer_t1_sagittal.json", "w") as json_file:
    json_file.write(model_json)
    
# serialize weights to HDF5
model.save_weights("resnet50_transfer_t1_sagittal.h5")

In [None]:
# load json and create model
json_file = open('resnet50_transfer_t1_sagittal.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)

# load weights into new model
loaded_model.load_weights("resnet50_transfer_t1_sagittal.h5")
print("Loaded model from disk")

1st training:  

Setup: `x = resnet_50.output
        x = Flatten()(x)
        x = Dense(20, activation = 'relu')(x)
        x = Dropout(0.3)(x)
        predictions = Dense(n_classes, activation = 'softmax')(x)`
       
batch size = 64 

number of epochs = 35

learning rate = 1e-5 with `ReduceLROnPlateau` + `EarlyStopping`

training time: 414.66 s $\approx 6$ min

test accuracy: 0.6406 (overfit)

------------------------------------------------------------------------------------------------------------------

2nd training: 

Setup: `x = resnet_50.output
        x = Flatten()(x)
        x = Dense(20, activation = 'relu')(x)
        x = Dropout(0.3)(x)
        predictions = Dense(n_classes, activation = 'softmax')(x)` 
        
froze bottom 100 layers instead

batch size = 64 

number of epochs = 30

learning rate = 1e-5 with `ReduceLROnPlateau` + `EarlyStopping`

training time: 380 s $\approx 6$ min

test accuracy: 70.83%

In [None]:
# Print test loss + accuracy 
score = model.evaluate(test, steps = len(test.labels) // batch_size, verbose = 0)
print('Test loss: %.4f' % score[0])
print('Test accuracy: %.4f' % score[1])

In [None]:
# Plot training + validation accuracy per epoch
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
n_epochs = range(len(acc))

# for fancy LaTeX style plots
from matplotlib import rc
import matplotlib.pylab as plt

rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
rc('text', usetex = True)

# Plot accuracy
plt.plot(n_epochs, acc, label = 'Training accuracy')
plt.plot(n_epochs, val_acc, label = 'Validation accuracy')
plt.title('ResNet50: T1')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend(loc = 'best')
#plt.savefig('ResNet50_t1_sagittal_acc.png', dpi = 300)
plt.show()

# Plot loss
plt.plot(n_epochs, loss, label = 'Training loss')
plt.plot(n_epochs, val_loss, label = 'Validation loss')
plt.title('ResNet50: T1')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(loc = 'best')
#plt.savefig('ResNet50_t1_sagittal_loss.png', dpi = 300)
plt.show()

In [None]:
# Get true labels
y_true = test.classes

# Get rounded predictions
y_pred = np.argmax(model.predict(test), axis = 1) 

# Plot confusion matrix
cm = confusion_matrix(y_true, y_pred)
ax = sns.heatmap(cm, annot = True, fmt = 'g', cmap = 'PuBu')
#ax = sns.heatmap(cm / np.sum(cm), annot = True, fmt = '.2%', cmap = 'PuBu')
#for t in ax.texts: 
#    t.set_text(t.get_text().replace('%', '\%'))
ax.set_xlabel('\nPredicted values')
ax.set_ylabel('Actual values ');
# list must be in alphabetical order
ax.xaxis.set_ticklabels(['Grade 2','Grade 3', 'Grade 4'])
ax.yaxis.set_ticklabels(['Grade 2','Grade 3', 'Grade 4'])
plt.title('ResNet50: T1 \n accuracy = 71.35\%')
#plt.savefig('ResNet50_CM_t1_sagittal.png', dpi = 300)
plt.show()

In [None]:
# Print AUC + weighted AUC score
weighted_auc = roc_auc_score(y_true, y_score, multi_class = 'ovr', average = 'weighted')
auc_score = roc_auc_score(y_true, y_score, multi_class = 'ovr')
print(f'AUC score: {auc_score:.4f}')
print(f'Weighted AUC score: {weighted_auc:.4f}')

In [None]:
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
n_classes = 3

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_true, y_score[:, i], pos_label = i)
    roc_auc[i] = auc(fpr[i], tpr[i])

# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

# Then interpolate all ROC curves at this point
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])

# Finally average it and compute AUC
mean_tpr /= n_classes

fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

# Plot all ROC curves
lw = 2

# plt.figure()
# plt.plot(fpr["macro"], tpr["macro"],
#     label = "macro-average ROC curve (AUC = {0:0.2f})".format(roc_auc["macro"]),
#     color = "navy",
#     linestyle = ":",
#     linewidth = 4,
# )

colors = cycle(["orange", "#9A0EEA", "#06C2AC"])
plt.plot([0, 1], [0, 1], "--", lw = lw, color = "#808080")
for i, color in zip(range(n_classes), colors):
    plt.plot(fpr[i], tpr[i], color = color, lw = lw,
             label = "G{0} (AUC = {1:0.2f})".format(i+2, roc_auc[i]))

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Multi-class ROC: ResNet50 on T1")
plt.legend(loc = "lower right")
#plt.savefig('Multi-class ROC - ResNet50 on T1', dpi = 300); 
plt.show()

In [None]:
# Print classification report
print(classification_report(y_true, y_pred))

# Save the classification report
# clsf_report = pd.DataFrame(classification_report(y_true = y_true, y_pred = y_pred, output_dict = True)).transpose()
# clsf_report.to_csv('Classification Report - ResNet50 on t1 sagittal.csv', index = True)