In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from keras.layers import Input, Conv2D, LeakyReLU, GlobalAveragePooling2D, Dense, Dropout, Softmax
from keras.models import Model
from tqdm import tqdm
from time import time

In [2]:
#parameters
epochs = 12
batch_size = 128
depth = 64

In [3]:
train_ds, info = tfds.load("cifar10", split='train', with_info=True, shuffle_files=True, download=True)
total_images = info.splits['train'].num_examples
total_batches = total_images//batch_size
total_steps = total_batches * epochs
xSize, ySize, rgbSize = info.features['image'].shape
num_classes = info.features['label'].num_classes

In [4]:
image_mean = tf.constant([[[0.49139968, 0.48215841, 0.44653091]]])
image_std = tf.constant([[[0.24703223, 0.24348513, 0.26158784]]])
def normalize(item):
    """
    Normalize the images
    """
    image = tf.cast(item['image'], tf.float32) / 255.0
    image = (image - image_mean) / image_std # zero mean unit variance
    label = item['label'] #use to_categorical for CategoricalCrossEntropy
    return image, label

train_ds = train_ds.shuffle(total_images)
train_ds = train_ds.map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_ds = train_ds.batch(batch_size)
train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)

In [5]:
#model
kernel = 3
input_shape =(xSize, ySize, rgbSize)

#architecture
def encoder_network(input_shape, activation, name="E"):
    """
    Encodes images into latent space
    """
    input = Input(input_shape, name=name+"input")
    net = Conv2D(depth, kernel_size=kernel, padding='same', activation=LeakyReLU(alpha=0.1))(input)
    net = Conv2D(depth, kernel_size=kernel, padding='same', activation=LeakyReLU(alpha=0.1), strides=2)(net)
    net = Conv2D(depth*2, kernel_size=kernel, padding='same', activation=LeakyReLU(alpha=0.1))(net)
    net = Conv2D(depth*2, kernel_size=kernel, padding='same', activation=LeakyReLU(alpha=0.1), strides=2)(net)
    net = Conv2D(depth*4, kernel_size=kernel, padding='same', activation=LeakyReLU(alpha=0.1))(net)
    net = Conv2D(depth*4, kernel_size=kernel, padding='same', activation=LeakyReLU(alpha=0.1), strides=2)(net)
    net = Conv2D(depth*8, kernel_size=kernel, padding='same', activation=LeakyReLU(alpha=0.1))(net)
    dense = GlobalAveragePooling2D()(net)
    # dense=Dropout(rate=0.5)(dense)
    dense = Dense(256, activation=LeakyReLU(alpha=0.1), kernel_initializer = tf.keras.initializers.glorot_normal())(dense)
    #dense=Dropout(rate=0.5)(dense)
    dense = Dense(128, activation=LeakyReLU(alpha=0.1), kernel_initializer = tf.keras.initializers.glorot_normal())(dense)
    #dense=Dropout(rate=0.5)(dense)
    latent = Dense(num_classes, kernel_initializer = tf.keras.initializers.glorot_normal(), activation=activation)(dense)

    return Model(inputs=input, outputs=latent, name=name)

#loss
cce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) # sinze labels are not one-hot encoded

#optimizer
classifier_opt = tf.keras.optimizers.Adam(learning_rate=1e-3)

#construct the network
classifier = encoder_network(input_shape, Softmax(), name="ConvNet") #Softmax is for one-hot encoded labels
classifier.summary(line_length=133)

Model: "ConvNet"
_____________________________________________________________________________________________________________________________________
 Layer (type)                                              Output Shape                                          Param #             
 ConvNetinput (InputLayer)                                 [(None, 32, 32, 3)]                                   0                   
                                                                                                                                     
 conv2d (Conv2D)                                           (None, 32, 32, 64)                                    1792                
                                                                                                                                     
 conv2d_1 (Conv2D)                                         (None, 16, 16, 64)                                    36928               
                                             

In [22]:
#training pipeline
@tf.function #compiles function, much faster
def train_step_classifier(images, labels):
    """
    The training step with the gradient tape (persistent). The switch allows for different training schedules.
    """
    with tf.GradientTape() as classifier_tape:
        pred_class = classifier(images, training=True)
        loss = cce_loss(labels, pred_class)

    gradients_of_classifier = classifier_tape.gradient(loss, classifier.trainable_variables)
    classifier_opt.apply_gradients(zip(gradients_of_classifier, classifier.trainable_variables))

@tf.function
def generate_and_classify(model, test_input, test_labels):
    #notice training is set to false
    #this is so all layers run in inference mode (batchnorm)
    predictions = model(test_input, training=False)
    #sparse for same reason as l
    top1 = tf.math.reduce_mean(tf.keras.metrics.sparse_top_k_categorical_accuracy(test_labels, predictions, k=1))
    top5 = tf.math.reduce_mean(tf.keras.metrics.sparse_top_k_categorical_accuracy(test_labels, predictions, k=5))
    return top1, top5

In [None]:
#train
def train(epochs):
    losses = []
    for epoch in range(1, epochs+1):
        print('>>>>>>>>>>>>>. Epoch{}'.format(epoch))
        #training
        loss = -1
        batch_losses = 0
        batch_top1 = 0
        batch_top5 = 0
        count = 0
        with tqdm(train_ds, unit="batch") as tepoch:
            for image_batch, labels_batch in tepoch:
                loss = train_step_classifier(image_batch, labels_batch)
                batch_losses += loss
                top1, top5 = generate_and_classify(classifier, image_batch, labels_batch)
                batch_top1 += top1
                batch_top5 += top5
                count += 1
                tepoch.set_postfix(loss=loss.numpy())
        #compute mean losses and accuracies
        loss = batch_losses/count
        top1 = batch_top1/count
        top5 = batch_top5/count

        print(f'Loss {loss} (top1 {top1}, top5{top5}')
        losses.append(loss)
    return losses

In [None]:
#train model


In [None]:
valid_ds = tfds.load('cifar10', split='test', shuffle_files=False, download=True, with_info=False)
total_validation_images = info.splits['test'].num_examples
print(f"found {total_validation_images} validation images")

#testing set
valid_ds = valid_ds.map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
