In [10]:
from tensorflow.keras import layers 

import matplotlib.pyplot as plt 
import tensorflow_addons as tfa
import numpy as np 
import tensorflow as tf 

AUTOTUNE = tf.data.AUTOTUNE

In [12]:
learning_rate = 0.01
weight_decay = 0.0001 
batch_size = 128 
num_epochs = 10

img_size = 32


In [8]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() 
val_split = 0.1 

val_indicies = int(len(x_train) * val_split)
new_x_train, new_y_train = x_train[val_indicies:], y_train[val_indicies:]

x_val, y_val = x_train[:val_indicies], y_train[:val_indicies] 

In [9]:
x_train.shape, y_train.shape, x_val.shape, y_val.shape

((50000, 32, 32, 3), (50000, 1), (5000, 32, 32, 3), (5000, 1))

In [13]:
data_aug = tf.keras.Sequential([
    layers.RandomCrop(img_size, img_size), 
    layers.RandomFlip('horizontal')
], name='data_aug')

def make_dataset(image, label, is_train=False): 
    dataset = tf.data.Dataset.from_tensor_slices((image, label))
    
    if is_train:
        dataset = dataset.shuffle(batch_size * 10)
    
    dataset = dataset.batch(batch_size) 
        
    if is_train: 
        dataset = dataset.map(
            lambda x, y : (data_aug(x), y), num_parallel_calls=AUTOTUNE
        )
        
    return dataset.prefetch(AUTOTUNE) 


train_dataset = make_dataset(new_x_train, new_y_train, is_train=True) 
val_dataset = make_dataset(x_val, y_val)
test_dataset = make_dataset(x_test, y_test)

In [15]:
def activation_block(x):
    x = layers.Activation('gelu')(x)
    return layers.BatchNormalization()(x)


def conv_stem(x, filters : int, patch_size : int): 
    x = layers.Conv2D(filters, kernel_size=patch_size, strides=patch_size)(x)
    return activation_block(x)

def conv_mixer_block(x, filters : int, kernel_size : int): 
    x0 = x
    x = layers.DepthwiseConv2D(kernel_size=kernel_size, padding='same')(x) 
    x = layers.Add()([activation_block(x), x0])
    
    x = layers.Conv2D(filters, kernel_size=1)(x)
    x = activation_block(x) 
    
    return x


def get_conv_mixer_256_8(image_size=32, filters=256, depth=8, kernel_size=5, patch_size=2, num_classes=10):
    inputs = tf.keras.Input((img_size, img_size, 3))
    x = layers.Rescaling(scale=1.0 / 255)(inputs) 
    
    x = conv_stem(x, filters, patch_size) 
    
    for _ in range(depth): 
        x = conv_mixer_block(x, filters, kernel_size) 
        
    x = layers.GlobalAveragePooling2D()(x) 
    outputs = layers.Dense(num_classes, activation='softmax')(x) 
    
    return tf.keras.Model(inputs=inputs, outputs=outputs)
    

In [16]:
def run_experiment(model): 
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, 
        weight_decay=weight_decay
    )
    
    model.compile(
        optimizer = optimizer, 
        loss='sparse_categorical_crossentropy', 
        metrics=['accuracy']
    )
    
    checkpoint_filepath = '/tmp/checkpoint'
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        checkpoint_filepath, 
        monitor='val_accuracy', 
        save_best_only=True, 
        save_weights_only=True
    )        
    
    hist = model.fit(
        train_dataset, 
        validation_data=val_dataset, 
        epochs=num_epochs, 
        callbacks=[checkpoint_callback]
    )
    
    model.load_weights(checkpoint_filepath)
    _, accuracy = model.evaluate(test_dataset) 
    
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    
    return hist, model 


In [17]:
conv_mixer_model = get_conv_mixer_256_8() 
history, conv_mixer_model = run_experiment(conv_mixer_model)

Epoch 1/10

KeyboardInterrupt: 