In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import scipy.io
import shutil
import pandas as pd
from PIL import Image
import numpy as np


In [2]:
df = pd.read_csv('StenosisDetection/train_labels.csv', usecols=['filename','xmax','ymax','xmin', 'ymin'])

In [3]:
image_size = 224  
images, targets = [], []

top_left_x_collection = []
top_left_y_collection = []
bottom_right_x_collection = []
bottom_right_y_collection = [] 


In [4]:
for index, row in df.iterrows():
    image = keras.utils.load_img(
        'StenosisDetection/dataset/'+row['filename'],
    )

    (w, h) = image.size[:2]
    image = image.resize((224, 224))
    images.append(keras.utils.img_to_array(image))
    
    
    top_left_x, top_left_y = row['xmax'], row['ymax']
    bottom_right_x, bottom_right_y = row['xmin'], row['ymin']
    
    targets.append(
        (
            float(top_left_x) / w,
            float(top_left_y) / h,
            float(bottom_right_x) / w,
            float(bottom_right_y) / h,
        )
    )


In [5]:
df_test = pd.read_csv('StenosisDetection/test_labels.csv', usecols=['filename','xmax','ymax','xmin', 'ymin'])


In [6]:
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    #     Override function to avoid error while saving model
    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "input_shape": input_shape,
                "patch_size": patch_size,
                "num_patches": num_patches,
                "projection_dim": projection_dim,
                "num_heads": num_heads,
                "transformer_units": transformer_units,
                "transformer_layers": transformer_layers,
                "mlp_head_units": mlp_head_units,
            }
        )
        return config

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        return tf.reshape(patches, [batch_size, -1, patches.shape[-1]])


In [7]:
image_size = 224  
images_test, targets_test = [], [] 


In [8]:
for index, row in df_test.iterrows():
    image = keras.utils.load_img(
        'StenosisDetection/dataset/'+row['filename'],
    )

    (w, h) = image.size[:2]
    image = image.resize((224, 224))
    images_test.append(keras.utils.img_to_array(image))
    
    
    top_left_x, top_left_y = row['xmax'], row['ymax']
    bottom_right_x, bottom_right_y = row['xmin'], row['ymin']
    
    targets_test.append(
        (
            float(top_left_x) / w,
            float(top_left_y) / h,
            float(bottom_right_x) / w,
            float(bottom_right_y) / h,
        )
    )


In [10]:
(x_train), (y_train) = (
    np.asarray(images[: int(len(images))]),
    np.asarray(targets[: int(len(targets))]),
)
(x_test), (y_test) = (
    np.asarray(images_test[: int(len(images_test))]),
    np.asarray(targets_test[: int(len(targets_test))]),
)


In [11]:
y_train.shape

(7493, 4)

In [12]:
y_train[0]

array([0.3525 , 0.24   , 0.29375, 0.195  ])

In [13]:
y_train.shape

(7493, 4)

In [14]:
patch_size = 32

patches = Patches(patch_size)(tf.convert_to_tensor([x_train[0]]))
print(patches.shape)

Metal device set to: Apple M1 Pro
(1, 49, 3072)


2023-01-27 11:14:50.330146: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-01-27 11:14:50.331636: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [15]:
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    # Override function to avoid error while saving model
    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "input_shape": input_shape,
                "patch_size": patch_size,
                "num_patches": num_patches,
                "projection_dim": projection_dim,
                "num_heads": num_heads,
                "transformer_units": transformer_units,
                "transformer_layers": transformer_layers,
                "mlp_head_units": mlp_head_units,
            }
        )
        return config

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

In [20]:
# x = keras.layers.InputShape()
# x = keras.layers.Patches()(x)
# x = PatchEncoder(patches)


In [79]:
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)


In [11]:
from tensorflow.keras.applications import inception_v3


In [12]:
def basemodel(weights='imagenet', input_shape=(32, 32, 3), unfreeze=True, unfreeze_from=15, include_top=False):
    base_model = inception_v3.InceptionV3(
    include_top=include_top,
    weights=weights,
    input_shape=input_shape
    )
    return base_model


In [13]:
def intermediatemodel(base_model, output_index_layer, unfreeze=False, unfreeze_from=0):
    intermediate_layer_model = keras.Model(inputs=base_model.input,
                                     outputs=base_model.get_layer(index=output_index_layer).output)

    i = 0
    for layer in intermediate_layer_model.layers:
        print(layer.name)
        print(i)
        i += 1
        layer.trainable = False

    if unfreeze:
        for layer in intermediate_layer_model.layers[unfreeze_from:]:
            layer.trainable = True

    for layer in intermediate_layer_model.layers:
        print("{}: {}".format(layer, layer.trainable))

    return intermediate_layer_model



In [46]:

from tensorflow.keras.layers import *
from tensorflow.keras.models import *
# def predictionmodel(base_model, dense_sizes=[128], dense_activations=['relu'],
#                     dense_kernel_inits=['random_uniform'], dropout_p=[0.5],
#                     pooling='avg', classes=1,
#                     class_activation='sigmoid'):
#     input_shape = base_model.output.shape
#     model = Sequential()
#     model.add(Conv2D(8, (7, 7), padding='same', activation='relu', input_shape=(input_shape[1],input_shape[2],input_shape[3])))
#     model.add(Dropout(0.5))
#     model.add(Conv2D(8, (7, 7), padding='same', activation='relu'))
#     model.add(Conv2D(8, (7, 7), padding='same', activation='relu'))
#     model.add(Dropout(0.5))
#     model.add(Conv2D(8, (7, 7), padding='same', activation='relu'))
#     model.add(Dropout(0.5))
#     model.add(Flatten())
#     model.add(Dense(16, activation='sigmoid', kernel_initializer='random_uniform', bias_initializer='random_uniform'))
#     model.add(Dense(1, activation='sigmoid', kernel_initializer='random_uniform', bias_initializer='random_uniform'))


#         predictions = Dense(classes, activation=class_activation, name='Output')(x)

#     return model


In [56]:
def predictionmodel(base_model, dense_sizes=[128], dense_activations=['relu'],
                    dense_kernel_inits=['random_uniform'], dropout_p=[0.5],
                    pooling='avg', classes=1,
                    class_activation='sigmoid'):
    x = base_model.output

    for i, (ds, da, dki, dop) in enumerate(zip(dense_sizes, dense_activations, dense_kernel_inits, dropout_p)):
        x = Dense(ds, activation=da, kernel_initializer=dki)(x)
        name = 'Dropout_Regularization_' + str(i)
        x = Dropout(dop, name=name)(x)

    predictions = Dense(classes, activation=class_activation, name='Output')(x)

    
    
    return predictions


In [57]:
base_model.output

<KerasTensor: shape=(None, 5, 5, 2048) dtype=float32 (created by layer 'mixed10')>

In [58]:
base_model = basemodel(input_shape=(224, 224, 3))
len(base_model.layers)

311

In [59]:
# midmodel = intermediatemodel(base_model, 86)

midmodel = intermediatemodel(base_model, len(base_model.layers) - 1)



input_6
0
conv2d_480
1
batch_normalization_470
2
activation_470
3
conv2d_481
4
batch_normalization_471
5
activation_471
6
conv2d_482
7
batch_normalization_472
8
activation_472
9
max_pooling2d_20
10
conv2d_483
11
batch_normalization_473
12
activation_473
13
conv2d_484
14
batch_normalization_474
15
activation_474
16
max_pooling2d_21
17
conv2d_488
18
batch_normalization_478
19
activation_478
20
conv2d_486
21
conv2d_489
22
batch_normalization_476
23
batch_normalization_479
24
activation_476
25
activation_479
26
average_pooling2d_45
27
conv2d_485
28
conv2d_487
29
conv2d_490
30
conv2d_491
31
batch_normalization_475
32
batch_normalization_477
33
batch_normalization_480
34
batch_normalization_481
35
activation_475
36
activation_477
37
activation_480
38
activation_481
39
mixed0
40
conv2d_495
41
batch_normalization_485
42
activation_485
43
conv2d_493
44
conv2d_496
45
batch_normalization_483
46
batch_normalization_486
47
activation_483
48
activation_486
49
average_pooling2d_46
50
conv2d_492
51
co

In [60]:
base_model.summary()

Model: "inception_v3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_6 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_480 (Conv2D)            (None, 111, 111, 32  864         ['input_6[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization_470 (Batch  (None, 111, 111, 32  96         ['conv2d_480[0][0]']             
 Normalization)                 )                                                      

 conv2d_491 (Conv2D)            (None, 25, 25, 32)   6144        ['average_pooling2d_45[0][0]']   
                                                                                                  
 batch_normalization_475 (Batch  (None, 25, 25, 64)  192         ['conv2d_485[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 batch_normalization_477 (Batch  (None, 25, 25, 64)  192         ['conv2d_487[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 batch_normalization_480 (Batch  (None, 25, 25, 96)  288         ['conv2d_490[0][0]']             
 Normalization)                                                                                   
          

                                                                                                  
 conv2d_502 (Conv2D)            (None, 25, 25, 64)   18432       ['mixed1[0][0]']                 
                                                                                                  
 batch_normalization_492 (Batch  (None, 25, 25, 64)  192         ['conv2d_502[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 activation_492 (Activation)    (None, 25, 25, 64)   0           ['batch_normalization_492[0][0]']
                                                                                                  
 conv2d_500 (Conv2D)            (None, 25, 25, 48)   13824       ['mixed1[0][0]']                 
                                                                                                  
 conv2d_50

                                                                                                  
 activation_499 (Activation)    (None, 12, 12, 96)   0           ['batch_normalization_499[0][0]']
                                                                                                  
 max_pooling2d_22 (MaxPooling2D  (None, 12, 12, 288)  0          ['mixed2[0][0]']                 
 )                                                                                                
                                                                                                  
 mixed3 (Concatenate)           (None, 12, 12, 768)  0           ['activation_496[0][0]',         
                                                                  'activation_499[0][0]',         
                                                                  'max_pooling2d_22[0][0]']       
                                                                                                  
 conv2d_51

 mixed4 (Concatenate)           (None, 12, 12, 768)  0           ['activation_500[0][0]',         
                                                                  'activation_503[0][0]',         
                                                                  'activation_508[0][0]',         
                                                                  'activation_509[0][0]']         
                                                                                                  
 conv2d_524 (Conv2D)            (None, 12, 12, 160)  122880      ['mixed4[0][0]']                 
                                                                                                  
 batch_normalization_514 (Batch  (None, 12, 12, 160)  480        ['conv2d_524[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 activatio

 conv2d_534 (Conv2D)            (None, 12, 12, 160)  122880      ['mixed5[0][0]']                 
                                                                                                  
 batch_normalization_524 (Batch  (None, 12, 12, 160)  480        ['conv2d_534[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 activation_524 (Activation)    (None, 12, 12, 160)  0           ['batch_normalization_524[0][0]']
                                                                                                  
 conv2d_535 (Conv2D)            (None, 12, 12, 160)  179200      ['activation_524[0][0]']         
                                                                                                  
 batch_normalization_525 (Batch  (None, 12, 12, 160)  480        ['conv2d_535[0][0]']             
 Normaliza

 activation_534 (Activation)    (None, 12, 12, 192)  0           ['batch_normalization_534[0][0]']
                                                                                                  
 conv2d_545 (Conv2D)            (None, 12, 12, 192)  258048      ['activation_534[0][0]']         
                                                                                                  
 batch_normalization_535 (Batch  (None, 12, 12, 192)  576        ['conv2d_545[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 activation_535 (Activation)    (None, 12, 12, 192)  0           ['batch_normalization_535[0][0]']
                                                                                                  
 conv2d_541 (Conv2D)            (None, 12, 12, 192)  147456      ['mixed6[0][0]']                 
          

 Normalization)                                                                                   
                                                                                                  
 activation_543 (Activation)    (None, 12, 12, 192)  0           ['batch_normalization_543[0][0]']
                                                                                                  
 conv2d_550 (Conv2D)            (None, 12, 12, 192)  147456      ['mixed7[0][0]']                 
                                                                                                  
 conv2d_554 (Conv2D)            (None, 12, 12, 192)  258048      ['activation_543[0][0]']         
                                                                                                  
 batch_normalization_540 (Batch  (None, 12, 12, 192)  576        ['conv2d_550[0][0]']             
 Normalization)                                                                                   
          

 Normalization)                                                                                   
                                                                                                  
 conv2d_564 (Conv2D)            (None, 5, 5, 192)    245760      ['average_pooling2d_52[0][0]']   
                                                                                                  
 batch_normalization_546 (Batch  (None, 5, 5, 320)   960         ['conv2d_556[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 activation_548 (Activation)    (None, 5, 5, 384)    0           ['batch_normalization_548[0][0]']
                                                                                                  
 activation_549 (Activation)    (None, 5, 5, 384)    0           ['batch_normalization_549[0][0]']
          

                                                                                                  
 activation_557 (Activation)    (None, 5, 5, 384)    0           ['batch_normalization_557[0][0]']
                                                                                                  
 activation_558 (Activation)    (None, 5, 5, 384)    0           ['batch_normalization_558[0][0]']
                                                                                                  
 activation_561 (Activation)    (None, 5, 5, 384)    0           ['batch_normalization_561[0][0]']
                                                                                                  
 activation_562 (Activation)    (None, 5, 5, 384)    0           ['batch_normalization_562[0][0]']
                                                                                                  
 batch_normalization_563 (Batch  (None, 5, 5, 192)   576         ['conv2d_573[0][0]']             
 Normaliza

In [61]:
midmodel.summary()

Model: "model_8"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_6 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_480 (Conv2D)            (None, 111, 111, 32  864         ['input_6[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization_470 (Batch  (None, 111, 111, 32  96         ['conv2d_480[0][0]']             
 Normalization)                 )                                                           

 conv2d_491 (Conv2D)            (None, 25, 25, 32)   6144        ['average_pooling2d_45[0][0]']   
                                                                                                  
 batch_normalization_475 (Batch  (None, 25, 25, 64)  192         ['conv2d_485[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 batch_normalization_477 (Batch  (None, 25, 25, 64)  192         ['conv2d_487[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 batch_normalization_480 (Batch  (None, 25, 25, 96)  288         ['conv2d_490[0][0]']             
 Normalization)                                                                                   
          

                                                                                                  
 conv2d_502 (Conv2D)            (None, 25, 25, 64)   18432       ['mixed1[0][0]']                 
                                                                                                  
 batch_normalization_492 (Batch  (None, 25, 25, 64)  192         ['conv2d_502[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 activation_492 (Activation)    (None, 25, 25, 64)   0           ['batch_normalization_492[0][0]']
                                                                                                  
 conv2d_500 (Conv2D)            (None, 25, 25, 48)   13824       ['mixed1[0][0]']                 
                                                                                                  
 conv2d_50

                                                                                                  
 activation_499 (Activation)    (None, 12, 12, 96)   0           ['batch_normalization_499[0][0]']
                                                                                                  
 max_pooling2d_22 (MaxPooling2D  (None, 12, 12, 288)  0          ['mixed2[0][0]']                 
 )                                                                                                
                                                                                                  
 mixed3 (Concatenate)           (None, 12, 12, 768)  0           ['activation_496[0][0]',         
                                                                  'activation_499[0][0]',         
                                                                  'max_pooling2d_22[0][0]']       
                                                                                                  
 conv2d_51

 mixed4 (Concatenate)           (None, 12, 12, 768)  0           ['activation_500[0][0]',         
                                                                  'activation_503[0][0]',         
                                                                  'activation_508[0][0]',         
                                                                  'activation_509[0][0]']         
                                                                                                  
 conv2d_524 (Conv2D)            (None, 12, 12, 160)  122880      ['mixed4[0][0]']                 
                                                                                                  
 batch_normalization_514 (Batch  (None, 12, 12, 160)  480        ['conv2d_524[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 activatio

 conv2d_534 (Conv2D)            (None, 12, 12, 160)  122880      ['mixed5[0][0]']                 
                                                                                                  
 batch_normalization_524 (Batch  (None, 12, 12, 160)  480        ['conv2d_534[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 activation_524 (Activation)    (None, 12, 12, 160)  0           ['batch_normalization_524[0][0]']
                                                                                                  
 conv2d_535 (Conv2D)            (None, 12, 12, 160)  179200      ['activation_524[0][0]']         
                                                                                                  
 batch_normalization_525 (Batch  (None, 12, 12, 160)  480        ['conv2d_535[0][0]']             
 Normaliza

 activation_534 (Activation)    (None, 12, 12, 192)  0           ['batch_normalization_534[0][0]']
                                                                                                  
 conv2d_545 (Conv2D)            (None, 12, 12, 192)  258048      ['activation_534[0][0]']         
                                                                                                  
 batch_normalization_535 (Batch  (None, 12, 12, 192)  576        ['conv2d_545[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 activation_535 (Activation)    (None, 12, 12, 192)  0           ['batch_normalization_535[0][0]']
                                                                                                  
 conv2d_541 (Conv2D)            (None, 12, 12, 192)  147456      ['mixed6[0][0]']                 
          

 Normalization)                                                                                   
                                                                                                  
 activation_543 (Activation)    (None, 12, 12, 192)  0           ['batch_normalization_543[0][0]']
                                                                                                  
 conv2d_550 (Conv2D)            (None, 12, 12, 192)  147456      ['mixed7[0][0]']                 
                                                                                                  
 conv2d_554 (Conv2D)            (None, 12, 12, 192)  258048      ['activation_543[0][0]']         
                                                                                                  
 batch_normalization_540 (Batch  (None, 12, 12, 192)  576        ['conv2d_550[0][0]']             
 Normalization)                                                                                   
          

 Normalization)                                                                                   
                                                                                                  
 conv2d_564 (Conv2D)            (None, 5, 5, 192)    245760      ['average_pooling2d_52[0][0]']   
                                                                                                  
 batch_normalization_546 (Batch  (None, 5, 5, 320)   960         ['conv2d_556[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 activation_548 (Activation)    (None, 5, 5, 384)    0           ['batch_normalization_548[0][0]']
                                                                                                  
 activation_549 (Activation)    (None, 5, 5, 384)    0           ['batch_normalization_549[0][0]']
          

                                                                                                  
 activation_557 (Activation)    (None, 5, 5, 384)    0           ['batch_normalization_557[0][0]']
                                                                                                  
 activation_558 (Activation)    (None, 5, 5, 384)    0           ['batch_normalization_558[0][0]']
                                                                                                  
 activation_561 (Activation)    (None, 5, 5, 384)    0           ['batch_normalization_561[0][0]']
                                                                                                  
 activation_562 (Activation)    (None, 5, 5, 384)    0           ['batch_normalization_562[0][0]']
                                                                                                  
 batch_normalization_563 (Batch  (None, 5, 5, 192)   576         ['conv2d_573[0][0]']             
 Normaliza

In [62]:
headmodel = predictionmodel(midmodel, pooling=None)


In [63]:
model = keras.Model(inputs=midmodel.input, outputs=headmodel)


In [72]:
model.output

<KerasTensor: shape=(None, 5, 5, 1) dtype=float32 (created by layer 'Output')>

In [73]:
y_train.shape

(7493, 4)

In [64]:
base_model_input_shape = model.input_shape
model.build(input_shape=base_model_input_shape)

In [65]:
from keras.callbacks import ReduceLROnPlateau
reduce_lr = ReduceLROnPlateau(
                mode='min',
                monitor='val_loss',
                factor=0.1,
                min_lr=5e-7,
                patience=10,
                verbose=1)
    

In [66]:
from tensorflow.keras.callbacks import EarlyStopping

es = EarlyStopping(
                monitor="val_loss",
                patience=10,
                mode="min",
                restore_best_weights=True,
                verbose=1)

In [67]:
from tensorflow.keras.optimizers import SGD

opt = SGD(learning_rate=0.001,
                      momentum=0.9,
                      name="SGD")

In [68]:
model.compile(opt,loss='binary_crossentropy',metrics=['acc'])

In [69]:
model.fit(x=x_train, y=y_train)

2023-01-25 22:43:18.253749: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


ValueError: in user code:

    File "/Users/ddyrga/.pyenv/versions/3.10.8/lib/python3.10/site-packages/keras/engine/training.py", line 1051, in train_function  *
        return step_function(self, iterator)
    File "/Users/ddyrga/.pyenv/versions/3.10.8/lib/python3.10/site-packages/keras/engine/training.py", line 1040, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/Users/ddyrga/.pyenv/versions/3.10.8/lib/python3.10/site-packages/keras/engine/training.py", line 1030, in run_step  **
        outputs = model.train_step(data)
    File "/Users/ddyrga/.pyenv/versions/3.10.8/lib/python3.10/site-packages/keras/engine/training.py", line 890, in train_step
        loss = self.compute_loss(x, y, y_pred, sample_weight)
    File "/Users/ddyrga/.pyenv/versions/3.10.8/lib/python3.10/site-packages/keras/engine/training.py", line 948, in compute_loss
        return self.compiled_loss(
    File "/Users/ddyrga/.pyenv/versions/3.10.8/lib/python3.10/site-packages/keras/engine/compile_utils.py", line 201, in __call__
        loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    File "/Users/ddyrga/.pyenv/versions/3.10.8/lib/python3.10/site-packages/keras/losses.py", line 139, in __call__
        losses = call_fn(y_true, y_pred)
    File "/Users/ddyrga/.pyenv/versions/3.10.8/lib/python3.10/site-packages/keras/losses.py", line 243, in call  **
        return ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "/Users/ddyrga/.pyenv/versions/3.10.8/lib/python3.10/site-packages/keras/losses.py", line 1930, in binary_crossentropy
        backend.binary_crossentropy(y_true, y_pred, from_logits=from_logits),
    File "/Users/ddyrga/.pyenv/versions/3.10.8/lib/python3.10/site-packages/keras/backend.py", line 5283, in binary_crossentropy
        return tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)

    ValueError: `logits` and `labels` must have the same shape, received ((None, 5, 5, 1) vs (None, 4)).


In [None]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x
