<a href="https://colab.research.google.com/github/KrishalDhungana/Reptiles-Amphibians-Classifier/blob/main/Ensemble_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import numpy as np
import tensorflow_datasets as tfds
import tensorflow as tf  # For tf.data
import matplotlib.pyplot as plt
import keras
from keras import layers
from keras.applications import EfficientNetB0,MobileNetV2,VGG19
from tensorflow.keras.applications import EfficientNetB0

model = EfficientNetB0(weights='imagenet')
dataset_name = "/content/new-reptiles-and-amphibians-image-dataset"

IMG_SIZE = 224
BATCH_SIZE = 16
NUM_CLASSES = 2
img_augmentation_layers = [
    layers.RandomRotation(factor=0.15),
    layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
    layers.RandomFlip(),
    layers.RandomContrast(factor=0.1),
]

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    dataset_name,
    image_size = (IMG_SIZE, IMG_SIZE),
    label_mode='int',
    batch_size=BATCH_SIZE
)
#split data for training and testing, 80% train 20% test
train_size = int(0.8 * len(dataset))
ds_train = dataset.take(train_size)
ds_test = dataset.skip(train_size)

def img_augmentation(image):

    for layer in img_augmentation_layers:
        image = layer(image)
    return image

def input_preprocess_train(image, label):

    image = img_augmentation(image)

    label = tf.one_hot(label, NUM_CLASSES)
    return image, label


def input_preprocess_test(image, label):


    label = tf.one_hot(label, NUM_CLASSES)
    return image, label

ds_train = ds_train.map(input_preprocess_train, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(input_preprocess_test, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)


def build_model(num_classes):
    inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))

    # prepare EfficientNetB0
    efficientnet = EfficientNetB0(include_top=False, input_tensor=inputs, weights="imagenet")
    efficientnet.trainable = False
    x1 = efficientnet.output
    x1 = layers.GlobalAveragePooling2D()(x1)
    x1 = layers.BatchNormalization()(x1)
    x1 = layers.Dropout(0.2)(x1)

    # prepare MobileNetV2
    mobilenet = MobileNetV2(include_top=False, input_tensor=inputs, weights="imagenet")
    mobilenet.trainable = False
    x2 = mobilenet.output
    x2 = layers.GlobalAveragePooling2D()(x2)
    x2 = layers.BatchNormalization()(x2)
    x2 = layers.Dropout(0.2)(x2)

    # prepare VGG19
    VGG = VGG19(include_top=False, input_tensor=inputs, weights="imagenet")
    VGG.trainable = False
    x3 = mobilenet.output
    x3 = layers.GlobalAveragePooling2D()(x3)
    x3 = layers.BatchNormalization()(x3)
    x3 = layers.Dropout(0.2)(x3)

    # Concatenate the 3 layers
    concatenated = layers.concatenate([x1, x2,x3])


    #classification layer
    outputs = layers.Dense(num_classes, activation="softmax")(concatenated)

    # Compile model
    model = tf.keras.Model(inputs, outputs)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2),
                  loss="categorical_crossentropy",
                  metrics=["accuracy"])
    return model
def unfreeze_model(model):
    # Unfreeze the top 20 layers while leaving BatchNorm layers frozen
    for layer in model.layers[-20:]:
        if not isinstance(layer, layers.BatchNormalization):
            layer.trainable = True

    optimizer = keras.optimizers.Adam(learning_rate=1e-5)
    model.compile(
        optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]
    )
    return model




model = build_model(NUM_CLASSES)

epochs = 16
hist = model.fit(ds_train, epochs=epochs, validation_data=ds_test)
model.save_weights('16_epochs_Ensemble.weights.h5')
unfreeze_model(model)

epochs = 4
hist = model.fit(ds_train, epochs=epochs, validation_data=ds_test)
model.save_weights('4_epochs_Ensemble.weights.h5')


Found 3178 files belonging to 2 classes.




Epoch 1/16
  4/159 [..............................] - ETA: 3:55 - loss: 1.0655 - accuracy: 0.7188

KeyboardInterrupt: 