In [1]:
%reset
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../")

import numpy as np
import matplotlib.pyplot as plt
import torch
import geomloss

import defmod as dm

torch.set_default_tensor_type(torch.DoubleTensor)

In [2]:
# Loading the datasets
data = pickle.load(open('../data/data_acropetal.pkl', 'rb'))

pos_source = torch.tensor(data['source_silent']).type(torch.get_default_dtype())
pos_implicit0 = torch.tensor(data['source_implicit0']).type(torch.get_default_dtype())
pos_implicit1 = torch.tensor(data['source_implicit1']).type(torch.get_default_dtype())
pos_target = torch.tensor(data['target_silent']).type(torch.get_default_dtype())

# Some rescaling for the source
Dx = 0.
Dy = 0.
height_source = 90.
height_target = 495.

smin, smax = torch.min(pos_source[:, 1]), torch.max(pos_source[:, 1])
sscale = height_source / (smax - smin)
pos_source[:, 0] = Dx + sscale * (pos_source[:, 0] - torch.mean(pos_source[:, 0]))
pos_source[:, 1] = Dy - sscale * (pos_source[:, 1] - smax)
pos_implicit0[:, 0] = Dx + sscale * (pos_implicit0[:, 0] - torch.mean(pos_implicit0[:, 0]))
pos_implicit0[:, 1] = Dy - sscale * (pos_implicit0[:, 1] - smax)
pos_implicit1[:, 0] = Dx + sscale * (pos_implicit1[:, 0] - torch.mean(pos_implicit1[:, 0]))
pos_implicit1[:, 1] = Dy - sscale * (pos_implicit1[:, 1] - smax)

# Some rescaling for the target
tmin, tmax = torch.min(pos_target[:, 1]), torch.max(pos_target[:, 1])
tscale = height_target / (tmax - tmin)
pos_target[:, 0] = tscale * (pos_target[:, 0] - torch.mean(pos_target[:, 0]))
pos_target[:, 1] = - tscale * (pos_target[:, 1] - tmax)

# Compute an AABB for plotting
aabb = dm.usefulfunctions.AABB.build_from_points(pos_target)
aabb.squared()

In [3]:
# Some plots
%matplotlib qt5

plt.subplot(2, 2, 1)
plt.axis(aabb.get_list())
plt.title('Source')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(pos_source[:, 0].numpy(), pos_source[:, 1].numpy(), '-')
plt.plot(pos_implicit1[:, 0].numpy(), pos_implicit1[:, 1].numpy(), '.')
plt.plot(pos_implicit0[:, 0].numpy(), pos_implicit0[:, 1].numpy(), 'x')

plt.subplot(2, 2, 2)
plt.axis(aabb.get_list())
plt.title('Target')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(pos_target[:, 0].numpy(), pos_target[:, 1].numpy(), '-')

plt.subplot(2, 2, 3)
plt.imshow(data['source_img'])

plt.subplot(2, 2, 4)
plt.imshow(data['target_img'])

plt.show()

In [38]:
# Setting up the modules
# Local translation module
sigma0 = 10.
nu0 = 0.001
coeff0 = 100.
implicit0 = dm.implicitmodules.ImplicitModule0(dm.manifold.Landmarks(2, pos_implicit0.shape[0], gd=pos_implicit0.view(-1).requires_grad_()), sigma0, nu0, coeff0)

# Global translation module
sigma00 = 800.
nu00 = 0.001
coeff00 = 0.01
implicit00 = dm.implicitmodules.ImplicitModule0(dm.manifold.Landmarks(2, 1, gd=torch.tensor([0., 0.], requires_grad=True)), sigma00, nu00, coeff00)

# Elastic modules
sigma1 = 100.
nu1 = 0.001
coeff1 = 0.000001
C = torch.zeros(pos_implicit1.shape[0], 2, 1)
K, L = 10, height_source
a, b = 1./L, 3.
z = a*(pos_implicit1[:, 1] - Dy)
C[:, 1, 0] = K * ((1 - b) * z**2 + b * z)
C[:, 0, 0] = 0.8 * C[:, 1, 0]
th = 0. * math.pi * torch.ones(pos_implicit1.shape[0])
R = torch.stack([dm.usefulfunctions.rot2d(t) for t in th])

implicit1 = dm.implicitmodules.ImplicitModule1(dm.manifold.Stiefel(2, pos_implicit1.shape[0], gd=(pos_implicit1.view(-1).requires_grad_(), R.view(-1).requires_grad_())), C, sigma1, nu1, coeff1)

In [39]:
# Setting up the model and start the fitting loop
model = dm.models.ModelCompoundWithPointsRegistration((pos_source, torch.ones(pos_source.shape[0])), [implicit0, implicit00, implicit1], [True, True, True])
costs = model.fit((pos_target, torch.ones(pos_target.shape[0])), max_iter=50, l=5e-3, lr=5e-2, log_interval=1)

It: 0, deformation cost: 0.000000, attach: 148.952556. Total cost: 148.952556


It: 1, deformation cost: 18.054090, attach: 336.978330. Total cost: 355.032421


It: 2, deformation cost: 17.485827, attach: 330.795265. Total cost: 348.281092


It: 3, deformation cost: 16.880237, attach: 324.075784. Total cost: 340.956021


It: 4, deformation cost: 16.240148, attach: 316.800997. Total cost: 333.041145


It: 5, deformation cost: 15.565832, attach: 308.917267. Total cost: 324.483099


It: 6, deformation cost: 14.853489, attach: 300.312025. Total cost: 315.165513


It: 7, deformation cost: 14.094974, attach: 290.794804. Total cost: 304.889778


It: 8, deformation cost: 13.273222, attach: 280.017656. Total cost: 293.290878


It: 9, deformation cost: 12.362486, attach: 267.421076. Total cost: 279.783562


It: 10, deformation cost: 11.336549, attach: 252.238882. Total cost: 263.575431


It: 11, deformation cost: 10.182624, attach: 233.566769. Total cost: 243.749393


It: 12, deformation cost: 8.936055, attach: 210.782908. Total cost: 219.718964


It: 13, deformation cost: 7.675743, attach: 183.571871. Total cost: 191.247614


It: 14, deformation cost: 6.455208, attach: 150.019703. Total cost: 156.474911


It: 15, deformation cost: 5.430544, attach: 106.412374. Total cost: 111.842918


It: 16, deformation cost: 5.200290, attach: 54.649044. Total cost: 59.849334


It: 17, deformation cost: 5.774230, attach: 21.679046. Total cost: 27.453276


It: 18, deformation cost: 5.729337, attach: 20.355007. Total cost: 26.084344


It: 19, deformation cost: 5.360122, attach: 14.683259. Total cost: 20.043381


It: 20, deformation cost: 4.612528, attach: 7.876841. Total cost: 12.489369


It: 21, deformation cost: 4.437826, attach: 6.771999. Total cost: 11.209826


It: 22, deformation cost: 4.318225, attach: 6.076823. Total cost: 10.395048


It: 23, deformation cost: 4.193478, attach: 5.389852. Total cost: 9.583330


It: 24, deformation cost: 4.081580, attach: 4.803571. Total cost: 8.885151


It: 25, deformation cost: 3.980787, attach: 4.297488. Total cost: 8.278275


It: 26, deformation cost: 3.889286, attach: 3.854709. Total cost: 7.743995


It: 27, deformation cost: 3.806128, attach: 3.466070. Total cost: 7.272197


It: 28, deformation cost: 3.730339, attach: 3.123810. Total cost: 6.854149


It: 29, deformation cost: 3.661065, attach: 2.821839. Total cost: 6.482904


It: 30, deformation cost: 3.597528, attach: 2.555249. Total cost: 6.152777


It: 31, deformation cost: 3.538961, attach: 2.319942. Total cost: 5.858903


It: 32, deformation cost: 3.484595, attach: 2.112550. Total cost: 5.597145


It: 33, deformation cost: 3.433874, attach: 1.930634. Total cost: 5.364508


It: 34, deformation cost: 3.386709, attach: 1.772459. Total cost: 5.159169


It: 35, deformation cost: 3.343130, attach: 1.635871. Total cost: 4.979000


It: 36, deformation cost: 3.302947, attach: 1.518249. Total cost: 4.821196


It: 37, deformation cost: 3.265875, attach: 1.417275. Total cost: 4.683151


It: 38, deformation cost: 3.231635, attach: 1.331057. Total cost: 4.562692


It: 39, deformation cost: 3.199957, attach: 1.257973. Total cost: 4.457929


It: 40, deformation cost: 3.170553, attach: 1.196559. Total cost: 4.367113


It: 41, deformation cost: 3.143026, attach: 1.145445. Total cost: 4.288471


It: 42, deformation cost: 3.116676, attach: 1.103346. Total cost: 4.220023


It: 43, deformation cost: 3.090730, attach: 1.069459. Total cost: 4.160188


It: 44, deformation cost: 3.065793, attach: 1.043929. Total cost: 4.109722


It: 45, deformation cost: 3.042896, attach: 1.026010. Total cost: 4.068906


It: 46, deformation cost: 3.021913, attach: 1.014041. Total cost: 4.035954


It: 47, deformation cost: 3.002540, attach: 1.007070. Total cost: 4.009611


It: 48, deformation cost: 2.984586, attach: 1.004362. Total cost: 3.988948


It: 49, deformation cost: 2.967912, attach: 1.005326. Total cost: 3.973238


It: 50, deformation cost: 2.952404, attach: 1.009450. Total cost: 3.961854
End of the optimisation process.


In [40]:
# Results
%matplotlib qt5
plt.subplot(1, 3, 1)
plt.axis(aabb.get_list())
plt.plot(pos_source[:, 0].numpy(), pos_source[:, 1].numpy(), '-')
plt.plot(pos_implicit1[:, 0].numpy(), pos_implicit1[:, 1].numpy(), '.')
plt.plot(pos_implicit0[:, 0].numpy(), pos_implicit0[:, 1].numpy(), 'x')

plt.subplot(1, 3, 2)
plt.axis(aabb.get_list())
out = model.shot_manifold[0].gd.view(-1, 2).detach().numpy()
shot_implicit0 = model.shot_manifold[1].gd.view(-1, 2).detach().numpy()
shot_implicit00 = model.shot_manifold[2].gd.view(-1, 2).detach().numpy()
shot_implicit1 = model.shot_manifold[3].gd[0].view(-1, 2).detach().numpy()
plt.plot(out[:, 0], out[:, 1], '-')
plt.plot(shot_implicit0[:, 0], shot_implicit0[:, 1], 'x')
plt.plot(shot_implicit00[:, 0], shot_implicit00[:, 1], 'o')
plt.plot(shot_implicit1[:, 0], shot_implicit1[:, 1], '.')

plt.subplot(1, 3, 3)
plt.axis(aabb.get_list())
plt.plot(pos_target[:, 0].numpy(), pos_target[:, 1].numpy(), '-')
plt.plot(out[:, 0], out[:, 1], '-')
plt.show()

In [None]:
compound = dm.deformationmodules.CompoundModule(model.modules)
compound.manifold.fill(model.init_manifold.copy())
#compound.manifold.fill_cotan(model.shot_manifold.cotan)

itt = 4
out = dm.shooting.shoot(dm.hamiltonian.Hamiltonian(compound), it=itt)

for i in range(itt):
    plt.figure()
    plt.plot(out[i].gd[0].view(-1, 2)[:, 0].detach().numpy(), out[i].gd[0].view(-1, 2)[:, 1].detach().numpy(), '-')
    plt.plot(out[i].gd[1].view(-1, 2)[:, 0].detach().numpy(), out[i].gd[1].view(-1, 2)[:, 1].detach().numpy(), '.')
    plt.plot(out[i][2].gd.view(-1, 2)[:, 0].detach().numpy(), out[i][2].gd.view(-1, 2)[:, 1].detach().numpy(), 'o')
    plt.plot(out[i][3].gd[0].view(-1, 2)[:, 0].detach().numpy(), out[i][3].gd[0].view(-1, 2)[:, 1].detach().numpy(), 'x')
    plt.show()




In [20]:
# Evolution of the cost with iterations
plt.title("Cost")
plt.xlabel("Iteration(s)")
plt.ylabel("Cost")
plt.plot(range(len(costs)), costs)
plt.show()