In [1]:
using Zygote
using Statistics: var, mean
using LinearAlgebra: norm
using Flux.NNlib: relu

In [2]:
function myvar(v)
    mv  = mean(v)
    sum((v .- mv).^2)./(length(v)-1)
end

myvar (generic function with 1 method)

In [3]:
l1_pairs = [(1,5), (2,6), (3,7), (4,8)]
l2_pairs = [(1,2), (2,3), (3,4), (4,1), (5,6), (6,7), (7,8), (8,5),
            (1,6), (1, 8), (2,7), (2,5), (3,6), (3,8), (4,7),(4,5)]
l3_pairs = [(1,3), (2,4), (1,7), (2,8), (3,5), (4,6), (5,7), (8,6)]

function loss(x)
    a = [norm(x[:,i]-x[:,j]) for (i, j) in l1_pairs]
    b = [norm(x[:,i]-x[:,j]) for (i, j) in l2_pairs]
    c = [norm(x[:,i]-x[:,j]) for (i, j) in l3_pairs]
    myvar(a) + myvar(b) + myvar(c) + exp(relu(-mean(b) + mean(a) + 0.1)) + exp(relu(-mean(c) + mean(b) + 0.1))
end

loss (generic function with 1 method)

In [4]:
loss(randn(2, 8))

4.459480414841205

In [5]:
loss'(randn(4, 8))

4×8 Array{Float64,2}:
  0.550572   0.0913925  0.372524   …  -0.728186  -0.381611   0.295862  
 -0.791831   0.186491   0.536327       0.420596  -0.900933  -0.24964   
  0.870882   0.854004   0.610873      -0.916425  -0.542903  -0.347213  
 -0.37948   -0.570891   0.0902415      0.486969   0.139332   0.00379381

In [6]:
using Flux.Optimise

function train(params)
    opt = ADAM(0.01)
    V = 8
    maxiter = 20000
    msk = [false, true, true, true, false, true, true, true]
    pp = params[:,msk]
    for i=1:maxiter
        grad = view(loss'(params), :,msk)
        Optimise.update!(opt, pp, grad)
        view(params, :, msk) .= pp
        if i%100 == 0
            @show loss(params)
        end
    end
    params
end

train (generic function with 1 method)

In [7]:
params = train(randn(6, 8))

loss(params) = 2.2098784366821973
loss(params) = 2.1361277236835363
loss(params) = 2.1207429569040683
loss(params) = 2.10302242130852
loss(params) = 2.089805767862301
loss(params) = 2.0817529871869294
loss(params) = 2.0752477541724312
loss(params) = 2.069833418679041
loss(params) = 2.064625078613525
loss(params) = 2.0599533989895598
loss(params) = 2.05603000762864
loss(params) = 2.0522464554171584
loss(params) = 2.0488886574188045
loss(params) = 2.0460432814626195
loss(params) = 2.0432679039827137
loss(params) = 2.040749667354423
loss(params) = 2.038923532407595
loss(params) = 2.036633353410319
loss(params) = 2.034726793229125
loss(params) = 2.032974236793671
loss(params) = 2.0313581986192664
loss(params) = 2.029864066358438
loss(params) = 2.0287251951596073
loss(params) = 2.0274506495916
loss(params) = 2.026265966498608
loss(params) = 2.0251577378083523
loss(params) = 2.0241188564367776
loss(params) = 2.0231431750063833
loss(params) = 2.0222252433274504
loss(params) = 2.02136020532919

6×8 Array{Float64,2}:
 -0.820814    17.8921    11.6629   …  16.374      11.1099     5.92653 
  0.874955    -1.11079  -12.1842      -0.228859  -12.2531   -10.2395  
 -0.963763     2.64763    5.60676      4.62266     3.29948   12.8898  
  0.00403347  -5.05025   15.87        -4.46967    15.9292    -3.63922 
  0.581611     5.23987    5.40326      4.67179     7.24354   17.667   
 -0.413453    15.6201     7.64469  …  16.9887      8.34591    0.796128

In [8]:
params

UndefVarError: UndefVarError: params not defined

In [9]:
loss(params)

UndefVarError: UndefVarError: params not defined

In [10]:
[norm(params[:,i]-params[:,j]) for (i,j) in l1_pairs]

UndefVarError: UndefVarError: params not defined

In [11]:
[norm(params[:,i]-params[:,j]) for (i,j) in l2_pairs]

UndefVarError: UndefVarError: params not defined

In [12]:
[norm(params[:,i]-params[:,j]) for (i,j) in l3_pairs]

UndefVarError: UndefVarError: params not defined

In [13]:
Zygote.refresh()

In [14]:
norm(params[:,1]-params[:,3])

UndefVarError: UndefVarError: params not defined

In [15]:
L1 = [(1,6), (2,7), (3,8), (4,9), (5,10), (1,2), (2,3), (3,4), (4,5), (5,1), (6,8), (8,10), (10,7), (7,9), (9,6)]
L1 = [i<j ? (i,j) : (j,i) for (i,j) in L1]

15-element Array{Tuple{Int64,Int64},1}:
 (1, 6) 
 (2, 7) 
 (3, 8) 
 (4, 9) 
 (5, 10)
 (1, 2) 
 (2, 3) 
 (3, 4) 
 (4, 5) 
 (1, 5) 
 (6, 8) 
 (8, 10)
 (7, 10)
 (7, 9) 
 (6, 9) 

In [16]:
LL = Any[]
for i=1:9
    for j=i+1:10
        push!(LL, (i,j))
    end
end

In [225]:
L2 = setdiff(LL, L1)

30-element Array{Any,1}:
 (1, 3) 
 (1, 4) 
 (1, 7) 
 (1, 8) 
 (1, 9) 
 (1, 10)
 (2, 4) 
 (2, 5) 
 (2, 6) 
 (2, 8) 
 (2, 9) 
 (2, 10)
 (3, 5) 
 ⋮      
 (4, 7) 
 (4, 8) 
 (4, 10)
 (5, 6) 
 (5, 7) 
 (5, 8) 
 (5, 9) 
 (6, 7) 
 (6, 10)
 (7, 8) 
 (8, 9) 
 (9, 10)

In [226]:
function loss2(x)
    a = [norm(x[:,i]-x[:,j]) for (i, j) in L1]
    b = [norm(x[:,i]-x[:,j]) for (i, j) in L2]
    myvar(a) + myvar(b) + exp(relu(-mean(b) + mean(a) + 0.1))
end

loss2 (generic function with 1 method)

In [250]:
function train(params)
    opt = ADAM(0.001)
    maxiter = 20000
    msk = fill(true, size(params, 2))
    msk[[1, 2]] .= false
    pp = params[:,msk]
    for i=1:maxiter
        grad = view(loss2'(params), :,msk)
        Optimise.update!(opt, pp, grad)
        view(params, :, msk) .= pp
        if i%100 == 0
            @show loss2(params)
        end
    end
    params
end

train (generic function with 3 methods)

In [251]:
params = randn(7, 10)
params = train(params)

loss2(params) = 3.331592321526906
loss2(params) = 2.914081404912695
loss2(params) = 2.5771553258750526
loss2(params) = 2.2962600851510033
loss2(params) = 2.0513004460282156
loss2(params) = 1.8651966586844886
loss2(params) = 1.7363954653318507
loss2(params) = 1.6414455335144589
loss2(params) = 1.5701551343111244
loss2(params) = 1.51586255741702
loss2(params) = 1.4737555860964107
loss2(params) = 1.4402490862117032
loss2(params) = 1.4126079444400914
loss2(params) = 1.3888145874269247
loss2(params) = 1.3675080586629358
loss2(params) = 1.3479736283615706
loss2(params) = 1.3301023934079441
loss2(params) = 1.314223764697566
loss2(params) = 1.3007786537900932
loss2(params) = 1.2899443717534265
loss2(params) = 1.281582851336879
loss2(params) = 1.2753358359108495
loss2(params) = 1.2707258324147752
loss2(params) = 1.267303280432197
loss2(params) = 1.264701221072482
loss2(params) = 1.262644243719694
loss2(params) = 1.2609349055398853
loss2(params) = 1.2594339066959979
loss2(params) = 1.25804203612

7×10 Array{Float64,2}:
  0.612515    1.41179     1.40957     …  -2.61222    1.6019    -2.47632 
 -1.61431     0.889491    1.0711          2.14935    3.8988     1.6821  
  0.420388    0.0693702  -3.31453        -3.32206    0.934447   1.43253 
  0.0867253   1.71063     0.00994049     -0.956393  -4.12743   -0.917385
 -0.371805   -0.416572   -0.404943        1.03215   -1.23915    3.17223 
  1.51972    -2.76443    -3.30082     …  -0.51173   -0.124884  -1.1929  
  0.933705    1.25518    -2.41411        -2.30603   -0.909133  -2.79326 

In [252]:
[norm(params[:,i]-params[:,j]) for (i,j) in L1]

15-element Array{Float64,1}:
 5.303518746113095 
 5.303499019277379 
 5.303511697455446 
 5.303521236582463 
 5.303400765791085 
 5.3035303418894175
 5.303545511130144 
 5.303558349676117 
 5.303425089395428 
 5.303471641705359 
 5.30353001564777  
 5.303377892744704 
 5.303380568680792 
 5.303434623657267 
 5.303529937047624 

In [253]:
[norm(params[:,i]-params[:,j]) for (i,j) in L2]

30-element Array{Float64,1}:
 7.499917253694322 
 7.500123007765547 
 7.499990560287559 
 7.500122285184035 
 7.499975832400816 
 7.499908918960845 
 7.499972037050174 
 7.499959264708391 
 7.499971899667681 
 7.500027847644169 
 7.500096946334635 
 7.499956524920089 
 7.500045680000539 
 ⋮                 
 7.500108072852169 
 7.500107793811843 
 7.499982343313271 
 7.499923479959791 
 7.499876970472376 
 7.500034017392995 
 7.500138103250206 
 7.50001091903994  
 7.500117648966705 
 7.500043848488527 
 7.5001658036914876
 7.499925531651715 

In [264]:
params = randn(5, 10)
params = train(params)

loss2(params) = 2.228589387426616
loss2(params) = 1.981818726173078
loss2(params) = 1.8003527029950928
loss2(params) = 1.6667746699250827
loss2(params) = 1.570061096751739
loss2(params) = 1.4997614651437283
loss2(params) = 1.4474798883460094
loss2(params) = 1.4079180142375904
loss2(params) = 1.3778243227861693
loss2(params) = 1.354568405419604
loss2(params) = 1.336191793627682
loss2(params) = 1.3210167787312175
loss2(params) = 1.3081224767954636
loss2(params) = 1.2968748804179038
loss2(params) = 1.2869138893840306
loss2(params) = 1.2780144051001656
loss2(params) = 1.2699620010484538
loss2(params) = 1.262591665406957
loss2(params) = 1.2557021804336543
loss2(params) = 1.2490563136087716
loss2(params) = 1.2423372319628452
loss2(params) = 1.2351519427887716
loss2(params) = 1.2272269149131698
loss2(params) = 1.2187715284140526
loss2(params) = 1.2103913271339117
loss2(params) = 1.2025336231178207
loss2(params) = 1.195301397187838
loss2(params) = 1.1885958072115814
loss2(params) = 1.182239850

InterruptException: InterruptException:

In [256]:
[norm(params[:,i]-params[:,j]) for (i,j) in L1]

15-element Array{Float64,1}:
 1.9237660755688748
 1.9237597720101194
 1.9237621252149806
 1.9237644941697498
 1.9237593226791587
 1.9237622743491238
 1.9237619054602841
 1.9237645886320545
 1.9237633737067432
 1.9237631135099424
 1.9237680307163536
 1.9237580482839558
 1.9237537486356715
 1.9237581973749713
 1.923769340117923 

In [257]:
[norm(params[:,i]-params[:,j]) for (i,j) in L2]

30-element Array{Float64,1}:
 2.720611320589652 
 2.7206124195570167
 2.7206068474653455
 2.7206113672247207
 2.7206122679173617
 2.720607362945088 
 2.720612348929594 
 2.7206107321417194
 2.720615649407073 
 2.720610516253029 
 2.7206106868107023
 2.7206060967370416
 2.7206113536682626
 ⋮                 
 2.7206088673349256
 2.7206126024398456
 2.7206076018335317
 2.720617480108722 
 2.7206060152066165
 2.7206102477282594
 2.720611464344555 
 2.720611936464884 
 2.720612560644666 
 2.7206053636081284
 2.7206129200979117
 2.720605186196139 

In [262]:
params = randn(4, 10)
params = train(params)

loss2(params) = 2.076634365822623
loss2(params) = 1.775084423185447
loss2(params) = 1.5912087916970021
loss2(params) = 1.4907662924733356
loss2(params) = 1.430179051717222
loss2(params) = 1.391435035062689
loss2(params) = 1.3642719706109647
loss2(params) = 1.3433396933034518
loss2(params) = 1.3258958426233514
loss2(params) = 1.310506956275508
loss2(params) = 1.2964212576304448
loss2(params) = 1.2832642614835907
loss2(params) = 1.270863989205448
loss2(params) = 1.259151664965692
loss2(params) = 1.24811307604911
loss2(params) = 1.237761560456945
loss2(params) = 1.228121160557785
loss2(params) = 1.2192148296305327
loss2(params) = 1.2110556550826925
loss2(params) = 1.2036410969748617
loss2(params) = 1.1969508919284904
loss2(params) = 1.1909485162483766
loss2(params) = 1.1855850800692853
loss2(params) = 1.1808041990638565
loss2(params) = 1.1765466940765645
loss2(params) = 1.1727544209436822
loss2(params) = 1.1693728970635253
loss2(params) = 1.1663526696290893
loss2(params) = 1.1636495943271

4×10 Array{Float64,2}:
 -1.22723   -0.754721  -0.170649  …  -0.209322    -1.2078     -1.41574  
  0.282991  -1.34115   -1.08024       0.00509969  -0.770441   -0.0904969
 -0.393245   1.0039     0.145425      0.706852    -0.0239067   1.02616  
  0.275059   0.102558   0.653934      0.552511    -0.799217    0.734231 

In [268]:
show(IOContext(STDOUT, limit=false), params)

MethodError: MethodError: no method matching IOContext(::IJulia.IJuliaStdio{Base.PipeEndpoint}; limit=false)
Closest candidates are:
  IOContext(::IO) at show.jl:189 got unsupported keyword argument "limit"
  IOContext(::IO, !Matched::Base.ImmutableDict) at show.jl:183 got unsupported keyword argument "limit"
  IOContext(::IO, !Matched::Pair) at show.jl:192 got unsupported keyword argument "limit"
  ...

In [270]:
for p in 1:10
    println(params[:,p])
end

[0.456042, 1.38215, -0.00241155, -0.74993, -0.442308]
[0.946446, 0.22905, -2.12314, 0.052197, 0.635222]
[0.3899, -0.390234, -1.06614, 2.50915, 0.47714]
[1.11132, 0.333807, -0.739465, 2.54774, -2.11501]
[2.35997, 0.768278, 0.79948, 0.716915, -1.47945]
[-1.23514, -0.541929, 0.362717, -0.234466, -1.39673]
[1.71574, -1.68961, -2.01312, -0.422448, -1.19826]
[-0.0526322, -1.87984, 0.837929, 1.19379, 0.158279]
[-0.0233147, -0.970986, -1.43134, 0.606254, -2.91135]
[2.40785, -1.87443, 0.582821, 0.203677, -0.720777]
