## Lung Segmentation using UNets

In [None]:
%load_ext autoreload
%autoreload

In [None]:
from deeply.model.unet import (
    UNet,
    Trainer
)
from deeply.plots import segplot
from deeply.datasets.util import split as split_dataset

import deeply.datasets as dd

#### Prepare Dataset

In [None]:
names    = ("montgomery",)
datasets = { }

for name in names:
    datasets[name] = split_dataset(dd.load(name, shuffle_files = True)["data"])

In [None]:
train, val, test = datasets["montgomery"]

#### Show Samples

In [None]:
n_samples = 3

for data in train.take(n_samples):
    segplot(data["image"], data["mask"])

In [None]:
width, height = (512, 512)

#### Build Model

In [None]:
unet = UNet(x = width, y = height, n_classes = 1, batch_norm = False, final_activation = "sigmoid")

In [None]:
unet.plot()

In [None]:
from tensorflow.keras.optimizers import Adam

from deeply.metrics import dice_coefficient

def dice_loss(*args, **kwargs):
    return -dice_coefficient(*args, **kwargs)

In [None]:
unet.compile(optimizer = Adam(learning_rate = 1e-5), loss = dice_loss, metrics = ["binary_accuracy"])

In [None]:
batch_size = 1
epochs = 50

#### Preprocess Data

In [None]:
import numpy as np
import tensorflow as tf
import imgaug.augmenters as iaa

from deeply.util.array import squash

In [None]:
augmentor = iaa.Sequential([
    iaa.Resize({ "width": width, "height": height })
])

grayscale = tf.image.rgb_to_grayscale
from_, to = (0, 1, 2), (1, 0, 2)

def augment(image):
    image     = image.numpy()
    
    image     = np.moveaxis(image, from_, to)
    augmented = augmentor(images = [image])
    augmented = squash(augmented)
    
    augmented = np.moveaxis(augmented, to, from_)
    
    return augmented
    
def mapper(ds):
    image, mask = ds["image"], ds["mask"]
    image, mask = grayscale(image), grayscale(mask)
    
    feature = tf.py_function(augment, [image], image.dtype)
    mask    = tf.py_function(augment, [mask],  mask.dtype)
    
    return feature, mask

In [None]:
trainer = Trainer()
history = trainer.fit(unet, train, val = val, batch_size = batch_size, epochs = 1, mapper = mapper)

In [None]:
expected = test.map(mapper)
predict  = unet.predict(expected.batch(batch_size))

In [None]:
for i, (image, mask) in enumerate(expected.take(n_samples)):
    segplot(image, mask, predict[i])