In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16, ResNet50, MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import classification_report, confusion_matrix

# Set GPU memory growth to avoid memory allocation issues
physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)


In [2]:
base_dir = '../thesis_task/imgs'
train_dir = os.path.join(base_dir, 'train')
test_dir = os.path.join(base_dir, 'test')

# Data augmentation and normalization
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

test_datagen = ImageDataGenerator(rescale=1./255)

# Create generators
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
)

# Since test data does not have subdirectories, use flow_from_dataframe or flow_from_directory with class_mode=None
test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode=None,  # No labels
    shuffle=False
)


Found 22424 images belonging to 10 classes.
Found 0 images belonging to 0 classes.


In [3]:
def create_model(model_name):
    if model_name == "VGG16":
        base_model = VGG16(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
    elif model_name == "ResNet50":
        base_model = ResNet50(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
    elif model_name == "MobileNetV2":
        base_model = MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
    else:
        raise ValueError("Invalid model name. Choose from 'VGG16', 'ResNet50', 'MobileNetV2'.")

    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024, activation='relu')(x)
    predictions = Dense(10, activation='softmax')(x)

    model = Model(inputs=base_model.input, outputs=predictions)
    
    for layer in base_model.layers:
        layer.trainable = False

    model.compile(optimizer=Adam(learning_rate=1e-4), loss='categorical_crossentropy', metrics=['accuracy'])
    
    return model


In [4]:
# Choose the model
model_name = "MobileNetV2"  # Change as needed
model = create_model(model_name)

# Train the model
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // train_generator.batch_size,
    validation_data=test_generator,
    validation_steps=test_generator.samples // test_generator.batch_size,
    epochs=10  # Adjust epochs as needed
)

# Save the model
model.save(f'{model_name}_driv_distraction.h5')

import pickle

with open(f'{model_name}_history.pkl', 'wb') as file:
    pickle.dump(history.history, file)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [3]:
import matplotlib.pyplot as plt
import pickle
from tensorflow.keras.models import load_model

# Load the saved model
model = load_model('MobileNetV2_driver_distraction.h5')

# Load training history
with open('MobileNetV2_history.pkl', 'rb') as file:
    history = pickle.load(file)

# Print available keys in history
print(history.keys())

# Plot the available metrics
plt.figure(figsize=(12, 4))

# Plot training & validation accuracy values if available
if 'accuracy' in history and 'val_accuracy' in history:
    plt.subplot(1, 2, 1)
    plt.plot(history['accuracy'])
    plt.plot(history['val_accuracy'])
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend(['Training Accuracy', 'Validation Accuracy'], loc='upper left')

# Plot training & validation loss values if available
if 'loss' in history and 'val_loss' in history:
    plt.subplot(1, 2, 2)
    plt.plot(history['loss'])
    plt.plot(history['val_loss'])
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(['Training Loss', 'Validation Loss'], loc='upper left')

plt.tight_layout()
plt.show()


EOFError: Ran out of input