# Differential Programming for Graph embeding problem

In [11]:
using Zygote
using Statistics: var, mean
using LinearAlgebra: norm
using Flux.NNlib: relu
using Flux.Optimise

# Define the variance

In [2]:
function myvar(v)
    mv  = mean(v)
    sum((v .- mv).^2)./(length(v)-1)
end

myvar (generic function with 1 method)

# Define the bonds of a Peterson graph

In [12]:
L1 = [(1,6), (2,7), (3,8), (4,9), (5,10), (1,2), (2,3), (3,4), (4,5), (5,1), (6,8), (8,10), (10,7), (7,9), (9,6)]
L1 = [i<j ? (i,j) : (j,i) for (i,j) in L1]

15-element Array{Tuple{Int64,Int64},1}:
 (1, 6) 
 (2, 7) 
 (3, 8) 
 (4, 9) 
 (5, 10)
 (1, 2) 
 (2, 3) 
 (3, 4) 
 (4, 5) 
 (1, 5) 
 (6, 8) 
 (8, 10)
 (7, 10)
 (7, 9) 
 (6, 9) 

# Disconnected bonds

In [13]:
LL = Any[]
for i=1:9
    for j=i+1:10
        push!(LL, (i,j))
    end
end

In [14]:
L2 = setdiff(LL, L1)

30-element Array{Any,1}:
 (1, 3) 
 (1, 4) 
 (1, 7) 
 (1, 8) 
 (1, 9) 
 (1, 10)
 (2, 4) 
 (2, 5) 
 (2, 6) 
 (2, 8) 
 (2, 9) 
 (2, 10)
 (3, 5) 
 ⋮      
 (4, 7) 
 (4, 8) 
 (4, 10)
 (5, 6) 
 (5, 7) 
 (5, 8) 
 (5, 9) 
 (6, 7) 
 (6, 10)
 (7, 8) 
 (8, 9) 
 (9, 10)

# Loss function

In [15]:
function loss(x)
    a = [norm(x[:,i]-x[:,j]) for (i, j) in L1]
    b = [norm(x[:,i]-x[:,j]) for (i, j) in L2]
    myvar(a) + myvar(b) + exp(relu(-mean(b) + mean(a) + 0.1))
end

loss (generic function with 1 method)

# The training

In [16]:
function train(params)
    opt = ADAM(0.01)
    maxiter = 2000
    msk = fill(true, size(params, 2))
    msk[[1, 2]] .= false
    pp = params[:,msk]
    for i=1:maxiter
        grad = view(loss'(params), :,msk)
        Optimise.update!(opt, pp, grad)
        view(params, :, msk) .= pp
        if i%100 == 0
            @show loss(params)
        end
    end
    params
end

train (generic function with 1 method)

In [19]:
params = randn(5, 10)
params = train(params)

loss(params) = 1.6687414155405986
loss(params) = 1.4157332623786845
loss(params) = 1.2533929463104143
loss(params) = 1.1396410068075318
loss(params) = 1.0477527911104751
loss(params) = 1.0139284709778715
loss(params) = 1.0051470918195975
loss(params) = 1.0017648391030218
loss(params) = 1.0005415926275132
loss(params) = 1.0001474502510763
loss(params) = 1.0000355434761794
loss(params) = 1.0000075735386393
loss(params) = 1.0000014222353657
loss(params) = 1.0000002343365628
loss(params) = 1.00000003368238
loss(params) = 1.0000000041945414
loss(params) = 1.0000000004490492
loss(params) = 1.000000000040968
loss(params) = 1.0000000000031546
loss(params) = 1.000000000000203


5×10 Array{Float64,2}:
 -0.966946  -1.61448   -1.04189   …   0.886009  -0.581735   1.54834  
 -0.194974  -0.67774   -1.23952      -1.99196   -0.642991  -0.0601035
 -0.158067   0.306378  -1.67698      -1.15769    0.259446  -0.0737423
  3.12774    1.42253    0.53096       1.33       0.333616   1.08056  
  1.52957    0.210582   0.589345      1.15895    3.1729     0.830863 

# Check Results

In [20]:
[norm(params[:,i]-params[:,j]) for (i,j) in L1]

15-element Array{Float64,1}:
 2.3485171583891953
 2.34851724011789  
 2.3485172751265435
 2.348517316152837 
 2.3485173741620464
 2.3485185579984065
 2.3485170233215147
 2.348517345362473 
 2.3485172970957717
 2.3485170780845173
 2.3485173406927196
 2.348517329213868 
 2.3485173088303415
 2.348517377963753 
 2.3485173715689505

In [21]:
[norm(params[:,i]-params[:,j]) for (i,j) in L2]

30-element Array{Float64,1}:
 3.321305506187664 
 3.321305269743166 
 3.321304346986394 
 3.321304943048185 
 3.3213049873051133
 3.321304509836133 
 3.32130473523949  
 3.321305436276903 
 3.32130540170377  
 3.321305242708751 
 3.321304876172061 
 3.3213055012778447
 3.321304845441519 
 ⋮                 
 3.321305399174512 
 3.3213051954153916
 3.3213051251865466
 3.321305326984935 
 3.3213049856062824
 3.3213051915654224
 3.3213051859553624
 3.3213049039932123
 3.3213052787609407
 3.321305038880893 
 3.321305068274501 
 3.321305056728446 