In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

# Load the dataset
train_ds = tfds.load('malaria', split='train[:80%]', as_supervised=True)
validation_ds = tfds.load('malaria', split='train[80%:]', as_supervised=True)

# Preprocess the dataset
def preprocess_image(image, label):
    image = tf.image.resize(image, (128, 128))
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

train_ds = train_ds.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
validation_ds = validation_ds.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)

# Batch and prefetch the datasets
train_ds = train_ds.batch(32).prefetch(tf.data.AUTOTUNE)
validation_ds = validation_ds.batch(32).prefetch(tf.data.AUTOTUNE)

# Build the model
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(2, activation='softmax')  # 2 classes: uninfected and parasitized
])

# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [None]:
#train the model
history = model.fit(train_ds, epochs=1, validation_data=validation_ds)


#

In [None]:
# Load the malaria dataset
train_ds, test_ds = tfds.load('malaria', split=['train[:80%]', 'train[80%:]'], as_supervised=True)

# Preprocess the test dataset
test_ds = test_ds.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE)

# Evaluate the model
loss, accuracy = model.evaluate(test_ds)
print(f"Test accuracy: {accuracy:.2f}")

In [None]:
import matplotlib.pyplot as plt

# Assuming test_ds contains preprocessed images
for images, labels in test_ds.take(1):  # Take one batch of images
    for i in range(len(images)):  # Iterate through each image in the batch
        plt.imshow(images[i])  # Visualize the image
        plt.title(f"Label: {labels[i]}")  # Display the label as the title
        plt.axis('off')  # Remove axes
        plt.show()  # Show the image