In [1]:
for p in ("Knet","ArgParse","PlyIO","Plots","JLD")
    Pkg.installed(p) == nothing && Pkg.add(p)
end
using PlyIO
using Knet
using ArgParse
using JLD
using Plots



[1m[36mINFO: [39m[22m[36mCloning cache of PlyIO from https://github.com/FugroRoames/PlyIO.jl.git
[39m[1m[36mINFO: [39m[22m[36mInstalling PlyIO v0.0.3
[39m[1m[36mINFO: [39m[22m[36mPackage database updated
[39m[1m[36mINFO: [39m[22m[36mMETADATA is out-of-date — you may not have the latest version of PlyIO
[39m[1m[36mINFO: [39m[22m[36mUse `Pkg.update()` to get the latest versions of your packages
[39m[1m[36mINFO: [39m[22m[36mCloning cache of Contour from https://github.com/JuliaGeometry/Contour.jl.git
[39m[1m[36mINFO: [39m[22m[36mCloning cache of Measures from https://github.com/JuliaGraphics/Measures.jl.git
[39m[1m[36mINFO: [39m[22m[36mCloning cache of PlotThemes from https://github.com/JuliaPlots/PlotThemes.jl.git
[39m[1m[36mINFO: [39m[22m[36mCloning cache of PlotUtils from https://github.com/JuliaPlots/PlotUtils.jl.git
[39m[1m[36mINFO: [39m[22m[36mCloning cache of Plots from https://github.com/JuliaPlots/Plots.jl.git
[39m[1m[

In [2]:
###EXTRA FUNC

function preprocess(dataset, a)
    r = length(dataset)
    c = length(dataset[1])
    datasetmat = zeros(r, c)
    for i=1:r
        for j=1:c
            datasetmat[i,j] = dataset[i][j]
        end
    end
    minarray = minimum(datasetmat,1)
    maxarray = maximum(datasetmat,1)
    denom = maxarray.-minarray
    denom = denom + 1e-6 * (denom.==0)

    for i=1:r
        for j=1:c
            dataset[i][j] = 2 * a * ((dataset[i][j]-minarray[j]) / denom[j]) - a
        end
    end 

    return dataset
end

##LEAKY RELU
function lrelu(x, alpha=0.2)
    pos = max(0,x)
    neg = min(0,x) * alpha
    return pos + neg
end

###EUC DIST, Nearest Neighbor
function nn(x, A)
    smallestdist = norm(x-Array{Float64,2}(A[1]))
    smallestidx = 1
    for i in 1:length(A)
        vec=Array{Float64,2}(A[i])
        if norm(x-vec)<smallestdist
            smallestdist = norm(x-vec)
            smallestidx = i
        end
    end
    closestmodel = A[smallestidx]
    return smallestidx
end

nn (generic function with 1 method)

In [3]:
#####TRIAL


In [4]:
#####NETWORK

const F = Float64

function encode(ϕ, x)
    x = mat(x)
    x = lrelu.(ϕ[1]*x .+ ϕ[2])
    μ = ϕ[3]*x .+ ϕ[4]
    logσ² = ϕ[5]*x .+ ϕ[6]
    return μ, logσ²
end

function decode(θ, z)
    z = lrelu.(θ[1]*z .+ θ[2])
    return tanh.(θ[3]*z .+ θ[4])
end

function binary_cross_entropy(x, x̂)
    s = @. x * log(x̂ + F(1e-10)) + (1-x) * log(1 - x̂ + F(1e-10))
    return -mean(s)
end

function output(w, x, nθ)
    θ, ϕ = w[1:nθ], w[nθ+1:end]
    μ, logσ² = encode(ϕ, x)
    nz, M = size(μ)
    σ² = exp.(logσ²)
    σ = sqrt.(σ²)

    #KL =  -sum(@. 1 + logσ² - μ*μ - σ²) / 2
    # Normalise by same number of elements as in reconstruction
    #KL /= M*28*28 ###1

    z = μ .+ randn!(similar(μ)) .* σ
    x̂ = decode(θ, z)
    #BCE = binary_cross_entropy(mat(x), x̂) ####2

    return x̂
end

function loss(w, x, nθ)
    θ, ϕ = w[1:nθ], w[nθ+1:end]
    μ, logσ² = encode(ϕ, x)
    nz, M = size(μ)
    σ² = exp.(logσ²)
    σ = sqrt.(σ²)

    KL =  -sum(@. 1 + logσ² - μ*μ - σ²) / 2
    # Normalise by same number of elements as in reconstruction
    #KL /= M*28*28 ###1

    z = μ .+ randn!(similar(μ)) .* σ
    x̂ = decode(θ, z)
    BCE = mean(abs2,mat(x)-x̂)
    #BCE = binary_cross_entropy(mat(x), x̂) ####2

    return 500*BCE + KL/2
end

function z_out(w, x, nθ)
    θ, ϕ = w[1:nθ], w[nθ+1:end]
    μ, logσ² = encode(ϕ, x)
    nz, M = size(μ)
    σ² = exp.(logσ²)
    σ = sqrt.(σ²)
    z = μ .+ randn!(similar(μ)) .* σ
    return z
end

function mseloss(w, x, nθ)
    θ, ϕ = w[1:nθ], w[nθ+1:end]
    μ, logσ² = encode(ϕ, x)
    nz, M = size(μ)
    σ² = exp.(logσ²)
    σ = sqrt.(σ²)

    z = μ .+ randn!(similar(μ)) .* σ
    x̂ = decode(θ, z)
    BCE = mean(abs2,mat(x)-x̂)
    #BCE = binary_cross_entropy(mat(x), x̂) ####2

    return BCE
end

function KLloss(w, x, nθ)
    θ, ϕ = w[1:nθ], w[nθ+1:end]
    μ, logσ² = encode(ϕ, x)
    nz, M = size(μ)
    σ² = exp.(logσ²)
    σ = sqrt.(σ²)

    KL =  -sum(@. 1 + logσ² - μ*μ - σ²) / 2
    # Normalise by same number of elements as in reconstruction
    #KL /= M*28*28 ###1

    return KL
end

function aveloss(θ, ϕ, data) #####3
    ls = F(0)
    nθ = length(θ)
    #for (x, y) in data
    for x in data
        ls += loss([θ; ϕ], x, nθ)
    end
    return ls / length(data)
end

function train!(θ, ϕ, data, opt; epochs=1)
    w = [θ; ϕ]
    for epoch=1:epochs
        for (x, y) in data
            dw = grad(loss)(w, x, length(θ))
            update!(w, dw, opt)
        end
    end
    return θ, ϕ
end

function weights(nz, nh; atype=Array{F})
    θ = [  # z->x
        xavier(nh, nz),
        zeros(nh),
        xavier(10002*3, nh), #x
        zeros(10002*3)
        ]
    θ = map(a->convert(atype,a), θ)

    ϕ = [ # x->z
        xavier(nh, 10002*3),
        zeros(nh),
        xavier(nz, nh), #μ
        zeros(nz),
        xavier(nz, nh), #σ
        zeros(nz)
        ]
    ϕ = map(a->convert(atype,a), ϕ)

    return θ, ϕ
end




weights (generic function with 1 method)

In [5]:
function main(args="")
    s = ArgParseSettings()
    s.description="Variational Auto Encoder on MNIST dataset."
    s.exc_handler=ArgParse.debug_handler
    @add_arg_table s begin
        ("--seed"; arg_type=Int; default=-1; help="random number seed: use a nonnegative int for repeatable results")
        ("--batchsize"; arg_type=Int; default=100; help="minibatch size")
        ("--epochs"; arg_type=Int; default=10; help="number of epochs for training")
        ("--nh"; arg_type=Int; default=300; help="hidden layer dimension")
        ("--nz"; arg_type=Int; default=2; help="encoding dimention")
        ("--lr"; arg_type=Float64; default=1e-5; help="learning rate")
        ("--atype"; default=(gpu()>=0 ? "KnetArray{F}" : "Array{F}"); help="array type: Array for cpu, KnetArray for gpu")
        ("--infotime"; arg_type=Int; default=2; help="report every infotime epochs")
    end
    isa(args, String) && (args=split(args))
    if in("--help", args) || in("-h", args)
        ArgParse.show_help(s; exit_when_done=false)
        return
    end
    o = parse_args(args, s; as_symbols=true)

    atype = eval(parse(o[:atype]))
    info("using ", atype)
    o[:seed] > 0 && setseed(o[:seed])

    θ, ϕ = weights(o[:nz], o[:nh], atype=atype)
    w = [θ; ϕ]
    opt = optimizers(w, Adam, lr=o[:lr])

    path = "jumping/"
    files = readdir(path)
    dataset=Any[]
    for i in files
        data0 = load_ply(path * i)
        datax = data0["vertex"]["x"]
        datay = data0["vertex"]["y"]
        dataz = data0["vertex"]["z"]
        data1 = hcat(datax,datay)
        data1 = hcat(data1,dataz)
        data1=data1/2
        #data1[:,2]=data1[:,2]+0.5
        data1[:,3]=data1[:,3]+0.5 
        data1 = reshape(data1,30006,1)
        dataset = push!(dataset,atype(data1))
    end
    
    
    ######dataset = preprocess(dataset, 0.9)
    
    #println(typeof(dataset))
    #println(size(dataset))
    #dataset = reshape(dataset,12500,3,1,72)
    #println(size(dataset))
    #dataset = minibatch(dataset,1;xtype=atype())

    #for (d1,d2) in dataset
    #    println(d1)
    #    break
    #end
    #println(size(dataset[1]))
    randomout = output(w, atype(reshape(dataset[1],30006,1)), length(θ))
    #return reshape(randomout,12500,3)

    #####TRAINING PART#######
    report(epoch) = begin
            #dtrn = minibatch(xtrn, ytrn, o[:batchsize]; xtype=atype)
            #dtst = minibatch(xtst, ytst, o[:batchsize]; xtype=atype)
            println((:epoch, epoch,
                     :trn, aveloss(θ, ϕ, dataset),
                     :length, length(dataset)))
    end

    report(0); tic()
    ls_list = []
    zlist = Any[]
    
    @time for epoch=1:o[:epochs]
        ls = 0
        kldiv = 0
        for x in dataset      
            ls += mseloss(w, atype(reshape(x,30006,1)), length(θ))
            kldiv += KLloss(w, atype(reshape(x,30006,1)), length(θ))
        end
        
        
        for x in dataset      
            dw = grad(loss)(w, atype(reshape(x,30006,1)), length(θ))
            update!(w, dw, opt)
        end   
        if (epoch % 500 == 0)
            opt = optimizers(w, Adam, lr=1e-6)
        end
        if (epoch % 800 == 0)
            opt = optimizers(w, Adam, lr=1e-7)
        end
            
        
        println(kldiv/length(dataset))
        println(ls/length(dataset))
        push!(ls_list, ls/length(dataset))
        (epoch % o[:infotime] == 0) && (report(epoch); toc(); tic())
        #println(23)
    end; toq()
    for x in dataset      
            z = z_out(w, atype(reshape(x,30006,1)), length(θ))
            push!(zlist,z)
        end
    
    samplein = dataset[1]
    sampleout = output(w, atype(reshape(dataset[1],30006,1)), length(θ))
    sampleout = reshape(sampleout,10002,3)
    return θ, ϕ, reshape(randomout,10002,3), w, sampleout, samplein, ls_list, zlist

end

#return randomout

#end # module

p1, p2, randomout, w, sampleout, samplein, ls_list, zlist = main("--infotime 1 --seed 1 --epochs 300 --lr 1e-5");

gpu()

[1m[36mINFO: [39m[22m[36musing Knet.KnetArray{Float64,N} where N
[39m

(:epoch, 0, :trn, 84.81782621220134, :length, 149)
0.043510120020972105
0.16958583939652286
(:epoch, 1, :trn, 75.80014178785974, :length, 149)
elapsed time: 10.36180048 seconds
13.978874852998747
0.13763665435525033
(:epoch, 2, :trn, 49.19718581462073, :length, 149)
elapsed time: 6.89378711 seconds
33.28823634739997
0.06551163887990202
(:epoch, 3, :trn, 36.37221522447385, :length, 149)
elapsed time: 6.86980668 seconds
29.238193261068865
0.04281496587151342
(:epoch, 4, :trn, 30.69816705977263, :length, 149)
elapsed time: 6.872519394 seconds
24.823769617838614
0.03644656991988097
(:epoch, 5, :trn, 30.081389184146406, :length, 149)
elapsed time: 6.867675131 seconds
18.868053781508262
0.041741558807132026
(:epoch, 6, :trn, 50.3298211046633, :length, 149)
elapsed time: 7.014460423 seconds
12.730826436478333
0.08746073886697053
(:epoch, 7, :trn, 67.23955590761028, :length, 149)
elapsed time: 6.900252708 seconds
9.891136765198215
0.12163841060496869
(:epoch, 8, :trn, 47.0504388999336, :length

(:epoch, 66, :trn, 16.16861471577963, :length, 149)
elapsed time: 6.872934617 seconds
3.417094844197058
0.02860830946021448
(:epoch, 67, :trn, 14.433148711301257, :length, 149)
elapsed time: 6.87928476 seconds
3.3538922130595483
0.027600271426654648
(:epoch, 68, :trn, 9.617743529977306, :length, 149)
elapsed time: 6.885033836 seconds
5.068458312199438
0.013856034618150746
(:epoch, 69, :trn, 8.84477659651702, :length, 149)
elapsed time: 6.987119752 seconds
5.862097937438066
0.011440419519504337
(:epoch, 70, :trn, 8.82810215256203, :length, 149)
elapsed time: 6.87928335 seconds
8.14169874937159
0.010035958599228349
(:epoch, 71, :trn, 11.983167333009916, :length, 149)
elapsed time: 6.865954784 seconds
11.187551437072198
0.012684158036631803
(:epoch, 72, :trn, 16.405701906581754, :length, 149)
elapsed time: 6.871491133 seconds
13.9914052824777
0.01882622545428045
(:epoch, 73, :trn, 21.06239082149129, :length, 149)
elapsed time: 6.91521064 seconds
16.903686385548905
0.025673940229112634
(:e

10.4526333644985
0.015788619852339827
(:epoch, 132, :trn, 15.209132111542221, :length, 149)
elapsed time: 7.003337695 seconds
11.205423304131008
0.01905306014327736
(:epoch, 133, :trn, 12.672431354985553, :length, 149)
elapsed time: 6.868741089 seconds
9.714441793695881
0.015982008216427375
(:epoch, 134, :trn, 9.882517257211473, :length, 149)
elapsed time: 6.92392917 seconds
6.750962094148049
0.01337145428169229
(:epoch, 135, :trn, 12.316509131329843, :length, 149)
elapsed time: 6.925438108 seconds
6.6070817714029
0.01735594961890022
(:epoch, 136, :trn, 13.060475767830766, :length, 149)
elapsed time: 6.939529799 seconds
5.928238314685585
0.018950112855124168
(:epoch, 137, :trn, 11.261008838036407, :length, 149)
elapsed time: 7.119190059 seconds
5.398839620251408
0.016871826483360296
(:epoch, 138, :trn, 10.505034394713556, :length, 149)
elapsed time: 7.019597273 seconds
4.729628920136088
0.015954559894032858
(:epoch, 139, :trn, 11.36108183142233, :length, 149)
elapsed time: 7.003412189 

(:epoch, 197, :trn, 9.692884812809172, :length, 149)
elapsed time: 6.905594747 seconds
3.5785028990139867
0.016754898074268262
(:epoch, 198, :trn, 9.612213289362629, :length, 149)
elapsed time: 6.905657215 seconds
3.4451299970366316
0.015865965618740024
(:epoch, 199, :trn, 8.243687694281762, :length, 149)
elapsed time: 6.958440971 seconds
2.829273479475968
0.013282525461780305
(:epoch, 200, :trn, 6.80785538593157, :length, 149)
elapsed time: 7.165325318 seconds
3.209643059271361
0.010744347629400625
(:epoch, 201, :trn, 6.7011168343432885, :length, 149)
elapsed time: 6.983679304 seconds
3.050093029397473
0.009964853342001301
(:epoch, 202, :trn, 6.849574670286269, :length, 149)
elapsed time: 6.906598322 seconds
4.160911573944844
0.009443274581514874
(:epoch, 203, :trn, 8.762720520526942, :length, 149)
elapsed time: 6.895228911 seconds
6.491813925299872
0.01091568065547736
(:epoch, 204, :trn, 12.148240306705386, :length, 149)
elapsed time: 6.975683686 seconds
7.812989291978657
0.016808773

(:epoch, 262, :trn, 7.522250578266919, :length, 149)
elapsed time: 6.908056845 seconds
4.254030857907515
0.010373596533356722
(:epoch, 263, :trn, 8.98495759570806, :length, 149)
elapsed time: 7.106984687 seconds
6.05697531445986
0.011853921735014027
(:epoch, 264, :trn, 11.149059000934253, :length, 149)
elapsed time: 6.908781225 seconds
5.908989172143528
0.01565321063598871
(:epoch, 265, :trn, 8.32377996782434, :length, 149)
elapsed time: 6.896377676 seconds
5.346747014987975
0.010981028358195736
(:epoch, 266, :trn, 5.84844463113095, :length, 149)
elapsed time: 6.897569745 seconds
3.2971379412606896
0.008352542665068338
(:epoch, 267, :trn, 5.155775883064917, :length, 149)
elapsed time: 6.893999092 seconds
2.6944223890846755
0.007239315000956247
(:epoch, 268, :trn, 5.243256362890276, :length, 149)
elapsed time: 7.009307111 seconds
2.191928228704971
0.008139351960709139
(:epoch, 269, :trn, 5.463796910499198, :length, 149)
elapsed time: 6.886768078 seconds
2.3045531984726777
0.008698411026

0

In [None]:
println(size(randomout))
sampleout=Array{Float64,2}(sampleout)
writedlm("test5.xyz", sampleout)
#println(randomout)

samplein=reshape(samplein,10002,3)
samplein=Array{Float64,2}(samplein)
println(maximum(samplein[:,1]), minimum(samplein[:,1]))
println(maximum(samplein[:,2]), minimum(samplein[:,2]))
maximum(samplein[:,3]), minimum(samplein[:,3])
writedlm("samplein.xyz", samplein)
#dataset[1]

In [None]:
###### IF LOAD MODEL
model = load("model_kl_jumping128_1_600epoch_-5lr_tanhout(1).jld")
w = model["w"]
zlist = model["zlist"]
ls_list = model["ls_list"]

In [None]:
x1 = Array{F}(zlist[1])
y1 = Array{F}(zlist[2])
scatter(x1,y1,reuse=false)

In [None]:
plot(ls_list)


In [None]:
ls_list

In [None]:
atype = KnetArray{F}
path = "jumping/"
    files = readdir(path)
    dataset=Any[]
    for i in files
        data0 = load_ply(path * i)
        datax = data0["vertex"]["x"]
        datay = data0["vertex"]["y"]
        dataz = data0["vertex"]["z"]
        data1 = hcat(datax,datay)
        data1 = hcat(data1,dataz)
        data1=data1/2
        #data1[:,2]=data1[:,2]+0.5
        data1[:,3]=data1[:,3]+0.5 
        data1 = reshape(data1,30006,1)
        dataset = push!(dataset,atype(data1))
    end
samplein = dataset[1]
thet, yy = weights(128, 300, atype=atype)
for i in 1:length(dataset)
    sample = dataset[i]
    sampleout = output(w, atype(reshape(dataset[i],30006,1)), length(thet))
    sampleout = reshape(sampleout,10002,3)
    sampleout=Array{Float64,2}(sampleout)
    writedlm("outputDL/meshout"*string(i)*".xyz", sampleout)
    writedlm("outputDL/meshin"*string(i)*".xyz", reshape(Array{Float64,2}(sample),10002,3))
end

In [112]:
zlistFloat = zeros(Float64,length(zlist),length(zlist[1]))
for i in 1:length(zlist)
    for j in 1:length(zlist[1])
        zlistFloat[i,j] = zlist[i][j]
    end
end
zlistFloat[:,1]
x1 = Array{F}(zlistFloat[:,1])
y1 = Array{F}(zlistFloat[:,2])
scatter(x1,y1)
x2 = Array{F}(zlistFloat[1:20,1])
y2 = Array{F}(zlistFloat[1:20,2])
scatter!(x2,y2,color = "orange")
x2 = Array{F}(zlistFloat[136:149,1])
y2 = Array{F}(zlistFloat[136:149,2])
scatter!(x2,y2,color = "orange")
x2 = Array{F}(zlistFloat[20:37,1])
y2 = Array{F}(zlistFloat[20:37,2])
scatter!(x2,y2,color = "red")
x2 = Array{F}(zlistFloat[38:80,1])
y2 = Array{F}(zlistFloat[38:80,2])
scatter!(x2,y2,color = "green")
x2 = Array{F}(zlistFloat[81:90,1])
y2 = Array{F}(zlistFloat[81:90,2])
scatter!(x2,y2,color = "red")
x2 = Array{F}(zlistFloat[91:95,1])
y2 = Array{F}(zlistFloat[91:95,2])
scatter!(x2,y2,color = "red")
x2 = Array{F}(zlistFloat[96:101,1])
y2 = Array{F}(zlistFloat[96:101,2])
scatter!(x2,y2,color = "green")
x2 = Array{F}(zlistFloat[102:126,1])
y2 = Array{F}(zlistFloat[102:126,2])
scatter!(x2,y2,color = "purple")
x2 = Array{F}(zlistFloat[127:130,1])
y2 = Array{F}(zlistFloat[127:130,2])
scatter!(x2,y2,color = "green")
x2 = Array{F}(zlistFloat[131:135,1])
y2 = Array{F}(zlistFloat[131:135,2])
scatter!(x2,y2,color = "red")
ix = 150
x2 = Array{F}(zlistFloat[[ix],1])
y2 = Array{F}(zlistFloat[[ix],2])
scatter!(x2,y2,color = "cyan")




LoadError: [91mBoundsError: attempt to access 149×2 Array{Float64,2} at index [[150], 1][39m

In [None]:
nθ = length(thet)
θ, ϕ = w[1:nθ], w[nθ+1:end]
z2out = decode(θ, zlist[113])
z2out = reshape(z2out,10002,3)
z2out=Array{Float64,2}(z2out)
writedlm("z2out1.xyz", z2out)
z2out = decode(θ, zlist[97])
z2out = reshape(z2out,10002,3)
z2out=Array{Float64,2}(z2out)
writedlm("z2out2.xyz", z2out)

###INBETWEEN
d = zlist[113]-zlist[97]
z2out = decode(θ, zlist[97]+d/5)
z2out = reshape(z2out,10002,3)
z2out=Array{Float64,2}(z2out)
writedlm("nearest1.xyz", z2out)
z2out = decode(θ, zlist[97]+(2*d)/5)
z2out = reshape(z2out,10002,3)
z2out=Array{Float64,2}(z2out)
writedlm("nearest2.xyz", z2out)
z2out = decode(θ, zlist[97]+(3*d)/5)
z2out = reshape(z2out,10002,3)
z2out=Array{Float64,2}(z2out)
writedlm("nearest3.xyz", z2out)
z2out = decode(θ, zlist[97]+(4*d)/5)
z2out = reshape(z2out,10002,3)
z2out=Array{Float64,2}(z2out)
writedlm("nearest4.xyz", z2out)

In [7]:
#p1, p2, randomout, w, sampleout, samplein, ls_list, lsTest_list, zlist

save("model_kl_jumping_z2embedding_100_500epoch_-5lr_tanhout.jld", "w", w, "zlist", zlist, "ls_list", ls_list)


In [None]:
nn(reshape(z2out,30006,1), dataset)
nearestdata1 = reshape(dataset[13],10002,3)
nearestdata1 = Array{Float64,2}(nearestdata1)
writedlm("nearestdata1.xyz", nearestdata1)
