In [2]:
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

In [3]:
key = jax.random.PRNGKey(0)
batch = jax.random.normal(key, (4, 32, 32, 3))  # batch, height, width, channel (NHWC)



In [4]:
conv = nn.Conv(features=16, kernel_size=(3, 3), padding=1)
variables = conv.init(key, batch)
output = conv.apply(variables, batch)
print(f"{batch.shape} -> {output.shape}")

(4, 32, 32, 3) -> (4, 32, 32, 16)


In [5]:
class SimpleConv(nn.Module):
    @nn.compact
    def __call__(self, x):
        # x is shape (4, 32, 32, 3)
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 3) -> (4, 32, 32, 16)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 32)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1, strides=2)(x)  # (4, 32, 32, 32) -> (4, 16, 16, 32)
        return x

In [6]:
model = SimpleConv()

In [7]:
variables = model.init(key, batch)
output = model.apply(variables, batch)
output.shape

(4, 16, 16, 32)

In [32]:
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        
        # Input x = (4, 32, 32, 3)  batch, height, width, channel (NHWC)
        
        # Input layer
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 3) -> (4, 32, 32, 16)
        x = nn.relu(x)
        
        # Block 1
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        
        # Block 2
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        
        # Block 3
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        x = nn.Conv(features=16, kernel_size=(3, 3), padding=1)(x)  # (4, 32, 32, 16) -> (4, 32, 32, 16)
        x = nn.relu(x)
        
        # Block 4
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1, strides=2)(x)  # (4, 32, 32, 16) -> (4, 16, 16, 32)
        x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        
        # Block 5
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        
        # Block 6
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(3, 3), padding=1)(x)  # (4, 16, 16, 32) -> (4, 16, 16, 32)
        x = nn.relu(x)
        
        # Block 7
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1, strides=2)(x)  # (4, 16, 16, 32) -> (4, 8, 8, 64)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        
        # Block 8
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        
        # Block 9
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), padding=1)(x)  # (4, 8, 8, 64) -> (4, 8, 8, 64)
        x = nn.relu(x)
        
        # Pooling 
        x = nn.avg_pool(x, window_shape=(8, 8)) # (4, 8, 8, 64) -> (4, 1, 1, 64)
        print("avg_pool", x.shape)
        x = x.flatten()  # flatten (4, 1, 1, 64) -> (4, 64)
        print("flatten", x.shape)
        
        # Output
        x = nn.Dense(features=10)(x)
        print("dense", x.shape)
        x = nn.log_softmax(x)
        print("softmax", x.shape)
        return x

In [33]:
resnet = CNN()

In [34]:
batch = jax.random.normal(key, (32, 32, 3))  # height, width, channel (NHWC)
variables = resnet.init(key, batch)
output = resnet.apply(variables, batch)
output.shape

avg_pool (1, 1, 64)
flatten (64,)
dense (10,)
softmax (10,)
avg_pool (1, 1, 64)
flatten (64,)
dense (10,)
softmax (10,)


(10,)