# CIFAR (TensorFlow)

[![Open in Colab](https://lab.aef.me/files/assets/colab-badge.svg)](https://colab.research.google.com/github/adamelliotfields/lab/blob/main/files/tf/cifar.ipynb)
[![Open in Kaggle](https://lab.aef.me/files/assets/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/adamelliotfields/lab/blob/main/files/tf/cifar.ipynb)
[![Render nbviewer](https://lab.aef.me/files/assets/nbviewer_badge.svg)](https://nbviewer.org/github/adamelliotfields/lab/blob/main/files/tf/cifar.ipynb)

In 2006, researchers from MIT CSAIL created the [Tiny Images](https://people.csail.mit.edu/billf/papers/80millionImages.pdf) dataset. They used over 50,000 nouns from [WordNet](https://en.wikipedia.org/wiki/WordNet) to search for images on the web and subsequently label them, resulting in 80 million 32x32 images. Unfortunately, due to the automated nature of the search, many [inappropriate](https://openreview.net/pdf?id=s-e2zaAlG3I) words and images ended up in the dataset. The dataset was taken down in [2020](https://groups.csail.mit.edu/vision/TinyImages/).

[CIFAR-10/100](https://www.cs.toronto.edu/~kriz/cifar.html) are labeled subsets of the Tiny Images dataset. Each dataset has 60,000 32x32 color images, with 10,000 reserved for testing. CIFAR-10 has 10 classes, while CIFAR-100 has 100. It was introduced in [Learning Multiple Layers of Features from Tiny Images](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf) (Krizhevsky, 2009).

This notebook will be about _transfer learning_, so we'll start with a model trained on ImageNet, remove the top layer, freeze the weights, and train a new layer on CIFAR-10. The image data is fed through the base model and the extracted features are passed to the new layer for learning.

Once training has converged, we'll unfreeze all of the layers and retrain at a much lower learning rate for a few additional epochs. This is _fine-tuning_.

Read the [Keras guide](https://keras.io/guides/transfer_learning/) for more.

In [None]:
import os
from importlib.util import find_spec

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["KERAS_BACKEND"] = "tensorflow"

if find_spec("google.colab") is not None:
    os.environ["TFDS_DATA_DIR"] = "/content/drive/MyDrive/tensorflow_datasets"

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

from keras import Input, Model, Sequential, applications, callbacks, layers, losses, optimizers

In [None]:
# !nvidia-smi --query-gpu=name,memory.total,memory.used,memory.free,utilization.gpu,utilization.memory --format=csv
for device in tf.config.list_physical_devices():
    print(device)

In [None]:
# @title Config
SEED = 42  # @param {type:"integer"}
EPOCHS = 15  # @param {type:"integer"}
VERBOSE = 1  # @param {type:"integer"}
CLASSES = 10  # @param {type:"integer"}
IMAGE_SIZE = 75  # @param {type:"integer"}
BATCH_SIZE = 128  # @param {type:"integer"}
VAL_SIZE = 5000  # @param {type:"integer"}
TEST_SIZE = 10000  # @param {type:"integer"}
TRAIN_SIZE = 45000  # @param {type:"integer"}
LEARNING_RATE = 0.001  # @param {type:"number"}
WEIGHT_DECAY = 0.004  # @param {type:"number"}

In [None]:
# @title Lib
def show_cifar_images(
    ds,
    *,
    rows=1,
    size=2,
    columns=5,
    figsize=None,
    label_decoder=None,
):
    if figsize is None:
        figsize = (columns * size, rows * size)

    _, axes = plt.subplots(rows, columns, figsize=figsize)

    for i, (image, label) in enumerate(ds.take(rows * columns)):
        image = image.numpy().squeeze()

        # rescale from -1, 1 to 0, 1
        if image.min() < 0:
            image = (image + 1) / 2

        # rescale to 0, 255 integers
        if not image.max() > 1:
            image = image * 255

        image = image.astype("uint8")

        ax = axes[i // columns, i % columns] if rows > 1 else axes[i]
        ax.imshow(image)
        ax.axis("off")

        # if one-hot encoded
        label = tf.squeeze(label)
        if tf.rank(label) > 0:
            label = tf.argmax(label)

        if label_decoder is not None:
            ax.set_title(label_decoder(label))
        else:
            ax.set_title(f"{label}")

    plt.tight_layout()
    plt.show()


def prepare_cifar_image(image, label):
    label = tf.one_hot(label, CLASSES)
    return image, label


# https://www.cs.toronto.edu/~kriz/cifar.html
def decode_cifar_label(label):
    # fmt: off
    labels_10 = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
    labels_100 = [
        "apple", "aquarium_fish", "baby", "bear", "beaver", "bed", "bee", "beetle", "bicycle", "bottle",
        "bowl", "boy", "bridge", "bus", "butterfly", "camel", "can", "castle", "caterpillar", "cattle",
        "chair", "chimpanzee", "clock", "cloud", "cockroach", "couch", "crab", "crocodile", "cup",
        "dinosaur", "dolphin", "elephant", "flatfish", "forest", "fox", "girl", "hamster", "house",
        "kangaroo", "computer_keyboard", "lamp", "lawn_mower", "leopard", "lion", "lizard", "lobster",
        "man", "maple_tree", "motorcycle", "mountain", "mouse", "mushroom", "oak_tree", "orange",
        "orchid", "otter", "palm_tree", "pear", "pickup_truck", "pine_tree", "plain", "plate", "poppy",
        "porcupine", "possum", "rabbit", "raccoon", "ray", "road", "rocket", "rose", "sea", "seal",
        "shark", "shrew", "skunk", "skyscraper", "snail", "snake", "spider", "squirrel", "streetcar",
        "sunflower", "sweet_pepper", "table", "tank", "telephone", "television", "tiger", "tractor",
        "train", "trout", "tulip", "turtle", "wardrobe", "whale", "willow_tree", "wolf", "woman", "worm"
    ]
    # fmt: on
    labels = labels_100 if CLASSES == 100 else labels_10
    return f"{labels[label]} ({label})"

In [None]:
# @title Data
(cifar_train, cifar_test), cifar_info = tfds.load(
    "cifar100" if CLASSES == 100 else "cifar10",
    with_info=True,
    as_supervised=True,
    split=["train", "test"],
)

# fmt: off
X_train, X_val, X_test = (
    cifar_train.take(TRAIN_SIZE).map(prepare_cifar_image).shuffle(seed=SEED, buffer_size=cifar_train.cardinality()).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE),
    cifar_train.skip(TRAIN_SIZE).take(VAL_SIZE).map(prepare_cifar_image).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE),
    cifar_test.take(TEST_SIZE).map(prepare_cifar_image).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE),
)
# fmt: on

show_cifar_images(
    cifar_train,
    rows=2,
    label_decoder=decode_cifar_label,
)

In [None]:
augment = Sequential(
    [
        # random augments aren't applied during inferencing, but the resizing and rescaling layers are
        layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
        layers.RandomFlip(mode="horizontal", seed=SEED),
        layers.RandomTranslation(
            height_factor=0.1,
            width_factor=0.1,
            fill_mode="constant",
            fill_value=0,
            seed=SEED,
        ),
        layers.Rescaling(scale=2.0 / 255, offset=-1.0),  # from 0,255 to -1,1
    ],
    name="augment",
)

# keep separate from "model" so you can set trainable later
# also try other models like Xception, InceptionV3, EfficientNetV2S, etc
base = applications.ResNet50V2(
    pooling="avg",
    include_top=False,
    weights="imagenet",
    input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),  # channels-last
)
base.trainable = False

# use original CIFAR size as input
x_inputs = Input(shape=(32, 32, 3), name="input")
x = augment(x_inputs)
x = base(x, training=False)  # set training to False so layers like batch norm are disabled
x = layers.Dense(CLASSES, name="output")(x)

model = Model(inputs=x_inputs, outputs=x, name=f"CIFAR-{CLASSES}")
model.compile(
    metrics=["accuracy"],
    loss=losses.CategoricalCrossentropy(from_logits=True),  # no softmax so output is raw logits
    optimizer=optimizers.AdamW(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY),
)

In [None]:
lr_callback = callbacks.ReduceLROnPlateau(min_delta=0.001, patience=3, verbose=VERBOSE)
ckpt_callback = callbacks.ModelCheckpoint(
    # if you don't use tokens like "{epoch}" then the file gets overwritten
    "resnet-cifar10.model.keras",
    verbose=VERBOSE,
    monitor="accuracy",
    save_best_only=True,
)

history = model.fit(
    X_train,
    epochs=EPOCHS,
    verbose=VERBOSE,
    validation_data=X_val,
    callbacks=[lr_callback, ckpt_callback],
)

In [None]:
# unfreeze and recompile with lower learning rate
base.trainable = True
model.compile(
    metrics=["accuracy"],
    loss=losses.CategoricalCrossentropy(from_logits=True),
    optimizer=optimizers.AdamW(learning_rate=LEARNING_RATE / 100, weight_decay=WEIGHT_DECAY / 100),
)

In [None]:
ckpt_ft_callback = callbacks.ModelCheckpoint(
    "resnet-cifar10-ft.model.keras",
    verbose=VERBOSE,
    monitor="accuracy",
    save_best_only=True,
)

history = model.fit(
    X_train,
    epochs=1,
    verbose=VERBOSE,
    validation_data=X_val,
    callbacks=[ckpt_ft_callback],
)

In [None]:
# evaluate on test data
model.evaluate(X_test, verbose=VERBOSE)

In [None]:
model.summary()