# This julia notebook shows how to load the model and apply it to an image.

In [5]:
using Pkg
Pkg.activate("/home/molloi-lab/Desktop/Project BAC/BAC project/libs/")
using Lux, Random, NNlib, Zygote, LuxCUDA, CUDA, FluxMPI, JLD2, DICOM
using Images
using MLUtils
using MPI
using Optimisers
using ImageMorphology, ChainRulesCore, Statistics, CSV, DataFrames, Dates

[32m[1m  Activating[22m[39m project at `~/Desktop/Project BAC/BAC project/libs`
└ @ FluxMPI /home/molloi-lab/.julia/packages/FluxMPI/BwbGS/src/FluxMPI.jl:28


In [6]:
CUDA.allowscalar(false)

FluxMPI.Init(;gpu_devices = [0,1,2,3])

# Rank(similar to threadID) of the current process.
comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)
dev = gpu_device()
dev_cpu = cpu_device()

_conv = (in, out) -> Conv((3, 3), in=>out, pad=1)
conv1 = (in, out) -> Chain(_conv(in, out), BatchNorm(out, leakyrelu))

_tran = (in, out) -> ConvTranspose((2, 2), in => out, stride = 2)
tran = (in, out) -> Chain(_tran(in, out), BatchNorm(out, leakyrelu))

struct UNet{
    CH1, CH2, CH3, CH4, CH5, CH6, CH7
} <: Lux.AbstractExplicitContainerLayer{
    (:l1, :l2, :l3, :l4, :l5, :l6, :l7)
}
    l1::CH1
    l2::CH2
    l3::CH3
    l4::CH4
    l5::CH5
    l6::CH6
    l7::CH7
end

function UNet(in_chs, lbl_chs, size)
    l1 = Chain(conv1(in_chs, size), conv1(size, size))
    l2 = Chain(MaxPool((2,2), stride=2), conv1(size, size*2), conv1(size*2, size*2))
    l3 = Chain(MaxPool((2,2), stride=2), conv1(size*2, size*4), conv1(size*4, size*4))
    l4 = Chain(MaxPool((2,2), stride=2), conv1(size*4, size*8), conv1(size*8, size*8), tran(size*8, size*4))

    # Expanding layers
    l5 = Chain(conv1(size*8, size*4), conv1(size*4, size*4), tran(size*4, size*2))
    l6 = Chain(conv1(size*4, size*2), conv1(size*2, size*2), tran(size*2, size))
    l7 = Chain(conv1(size*2, size), conv1(size, size), Conv((1, 1), size=>lbl_chs), sigmoid)

    UNet(l1, l2, l3, l4, l5, l6, l7)
end

function (m::UNet)(x, ps, st::NamedTuple)
    # Convolutional layers
    x1, st_l1 = m.l1(x, ps.l1, st.l1)

    x2, st_l2 = m.l2(x1, ps.l2, st.l2)

    # Downscaling Blocks
    x3, st_l3 = m.l3(x2, ps.l3, st.l3)
    x4, st_l4 = m.l4(x3, ps.l4, st.l4)

    # Upscaling Blocks
    x5, st_l5 = m.l5(cat(x4, x3; dims=3), ps.l5, st.l5)
    x6, st_l6 = m.l6(cat(x5, x2; dims=3), ps.l6, st.l6)
    x7, st_l7 = m.l7(cat(x6, x1; dims=3), ps.l7, st.l7)


    # Merge states
    st = (
    l1=st_l1, l2=st_l2, l3=st_l3, l4=st_l4, l5=st_l5, l6=st_l6, l7=st_l7
    )

    return x7, st
end

In [7]:
# input data
input_img = "/media/molloi-lab/2TB1/Clean_Dataset_full/SID-100510/L_CC.3328_2560.dcm"

"/media/molloi-lab/2TB1/Clean_Dataset_full/SID-100510/L_CC.3328_2560.dcm"

In [None]:
# # Prepare the image
# ground_truth_mask = Float32.(Images.load(ground_truth_mask_path))
# breast_mask = Float32.(Images.load(breast_mask_path))
# dcm_data = dcm_parse(dcm_path)
# is_reversed = uppercase(dcm_data[(0x2050, 0x0020)]) == "INVERSE"
# pixel_size = dcm_data[(0x0018, 0x1164)]
# img = Float32.(dcm_data[(0x7fe0, 0x0010)])
# original_size = size(img)
# # resize image based on pixel length
# img, breast_mask, ground_truth_mask, new_size = resize_dicom_image(img, breast_mask, ground_truth_mask, pixel_size)
# # normalize image and correct color
# img = normalize_img(img; mask = breast_mask, invert = is_reversed)
# # crop to breast only
# img_cropped, ground_truth_mask_cropped, coords = crop_to_bounding_box(breast_mask, img, ground_truth_mask)
# # save resize info to local
# @save joinpath(curr_dir, f_name*"_resize_info.jld2") original_size new_size coords
# # check size
# x, y = size(img_cropped)
# # if y % 32 != 0
# #     x_org, y_org = size(img)
# #     println(i, "\t", ct+1)
# #     println("($x_org, $y_org)")
# #     println("($x, $y)\n")
# # end
# @assert x % 32 == 0
# @assert y % 32 == 0

# #save
# @save joinpath(out_dir, f_name*".jld2") img_cropped
# Images.save(joinpath(out_dir, f_name*".png"), Gray.(round.(ground_truth_mask_cropped)))


In [9]:
# load the trained model
model_path = "/home/molloi-lab/Desktop/wenbo2_flashdrive_backup/saved_train_info_334.jld2"
@load model_path ps_save st_save
ps_save = ps_save |> dev
st_save = st_save |> dev

# ps = FluxMPI.synchronize!(ps; root_rank = 0)
# st = FluxMPI.synchronize!(st; root_rank = 0)

(l1 = (layer_1 = NamedTuple(), layer_2 = (running_mean = Float32[-1.5754029, 2.2304332, -1.6438724, 2.1291182, 1.5642703, -4.459931, -1.4128605, 2.5394652, 0.08220151, 1.3944055, -2.2787423, 2.0576603, -1.6820405, 1.7552352, 2.6440036, -2.1005633], running_var = Float32[2.0897603, 2.1679688, 0.8702852, 0.7614294, 0.4340817, 9.882541, 1.1594844, 3.4043326, 0.6023198, 1.5117741, 2.6206417, 2.2926908, 1.1610765, 1.3491561, 1.8846543, 1.7653261], training = Val{true}()), layer_3 = NamedTuple(), layer_4 = (running_mean = Float32[-0.26003373, 0.75799036, 1.0921711, -2.363193, -2.6438394, 0.91217846, -11.351362, 1.798894, -0.120637566, -1.7964456, -2.17908, 0.97645134, -0.6907514, -0.715033, 0.84746337, 1.8764355], running_var = Float32[0.86277986, 2.34481, 0.5159975, 3.5334196, 2.3837838, 4.866938, 86.0603, 4.1092596, 0.39124367, 4.560316, 4.7320285, 2.6757293, 1.6857156, 1.596671, 2.1655583, 3.9297612], training = Val{true}())), l2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 

In [11]:
model = UNet(1, 1, 16)

UNet(
    l1 = Chain(
        layer_1 = Conv((3, 3), 1 => 16, pad=1),  [90m# 160 parameters[39m
        layer_2 = BatchNorm(16, leakyrelu, affine=true, track_stats=true),  [90m# 32 parameters[39m[90m, plus 33[39m
        layer_3 = Conv((3, 3), 16 => 16, pad=1),  [90m# 2_320 parameters[39m
        layer_4 = BatchNorm(16, leakyrelu, affine=true, track_stats=true),  [90m# 32 parameters[39m[90m, plus 33[39m
    ),
    l2 = Chain(
        layer_1 = MaxPool((2, 2)),
        layer_2 = Conv((3, 3), 16 => 32, pad=1),  [90m# 4_640 parameters[39m
        layer_3 = BatchNorm(32, leakyrelu, affine=true, track_stats=true),  [90m# 64 parameters[39m[90m, plus 65[39m
        layer_4 = Conv((3, 3), 32 => 32, pad=1),  [90m# 9_248 parameters[39m
        layer_5 = BatchNorm(32, leakyrelu, affine=true, track_stats=true),  [90m# 64 parameters[39m[90m, plus 65[39m
    ),
    l3 = Chain(
        layer_1 = MaxPool((2, 2)),
        layer_2 = Conv((3, 3), 32 => 64, pad=1),  [90m# 18_496 p

In [None]:
ŷ, st = Lux.apply(model, x, ps, st)