In [20]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Input, Dense, Conv2D, MaxPool2D, Flatten, BatchNormalization
from tensorflow.keras.models import Model
import tensorflow as tf
from tensorflow.keras.layers import Layer, Input, Dense, Conv2D, MaxPool2D, Flatten, BatchNormalization
from tensorflow.keras.models import Model
import tensorflow_datasets as tfds

# Load the dataset
dataset, dataset_info = tfds.load('malaria', with_info=True, as_supervised=True, shuffle_files=True, split='train')

def reshape_resize(image, label):
    image = tf.image.resize(image, (224, 224))
    image = tf.cast(image, tf.float32) / 255.0
    label = tf.cast(label, tf.float32) 
    return image, label

# Split the dataset
TRAIN_RATIO = 0.6
VAL_RATIO = 0.2
TEST_RATIO = 0.2
LEN = len(dataset)

train_size = int(LEN * TRAIN_RATIO)
val_size = int(LEN * VAL_RATIO)

train_ds = dataset.take(train_size)
remaining_ds = dataset.skip(train_size)
val_ds = remaining_ds.take(val_size)
test_ds = remaining_ds.skip(val_size)

# Preprocess the datasets
train_ds = train_ds.map(reshape_resize).batch(32)
val_ds = val_ds.map(reshape_resize).batch(32)
test_ds = test_ds.map(reshape_resize).batch(32)

class FeatureExtractor(Layer):
    def __init__(self, filters_1, filters_2, kernel_size, strides, padding, activation, pool_size):
        super(FeatureExtractor, self).__init__()
        self.conv_1 = Conv2D(filters=filters_1, kernel_size=kernel_size, strides=strides, padding=padding, activation=activation)
        self.batch_1 = BatchNormalization()
        self.pool_1 = MaxPool2D(pool_size=pool_size, strides=strides*2)
        self.conv_2 = Conv2D(filters=filters_2, kernel_size=kernel_size, strides=strides, padding=padding, activation=activation)
        self.batch_2 = BatchNormalization()
        self.pool_2 = MaxPool2D(pool_size=pool_size, strides=strides*2)

    def call(self, x, training=None):
        x = self.conv_1(x)
        x = self.batch_1(x, training=training)
        x = self.pool_1(x)
        x = self.conv_2(x)
        x = self.batch_2(x, training=training)
        x = self.pool_2(x)
        return x

class LenetModel(Model):
    def __init__(self):
        super(LenetModel, self).__init__()
        self.feature_extractor = FeatureExtractor(6, 16, (5, 5), 1, "valid", "relu", (2, 2))
        self.flatten = Flatten()
        self.dense_1 = Dense(120, activation="relu")
        self.batch_1 = BatchNormalization()
        self.dense_2 = Dense(84, activation="relu")
        self.batch_2 = BatchNormalization()
        self.output_layer = Dense(1, activation="sigmoid")

    def call(self, x, training=None):
        x = self.feature_extractor(x, training=training)
        x = self.flatten(x)
        x = self.dense_1(x)
        x = self.batch_1(x, training=training)
        x = self.dense_2(x)
        x = self.batch_2(x, training=training)
        x = self.output_layer(x)
        return x
        
# Create the model
func_input = Input(shape=(224, 224, 3), name="input")
feature_sub_classed = LenetModel()
func_output = feature_sub_classed(func_input)
model = Model(inputs=func_input, outputs=func_output, name="ff")

# Print model summary
model.summary()

# Compile and train the model
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
model.fit(train_ds, validation_data=val_ds, epochs=3, verbose=1)


Epoch 1/3


2024-11-30 06:31:08.187590: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:376] The default buffer size is 262144, which is overridden by the user specified `buffer_size` of 8388608
I0000 00:00:1732948268.265666     671 service.cc:148] XLA service 0x7fdb64012c70 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1732948268.266716     671 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 4060 Laptop GPU, Compute Capability 8.9
2024-11-30 06:31:08.330880: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.


[1m  7/517[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m7s[0m 16ms/step - accuracy: 0.5899 - loss: 0.9026

I0000 00:00:1732948271.217768     671 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 26ms/step - accuracy: 0.6604 - loss: 0.6374 - val_accuracy: 0.6627 - val_loss: 0.8685
Epoch 2/3
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 15ms/step - accuracy: 0.8750 - loss: 0.3022 - val_accuracy: 0.8938 - val_loss: 0.3143
Epoch 3/3
[1m517/517[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 15ms/step - accuracy: 0.9277 - loss: 0.1949 - val_accuracy: 0.9225 - val_loss: 0.2389


<keras.src.callbacks.history.History at 0x7fdb680e6e70>