Importing the required modules

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.metrics import confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns


Preparing the data using ImageDataGenerator class, we have also increased the training data using Data Augmentation

In [None]:
def prepare_data(train_data_dir, target_size=(100, 100), batch_size=128, validation_split=0.2):
    # Data augmentation to increase diversity of training data
    train_datagen = ImageDataGenerator(
        rescale=1.0 / 255,  # Rescale pixel values to range [0,1]
        validation_split=validation_split,  # Split data into training and validation sets
        rotation_range=20,  # Rotate images randomly up to 20 degrees
        width_shift_range=0.2,  # Shift images horizontally up to 20% of the width
        height_shift_range=0.2,  # Shift images vertically up to 20% of the height
        shear_range=0.2,  # Shear transformation with max intensity of 20%
        zoom_range=0.2,  # Randomly zoom images up to 20%
        horizontal_flip=True  # Randomly flip images horizontally
    )
    
    # Generate batches of augmented data from the directory
    train_generator = train_datagen.flow_from_directory(
        train_data_dir,
        target_size=target_size,
        batch_size=batch_size,
        class_mode='categorical',
        subset='training'
    )
    
    validation_generator = train_datagen.flow_from_directory(
        train_data_dir,
        target_size=target_size,
        batch_size=batch_size,
        class_mode='categorical',
        subset='validation'
    )
    
    return train_generator, validation_generator


Creating our model on top of the GlobalAveragePooling2D pre trained model


We add the final fully connected layer and add a dropout of 50% to reduce overfitting


There are 10 classes of the images


In [None]:
def create_model(input_shape=(100, 100, 3), num_classes=10):
    # Load pre-trained MobileNetV2 model without the top classification layer
    base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape)
    base_model.trainable = False  # Freeze the weights of the base model
    
    # Build a sequential model by adding layers
    model = Sequential([
        base_model,  # Use the MobileNetV2 base model
        GlobalAveragePooling2D(),  # Global average pooling to reduce spatial dimensions
        Dropout(0.5),  # Dropout layer with dropout rate of 0.5 to prevent overfitting
        Dense(1024, activation='relu'),  # Fully connected layer with 1024 units and ReLU activation
        Dense(num_classes, activation='softmax')  # Output layer with num_classes units and softmax activation
    ])
    
    return model


In [None]:
def compile_model(model):
    # Compile the model with Adam optimizer, categorical crossentropy loss, and accuracy metric
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])


| Layer (type)                | Output Shape     | Param #    |
|-----------------------------|------------------|------------|
| mobilenetv2_1.00_100        | (None, 3, 3, 1280)| 2,257,984  |
| GlobalAveragePooling2D      | (None, 1280)     | 0          |
| dropout                     | (None, 1280)     | 0          |
| dense                       | (None, 1024)     | 1,311,744  |
| dense_1                     | (None, 10)       | 10,250     |


Total params: 3,583,978
Trainable params: 1,321,994
Non-trainable params: 2,261,984


To plot the confusion matrix

In [None]:
def plot_confusion_matrix(conf_matrix, class_names):
    # Plotting the confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.show()


Main Function of the code that runs the functions

In [None]:
def main():
    # Define directories and parameters
    train_data_dir = "/kaggle/input/state-farm-distracted-driver-detection/imgs/train" #For running on the Kaggle Kernel
    target_size = (100, 100)
    batch_size = 128
    validation_split = 0.2
    num_classes = 10
    epochs = 10 #Gives an accuracy of about 70% whereas epochs=20 causes overfitting and goes below 10%
    
    # Prepare data
    train_generator, validation_generator = prepare_data(train_data_dir, target_size=target_size, batch_size=batch_size, validation_split=validation_split)
    
    # Create model
    model = create_model(input_shape=(target_size[0], target_size[1], 3), num_classes=num_classes)
    
    # Compile model
    compile_model(model)
    
    # Early stopping callback to prevent overfitting
    early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
    
    # Train the model
    model.fit(train_generator, epochs=epochs, validation_data=validation_generator, callbacks=[early_stopping])
    
    # Make predictions on the validation set
    predictions = model.predict(validation_generator)
    y_pred = np.argmax(predictions, axis=1)  # Get the index of the highest probability as predicted class
    
    # Get true labels from the generator
    y_true = validation_generator.classes  # True labels are encoded as integers based on class indices
    
    # Compute confusion matrix
    conf_matrix = confusion_matrix(y_true, y_pred)
    
    # Plot the confusion matrix
    class_names = list(validation_generator.class_indices.keys())  # Extract class names from the generator
    plot_confusion_matrix(conf_matrix, class_names)
    
    # Check accuracy to verify the model's performance
    accuracy = accuracy_score(y_true, y_pred)
    print(f"Validation Accuracy: {accuracy}")

# Call the main function
if __name__ == "__main__":
    main()
