In [1]:
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import os
import PIL
from datetime import datetime
from functools import partial
import seaborn as sns

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Activation, Dense, Flatten, Dropout
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from tensorflow.keras import Model

from skimage import io, color, filters
from skimage.transform import resize, rotate
from sklearn.metrics import confusion_matrix, accuracy_score

In [13]:
# Image data generator & flow from dir
train_path = 'disease_data3/train'
test_path = 'disease_data3/test'
valid_path = 'disease_data3/valid'
bs = 32

In [14]:
train_batches = ImageDataGenerator(rescale=1/255, horizontal_flip=True, vertical_flip=True, shear_range=0.2, rotation_range=60).flow_from_directory(train_path, target_size=(224, 224), batch_size=bs)
test_batches = ImageDataGenerator(rescale=1/255).flow_from_directory(test_path, target_size=(224, 224), batch_size=bs)
valid_batches = ImageDataGenerator(rescale=1/255).flow_from_directory(valid_path, target_size=(224, 224), batch_size=bs)

Found 8747 images belonging to 10 classes.
Found 1158 images belonging to 10 classes.
Found 1208 images belonging to 10 classes.


In [17]:
vgg = tf.keras.applications.vgg16.VGG16()
vgg16_model = Sequential()

for layer in vgg.layers: #[:-3]
    vgg16_model.add(layer)
for layer in vgg16_model.layers:
    layer.trainable = False

# Defining Additional Model Layers
vgg16_model.add(Dense(512, activation='relu'))
vgg16_model.add(Dropout(.2)) # helps prevent overfitting
vgg16_model.add(Dense(256, activation='relu'))
vgg16_model.add(Dropout(.2))
vgg16_model.add(Dense(10, activation='softmax'))
vgg16_model.compile(loss='categorical_crossentropy', optimizer=Adam(.001), metrics=['accuracy']) 

# Tensorboard Callback
vgg16_tb_callback = TensorBoard(log_dir = 'logdir', histogram_freq=1)

# Early Stopping
vgg16_es = EarlyStopping(monitor='val_accuracy', patience=20, verbose=1, restore_best_weights=True)

# Save the file name
checkpoint_path = "drive/MyDrive/capstone3/checkpoints/vgg16_model/cp-{epoch:04d}.ckpt" # all the trained weights
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create checkpoint
vgg16_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath = checkpoint_path, 
    verbose = 1, 
    save_weights_only = True,
    save_best_only = True, # only save the best epoch
    monitor = 'val_accuracy',
    save_freq = 'epoch')

In [18]:
vgg16_model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)      

In [19]:
vgg16_history = vgg16_model.fit(train_batches, validation_data=valid_batches, epochs=100, callbacks=[vgg16_checkpoint, vgg16_es, vgg16_tb_callback], verbose=1)

Epoch 1/100
  7/274 [..............................] - ETA: 17:57 - loss: 2.3009 - accuracy: 0.1027

KeyboardInterrupt: 