In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sea
import tensorflow as tf
import random
import gc

import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, ReLU, BatchNormalization
from tensorflow.keras.initializers import HeNormal, Ones, Constant
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

In [2]:
SEED = 3126  
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [3]:
WIDTH, HEIGHT = 224, 224
BATCH_SIZE = 32
LEARNING_RATE = .01
EPOCHS = 100
DIR = "/kaggle/input/finalized-astrovision-data"

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    width_shift_range=0,
    height_shift_range=0,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2
)

train_generator = train_datagen.flow_from_directory(
    DIR,
    target_size=(WIDTH, HEIGHT),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='training',
    seed=SEED  
)

validation_generator = train_datagen.flow_from_directory(
    DIR,
    target_size=(WIDTH, HEIGHT),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation',
    seed=SEED 
)

Found 4913 images belonging to 4 classes.
Found 1228 images belonging to 4 classes.


In [4]:
def make_dense_layer(input_size, dropout_rate=0.0, input_shape=None):
    layers = []
    if input_shape:
        layers.append(Dense(input_size,
                            use_bias=False, 
                            input_shape=input_shape,
                            kernel_initializer=HeNormal()
                            ))
    else:
        layers.append(Dense(input_size, use_bias=False, kernel_initializer=HeNormal()))
    layers.extend([BatchNormalization(gamma_initializer=Ones(), beta_initializer=Constant(0.25)),
                  ReLU()])
    if dropout_rate > 0:
        layers.append(Dropout(dropout_rate))
    return Sequential(layers)

In [5]:
def build_model():
    base_model = MobileNetV2(weights='imagenet', include_top=False,
                             input_shape=(WIDTH, HEIGHT, 3))
    base_model.trainable = False
    x = base_model.output
    x = GlobalAveragePooling2D()(x)

    fc_layers = Sequential([
        make_dense_layer(1024, input_shape=(x.shape[-1],)),
        make_dense_layer(8)
    ])
    x = fc_layers(x)
    predictions = Dense(train_generator.num_classes,
                        activation='softmax',
                        kernel_initializer=HeNormal())(x)

    model = Model(inputs=base_model.input, outputs=predictions)
    model.compile(optimizer=Adam(learning_rate=LEARNING_RATE),
                 loss='categorical_crossentropy',
                 metrics=['accuracy'])   
    return model

In [6]:
class ClearMemory(Callback):
    def on_epoch_end(self, epoch, logs=None):
        K.clear_session()  
        gc.collect() 

In [7]:
model = build_model()

best_weights_path = "/kaggle/working/best_weights.weights.h5"

checkpoint_callback = ModelCheckpoint(
    filepath=best_weights_path,  
    monitor='val_accuracy',    
    save_best_only=True,       
    save_weights_only=True,    
    mode='max',                
    verbose=1                  
)

early_stopping_callback = EarlyStopping(
    monitor='val_loss',  
    patience=10,              
    mode='min',              
    min_delta=.001,
    restore_best_weights=True,  
    verbose=1
)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
[1m9406464/9406464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [8]:
history = model.fit(
    train_generator,
    epochs=EPOCHS,
    validation_data=validation_generator,
    callbacks=[checkpoint_callback, early_stopping_callback, ClearMemory()]
)

initial_test_loss, initial_test_acc = model.evaluate(validation_generator, 
                                                     steps=validation_generator.samples // BATCH_SIZE)

print(f"Test Accuracy Before Fine-tuning: {initial_test_acc*100:.2f}%")

Epoch 1/100


  self._warn_if_super_not_called()


[1m153/154[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 523ms/step - accuracy: 0.6734 - loss: 0.7912
Epoch 1: val_accuracy improved from -inf to 0.76466, saving model to /kaggle/working/best_weights.weights.h5
[1m154/154[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m123s[0m 702ms/step - accuracy: 0.6741 - loss: 0.7893 - val_accuracy: 0.7647 - val_loss: 0.5730
Epoch 2/100
[1m153/154[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 359ms/step - accuracy: 0.7773 - loss: 0.5009
Epoch 2: val_accuracy improved from 0.76466 to 0.77524, saving model to /kaggle/working/best_weights.weights.h5
[1m154/154[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m74s[0m 458ms/step - accuracy: 0.7774 - loss: 0.5007 - val_accuracy: 0.7752 - val_loss: 0.5231
Epoch 3/100
[1m153/154[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 358ms/step - accuracy: 0.8008 - loss: 0.4528
Epoch 3: val_accuracy improved from 0.77524 to 0.81515, saving model to /kaggle/working/best_weights.weights.h5
