In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# Set image dimensions and other parameters
image_size = (224, 224)  # ResNet50 expects 224x224 images
batch_size = 32
epochs = 50

# Define the paths to the dataset (adjust these paths as needed)
base_dir = './data'  # Root directory where the 'Angry', 'Happy', 'Sad', 'Fear' folders are located

# 1. **Data Augmentation**: Apply common transformations to prevent overfitting
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2  # This will split off 20% for validation
)

# Validation and Test generators (only rescaling)
val_test_datagen = ImageDataGenerator(rescale=1./255)

# 2. **Load the data into train and validation sets using validation_split**
train_data_gen = train_datagen.flow_from_directory(
    base_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode='categorical',  # Assuming multi-class classification (4 classes)
    shuffle=True,
    subset='training'  # This is the training data subset (80%)
)

val_data_gen = train_datagen.flow_from_directory(
    base_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode='categorical',  # Multi-class classification
    subset='validation'  # This is the validation data subset (20%)
)

# 3. **Define the model with ResNet50 as the base model**
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze the layers of the ResNet50 model so we only train the top layers
base_model.trainable = False

# Create the model by adding custom layers on top of ResNet50
model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(1024, activation='relu'),
    layers.Dropout(0.5),  # Dropout layer to prevent overfitting
    layers.Dense(4, activation='softmax')  # Output layer with 4 categories
])

# 4. **Compile the model**
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),  # A low learning rate for fine-tuning
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# 5. **Setup callbacks for early stopping and model checkpointing**
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
  
checkpoint = ModelCheckpoint('best_model.keras', monitor='val_loss', save_best_only=True)


In [None]:
import splitfolders
splitfolders.ratio("./dataset", output="./data-split", seed=1337, ratio=(0.8, 0.1,0.1), group_prefix=None, move=False)


In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(rescale=1./255)
valid_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    './dataset-split/train/',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'  # Ensure this is 'categorical' for multi-class classification
)

valid_generator = valid_datagen.flow_from_directory(
    './dataset-split/val/',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'  # Ensure this is 'categorical' for multi-class classification
)


Found 482 images belonging to 4 classes.
Found 74 images belonging to 4 classes.
Found 138 images belonging to 4 classes.


In [5]:
# ResNet50 model without the top layer (for transfer learning)
resnet_50 = ResNet50(
    include_top=False,     # Exclude the fully connected layers
    weights='imagenet',    # Load pre-trained ImageNet weights
    input_shape=(224, 224, 3)  # Input image shape
)

# Freeze all layers of ResNet50 initially
for layer in resnet_50.layers[-4:]:  # Unfreeze the last 4 layers of ResNet50
    layer.trainable = True

# Add custom layers on top of ResNet50
x = resnet_50.output
x = GlobalAveragePooling2D()(x)  # Global Average Pooling
x = BatchNormalization()(x)      # Batch Normalization
x = Dense(128, activation='relu')(x)  # Fully connected layer with 128 units
x = Dropout(0.5)(x)              # Dropout layer to prevent overfitting

# Output layer (softmax for 4 classes)
predictions = Dense(4, activation='softmax')(x)

# Final model
model = Model(inputs=resnet_50.input, outputs=predictions)

# Compile the model
model.compile(optimizer=Adam(learning_rate=1e-5), loss="categorical_crossentropy", metrics=["accuracy"])


In [6]:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

# Function to train the model
def trainModel(model, train_generator, valid_generator, epochs, class_weights=None):
    batch_size = 32
    
    # Callbacks to stop training early and reduce learning rate
    early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6)
    
    # Train the model with class weights if provided
    return model.fit(
        train_generator,
        validation_data=valid_generator,
        epochs=epochs,
        batch_size=batch_size,
        class_weight=class_weights,  # Include class weights
        callbacks=[early_stopping, reduce_lr]
    )

# Define class weights (adjust based on your dataset)
class_weights = {0: 1., 1: 1., 2: 2., 3: 2.}

# Call the trainModel function
model_history = trainModel(
    model=model, 
    train_generator=train_generator, 
    valid_generator=valid_generator, 
    epochs=50, 
    class_weights=class_weights
)


  self._warn_if_super_not_called()


Epoch 1/50


InvalidArgumentError: Graph execution error:

Detected at node PyFunc defined at (most recent call last):
<stack traces unavailable>
TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was (tf.float32, tf.float32, tf.float32), but the yielded element was (array([[[[ -90.939    ,  -94.779    ,  -97.68     ],
         [ -89.98149  ,  -93.82149  ,  -96.72249  ],
         [ -88.99046  ,  -92.83046  ,  -95.73146  ],
         ...,
         [ -31.971985 ,  -18.192192 ,   -9.299789 ],
         [ -38.334053 ,  -25.563263 ,  -16.464264 ],
         [ -49.65444  ,  -37.47946  ,  -28.380463 ]],

        [[ -90.939    ,  -94.779    ,  -97.68     ],
         [ -90.405525 ,  -94.24553  ,  -97.14653  ],
         [ -89.414505 ,  -93.2545   ,  -96.1555   ],
         ...,
         [ -40.45012  ,  -27.790703 ,  -18.691704 ],
         [ -50.228188 ,  -38.26097  ,  -29.161972 ],
         [ -52.015617 ,  -41.24002  ,  -32.14102  ]],

        [[ -90.939    ,  -94.779    ,  -97.68     ],
         [ -90.82957  ,  -94.66956  ,  -97.570564 ],
         [ -89.83855  ,  -93.67854  ,  -96.579544 ],
         ...,
         [ -50.562305 ,  -38.817833 ,  -29.718834 ],
         [ -52.349735 ,  -41.796883 ,  -32.697884 ],
         [ -51.74084  ,  -41.58084  ,  -32.481842 ]],

        ...,

        [[ -86.75221  ,  -82.779    ,  -82.68     ],
         [ -87.93659  ,  -82.77659  ,  -82.67759  ],
         [ -87.34078  ,  -82.18078  ,  -82.08178  ],
         ...,
         [ -44.939003 ,  -82.779    , -115.68     ],
         [ -41.54163  ,  -79.12029  , -112.54397  ],
         [ -25.377617 ,  -62.72234  ,  -98.12806  ]],

        [[ -87.82522  ,  -82.665215 ,  -82.566216 ],
         [ -87.22941  ,  -82.069405 ,  -81.970406 ],
         [ -84.80117  ,  -79.64117  ,  -79.54217  ],
         ...,
         [ -42.633095 ,  -80.31937  , -112.60545  ],
         [ -44.939003 ,  -82.779    , -115.68     ],
         [ -34.170822 ,  -71.182495 , -105.740135 ]],

        [[ -87.118034 ,  -81.95804  ,  -81.85904  ],
         [ -84.02157  ,  -78.861565 ,  -78.762566 ],
         [ -79.939    ,  -74.779    ,  -74.68     ],
         ...,
         [ -36.272507 ,  -73.53474  , -104.12467  ],
         [ -44.939003 ,  -82.779    , -115.68     ],
         [ -39.683327 ,  -77.11903  , -110.828606 ]]],


       [[[ 110.38456  ,   97.54456  ,   90.64356  ],
         [ 117.754814 ,  104.91482  ,   98.01382  ],
         [ 118.97321  ,  106.13322  ,   99.232216 ],
         ...,
         [ 118.12104  ,  105.281044 ,   98.38004  ],
         [ 119.042915 ,  106.20292  ,   99.30192  ],
         [ 119.848564 ,  107.00857  ,  100.10757  ]],

        [[ 111.05479  ,   98.21479  ,   91.31379  ],
         [ 117.5559   ,  104.715904 ,   97.8149   ],
         [ 124.057014 ,  111.21702  ,  104.31602  ],
         ...,
         [ 116.417885 ,  103.57789  ,   96.67689  ],
         [ 116.1065   ,  103.2665   ,   96.3655   ],
         [ 115.176384 ,  102.33639  ,   95.43539  ]],

        [[ 113.06446  ,  100.224464 ,   93.32346  ],
         [ 119.564064 ,  106.72407  ,   99.82307  ],
         [ 126.50323  ,  113.66323  ,  106.76223  ],
         ...,
         [ 114.13218  ,  101.29218  ,   94.39118  ],
         [ 112.870735 ,  100.03074  ,   93.12974  ],
         [ 110.80599  ,   97.965996 ,   91.064995 ]],

        ...,

        [[ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         ...,
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ]],

        [[ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         ...,
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ]],

        [[ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         ...,
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ]]],


       [[[  31.151436 ,   19.31144  ,   10.410439 ],
         [  38.447853 ,   26.607857 ,   17.706856 ],
         [  73.78977  ,   61.949776 ,   53.048775 ],
         ...,
         [  93.69001  ,   82.85001  ,   75.94901  ],
         [  92.682076 ,   81.84208  ,   74.94108  ],
         [  91.67414  ,   80.834145 ,   73.93314  ]],

        [[  50.77578  ,   38.935783 ,   30.034782 ],
         [  40.30163  ,   28.461632 ,   19.56063  ],
         [  58.50776  ,   46.667763 ,   37.76676  ],
         ...,
         [  91.8292   ,   80.989204 ,   74.0882   ],
         [  92.69314  ,   81.85314  ,   74.95214  ],
         [  93.55709  ,   82.717094 ,   75.81609  ]],

        [[  81.30981  ,   69.46981  ,   60.56881  ],
         [  42.92936  ,   31.089363 ,   22.188362 ],
         [  33.261864 ,   21.421867 ,   12.520866 ],
         ...,
         [  79.04406  ,   68.20406  ,   61.303062 ],
         [  80.48397  ,   69.643974 ,   62.742973 ],
         [  81.92388  ,   71.083885 ,   64.182884 ]],

        ...,

        [[ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         ...,
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ]],

        [[ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         ...,
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ]],

        [[ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         ...,
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ]]],


       ...,


       [[[ 130.061    ,  120.221    ,  111.32     ],
         [ 130.061    ,  120.221    ,  111.32     ],
         [ 130.061    ,  120.221    ,  111.32     ],
         ...,
         [  94.061    ,   84.221    ,   75.32     ],
         [  93.205696 ,   83.3657   ,   74.4647   ],
         [  93.061    ,   83.221    ,   74.32     ]],

        [[ 130.061    ,  120.221    ,  111.32     ],
         [ 130.061    ,  120.221    ,  111.32     ],
         [ 130.061    ,  120.221    ,  111.32     ],
         ...,
         [  94.061    ,   84.221    ,   75.32     ],
         [  93.061    ,   83.221    ,   74.32     ],
         [  93.061    ,   83.221    ,   74.32     ]],

        [[ 130.061    ,  120.221    ,  111.32     ],
         [ 130.061    ,  120.221    ,  111.32     ],
         [ 130.061    ,  120.221    ,  111.32     ],
         ...,
         [  93.96204  ,   84.12205  ,   75.22105  ],
         [  93.061    ,   83.221    ,   74.32     ],
         [  93.061    ,   83.221    ,   74.32     ]],

        ...,

        [[ -69.55296  ,  -75.19606  ,  -41.69551  ],
         [ -58.4503   ,  -62.609505 ,  -30.872086 ],
         [ -61.80721  ,  -65.6472   ,  -34.548203 ],
         ...,
         [ -72.51167  ,  -77.16078  ,  -44.443565 ],
         [ -72.03226  ,  -76.84117  ,  -43.804337 ],
         [ -71.939    ,  -76.90772  ,  -43.422554 ]],

        [[ -68.08595  ,  -73.40305  ,  -40.065506 ],
         [ -56.657288 ,  -60.653492 ,  -29.24208  ],
         [ -63.274216 ,  -67.11421  ,  -36.015213 ],
         ...,
         [ -74.939    ,  -78.779    ,  -47.68     ],
         [ -74.939    ,  -78.779    ,  -47.68     ],
         [ -74.89541  ,  -78.74994  ,  -47.62188  ]],

        [[ -66.61895  ,  -71.61005  ,  -38.435493 ],
         [ -55.00014  ,  -58.840137 ,  -27.741135 ],
         [ -63.939003 ,  -67.779    ,  -36.68     ],
         ...,
         [ -76.19684  ,  -80.036835 ,  -49.35711  ],
         [ -75.717415 ,  -79.55741  ,  -48.71788  ],
         [ -75.23799  ,  -79.07799  ,  -48.07865  ]]],


       [[[  58.469154 ,  120.148186 ,  117.373344 ],
         [ 129.83282  ,  119.33723  ,  106.830986 ],
         [ 134.50323  ,  128.25238  ,  109.488335 ],
         ...,
         [ 148.768    ,  133.92801  ,  126.027016 ],
         [ 148.12866  ,  133.28867  ,  125.38767  ],
         [ 149.10202  ,  134.26202  ,  126.361015 ]],

        [[  37.190742 ,  121.92046  ,  115.77082  ],
         [  86.03661  ,  119.83491  ,  113.30088  ],
         [ 138.48471  ,  121.78249  ,  106.50222  ],
         ...,
         [ 148.73022  ,  133.89023  ,  125.989235 ],
         [ 148.79123  ,  133.95123  ,  126.05024  ],
         [ 150.61026  ,  135.77026  ,  127.869255 ]],

        [[  19.77179  ,  123.9112   ,  112.784706 ],
         [  48.155083 ,  120.6674   ,  117.65041  ],
         [ 113.60409  ,  119.521645 ,  109.2284   ],
         ...,
         [ 149.50964  ,  134.66965  ,  126.76864  ],
         [ 146.92053  ,  132.08054  ,  124.17953  ],
         [ 149.93597  ,  135.09598  ,  127.19497  ]],

        ...,

        [[ 146.10834  ,  129.26834  ,  121.36733  ],
         [ 147.0217   ,  130.1817   ,  122.28069  ],
         [ 147.93506  ,  131.09506  ,  123.19407  ],
         ...,
         [ 149.84198  ,   97.1641   ,   40.301308 ],
         [ 142.20139  ,   69.824196 ,    1.6090927],
         [ 108.14144  ,   39.008064 ,  -45.426064 ]],

        [[ 147.86179  ,  131.02179  ,  123.12078  ],
         [ 147.70392  ,  130.86392  ,  122.96293  ],
         [ 147.24725  ,  130.40726  ,  122.50625  ],
         ...,
         [ 150.09634  ,   98.92501  ,   41.586617 ],
         [ 148.58893  ,   86.8263   ,   25.89103  ],
         [ 129.04419  ,   57.920067 ,  -16.560371 ]],

        [[ 147.28387  ,  130.44388  ,  122.542885 ],
         [ 146.82721  ,  129.98721  ,  122.086205 ],
         [ 146.37051  ,  129.53052  ,  121.629524 ],
         ...,
         [ 147.60794  ,   77.52459  ,    8.739464 ],
         [ 150.57965  ,  103.249886 ,   48.784508 ],
         [ 147.33585  ,   76.4885   ,   11.480766 ]]],


       [[[  58.29081  ,   76.893364 ,   79.87745  ],
         [  61.71862  ,   79.878624 ,   81.97762  ],
         [  65.25152  ,   83.95916  ,   84.41526  ],
         ...,
         [ -80.939    ,  -84.779    ,  -87.68     ],
         [ -81.4729   ,  -85.3129   ,  -88.2139   ],
         [ -83.866005 ,  -87.706    ,  -90.607    ]],

        [[  56.343636 ,   75.432976 ,   79.39066  ],
         [  60.258247 ,   78.41825  ,   80.51725  ],
         [  63.304344 ,   81.525185 ,   83.44167  ],
         ...,
         [ -82.304276 ,  -86.14427  ,  -89.04527  ],
         [ -84.69738  ,  -88.53738  ,  -91.43838  ],
         [ -87.62836  ,  -91.46835  ,  -94.369354 ]],

        [[  56.060997 ,   75.221    ,   80.152275 ],
         [  58.37681  ,   76.95786  ,   79.89895  ],
         [  61.78312  ,   79.94312  ,   82.04212  ],
         ...,
         [ -85.67619  ,  -89.51619  ,  -92.41719  ],
         [ -88.66758  ,  -92.507576 ,  -95.40858  ],
         [ -90.97098  ,  -94.810974 ,  -97.711975 ]],

        ...,

        [[ -93.751236 , -113.59123  , -116.42964  ],
         [ -97.00598  , -116.71202  , -120.74698  ],
         [ -97.60426  , -116.11374  , -121.34526  ],
         ...,
         [ -44.46813  ,  -32.30812  ,  -23.209122 ],
         [ -44.710735 ,  -31.664864 ,  -24.3376   ],
         [ -44.05805  ,  -30.89804  ,  -25.560959 ]],

        [[ -97.21383  , -116.50417  , -120.95483  ],
         [ -97.8121   , -115.9059   , -121.5531   ],
         [ -98.88176  , -115.30762  , -122.15138  ],
         ...,
         [ -51.283237 ,  -39.12323  ,  -30.024231 ],
         [ -43.73715  ,  -31.17807  ,  -22.87722  ],
         [ -44.54484  ,  -31.384834 ,  -25.074165 ]],

        [[ -98.1009   , -115.69805  , -121.76095  ],
         [ -99.29745  , -115.09978  , -122.35922  ],
         [-101.049    , -114.5015   , -122.68     ],
         ...,
         [ -56.939003 ,  -44.779    ,  -35.68     ],
         [ -44.1671   ,  -32.007095 ,  -22.908096 ],
         [ -44.75374  ,  -31.68637  ,  -24.402107 ]]]], dtype=float32), array([[0., 0., 0., 1.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 1., 0., 0.],
       [1., 0., 0., 0.],
       [0., 0., 0., 1.],
       [0., 1., 0., 0.],
       [0., 0., 0., 1.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 1., 0., 0.],
       [0., 0., 0., 1.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [0., 0., 1., 0.],
       [1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [1., 0., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 1., 0., 0.]], dtype=float32)).
Traceback (most recent call last):

  File "c:\Users\Lenovo\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\python\data\ops\from_generator_op.py", line 204, in generator_py_func
    flattened_values = nest.flatten_up_to(output_types, values)

  File "c:\Users\Lenovo\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\python\data\util\nest.py", line 237, in flatten_up_to
    return nest_util.flatten_up_to(

  File "c:\Users\Lenovo\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\python\util\nest_util.py", line 1541, in flatten_up_to
    return _tf_data_flatten_up_to(shallow_tree, input_tree)

  File "c:\Users\Lenovo\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\python\util\nest_util.py", line 1570, in _tf_data_flatten_up_to
    _tf_data_assert_shallow_structure(shallow_tree, input_tree)

  File "c:\Users\Lenovo\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\python\util\nest_util.py", line 1427, in _tf_data_assert_shallow_structure
    raise ValueError(

ValueError: The two structures don't have the same sequence length. Input structure has length 2, while shallow structure has length 3.


The above exception was the direct cause of the following exception:


Traceback (most recent call last):

  File "c:\Users\Lenovo\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\python\ops\script_ops.py", line 269, in __call__
    ret = func(*args)

  File "c:\Users\Lenovo\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\python\autograph\impl\api.py", line 643, in wrapper
    return func(*args, **kwargs)

  File "c:\Users\Lenovo\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\python\data\ops\from_generator_op.py", line 206, in generator_py_func
    raise TypeError(

TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was (tf.float32, tf.float32, tf.float32), but the yielded element was (array([[[[ -90.939    ,  -94.779    ,  -97.68     ],
         [ -89.98149  ,  -93.82149  ,  -96.72249  ],
         [ -88.99046  ,  -92.83046  ,  -95.73146  ],
         ...,
         [ -31.971985 ,  -18.192192 ,   -9.299789 ],
         [ -38.334053 ,  -25.563263 ,  -16.464264 ],
         [ -49.65444  ,  -37.47946  ,  -28.380463 ]],

        [[ -90.939    ,  -94.779    ,  -97.68     ],
         [ -90.405525 ,  -94.24553  ,  -97.14653  ],
         [ -89.414505 ,  -93.2545   ,  -96.1555   ],
         ...,
         [ -40.45012  ,  -27.790703 ,  -18.691704 ],
         [ -50.228188 ,  -38.26097  ,  -29.161972 ],
         [ -52.015617 ,  -41.24002  ,  -32.14102  ]],

        [[ -90.939    ,  -94.779    ,  -97.68     ],
         [ -90.82957  ,  -94.66956  ,  -97.570564 ],
         [ -89.83855  ,  -93.67854  ,  -96.579544 ],
         ...,
         [ -50.562305 ,  -38.817833 ,  -29.718834 ],
         [ -52.349735 ,  -41.796883 ,  -32.697884 ],
         [ -51.74084  ,  -41.58084  ,  -32.481842 ]],

        ...,

        [[ -86.75221  ,  -82.779    ,  -82.68     ],
         [ -87.93659  ,  -82.77659  ,  -82.67759  ],
         [ -87.34078  ,  -82.18078  ,  -82.08178  ],
         ...,
         [ -44.939003 ,  -82.779    , -115.68     ],
         [ -41.54163  ,  -79.12029  , -112.54397  ],
         [ -25.377617 ,  -62.72234  ,  -98.12806  ]],

        [[ -87.82522  ,  -82.665215 ,  -82.566216 ],
         [ -87.22941  ,  -82.069405 ,  -81.970406 ],
         [ -84.80117  ,  -79.64117  ,  -79.54217  ],
         ...,
         [ -42.633095 ,  -80.31937  , -112.60545  ],
         [ -44.939003 ,  -82.779    , -115.68     ],
         [ -34.170822 ,  -71.182495 , -105.740135 ]],

        [[ -87.118034 ,  -81.95804  ,  -81.85904  ],
         [ -84.02157  ,  -78.861565 ,  -78.762566 ],
         [ -79.939    ,  -74.779    ,  -74.68     ],
         ...,
         [ -36.272507 ,  -73.53474  , -104.12467  ],
         [ -44.939003 ,  -82.779    , -115.68     ],
         [ -39.683327 ,  -77.11903  , -110.828606 ]]],


       [[[ 110.38456  ,   97.54456  ,   90.64356  ],
         [ 117.754814 ,  104.91482  ,   98.01382  ],
         [ 118.97321  ,  106.13322  ,   99.232216 ],
         ...,
         [ 118.12104  ,  105.281044 ,   98.38004  ],
         [ 119.042915 ,  106.20292  ,   99.30192  ],
         [ 119.848564 ,  107.00857  ,  100.10757  ]],

        [[ 111.05479  ,   98.21479  ,   91.31379  ],
         [ 117.5559   ,  104.715904 ,   97.8149   ],
         [ 124.057014 ,  111.21702  ,  104.31602  ],
         ...,
         [ 116.417885 ,  103.57789  ,   96.67689  ],
         [ 116.1065   ,  103.2665   ,   96.3655   ],
         [ 115.176384 ,  102.33639  ,   95.43539  ]],

        [[ 113.06446  ,  100.224464 ,   93.32346  ],
         [ 119.564064 ,  106.72407  ,   99.82307  ],
         [ 126.50323  ,  113.66323  ,  106.76223  ],
         ...,
         [ 114.13218  ,  101.29218  ,   94.39118  ],
         [ 112.870735 ,  100.03074  ,   93.12974  ],
         [ 110.80599  ,   97.965996 ,   91.064995 ]],

        ...,

        [[ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         ...,
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ]],

        [[ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         ...,
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ]],

        [[ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         ...,
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ],
         [ 151.061    ,  138.22101  ,  131.32     ]]],


       [[[  31.151436 ,   19.31144  ,   10.410439 ],
         [  38.447853 ,   26.607857 ,   17.706856 ],
         [  73.78977  ,   61.949776 ,   53.048775 ],
         ...,
         [  93.69001  ,   82.85001  ,   75.94901  ],
         [  92.682076 ,   81.84208  ,   74.94108  ],
         [  91.67414  ,   80.834145 ,   73.93314  ]],

        [[  50.77578  ,   38.935783 ,   30.034782 ],
         [  40.30163  ,   28.461632 ,   19.56063  ],
         [  58.50776  ,   46.667763 ,   37.76676  ],
         ...,
         [  91.8292   ,   80.989204 ,   74.0882   ],
         [  92.69314  ,   81.85314  ,   74.95214  ],
         [  93.55709  ,   82.717094 ,   75.81609  ]],

        [[  81.30981  ,   69.46981  ,   60.56881  ],
         [  42.92936  ,   31.089363 ,   22.188362 ],
         [  33.261864 ,   21.421867 ,   12.520866 ],
         ...,
         [  79.04406  ,   68.20406  ,   61.303062 ],
         [  80.48397  ,   69.643974 ,   62.742973 ],
         [  81.92388  ,   71.083885 ,   64.182884 ]],

        ...,

        [[ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         ...,
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ]],

        [[ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         ...,
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ]],

        [[ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         ...,
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ],
         [ -46.939003 ,  -69.779    ,  -83.68     ]]],


       ...,


       [[[ 130.061    ,  120.221    ,  111.32     ],
         [ 130.061    ,  120.221    ,  111.32     ],
         [ 130.061    ,  120.221    ,  111.32     ],
         ...,
         [  94.061    ,   84.221    ,   75.32     ],
         [  93.205696 ,   83.3657   ,   74.4647   ],
         [  93.061    ,   83.221    ,   74.32     ]],

        [[ 130.061    ,  120.221    ,  111.32     ],
         [ 130.061    ,  120.221    ,  111.32     ],
         [ 130.061    ,  120.221    ,  111.32     ],
         ...,
         [  94.061    ,   84.221    ,   75.32     ],
         [  93.061    ,   83.221    ,   74.32     ],
         [  93.061    ,   83.221    ,   74.32     ]],

        [[ 130.061    ,  120.221    ,  111.32     ],
         [ 130.061    ,  120.221    ,  111.32     ],
         [ 130.061    ,  120.221    ,  111.32     ],
         ...,
         [  93.96204  ,   84.12205  ,   75.22105  ],
         [  93.061    ,   83.221    ,   74.32     ],
         [  93.061    ,   83.221    ,   74.32     ]],

        ...,

        [[ -69.55296  ,  -75.19606  ,  -41.69551  ],
         [ -58.4503   ,  -62.609505 ,  -30.872086 ],
         [ -61.80721  ,  -65.6472   ,  -34.548203 ],
         ...,
         [ -72.51167  ,  -77.16078  ,  -44.443565 ],
         [ -72.03226  ,  -76.84117  ,  -43.804337 ],
         [ -71.939    ,  -76.90772  ,  -43.422554 ]],

        [[ -68.08595  ,  -73.40305  ,  -40.065506 ],
         [ -56.657288 ,  -60.653492 ,  -29.24208  ],
         [ -63.274216 ,  -67.11421  ,  -36.015213 ],
         ...,
         [ -74.939    ,  -78.779    ,  -47.68     ],
         [ -74.939    ,  -78.779    ,  -47.68     ],
         [ -74.89541  ,  -78.74994  ,  -47.62188  ]],

        [[ -66.61895  ,  -71.61005  ,  -38.435493 ],
         [ -55.00014  ,  -58.840137 ,  -27.741135 ],
         [ -63.939003 ,  -67.779    ,  -36.68     ],
         ...,
         [ -76.19684  ,  -80.036835 ,  -49.35711  ],
         [ -75.717415 ,  -79.55741  ,  -48.71788  ],
         [ -75.23799  ,  -79.07799  ,  -48.07865  ]]],


       [[[  58.469154 ,  120.148186 ,  117.373344 ],
         [ 129.83282  ,  119.33723  ,  106.830986 ],
         [ 134.50323  ,  128.25238  ,  109.488335 ],
         ...,
         [ 148.768    ,  133.92801  ,  126.027016 ],
         [ 148.12866  ,  133.28867  ,  125.38767  ],
         [ 149.10202  ,  134.26202  ,  126.361015 ]],

        [[  37.190742 ,  121.92046  ,  115.77082  ],
         [  86.03661  ,  119.83491  ,  113.30088  ],
         [ 138.48471  ,  121.78249  ,  106.50222  ],
         ...,
         [ 148.73022  ,  133.89023  ,  125.989235 ],
         [ 148.79123  ,  133.95123  ,  126.05024  ],
         [ 150.61026  ,  135.77026  ,  127.869255 ]],

        [[  19.77179  ,  123.9112   ,  112.784706 ],
         [  48.155083 ,  120.6674   ,  117.65041  ],
         [ 113.60409  ,  119.521645 ,  109.2284   ],
         ...,
         [ 149.50964  ,  134.66965  ,  126.76864  ],
         [ 146.92053  ,  132.08054  ,  124.17953  ],
         [ 149.93597  ,  135.09598  ,  127.19497  ]],

        ...,

        [[ 146.10834  ,  129.26834  ,  121.36733  ],
         [ 147.0217   ,  130.1817   ,  122.28069  ],
         [ 147.93506  ,  131.09506  ,  123.19407  ],
         ...,
         [ 149.84198  ,   97.1641   ,   40.301308 ],
         [ 142.20139  ,   69.824196 ,    1.6090927],
         [ 108.14144  ,   39.008064 ,  -45.426064 ]],

        [[ 147.86179  ,  131.02179  ,  123.12078  ],
         [ 147.70392  ,  130.86392  ,  122.96293  ],
         [ 147.24725  ,  130.40726  ,  122.50625  ],
         ...,
         [ 150.09634  ,   98.92501  ,   41.586617 ],
         [ 148.58893  ,   86.8263   ,   25.89103  ],
         [ 129.04419  ,   57.920067 ,  -16.560371 ]],

        [[ 147.28387  ,  130.44388  ,  122.542885 ],
         [ 146.82721  ,  129.98721  ,  122.086205 ],
         [ 146.37051  ,  129.53052  ,  121.629524 ],
         ...,
         [ 147.60794  ,   77.52459  ,    8.739464 ],
         [ 150.57965  ,  103.249886 ,   48.784508 ],
         [ 147.33585  ,   76.4885   ,   11.480766 ]]],


       [[[  58.29081  ,   76.893364 ,   79.87745  ],
         [  61.71862  ,   79.878624 ,   81.97762  ],
         [  65.25152  ,   83.95916  ,   84.41526  ],
         ...,
         [ -80.939    ,  -84.779    ,  -87.68     ],
         [ -81.4729   ,  -85.3129   ,  -88.2139   ],
         [ -83.866005 ,  -87.706    ,  -90.607    ]],

        [[  56.343636 ,   75.432976 ,   79.39066  ],
         [  60.258247 ,   78.41825  ,   80.51725  ],
         [  63.304344 ,   81.525185 ,   83.44167  ],
         ...,
         [ -82.304276 ,  -86.14427  ,  -89.04527  ],
         [ -84.69738  ,  -88.53738  ,  -91.43838  ],
         [ -87.62836  ,  -91.46835  ,  -94.369354 ]],

        [[  56.060997 ,   75.221    ,   80.152275 ],
         [  58.37681  ,   76.95786  ,   79.89895  ],
         [  61.78312  ,   79.94312  ,   82.04212  ],
         ...,
         [ -85.67619  ,  -89.51619  ,  -92.41719  ],
         [ -88.66758  ,  -92.507576 ,  -95.40858  ],
         [ -90.97098  ,  -94.810974 ,  -97.711975 ]],

        ...,

        [[ -93.751236 , -113.59123  , -116.42964  ],
         [ -97.00598  , -116.71202  , -120.74698  ],
         [ -97.60426  , -116.11374  , -121.34526  ],
         ...,
         [ -44.46813  ,  -32.30812  ,  -23.209122 ],
         [ -44.710735 ,  -31.664864 ,  -24.3376   ],
         [ -44.05805  ,  -30.89804  ,  -25.560959 ]],

        [[ -97.21383  , -116.50417  , -120.95483  ],
         [ -97.8121   , -115.9059   , -121.5531   ],
         [ -98.88176  , -115.30762  , -122.15138  ],
         ...,
         [ -51.283237 ,  -39.12323  ,  -30.024231 ],
         [ -43.73715  ,  -31.17807  ,  -22.87722  ],
         [ -44.54484  ,  -31.384834 ,  -25.074165 ]],

        [[ -98.1009   , -115.69805  , -121.76095  ],
         [ -99.29745  , -115.09978  , -122.35922  ],
         [-101.049    , -114.5015   , -122.68     ],
         ...,
         [ -56.939003 ,  -44.779    ,  -35.68     ],
         [ -44.1671   ,  -32.007095 ,  -22.908096 ],
         [ -44.75374  ,  -31.68637  ,  -24.402107 ]]]], dtype=float32), array([[0., 0., 0., 1.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 1., 0., 0.],
       [1., 0., 0., 0.],
       [0., 0., 0., 1.],
       [0., 1., 0., 0.],
       [0., 0., 0., 1.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 1., 0., 0.],
       [0., 0., 0., 1.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [0., 0., 1., 0.],
       [1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [1., 0., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 1., 0., 0.]], dtype=float32)).


	 [[{{node PyFunc}}]]
	 [[IteratorGetNext]] [Op:__inference_one_step_on_iterator_45323]

In [23]:
test_loss, test_acc = model.evaluate(test_generator, verbose=2)
print(f"Test Accuracy: {test_acc}")
print(f"Test Loss: {test_loss}")


3/3 - 4s - 1s/step - accuracy: 0.5541 - loss: 1.3819
Test Accuracy: 0.5540540814399719
Test Loss: 1.3819315433502197
