# This julia notebook shows how to load the model and apply it to an image.

In [5]:
using Pkg
Pkg.activate("/home/molloi-lab/Desktop/Project BAC/BAC project/libs/")
using Lux, Random, NNlib, Zygote, LuxCUDA, CUDA, FluxMPI, JLD2, DICOM
using Images
using MLUtils
using MPI
using Optimisers
using ImageMorphology, ChainRulesCore, Statistics, CSV, DataFrames, Dates

[32m[1m  Activating[22m[39m project at `~/Desktop/Project BAC/BAC project/libs`
└ @ FluxMPI /home/molloi-lab/.julia/packages/FluxMPI/BwbGS/src/FluxMPI.jl:28


In [6]:
CUDA.allowscalar(false)

FluxMPI.Init(;gpu_devices = [0,1,2,3])

# Rank(similar to threadID) of the current process.
comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)
dev = gpu_device()
dev_cpu = cpu_device()

_conv = (in, out) -> Conv((3, 3), in=>out, pad=1)
conv1 = (in, out) -> Chain(_conv(in, out), BatchNorm(out, leakyrelu))

_tran = (in, out) -> ConvTranspose((2, 2), in => out, stride = 2)
tran = (in, out) -> Chain(_tran(in, out), BatchNorm(out, leakyrelu))

struct UNet{
    CH1, CH2, CH3, CH4, CH5, CH6, CH7
} <: Lux.AbstractExplicitContainerLayer{
    (:l1, :l2, :l3, :l4, :l5, :l6, :l7)
}
    l1::CH1
    l2::CH2
    l3::CH3
    l4::CH4
    l5::CH5
    l6::CH6
    l7::CH7
end

function UNet(in_chs, lbl_chs, size)
    l1 = Chain(conv1(in_chs, size), conv1(size, size))
    l2 = Chain(MaxPool((2,2), stride=2), conv1(size, size*2), conv1(size*2, size*2))
    l3 = Chain(MaxPool((2,2), stride=2), conv1(size*2, size*4), conv1(size*4, size*4))
    l4 = Chain(MaxPool((2,2), stride=2), conv1(size*4, size*8), conv1(size*8, size*8), tran(size*8, size*4))

    # Expanding layers
    l5 = Chain(conv1(size*8, size*4), conv1(size*4, size*4), tran(size*4, size*2))
    l6 = Chain(conv1(size*4, size*2), conv1(size*2, size*2), tran(size*2, size))
    l7 = Chain(conv1(size*2, size), conv1(size, size), Conv((1, 1), size=>lbl_chs), sigmoid)

    UNet(l1, l2, l3, l4, l5, l6, l7)
end

function (m::UNet)(x, ps, st::NamedTuple)
    # Convolutional layers
    x1, st_l1 = m.l1(x, ps.l1, st.l1)

    x2, st_l2 = m.l2(x1, ps.l2, st.l2)

    # Downscaling Blocks
    x3, st_l3 = m.l3(x2, ps.l3, st.l3)
    x4, st_l4 = m.l4(x3, ps.l4, st.l4)

    # Upscaling Blocks
    x5, st_l5 = m.l5(cat(x4, x3; dims=3), ps.l5, st.l5)
    x6, st_l6 = m.l6(cat(x5, x2; dims=3), ps.l6, st.l6)
    x7, st_l7 = m.l7(cat(x6, x1; dims=3), ps.l7, st.l7)


    # Merge states
    st = (
    l1=st_l1, l2=st_l2, l3=st_l3, l4=st_l4, l5=st_l5, l6=st_l6, l7=st_l7
    )

    return x7, st
end

In [7]:
# input data
input_img = "/media/molloi-lab/2TB1/Clean_Dataset_full/SID-100510/L_CC.3328_2560.dcm"

"/media/molloi-lab/2TB1/Clean_Dataset_full/SID-100510/L_CC.3328_2560.dcm"

In [None]:
# Prepare the image
ground_truth_mask = Float32.(Images.load(ground_truth_mask_path))
breast_mask = Float32.(Images.load(breast_mask_path))
dcm_data = dcm_parse(dcm_path)
is_reversed = uppercase(dcm_data[(0x2050, 0x0020)]) == "INVERSE"
pixel_size = dcm_data[(0x0018, 0x1164)]
img = Float32.(dcm_data[(0x7fe0, 0x0010)])
original_size = size(img)
# resize image based on pixel length
img, breast_mask, ground_truth_mask, new_size = resize_dicom_image(img, breast_mask, ground_truth_mask, pixel_size)
# normalize image and correct color
img = normalize_img(img; mask = breast_mask, invert = is_reversed)
# crop to breast only
img_cropped, ground_truth_mask_cropped, coords = crop_to_bounding_box(breast_mask, img, ground_truth_mask)
# save resize info to local
@save joinpath(curr_dir, f_name*"_resize_info.jld2") original_size new_size coords
# check size
x, y = size(img_cropped)
# if y % 32 != 0
#     x_org, y_org = size(img)
#     println(i, "\t", ct+1)
#     println("($x_org, $y_org)")
#     println("($x, $y)\n")
# end
@assert x % 32 == 0
@assert y % 32 == 0

#save
@save joinpath(out_dir, f_name*".jld2") img_cropped
Images.save(joinpath(out_dir, f_name*".png"), Gray.(round.(ground_truth_mask_cropped)))


In [9]:
# load the trained model
model_path = "/home/molloi-lab/Desktop/wenbo2_flashdrive_backup/saved_train_info_334.jld2"
@load model_path ps_save st_save
ps_save = ps_save |> dev
st_save = st_save |> dev

# ps = FluxMPI.synchronize!(ps; root_rank = 0)
# st = FluxMPI.synchronize!(st; root_rank = 0)

(l1 = (layer_1 = NamedTuple(), layer_2 = (running_mean = Float32[-1.5754029, 2.2304332, -1.6438724, 2.1291182, 1.5642703, -4.459931, -1.4128605, 2.5394652, 0.08220151, 1.3944055, -2.2787423, 2.0576603, -1.6820405, 1.7552352, 2.6440036, -2.1005633], running_var = Float32[2.0897603, 2.1679688, 0.8702852, 0.7614294, 0.4340817, 9.882541, 1.1594844, 3.4043326, 0.6023198, 1.5117741, 2.6206417, 2.2926908, 1.1610765, 1.3491561, 1.8846543, 1.7653261], training = Val{true}()), layer_3 = NamedTuple(), layer_4 = (running_mean = Float32[-0.26003373, 0.75799036, 1.0921711, -2.363193, -2.6438394, 0.91217846, -11.351362, 1.798894, -0.120637566, -1.7964456, -2.17908, 0.97645134, -0.6907514, -0.715033, 0.84746337, 1.8764355], running_var = Float32[0.86277986, 2.34481, 0.5159975, 3.5334196, 2.3837838, 4.866938, 86.0603, 4.1092596, 0.39124367, 4.560316, 4.7320285, 2.6757293, 1.6857156, 1.596671, 2.1655583, 3.9297612], training = Val{true}())), l2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 

In [11]:
model = UNet(1, 1, 16)

UNet(
    l1 = Chain(
        layer_1 = Conv((3, 3), 1 => 16, pad=1),  [90m# 160 parameters[39m
        layer_2 = BatchNorm(16, leakyrelu, affine=true, track_stats=true),  [90m# 32 parameters[39m[90m, plus 33[39m
        layer_3 = Conv((3, 3), 16 => 16, pad=1),  [90m# 2_320 parameters[39m
        layer_4 = BatchNorm(16, leakyrelu, affine=true, track_stats=true),  [90m# 32 parameters[39m[90m, plus 33[39m
    ),
    l2 = Chain(
        layer_1 = MaxPool((2, 2)),
        layer_2 = Conv((3, 3), 16 => 32, pad=1),  [90m# 4_640 parameters[39m
        layer_3 = BatchNorm(32, leakyrelu, affine=true, track_stats=true),  [90m# 64 parameters[39m[90m, plus 65[39m
        layer_4 = Conv((3, 3), 32 => 32, pad=1),  [90m# 9_248 parameters[39m
        layer_5 = BatchNorm(32, leakyrelu, affine=true, track_stats=true),  [90m# 64 parameters[39m[90m, plus 65[39m
    ),
    l3 = Chain(
        layer_1 = MaxPool((2, 2)),
        layer_2 = Conv((3, 3), 32 => 64, pad=1),  [90m# 18_496 p

In [None]:
ŷ, st = Lux.apply(model, x, ps, st)

In [None]:


function zoom_pixel_values(img; mask=nothing)
    selection = (mask == nothing) ? nothing : findall(isone, mask)
    # Determine if we're working on the whole image or a selection
    working_img = (selection !== nothing) ? img[selection] : img
    
    a, b = minimum(working_img), maximum(working_img)
    img_ = (img .- a) ./ (b - a)
    
    if mask !== nothing
        img_ = img_ .* mask
    end

    return img_
end


function histogram_equalization(img; mask=nothing)
    selection = (mask == nothing) ? nothing : findall(isone, mask)
    img = zoom_pixel_values(img; mask = mask)
    # Determine if we're working on the whole image or a selection
    working_img = (selection !== nothing) ? img[selection] : img
    len = length(working_img)
    
    # Initialize histogram and cumulative histogram
    nbins = 256
    hist = zeros(Int, nbins)
    chist = zeros(Int, nbins)
    
    # Compute the histogram
    for val in working_img
        bin = Int(floor(val * (nbins - 1)) + 1)
        hist[bin] += 1
    end
    
    # Compute the cumulative histogram
    chist[1] = hist[1]
    for i in 2:nbins
        chist[i] = chist[i - 1] + hist[i]
    end
    
    # Perform histogram equalization
    min_chist = minimum(filter(x -> x > 0, chist))
    total_pixels = len
    new_img = copy(img)
    
    indices = (selection !== nothing) ? selection : 1:length(img)
    for i in indices
        bin = Int(floor(img[i] * (nbins - 1)) + 1)
        new_intensity = (chist[bin] - min_chist) / (total_pixels - min_chist)
        new_img[i] = Float32(new_intensity)
    end
    
    if mask !== nothing
        new_img = new_img .* mask
    end

    return new_img
end


# function xlogy(x, y)
#     result = x * log(y)
#     return ifelse(iszero(x), zero(result), result)
# end

function weighted_xlogy(x, y, weight)
    result = weight * x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function weighted_bce_loss(y, ŷ; w_pos = 25 , w_neg = 1, ϵ=1f-5)
    return mean(@. -(weighted_xlogy(y, ŷ .+ ϵ, w_pos) + weighted_xlogy(1f0 .- y, 1f0 .- ŷ .+ ϵ, w_neg)))
end


function bce_dice_and_hd_loss(ŷ,  y, epoch, step, save_output; HD_kick_in = 501, ϵ=1f-5)
    # x_size, y_size, _, num_batches = size(ŷ)
    # hd_weight = min(75f-2, max(0, epoch-HD_kick_in)*5f-3) # starts at epoch#51, max at epoch#200
    # hd_factor = 1.3332f-3 * hd_weight + 1.0f-7
    # dice
    loss_dice = 1f0 - (muladd(2f0, sum(ŷ .* y), ϵ) / (sum(ŷ .^ 2) + sum(y .^ 2) + ϵ))

    # binarycrossentropy
    # bce_loss = mean(@. -xlogy(y, ŷ .+ ϵ) - xlogy(1f0 .- y, 1f0 .- ŷ .+ ϵ))
    # bce_loss = mean(@. -xlogy(y, ŷ + ϵ) - xlogy(1f0 - y, 1f0 - ŷ + ϵ))
    bce_loss = weighted_bce_loss(y, ŷ)

    # # HD
    # ŷ_dtm = fill(1f3, (x_size, y_size, 1, num_batches))
    # y_dtm = fill(1f3, (x_size, y_size, 1, num_batches))
    # ŷ_cpu = ŷ |> dev_cpu
    # y_cpu = y |> dev_cpu
    # # FluxMPI.fluxmpi_println(extrema(ŷ_cpu))
    # ignore_derivatives() do
    #     # for chan_idx = 1:1
    #     #     for batch_idx = 1 : num_batches
    #     ŷ_cpu_round = round.(ŷ_cpu[:,:, 1, 1])
    #     if sum(ŷ_cpu_round) > 0f0
    #         ŷ_dtm[:,:, 1, 1] = 
    #         distance_transform(feature_transform(Bool.(ŷ_cpu_round)))
    #     end
    #     if sum(y_cpu[:,:, 1, 1]) > 0f0
    #         y_dtm[:,:, 1, 1] = 
    #         distance_transform(feature_transform(Bool.(round.(y_cpu[:,:, 1, 1]))))
    #     end
    #     #     end
    #     # end
    # end
    # loss_hd = mean(((ŷ_cpu .- y_cpu) .^ 2) .* (ŷ_dtm .^ 4 .+ y_dtm .^ 4))


    loss_total = loss_dice * 7f-1 + bce_loss * 3f-1
    
    ignore_derivatives() do
        # log_losses(epoch, step, loss_dice, bce_loss, loss_hd, hd_weight, loss_total, rank+1)
        if step <= 321
            log_losses(epoch, step, loss_dice, bce_loss, 0f0, 0f0, loss_total, rank+1)
        end
    end
    return loss_total * 25f-2
end

function bce_dice_and_hd_loss_testmode(ŷ,  y, epoch, step; HD_kick_in = 501, ϵ=1f-5)
    ignore_derivatives() do
        # x_size, y_size, _, num_batches = size(ŷ)
        # hd_weight = min(75f-1, max(0, epoch-HD_kick_in)*1f-2)
    
        # dice
        loss_dice = 1f0 - (muladd(2f0, sum(ŷ .* y), ϵ) / (sum(ŷ) + sum(y) + ϵ))
            # 1f0 - (muladd(2f0, sum(ŷ .* y), ϵ) / (sum(ŷ .^ 2) + sum(y .^ 2) + ϵ))
    
        # binarycrossentropy
        # bce_loss = mean(@. -xlogy(y, ŷ .+ ϵ) - xlogy(1f0 .- y, 1f0 .- ŷ .+ ϵ))
        # bce_loss = mean(@. -xlogy(y, ŷ + ϵ) - xlogy(1f0 - y, 1f0 - ŷ + ϵ))
        bce_loss = weighted_bce_loss(y, ŷ)
    
        # # HD
        # ŷ_dtm = fill(1f3, (x_size, y_size, 1, num_batches))
        # y_dtm = fill(1f3, (x_size, y_size, 1, num_batches))
        # ŷ_cpu = ŷ |> dev_cpu
        # y_cpu = y |> dev_cpu
        # ŷ_cpu_round = round.(ŷ_cpu[:,:, 1, 1])
        # if sum(ŷ_cpu_round) > 0f0
        #     ŷ_dtm[:,:, 1, 1] = 
        #     distance_transform(feature_transform(Bool.(ŷ_cpu_round)))
        # end
        # if sum(y_cpu[:,:, 1, 1]) > 0f0
        #     y_dtm[:,:, 1, 1] = 
        #     distance_transform(feature_transform(Bool.(round.(y_cpu[:,:, 1, 1]))))
        # end
        # loss_hd = mean(((ŷ_cpu .- y_cpu) .^ 2) .* (ŷ_dtm .^ 4 .+ y_dtm .^ 4))
        # if step % 10 == 0
        #     FluxMPI.fluxmpi_println("TESTING step $step / 50")
        # end
        # log_losses_testmode(epoch, step, loss_dice, bce_loss, loss_hd, rank+1)
        log_losses_testmode(epoch, step, loss_dice, bce_loss, 0f0, rank+1)

        # save_path = joinpath(data_dir, "outputs", "epoch#$epoch","test_$(rank+1)_$step.png")
        # save(save_path, Gray.(ŷ_cpu_round))

        return loss_dice
    end
end

function compute_loss(x, y, model, ps, st, epoch, step; save_output = true)
    # ŷ, st = model(x, ps, st)
    
    loss = bce_dice_and_hd_loss(ŷ, y, epoch, step, save_output)
    return loss, ŷ |> dev_cpu, st
end

function compute_loss_testmode(x, y, model, ps, st, epoch, step)
    # ŷ, st = model(x, ps, st)
    ŷ, st = Lux.apply(model, x, ps, st)
    dice_loss = bce_dice_and_hd_loss_testmode(ŷ, y, epoch, step)
    return dice_loss, ŷ |> dev_cpu, st
end

function auto_continue_training()
    i = 0
    while isfile("training/saved_train_info_$i.jld2")
        i += 1
    end
    return max(0, i-1)
end

# struct Scheduler{T, F}<: Optimisers.AbstractRule
#     constructor::F
#     schedule::T
# end

# _get_opt(scheduler::Scheduler, t) = scheduler.constructor(scheduler.schedule(t))

# Optimisers.init(o::Scheduler, x::AbstractArray) =
#     (t = 1, opt = Optimisers.init(_get_opt(o, 1), x))

# function Optimisers.apply!(o::Scheduler, state, x, dx)
#     opt = _get_opt(o, state.t)
#     new_state, new_dx = Optimisers.apply!(opt, state.opt, x, dx)

#     return (t = state.t + 1, opt = new_state), new_dx
# end

function train(start_epoch_idx, epoch_target, ps, st, opt, st_opt, model, train_loader, test_loader)
    FluxMPI.fluxmpi_println("Start training...")
    # global ps, st, st_opt

    for epoch in start_epoch_idx : epoch_target
        η = round(exp(-(epoch-1) * 1f-2) * 1f-2 + 1f-4; digits = 6)
        Optimisers.adjust!(st_opt, η)

        step = 0
        for (x_cpu, y_cpu) in train_loader
            step += 1
            x, y = x_cpu |> dev, y_cpu |> dev
            
            (loss, ŷ, st), back = pullback(p -> compute_loss(x, y, model, p, st, epoch, step), ps)

            gs = back((one.(loss), nothing, nothing))[1]
            # gs = back((one(loss), nothing, nothing))

            st_opt, ps = Optimisers.update(st_opt, ps, gs)

            # if save_output
                SID = split(train_paths[step], "/")[8]
                save_path = joinpath(data_dir, "outputs", "epoch#$epoch","train_$(rank+1)_$(step)_$(SID).png")
                save(save_path, Gray.(hcat(histogram_equalization(x_cpu[:,:,1,1]), ones(size(ŷ)[1], 1), y_cpu[:,:,1,1], ones(size(ŷ)[1], 1),round.(ŷ[:,:,1,1]))))
            # end

            CUDA.reclaim()
            # if step > 50
            #     break
            # end
        end

        # FluxMPI.fluxmpi_println("------------------------------------------------------------")
        # save
        if rank == 0
            # global ps, st
            # local 
            ps_save, st_save = ps |> dev_cpu, st |> dev_cpu
            @save joinpath(data_dir, "saved_train_info_$epoch.jld2") ps_save st_save
        end
        # test set
        # st_ = Lux.testmode(st)

        step = 0
        dice_losses = []
        for (x_cpu, y_cpu) in test_loader
            step += 1
            x, y = x_cpu |> dev, y_cpu |> dev
            dice_loss, ŷ, _ = compute_loss_testmode(x, y, model, ps, st, epoch, step)
            # if save_output
                SID = split(test_paths[step], "/")[8]
                save_path = joinpath(data_dir, "outputs", "epoch#$epoch","test_$(rank+1)_$(step)_$(SID).png")
                save(save_path, Gray.(hcat(histogram_equalization(x_cpu[:,:,1,1]), ones(size(ŷ)[1], 1), y_cpu[:,:,1,1], ones(size(ŷ)[1], 1),round.(ŷ[:,:,1,1]))))
            # end

            if sum(y_cpu) > 0
                push!(dice_losses, dice_loss)
            end
        end
        if epoch % 2 == 0
            FluxMPI.fluxmpi_println("================== $epoch: η = $η, test set: $(mean(dice_losses)) ==================")
        else
            FluxMPI.fluxmpi_println("------------------ $epoch: η = $η, test set: $(mean(dice_losses)) ------------------")
        end
        CUDA.reclaim()
    end
end

function log_losses(epoch, step, loss_dice, bce_loss, loss_hd, hd_weight, loss_total, id)
    filename="training/training_log_$id.csv"
    # Check if the file exists
    if isfile(filename)
        # Load the existing data
        df = CSV.File(filename) |> DataFrame
    else
        # Create a new DataFrame
        df = DataFrame(TimeStamp = String[], Epoch = Int[], Step = Int[], LossDice = Float64[], BCELoss = Float64[], LOSS_HD = Float64[], HD_WEIGHT = Float64[], LOSS_TOTAL = Float64[])
    end

    # Append the new data
    new_row = DataFrame(TimeStamp = Dates.format(now(), "yyyy-mm-dd HH:MM:SS"), Epoch = epoch, Step = step, LossDice = loss_dice, BCELoss = bce_loss, LOSS_HD = loss_hd, HD_WEIGHT = hd_weight, LOSS_TOTAL = loss_total)
    append!(df, new_row)

    # Write the updated DataFrame to the CSV file
    CSV.write(filename, df)
end

function log_losses_testmode(epoch, step, loss_dice, bce_loss, loss_hd, id)
    filename="training/testing_log_$id.csv"
    # Check if the file exists
    if isfile(filename)
        # Load the existing data
        df = CSV.File(filename) |> DataFrame
    else
        # Create a new DataFrame
        df = DataFrame(TimeStamp = String[], Epoch = Int[], Step = Int[], LossDice = Float64[], BCELoss = Float64[], LOSS_HD = Float64[])
    end

    # Append the new data
    new_row = DataFrame(TimeStamp = Dates.format(now(), "yyyy-mm-dd HH:MM:SS"), Epoch = epoch, Step = step, LossDice = loss_dice, BCELoss = bce_loss, LOSS_HD = loss_hd)
    append!(df, new_row)

    # Write the updated DataFrame to the CSV file
    CSV.write(filename, df)
end

isfile("training/training_log_$(rank+1).csv") && rm("training/training_log_$(rank+1).csv")
isfile("training/testing_log_$(rank+1).csv") && rm("training/testing_log_$(rank+1).csv")

# read data
@load joinpath("JLD2s/train_loader_$(rank+1).jld2") data_loader
train_loader_ = data_loader

@load joinpath("JLD2s/test_loader_$(rank+1).jld2") data_loader
test_loader = data_loader

@load "JLD2s/train_dl_paths_$(rank+1).jld2" paths
train_paths = paths

@load "JLD2s/test_dl_paths_$(rank+1).jld2" paths
test_paths = paths

# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)
l_r = 1f-3
model = UNet(1, 1, 16)

# start_epoch_idx = auto_continue_training() # Replace 0 with other epoch idx if training on a saved model
start_epoch_idx = 0
epoch_target = 500 # end epoch idx
# if start_epoch_idx == 0
    # train new model 
    ps, st = Lux.setup(rng, model)

    ps = ps |> dev
    st = st |> dev

    ps = FluxMPI.synchronize!(ps; root_rank = 0)
    st = FluxMPI.synchronize!(st; root_rank = 0)
    # opt_ = Scheduler(Exp(1f-2, 8f-1)) do lr 
    #     NAdam(lr) 
    # end
    
    # opt = DistributedOptimizer(opt_)
    opt = DistributedOptimizer(Adam(l_r))
    st_opt = Optimisers.setup(opt, ps)
    st_opt = FluxMPI.synchronize!(st_opt; root_rank = 0)
    
    if rank == 0
        local ps_save, st_save = ps |> dev_cpu, st |> dev_cpu
        @save joinpath(data_dir, "saved_train_info_0.jld2") ps_save st_save
    end
    train(start_epoch_idx+1, epoch_target, ps, st, opt, st_opt, model, train_loader_, test_loader)
# else
    # # load saved model 
    # @load "saved_train_info_$start_epoch_idx.jld2" ps_save st_save
    # ps = ps_save |> dev
    # st = st_save |> dev

    # # opt = Scheduler(Exp(1f-2, 8f-1)) do lr 
    # #     NAdam(lr) 
    # # end

    # opt = DistributedOptimizer(opt)
    # st_opt = Optimisers.setup(opt, ps)

    # train(start_epoch_idx+1, epoch_target, ps, st, opt, st_opt, model, train_loader_)
# end


