In [23]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, regularizers, callbacks
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
import numpy as np

In [24]:
def load_binary_mnist(selected_classes=[0, 1], img_size=32):
    """
    Loads and preprocesses the Binary MNIST dataset.

    Args:
        selected_classes (list): List of classes to include (e.g., [0, 1]).
        img_size (int): Desired image size after resizing.

    Returns:
        tuple: Preprocessed training and test datasets.
    """
    # Load MNIST data
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    # Filter the data to include only selected classes
    train_filter = np.isin(y_train, selected_classes)
    test_filter = np.isin(y_test, selected_classes)

    x_train, y_train = x_train[train_filter], y_train[train_filter]
    x_test, y_test = x_test[test_filter], y_test[test_filter]

    # Relabel the classes to 0 and 1
    class_to_idx = {cls: idx for idx, cls in enumerate(selected_classes)}
    y_train = np.vectorize(class_to_idx.get)(y_train)
    y_test = np.vectorize(class_to_idx.get)(y_test)

    # Expand dimensions to include channel information
    x_train = np.expand_dims(x_train, axis=-1)
    x_test = np.expand_dims(x_test, axis=-1)

    # Resize images to img_size x img_size
    x_train = tf.image.resize(x_train, [img_size, img_size]).numpy()
    x_test = tf.image.resize(x_test, [img_size, img_size]).numpy()

    # Normalize the images to [0, 1]
    x_train = x_train.astype('float32') / 255.
    x_test = x_test.astype('float32') / 255.

    # One-hot encode the labels
    y_train = to_categorical(y_train, num_classes=2)
    y_test = to_categorical(y_test, num_classes=2)

    return (x_train, y_train), (x_test, y_test)


In [25]:
def inception_module(x, filters, weight_decay=5e-4):
    """
    Constructs an Inception module.

    Args:
        x (tf.Tensor): Input tensor.
        filters (dict): Dictionary containing the number of filters for each branch.
        weight_decay (float): L2 regularization factor.

    Returns:
        tf.Tensor: Output tensor after concatenating all branches.
    """
    # 1x1 Convolution branch
    branch1 = layers.Conv2D(filters['1x1'], (1, 1), padding='same',
                           activation='relu',
                           kernel_regularizer=regularizers.l2(weight_decay))(x)
    branch1 = layers.BatchNormalization()(branch1)

    # 1x1 Convolution followed by 3x3 Convolution branch
    branch2 = layers.Conv2D(filters['3x3_reduce'], (1, 1), padding='same',
                           activation='relu',
                           kernel_regularizer=regularizers.l2(weight_decay))(x)
    branch2 = layers.BatchNormalization()(branch2)
    branch2 = layers.Conv2D(filters['3x3'], (3, 3), padding='same',
                           activation='relu',
                           kernel_regularizer=regularizers.l2(weight_decay))(branch2)
    branch2 = layers.BatchNormalization()(branch2)

    # 1x1 Convolution followed by 5x5 Convolution branch
    branch3 = layers.Conv2D(filters['5x5_reduce'], (1, 1), padding='same',
                           activation='relu',
                           kernel_regularizer=regularizers.l2(weight_decay))(x)
    branch3 = layers.BatchNormalization()(branch3)
    branch3 = layers.Conv2D(filters['5x5'], (5, 5), padding='same',
                           activation='relu',
                           kernel_regularizer=regularizers.l2(weight_decay))(branch3)
    branch3 = layers.BatchNormalization()(branch3)

    # 3x3 Max Pooling followed by 1x1 Convolution branch
    branch4 = layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(x)
    branch4 = layers.Conv2D(filters['pool_proj'], (1, 1), padding='same',
                           activation='relu',
                           kernel_regularizer=regularizers.l2(weight_decay))(branch4)
    branch4 = layers.BatchNormalization()(branch4)

    # Concatenate all the branches
    output = layers.concatenate([branch1, branch2, branch3, branch4], axis=3)
    return output


def auxiliary_classifier(x, num_classes=2, weight_decay=5e-4):
    """
    Constructs an Auxiliary Classifier as used in Inception v1 (GoogleNet).

    Args:
        x (tf.Tensor): Input tensor from an intermediate layer.
        num_classes (int): Number of output classes.
        weight_decay (float): L2 regularization factor.

    Returns:
        tf.Tensor: Output tensor after the auxiliary classifier.
    """
    x = layers.AveragePooling2D(pool_size=(2, 2), strides=(1, 1), padding='same')(x)
    x = layers.Conv2D(128, (1, 1), padding='same',
                      activation='relu',
                      kernel_regularizer=regularizers.l2(weight_decay))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(1024, activation='relu',
                     kernel_regularizer=regularizers.l2(weight_decay))(x)
    x = layers.Dropout(0.7)(x)
    x = layers.Dense(num_classes, activation='softmax',
                     kernel_regularizer=regularizers.l2(weight_decay))(x)
    return x


def build_googlenet(input_shape=(32, 32, 1), num_classes=2):
    """
    Builds the Inception v1 (GoogleNet) model.

    Args:
        input_shape (tuple): Shape of the input images.
        num_classes (int): Number of output classes.

    Returns:
        tf.keras.Model: Compiled GoogleNet model.
    """
    weight_decay = 5e-4  # L2 regularization factor

    inputs = layers.Input(shape=input_shape)

    # Initial Layers
    x = layers.Conv2D(64, (7, 7), strides=(2, 2), padding='same',
                      activation='relu',
                      kernel_regularizer=regularizers.l2(weight_decay))(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)

    x = layers.Conv2D(192, (3, 3), padding='same',
                      activation='relu',
                      kernel_regularizer=regularizers.l2(weight_decay))(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)

    # 1st Inception Module
    filters = {
        '1x1': 64,
        '3x3_reduce': 96,
        '3x3': 128,
        '5x5_reduce': 16,
        '5x5': 32,
        'pool_proj': 32,
    }
    x = inception_module(x, filters, weight_decay)

    # 2nd Inception Module
    filters = {
        '1x1': 128,
        '3x3_reduce': 128,
        '3x3': 192,
        '5x5_reduce': 32,
        '5x5': 96,
        'pool_proj': 64,
    }
    x = inception_module(x, filters, weight_decay)
    x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)

    # 3rd Inception Module
    filters = {
        '1x1': 192,
        '3x3_reduce': 96,
        '3x3': 208,
        '5x5_reduce': 16,
        '5x5': 48,
        'pool_proj': 64,
    }
    x = inception_module(x, filters, weight_decay)

    # Auxiliary Classifier 1 (placed here after 3rd Inception module with spatial size 4x4)
    aux1 = auxiliary_classifier(x, num_classes, weight_decay)

    # 4th Inception Module
    filters = {
        '1x1': 160,
        '3x3_reduce': 112,
        '3x3': 224,
        '5x5_reduce': 24,
        '5x5': 64,
        'pool_proj': 64,
    }
    x = inception_module(x, filters, weight_decay)

    # Auxiliary Classifier 2 (placed here after 4th Inception module with spatial size ~2x2)
    # To prevent pooling errors, ensure the spatial size is sufficient
    aux2 = auxiliary_classifier(x, num_classes, weight_decay)

    # 5th Inception Module
    filters = {
        '1x1': 128,
        '3x3_reduce': 128,
        '3x3': 256,
        '5x5_reduce': 24,
        '5x5': 64,
        'pool_proj': 64,
    }
    x = inception_module(x, filters, weight_decay)

    # 6th Inception Module
    filters = {
        '1x1': 112,
        '3x3_reduce': 144,
        '3x3': 288,
        '5x5_reduce': 32,
        '5x5': 64,
        'pool_proj': 64,
    }
    x = inception_module(x, filters, weight_decay)
    x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)

    # 7th Inception Module
    filters = {
        '1x1': 256,
        '3x3_reduce': 160,
        '3x3': 320,
        '5x5_reduce': 32,
        '5x5': 128,
        'pool_proj': 128,
    }
    x = inception_module(x, filters, weight_decay)

    # 8th Inception Module
    filters = {
        '1x1': 256,
        '3x3_reduce': 160,
        '3x3': 320,
        '5x5_reduce': 32,
        '5x5': 128,
        'pool_proj': 128,
    }
    x = inception_module(x, filters, weight_decay)

    # 9th Inception Module
    filters = {
        '1x1': 384,
        '3x3_reduce': 192,
        '3x3': 384,
        '5x5_reduce': 48,
        '5x5': 128,
        'pool_proj': 128,
    }
    x = inception_module(x, filters, weight_decay)

    # Global Average Pooling and Dropout
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.4)(x)
    primary_output = layers.Dense(num_classes, activation='softmax',
                                  kernel_regularizer=regularizers.l2(weight_decay))(x)

    # Define the model with auxiliary outputs
    model = models.Model(inputs=inputs, outputs=[primary_output, aux1, aux2])

    return model


In [26]:
def compile_and_train(model, train_data, test_data, epochs=30, batch_size=32):
    """
    Compiles and trains the GoogleNet model.

    Args:
        model (tf.keras.Model): The GoogleNet model.
        train_data (tuple): Tuple of training data (x_train, y_train).
        test_data (tuple): Tuple of test data (x_test, y_test).
        epochs (int): Number of training epochs.
        batch_size (int): Size of training batches.

    Returns:
        history: Training history object.
    """
    # Define loss weights for the auxiliary classifiers
    loss_weights = [1.0, 0.3, 0.3]

    # Compile the model
    model.compile(optimizer=optimizers.SGD(learning_rate=0.01, momentum=0.9),
                  loss=['categorical_crossentropy', 'categorical_crossentropy', 'categorical_crossentropy'],
                  loss_weights=loss_weights,
                  metrics=['accuracy'])

    # Define callbacks
    lr_scheduler = callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,
                                               patience=5, verbose=1)
    early_stopping = callbacks.EarlyStopping(monitor='val_loss', patience=15,
                                             restore_best_weights=True, verbose=1)

    # Train the model
    history = model.fit(train_data[0], [train_data[1], train_data[1], train_data[1]],
                        epochs=epochs,
                        batch_size=batch_size,
                        validation_data=(test_data[0], [test_data[1], test_data[1], test_data[1]]),
                        callbacks=[lr_scheduler, early_stopping],
                        shuffle=True)

    return history


In [27]:
def plot_training_history(history):
    """
    Plots the training and validation accuracy and loss.

    Args:
        history: Training history object.
    """
    # Plot Primary Loss
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Train Primary Loss')
    plt.plot(history.history['val_loss'], label='Validation Primary Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Primary Loss over Epochs')
    plt.legend()

    # Plot Primary Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(history.history['accuracy'], label='Train Primary Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Primary Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Primary Accuracy over Epochs')
    plt.legend()

    plt.tight_layout()
    plt.show()

    # Plot Auxiliary Classifier 1 Loss
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss_1'], label='Train Aux1 Loss')
    plt.plot(history.history['val_loss_1'], label='Validation Aux1 Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Auxiliary Classifier 1 Loss over Epochs')
    plt.legend()

    # Plot Auxiliary Classifier 2 Loss
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss_2'], label='Train Aux2 Loss')
    plt.plot(history.history['val_loss_2'], label='Validation Aux2 Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Auxiliary Classifier 2 Loss over Epochs')
    plt.legend()

    plt.tight_layout()
    plt.show()

    # Plot Auxiliary Classifier 1 Accuracy
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy_1'], label='Train Aux1 Accuracy')
    plt.plot(history.history['val_accuracy_1'], label='Validation Aux1 Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Auxiliary Classifier 1 Accuracy over Epochs')
    plt.legend()

    # Plot Auxiliary Classifier 2 Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(history.history['accuracy_2'], label='Train Aux2 Accuracy')
    plt.plot(history.history['val_accuracy_2'], label='Validation Aux2 Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Auxiliary Classifier 2 Accuracy over Epochs')
    plt.legend()

    plt.tight_layout()
    plt.show()


In [28]:
# Hyperparameters
batch_size = 32
epochs = 10
# Load and preprocess data
(x_train, y_train), (x_test, y_test) = load_binary_mnist(selected_classes=[0, 1], img_size=32)
# Build the GoogleNet model
model = build_googlenet(input_shape=(32, 32, 1), num_classes=2)
model.summary()
# Compile and train the model
history = compile_and_train(model, (x_train, y_train), (x_test, y_test),
                            epochs=epochs, batch_size=batch_size)
# Evaluate the model
results = model.evaluate(x_test, [y_test, y_test, y_test], verbose=0)
print(f'Test Primary Loss: {results[0]:.4f}')
print(f'Test Primary Accuracy: {results[1]:.2f}%')
print(f'Test Aux1 Loss: {results[2]:.4f}')
print(f'Test Aux1 Accuracy: {results[3]:.2f}%')
print(f'Test Aux2 Loss: {results[4]:.4f}')
print(f'Test Aux2 Accuracy: {results[5]:.2f}%')
# Plot training history
plot_training_history(history)

Epoch 1/10


ValueError: For a model with multiple outputs, when providing the `metrics` argument as a list, it should have as many entries as the model has outputs. Received:
metrics=['accuracy']
of length 1 whereas the model has 3 outputs.