In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tf_explain.callbacks.grad_cam import GradCAMCallback

In [None]:
image_size = (200, 200)
batch_size = 64
main_directory = Path("/media/hdd/Datasets/asl")

In [None]:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    str(main_directory / "asl_alphabet_train"),
    validation_split=0.2,
    subset="training",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    str(main_directory / "asl_alphabet_train"),
    subset="validation",
    validation_split=0.2,
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)

In [None]:
train_ds

In [None]:

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(int(labels[i]))
        plt.axis("off")

In [None]:
data_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
    ]
)

In [None]:
plt.figure(figsize=(10, 10))
for images, _ in train_ds.take(1):
    for i in range(9):
        augmented_images = data_augmentation(images)
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(augmented_images[0].numpy().astype("uint8"))
        plt.axis("off")

In [None]:
train_ds = train_ds.prefetch(buffer_size=batch_size)
val_ds = val_ds.prefetch(buffer_size=batch_size)

In [None]:


def make_model(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)
    # Image augmentation block
    x = data_augmentation(inputs)

    # Entry block
    x = layers.Rescaling(1.0 / 255)(x)
    x = layers.Conv2D(32, 3, strides=2, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    x = layers.Conv2D(64, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    for size in [128, 256, 512, 728]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(size, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    x = layers.SeparableConv2D(1024, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    x = layers.GlobalAveragePooling2D()(x)
    if num_classes == 2:
        activation = "sigmoid"
        units = 1
    else:
        activation = "softmax"
        units = num_classes

    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(units, activation=activation)(x)
    return keras.Model(inputs, outputs)

In [None]:
model = make_model(input_shape=image_size + (3,), num_classes=30)
keras.utils.plot_model(model, show_shapes=True)

In [None]:
epochs = 1

im, label = val_ds.take(2)
im, label = im[0], tf.cast(im[1], tf.float32)

In [None]:
im, label = im.numpy(), label.numpy()

In [None]:

callbacks = [
    keras.callbacks.ModelCheckpoint("./logs/save_at_{epoch}.h5"),
    keras.callbacks.ProgbarLogger(count_mode="samples", stateful_metrics=None),
    GradCAMCallback(
        layer_name="activation_10",
        class_index=0,
        output_dir="./logs/",
        validation_data=(im, label),
    ),
]
loss_fn = keras.losses.SparseCategoricalCrossentropy()
opt = keras.optimizers.Adam(1e-3)

model.compile(
    optimizer=opt,
    loss=loss_fn,
    metrics=["accuracy"],
)
model.fit(
    train_ds,
    epochs=epochs,
    callbacks=callbacks,
    validation_data=val_ds,
)