In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import Conv2D, MaxPool2D, Dense, Flatten, InputLayer, BatchNormalization, Input
from tensorflow.keras.losses import MeanSquaredError, Huber, MeanAbsoluteError, BinaryCrossentropy
from tensorflow.keras.metrics import RootMeanSquaredError
from tensorflow.keras.optimizers import Adam

### Preparing Data

In [None]:
dataset, dataset_info = tfds.load('malaria', with_info=True, as_supervised=True, shuffle_files=True, split=['train'])

In [None]:
dataset

In [None]:
dataset_info

In [None]:
for data in dataset[0].take(1):
  print(data)
  break

In [None]:
def splits(dataset, train_ratio, val_ratio, test_ratio):

  DATASET_SIZE = len(dataset)

  train_dataset = dataset.take(int(train_ratio*DATASET_SIZE))

  val_test_dataset = dataset.skip(int(train_ratio*DATASET_SIZE))
  val_dataset = val_test_dataset.take(int(val_ratio*DATASET_SIZE))

  test_dataset = val_test_dataset.skip(int(val_ratio*DATASET_SIZE))

  return train_dataset, val_dataset, test_dataset

In [None]:
TRAIN_RATIO = 0.8
VAL_RATIO = 0.1
TEST_RATIO = 0.1

train_dataset, val_dataset, test_dataset = splits(dataset[0], TRAIN_RATIO, VAL_RATIO, TEST_RATIO)
print(list(train_dataset.take(1).as_numpy_iterator()),
      list(val_dataset.take(1).as_numpy_iterator()), list(test_dataset.take(1).as_numpy_iterator()))

### Data Visualization

In [None]:
for i, (image, label) in enumerate(train_dataset.take(16)):
  ax = plt.subplot(4, 4, i + 1)
  plt.imshow(image)
  plt.title(dataset_info.features['label'].int2str(label))
  plt.axis('off')

### Data Preprocessing

In [None]:
IM_SIZE = 224
def resizing_rescaling(image, label):
  return tf.image.resize(image, (IM_SIZE, IM_SIZE))/255.0, label

In [None]:
train_dataset = train_dataset.map(resizing_rescaling)
val_dataset = val_dataset.map(resizing_rescaling)
test_dataset = test_dataset.map(resizing_rescaling)
train_dataset

In [None]:
for image, label in train_dataset.take(1):
  print(image, label)

In [None]:
BATCH_SIZE = 32
train_dataset = train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_dataset = val_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

### Functional API

In [None]:
func_input = Input(shape=(IM_SIZE, IM_SIZE, 3), name='Input Image')

x = Conv2D(filters=6, kernel_size=3, strides=1, padding='valid', activation='relu')(func_input)
x = BatchNormalization()(x)
x = MaxPool2D(pool_size=2, strides=2)(x),

x = Conv2D(filters=16, kernel_size=3, strides=1, padding='valid', activation='relu')(x)
x = BatchNormalization()(x)
output = MaxPool2D(pool_size=2, strides=2)(x)

feature_extractor_model = Model(func_input, func_output, name='Feature_Extractor')
feature_extractor_model.summary()

In [None]:
x = feature_extractor_model(func_input)

x = Flatten()(x)

x = Dense(100, activation='relu')(x)
x = BatchNormalization()(x)

x = Dense(10, activation='relu')(x)
x = BatchNormalization()(x)

func_output = Dense(1, activation='sigmoid')(x)

lenet_model_func = Model(func_input, func_output, name='Lenet_Model')
lenet_model_func.summary()

In [None]:
lenet_model_func.compile(optimizer=Adam(learning_rate=0.01),
              loss=BinaryCrossentropy(),
              metrics='accuracy',
              )

In [None]:
history = lenet_model_func.fit(train_dataset, validation_data=val_dataset, epochs=20, verbose=1)

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train_loss', 'val_loss'])
plt.show()

### Model Evaluation and Testing

In [None]:
test_dataset = test_dataset.batch(1)

In [None]:
lenet_model_func.evaluate(test_dataset)

In [None]:
lenet_model_func.predict(test_dataset.take(1))[0][0]

In [None]:
def parasite_or_not(x):
  if (x < 0.5):
    return str('P')
  else:
    return str('U')

In [None]:
parasite_or_not(lenet_model_func.predict(test_dataset.take(1))[0][0])

In [None]:
for i, (image, label) in enumerate(test_dataset.take(9)):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(image[0])
  plt.title(str(parasite_or_not(label.numpy()[0])) + ':' + str(parasite_or_not(lenet_model_func.predict(image)[0][0])))

  plt.axis('off')

In [None]:
lenet_model_func.save('lenet_model_func.keras')

In [None]:
lenet_loaded_model = tf.keras.models.load_model('lenet_model_func.keras')
lenet_loaded_model.summary()