In [10]:
"""
Flax CNN
"""

import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers

# Load dataset
import keras
import tensorflow as tf
from sklearn.model_selection import train_test_split

# Suppress warning and info messages from jax
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

print("Jax CNN Notebook...")

Jax CNN Notebook...


In [11]:
def load_data():
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
    # print(f"x_train: {x_train.shape}")
    # print(f"y_train: {y_train.shape}")
    # print(f"x_test: {x_test.shape}")
    # print(f"y_test: {y_test.shape}")
    
    # Normalize data to range [0, 1]
    x_train = x_train / 255
    y_train = y_train / 255
    x_test = x_test / 255
    y_test = y_test / 255
    
    # Create validation set
    x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=5000, random_state=4)
    # print(f"x_train: {x_train.shape}")
    # print(f"y_train: {y_train.shape}")
    # print(f"x_val: {x_val.shape}")
    # print(f"y_val: {y_val.shape}")

    # https://www.tensorflow.org/api_docs/python/tf/image/ResizeMethod
    x_train_padded = tf.image.resize_with_pad(
        x_train,
        40,
        40,
        method=tf.image.ResizeMethod.GAUSSIAN,
        antialias=False
    )
    print(f"x_train_padded: {x_train_padded.shape}")
    
    # cropped images
    x_train_cropped = tf.image.crop_to_bounding_box(
        x_train_padded, 0, 0, 32, 32
    )
    print(f"x_train_cropped: {x_train_cropped.shape}")
    
    # re-assign
    x_train = x_train_cropped
    
    return x_train, y_train, x_val, y_val, x_test, y_test

In [12]:
x_train, y_train, x_val, y_val, x_test, y_test = load_data()
print(f"x_train: {x_train.shape}")
print(f"y_train: {y_train.shape}")
print(f"x_val: {x_val.shape}")
print(f"y_val: {y_val.shape}")
print(f"x_test: {x_test.shape}")
print(f"y_test: {y_test.shape}")

x_train_padded: (45000, 40, 40, 3)
x_train_cropped: (45000, 32, 32, 3)
x_train: (45000, 32, 32, 3)
y_train: (45000, 1)
x_val: (5000, 32, 32, 3)
y_val: (5000, 1)
x_test: (10000, 32, 32, 3)
y_test: (10000, 1)


In [None]:
class CNN(nn.Module):
    
    def setup(self):
        self.input = linen.Conv(features=32, kernel_size=(3,3), padding="SAME", name="input")
        
        self.b1_conv1 = linen.Conv(features=16, kernel_size=(3,3), padding="SAME", name="block1_conv1")
        self.b1_conv2 = linen.Conv(features=16, kernel_size=(3,3), padding="SAME", name="block1_conv2")

    def __call__(self, inputs):
        
        # input
        x = nn.relu(self.input(inputs))
        
        # block 1
        x = nn.relu(self.b1_conv1(x))
        x = nn.relu(self.b1_conv2(x))
        
        # block 2
        

        # flatten
        x = x.reshape((x.shape[0], -1))
        x = self.linear1(x)

        return  linen.softmax(x)