In [2]:
"""
Flax CNN
"""

import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers

# Load dataset
import keras
import tensorflow as tf
from sklearn.model_selection import train_test_split

# Suppress warning and info messages from jax
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

print("Jax CNN Notebook...")

2023-01-23 16:37:20.613628: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-23 16:37:20.675209: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-01-23 16:37:21.614953: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-01-23 16:37:21.615044: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


Jax CNN Notebook...


In [7]:
def load_data():
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    # print(f"x_train: {x_train.shape}")
    # print(f"y_train: {y_train.shape}")
    # print(f"x_test: {x_test.shape}")
    # print(f"y_test: {y_test.shape}")
    
    # Normalize data to range [0, 1]
    x_train = x_train / 255
    y_train = y_train / 255
    x_test = x_test / 255
    y_test = y_test / 255
    
    # Create validation set
    x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=5000, random_state=4)
    # print(f"x_train: {x_train.shape}")
    # print(f"y_train: {y_train.shape}")
    # print(f"x_val: {x_val.shape}")
    # print(f"y_val: {y_val.shape}")

    # https://www.tensorflow.org/api_docs/python/tf/image/ResizeMethod
    # Have padding be zeros
    x_train_padded = tf.image.resize_with_pad(
        x_train,
        40,
        40,
        method=tf.image.ResizeMethod.GAUSSIAN,
        antialias=False
    )
    print(f"x_train_padded: {x_train_padded.shape}")
    
    # cropped images
    x_train_cropped = tf.image.crop_to_bounding_box(
        x_train_padded, 0, 0, 32, 32
    )
    print(f"x_train_cropped: {x_train_cropped.shape}")
    
    # re-assign
    x_train = x_train_cropped
    
    return x_train, y_train, x_val, y_val, x_test, y_test

In [8]:
x_train, y_train, x_val, y_val, x_test, y_test = load_data()
print(f"x_train: {x_train.shape}")
print(f"y_train: {y_train.shape}")
print(f"x_val: {x_val.shape}")
print(f"y_val: {y_val.shape}")
print(f"x_test: {x_test.shape}")
print(f"y_test: {y_test.shape}")

2023-01-23 16:37:59.994321: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-23 16:37:59.994390: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or directory
2023-01-23 16:37:59.994432: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublasLt.so.11'; dlerror: libcublasLt.so.11: cannot open shared object file: No such file or directory
2023-01-23 16:37:59.994479: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcufft.so.10'; dlerror: libcufft.so.10: cannot open shared object file: No such file or directory
2023-01-23 16:37:59.994523: W tensorflow/stream_executor/platform/default/dso_loader.cc:64

x_train_padded: (45000, 40, 40, 3)
x_train_cropped: (45000, 32, 32, 3)
x_train: (45000, 32, 32, 3)
y_train: (45000, 1)
x_val: (5000, 32, 32, 3)
y_val: (5000, 1)
x_test: (10000, 32, 32, 3)
y_test: (10000, 1)


2023-01-23 16:38:03.034500: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 552960000 exceeds 10% of free system memory.


In [9]:
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        
        # Input layer
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 3) -> (4, 32, 32, 16)
        x = nn.relu(x)
        
        # Block 1
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        
        # Block 2
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        
        # Block 3
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        
        # Block 4
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1, strides=2)(x)  # (4, 32, 32, 16) -> (4, 16, 16, 32)
        x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        
        # Block 5
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        
        # Block 6
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        
        # Block 7
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1, strides=2)(x)  # (4, 16, 16, 32) -> (4, 8, 8, 64)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        
        # Block 8
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        
        # Block 9
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        
        # Pooling 
        x = nn.avg_pool(x, window_shape=(8, 8)) # (4, 8, 8, 64) -> (4, 1, 1, 64)
        x = x.flatten()  # flatten (4, 1, 1, 64) -> (4, 64)
        
        # Output 
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        return x

In [10]:
resnetv2 = CNN()

In [23]:
# Test final output shape

# Generate random data 
key = jax.random.PRNGKey(0)
batch = jax.random.normal(key, (32, 32, 3))  # height, width, channel (NHWC)

# Test model
variables = resnetv2.init(key, batch)
output = resnetv2.apply(variables, batch)
output.shape

(10,)