In [None]:
# import necessary libraries
import os
import glob
import wandb
from wandb.keras import WandbCallback
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
import numpy as np

%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

In [None]:
def load_data(input_shape=(256,256)):
    """Loads the data and performs preprocessing

    Args:
        input_shape (tuple, optional): shape to resize all images to. Defaults to (256,256).

    Returns:
        tuple: train dataset, validation dataset
    """

    DATA_DIR = r"/content/drive/MyDrive/inaturalist_12K"

    train_dir = os.path.join(DATA_DIR, "train")
    val_dir = os.path.join(DATA_DIR, "val")

    train_generator = ImageDataGenerator(rescale=1./255,
                                    rotation_range=50,
                                    zoom_range=0.2,
                                    shear_range=0.2,
                                    horizontal_flip=True)
    val_generator =ImageDataGenerator(rescale=1./255)

    train_ds = train_generator.flow_from_directory(train_dir, target_size=input_shape, batch_size=128, shuffle=True)
    val_ds = val_generator.flow_from_directory(val_dir, target_size=input_shape, batch_size=128)

    return train_ds, val_ds

# testing out the function
train_ds, val_ds = load_data()


In [None]:
## Utility functions for plotting ##
def plot_sample_images(dir_path):
    """Plots one sample from each label

    Args:
        dir_path (str): path of the directory containing the images
    """

    subdirs = glob.glob(os.path.join(dir_path,r"*"))

    fig_height = len(subdirs)//5

    if len(subdirs)%5 != 0:
        fig_height+=1

    fig, axs = plt.subplots(fig_height, 5, figsize=(10, fig_height*2))
    fig.suptitle("Sample images from each class")
    axs = axs.reshape(-1)

    for i, subdir in enumerate(subdirs):

        class_name = os.path.basename(subdir)
        axs[i].set_title(class_name)

        img_path = glob.glob(os.path.join(subdir, r"*"))[0]
        img = mpimg.imread(img_path)
        axs[i].imshow(img)

    plt.show()

In [None]:
# Plotting samples using the above function
DATA_DIR = r"/content/drive/MyDrive/inaturalist_12K"
plot_sample_images(os.path.join(DATA_DIR,"train"))


In [None]:
class NeuralNet(tf.keras.Model):
    """This class holds the model for training
    """
    def __init__(self, base_model, image_shape=(256, 256)):
        """Init function

        Args:
            base_model (str): The base pretrained model to use
            image_shape (tuple, optional): image shape as input for the model. Defaults to (256, 256).
        """
        super(NeuralNet, self).__init__()

        # instantiating the base model and freezing it's weights
        self.base_model = self.select_model(base_model, image_shape)
        self.base_model.trainable=False

        # The layers below form the classification head
        self.conv1 = layers.Conv1D(3, 12, 6, activation="relu")
        self.pool1 = layers.MaxPool1D(3,3)
        self.bn1 = layers.BatchNormalization()
        self.conv2 = layers.Conv1D(3, 6, 3, activation="relu")
        self.pool2 = layers.MaxPool1D(3,3)
        self.bn2 = layers.BatchNormalization()
        self.output_layer = layers.Dense(10, activation=None)

    @staticmethod
    def select_model(name, image_shape):
        """Selects the pretrained model to be used

        Args:
            name (str): name of the pretrained model to use
            image_shape (tuple): input size for the pretrained model

        Returns:
            tf.keras.Model: Base model from tensorflow
        """

        image_shape = list(image_shape)
        image_shape.append(3)

        INPUT_SHAPE = tuple(image_shape)

        if name=="InceptionV3":
            return tf.keras.applications.InceptionV3(include_top=False, input_shape=INPUT_SHAPE, weights='imagenet')
        elif name=="InceptionResNetV2":
            return tf.keras.applications.InceptionResNetV2(include_top=False, input_shape=INPUT_SHAPE, weights='imagenet')
        elif name=="ResNet50":
            return tf.keras.applications.ResNet50(include_top=False, input_shape=INPUT_SHAPE, weights='imagenet')
        elif name=="Xception":
            return tf.keras.applications.Xception(include_top=False, input_shape=INPUT_SHAPE, weights='imagenet')

    def call(self, x):
        """Performs forward pass for the model

        Args:
            x (tf.tensor): input for the model

        Returns:
            tf.tensor: output of the model
        """
        x = layers.Flatten()(self.base_model(x))

        x = tf.expand_dims(x, -1)
        x = self.bn1(self.pool1(self.conv1(x)))
        x = self.bn2(self.pool2(self.conv2(x)))
        x = layers.Flatten()(x)

        return self.output_layer(x)

In [None]:
# This is the main function to use to train/fine-tune the model using wandb runs
def train_with_wandb(image_shape=(256, 256), epochs=30, fine_tune_epochs=10):

    config_defaults = {"base_model": "InceptionV3"}

    wandb.init(config=config_defaults, project="cs6910-assignment2", magic=True)

    ## 1. Data loading
    print("1. Loading the dataset ...\n")
    train_ds, val_ds = load_data(image_shape)

    ## 2. Initializing the model
    print("2. Initializing the model ...\n")
    model = NeuralNet(wandb.config.base_model, image_shape=image_shape)

    ## 3. Compiling the model
    base_learning_rate = 0.0002

    print("3. Compiling the model ...\n")
    model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate),
                  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])

    ## 4. Fitting the model
    print("4. Fitting the model ...\n")
    model.fit(train_ds,
              validation_data=val_ds,
              epochs=epochs,
              callbacks=[WandbCallback()])
    print("Model trained successfully!!\n")

    ## 5. Fine tuning the model
    to_tune_defaults = {
        "InceptionV3": 55,
        "InceptionResNetV2": 55,
        "ResNet50": 50,
        "Xception": 50
    }

    model.base_model.trainable = True
    print(f"Total layers in base model is {len(model.base_model.layers)}\n")

    fine_tune_at = len(model.base_model.layers) - to_tune_defaults[wandb.config.base_model]

    for layer in model.base_model.layers[:fine_tune_at]:
        layer.trainable =  False

    model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate/10),
                  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])

    print("Fine tuning the model ...\n")
    model.fit(train_ds,
              validation_data=val_ds,
              epochs=fine_tune_epochs,
              callbacks=[WandbCallback()])
    print("Model tuned successfully!!\n")
