## Setting up General

In [None]:
using InteractiveUtils

macro bind(def, element)
    quote
        local el = $(esc(element))
        global $(esc(def)) = Core.applicable(Base.get, el) ? Base.get(el) : missing
        el
    end
end

In [None]:
using Pkg

Pkg.activate(mktempdir())
Pkg.Registry.update()
Pkg.add("PlutoUI")
Pkg.add("Tar")
Pkg.add("MLDataPattern")
Pkg.add("Glob")
Pkg.add("NIfTI")
Pkg.add("CairoMakie")
Pkg.add("ImageCore")
Pkg.add("DataLoaders")
Pkg.add("CUDA")
Pkg.add(PackageSpec(;name="FastAI", version="0.4.0"))
        

In [None]:
using PlutoUI
using Tar
using MLDataPattern
using Glob
using NIfTI
using CairoMakie
using ImageCore
using DataLoaders
using CUDA
using FastAI
using FastAI: DataAugmentation

In [None]:
TableOfContents()

## Load Data

## Load data
Part of the [Medical Decathlon Dataset](http://medicaldecathlon.com/)

In [None]:
data_dir = raw"You_Path_To_Task02_Heart Dataset"

In [None]:
function loadfn_label(p)
    a = NIfTI.niread(string(p)).raw
    convert_a = convert(Array{UInt8}, a)
    convert_a = convert_a .+ 1
    return convert_a
end

function loadfn_image(p)
    a = NIfTI.niread(string(p)).raw
    convert_a = convert(Array{Float32}, a)
    convert_a = convert_a / max(convert_a...)
    return convert_a
end

In [None]:
images(dir) = mapobs(loadfn_image, Glob.glob("*.nii*", dir))
masks(dir) =  mapobs(loadfn_label, Glob.glob("*.nii*", dir))
data = (
    images(joinpath(data_dir, "imagesTr")),
    masks(joinpath(data_dir, "labelsTr")),
)

In [None]:
train_files, val_files = MLDataPattern.splitobs(data, 0.8)

In [None]:
image, mask = sample = getobs(data, 1);

In [None]:
@bind a PlutoUI.Slider(1:size(image, 3), default=50, show_value=true)

In [None]:
heatmap(image[:, :, a], colormap=:grays)

In [None]:
heatmap(mask[:, :, a], colormap=:grays)

## Create Learning Task

In [None]:
image_size = (96, 96, 96)

In [None]:
task = SupervisedTask(
    (FastAI.Vision.Image{3}(), Mask{3}(1:2)),
    (
        ProjectiveTransforms((image_size)),
        ImagePreprocessing(),
        OneHot()
    )
)

In [None]:
describetask(task)

## Visualize

In [None]:
xs, ys = FastAI.makebatch(task, data, 1:3);

In [None]:
@bind b PlutoUI.Slider(1:size(xs, 3), default=50, show_value=true)

In [None]:
heatmap(xs[:, :, b, 3, 2], colormap=:grays)

In [None]:
heatmap(ys[:, :, b, 2, 2], colormap=:grays)

## Dataloader

In [None]:
traindl, validdl = taskdataloaders(data, task, 1)

## Model

In [None]:
conv = (stride, in, out) -> Conv((3, 3, 3), in=>out, stride=stride, pad=(1, 1, 1))
tran = (stride, in, out) -> ConvTranspose((4, 4, 4), in=>out, stride=stride, pad=1)

conv1 = (in, out) -> Chain(conv(1, in, out), BatchNorm(out), x -> leakyrelu.(x))
conv2 = (in, out) -> Chain(conv(2, in, out), BatchNorm(out), x -> leakyrelu.(x))
tran2 = (in, out) -> Chain(tran(2, in, out), BatchNorm(out), x -> leakyrelu.(x))

In [None]:
function unet3D(in_chs, lbl_chs)
    # Contracting layers
    l1 = Chain(conv1(in_chs, 4))
    l2 = Chain(l1, conv1(4, 4), conv2(4, 16))
    l3 = Chain(l2, conv1(16, 16), conv2(16, 32))
    l4 = Chain(l3, conv1(32, 32), conv2(32, 64))
    l5 = Chain(l4, conv1(64, 64), conv2(64, 128))

    # Expanding layers
    l6 = Chain(l5, tran2(128, 64), conv1(64, 64))
    l7 = Chain(Parallel(+, l6, l4), tran2(64, 32), conv1(32, 32))       # Residual connection between l6 & l4
    l8 = Chain(Parallel(+, l7, l3), tran2(32, 16), conv1(16, 16))       # Residual connection between l7 & l3
    l9 = Chain(Parallel(+, l8, l2), tran2(16, 4), conv1(4, 4))          # Residual connection between l8 & l2
    l10 = Chain(l9, conv1(4, lbl_chs))
end

In [None]:
model = unet3D(3, 2) |> gpu;

## Helper Functions

In [None]:
function dice_metric(ŷ, y)
    dice = 2 * sum(ŷ .& y) / (sum(ŷ) + sum(y))
    return dice
end

function as_discrete(array, logit_threshold)
    array = array .>= logit_threshold
    return array
end

## Loss Functions

In [1]:
function dice_loss(ŷ, y)
    ϵ = 1e-5
    return loss = 1 - ((2 * sum(ŷ .* y) + ϵ) / (sum(ŷ .* ŷ) + sum(y .* y) + ϵ))
end


dice_loss (generic function with 1 method)

## Training

In [None]:
ps = Flux.params(model);
loss_function = dice_loss
optimizer = Flux.ADAM(0.01)

In [None]:
max_epochs = 2
val_interval = 1
epoch_loss_values = []
val_epoch_loss_values = []
dice_metric_values = []

In [None]:
for (xs, ys) in validdl
    @info size(xs)
    @info size(ys)
end

for (xs, ys) in traindl
    @info size(xs)
    @info size(ys)
end

In [None]:
for epoch in 1:max_epochs
    step = 0
    @show epoch

    # Loop through training data
    for (xs, ys) in traindl
        xs, ys = xs |> gpu, ys |> gpu
        step += 1
        @show step
        gs = Flux.gradient(ps) do
            ŷs = model(xs)
            loss = loss_function(ŷs[:, :, :, 2, :], ys[:, :, :, 2, :])
            return loss
        end
        Flux.update!(optimizer, ps, gs)
    end

    # Loop through validation data
    if (epoch + 1) % val_interval == 0
        val_step = 0
        for (val_xs, val_ys) in validdl
            val_xs, val_ys = val_xs |> gpu, val_ys |> gpu
            val_step += 1
            @show val_step

            local val_ŷs = model(val_xs)
            local val_loss = loss_function(val_ŷs[:, :, :, 2, :], val_ys[:, :, :, 2, :])
            # val_ŷs, val_ys = as_discrete(val_ŷs, 0.5), as_discrete(val_ys, 0.5)
        end
    end
end