In [1]:
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim
include("/Users/piotrsokol/Documents/RNNODE.jl/src/rnn_ode.jl")
using Zygote
using Flux: logitcrossentropy
using Flux.Data: DataLoader
using MLDatasets, NNlib, MLDataUtils
#ENV["PYTHON"] = "/Users/piotrsokol/anaconda3/envs/bortho/bin/python"
#using Pkg; Pkg.build("PyCall")
using PyCall
using CUDA
# using Parameters: @with_kw, @unpack
import Statistics: mean
using BSON,NPZ
using UUIDs
using EllipsisNotation
using ProgressMeter
using ArgParse
FT = Float32

Float32

In [28]:
function onehot(labels_raw; ntoken::Int=9)
    return  convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:ntoken)))
end

function preprocess_img(imgs, p)
    imgs = reshape( imgs, prod(size(imgs)[1:2]), size(imgs)[3] ) |> x-> FT.(x)
    imgs = map(x-> permute!(x,p), eachcol(imgs)) |> x-> hcat(x...)
    imgs .-= 0.13066047
    imgs./= 0.30810785
    return imgs
end
function get_data(batchsize, device, set::Symbol; train_split=58999/60000)
    py_seeds= pyimport("data_utils").__fixed_seeds__
    py_random=pyimport("numpy.random");
    seeded_rng = py_random.RandomState(py_seeds["image"])
    p = seeded_rng.permutation(1:784)

    X, labels_raw = MNIST.traindata()
    Y = onehot(labels_raw)
    X = preprocess_img(X, p)

    (x_train, y_train), (x_valid, y_valid) = stratifiedobs((X, Y),p = train_split)
    
    if set!= :test
        return (
                DataLoader(device.(collect.((x_train, y_train))); batchsize = batchsize, shuffle = true),
                DataLoader(device.(collect.((x_valid, y_valid))); batchsize = batchsize, shuffle = false) )
    elseif set == :test
        X, labels_raw = MNIST.testdata()
        Y = onehot(labels_raw)
        X = preprocess_img(X, p)
        return (
                DataLoader(device.(collect.((x_train, y_train))); batchsize = batchsize, shuffle = true),
                DataLoader(device.(collect.((x_valid, y_valid))); batchsize = batchsize, shuffle = false),
                DataLoader(device.(collect.((X, Y))); batchsize = batchsize, shuffle = false)
            )
    end
end

function get_network(alpha, architecture, initializer, isize, hsize,osize, tsteps, interpolation)
    if architecture == "RNN_TANH"
        ∂rnncell = ∂RNNCell
    elseif architecture == "GRU"
        ∂rnncell = ∂GRUCell
    else
        ∂rnncell = ∂LSTMCell
    end
    """
    Ternary op -> reads as if initializer == "limitcycle" or architecture == "LSTM" use two argument function dispatch, else additionally pass initializer variable
    """
    ∂rnn = initializer == "limitcycle" || architecture == "LSTM" ? ∂rnncell(isize, hsize) : ∂rnncell(isize,hsize,Flux.glorot_uniform)

    node = RNNODE(∂rnn, (0.f0, tsteps[end]), preprocess=x-> FT.(permutedims(x)), save_end=true, save_start=false, saveat=collect(0.f0:tsteps[end]) )

    println(interpolation)
    function interpolate(x)
        X = Zygote.ignore() do
            permutedims(x) |> interpolation
        end
    end
    return Chain( interpolate, node, Array, x-> x[:,:,end], Dense(hsize, osize) )
end

classify(x) = argmax.(eachcol(x))

function evaluate_set(model, data, 𝓁array, accarray, ℒ, set)
    loss_set = Float32[]
    total_correct = 0
    total = 0
    @showprogress "Evaluating $set "  for (x,y) in data
        ŷ = model(x)
        𝓁 = ℒ(ŷ,y)
        push!(loss_set, 𝓁[1])
        target_class = classify(cpu(y))
        predicted_class = classify(cpu(ŷ))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    push!(accarray, (total_correct / total)[1] )
    push!(𝓁array, mean(loss_set) )
    return nothing
end


evaluate_set (generic function with 1 method)

In [29]:
α = 1.5 |> FT
η = 1e-3 |> FT
optimizer = ADAM
Random.seed!(1);
device = cpu
hpsearch = true
sets = hpsearch ? [:valid] : [:valid,:test] 
architecture = "RNN_TANH"
initializer = "limitcycle"
hidden_size = 512
interpolation = "CubicSplineFixedGrid"
python_code_dir = "/Users/piotrsokol/Documents/block-orthogonal/src/"
bs = 100
gradient_clipping = FT(1e3)
py"""
import sys
sys.path.insert(0, $python_code_dir)
"""
train_loader,valid_loader = get_data(bs, device, :valid)
        eval_sets = Dict(:valid=>valid_loader)
metrics = Dict(:test=>Dict(:loss=>FT[], :accuracy=>FT[]), :valid=>Dict(:loss=>FT[], :accuracy=>FT[]))
tsteps =  FT(784)

nn = get_network(α, architecture, initializer, 1, hidden_size,10, tsteps, eval(Symbol(interpolation)));
opt = Flux.Optimiser(ClipValue(gradient_clipping), eval(optimizer)(η))
ℒ(ŷ,y) = logitcrossentropy(ŷ,y)

CubicSplineFixedGrid


ℒ (generic function with 1 method)

In [19]:
X,Y = first(train_loader)

(Float32[1.7651266 -0.42407382 … -0.42407382 2.7833593; -0.42407382 -0.42407382 … -0.42407382 -0.42407382; … ; -0.42407382 -0.42407382 … -0.42407382 -0.42407382; -0.42407382 -0.42407382 … -0.42407382 -0.42407382], [0 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 0 0; 0 1 … 1 0])

In [31]:
Flux.train!((x,y)->ℒ(nn(x),y), params(nn), train_loader, opt)