In [None]:
] up MultiScaleOT

In [None]:
using HDF5
using Plots
# Some plots defaults
using MultiScaleOT
default(legend = :none, aspect_ratio = :equal)
using SparseArrays

In [None]:
# Load data
img1, img2, shape1, shape2 = h5open("data/square_diamond.hdf5", "r") do file
    (
        read(file, "img1"), read(file, "img2"), 
        read(file, "shape1"), read(file, "shape2")
    )
end

plot(heatmap(img1), heatmap(img2), size = (600, 300))

In [None]:
# Build measures:
# Weights
mu1 = img1[:] .+ 1e-8
mu2 = img2[:] .+ 1e-8

normalize!(mu1)
normalize!(mu2)

# Supporting points
x1 = collect(1:128)
X = flat_grid(x1, x1)
Y = copy(X)

# Gridshapes
shapeX = size(img1)
shapeY = size(img2)

# Measure struct
mu = GridMeasure(X, mu1, shapeX)
nu = GridMeasure(Y, mu2, shapeY)

In [None]:
# Prepare parameters for the solver and multiscale
depth = compute_multiscale_depth(mu)

c(x,y) = l22(x,y)

# Epsilon schedule
Nsteps = 3
factor = 2.
eps_target = 0.5
last_iter = [eps_target/2]

# Epsilon scaling
eps_schedule = scaling_schedule(depth, eps_target, Nsteps, factor; last_iter = last_iter)

layer_schedule = template_schedule(depth, Nsteps, collect(1:depth); last_iter = [depth])

truncation = 1e-15

params_schedule = make_schedule(
                layer = layer_schedule,
                epsilon = eps_schedule, 
                solver_truncation = truncation,
                solver_max_error = 1e-4,
                solver_verbose = true,
                solver_max_iter = 10000
        );

In [None]:
# Solve OT problem hierarchically
# Note: second time running this code takes less time; 
# much of the first run is spent in precompilation.
@time K, a, b, status = hierarchical_sinkhorn(mu, nu, c, params_schedule, 2)

In [None]:
# Visualize displacement interpolation
# TODO: Quite hacky code right now
function displacement_interpolation(P::SparseMatrixCSC, X, Y, shapeX, t)
    (0 ≤ t ≤ 1) || erorr("t must be in [0,1]")
    # Z = Matrix{Float64}(undef, size(X, 1), length(P.nzval))
    Z = zeros(shapeX...)
    for j in 1:size(P, 2)
        for r in P.colptr[j]:P.colptr[j+1]-1
            i = P.rowval[r]
            k1 = (1-t).*X[1,i] .+ t.*Y[1,j]
            k2 = (1-t).*X[2,i] .+ t.*Y[2,j]
            k1 = 2e-8 + (1-1e-8)*k1
            k2 = 2e-8 + (1-1e-8)*k2
            mass = P.nzval[r]
            k1_base = Int(floor(k1))
            k2_base = Int(floor(k2))
            # The following offers a bit of antialiasing against just setting 
            # Z[k1_base, k2_base] +=mass
            # though it is of course not perfect
            offset1 = k1 - k1_base
            offset2 = k2 - k2_base
            
            Z[k1_base, k2_base] += (1-offset1)*(1-offset2)*mass
            Z[k1_base, k2_base+1] += (1-offset1)*(offset2)*mass
            Z[k1_base+1, k2_base] += (offset1)*(1-offset2)*mass
            Z[k1_base+1, k2_base+1] += (offset1)*(offset2)*mass
        end
    end
    return Z
end

In [None]:
plots = []
for t in 0:0.2:1
    Z = displacement_interpolation(K, X, Y, shape1, t)
    plot_i = heatmap(Z, axis = :off)
    xticks!(Int[])
    yticks!(Int[])
    push!(plots, plot_i)
end
plot(plots..., layout = (1,6), size = (900,150))