In [1]:
!pip install keras_tuner



In [2]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, Add, GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.models import Model
import keras_tuner as kt

In [3]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Function to preprocess a single image
def preprocess_image(image, target_size=(64, 64)):  # Smaller input size to reduce memory usage
    image = tf.expand_dims(image, axis=-1)  # Add channel dimension
    image = tf.image.grayscale_to_rgb(image)  # Convert to RGB
    image = tf.image.resize(image, target_size)  # Resize image
    return image / 255.0  # Normalize

# Convert dataset into TensorFlow Dataset objects
batch_size = 16  # Smaller batch size to reduce memory consumption
target_size = (64, 64)  # Smaller target size

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.map(lambda x, y: (preprocess_image(x, target_size), tf.one_hot(y, 10)))
train_dataset = train_dataset.shuffle(buffer_size=1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)

test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.map(lambda x, y: (preprocess_image(x, target_size), tf.one_hot(y, 10)))
test_dataset = test_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

# Lightweight ResNet Block
def resnet_block(x, filters, stride=1, dropout_rate=0.2):  # Reduced dropout to balance memory and regularization
    shortcut = x

    # First convolution
    x = Conv2D(filters, (3, 3), strides=stride, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    # Second convolution
    x = Conv2D(filters, (3, 3), strides=1, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)

    # Adjust shortcut dimensions if necessary
    if stride != 1 or x.shape[-1] != shortcut.shape[-1]:
        shortcut = Conv2D(filters, (1, 1), strides=stride, use_bias=False)(shortcut)
        shortcut = BatchNormalization()(shortcut)

    x = Add()([x, shortcut])
    x = ReLU()(x)
    x = Dropout(dropout_rate)(x)
    return x

# Build the Lightweight ResNet Model
def build_lightweight_resnet(hp, input_shape=(64, 64, 3), num_classes=10):
    inputs = Input(shape=input_shape)

    # Initial Conv Layer
    x = Conv2D(hp.Int('initial_filters', min_value=16, max_value=32, step=8),
               (3, 3), strides=1, padding='same', use_bias=False)(inputs)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    # Residual Blocks with reduced filters
    for filters in [32, 64]:  # Fewer filters to save memory
        x = resnet_block(x, filters, stride=2, dropout_rate=hp.Float('dropout_rate', min_value=0.2, max_value=0.4, step=0.1))

    # Global Pooling and Output Layer
    x = GlobalAveragePooling2D()(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs, outputs)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(
            learning_rate=hp.Float('learning_rate', min_value=1e-4, max_value=1e-3, sampling='log')
        ),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

# Hyperparameter Tuning with Keras Tuner
def tune_hyperparameters():
    tuner = kt.RandomSearch(  # RandomSearch for quicker results
        build_lightweight_resnet,
        objective='val_accuracy',
        max_trials=5,  # Reduce the number of trials
        directory='tuner',
        project_name='lightweight_resnet_mnist_optimized'
    )

    # Train the model with Keras Tuner
    tuner.search(train_dataset, validation_data=test_dataset, epochs=3)  # Fewer epochs to save memory and time

    # Retrieve the best model and hyperparameters
    best_model = tuner.get_best_models(num_models=1)[0]
    best_hps = tuner.oracle.get_best_trials(num_trials=1)[0].hyperparameters

    return best_model, best_hps


In [4]:
best_model, best_hps = tune_hyperparameters()

# Show the best hyperparameters
print("Best Hyperparameters:")
print(best_hps.values)

# Final Evaluation
final_loss, final_accuracy = best_model.evaluate(test_dataset, verbose=2)
print(f"Final Test Loss: {final_loss:.4f}")
print(f"Final Test Accuracy: {final_accuracy:.4f}")

# Save the best model
saved_model_dir = "/content/lightweight_resnet_mnist_optimized.h5"
best_model.save(saved_model_dir)
print(f"Best Model saved to {saved_model_dir}. You can now download it manually.")

Trial 5 Complete [00h 01m 36s]
val_accuracy: 0.9732999801635742

Best val_accuracy So Far: 0.9732999801635742
Total elapsed time: 00h 08m 04s


  saveable.load_own_variables(weights_store.get(inner_path))


Best Hyperparameters:
{'initial_filters': 24, 'dropout_rate': 0.4, 'learning_rate': 0.0009287197673218198}
625/625 - 3s - 5ms/step - accuracy: 0.9733 - loss: 0.0882




Final Test Loss: 0.0882
Final Test Accuracy: 0.9733
Best Model saved to /content/lightweight_resnet_mnist_optimized.h5. You can now download it manually.
