##Training Example

Note: this notebook is meant to be ran after the Download Data Notebook

In [None]:
import tensorflow as tf
import os
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, BatchNormalization, Dropout,RandomFlip, RandomRotation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, ModelCheckpoint
import pandas as pd
import shutil
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
from mpl_toolkits.axes_grid1 import ImageGrid
from tqdm.notebook import tqdm
from sklearn.metrics import roc_curve, RocCurveDisplay
from tensorflow.keras.applications import Xception 
from tensorflow.keras.applications.xception import preprocess_input
from cnn_utils import *

In [None]:
fp= 'Data'

In [None]:
#Set up metrics
metrics = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
      keras.metrics.AUC(name='prc', curve='PR') # precision-recall curve
      ]

In [None]:
#initialize final model
base_model = Xception(weights='imagenet',
         input_shape=(256,256,3),
         include_top=False)
base_model.trainable = False #freeze base model layers
aug =augment()
inputs = keras.Input(shape=(256,256,3))
x = aug(inputs)
#scale=layers.Rescaling(1./127.5, offset = -1)
x=preprocess_input(x)
x=base_model(x,training=False)
x= keras.layers.GlobalAveragePooling2D()(x)
x= Dropout(0.2)(x)
outputs = Dense(1,activation='sigmoid')(x)
model = keras.Model(inputs,outputs)
model.summary()

In [None]:
epochs = 100
model_run = '8_run1'
lr = 0.001
batch_size= 32

if not os.path.exists('model_checkpoints'):
  os.mkdir('model_checkpoints')

outpath = os.path.join('model_checkpoints',model_run)
if not os.path.exists(outpath):
  os.mkdir(outpath)

callbacks = [EarlyStopping(patience = 2,restore_best_weights=True),
              ModelCheckpoint(outpath,save_best_only=True,verbose=False)]

###Fit

AUTOTUNE = tf.data.AUTOTUNE
batch_size= 32
training, val = load_data(fp,batch_size)
training.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val.cache().prefetch(buffer_size=AUTOTUNE)
    
opt=Adam(learning_rate=lr)
model.summary()
model.compile(optimizer=opt,loss=tf.keras.losses.BinaryCrossentropy(),metrics=metrics)

history = model.fit(training, validation_data=val,epochs=epochs,callbacks = callbacks)




