In [None]:
%load_ext autoreload
%autoreload

In [None]:
!pip install -q chitra==0.0.20

## import

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa

from chitra.trainer import Trainer, create_cnn

import matplotlib.pyplot as plt

## helper functions

In [None]:
# data visualization

def visualize(data, nrow=2, ncol=3):
    fig, axs = plt.subplots(nrow, ncol)
    _data = []
    for img, label in data.take(nrow*ncol):
        _data.append((img.numpy(), str(label.numpy())))
    
    k = 0
    for i in range(nrow):
        for j in range(ncol):
            _data[k][0]
            axs[i][j].imshow(_data[k][0])
            axs[i][j].set_title(_data[k][1])
            k+=1

## define constants

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
BS = 16

## Load data with tfds

In [None]:
(ds_train, ds_test), ds_info = tfds.load(
    'cifar10',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [None]:
visualize(ds_train)

## Preprocess
Images augmentation is applied to generate more data points from the existing data. It helps in generalization of the model and produce a regularization effect.

In [None]:
@tf.function
def rescale(image, label):
    image = tf.cast(image, tf.float32) / 127.5 - 1.0
    return image, label

@tf.function
def augment(image, label):
    image = tf.cast(image, tf.float32)
    image = tf.image.random_flip_left_right(image)
    image = image + tf.random.normal((32, 32, 3), mean=0.0, stddev=0.1)
    return image, label

In [None]:
train_dl = ds_train.map(augment, AUTOTUNE).map(rescale, AUTOTUNE).cache().batch(BS).prefetch(AUTOTUNE)
test_dl = ds_test.map(rescale).batch(BS).prefetch(AUTOTUNE)

## Build model

In [None]:
model = create_cnn('resnet50', num_classes=10, drop_out=0.3, weights=None)

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

In [None]:
model.fit(train_dl, epochs=15, validation_data=test_dl)

In [None]:
model.fit(train_dl, epochs=15, validation_data=test_dl)