In [1]:
using ImageMagick
using Knet
using AutoGrad
using JLD
using Statistics
using Random
using Images
include(Knet.dir("data", "mnist.jl"))
xtrain, ytrain, xtest, ytest = mnist();

┌ Info: Loading MNIST...
└ @ Main /home/ahnaf/.julia/packages/Knet/05UDD/data/mnist.jl:33


In [2]:
# Global Parameters
BANDWIDTH = 8
NUM_PATCHES = 1
GLIMPSE_COUNT = 6 #Number of glimpses to be employed
NUM_CLASSES = 10
SCALE = 2
HSIZE = 256
GSIZE = 256;
BATCHSIZE = 20

# FILENAME = "RAM_bnorm.jld2"

20

In [3]:
# Generic Functions
# Normalize dataset and return mean and SD
function Normalize(dataset)
    mean_dataset = Float32(mean(dataset))
    sd_dataset = Float32(std(dataset))
    dataset_norm = (dataset .- mean_dataset) ./ sd_dataset
    return mean_dataset, sd_dataset, dataset_norm
end


# Function for normalizing the locations
# Batch location dimensions = 2 x Batchsize
function antinorm(dim, loc)
    #Conver Normalized locations range[-1,1] into X, Y coordinates
    anloc = []
    meanx::Int16 = floor(dim/2);
    meany::Int16 = floor(dim/2);
    
    locx = Array{Int16}(ceil.((loc[1,:] .* dim) ./ 2 .+ meanx))  
    locy = Array{Int16}(ceil.((loc[2,:] .* dim) ./ 2 .+ meany))
    for i = 1:BATCHSIZE
        temp = reshape([locx[i], locy[i]], (2,1))
        push!(anloc, temp)
    end
    return (cat(anloc..., dims = 2))
end

# # AntiNorm test
# test = KnetArray{Float32}(rand(2,BATCHSIZE))
# println(size(test))
# test = antinorm(28, test)

antinorm (generic function with 1 method)

In [4]:
MEANX, SDX, xtrain = Normalize(xtrain);
~, ~, xtest = Normalize(xtest);

dtrn = minibatch(xtrain, ytrain, BATCHSIZE, xtype=KnetArray{Float32});
dtst = minibatch(xtest, ytest, BATCHSIZE, xtype=KnetArray{Float32});

xd, yd = first(dtrn);

# Generic Elements

In [5]:
#################################################################################################
struct FC
    w
    b
    act
end

(f::FC)(x) =f.act.(f.w * mat(x) .+ f.b)
FC(outsize::Int, insize::Int, act = relu) = FC(param(outsize, insize, atype = KnetArray{Float32}), 
    param0(outsize, atype = KnetArray{Float32}), act);


# # Layer Test
# F1 = FC(100, 784)
# params(FC1)
# testFC = FC1(xd)
#################################################################################################

struct Linear
    w
    b
end

(l::Linear)(x) = l.w * mat(x) .+ l.b

function Linear(outsize::Int, insize::Int)
    w = param(outsize, insize, atype=KnetArray{Float32})
    b = param0(outsize, atype = KnetArray{Float32})
    return Linear(w,b)
end

##################################################################################################

# # Layer Test
# L1 = Linear(100, 784)
# params(L1)
# testLinear = L1(xd)

##################################################################################################

mutable struct BatchNorm
    w
    m
end

(b::BatchNorm)(x; o...) = batchnorm(x, b.m, b.w;o...)

function BatchNorm(c::Int, atype = KnetArray{Float32})
    w = Param(atype(bnparams(c)))
    m = bnmoments()
    return BatchNorm(w, m)
end

# Layer Test
# btest  = BatchNorm(784)
# res = btest(mat(xd))

BatchNorm

## Core Network
Takes the output of Glimpse Network(gt -> gc) and previous hidden state(h(t-1)-> hp). Outputs current Hidden State (ht -> hc).
Runs num_glimpses times and records the hidden state at the end as output. 
Unit Defintion = Rect (Linear(h(t-1)) + Linear(gt))

In [6]:
struct CoreNet
    FCgc   # Linear operator for g current  : gt
    FChp   # Linear operator for h previous : ht-1
    BN_gc # BatchNormalization layer. Defined for gc
    BN_hp # BatchNormalization layer. Defined for hp
end

function CoreNet(gsize::Int, hsize::Int)
    return CoreNet(Linear(gsize, hsize), Linear(hsize, hsize), BatchNorm(hsize), BatchNorm(hsize))
end

function (c::CoreNet)(gc,hp)
    return relu.(c.BN_gc(c.FCgc(gc)) + c.BN_hp(c.FChp(hp)))
end

# CoreNet Test
# CN = CoreNet(hsize, hsize)
# sample = KnetArray{Float32}(rand(256,10))
# sample = reshape(sample , (256, 10))
# CN_test = CN(sample, sample)

# Glimpse Network
## Glimpse Network Constructor

In [7]:
mutable struct GpsNet
    FC1
    FC2
    FC3
    FC4
    bn128_1
    bn128_2
    bn256_1
    bn256_2
end

### Glimpse Network: Initializer & Method

In [8]:
# Glimpse Network default constructor: Definition of vairous cogs of the Glimpse network
function GpsNet(bandwidth::Int, scale::Int, num_patches::Int)
    FC1 = FC(128, bandwidth*bandwidth*num_patches, identity)
    FC2 = FC(128, 2, identity)
    FC3 = Linear(256, 128)  # The outsize may be 256
    FC4 = Linear(256, 128) # The outsize may be 256
    bn128_1 = BatchNorm(128)
    bn128_2 = BatchNorm(128)
    bn256_1 = BatchNorm(256)
    bn256_2 = BatchNorm(256)
    return GpsNet(FC1, FC2, FC3, FC4, bn128_1, bn128_2, bn256_1, bn256_2)
end
###########################################################################################

## Glimpse Network Deployment

# Deploying glimpse Network
# Input is the batch of images
# Take glimpses & locations. Pass them to FC layers 
function (g::GpsNet)(x_batch, loc)
    phi   = sensor(x_batch, loc)
    phi_1 = g.FC1(phi)
    phi_1 = g.bn128_1(phi_1)
    loc_1 = g.FC2(loc)
    loc_1 = g.bn128_2(loc_1)
    phi_2 = g.FC3(phi_1)
    phi_2 = g.bn256_1(phi_2)
    loc_2 = g.FC4(loc_1)
    loc_2 = g.bn256_2(loc_2)
    return relu.(phi_2 + loc_2) # The output here is gt to be fed to CoreNet
end

# # GpsNet Test
# test2 = GpsNet(bandwidth, scale, num_patches)
# params(test2.FC2)
# res_test2 = test2(xd, loc)

### Sensor Network

In [None]:
# Uses imresize for equalizing the return dimensions
function sensor(x_batch, loc)
    l,w,c,b = size(x_batch)
#     coord = antinorm(l, loc)
    coord = antinorm(l, loc)
    glimpse_array = []
    padsize::Int = 60
    width::Int = floor(BANDWIDTH / 2)  
    batch = Array(mat(x_batch))
    batch = reshape(batch, (l,w,b))
    
    for i = 1:BATCHSIZE
        img = batch[:,:,i]
        pad_img = padarray(img, Fill(0, (padsize, padsize)))
        temp = []
        for j = 1: NUM_PATCHES
            xlim1 = coord[1,i] - (width* (SCALE ^ (j-1)))
            xlim2 = coord[1,i]-1 + (width * (SCALE ^(j-1)))
            ylim1 = coord[2,i] - (width* (SCALE ^ (j-1)))
            ylim2 = coord[2,i]-1 + (width * (SCALE ^(j-1)))

            glimpse = pad_img[xlim1:xlim2, ylim1:ylim2]
            glimpse = imresize(glimpse, (BANDWIDTH, BANDWIDTH))
            push!(temp, glimpse)
        end
        temp = vcat(vcat(temp...)...)
        push!(glimpse_array, vcat(temp))
    end
    return KnetArray{Float32}(mat(cat(glimpse_array..., dims=2)))
end

## Function test
coord = KnetArray{Float32}(zeros(2, BATCHSIZE))
res = sensor(xd, coord)

In [9]:
# Uses the pooling for equalizing dimensions
function sensor(x_batch, loc)
    patch_size =  BANDWIDTH
    phi = []
    for i = 1:NUM_PATCHES
        push!(phi, foveat(x_batch, loc, patch_size))
        patch_size = SCALE * patch_size
    end
    
    for j = 1:length(phi)
        k = div(size(phi[j])[1], BANDWIDTH)
        phi[j] = pool(phi[j]; window = k, mode=1)
    end
    
    phi2d = map(mat, phi)
    return mat(vcat(phi2d...))
end

function foveat(x_batch, loc, patch_size)
    l,w,c,b = size(x_batch)
    coord = antinorm(l, loc)
    patches = [] 
    padsize::Int = 60
    width::Int = floor(patch_size / 2)

    xlim1 = coord[1,:] .- width .+1
    xlim2 = coord[1,:] .+ width
    ylim1 = coord[2,:] .- width .+1
    ylim2 = coord[2,:] .+ width 
    batch = Array(x_batch)

    for i=1:b
        img = batch[:,:,1,i]
        pad_img = padarray(img, Fill(0, (padsize, padsize)))
        push!(patches, pad_img[xlim1[i]:xlim2[i], ylim1[i]:ylim2[i]])
    end
    
    return KnetArray{Float32}(cat(patches...; dims=4))
end          

# ## Function test
coord = KnetArray{Float32}(zeros(2, BATCHSIZE))
res = sensor(xd, coord)

64×20 KnetArray{Float32,2}:
 -0.411346   2.60517   -0.424074  …   1.34511    -0.424074  -0.424074 
  1.53602    2.78336   -0.424074      2.80882     1.16691    2.08332  
  2.79609    2.78336   -0.424074      2.41425     2.66881    2.78336  
  0.721438   1.85422   -0.424074     -0.105876    0.212322   1.89241  
 -0.424074  -0.271339  -0.424074     -0.424074   -0.424074  -0.424074 
 -0.424074   0.530519  -0.424074  …  -0.424074   -0.424074  -0.424074 
 -0.424074   1.116     -0.424074      0.0341309  -0.424074  -0.424074 
 -0.424074  -0.156788  -0.424074      1.73967    -0.424074   0.0468588
 -0.424074   2.79609   -0.424074      0.275961   -0.424074  -0.424074 
  1.34511    2.54153   -0.424074      2.27424     2.05787    1.96877  
  2.79609    2.22333   -0.424074  …   2.79609     1.91786    2.78336  
  1.99423    0.645071  -0.424074      1.62512    -0.424074   0.886901 
 -0.398618  -0.424074  -0.424074     -0.284067   -0.424074  -0.424074 
  ⋮                               ⋱              

## Location Network

In [10]:
mutable struct LocNet
    sdx
    Linl
end

# Function should take input size as 256 and output size as 2 (size of loc coordinates)
function LocNet(outsize::Int, insize::Int, sdx::Float32)
    Linloc = Linear(outsize, insize)
    return LocNet(sdx, Linloc)
end

function (l::LocNet)(ht)
    mean_loc = tanh.(l.Linl(ht))
    
    prtb = KnetArray{Float32}(gaussian((2,BATCHSIZE); mean = 0, std = 0.12))
    loc_new = mean_loc + prtb 
    return mean_loc, value(tanh.(loc_new))
end

# # TLocation Network Test
# Loc1 = LocNet(2, 256, SDX)
# sample = KnetArray{Float32}(rand(256));
# sample = reshape(sample, (256, 1));
# res, lt = Loc1(sample)

## RAM Model
### Staggered RAM
The Reference RAM deployment has been staggered.
Deploy one RAM function once to generate the hidden state, next location and action.
Deploy this one-time active RAM recursively equal to number of glimpses. Record all intermediate states
but base decision on output of the final one of the loop.

In [11]:
# RAM Initializer
# ram = RAM(SDX, GLIMPSE_COUNT, BANDWIDTH, NUM_PATCHES, NUM_CLASSES, SCALE, GSIZE, HSIZE);
mutable struct RAM
    sdx
    num_glimpses
    gpsnet::GpsNet
    corenet::CoreNet
    locnet::LocNet
    output_layer
    locations
    basenet
    hsize
end

function RAM(sdx, glimpse_count, bandwidth, num_patches, num_classes, scale, gsize, hsize)
#     sdx = SDX
#     num_glimpses = glimpse_count
    gpsnet = GpsNet(bandwidth, scale, num_patches)
    corenet = CoreNet(gsize, hsize)
    locnet = LocNet(2, hsize, sdx)
    output_layer = Linear(num_classes, hsize) # Action Network
    locations = []
    basenet = FC(1, hsize)
    return RAM(sdx, glimpse_count, gpsnet, corenet, locnet, output_layer, locations, basenet, hsize)
end

##########################################################################################

# Outer RAM instance. Calls inner instance glimpse_count times
# Returns class_probs, policy result, hidden state and baseline results

function (r::RAM)(x_batch)
    # Init States
    hc = KnetArray{Float32}(reshape(zeros(r.hsize, BATCHSIZE), (r.hsize, BATCHSIZE)))
    loc = KnetArray{Float32}(reshape(zeros(2,BATCHSIZE), (2, BATCHSIZE)))
    
    xs, log_ps, baseresults = [], [], [] 
    for i = 1:r.num_glimpses
        hc, loc, policy, baseres = r(x_batch, loc, hc)
        push!(log_ps, policy)       
        push!(baseresults, baseres)      
    end
    log_ps, baseresults = vcat(log_ps...), vcat(baseresults...)    
    scores = logsoftmax(r.output_layer(hc))  # Size of NUM_CLASSES x BATCHSIZE. mat() applied at layer definition
    return scores, log_ps, hc, baseresults
end

###########################################################################################

# Innermost RAM instance
# Return for a single glimpse, hidden state, next_location, policy and baseline results.
function (r::RAM)(x_batch, loc, hp)
    gc = r.gpsnet(x_batch, loc)  # g-current: gt
    hc = r.corenet(gc,hp)       # h-current ; ht-- Uses g_current and h_previous
    mean, lnext = r.locnet(hc) # mean and l-next:lt to be used in the next time step
    sdx = r.sdx
    sdx2 = sdx.^2
    policy = -(abs.(lnext - mean).^2) / 2*(sdx2) .- log(sdx) .- log(sqrt(2 * 3.1415))
    policy = sum(policy, dims =1)
    base_res = r.basenet(hc)
    return hc, lnext, policy, base_res
end

In [16]:
# Loss Caluclator
function (ram::RAM)(x_batch, y)
    scores, log_ps, hc, baseresults = ram(x_batch) 
    ypred = vec(map(i->i[1], argmax(Array(value(scores)), dims=1)))
#     ypred = vec(map(i->i[1], argmax(scores, dims=1)))
    r = ypred .== y; r = reshape(r, 1, :) # Return Boolean values
#     r = Array{Float32}(ypred .== y)
#     r = reshape(r, (1, BATCHSIZE))
    
    # Cumulation over one-batch
    R = zeros(Float32, size(baseresults)...); R[end,:] = r
    R = KnetArray{Float32}(R)
#     R̂ = R .- value(baseresults)
    R_err = R .- value(baseresults)
    loss_action = nll(scores, y)
    loss_baseline = sum(abs2, baseresults .- R) / length(baseresults)
    loss_reinforce = mean(sum(-log_ps .* R_err, dims=1))
    return loss_action, loss_baseline, loss_reinforce, sum(R), length(R)
end


In [17]:
function validate(ram::RAM, data)
    losses = zeros(3)
    ncorrect = ninstances = 0
    for (x,y) in data
        ret = ram(x,y)
        for i = 1:3; losses[i] += ret[i]; end
        ncorrect += ret[4]
        ninstances += ret[5]
    end
    losses = losses / length(data)
    losses = [sum(losses), losses...]
    return losses, ncorrect / ninstances
end

validate (generic function with 1 method)

### RAM Initializer

In [18]:
function initopt!(model, optimizer="Adam(lr=0.003, gclip=0.0)")
    for par in params(model)
        par.opt = eval(Meta.parse(optimizer))
    end
end

ram = RAM(SDX, GLIMPSE_COUNT, BANDWIDTH, NUM_PATCHES, NUM_CLASSES, SCALE, GSIZE, HSIZE);
record = params(ram);
initopt!(ram)

In [None]:
ram = Knet.load("RAM_bnorm.jld2", "RAM");

In [None]:
epochs = 1000
history = []
loss(x, yref) = sum(ram(x, yref)[1:3])
gradients = []
loss_batch = []
for epoch = 1:epochs
    losses = []
    for (x, y) in dtrn
        lss = @diff loss(x, y)
        push!(losses, value(lss))
        for par in params(ram)
            g = grad(lss, par)
            update!(value(par), g, par.opt)
        end
    end
    
    push!(loss_batch, mean(losses))
    trn_losses, trn_acc = validate(ram, dtrn);
    tst_losses, tst_acc = validate(ram, dtst);
    println("Loss for #epoch : ", mean(losses))
        println(
        "epoch=$(1+length(history)) ",
        "trnloss=$(trn_losses), trnacc=$trn_acc, ",
        "tstloss=$(tst_losses), tstacc=$tst_acc")
    push!(history, ([trn_losses..., trn_acc, tst_losses..., tst_acc]));

        
end

Loss for #epoch : 0.73197216
epoch=1 trnloss=[0.61151, 0.868648, 0.122842, -0.37998], trnacc=0.12284166, tstloss=[0.579599, 0.840164, 0.124483, -0.385049], tstacc=0.12448333
Loss for #epoch : 0.40665886
epoch=2 trnloss=[0.440806, 0.713211, 0.130122, -0.402527], trnacc=0.13012223, tstloss=[0.433509, 0.706696, 0.1305, -0.403687], tstacc=0.1305
Loss for #epoch : 0.28901953
epoch=3 trnloss=[0.303507, 0.589309, 0.136525, -0.422327], trnacc=0.136525, tstloss=[0.308736, 0.594618, 0.136567, -0.422449], tstacc=0.13656667
Loss for #epoch : 0.21611997
epoch=4 trnloss=[0.325748, 0.610021, 0.1358, -0.420073], trnacc=0.1358, tstloss=[0.353135, 0.635637, 0.13495, -0.417452], tstacc=0.13495
Loss for #epoch : 0.16087914
epoch=5 trnloss=[0.18978, 0.483857, 0.140478, -0.434555], trnacc=0.14047778, tstloss=[0.220594, 0.512466, 0.139417, -0.431288], tstacc=0.13941666
Loss for #epoch : 0.13115229
epoch=6 trnloss=[0.0956406, 0.399614, 0.145217, -0.44919], trnacc=0.14521667, tstloss=[0.14063, 0.441116, 0.1435

In [None]:
Knet.@save "samples.jld2" LOCATION_RECORD SAMPLES

In [None]:
Knet.save(FILENAME, "RAM", ram)

In [None]:
Knet.@save "HistoryBnorm_acc.jld2" history

In [None]:
ram = nothing
Knet.gc()