In [11]:
using HDF5,Statistics, ImageView

In [2]:
using Knet

# Preprocessing

In [9]:
file_path = "cropped_data_0_10.h5";
data = h5read(file_path, "data");
X_train = convert(KnetArray, reshape(data[:,:,:,2], 256,256,1,10));
Y_train = convert(KnetArray, reshape(data[:,:,:,1], 256,256,1,10));

ErrorException: KnetPtr: bad device id -1.

# Initialization

# Model

In [None]:
function flatten(array)
    size_array = size(array)
    length_flattened = 1
    for dim = 1:ndims(array)
        length_flattened = length_flattened * size_array[dim]
    end
    return flattened = reshape(array,length_flattened)
end

In [None]:
function cat3d(arr1, arr2)
    R = size(arr1)[1]
    C = size(arr1)[2]
    Ch = size(arr1)[3]+size(arr2)[3]
    S = size(arr1)[4]
    
    arr1 = flatten(permutedims(arr1, [1,2,4,3]))
    arr2 = flatten(permutedims(arr2, [1,2,4,3]))
    out = reshape(cat1d(arr1,arr2),R,C,S,Ch)
    out = permutedims(out,[1,2,4,3])
    
    return out
end

In [None]:
function conv_layer_encoder(w, bn_params, x)
    out = conv4(w[1], x, padding=1).+ w[2]
    out = batchnorm(out, bn_moments[1], w[3])
    conv_1 = relu.(out)
    
    out = conv4(w[4], conv_1, padding=1) .+ w[5]
    out = batchnorm(out, bn_moments[2], w[6])
    out = relu.(out)
    pooled = pool(out)
    
    return conv_1, pooled
end

In [None]:
function conv_layer_encoder(w, bn_params, x, )
    out = conv4(w[1], x, padding=1).+ w[2]
    out = batchnorm(out, bn_moments[1], w[3])
    conv_1 = relu.(out)
    
    out = conv4(w[4], conv_1, padding=1) .+ w[5]
    out = batchnorm(out, bn_moments[2], w[6])
    out = relu.(out)
    pooled = pool(out)
    
    return conv_1, pooled
end

In [None]:
function conv_layer_decoder(w, bn_moments, x_cat, x)
    number_of_channels = size(x)[3]
    upsampling_kernel = convert(KnetArray, bilinear(Float64,2,2,number_of_channels,number_of_channels))
    out = deconv4(upsampling_kernel, x, padding = 1,stride = 2) 
    out = cat3d(out, x_cat)
    
    out = conv4(w[1], out, padding = 1) .+ w[2]
    out = batchnorm(out, bn_moments[1], w[3])
    out = relu.(out)
    
    out = conv4(w[4], out, padding = 1) .+ w[5]
    out = batchnorm(out, bn_moments[2], w[6])
    out = relu.(out)
    
    return out
end

In [None]:
function bottleneck(w, bn_moments, x)
    out = conv4(w[1], x, padding = 1) .+ w[2]
    out = batchnorm(out, bn_moments[1], w[3])
    out = relu.(out)
    
    out = conv4(w[4], out, padding = 1) .+ w[5]
    out = batchnorm(out, bn_moments[2], w[6])
    out = relu.(out)
    
    return out
end

In [None]:
function output_layer(w, x)
    out = sigm.(conv4(w[1], x, padding = 1).+ w[2])
    
    return out
end

In [None]:
function predict(x_in, w, bn_moments, training)
    conv_1, out = conv_layer_encoder(w[1:6], bn_moments[1:2], x_in, training=training);
    # layer 2
    conv_2, out = conv_layer_encoder(w[7:12], bn_moments[3:4], out, training=training);
    #layer 3
    conv_3, out = conv_layer_encoder(w[13:18], bn_moments[5:6], out, training=training);
    # layer 4 = bottleneck
    out = bottleneck(w[19:24], bn_moments[7:8], out);
    # layer 5
    out = conv_layer_decoder(w[25:30], bn_moments[9:10], conv_3, out);
    # layer 6
    out = conv_layer_decoder(w[31:36], bn_moments[11:12], conv_2, out);
    # layer 7
    out = conv_layer_decoder(w[37:42], bn_moments[13:14], conv_1, out);
    println(size(out))
    out = output_layer(w[43:44], out)
    
    return out
end

# Train

In [None]:
function dice_coeff(X, Y, predict, w, bn_moments)
    lambda = 1
    Y_pred = predict(X, w, bn_moments)
    Y_pred_flatten = flatten(Y_pred)
    Y_flatten = flatten(Y)
    intersection = sum(Y_pred_flatten .* Y_flatten)
    union_area = sum(Y_pred_flatten) + sum(Y_flatten) 
    
    return (2*intersection)/(union_area + lambda)
end

In [None]:
w, bn_moments = init_model();
lr = 0.001
optim = optimizers(w, Adam;  lr=lr);

In [None]:
loss(w, bn_moments, x_in, y_true, predict) =  bce(predict(x_in, w, bn_moments),  y_true)
lossgradient = grad(loss)

In [None]:
function epoch!(w, bn_moments, optim, X_train, Y_train;  batch_size=5)
    data = minibatch(X_train, Y_train, batch_size; shuffle=true, xtype=KnetArray)
    epoch_dice = []
    for (X, Y) in data
        # X = convert(KnetArray, X)
        # Y = convert(KnetArray, Y)
        gradient = lossgradient(w, bn_moments, X, Y, predict)
        update!(w, gradient, optim)
        dice = dice_coeff(X, Y, predict, w, bn_moments)
        push!(epoch_dice, dice)
        Knet.gc()
    end 
    println("dice: ", string(mean(epoch_dice)))
end

In [None]:
function train(w, bn_moments, optim, X_train, Y_train, predict; batch_size=5, epochs=10)
    for epoch = 1:epochs
        epoch!(w, bn_moments, optim, X_train, Y_train;  batch_size=batch_size)
    end
end

In [None]:
train(w, bn_moments, optim, X_train, Y_train, predict, batch_size=2, epochs=10)
pred = predict(X_train, w, bn_moments)
pred = convert(Array, pred)
filename = "./training_10epochs.h5"
h5write(filename,"data", pred)

In [7]:
masks = h5read("training_10epochs.h5", "data")

256×256×1×10 Array{Float64,4}:
[:, :, 1, 1] =
 0.0603632   0.011934     0.00540719  …  0.0119207    0.0279124    0.137414 
 0.0110738   0.000389509  8.41192e-5     0.00025218   0.000865769  0.0237312
 0.00455206  7.0687e-5    1.05132e-5     4.79694e-5   0.000203566  0.0115987
 0.00317239  3.52419e-5   4.61183e-6     3.43131e-5   0.000157479  0.0109787
 0.00254002  2.25337e-5   2.78308e-6     4.82461e-5   0.00022192   0.0128997
 0.00216908  1.79004e-5   2.11073e-6  …  6.71402e-5   0.000291451  0.0150478
 0.00215792  1.70188e-5   1.95096e-6     0.000111882  0.00046204   0.019605 
 0.00230053  1.7997e-5    2.03879e-6     0.000214432  0.000761681  0.0251704
 0.00232962  1.93933e-5   2.22137e-6     0.000539673  0.0014062    0.032766 
 0.00245658  2.13182e-5   2.79764e-6     0.00205589   0.00337874   0.0506263
 0.00250736  2.45872e-5   3.195e-6    …  0.0119967    0.0131598    0.0922828
 0.00253279  2.43025e-5   3.30264e-6     0.0575607    0.0435172    0.17272  
 0.00251591  2.26701e-5   3.02

In [10]:
file_path = "cropped_data_0_10.h5";
data = h5read(file_path, "data");
X = reshape(data[:,:,:,2], 256,256,1,10);
Y = reshape(data[:,:,:,1], 256,256,1,10);

256×256×1×10 Array{Float64,4}:
[:, :, 1, 1] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  1.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  1.0  0.0  1.0  1.0  1.0  1.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  0.0  0.

In [12]:
pred = masks[:,:,1,1]
imshow(pred)

Dict{String,Any} with 4 entries:
  "gui"         => Dict{String,Any}("window"=>GtkWindowLeaf(name="", parent, wi…
  "roi"         => Dict{String,Any}("redraw"=>37: "map(clim-mapped image, input…
  "annotations" => 3: "input-2" = Dict{UInt64,Any}() Dict{UInt64,Any} 
  "clim"        => 2: "CLim" = CLim{Float64}(2.441090585900881e-7, 0.9993040341…

In [13]:
input = X[:,:,1,1]
gt = Y[:,:,1,1]
imshow(input)
imshow(gt)

Dict{String,Any} with 4 entries:
  "gui"         => Dict{String,Any}("window"=>GtkWindowLeaf(name="", parent, wi…
  "roi"         => Dict{String,Any}("redraw"=>111: "map(clim-mapped image, inpu…
  "annotations" => 77: "input-26" = Dict{UInt64,Any}() Dict{UInt64,Any} 
  "clim"        => 76: "CLim" = CLim{Float64}(0.0, 1.0) CLim{Float64} 