In [1]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.keras.layers import Dropout, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2

In [2]:
tf.random.set_seed(42)
np.random.seed(42)

In [3]:
IMG_SIZE = (48, 48)
BATCH_SIZE = 64
EPOCHS = 100
TRAIN_DIR = "C:/MENATL_HEALTH/train"  
TEST_DIR = "C:/MENATL_HEALTH/test"  

In [4]:
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    brightness_range=[0.8, 1.2],
    fill_mode='nearest'
)

test_datagen = ImageDataGenerator(rescale=1./255)

In [5]:
train_generator = train_datagen.flow_from_directory(
    TRAIN_DIR,
    target_size=IMG_SIZE,
    color_mode='grayscale',
    batch_size=BATCH_SIZE,
    class_mode='sparse',
    shuffle=True
)

test_generator = test_datagen.flow_from_directory(
    TEST_DIR,
    target_size=IMG_SIZE,
    color_mode='grayscale',
    batch_size=BATCH_SIZE,
    class_mode='sparse',
    shuffle=False
)

Found 28709 images belonging to 7 classes.
Found 7178 images belonging to 7 classes.


In [6]:
class_names = list(train_generator.class_indices.keys())
print("Class names:", class_names)

Class names: ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']


In [7]:
from sklearn.utils import class_weight
import numpy as np

class_weights = class_weight.compute_class_weight(
    'balanced',
    classes=np.unique(train_generator.classes),
    y=train_generator.classes
)
class_weights = dict(enumerate(class_weights))
print("Class weights:", class_weights)

Class weights: {0: 1.0266046844269623, 1: 9.406618610747051, 2: 1.0010460615781582, 3: 0.5684387684387684, 4: 0.8260394187886635, 5: 0.8491274770777877, 6: 1.293372978330405}


In [8]:
def build_model(input_shape=(48, 48, 1), num_classes=7):
    model = Sequential([
        Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        BatchNormalization(),
        Conv2D(32, (3, 3), activation='relu'),
        BatchNormalization(),
        MaxPooling2D((2, 2)),
        Dropout(0.25),
        
        Conv2D(64, (3, 3), activation='relu'),
        BatchNormalization(),
        Conv2D(64, (3, 3), activation='relu'),
        BatchNormalization(),
        MaxPooling2D((2, 2)),
        Dropout(0.25),
        
        Conv2D(128, (3, 3), activation='relu'),
        BatchNormalization(),
        Conv2D(128, (3, 3), activation='relu'),
        BatchNormalization(),
        MaxPooling2D((2, 2)),
        Dropout(0.25),
        
        Flatten(),
        Dense(512, activation='relu'),
        BatchNormalization(),
        Dropout(0.5),
        Dense(256, activation='relu'),
        BatchNormalization(),
        Dropout(0.5),
        Dense(num_classes, activation='softmax')
    ])
    
    return model

model = build_model()
model.summary()

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [10]:
optimizer = Adam(learning_rate=0.0001)
model.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Callbacks
callbacks = [
    EarlyStopping(patience=15, monitor='val_loss', restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5),
    ModelCheckpoint('best_emotion_model.keras', monitor='val_accuracy', save_best_only=True)
]

In [11]:
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    validation_data=test_generator,
    validation_steps=test_generator.samples // BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=callbacks,
    class_weight=class_weights
)

  self._warn_if_super_not_called()


Epoch 1/100


InvalidArgumentError: Graph execution error:

Detected at node PyFunc defined at (most recent call last):
<stack traces unavailable>
TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was (tf.float32, tf.float32, tf.float32), but the yielded element was (array([[[[1.        ],
         [1.        ],
         [1.        ],
         ...,
         [0.6745098 ],
         [0.6784314 ],
         [0.6627451 ]],

        [[1.        ],
         [1.        ],
         [1.        ],
         ...,
         [0.6784314 ],
         [0.67058825],
         [0.6627451 ]],

        [[1.        ],
         [1.        ],
         [1.        ],
         ...,
         [0.6666667 ],
         [0.65882355],
         [0.6745098 ]],

        ...,

        [[0.43921572],
         [0.4431373 ],
         [0.46274513],
         ...,
         [0.8117648 ],
         [0.8196079 ],
         [0.8117648 ]],

        [[0.43921572],
         [0.45098042],
         [0.45098042],
         ...,
         [0.80392164],
         [0.8078432 ],
         [0.81568635]],

        [[0.4431373 ],
         [0.46274513],
         [0.43921572],
         ...,
         [0.78823537],
         [0.79215693],
         [0.7960785 ]]],


       [[[0.29803923],
         [0.29803923],
         [0.3019608 ],
         ...,
         [0.854902  ],
         [0.83921576],
         [0.82745105]],

        [[0.32156864],
         [0.30588236],
         [0.29803923],
         ...,
         [0.85098046],
         [0.8352942 ],
         [0.8196079 ]],

        [[0.34901962],
         [0.34901962],
         [0.34509805],
         ...,
         [0.8431373 ],
         [0.82745105],
         [0.81568635]],

        ...,

        [[0.3137255 ],
         [0.18431373],
         [0.13725491],
         ...,
         [0.09411766],
         [0.10196079],
         [0.19607845]],

        [[0.25490198],
         [0.16470589],
         [0.13333334],
         ...,
         [0.12156864],
         [0.13333334],
         [0.30980393]],

        [[0.18823531],
         [0.14117648],
         [0.14117648],
         ...,
         [0.15294118],
         [0.23137257],
         [0.38431376]]],


       [[[0.2392157 ],
         [0.23529413],
         [0.24705884],
         ...,
         [0.40000004],
         [0.40000004],
         [0.40000004]],

        [[0.2392157 ],
         [0.23529413],
         [0.24705884],
         ...,
         [0.40000004],
         [0.40000004],
         [0.40000004]],

        [[0.2392157 ],
         [0.23529413],
         [0.24705884],
         ...,
         [0.40000004],
         [0.40000004],
         [0.40000004]],

        ...,

        [[0.22352943],
         [0.19607845],
         [0.3137255 ],
         ...,
         [0.28235295],
         [0.28235295],
         [0.28235295]],

        [[0.21960786],
         [0.20784315],
         [0.20784315],
         ...,
         [0.3803922 ],
         [0.3803922 ],
         [0.3803922 ]],

        [[0.20784315],
         [0.227451  ],
         [0.19215688],
         ...,
         [0.3803922 ],
         [0.3803922 ],
         [0.3803922 ]]],


       ...,


       [[[0.35686275],
         [0.35686275],
         [0.35686275],
         ...,
         [0.4666667 ],
         [0.4666667 ],
         [0.4666667 ]],

        [[0.35686275],
         [0.35686275],
         [0.36078432],
         ...,
         [0.4666667 ],
         [0.4666667 ],
         [0.4666667 ]],

        [[0.36078432],
         [0.35686275],
         [0.3529412 ],
         ...,
         [0.4666667 ],
         [0.4666667 ],
         [0.4666667 ]],

        ...,

        [[0.54509807],
         [0.4666667 ],
         [0.5058824 ],
         ...,
         [0.8352942 ],
         [0.82745105],
         [0.81568635]],

        [[0.5254902 ],
         [0.4784314 ],
         [0.5176471 ],
         ...,
         [0.8196079 ],
         [0.82745105],
         [0.8352942 ]],

        [[0.5019608 ],
         [0.49411768],
         [0.5176471 ],
         ...,
         [0.8078432 ],
         [0.8078432 ],
         [0.8078432 ]]],


       [[[0.01568628],
         [0.01568628],
         [0.03137255],
         ...,
         [0.16862746],
         [0.0627451 ],
         [0.03529412]],

        [[0.01568628],
         [0.01960784],
         [0.03137255],
         ...,
         [0.13333334],
         [0.04313726],
         [0.04313726]],

        [[0.01568628],
         [0.02352941],
         [0.03137255],
         ...,
         [0.1137255 ],
         [0.03137255],
         [0.0509804 ]],

        ...,

        [[0.47450984],
         [0.49411768],
         [0.49411768],
         ...,
         [0.06666667],
         [0.06666667],
         [0.07058824]],

        [[0.4901961 ],
         [0.5137255 ],
         [0.5372549 ],
         ...,
         [0.02745098],
         [0.03137255],
         [0.03529412]],

        [[0.5254902 ],
         [0.5411765 ],
         [0.56078434],
         ...,
         [0.00392157],
         [0.00392157],
         [0.00392157]]],


       [[[0.6117647 ],
         [0.6117647 ],
         [0.40784317],
         ...,
         [0.6509804 ],
         [0.6509804 ],
         [0.6509804 ]],

        [[0.60784316],
         [0.61960787],
         [0.47058827],
         ...,
         [0.654902  ],
         [0.654902  ],
         [0.654902  ]],

        [[0.6       ],
         [0.5921569 ],
         [0.4039216 ],
         ...,
         [0.6509804 ],
         [0.6509804 ],
         [0.6509804 ]],

        ...,

        [[0.7411765 ],
         [0.6392157 ],
         [0.6117647 ],
         ...,
         [0.4431373 ],
         [0.45098042],
         [0.45098042]],

        [[0.74509805],
         [0.6392157 ],
         [0.63529414],
         ...,
         [0.5529412 ],
         [0.5529412 ],
         [0.5529412 ]],

        [[0.7568628 ],
         [0.6627451 ],
         [0.6627451 ],
         ...,
         [0.6431373 ],
         [0.6431373 ],
         [0.6509804 ]]]], dtype=float32), array([0., 3., 3., 4., 6., 6., 0., 5., 5., 2., 4., 2., 4., 0., 4., 6., 4.,
       3., 5., 0., 3., 0., 0., 5., 2., 6., 6., 5., 2., 3., 3., 2., 5., 3.,
       3., 3., 3., 0., 4., 3., 6., 5., 0., 3., 2., 4., 2., 3., 4., 2., 0.,
       3., 6., 0., 4., 0., 3., 6., 6., 3., 3., 0., 3., 3.], dtype=float32)).
Traceback (most recent call last):

  File "c:\Users\choud\anaconda3\Lib\site-packages\tensorflow\python\data\ops\from_generator_op.py", line 204, in generator_py_func
    flattened_values = nest.flatten_up_to(output_types, values)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "c:\Users\choud\anaconda3\Lib\site-packages\tensorflow\python\data\util\nest.py", line 237, in flatten_up_to
    return nest_util.flatten_up_to(
           ^^^^^^^^^^^^^^^^^^^^^^^^

  File "c:\Users\choud\anaconda3\Lib\site-packages\tensorflow\python\util\nest_util.py", line 1541, in flatten_up_to
    return _tf_data_flatten_up_to(shallow_tree, input_tree)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "c:\Users\choud\anaconda3\Lib\site-packages\tensorflow\python\util\nest_util.py", line 1570, in _tf_data_flatten_up_to
    _tf_data_assert_shallow_structure(shallow_tree, input_tree)

  File "c:\Users\choud\anaconda3\Lib\site-packages\tensorflow\python\util\nest_util.py", line 1427, in _tf_data_assert_shallow_structure
    raise ValueError(

ValueError: The two structures don't have the same sequence length. Input structure has length 2, while shallow structure has length 3.


The above exception was the direct cause of the following exception:


Traceback (most recent call last):

  File "c:\Users\choud\anaconda3\Lib\site-packages\tensorflow\python\ops\script_ops.py", line 269, in __call__
    ret = func(*args)
          ^^^^^^^^^^^

  File "c:\Users\choud\anaconda3\Lib\site-packages\tensorflow\python\autograph\impl\api.py", line 643, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^

  File "c:\Users\choud\anaconda3\Lib\site-packages\tensorflow\python\data\ops\from_generator_op.py", line 206, in generator_py_func
    raise TypeError(

TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was (tf.float32, tf.float32, tf.float32), but the yielded element was (array([[[[1.        ],
         [1.        ],
         [1.        ],
         ...,
         [0.6745098 ],
         [0.6784314 ],
         [0.6627451 ]],

        [[1.        ],
         [1.        ],
         [1.        ],
         ...,
         [0.6784314 ],
         [0.67058825],
         [0.6627451 ]],

        [[1.        ],
         [1.        ],
         [1.        ],
         ...,
         [0.6666667 ],
         [0.65882355],
         [0.6745098 ]],

        ...,

        [[0.43921572],
         [0.4431373 ],
         [0.46274513],
         ...,
         [0.8117648 ],
         [0.8196079 ],
         [0.8117648 ]],

        [[0.43921572],
         [0.45098042],
         [0.45098042],
         ...,
         [0.80392164],
         [0.8078432 ],
         [0.81568635]],

        [[0.4431373 ],
         [0.46274513],
         [0.43921572],
         ...,
         [0.78823537],
         [0.79215693],
         [0.7960785 ]]],


       [[[0.29803923],
         [0.29803923],
         [0.3019608 ],
         ...,
         [0.854902  ],
         [0.83921576],
         [0.82745105]],

        [[0.32156864],
         [0.30588236],
         [0.29803923],
         ...,
         [0.85098046],
         [0.8352942 ],
         [0.8196079 ]],

        [[0.34901962],
         [0.34901962],
         [0.34509805],
         ...,
         [0.8431373 ],
         [0.82745105],
         [0.81568635]],

        ...,

        [[0.3137255 ],
         [0.18431373],
         [0.13725491],
         ...,
         [0.09411766],
         [0.10196079],
         [0.19607845]],

        [[0.25490198],
         [0.16470589],
         [0.13333334],
         ...,
         [0.12156864],
         [0.13333334],
         [0.30980393]],

        [[0.18823531],
         [0.14117648],
         [0.14117648],
         ...,
         [0.15294118],
         [0.23137257],
         [0.38431376]]],


       [[[0.2392157 ],
         [0.23529413],
         [0.24705884],
         ...,
         [0.40000004],
         [0.40000004],
         [0.40000004]],

        [[0.2392157 ],
         [0.23529413],
         [0.24705884],
         ...,
         [0.40000004],
         [0.40000004],
         [0.40000004]],

        [[0.2392157 ],
         [0.23529413],
         [0.24705884],
         ...,
         [0.40000004],
         [0.40000004],
         [0.40000004]],

        ...,

        [[0.22352943],
         [0.19607845],
         [0.3137255 ],
         ...,
         [0.28235295],
         [0.28235295],
         [0.28235295]],

        [[0.21960786],
         [0.20784315],
         [0.20784315],
         ...,
         [0.3803922 ],
         [0.3803922 ],
         [0.3803922 ]],

        [[0.20784315],
         [0.227451  ],
         [0.19215688],
         ...,
         [0.3803922 ],
         [0.3803922 ],
         [0.3803922 ]]],


       ...,


       [[[0.35686275],
         [0.35686275],
         [0.35686275],
         ...,
         [0.4666667 ],
         [0.4666667 ],
         [0.4666667 ]],

        [[0.35686275],
         [0.35686275],
         [0.36078432],
         ...,
         [0.4666667 ],
         [0.4666667 ],
         [0.4666667 ]],

        [[0.36078432],
         [0.35686275],
         [0.3529412 ],
         ...,
         [0.4666667 ],
         [0.4666667 ],
         [0.4666667 ]],

        ...,

        [[0.54509807],
         [0.4666667 ],
         [0.5058824 ],
         ...,
         [0.8352942 ],
         [0.82745105],
         [0.81568635]],

        [[0.5254902 ],
         [0.4784314 ],
         [0.5176471 ],
         ...,
         [0.8196079 ],
         [0.82745105],
         [0.8352942 ]],

        [[0.5019608 ],
         [0.49411768],
         [0.5176471 ],
         ...,
         [0.8078432 ],
         [0.8078432 ],
         [0.8078432 ]]],


       [[[0.01568628],
         [0.01568628],
         [0.03137255],
         ...,
         [0.16862746],
         [0.0627451 ],
         [0.03529412]],

        [[0.01568628],
         [0.01960784],
         [0.03137255],
         ...,
         [0.13333334],
         [0.04313726],
         [0.04313726]],

        [[0.01568628],
         [0.02352941],
         [0.03137255],
         ...,
         [0.1137255 ],
         [0.03137255],
         [0.0509804 ]],

        ...,

        [[0.47450984],
         [0.49411768],
         [0.49411768],
         ...,
         [0.06666667],
         [0.06666667],
         [0.07058824]],

        [[0.4901961 ],
         [0.5137255 ],
         [0.5372549 ],
         ...,
         [0.02745098],
         [0.03137255],
         [0.03529412]],

        [[0.5254902 ],
         [0.5411765 ],
         [0.56078434],
         ...,
         [0.00392157],
         [0.00392157],
         [0.00392157]]],


       [[[0.6117647 ],
         [0.6117647 ],
         [0.40784317],
         ...,
         [0.6509804 ],
         [0.6509804 ],
         [0.6509804 ]],

        [[0.60784316],
         [0.61960787],
         [0.47058827],
         ...,
         [0.654902  ],
         [0.654902  ],
         [0.654902  ]],

        [[0.6       ],
         [0.5921569 ],
         [0.4039216 ],
         ...,
         [0.6509804 ],
         [0.6509804 ],
         [0.6509804 ]],

        ...,

        [[0.7411765 ],
         [0.6392157 ],
         [0.6117647 ],
         ...,
         [0.4431373 ],
         [0.45098042],
         [0.45098042]],

        [[0.74509805],
         [0.6392157 ],
         [0.63529414],
         ...,
         [0.5529412 ],
         [0.5529412 ],
         [0.5529412 ]],

        [[0.7568628 ],
         [0.6627451 ],
         [0.6627451 ],
         ...,
         [0.6431373 ],
         [0.6431373 ],
         [0.6509804 ]]]], dtype=float32), array([0., 3., 3., 4., 6., 6., 0., 5., 5., 2., 4., 2., 4., 0., 4., 6., 4.,
       3., 5., 0., 3., 0., 0., 5., 2., 6., 6., 5., 2., 3., 3., 2., 5., 3.,
       3., 3., 3., 0., 4., 3., 6., 5., 0., 3., 2., 4., 2., 3., 4., 2., 0.,
       3., 6., 0., 4., 0., 3., 6., 6., 3., 3., 0., 3., 3.], dtype=float32)).


	 [[{{node PyFunc}}]]
	 [[IteratorGetNext]] [Op:__inference_multi_step_on_iterator_7341]