<a href="https://colab.research.google.com/github/FlorenceBoutin/GI_disease_detection/blob/master/notebooks/florence_EfficientNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1  Build CNN model from EfficientNet architecture

## Imports

In [1]:
# Imports

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from tensorflow import keras

from keras import optimizers, regularizers, models, Sequential, layers, Model
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping
from keras.applications.efficientnet import *

from google.colab import drive

## Import data

In [2]:
drive.mount('/content/gdrive/')

Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).


In [3]:
drive_path = '/content/gdrive/My Drive/Colab Notebooks/GI_disease_detection/raw_data/'

In [4]:
train_folder = drive_path + 'train'
val_folder = drive_path + 'val'
test_folder = drive_path + 'test'

datagen = ImageDataGenerator(rescale = 1. / 255)
train_dataset = datagen.flow_from_directory(train_folder,
                                         target_size = (224, 224),
                                         color_mode = "rgb",
                                         batch_size = 32,
                                         class_mode = "categorical")

val_dataset = datagen.flow_from_directory(val_folder,
                                         target_size = (224, 224),
                                         color_mode = "rgb",
                                         batch_size = 32,
                                         class_mode = "categorical")

test_dataset = datagen.flow_from_directory(test_folder,
                                         target_size = (224, 224),
                                         color_mode = "rgb",
                                         batch_size = 32,
                                         class_mode = "categorical", 
                                         shuffle=False)

Found 2406 images belonging to 3 classes.
Found 1500 images belonging to 3 classes.
Found 600 images belonging to 3 classes.


## Create model

In [5]:
recall = keras.metrics.Recall()
es = EarlyStopping(patience=10, restore_best_weights=True)
es_recall = EarlyStopping(patience=5, restore_best_weights=True, monitor='val_recall', mode='max')
adam_opt = optimizers.Adam(learning_rate=0.0001)
adam_opt_1 = optimizers.Adam(learning_rate=0.01)
adam_opt_2 = optimizers.Adam(learning_rate=0.005)
adam_opt_3 = optimizers.Adam(learning_rate=1)

### Efficient Net

In [6]:
base_model = EfficientNetB0(include_top=False, weights=None, input_shape=(224,224,3), pooling='max')
base_model.trainable = True

x = base_model.output
x = layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001 )(x)
x = layers.Dense(256, kernel_regularizer = regularizers.l2(l = 0.016),activity_regularizer=regularizers.l1(0.006),
                bias_regularizer=regularizers.l1(0.006), activation='relu')(x)
x = layers.Dropout(rate=0.3)(x)
x = layers.Dense(128, kernel_regularizer = regularizers.l2(l = 0.016),activity_regularizer=regularizers.l1(0.006),
                bias_regularizer=regularizers.l1(0.006) ,activation='relu')(x)
x = layers.Dropout(rate=0.45)(x)        
output = layers.Dense(3, activation='softmax')(x)

model=Model(inputs=base_model.input, outputs=output)

In [7]:
model.compile(loss='categorical_crossentropy',
               optimizer=adam_opt_1,
               metrics=[recall, 'accuracy'])

In [8]:
model.fit(train_dataset,
                epochs=50,
                callbacks=[es_recall],
                validation_data=val_dataset)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50


<keras.callbacks.History at 0x7fa9d81f1100>