<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [7]:
using Flux, Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using Metalhead:trainimgs, CIFAR10
using Images

┌ Info: Precompiling Metalhead [dbeba491-748d-5e0e-a39e-b530a07fa0cc]
└ @ Base loading.jl:1278


In [8]:
getarray(X) = Float32.(permutedims(channelview(X), (2, 3, 1)))

getarray (generic function with 1 method)

## CIFAR 10 dataset

In [9]:
X = trainimgs(CIFAR10)
imgs = [getarray(X[i].img) for i in 1:50000];
labels = onehotbatch([X[i].ground_truth.class for i in 1:50000],1:10);
train = gpu.([(cat(imgs[i]..., dims = 4), labels[:,i]) for i in partition(1:49000, 100)]);

In [10]:
valset = collect(49001:50000)
valX = cat(imgs[valset]..., dims = 4) |> gpu
valY = labels[:, valset] |> gpu

10×1000 Flux.OneHotMatrix{CUDA.CuArray{Flux.OneHotVector,1}}:
 0  0  0  0  1  0  1  0  0  0  0  0  0  …  0  0  0  0  1  0  1  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  1  0  0  0  1  0  0  0  0  1  1
 0  0  0  0  0  0  0  0  1  0  0  0  0     0  0  0  1  0  0  0  1  0  0  0  0
 0  0  0  0  0  0  0  0  0  1  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  1  0  0  0  0  0  0  0  0  0  0     0  0  1  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  1  0  0  0  0  0  0  0  …  1  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  1  0  0  0
 0  0  0  0  0  0  0  0  0  0  1  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 1  0  0  0  0  0  0  1  0  0  0  1  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  1  0  1  0  0  0  0  0  0  0  0  1     0  0  0  0  0  0  0  0  0  1  0  0

└ @ GPUArrays /home/subhaditya/.julia/packages/GPUArrays/PkHCM/src/host/indexing.jl:43


## A block of Conv Relu Batchnorm based on input and output channels

In [9]:
conv_block(in_channels, out_channels) = (
    Conv((3,3), in_channels => out_channels, relu, pad = (1,1), stride = (1,1)), 
    BatchNorm(out_channels))

conv_block (generic function with 1 method)

## Two of the conv blocks which is common in VGG + Maxpool

In [10]:
double_conv(in_channels, out_channels) = (
    conv_block(in_channels, out_channels)...,
    conv_block(out_channels, out_channels)...,
    MaxPool((2,2)))

double_conv (generic function with 1 method)

## VGG Arch
- ... operator will help us unroll the previously defined blocks

In [29]:
collect(Iterators.repeated(conv_block(256, 256),3))

3-element Array{Tuple{Conv{2,2,typeof(relu),Array{Float32,4},Array{Float32,1}},BatchNorm{typeof(identity),Array{Float32,1},Array{Float32,1},Float32}},1}:
 (Conv((3, 3), 256=>256, relu), BatchNorm(256))
 (Conv((3, 3), 256=>256, relu), BatchNorm(256))
 (Conv((3, 3), 256=>256, relu), BatchNorm(256))

In [68]:
vgg16(initial_channels, num_classes) = Chain(
    double_conv(initial_channels, 64)...,
    double_conv(64,128)...,
    conv_block(128, 256)...,
    double_conv(256, 256)...,  
    conv_block(256, 512)...,
    double_conv(512, 512)...,
    conv_block(512, 512)...,
    double_conv(512, 512)...,
    x -> reshape(x, :, size(x, 4)),
    Dense(512, 4096, relu),
    Dropout(0.5),
    Dense(4096, 4096, relu),
    Dropout(0.5),
    Dense(4096, num_classes), 
    softmax
    ) |> gpu

vgg16 (generic function with 1 method)

In [69]:
m = vgg16(3, 10)

Chain(Conv((3, 3), 3=>64, relu), BatchNorm(64), Conv((3, 3), 64=>64, relu), BatchNorm(64), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 64=>128, relu), BatchNorm(128), Conv((3, 3), 128=>128, relu), BatchNorm(128), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 128=>256, relu), BatchNorm(256), Conv((3, 3), 256=>256, relu), BatchNorm(256), Conv((3, 3), 256=>256, relu), BatchNorm(256), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 256=>512, relu), BatchNorm(512), Conv((3, 3), 512=>512, relu), BatchNorm(512), Conv((3, 3), 512=>512, relu), BatchNorm(512), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Conv((3, 3), 512=>512, relu), BatchNorm(512), Conv((3, 3), 512=>512, relu), BatchNorm(512), Conv((3, 3), 512=>512, relu), BatchNorm(512), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), #41, Dense(512, 4096, relu), Dropout(0.5), Dense(4096, 4096, relu), Dropout(0.5), Dense(4096, 10), softmax)

In [70]:
loss(x, y) = crossentropy(m(x), y)

loss (generic function with 1 method)

In [71]:
accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))

accuracy (generic function with 1 method)

In [72]:
evalcb = throttle(() -> @show(accuracy(valX, valY)), 10)

(::Flux.var"#throttled#42"{Flux.var"#throttled#38#43"{Bool,Bool,var"#43#44",Int64}}) (generic function with 1 method)

In [73]:
opt = ADAM()

ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}())

In [74]:
Flux.train!(loss, params(m), train, opt, cb = evalcb)