In [1]:
import pandas as pd
import numpy as np
import S2_0_Loading_Data

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten, BatchNormalization
from keras.regularizers import L2
from keras.callbacks import EarlyStopping
from tensorflow.keras.applications.resnet50 import ResNet50
import keras
import tensorflow as tf
import S4_0_Helper_Functions

In [2]:
train_x, test_x, train_y, test_y = S2_0_Loading_Data.load_data(image_size=(224, 224))
train_x.shape

(3257, 224, 224, 3)

In [3]:
pretrained_RN50 = ResNet50(weights='imagenet', include_top=False, input_shape=(224,224, 3))

In [4]:
weight_decay = 0.0005
for layer in pretrained_RN50.layers:
    layer.trainable = False
model = Sequential()
model.add(pretrained_RN50)
model.add(Flatten())
model.add(Dense(512, kernel_regularizer=L2(weight_decay)))
model.add(Activation('relu'))
model.add(BatchNormalization())

model.add(Dropout(0.1))
model.add(Dense(3))
model.add(Activation('softmax'))

In [5]:
# compile model
METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.CategoricalAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
      keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss='CategoricalCrossentropy',
              metrics=METRICS)

In [6]:
callback = EarlyStopping(monitor='val_prc', mode='max',min_delta = 0.01, verbose=1, patience=10,restore_best_weights=True)

model_hist = model.fit(
    train_x,
    train_y,
    validation_data = (test_x, test_y),
    epochs = 1, # change to 100 to get superior
    verbose = 1,
    callbacks = [callback]
)



In [7]:
preds = model.predict(
    test_x
)

S4_0_Helper_Functions.getAccuracyMetrics(preds, test_y)

accuracy: 0.9546012269938651
recall: 0.9546012269938651
precision: 0.9546012269938651
