In [1]:
import jax
import jax.numpy as jnp
from jax import random
import flax.linen as nn
import optax

### Model

In [2]:
# input dims [N: number of batches, H, W, C]
class UNET_JAX(nn.Module):
    input_image_size: int

    @staticmethod
    def contracting_block(input, num_features):
        input = nn.Conv(features=num_features, kernel_size=(3,3), padding='VALID') (input)
        input = nn.relu(input)
        input = nn.Conv(features=num_features, kernel_size=(3,3), padding='VALID') (input)
        input = nn.relu(input)
        return input

    @staticmethod
    def expanding_block(input, residual_feature_map, num_features):
        input = nn.ConvTranspose(features=num_features, kernel_size=(2,2), strides=(2,2)) (input)
        cropped_feature_map =  UNET_JAX.center_crop_array(residual_feature_map, input.shape[1])
        input = jnp.concatenate((input, cropped_feature_map), axis=3)
        input = nn.Conv(features=num_features, kernel_size=(3,3), padding='VALID') (input)
        input = nn.relu(input)
        input = nn.Conv(features=num_features, kernel_size=(3,3), padding='VALID') (input)
        input = nn.relu(input)
        return input

    @staticmethod    
    def final_block(input):
        return nn.Conv(features=2, kernel_size=(1,1)) (input)

    @staticmethod
    def max_pool_block(input):
        return nn.max_pool(input, window_shape=(2,2), strides=(2,2))

    @staticmethod 
    def center_crop_array(array, new_size):
        crop_offset = (array.shape[1] - new_size)//2
        return array[:, crop_offset:-crop_offset, crop_offset:-crop_offset, :]

    @nn.compact
    def __call__(self, input):
        contracting_out1 = self.contracting_block(input, 64)
        max_pool_out = self.max_pool_block(contracting_out1)
        contracting_out2 = self.contracting_block(max_pool_out, 128)
        max_pool_out = self.max_pool_block(contracting_out2)
        contracting_out3 = self.contracting_block(max_pool_out, 256)
        max_pool_out = self.max_pool_block(contracting_out3)
        contracting_out4 = self.contracting_block(max_pool_out, 512)
        max_pool_out = self.max_pool_block(contracting_out4)
        contracting_out5 = self.contracting_block(max_pool_out, 1024)
        output = self.expanding_block(contracting_out5, contracting_out4, 512)
        output = self.expanding_block(output, contracting_out3, 256)
        output = self.expanding_block(output, contracting_out2, 128)
        output = self.expanding_block(output, contracting_out1, 64)
        output = self.final_block(output)
        print(output.shape)
        return output

    def init_unet(self, rng):
        input_size_dummy = jnp.ones([1,self.input_image_size, self.input_image_size,1])
        params = self.init(rng, input_size_dummy)
        return params

In [3]:
key = random.PRNGKey(0)
unet = UNET_JAX(input_image_size=572)
unet_params = unet.init_unet(key)
jax.tree_map(lambda x: x.shape, unet_params) # Checking output shapes



(1, 388, 388, 2)


FrozenDict({
    params: {
        ConvTranspose_0: {
            bias: (512,),
            kernel: (2, 2, 1024, 512),
        },
        ConvTranspose_1: {
            bias: (256,),
            kernel: (2, 2, 512, 256),
        },
        ConvTranspose_2: {
            bias: (128,),
            kernel: (2, 2, 256, 128),
        },
        ConvTranspose_3: {
            bias: (64,),
            kernel: (2, 2, 128, 64),
        },
        Conv_0: {
            bias: (64,),
            kernel: (3, 3, 1, 64),
        },
        Conv_1: {
            bias: (64,),
            kernel: (3, 3, 64, 64),
        },
        Conv_10: {
            bias: (512,),
            kernel: (3, 3, 1024, 512),
        },
        Conv_11: {
            bias: (512,),
            kernel: (3, 3, 512, 512),
        },
        Conv_12: {
            bias: (256,),
            kernel: (3, 3, 512, 256),
        },
        Conv_13: {
            bias: (256,),
            kernel: (3, 3, 256, 256),
        },
        Co

### Training


In [None]:
def loss_function(logits, labels):
    return optax.softmax_cross_entropy(logits= logits, labels=labels)

In [None]:
def create_training_state(rng, optimizer):
    unet_init_params = UNET_JAX.init_unet(rng)

In [None]:
def apply_model(state, images, labels):