In [65]:
import Pkg

# Pkg.add("NNlib")
# Pkg.add("CUDA")
# Pkg.add("DICOM")
# Pkg.add(url="https://github.com/MolloiLab/imageToolBox.jl")
# Pkg.add("Lux")
# Pkg.add("LuxCUDA")
# Pkg.add("NNlib")
# Pkg.add("NIfTI")
# Pkg.add("Images")
# Pkg.add("ImageFiltering")
# Pkg.add("Statistics")
# Pkg.add("ImageMorphology")
# Pkg.add("CSV")
# Pkg.add("DataFrames")
Pkg.add("Printf")

[32m[1m   Resolving[22m[39m package versions...
[32m[1m    Updating[22m[39m `~/.julia/environments/v1.9/Project.toml`
  [90m[de0858da] [39m[92m+ Printf[39m
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.9/Manifest.toml`


In [66]:
using Pkg
# add the following packages before running
using Lux, NNlib, LuxCUDA, CUDA, JLD2, NIfTI, DICOM
using Images, imageToolBox, ImageFiltering, Statistics
using ImageMorphology
using CSV, DataFrames
using Printf

In [9]:
# The U-net model

_conv = (in, out) -> Conv((3, 3), in => out, pad=1)
conv1 = (in, out) -> Chain(_conv(in, out), BatchNorm(out, leakyrelu))

_tran = (in, out) -> ConvTranspose((2, 2), in => out, stride=2)
tran = (in, out) -> Chain(_tran(in, out), BatchNorm(out, leakyrelu))

struct UNet{
    CH1,CH2,CH3,CH4,CH5,CH6,CH7
} <: Lux.AbstractExplicitContainerLayer{
    (:l1, :l2, :l3, :l4, :l5, :l6, :l7)
}
    l1::CH1
    l2::CH2
    l3::CH3
    l4::CH4
    l5::CH5
    l6::CH6
    l7::CH7
end

function UNet(in_chs, lbl_chs, size)
    l1 = Chain(conv1(in_chs, size), conv1(size, size))
    l2 = Chain(MaxPool((2, 2), stride=2), conv1(size, size * 2), conv1(size * 2, size * 2))
    l3 = Chain(MaxPool((2, 2), stride=2), conv1(size * 2, size * 4), conv1(size * 4, size * 4))
    l4 = Chain(MaxPool((2, 2), stride=2), conv1(size * 4, size * 8), conv1(size * 8, size * 8), tran(size * 8, size * 4))

    # Expanding layers
    l5 = Chain(conv1(size * 8, size * 4), conv1(size * 4, size * 4), tran(size * 4, size * 2))
    l6 = Chain(conv1(size * 4, size * 2), conv1(size * 2, size * 2), tran(size * 2, size))
    l7 = Chain(conv1(size * 2, size), conv1(size, size), Conv((1, 1), size => lbl_chs), sigmoid)

    UNet(l1, l2, l3, l4, l5, l6, l7)
end

function (m::UNet)(x, ps, st::NamedTuple)
    # Convolutional layers
    x1, st_l1 = m.l1(x, ps.l1, st.l1)

    x2, st_l2 = m.l2(x1, ps.l2, st.l2)

    # Downscaling Blocks
    x3, st_l3 = m.l3(x2, ps.l3, st.l3)
    x4, st_l4 = m.l4(x3, ps.l4, st.l4)

    # Upscaling Blocks
    x5, st_l5 = m.l5(cat(x4, x3; dims=3), ps.l5, st.l5)
    x6, st_l6 = m.l6(cat(x5, x2; dims=3), ps.l6, st.l6)
    x7, st_l7 = m.l7(cat(x6, x1; dims=3), ps.l7, st.l7)


    # Merge states
    st = (
        l1=st_l1, l2=st_l2, l3=st_l3, l4=st_l4, l5=st_l5, l6=st_l6, l7=st_l7
    )

    return x7, st
end

model_to_use = UNet(1, 1, 16)

UNet(
    l1 = Chain(
        layer_1 = Conv((3, 3), 1 => 16, pad=1),  [90m# 160 parameters[39m
        layer_2 = BatchNorm(16, leakyrelu, affine=true, track_stats=true),  [90m# 32 parameters[39m[90m, plus 33[39m
        layer_3 = Conv((3, 3), 16 => 16, pad=1),  [90m# 2_320 parameters[39m
        layer_4 = BatchNorm(16, leakyrelu, affine=true, track_stats=true),  [90m# 32 parameters[39m[90m, plus 33[39m
    ),
    l2 = Chain(
        layer_1 = MaxPool((2, 2)),
        layer_2 = Conv((3, 3), 16 => 32, pad=1),  [90m# 4_640 parameters[39m
        layer_3 = BatchNorm(32, leakyrelu, affine=true, track_stats=true),  [90m# 64 parameters[39m[90m, plus 65[39m
        layer_4 = Conv((3, 3), 32 => 32, pad=1),  [90m# 9_248 parameters[39m
        layer_5 = BatchNorm(32, leakyrelu, affine=true, track_stats=true),  [90m# 64 parameters[39m[90m, plus 65[39m
    ),
    l3 = Chain(
        layer_1 = MaxPool((2, 2)),
        layer_2 = Conv((3, 3), 32 => 64, pad=1),  [90m# 18_496 p

In [10]:
function crop_to_bounding_box(mask, image, ground_truth_mask)
    # Find rows and columns where the mask has value 1
    indices = findall(x -> x == 1, mask)
    if isempty(indices)
        return "Error in breast mask!" # Return the original mask if no 1s are found
    end

    # Extract row and column indices from the CartesianIndex array
    rows = [index[1] for index in indices]
    cols = [index[2] for index in indices]

    # Determine the initial bounding box
    min_row, max_row = minimum(rows), maximum(rows)
    min_col, max_col = minimum(cols), maximum(cols)

    # Adjust dimensions to be divisible by 32
    width, height = max_col - min_col + 1, max_row - min_row + 1
    if width % 32 != 0
        width_adjustment = 32 - (width % 32)
        max_col += width_adjustment
        # Move the bounding box if it exceeds mask dimensions
        if max_col > size(mask, 2)
            min_col = max(1, min_col - (max_col - size(mask, 2)))
            max_col = size(mask, 2)
        end
    end
    if height % 32 != 0
        height_adjustment = 32 - (height % 32)
        max_row += height_adjustment
        # Move the bounding box if it exceeds mask dimensions
        if max_row > size(mask, 1)
            min_row = max(1, min_row - (max_row - size(mask, 1)))
            max_row = size(mask, 1)
        end
    end
    rslt = image[min_row:max_row, min_col:max_col]
    rslt2 = ground_truth_mask[min_row:max_row, min_col:max_col]
    rslt3 = [min_row, max_row, min_col, max_col]

    x, y = size(rslt)
    res = x % 32
    if res != 0
        top_trim = round(Int, res / 2)
        bot_trim = res - top_trim
        rslt = rslt[top_trim+1:end-bot_trim, 1:end]
        rslt2 = rslt2[top_trim+1:end-bot_trim, 1:end]
        rslt3[1] += top_trim
        rslt3[2] -= bot_trim
    end

    res = y % 32
    if res != 0
        left_trim = round(Int, res / 2)
        right_trim = res - left_trim
        rslt = rslt[1:end, left_trim+1:end-right_trim]
        rslt2 = rslt2[1:end, left_trim+1:end-right_trim]
        rslt3[3] += left_trim
        rslt3[4] -= right_trim
    end

    return rslt, rslt2, rslt3
end

function resize_dicom_image(image, mask, ground_truth_mask, original_spacing; target_spacing=[0.13, 0.13])
    # Calculate the scaling factors
    scale_x = original_spacing[1] / target_spacing[1]
    scale_y = original_spacing[2] / target_spacing[2]

    new_size_x = round(Int, size(image, 1) * scale_y)
    new_size_y = round(Int, size(image, 2) * scale_x)
    # Resample the image
    resized_img = imresize(image, (new_size_x, new_size_y))
    resized_mask = imresize(mask, (new_size_x, new_size_y))
    resized_ground_truth_mask = imresize(ground_truth_mask, (new_size_x, new_size_y))
    return resized_img, resized_mask, resized_ground_truth_mask, [new_size_x, new_size_y]
end

resize_dicom_image (generic function with 1 method)

In [11]:
path_to_saved_model = "saved_train_info_112.jld2"
dev = gpu_device()

└ @ LuxCUDA /Users/harryxiong24/.julia/packages/LuxCUDA/QvUoj/src/LuxCUDA.jl:20
│ 
│ 1. If no GPU is available, nothing needs to be done.
│ 2. If GPU is available, load the corresponding trigger package.
│     a. LuxCUDA.jl for NVIDIA CUDA Support!
│     b. LuxAMDGPU.jl for AMD GPU ROCM Support!
│     c. Metal.jl for Apple Metal GPU Support!
└ @ LuxDeviceUtils /Users/harryxiong24/.julia/packages/LuxDeviceUtils/eyk2C/src/LuxDeviceUtils.jl:154


(::LuxCPUDevice) (generic function with 5 methods)

In [12]:
@load path_to_saved_model ps_save st_save
ps, st = ps_save |> dev, st_save |> dev

((l1 = (layer_1 = (weight = [-0.13811347 -0.10989219 -0.45017767; -0.03557757 0.15679145 -0.10025815; 0.4154285 0.11406401 -0.1702349;;;; -0.2519165 -0.2005901 0.269674; 0.21714726 0.121530145 0.1745333; 0.32218948 -0.15850435 -0.16568863;;;; 0.05944323 -0.19018307 0.27134103; 0.1708762 -0.094147116 0.019200183; 0.12004882 -0.29905564 -0.2584902;;;; … ;;;; 0.34197986 -0.2045819 -0.03704661; 0.17001131 -0.23942131 -0.083972864; 0.118663475 0.054883588 0.13577364;;;; 0.21762453 -0.1118695 -0.12177359; -0.29748553 0.5254144 0.24555175; -0.0060818857 -0.16924761 0.020385385;;;; 0.05229594 0.20575164 -0.1876456; 0.031438395 -0.21704458 -0.39874387; 0.33417702 -0.10161231 -0.0084489], bias = [0.44849735;;; 0.11582101;;; -0.3342599;;; … ;;; 0.10559988;;; 0.6723078;;; -0.23398082;;;;]), layer_2 = (scale = Float32[0.41381815, 0.38657746, 0.48907733, 0.97282636, 1.3967147, 0.3998129, 0.5293896, 0.43151233, 1.6391323, 0.60299855, 0.466895, 0.37334907, 0.70797276, 0.3011922, 1.3915981, 0.7733621],

In [13]:
root = raw"./dataset/inter-observer/50_Sohrab_Fati/output"

"./dataset/inter-observer/50_Sohrab_Fati/output"

In [14]:
function prepare_image(curr_dir, sid, f_name)
    dcm_path = joinpath(curr_dir, sid, "source", f_name * ".dcm")
    breast_mask_path = joinpath(curr_dir, sid, "source", f_name * ".mask.png")
    # ground_truth_mask_path = joinpath(curr_dir, sid, f_name * ".png")

    # ground_truth_mask = Float32.(Images.load(ground_truth_mask_path))
    breast_mask = Float32.(Images.load(breast_mask_path))
    dcm_data = dcm_parse(dcm_path)
    pixel_size = dcm_data[(0x0018, 0x1164)]
    is_reversed = uppercase(dcm_data[(0x2050, 0x0020)]) == "INVERSE"
    img = Float32.(dcm_data[(0x7fe0, 0x0010)])
    original_size = size(img)
    ground_truth_mask = zeros(Float32, original_size)
    # resize image based on pixel length
    img, breast_mask, ground_truth_mask, new_size = resize_dicom_image(img, breast_mask, ground_truth_mask, pixel_size)
    # normalize image and correct color
    img = normalize_img(img; mask=breast_mask, invert=is_reversed)
    # crop to breast only
    img_cropped, ground_truth_mask_cropped, coords = crop_to_bounding_box(breast_mask, img, ground_truth_mask)
    # I CAN THROW THE 'img_cropped' thing into the finished BAC model.
    # save resize info to local
    @save joinpath(curr_dir, sid, "source", f_name * "_resize_info.jld2") original_size new_size coords
    # check size
    x, y = size(img_cropped)
    # if y % 32 != 0
    #     x_org, y_org = size(img)
    #     println(i, "\t", ct+1)
    #     println("($x_org, $y_org)")
    #     println("($x, $y)\n")
    # end
    @assert x % 32 == 0
    @assert y % 32 == 0

    #save
    @save joinpath(curr_dir, sid, "source", f_name * "_cropped.jld2") img_cropped
    # Images.save(joinpath(out_dir, f_name*".png"), Gray.(round.(ground_truth_mask_cropped)))
    GC.gc(true)

    return img_cropped
end

prepare_image (generic function with 1 method)

In [47]:
function apply_model_to_directory(main_directory::String)
  # Initialize a counter to keep track of the number of files processed
  processed_count = 0
  max_files = 4

  # Iterate over each subdirectory in the main directory
  for subdirectory in readdir(main_directory, join=true)
    # Process only directories starting with "SID"
    if isdir(subdirectory) && occursin(r"^SID", basename(subdirectory))
      source_dir = joinpath(subdirectory, "source")

      # Process each file in the 'source' subdirectory
      for file in readdir(source_dir, join=true)

        # if processed_count >= max_files
        #   break
        # end

        file_name, ext = splitext(file)

        # Check if the file is a DICOM file
        if isfile(file) && occursin(r".*\.dcm$", file)
          # Load and prepare the image
          x = prepare_image(main_directory, basename(subdirectory), basename(file_name))
          x = Float32.(reshape(x, size(x)..., 1, 1))

          x = x |> dev

          # Apply the model
          ŷ, _ = Lux.apply(model_to_use, x, ps, st)
          ŷ = round.(ŷ |> cpu)[:, :, 1, 1]

          # Load resize information and apply size conversion
          resize_path = joinpath(source_dir, basename(file_name) * "_resize_info.jld2")
          @load resize_path new_size coords original_size

          ŷ = convert_to_og_size(ŷ, original_size, new_size, coords)

          # Save the prediction and denoise
          output_dir = joinpath(subdirectory, "predict")
          output_path = joinpath(output_dir, basename(file_name) * "_predict_original.png")
          save(output_path, Gray.(ŷ))

          drop_area_path = joinpath(output_dir, basename(file_name) * "_predict.png")

          # Drop small areas
          drop_small_areas(output_path, drop_area_path, 800)

          # Denoise the image
          denoise(drop_area_path, drop_area_path, (5, 5))

          # Increment the processed file count
          processed_count += 1
        end
      end
    end
  end
end


apply_model_to_directory (generic function with 1 method)

In [48]:
function convert_to_og_size(ŷ_cropped, original_size, new_size, coords)

  # convert it back to the original size
  ŷ = zeros(Float32, (new_size...))
  a, b, c, d = coords
  ŷ[a:b, c:d] = ŷ_cropped
  ŷ = round.(imresize(ŷ, original_size))

  return ŷ
end

convert_to_og_size (generic function with 1 method)

In [49]:
function denoise(input, output, filter_size)
  img = load(input)
  filtered_img = mapwindow(median, img, filter_size)
  save(output, filtered_img)
end

denoise (generic function with 1 method)

In [50]:
function drop_small_areas(input, output, threshold_area=1000)
    img = load(input)
    gray_img = Gray.(img)
    binary_img = gray_img .> 0.5
    labels = label_components(binary_img)

    areas = Dict{Int,Int}()

    for i in eachindex(labels)
        label = labels[i]
        if label != 0
            if haskey(areas, label)
                areas[label] += 1
            else
                areas[label] = 1
            end
        end
    end

    output_img = similar(binary_img)

    for i in eachindex(labels)
        label = labels[i]
        if label != 0 && areas[label] > threshold_area
            output_img[i] = true
        else
            output_img[i] = false
        end
    end

    save(output, output_img)
end

drop_small_areas (generic function with 2 methods)

In [52]:
apply_model_to_directory(root)

In [63]:
function dice_loss(y_true, y_pred)::Float32

  if size(y_true) != size(y_pred)
    error("Input images must have the same dimensions")
  end

  y_true = Float32.(Gray.(y_true) .> 0.5)
  y_pred = Float32.(Gray.(y_pred) .> 0.5)

  y_true_flat = vec(y_true)
  y_pred_flat = vec(y_pred)

  intersection = sum(y_true_flat .* y_pred_flat)
  true_sum = sum(y_true_flat)
  pred_sum = sum(y_pred_flat)

  dice_coefficient = 2 * intersection / (true_sum + pred_sum)
  return 1 - dice_coefficient
end

dice_loss (generic function with 1 method)

In [74]:
function score(main_directory)
  results = DataFrame(sid=String[], image_id=String[], fati_vs_predict=String[], sohrab_vs_predict=String[])

  # Initialize a counter to keep track of the number of files processed
  processed_count = 0
  max_files = 4

  # Iterate over each subdirectory in the main directory
  for subdirectory in readdir(main_directory, join=true)
    SID = split(subdirectory, '/')[end]
    # Process only directories starting with "SID"
    if isdir(subdirectory) && occursin(r"^SID", basename(subdirectory))
      predict_dir = joinpath(subdirectory, "predict")
      fati_dir = joinpath(subdirectory, "fati")
      sohrab_dir = joinpath(subdirectory, "sohrab")

      # if processed_count >= max_files
      #   break
      # end

      for file_predict in readdir(predict_dir, join=true)

        file_name, ext = splitext(basename(file_predict))
        sid = split(file_name, '_')[1]

        # Check if the file is a DICOM file
        if isfile(file_predict) && occursin(r"._predict.png", file_predict)

          file_path_predict = joinpath(predict_dir, sid * "_predict.png")
          file_path_a = joinpath(fati_dir, sid * ".png")
          file_path_b = joinpath(sohrab_dir, sid * ".png")


          if isfile(file_path_a) && isfile(file_path_b)
            img_predict = load(file_path_predict)
            img_a = load(file_path_a)
            img_b = load(file_path_b)

            fati_vs_predict = dice_loss(img_a, img_predict)
            sohrab_vs_predict = dice_loss(img_b, img_predict)

            push!(results, (SID, sid, @sprintf("%.2f", fati_vs_predict), @sprintf("%.2f", sohrab_vs_predict)))

          end

          # Increment the processed file count
          processed_count += 1

        end
      end

    end
  end

  CSV.write("results.csv", results)
end


score (generic function with 1 method)

In [75]:
score(root)

"results.csv"

In [None]:
get_BAC_mass_data()