# Data handling

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

Source: https://data.mendeley.com/datasets/tywbtsjrjv/1 \
Task: image classification \
Description: 54303 images of healthy and unhealthy plant leaves divided in 38 categories by species and state of health.

## Data distribution

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np 
sns.set_theme()

In [None]:
ds, ds_info = tfds.load(
                        'plant_village',
                        split='train',
                        with_info=True,
                        )

In [None]:
n_classes=38

def count_class(counts, batch, num_classes=n_classes):
    labels = batch['label']
    for i in range(num_classes):
        cc = tf.cast(labels == i, tf.int32)
        counts[i] += tf.reduce_sum(cc)
    return counts

initial_state = dict((i, 0) for i in range(38))
counts = ds.reduce(initial_state=initial_state,
                         reduce_func=count_class)

counts = [(k, v.numpy()) for k, v in counts.items()]

# counts = [630,
#  621,
#  275,
#  1645,
#  1502,
#  854,
#  1052,
#  513,
#  1192,
#  1162,
#  985,
#  1180,
#  1383,
#  423,
#  1076,
#  5507,
#  2297,
#  360,
#  997,
#  1477,
#  1000,
#  152,
#  1000,
#  371,
#  5090,
#  1835,
#  456,
#  1109,
#  2127,
#  1000,
#  1591,
#  1908,
#  952,
#  1771,
#  1676,
#  1404,
#  373,
#  5357]

In [None]:
plt.rcParams['figure.figsize']=(23, 10)
plt.rcParams['figure.dpi']=300


plt.bar(x=np.arange(n_classes), height=counts)
plt.xticks(range(n_classes))
plt.title('Label Frequency')
plt.xlabel('Label index')
plt.ylabel('Count')
plt.show()

In [None]:
ds, ds_info = tfds.load(
                        'plant_village',
                        split='train',
                        as_supervised=True,
                        with_info=True,
                        )

In [None]:
label_to_name = ds_info.features['label'].names

plt.figure()
for i, (image, label) in enumerate(ds.take(15).as_numpy_iterator()):
  plt.subplot(3, 5, i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(image)
  plt.xlabel(str(label_to_name[label].replace('___', ' - ').replace('_', ' ')))

## Data split

In [None]:
batch_size=1054+512

In [None]:
# Construct a tf.data.Dataset
(ds_train, ds_val, ds_test), ds_info = tfds.load(
                            'plant_village',
                            split=['train[:75%]', 'train[75%:95]', 'train[95%:]'],
                            as_supervised=True,
                            with_info=True,
                            batch_size=batch_size,
                            shuffle_files=True
                            )

# Preprocessing

In [None]:
def random_augmentations_fn(image, label):
  image = tf.image.random_brightness(image, 0.1)
  image = tf.image.random_contrast(image, 0.8, 1.2)
  image = tf.image.random_flip_left_right(image)
  return image, label

In [None]:
ds_train = ds_train.map(random_augmentations_fn)
ds_train = ds_train.cache()
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

## Class weighting

In [None]:
# source: https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_class_weight.html

dataset_length = 54303

class_weights = {c:(dataset_length / (n_classes * c_len)) for c, c_len in enumerate(counts)}

# Model building

In [None]:
from tensorflow.keras import layers, initializers, regularizers

In [None]:
def add_conv_block(x, n_filters, strides=1):
    x = layers.Conv2D(n_filters, 3, padding='same', strides=strides,
                      kernel_regularizer=regularizers.L2(10e-4))(x)
    x = layers.BatchNormalization()(x)
    
    return tf.nn.relu(x)

In [None]:
def build_model():
  # Input part
  inp = layers.Input(shape=(256, 256, 3))

  # Convolutive part
  x = add_conv_block(inp, 16, 2)
  x = layers.Dropout(0.25)(x)
  x = add_conv_block(x, 32, 2)
  x = layers.Dropout(0.25)(x)
  x = add_conv_block(x, 64, 2)

  x = layers.GlobalAvgPool2D()(x)

  # classification block
  x = layers.Dense(64)(x)
  x = layers.Dropout(0.5)(x)
  x = layers.Dense(38)(x)
  return tf.keras.Model(inputs=inp, outputs=x)

In [None]:
model = build_model()

In [None]:
model.summary()

In [None]:
from tensorflow.keras import losses, metrics, optimizers, callbacks

In [None]:
cross_entropy = losses.SparseCategoricalCrossentropy(from_logits=True)
accuracy = metrics.SparseCategoricalAccuracy()

lr = 10e-4
n_epochs = 150
optimizer = optimizers.Adam(learning_rate=lr, decay= lr / n_epochs)

In [None]:
model.compile(loss=cross_entropy, optimizer=optimizer, metrics=[accuracy])

In [None]:
history = model.fit(ds_train,
                    validation_data=ds_val, 
                    epochs=n_epochs,
                    class_weight=class_weights,
                    )

In [None]:
model.evaluate(ds_test)

In [None]:
plt.figure()
plt.plot(history.history["loss"], label='Training loss')
plt.plot(history.history["val_loss"], label='Validation loss')
plt.title("Loss function")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.plot(history.history["sparse_categorical_accuracy"], label='Training accuracy')
plt.plot(history.history["val_sparse_categorical_accuracy"], label='Validation accuracy')
plt.title("Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()