In [2]:
using ImageMagick
using Knet
using AutoGrad
using JLD
using Statistics
using Random
using Images
include(Knet.dir("data", "mnist.jl"))
xtrain, ytrain, xtest, ytest = mnist();

In [3]:
# Global Parameters
BANDWIDTH = 8
NUM_PATCHES = 1
GLIMPSE_COUNT = 6 #Number of glimpses to be employed
NUM_CLASSES = 10
SCALE = 2
HSIZE = 256
GSIZE = 256;
BATCHSIZE = 20
MOdel_FILENAME = "RAM_Fleer.jld2"
Data_FILENAME = "History_Fleer.jld2"

"History_Fleer.jld2"

In [4]:
# Generic Functions
# Normalize dataset and return mean and SD
function Normalize(dataset)
    mean_dataset = Float32(mean(xtrain))
    sd_dataset = Float32(std(xtrain))
    dataset_norm = (dataset .- mean_dataset) ./ sd_dataset
    return dataset_norm
end


# Function for denormalizing the locations
# From -1:1 to 1:28
# Batch location dimensions = 2 x Batchsize
function antinorm(dim, loc)
    #Conver Normalized locations range[-1,1] into X, Y coordinates
    anloc = []
    meanx::Int16 = floor(dim/2);
    meany::Int16 = floor(dim/2);
    
    locx = Array{Int16}(ceil.((loc[1,:] .* dim) ./ 2 .+ meanx))  
    locy = Array{Int16}(ceil.((loc[2,:] .* dim) ./ 2 .+ meany))
    for i = 1:BATCHSIZE
        temp = reshape([locx[i], locy[i]], (2,1))
        push!(anloc, temp)
    end
    return (cat(anloc..., dims = 2))
    
end

function denorm(dim, loc)
    return convert(Array{Int}, floor.(0.5 .* dim .* (loc .+ 1.0)))
end


# # AntiNorm test
# test = KnetArray{Float32}(rand(2,BATCHSIZE))
# println(size(test))
# test = antinorm(28, test)

denorm (generic function with 1 method)

In [5]:
SDX = Float32(0.11)
xtrain = Normalize(xtrain);
xtest = Normalize(xtest);

dtrn = minibatch(xtrain, ytrain, BATCHSIZE, xtype=KnetArray{Float32});
dtst = minibatch(xtest, ytest, BATCHSIZE, xtype=KnetArray{Float32});

# Samples for Testing
xd, yd = first(dtrn);

# Generic Elements

In [6]:
#################################################################################################
struct FC
    w
    b
    act
end

(f::FC)(x) = f.act.(f.w * mat(x) .+ f.b)
FC(outsize::Int, insize::Int, act) = FC(Knet.param(outsize, insize, atype = KnetArray{Float32}), 
    Knet.param(outsize, atype = KnetArray{Float32}), act);


# # Layer Test
# F1 = FC(100, 784)
# params(FC1)
# testFC = FC1(xd)
#################################################################################################

struct Linear
    w
    b
end

(l::Linear)(x) = l.w * mat(x) .+ l.b

function Linear(outsize::Int, insize::Int)
    w = Knet.param(outsize, insize, atype=KnetArray{Float32})
    b = Knet.param(outsize, atype = KnetArray{Float32})
    return Linear(w,b)
end

# # Layer Test
# L1 = Linear(100, 784)
# params(L1)
# testLinear = L1(xd)

###################################################################################################

mutable struct BatchNorm
    w
    m
end

(b::BatchNorm)(x) = batchnorm(x, b.m, b.w)

function BatchNorm(c::Int; atype=KnetArray{Float32})
    w = Knet.Param(atype(bnparams(c)))
    m = bnmoments()
end


BatchNorm

## Core Network
Takes the output of Glimpse Network(gt -> gc) and previous hidden state(h(t-1)-> hp). Outputs current Hidden State (ht -> hc).
Runs num_glimpses times and records the hidden state at the end as output. 
Unit Defintion = Rect (Linear(h(t-1)) + Linear(gt))

In [7]:
struct CoreNet
    FCgc   # Linear operator for g current  : gt
    FChp   # Linear operator for h previous : ht-1
end

function CoreNet(gsize::Int, hsize::Int)
    return CoreNet(Linear(hsize, hsize), Linear(hsize, hsize))
end

function (c::CoreNet)(gt,hp)
    return relu.(c.FCgc(gt) + c.FChp(hp))
end

# CoreNet Test
# CN = CoreNet(hsize, hsize)
# sample = KnetArray{Float32}(rand(256,10))
# sample = reshape(sample , (256, 10))
# CN_test = CN(sample, sample)

# Glimpse Network
## Glimpse Network Constructor

In [8]:
mutable struct GpsNet
    FC1
    FC2
    FC3
    FC4
end

### Glimpse Network: Initializer & Method

In [9]:
# Glimpse Network default constructor: Definition of vairous cogs of the Glimpse network
function GpsNet(bandwidth::Int, scale::Int, num_patches::Int)
    FC1 = FC(128, bandwidth*bandwidth*num_patches, relu)
    FC2 = FC(128, 2, relu)
    FC3 = Linear(256, 256)  # The outsize may be 256
    FC4 = Linear(256, 256) # The outsize may be 256
    return GpsNet(FC1, FC2, FC3, FC4)
end

###########################################################################################

## Glimpse Network Deployment

# Deploying glimpse Network
# Input is the batch of images
# Take glimpses & locations. Pass them to FC layers 
function (g::GpsNet)(x_batch, loc)
    phi   = sensor(x_batch, loc)
#     phi =  mat(phi)
    phi_1 = g.FC1(phi)
    loc_1 = g.FC2(loc)
    cat_philoc =cat((loc_1, phi_1)..., dims =1)
    cat1 = relu.(g.FC3(cat_philoc))
    cat2 = g.FC4(cat1)
    return cat2 # The output here is gt to be fed to CoreNet
end

# # GpsNet Test
# test2 = GpsNet(bandwidth, scale, num_patches)
# params(test2.FC2)
# res_test2 = test2(xd, loc)

### Sensor Network

In [10]:
# Uses imresize for equalizing the return dimensions
function sensor(x_batch, loc)
    l,w,c,b = size(x_batch)
    coord = antinorm(l, loc)
    glimpse_array = []
    padsize::Int = 60
    width::Int = floor(BANDWIDTH / 2)  
    batch = Array(mat(x_batch))
    batch = reshape(batch, (l,w,b))
    
    for i = 1:BATCHSIZE
        img = batch[:,:,i]
        pad_img = padarray(img, Fill(0, (padsize, padsize)))
        temp = []
        for j = 1: NUM_PATCHES
            xlim1 = coord[1,i] - (width* (SCALE ^ (j-1)))
            xlim2 = coord[1,i]-1 + (width * (SCALE ^(j-1)))
            ylim1 = coord[2,i] - (width* (SCALE ^ (j-1)))
            ylim2 = coord[2,i]-1 + (width * (SCALE ^(j-1)))

            glimpse = pad_img[xlim1:xlim2, ylim1:ylim2]
            glimpse = imresize(glimpse, (BANDWIDTH, BANDWIDTH))
            push!(temp, glimpse)
        end
        temp = vcat(vcat(temp...)...)
        push!(glimpse_array, vcat(temp))
    end
    return KnetArray{Float32}(mat(cat(glimpse_array..., dims=2)))
end

## Function test
# coord = KnetArray{Float32}(zeros(2, BATCHSIZE))
# res = sensor(xd, coord)

sensor (generic function with 1 method)

In [None]:
# Uses the pooling for equalizing dimensions
function sensor(x_batch, loc)
    patch_size =  BANDWIDTH
    phi = []
    for i = 1:NUM_PATCHES
        push!(phi, foveat(x_batch, loc, patch_size))
        patch_size = SCALE * patch_size
    end
    
    for j = 1:length(phi)
        k = div(size(phi[j])[1], BANDWIDTH)
        phi[j] = pool(phi[j]; window = k, mode=1)
    end
    
    phi2d = map(mat, phi)
    return mat(vcat(phi2d...))
end

function foveat(x_batch, loc, patch_size)
    l,w,c,b = size(x_batch)
    coord = antinorm(l, loc)
    patches = [] 
    padsize::Int = 60
    width::Int = floor(patch_size / 2)

    xlim1 = coord[1,:] .- width .+1
    xlim2 = coord[1,:] .+ width
    ylim1 = coord[2,:] .- width .+1
    ylim2 = coord[2,:] .+ width 
    batch = Array(x_batch)

    for i=1:b
        img = batch[:,:,1,i]
        pad_img = padarray(img, Fill(0, (padsize, padsize)))
        push!(patches, pad_img[xlim1[i]:xlim2[i], ylim1[i]:ylim2[i]])
    end
    
    return KnetArray{Float32}(cat(patches...; dims=4))
end          

# ## Function test
# coord = KnetArray{Float32}(zeros(2, BATCHSIZE))
# res = sensor(xd, coord)

## Location Network

In [11]:
mutable struct LocNet
    sdx
    Linl
end

# Function should take input size as 256 x BATCHSIZE and output size as 2 (size of loc coordinates) x BATCHSIZE
function LocNet(outsize::Int, insize::Int, sdx::Float32)
    Linloc = Linear(outsize, insize)
    return LocNet(sdx, Linloc)
end

function (l::LocNet)(ht)
    mean_loc = tanh.(l.Linl(value(ht)))
    prtb = SDX .* randn!(similar(mean_loc))
    loc_new = mean_loc + prtb
#     loc_new = mean_loc + KnetArray{Float32}(randn((2, BATCHSIZE))) .* l.sdx
    return mean_loc, tanh.(loc_new)
end

# function (l::LocNet)(ht)
#     mean_loc = tanh.(l.Linl(ht))
    
#     loc_new = KnetArray{Float32}(zeros(size(mean_loc)))
#     for i = 1:size(mean_loc)[1]
#         for j = 1:size(mean_loc)[2]
#             loc_new = rand(Normal(mean_loc[i,j], l.sdx))
#         end
#     end
    
#     return mean_loc, value(tanh.(loc_new))
# end

# # TLocation Network Test
# Loc1 = LocNet(2, 256, SDX)
# sample = KnetArray{Float32}(rand(256));
# sample = reshape(sample, (256, 1));
# res, lt = Loc1(sample)

## RAM Model
### Staggered RAM
The Reference RAM deployment has been staggered.
Deploy one RAM function once to generate the hidden state, next location and action.
Deploy this one-time active RAM recursively equal to number of glimpses. Record all intermediate states
but base decision on output of the final one of the loop.

In [19]:
# RAM Initializer
# ram = RAM(SDX, GLIMPSE_COUNT, BANDWIDTH, NUM_PATCHES, NUM_CLASSES, SCALE, GSIZE, HSIZE);
mutable struct RAM
    sdx
    num_glimpses
    gpsnet::GpsNet
    corenet::CoreNet
    locnet::LocNet
    output_layer
    locations
    basenet
    hsize
end

function RAM(sdx, glimpse_count, bandwidth, num_patches, num_classes, scale, gsize, hsize)
    gpsnet = GpsNet(bandwidth, scale, num_patches)
    corenet = CoreNet(gsize, hsize)
    locnet = LocNet(2, hsize, sdx)
    output_layer = Linear(num_classes, hsize) # Action Network
    locations = []
    basenet = Linear(1, hsize)
    return RAM(sdx, glimpse_count, gpsnet, corenet, locnet, output_layer, locations, basenet, hsize)
end

##########################################################################################

# Outer RAM instance. Calls inner instance glimpse_count times
# Returns class_probs, policy result, hidden state and baseline results

function (r::RAM)(x_batch)
    # Init States
    hc = KnetArray{Float32}(reshape(zeros(r.hsize, BATCHSIZE), (r.hsize, BATCHSIZE)))
    loc = KnetArray{Float32}(zeros(2,BATCHSIZE))
    log_ps, baseresults = [], []  
    for i = 1:r.num_glimpses
        hc, loc, policy, baseres = r(x_batch, loc, hc)
        push!(log_ps, policy)       
        push!(baseresults, baseres)      
    end
    log_ps, baseresults = vcat(log_ps...), vcat(baseresults...)    
    scores = logsoftmax(r.output_layer(hc))  # Size of NUM_CLASSES x BATCHSIZE. mat() applied at layer definition
    return scores, log_ps, hc, baseresults
end

###########################################################################################

# Innermost RAM instance
# Return for a single glimpse, hidden state, next_location, policy and baseline results.
function (r::RAM)(x_batch, loc, hp)
    gc = r.gpsnet(x_batch, loc)  # g-current: gt
    hc = r.corenet(gc,hp)       # h-current ; ht-- Uses g_current and h_previous
    mean_loc, lnext = r.locnet(hc) # mean and l-next:lt to be used in the next time step
    sdx2 = (r.sdx.^2)
#     policy = -(abs.(lnext - mean_loc).^2) ./ 2*(sdx2) .- log(r.sdx) .- log(sqrt(2 * 3.1415))
#     policy = (value(lnext) - mean_loc) ./ sdx2 # #Glimpses x 
    policy = (value(lnext) - mean_loc) ./ sdx2
    policy = sum(policy, dims =1)
    base_res = sigm.(r.basenet(hc))
    return hc, lnext, policy, base_res
end

In [13]:
# Loss Caluclator
function (ram::RAM)(x_batch, y)
    scores, log_ps, hc, baseresults = ram(x_batch) 
    ypred = vec(map(i->i[1], argmax(Array(value(scores)), dims=1)))

    r = (ypred .== y)
    r = reshape(r, 1, :) 
    
    # Modified Implementation
    R = zeros(Float32, size(baseresults)...); R[end,:] = r
    R = KnetArray{Float32}(R)    
    
    # Batch Losses
    R_err = (R - baseresults)
    loss_action = nll(scores, y)
    loss_baseline = sum(abs2, value(baseresults) .- R) / length(baseresults)
    
    # Reinforce Loss
#     mean_logps = mean(log_ps)
#     loss_reinforce = sum(mean_logps .* R_err)
    
    # Reinforce Loss 2
    loss_reinforce = mean(sum(-log_ps .* R_err, dims =1))
    return loss_action, loss_baseline, loss_reinforce, sum(R), length(R)
end


In [14]:
function validate(ram::RAM, data)
    losses = zeros(3)
    ncorrect = ninstances = 0
    for (x,y) in data
        ret = ram(x,y)
        for i = 1:3; losses[i] += ret[i]; end
        ncorrect += ret[4]
        ninstances += ret[5]
    end
    losses = losses / length(data)
    losses = [sum(losses), losses...]
    return losses, ncorrect / ninstances
end

validate (generic function with 1 method)

### RAM Initializer

In [15]:
function initopt!(model, optimizer="Adam(lr=0.001, gclip = 5)")
    for par in params(model)
        par.opt = eval(Meta.parse(optimizer))
    end
end

ram = RAM(SDX, GLIMPSE_COUNT, BANDWIDTH, NUM_PATCHES, NUM_CLASSES, SCALE, GSIZE, HSIZE);
record = params(ram);
initopt!(ram)

In [None]:
# ram = Knet.load(FILENAME, "RAM");

In [None]:
epochs = 1000
history = []
loss(x, yref) = sum(ram(x, yref)[1:3])
gradients = []
best_acc = 0.0
for epoch = 1:epochs
    losses = []
    for (x, y) in dtrn
        lss = @diff loss(x, y)
        push!(losses, value(lss))
        for par in params(ram)
            g = grad(lss, par)          
#             push!(gradients, g)
            update!(value(par), g, par.opt)
        end
    end
    println("Loss for #epoch : ", mean(losses))
    #progress!(sgd(ram, dtrn; lr=0.001, gclip=0))
    trn_losses, trn_acc = validate(ram, dtrn);
    tst_losses, tst_acc = validate(ram, dtst);
    println(
        "epoch=$(1+length(history)) ",
        "trnloss=$(trn_losses), trnacc=$trn_acc, ",
        "tstloss=$(tst_losses), tstacc=$tst_acc")
    push!(history, ([trn_losses..., trn_acc, tst_losses..., tst_acc]));

end

Loss for #epoch : -22.350313
epoch=1 trnloss=[-28.0982, 0.93616, 0.12143, -29.1558], trnacc=0.12143055, tstloss=[-11.335, 1.97082, 0.0555162, -13.3614], tstacc=0.055516668
Loss for #epoch : -30.787079
epoch=2 trnloss=[-30.7534, 0.731827, 0.131592, -31.6168], trnacc=0.13159166, tstloss=[-14.1088, 1.83736, 0.0664999, -16.0127], tstacc=0.0665
Loss for #epoch : -32.037098
epoch=3 trnloss=[-31.8658, 0.632044, 0.135744, -32.6336], trnacc=0.13574444, tstloss=[-14.9966, 1.77776, 0.0700833, -16.8445], tstacc=0.070083335
Loss for #epoch : -32.735504
epoch=4 trnloss=[-32.4518, 0.600307, 0.138169, -33.1903], trnacc=0.13816944, tstloss=[-10.8909, 2.06365, 0.0540333, -13.0086], tstacc=0.05403333
Loss for #epoch : -33.22399
epoch=5 trnloss=[-32.7726, 0.574451, 0.139433, -33.4865], trnacc=0.13943334, tstloss=[-9.70992, 2.69893, 0.05195, -12.4608], tstacc=0.05195
Loss for #epoch : -33.61443
epoch=6 trnloss=[-33.364, 0.519892, 0.141539, -34.0254], trnacc=0.14153889, tstloss=[-12.6103, 2.30332, 0.06245, 

In [None]:
Knet.save("RAM_Fleer.jld2", "RAM", ram)

In [None]:
Knet.@save ".jld2" history

In [None]:
ram = nothing
Knet.gc()