In [1]:
import jax
import jax.numpy as jnp
from jax import random
import flax.linen as nn
import optax

### Model

In [2]:
# input dims [N: number of batches, H, W, C]
class UNET_JAX(nn.Module):
    input_image_size: int

    @staticmethod
    def contracting_block(input, num_features):
        input = nn.Conv(features=num_features, kernel_size=(3,3)) (input)
        # input = nn.Conv(features=num_features, kernel_size=(3,3), padding='VALID') (input)
        input = nn.relu(input)
        input = nn.Conv(features=num_features, kernel_size=(3,3)) (input)
        # input = nn.Conv(features=num_features, kernel_size=(3,3), padding='VALID') (input)
        input = nn.relu(input)
        return input

    @staticmethod
    def expanding_block(input, residual_feature_map, num_features):
        input = nn.ConvTranspose(features=num_features, kernel_size=(2,2), strides=(2,2)) (input)
        cropped_feature_map =  residual_feature_map
        # cropped_feature_map =  UNET_JAX.center_crop_array(residual_feature_map, input.shape[1])
        input = jnp.concatenate((input, cropped_feature_map), axis=3)
        input = nn.Conv(features=num_features, kernel_size=(3,3)) (input)
        # input = nn.Conv(features=num_features, kernel_size=(3,3), padding='VALID') (input)
        input = nn.relu(input)
        input = nn.Conv(features=num_features, kernel_size=(3,3)) (input)
        # input = nn.Conv(features=num_features, kernel_size=(3,3), padding='VALID') (input)
        input = nn.relu(input)
        return input

    @staticmethod    
    def final_block(input):
        return nn.Conv(features=1, kernel_size=(1,1)) (input)
        # input = nn.Conv(features=1, kernel_size=(1,1)) (input) 
        # return nn.sigmoid(input)

    @staticmethod
    def max_pool_block(input):
        return nn.max_pool(input, window_shape=(2,2), strides=(2,2))

    @staticmethod 
    def center_crop_array(array, new_size):
        crop_offset = (array.shape[1] - new_size)//2
        return array[:, crop_offset:-crop_offset, crop_offset:-crop_offset, :]

    @nn.compact
    def __call__(self, input):
        contracting_out1 = self.contracting_block(input, 64)
        max_pool_out = self.max_pool_block(contracting_out1)
        contracting_out2 = self.contracting_block(max_pool_out, 128)
        max_pool_out = self.max_pool_block(contracting_out2)
        contracting_out3 = self.contracting_block(max_pool_out, 256)
        max_pool_out = self.max_pool_block(contracting_out3)
        contracting_out4 = self.contracting_block(max_pool_out, 512)
        max_pool_out = self.max_pool_block(contracting_out4)
        contracting_out5 = self.contracting_block(max_pool_out, 1024)
        output = self.expanding_block(contracting_out5, contracting_out4, 512)
        output = self.expanding_block(output, contracting_out3, 256)
        output = self.expanding_block(output, contracting_out2, 128)
        output = self.expanding_block(output, contracting_out1, 64)
        output = self.final_block(output)
        return output

    def init_unet(self, rng):
        input_size_dummy = jnp.ones([1,self.input_image_size, self.input_image_size,1])
        params = self.init(rng, input_size_dummy)
        return params

In [3]:
key = random.PRNGKey(0)
#changed input_image_size to 512 due to dataset shape
unet = UNET_JAX(input_image_size=512)
unet_params = unet.init_unet(key)
#jax.tree_map(lambda x: x.shape, unet_params) # Checking output shapes
dummy_in = jnp.ones([1,512,512,1])
dummy_out = unet.apply(unet_params, dummy_in)



In [4]:
dummy_out.shape

(1, 512, 512, 1)

### Data Loading

In [9]:
import glob
masks = glob.glob("../data/isbi2015/train/label/*.png")
orgs = glob.glob("../data/isbi2015/train/image/*.png")

In [15]:
from PIL import Image
imgs_list = []
masks_list = []
for image, mask in zip(orgs, masks):
    imgs_list.append(jnp.array(Image.open(image).resize((512,512))))
    masks_list.append(jnp.array(Image.open(mask).resize((512,512))))
imgs_np = jnp.asarray(imgs_list)
masks_np = jnp.asarray(masks_list)

In [16]:
print(imgs_np.shape, masks_np.shape)

(30, 512, 512) (30, 512, 512)


### Data preparation

#### Input Range

In [12]:
print(imgs_np.max(), masks_np.max())

255 255


In [17]:
images = jnp.asarray(imgs_np, dtype=jnp.float32)/255
masks = jnp.asarray(masks_np, dtype=jnp.float32)/255

In [18]:
print(images.max(), masks.max())

1.0 1.0


#### Input Shape

In [21]:
images = images.reshape(images.shape + (1,))
masks = masks.reshape(y.shape + (1,))
print(images.shape)
print(masks.shape)

(30, 512, 512, 1)
(30, 512, 512, 1)


### Training


In [7]:
# expects unnormalized log probabilities as logits
def compute_loss_function(logits, labels):
    return optax.sigmoid_binary_cross_entropy(logits= logits, labels=labels).mean()

In [6]:
def compute_accuracy(logits):
    logits = nn.sigmoid(logits)
    logits = logits.round()
    return logits

In [8]:
def compute_metrics(logits, labels):
    loss = compute_loss_function(logits, labels)
    accuracy = compute_accuracy(logits, labels)
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
    }
    return metrics

In [None]:
def create_training_state(rng, optimizer):
    unet_init_params = UNET_JAX.init_unet(rng)

In [None]:
def apply_model(state, images, labels):