<a href="https://colab.research.google.com/github/Swathi1309/ED18B034_ME18B133_CS6910/blob/main/Assignment2/PartB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tuning a pre-trained model

# Initial Setup

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import Xception
import pprint

In [None]:
from google.colab import drive
drive.mount('/content/drive')

classes = ['Amphibia', 'Animalia', 'Arachnida', 'Aves', 'Fungi', 'Insecta', 'Mammalia', 'Mollusca', 'Plantae', 'Reptilia']

In [None]:
!pip install wandb
!wandb login
import wandb
from wandb.keras import WandbCallback
wandb.init(project="CS6910-assg2", entity="swathi")

# Function to load the model and data
To change the model from Xception to any other, the base_model can be changed in the below cell.

In [None]:
def load_model(dropout, learning_rate, unfreeze):
  unfreeze = -1*unfreeze
  input = keras.Input(shape=(img_dim,img_dim,channel_no))
  # The model can be changed to any other as required
  base_model = Xception(weights="imagenet",
                        input_shape=(img_dim, img_dim, channel_no),
                        include_top=False,)
  
  for layer in base_model.layers:
    layer.trainable = False
  if (unfreeze !=0):
    for layer in base_model.layers[unfreeze:]:
      layer.trainable = True
  
  X = base_model(input, training=False)
  X = GlobalAveragePooling2D()(X)
  X = Dropout(0.2)(X)
  output = Dense(10,activation='softmax')(X)
  model = Model(inputs=input,outputs=output)
  model.compile(optimizer = Adam(learning_rate),
                loss='categorical_crossentropy',
                metrics=['accuracy'])
  
  return base_model, model

In [None]:
def load_data(dir_train, dir_test, batch):
  
  seed = 42
  
  train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale = 1./255,
    samplewise_center = 0,
    horizontal_flip = True,
    rotation_range = 30,
    validation_split = 0.1)
  
  val_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale = 1./255,
    samplewise_center = 0,
    validation_split = 0.1)
  
  test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale = 1./255,
    samplewise_center = 0)
  
  train_aug_dataset = train_datagen.flow_from_directory(
    dir_train,
    target_size = (img_dim,img_dim),
    batch_size = batch,
    classes = ['Amphibia', 'Animalia', 'Arachnida', 'Aves', 'Fungi', 'Insecta', 'Mammalia', 'Mollusca', 'Plantae', 'Reptilia'],
    class_mode='categorical',
    subset = 'training',
    seed = seed)

  train_dataset = val_datagen.flow_from_directory(
    dir_train,
    target_size = (img_dim,img_dim),
    batch_size = batch,
    classes = ['Amphibia', 'Animalia', 'Arachnida', 'Aves', 'Fungi', 'Insecta', 'Mammalia', 'Mollusca', 'Plantae', 'Reptilia'],
    class_mode='categorical',
    subset = 'training',
    seed = seed)
  
  val_dataset = val_datagen.flow_from_directory(
    dir_train,
    target_size = (img_dim,img_dim),
    batch_size = batch,
    classes = ['Amphibia', 'Animalia', 'Arachnida', 'Aves', 'Fungi', 'Insecta', 'Mammalia', 'Mollusca', 'Plantae', 'Reptilia'],
    class_mode='categorical',
    subset = 'validation',
    seed = seed)
  
  test_dataset = test_datagen.flow_from_directory(
    dir_test,
    target_size = (img_dim,img_dim),
    batch_size = batch,
    classes = ['Amphibia', 'Animalia', 'Arachnida', 'Aves', 'Fungi', 'Insecta', 'Mammalia', 'Mollusca', 'Plantae', 'Reptilia'],
    class_mode='categorical',
    subset = None,
    seed = seed
  )
  
  return train_aug_dataset, train_dataset, val_dataset, test_dataset

# Defining parameters

In [None]:
global img_dim
img_dim = 200

global channel_no
channel_no = 3 #3 for RGB images, 1 for greyscale

global batch
batch = 128

global train_aug_dataset, train_dataset, val_dataset, test_datagen
train_aug_dataset, train_dataset, val_dataset, test_dataset = load_data('/content/drive/MyDrive/inaturalist_12K/train', '/content/drive/MyDrive/inaturalist_12K/val', batch)

# Hyperparameter sweeps

In [None]:
sweep_config = {
    'method': 'grid'
    }

parameters_dict = {
    'unfreeze' : {
        'values' : [0] # 0,1,2
    },
    'unfreeze_fine_tune' : {
        'values' : [30] #0,30
    },
    'fine_tune_rate' : {
        'values' : [1e-5]
    },
    'learning_rate' : {
        'values' : [1e-3]
    },
    'dropout' : {
        'values' : [0.5] #0.2,0.5
    },
    'epochs_train' : {
        'values' : [10]
    },
    'epochs_fine_tune' : {
        'values' : [10]
    },
    'augmentation' : {
        'values' : [True]
    }
}

sweep_config['parameters'] = parameters_dict
pprint.pprint(sweep_config)

def training_sweep(config=None):
    with wandb.init(config=config):
        config = wandb.config
        base_model, model = load_model(config.dropout, config.learning_rate, config.unfreeze)
        if config.augmentation == True:
          train = train_aug_dataset
        else:
          train = train_dataset
        history = model.fit(train, 
                            epochs=config.epochs_train,
                            validation_data = val_dataset,
                            callbacks = [WandbCallback(data_type='image', labels = classes)]
                            )
        if (config.unfreeze_fine_tune!= 0):
          for layer in base_model.layers[config.unfreeze_fine_tune:]:
            layer.trainable = True
          model.compile(optimizer = Adam(config.fine_tune_rate),
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])
          history = model.fit(train, 
                              epochs=config.epochs_fine_tune,
                              validation_data = val_dataset,
                              callbacks = [WandbCallback(data_type='image', labels = classes)]
                              )

In [None]:
sweep_id = wandb.sweep(sweep_config, project="CS6910-assg2")
wandb.agent(sweep_id, training_sweep)