In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# Set image dimensions and other parameters
image_size = (224, 224)  # ResNet50 expects 224x224 images
batch_size = 32
epochs = 50

# Define the paths to the dataset (adjust these paths as needed)
base_dir = './data'  # Root directory where the 'Angry', 'Happy', 'Sad', 'Fear' folders are located

# 1. **Data Augmentation**: Apply common transformations to prevent overfitting
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2  # This will split off 20% for validation
)

# Validation and Test generators (only rescaling)
val_test_datagen = ImageDataGenerator(rescale=1./255)

# 2. **Load the data into train and validation sets using validation_split**
train_data_gen = train_datagen.flow_from_directory(
    base_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode='categorical',  # Assuming multi-class classification (4 classes)
    shuffle=True,
    subset='training'  # This is the training data subset (80%)
)

val_data_gen = train_datagen.flow_from_directory(
    base_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode='categorical',  # Multi-class classification
    subset='validation'  # This is the validation data subset (20%)
)

# 3. **Define the model with ResNet50 as the base model**
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze the layers of the ResNet50 model so we only train the top layers
base_model.trainable = False

# Create the model by adding custom layers on top of ResNet50
model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(1024, activation='relu'),
    layers.Dropout(0.5),  # Dropout layer to prevent overfitting
    layers.Dense(4, activation='softmax')  # Output layer with 4 categories
])

# 4. **Compile the model**
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),  # A low learning rate for fine-tuning
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# 5. **Setup callbacks for early stopping and model checkpointing**
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
  
checkpoint = ModelCheckpoint('best_model.keras', monitor='val_loss', save_best_only=True)


Found 322 images belonging to 4 classes.
Found 79 images belonging to 4 classes.


In [7]:
# 6. **Train the model**
history = model.fit(
    train_data_gen,
    epochs=epochs,
    validation_data=val_data_gen,
    callbacks=[early_stop, checkpoint],
    verbose=1
)


  self._warn_if_super_not_called()


Epoch 1/50
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 4s/step - accuracy: 0.2713 - loss: 1.6548 - val_accuracy: 0.2532 - val_loss: 1.4060
Epoch 2/50
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 3s/step - accuracy: 0.2666 - loss: 1.5621 - val_accuracy: 0.2278 - val_loss: 1.3864
Epoch 3/50
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 4s/step - accuracy: 0.2811 - loss: 1.5217 - val_accuracy: 0.3291 - val_loss: 1.3742
Epoch 4/50
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 4s/step - accuracy: 0.2640 - loss: 1.5567 - val_accuracy: 0.2911 - val_loss: 1.3712
Epoch 5/50
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 3s/step - accuracy: 0.2944 - loss: 1.4932 - val_accuracy: 0.3544 - val_loss: 1.3574
Epoch 6/50
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 3s/step - accuracy: 0.3215 - loss: 1.4714 - val_accuracy: 0.3165 - val_loss: 1.3472
Epoch 7/50
[1m11/11[0m [32m━━━━━━━━━━

In [8]:

test_loss, test_acc = model.evaluate(val_data_gen)
print(f"Test accuracy: {test_acc:.4f}")

[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 2s/step - accuracy: 0.3262 - loss: 1.3376
Test accuracy: 0.3165
