In [1]:
import jax.numpy as jnp
import flax.linen as nn
import optax

### Model

In [2]:
# input dims (channels, rows, columns)
class UNET_JAX(nn.Module):
    @staticmethod
    def contracting_block(input, num_features):
        input = nn.Conv(features=num_features, kernel_size=(3,3)) (input)
        input = nn.relu(input)
        input = nn.Conv(features=num_features, kernel_size=(3,3)) (input)
        input = nn.relu(input)
        return input

    @staticmethod
    def expanding_block(input, residual_feature_map, num_features):
        cropped_feature_map =  UNET_JAX.center_crop_array(residual_feature_map, input.shape[1])
        input = nn.ConvTranspose(features=num_features, kernel_size=(2,2), strides=(2,2))
        input = jnp.hstack((input, cropped_feature_map))
        input = nn.Conv(features=num_features, kernel_size=(3,3)) (input)
        input = nn.relu(input)
        input = nn.Conv(features=num_features, kernel_size=(3,3)) (input)
        input = nn.relu(input)
        return input

    @staticmethod    
    def final_block(input):
        return nn.Conv(features=2, kernel_size=(1,1)) (input)

    @staticmethod
    def max_pool_block(input):
        return nn.max_pool(input, window_shape=(2,2))

    @staticmethod 
    def center_crop_array(array, new_size):
        crop_offset = (array.shape[1] - new_size)//2
        return array[:, crop_offset:-crop_offset, crop_offset:-crop_offset]

    def __call__(self, input):
        contracting_out1 = self.contracting_block(input, 64)
        max_pool_out = self.max_pool_block(contracting_out1)
        contracting_out2 = self.contracting_block(max_pool_out, 128)
        max_pool_out = self.max_pool_block(contracting_out2)
        contracting_out3 = self.contracting_block(max_pool_out, 256)
        max_pool_out = self.max_pool_block(contracting_out3)
        contracting_out4 = self.contracting_block(max_pool_out, 512)
        max_pool_out = self.max_pool_block(contracting_out4)
        contracting_out5 = self.contracting_block(max_pool_out, 1024)
        contracting_out4 = contracting_out4
        output = self.expanding_block(contracting_out5, contracting_out4, 512)
        output = self.expanding_block(output, contracting_out3, 256)
        output = self.expanding_block(output, contracting_out2, 128)
        output = self.expanding_block(output, contracting_out1, 64)
        output = self.final_block(output)
        return output

    def init_unet(self, rng):
        unet = UNET_JAX()
        input_image_size = 572
        input_size_dummy = jnp.ones(1, input_image_size, input_image_size) 
        params = unet.init(rng, input_size_dummy)
        return params

### Training


In [4]:
def loss_function(logits, labels):
    return optax.softmax_cross_entropy(logits= logits, labels=labels)

TypeError: softmax_cross_entropy() missing 2 required positional arguments: 'logits' and 'labels'

In [None]:
def create_training_state(rng, optimizer):
    unet_init_params = UNET_JA

In [None]:
def apply_model(state, images, labels):