In [6]:
using Knet, HDF5

ToDos: 
- change weight initialization and model s.t. bn_params are part of the weight matrix (for update step)
- TRAIN (overfit on one batch)

# Preprocessing

In [10]:
file_path = file_path = "/Users/kathi/Desktop/6.338-ModernNumericalComputing/project/cropped_data_0_10.h5"
data = h5read(file_path, "data")
X_train = reshape(data[:,:,:,2], 256,256,1,10) 
Y_train = reshape(data[:,:,:,1], 256,256,1,10)

256×256×1×10 Array{Float64,4}:
[:, :, 1, 1] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  1.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  1.0  0.0  1.0  1.0  1.0  1.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  0.0  0.

# Initialization

In [1]:
function init_model(n_input_channels=1, n_output_channels=2, depth=4, max_channels=512, kernel_size=3)
    w = Any[]
    bn_params = Any[]
    bn_moments = Any[bnmoments() for i = 1:14]
    # 1st layer
    # convolution
    push!(w,xavier(Float64,kernel_size,kernel_size,n_input_channels,Int64(max_channels/(2^(depth-1)))))
    # bias
    push!(w,zeros(Float64,1,1,Int64(max_channels/(2^(depth-1))),1))
    # batchnorm params
    push!(bn_params, bnparams(Float64,Int64(max_channels/(2^(depth-1)))))
    # convolution
    push!(w,xavier(Float64,kernel_size,kernel_size,Int64(max_channels/(2^(depth-1))),Int64(max_channels/(2^(depth-1)))))
    # bias
    push!(w,zeros(Float64,1,1,Int64(max_channels/(2^(depth-1))),1))
    push!(bn_params, bnparams(Float64,Int64(max_channels/(2^(depth-1)))))
    
    println("layer 1")
    println(size(w))
    
    # encoding arm: 2nd up to and including bottleneck
    for layer=2:depth
       push!(w,xavier(Float64,kernel_size,kernel_size,Int64(max_channels/(2^(depth-layer+1))),Int64(max_channels/(2^(depth-layer))))) 
        push!(w, zeros(Float64,1,1,Int64(max_channels/(2^(depth-layer))),1))
        push!(bn_params, bnparams(Float64,Int64(max_channels/(2^(depth-layer)))))
        push!(w,xavier(Float64,kernel_size,kernel_size,Int64(max_channels/(2^(depth-layer))),Int64(max_channels/(2^(depth-layer))))) 
        push!(w, zeros(Float64,1,1,Int64(max_channels/(2^(depth-layer))),1))
        push!(bn_params, bnparams(Float64,Int64(max_channels/(2^(depth-layer)))))
    end
    
    # decoding arm (except for 3rd convolution in the last layer)
    for layer=(depth+1):(2*depth-1)
        push!(w,xavier(Float64,kernel_size,kernel_size,Int64(1.5*max_channels/(2^(layer-depth-1))),Int64(max_channels/(2^(layer-depth))))) 
        push!(w, zeros(Float64,1,1,Int64(max_channels/(2^(layer-depth))),1))
        push!(bn_params, bnparams(Float64,Int64(max_channels/(2^(layer-depth)))))
        push!(w,xavier(Float64,kernel_size,kernel_size,Int64(max_channels/(2^(layer-depth))),Int64(max_channels/(2^(layer-depth))))) 
        push!(w, zeros(Float64,1,1,Int64(max_channels/(2^(layer-depth))),1))
        push!(bn_params, bnparams(Float64,Int64(max_channels/(2^(layer-depth)))))
    end
    
    # output convolution
    push!(w, xavier(Float64,kernel_size,kernel_size,Int64(max_channels/(2^(depth-1))),n_output_channels)) 
    push!(w, zeros(Float64,1,1,n_output_channels,1))
       
    w = map(KnetArray, w)
    bn_params = map(KnetArray, bn_params)
    
    return w, bn_moments, bn_params
end

init_model (generic function with 6 methods)

# Model

In [None]:
function flatten(array)
    size_array = size(array)
    length_flattened = 1
    for dim = 1:ndims(array)
        length_flattened = length_flattened * size_array[dim]
    end
    return flattened = reshape(array,length_flattened)
end

In [1]:
function cat3d(arr1, arr2)
    R = size(arr1)[1]
    C = size(arr1)[2]
    Ch = size(arr1)[3]+size(arr2)[3]
    S = size(arr1)[4]
    
    arr1 = flatten(permutedims(arr1, [1,2,4,3]))
    arr2 = flatten(permutedims(arr2, [1,2,4,3]))
    out = reshape(cat1d(arr1,arr2),R,C,S,Ch)
    out = permutedims(out,[1,2,4,3])
    
    return out
end

cat3d (generic function with 1 method)

In [None]:
function conv_layer_encoder(w, bn_moments, bn_params, x)
    out = conv4(w[1], x, padding=1).+ w[2]
    out = batchnorm(out, bn_moments[1], bn_params[1])
    conv_1 = relu.(out)
    
    out = conv4(w[3], conv_1, padding=1) .+ w[4]
    out = batchnorm(out, bn_moments[2], bn_params[2])
    out = relu.(out)
    pooled = pool(out)
    
    return conv_1, pooled
end

In [2]:
function conv_layer_decoder(w, bn_moments, bn_params, x_cat, x)
    number_of_channels = size(x)[3]
    upsampling_kernel = convert(KnetArray, bilinear(Float64,2,2,number_of_channels,number_of_channels))
    out = deconv4(upsampling_kernel, x, padding = 1,stride = 2) 
    out = cat3d(out, x_cat)
    
    out = conv4(w[1], out, padding = 1) .+ w[2]
    out = batchnorm(out, bn_moments[1], bn_params[1])
    out = relu.(out)
    
    out = conv4(w[3], out, padding = 1) .+ w[4]
    out = batchnorm(out, bn_moments[2], bn_params[2])
    out = relu.(out)
    
    return out
end

conv_layer_decoder (generic function with 1 method)

In [3]:
function bottleneck(w, bn_moments, bn_params, x)
    out = conv4(w[1], x, padding = 1) .+ w[2]
    out = batchnorm(out, bn_moments[1], bn_params[1])
    out = relu.(out)
    out = conv4(w[3], out, padding = 1) .+ w[4]
    out = batchnorm(out, bn_moments[2], bn_params[2])
    out = relu.(out)
    
    return out
end

bottleneck (generic function with 1 method)

In [None]:
function output_layer(w, x)
    out = sigm.(conv4(w[1], x).+ w[2])
    
    return out
end

In [None]:
function predict(w,bn_moments, bn_params,x_in)
    # layer 1
    conv_1, out = conv_layer_encoder(w[1:4], bn_moments[1:2], bn_params[1:2], x_in)
    # layer 2
    conv_2, out = conv_layer_encoder(w[5:8], bn_moments[3:4], bn_params[3:4], out)
    #layer 3
    conv_3, out = conv_layer_encoder(w[9:12], bn_moments[5:6], bn_params[5:6], out)
    # layer 4 = bottleneck
    out = bottleneck(w[13:16], bn_moments[7:8], bn_params[7:8], out)
    # layer 5
    out = conv_layer_decoder(w[17:20], bn_moments[9:10], bn_params[9:10], conv_3, out)
    # layer 6
    out = conv_layer_decoder(w[21:24], bn_moments[11:12], bn_params[11:12], conv_2, out)
    # layer 7
    out = conv_layer_decoder(w[25:28], bn_moments[13:14], bn_params[13:14], conv_1, out)
    out = output_layer(w[29:30], out)
    
    return out
end

# Loss

In [11]:
# dice loss -> not yet implemented

In [None]:
function categorical_crossentropy_3d(y_true, y_pred):
    epsilon = 1.0e-7
    y_true_flatten = flatten(y_true)
    y_pred_flatten = flatten(y_pred)
    y_pred_flatten_log = log.(y_pred_flatten.+epsilon)
    num_total_elements = sum(y_true_flatten) # number of positive voxels (gt 0/1 encoded)
    cross_entropy = sum(y_true_flatten.*y_pred_flatten_log)
    mean_cross_entropy = cross_entropy/(num_total_elements+epsilon)
    return mean_cross_entropy
end

lossgrad = grad(categorical_crossentropy_3d)

# Train
- work in progress

In [8]:
function epoch!(w, bn_moments, bn_params, optim, X_train, Y_train;  batch_size = 10)
    data = minibatch(X_train, Y_train, batch_size; shuffle=true, xtype=Array)
    for (X, Y) in data
        gradient = lossgrad(w, m, x, y)
        update!(w, gradient, optim)
    end 
end

epoch! (generic function with 2 methods)

In [9]:
epoch!(X_train,Y_train)

(256, 256, 1, 2)
(256, 256, 1, 2)
(256, 256, 1, 2)
(256, 256, 1, 2)
(256, 256, 1, 2)
(256, 256, 1, 2)
(256, 256, 1, 2)
(256, 256, 1, 2)
(256, 256, 1, 2)
(256, 256, 1, 2)


In [2]:
optim(w; lr=0.001) = optimizers(w, ADAM;  lr=lr);

In [None]:
function report(epoch, w, dtrn, dtst, predict)
    println((:epoch, epoch, :trn, accuracy(w, dtrn, predict), :tst, accuracy(w, dtst, predict)))
end

In [None]:
w   = initweights(atype);
opt = optim(w);

if fast
    train(w, dtrn, opt, predict; epochs=nepochs)
else
    for epoch = 1:nepochs
        train(w, dtrn, opt, predict; epochs=1)
        report(epoch, w, dtrn, dtst, predict)
    end
end