In [None]:
import jax
import jax.numpy as jnp
from jax import random
import flax.linen as nn
import optax

### Model

In [None]:
# input dims [N: number of batches, H, W, C]
class UnetJAX(nn.Module):
    input_image_size: int

    @staticmethod
    def contracting_block(input, num_features):
        input = nn.Conv(features=num_features, kernel_size=(3,3)) (input)
        # input = nn.Conv(features=num_features, kernel_size=(3,3), padding='VALID') (input)
        input = nn.relu(input)
        input = nn.Conv(features=num_features, kernel_size=(3,3)) (input)
        # input = nn.Conv(features=num_features, kernel_size=(3,3), padding='VALID') (input)
        input = nn.relu(input)
        return input

    @staticmethod
    def expanding_block(input, residual_feature_map, num_features):
        input = nn.ConvTranspose(features=num_features, kernel_size=(2,2), strides=(2,2)) (input)
        cropped_feature_map =  residual_feature_map
        # cropped_feature_map =  UNET_JAX.center_crop_array(residual_feature_map, input.shape[1])
        input = jnp.concatenate((input, cropped_feature_map), axis=2)
        input = nn.Conv(features=num_features, kernel_size=(3,3)) (input)
        # input = nn.Conv(features=num_features, kernel_size=(3,3), padding='VALID') (input)
        input = nn.relu(input)
        input = nn.Conv(features=num_features, kernel_size=(3,3)) (input)
        # input = nn.Conv(features=num_features, kernel_size=(3,3), padding='VALID') (input)
        input = nn.relu(input)
        return input

    @staticmethod    
    def final_block(input):
        return nn.Conv(features=1, kernel_size=(1,1)) (input)
        # input = nn.Conv(features=1, kernel_size=(1,1)) (input) 
        # return nn.sigmoid(input)

    @staticmethod
    def max_pool_block(input):
        return nn.max_pool(input, window_shape=(2,2), strides=(2,2))

    @staticmethod 
    def center_crop_array(array, new_size):
        crop_offset = (array.shape[0] - new_size)//2
        return array[crop_offset:-crop_offset, crop_offset:-crop_offset, :]

    @nn.compact
    def __call__(self, input):
        contracting_out1 = self.contracting_block(input, 64)
        max_pool_out = self.max_pool_block(contracting_out1)
        contracting_out2 = self.contracting_block(max_pool_out, 128)
        max_pool_out = self.max_pool_block(contracting_out2)
        contracting_out3 = self.contracting_block(max_pool_out, 256)
        max_pool_out = self.max_pool_block(contracting_out3)
        contracting_out4 = self.contracting_block(max_pool_out, 512)
        max_pool_out = self.max_pool_block(contracting_out4)
        contracting_out5 = self.contracting_block(max_pool_out, 1024)
        output = self.expanding_block(contracting_out5, contracting_out4, 512)
        output = self.expanding_block(output, contracting_out3, 256)
        output = self.expanding_block(output, contracting_out2, 128)
        output = self.expanding_block(output, contracting_out1, 64)
        output = self.final_block(output)
        return output

    def init_params(self, rng):
        input_size_dummy = jnp.ones([self.input_image_size, self.input_image_size,1])
        params = self.init(rng, input_size_dummy)
        return params

In [None]:
key = random.PRNGKey(0)
#changed input_image_size to 512 due to dataset shape
unet = UnetJAX(input_image_size=512)
unet_params = unet.init_params(key)
#jax.tree_map(lambda x: x.shape, unet_params) # Checking output shapes
dummy_in = jnp.ones([512,512,1])
dummy_out = unet.apply(unet_params, dummy_in)

In [None]:
unet_params["params"]

In [None]:
dummy_out.shape

### Data Loading

In [None]:
import glob
masks = glob.glob("../data/isbi2015/train/label/*.png")
orgs = glob.glob("../data/isbi2015/train/image/*.png")

In [None]:
from PIL import Image
imgs_list = []
masks_list = []
for image, mask in zip(orgs, masks):
    imgs_list.append(jnp.array(Image.open(image).resize((512,512))))
    masks_list.append(jnp.array(Image.open(mask).resize((512,512))))
imgs_np = jnp.asarray(imgs_list)
masks_np = jnp.asarray(masks_list)

In [None]:
print(imgs_np.shape, masks_np.shape)

### Data preparation

#### Input Range

In [None]:
print(imgs_np.max(), masks_np.max())

In [None]:
images = jnp.asarray(imgs_np, dtype=jnp.float32)/255
masks = jnp.asarray(masks_np, dtype=jnp.float32)/255

In [None]:
print(images.max(), masks.max())

#### Input Shape

In [None]:
images = images.reshape(images.shape + (1,))
masks = masks.reshape(masks.shape + (1,))
print(images.shape)
print(masks.shape)

In [None]:
images[0,0,0,0]

#### Data Split

In [None]:
split_factor = 0.8
images_train, images_test =  jnp.split(images, [int(images.shape[0]*split_factor)])
masks_train, masks_test =  jnp.split(masks, [int(masks.shape[0]*split_factor)])
dataset = { "train" :{}, "test": {}}
dataset["train"] = {"images": images_train, "labels": masks_train}
dataset["test"] = {"images": images_test, "labels": masks_test}
dataset["train"]["images"].shape

### Training


#### Metrics

In [None]:
# expects unnormalized log probabilities as logits
def loss_function(logits, labels):
    return optax.sigmoid_binary_cross_entropy(logits, labels).mean()

In [None]:
def logits_to_binary(logits):
    logits = nn.sigmoid(logits)
    logits = logits.round()
    return logits

def compute_accuracy(logits, labels):
    return jnp.mean(logits_to_binary(logits) == labels)

In [None]:
def compute_metrics(logits, labels):
    loss = loss_function(logits, labels)
    accuracy = compute_accuracy(logits, labels)
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
    }
    return metrics

#### Train Steps

In [None]:
from flax.training import train_state

class UnetTrainState():
    train_state : train_state.TrainState
    unet : UnetJAX
    current_epoch : int = 0
    rng : jax.random.PRNGKey
    unet_params = None 

    def __init__(self, unet: UnetJAX, optimizer, seed):
        self.unet = unet
        self.rng = jax.random.PRNGKey(seed)
        self.create_training_state(optimizer)

    def create_training_state(self, optimizer):
        self.unet_params = unet.init_params(self.rng)
        self.train_state = train_state.TrainState.create(apply_fn=unet.apply, params=self.unet_params, tx=optimizer)

    def print_train_metrics(self, metrics):
        print('train epoch: %d, loss: %.4f, accuracy: %.4f' % (self.current_epoch, metrics["loss"], metrics["accuracy"]))

    def print_eval_metrics(self, metrics):
        print('model eval: %d, loss: %.4f, accuracy: %.4f' % (self.current_epoch, metrics["loss"], metrics["accuracy"]))

    def train_step(self, batch):
        def compute_loss_function(params):
            logits = unet.apply(self.unet_params, batch['image']) 
            loss = loss_function(logits, batch['label'])
            return loss, logits
        compute_loss_grads = jax.value_and_grad(compute_loss_function, has_aux=True)
        (loss, logits), grads = compute_loss_grads(self.train_state.params)
        self.train_state = self.train_state.apply_gradients(grads=grads)

    def eval_step(self, batch):
        logits = unet.apply(self.train_state.params, batch['image']) 
        return compute_metrics(logits, batch['label'])

    # batch size is 1 image (paper)
    def train_epoch(self, train_dataset):
        self.rng, new_rng = jax.random.split(self.rng)
        shuffled_indexes = jax.random.permutation(new_rng, len(train_dataset["images"]))
        for image_index in shuffled_indexes:
            batch = {"image": train_dataset["images"][image_index], "label": train_dataset["labels"][image_index]}
            self.train_step(batch)
            batch_metrics = self.eval_step(batch)
            batch_metrics = jax.device_get(batch_metrics)
            self.print_train_metrics(batch_metrics)
            self.current_epoch+=1
    
    def eval_model(self, test_dataset):
        batch = {"image": test_dataset["images"], "label": test_dataset["labels"]}
        metrics = self.eval_step(batch)
        metrics = jax.device_get(metrics)
        self.print_eval_metrics(metrics)

#### Execute Training

In [None]:
unet = UnetJAX(input_image_size=512)
optimizer = optax.sgd(learning_rate=0.1, momentum=0.99)
unet_train_state = UnetTrainState(unet, optimizer, seed=0)

In [None]:
unet_train_state.train_epoch(train_dataset=dataset["train"])

In [None]:
unet_train_state.eval_model(test_dataset=dataset["test"])