In [5]:
%reset
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../")

import numpy as np
import matplotlib.pyplot as plt
import torch
import geomloss

import defmod as dm

torch.set_default_tensor_type(torch.FloatTensor)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
# Loading the datasets
source_raw = pickle.load(open('../data/basi2btemp.pkl', 'rb'))
target_raw = pickle.load(open('../data/basi2target.pkl', 'rb'))
data = pickle.load(open('../data/data_acropetal.pkl', 'rb'))

pos_source = torch.tensor(data['source_silent']).type(torch.get_default_dtype())
pos_implicit0 = torch.tensor(data['source_implicit0']).type(torch.get_default_dtype())
pos_implicit1 = torch.tensor(data['source_implicit1']).type(torch.get_default_dtype())
pos_target = torch.tensor(data['target_silent']).type(torch.get_default_dtype())

# Some rescaling for the source
Dx = 0.
Dy = 0.
height_source = 90.
height_target = 495.

smin, smax = torch.min(pos_source[:, 1]), torch.max(pos_source[:, 1])
sscale = height_source / (smax - smin)
pos_source[:, 0] = Dx + sscale * (pos_source[:, 0] - torch.mean(pos_source[:, 0]))
pos_source[:, 1] = Dy - sscale * (pos_source[:, 1] - smax)
pos_implicit0[:, 0] = Dx + sscale * (pos_implicit0[:, 0] - torch.mean(pos_implicit0[:, 0]))
pos_implicit0[:, 1] = Dy - sscale * (pos_implicit0[:, 1] - smax)
pos_implicit1[:, 0] = Dx + sscale * (pos_implicit1[:, 0] - torch.mean(pos_implicit1[:, 0]))
pos_implicit1[:, 1] = Dy - sscale * (pos_implicit1[:, 1] - smax)

# Some rescaling for the target
tmin, tmax = torch.min(pos_target[:, 1]), torch.max(pos_target[:, 1])
tscale = height_target / (tmax - tmin)
pos_target[:, 0] = tscale * (pos_target[:, 0] - torch.mean(pos_target[:, 0]))
pos_target[:, 1] = - tscale * (pos_target[:, 1] - tmax)

# Compute an AABB for plotting
aabb = dm.usefulfunctions.AABB.build_from_points(pos_target)
aabb.squared()

In [11]:
# Some plots
%matplotlib qt5

plt.subplot(1, 2, 1)
plt.axis(aabb.get_list())
plt.title('Source')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(pos_source[:, 0].numpy(), pos_source[:, 1].numpy(), '-')
plt.plot(pos_implicit1[:, 0].numpy(), pos_implicit1[:, 1].numpy(), '.')
plt.plot(pos_implicit0[:, 0].numpy(), pos_implicit0[:, 1].numpy(), 'x')

plt.subplot(1, 2, 2)
plt.axis(aabb.get_list())
plt.title('Target')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(pos_target[:, 0].numpy(), pos_target[:, 1].numpy(), '-')

plt.show()

In [46]:
# Setting up the modules
# Local translation module
sigma0 = 10.
nu0 = 0.001
coeff0 = 100.
implicit0 = dm.implicitmodules.ImplicitModule0(dm.manifold.Landmarks(2, pos_implicit0.shape[0], gd=pos_implicit0.view(-1).requires_grad_()), sigma0, nu0, coeff0)

# Global translation module
sigma00 = 800.
nu00 = 0.001
coeff00 = 0.01
implicit00 = dm.implicitmodules.ImplicitModule0(dm.manifold.Landmarks(2, 1, gd=torch.tensor([0., 0.], requires_grad=True)), sigma00, nu00, coeff00)

# Elastic modules
sigma1 = 100.
nu1 = 0.001
coeff1 = 0.000001
C = torch.zeros(pos_implicit1.shape[0], 2, 1)
K, L = 10, height_source
a, b = 1./L, 3.
z = a*(pos_implicit1[:, 1] - Dy)
C[:, 1, 0] = K * ((1 - b) * z**2 + b * z)
C[:, 0, 0] = 0.8 * C[:, 1, 0]
th = 0. * math.pi * torch.ones(pos_implicit1.shape[0])
R = torch.stack([dm.usefulfunctions.rot2d(t) for t in th])

implicit1 = dm.implicitmodules.ImplicitModule1(dm.manifold.Stiefel(2, pos_implicit1.shape[0], gd=(pos_implicit1.view(-1).requires_grad_(), R.view(-1).requires_grad_())), C, sigma1, nu1, coeff1)

In [47]:
# Setting up the model and start the fitting loop
model = dm.models.ModelCompoundWithPointsRegistration((pos_source, torch.ones(pos_source.shape[0])), [implicit0, implicit00, implicit1], [True, True, True])
costs = model.fit((pos_target, torch.ones(pos_target.shape[0])), max_iter=140, l=9e-4, lr=5e-2, log_interval=1)

It: 0, deformation cost: 0.000000, attach: 26.811464. Total cost: 26.811464


It: 1, deformation cost: 28.759995, attach: 95.446487. Total cost: 124.206482


It: 2, deformation cost: 23.678719, attach: 88.697556. Total cost: 112.376274


It: 3, deformation cost: 21.220154, attach: 85.055466. Total cost: 106.275620


It: 4, deformation cost: 11.638828, attach: 61.290871. Total cost: 72.929703


It: 5, deformation cost: 11.059816, attach: 58.409328. Total cost: 69.469147


It: 6, deformation cost: 10.954777, attach: 57.854439. Total cost: 68.809219


It: 7, deformation cost: 10.852327, attach: 57.353168. Total cost: 68.205498


It: 8, deformation cost: 10.736352, attach: 56.828594. Total cost: 67.564949


It: 9, deformation cost: 10.603662, attach: 56.272507. Total cost: 66.876167


It: 10, deformation cost: 10.450811, attach: 55.677162. Total cost: 66.127975


It: 11, deformation cost: 10.273535, attach: 55.032997. Total cost: 65.306534


It: 12, deformation cost: 10.065015, attach: 54.323086. Total cost: 64.388100


It: 13, deformation cost: 9.823883, attach: 53.551785. Total cost: 63.375668


It: 14, deformation cost: 9.530937, attach: 52.665855. Total cost: 62.196793


It: 15, deformation cost: 9.104757, attach: 51.441551. Total cost: 60.546310


It: 16, deformation cost: 8.671144, attach: 50.253517. Total cost: 58.924660


It: 17, deformation cost: 8.225063, attach: 49.077393. Total cost: 57.302456


It: 18, deformation cost: 7.493073, attach: 47.204880. Total cost: 54.697952


It: 19, deformation cost: 6.443037, attach: 44.567001. Total cost: 51.010036


It: 20, deformation cost: 4.644666, attach: 39.861893. Total cost: 44.506557


It: 21, deformation cost: 2.980214, attach: 34.901993. Total cost: 37.882206


It: 22, deformation cost: 0.823410, attach: 28.367334. Total cost: 29.190744


It: 23, deformation cost: 0.801046, attach: 28.301462. Total cost: 29.102509


It: 24, deformation cost: 0.774394, attach: 28.169554. Total cost: 28.943949


It: 25, deformation cost: 0.729331, attach: 27.878368. Total cost: 28.607700


It: 26, deformation cost: 0.088395, attach: 26.785191. Total cost: 26.873585


It: 27, deformation cost: 0.090999, attach: 26.591251. Total cost: 26.682251


It: 28, deformation cost: 0.094186, attach: 26.336218. Total cost: 26.430405


It: 29, deformation cost: 0.098084, attach: 26.037136. Total cost: 26.135220


It: 30, deformation cost: 0.103002, attach: 25.691730. Total cost: 25.794733


It: 31, deformation cost: 0.109252, attach: 25.296408. Total cost: 25.405659


It: 32, deformation cost: 0.117173, attach: 24.846756. Total cost: 24.963930


It: 33, deformation cost: 0.127556, attach: 24.323734. Total cost: 24.451290


It: 34, deformation cost: 0.139957, attach: 23.761169. Total cost: 23.901127


It: 35, deformation cost: 0.155436, attach: 23.122662. Total cost: 23.278097


It: 36, deformation cost: 0.174767, attach: 22.395168. Total cost: 22.569935


It: 37, deformation cost: 0.199022, attach: 21.560339. Total cost: 21.759361


It: 38, deformation cost: 0.234154, attach: 20.480059. Total cost: 20.714212


It: 39, deformation cost: 0.265937, attach: 19.554068. Total cost: 19.820004


It: 40, deformation cost: 0.319299, attach: 18.116732. Total cost: 18.436031


It: 41, deformation cost: 0.392563, attach: 16.338058. Total cost: 16.730621


It: 42, deformation cost: 0.491926, attach: 14.200340. Total cost: 14.692266


It: 43, deformation cost: 0.608903, attach: 12.036707. Total cost: 12.645610


It: 44, deformation cost: 0.724372, attach: 10.220663. Total cost: 10.945035


It: 45, deformation cost: 0.805518, attach: 9.072803. Total cost: 9.878322


It: 46, deformation cost: 0.869669, attach: 8.255994. Total cost: 9.125663


It: 47, deformation cost: 0.932504, attach: 7.523699. Total cost: 8.456203


It: 48, deformation cost: 0.984511, attach: 6.971378. Total cost: 7.955889


It: 49, deformation cost: 1.033454, attach: 6.491166. Total cost: 7.524620


It: 50, deformation cost: 1.080223, attach: 6.065067. Total cost: 7.145290


It: 51, deformation cost: 1.125284, attach: 5.682721. Total cost: 6.808005


It: 52, deformation cost: 1.168913, attach: 5.337008. Total cost: 6.505920


It: 53, deformation cost: 1.211286, attach: 5.022806. Total cost: 6.234092


It: 54, deformation cost: 1.252520, attach: 4.736018. Total cost: 5.988538


It: 55, deformation cost: 1.292698, attach: 4.473353. Total cost: 5.766051


It: 56, deformation cost: 1.331873, attach: 4.232058. Total cost: 5.563931


It: 57, deformation cost: 1.370080, attach: 4.009871. Total cost: 5.379951


It: 58, deformation cost: 1.407340, attach: 3.804860. Total cost: 5.212200


It: 59, deformation cost: 1.443671, attach: 3.615321. Total cost: 5.058992


It: 60, deformation cost: 1.479070, attach: 3.439882. Total cost: 4.918952


It: 61, deformation cost: 1.513545, attach: 3.277257. Total cost: 4.790802


It: 62, deformation cost: 1.547096, attach: 3.126344. Total cost: 4.673440


It: 63, deformation cost: 1.579724, attach: 2.986122. Total cost: 4.565846


It: 64, deformation cost: 1.611424, attach: 2.855761. Total cost: 4.467185


It: 65, deformation cost: 1.642195, attach: 2.734483. Total cost: 4.376678


It: 66, deformation cost: 1.672052, attach: 2.621472. Total cost: 4.293524


It: 67, deformation cost: 1.700975, attach: 2.516259. Total cost: 4.217234


It: 68, deformation cost: 1.728994, attach: 2.418012. Total cost: 4.147007


It: 69, deformation cost: 1.756083, attach: 2.326555. Total cost: 4.082638


It: 70, deformation cost: 1.782297, attach: 2.240906. Total cost: 4.023203


It: 71, deformation cost: 1.807598, attach: 2.161216. Total cost: 3.968815


It: 72, deformation cost: 1.832167, attach: 2.085909. Total cost: 3.918076


It: 73, deformation cost: 1.856464, attach: 2.014203. Total cost: 3.870666


It: 74, deformation cost: 1.880923, attach: 1.942971. Total cost: 3.823894


It: 75, deformation cost: 1.909193, attach: 1.864121. Total cost: 3.773313


It: 76, deformation cost: 1.936589, attach: 1.789525. Total cost: 3.726114


It: 77, deformation cost: 1.957610, attach: 1.736450. Total cost: 3.694060


It: 78, deformation cost: 1.978964, attach: 1.678865. Total cost: 3.657829


It: 79, deformation cost: 1.999548, attach: 1.631128. Total cost: 3.630676


It: 80, deformation cost: 2.019124, attach: 1.582706. Total cost: 3.601830


It: 81, deformation cost: 2.036659, attach: 1.542924. Total cost: 3.579583


It: 82, deformation cost: 2.053664, attach: 1.503546. Total cost: 3.557209


It: 83, deformation cost: 2.069548, attach: 1.468764. Total cost: 3.538312


It: 84, deformation cost: 2.084864, attach: 1.435118. Total cost: 3.519982


It: 85, deformation cost: 2.099349, attach: 1.404426. Total cost: 3.503775


It: 86, deformation cost: 2.113271, attach: 1.375091. Total cost: 3.488362


It: 87, deformation cost: 2.126504, attach: 1.347927. Total cost: 3.474431


It: 88, deformation cost: 2.139188, attach: 1.322132. Total cost: 3.461320


It: 89, deformation cost: 2.151270, attach: 1.298077. Total cost: 3.449347


It: 90, deformation cost: 2.162839, attach: 1.275271. Total cost: 3.438111


It: 91, deformation cost: 2.173862, attach: 1.253950. Total cost: 3.427812


It: 92, deformation cost: 2.184412, attach: 1.233740. Total cost: 3.418152


It: 93, deformation cost: 2.194466, attach: 1.214795. Total cost: 3.409261


It: 94, deformation cost: 2.204076, attach: 1.196877. Total cost: 3.400952


It: 95, deformation cost: 2.213255, attach: 1.179965. Total cost: 3.393219


It: 96, deformation cost: 2.222003, attach: 1.164077. Total cost: 3.386080


It: 97, deformation cost: 2.230370, attach: 1.148999. Total cost: 3.379369


It: 98, deformation cost: 2.238358, attach: 1.134769. Total cost: 3.373127


It: 99, deformation cost: 2.245973, attach: 1.121363. Total cost: 3.367336


It: 100, deformation cost: 2.253262, attach: 1.108605. Total cost: 3.361867


It: 101, deformation cost: 2.260203, attach: 1.096628. Total cost: 3.356831


It: 102, deformation cost: 2.266852, attach: 1.085188. Total cost: 3.352040


It: 103, deformation cost: 2.273190, attach: 1.074427. Total cost: 3.347617


It: 104, deformation cost: 2.279250, attach: 1.064180. Total cost: 3.343430


It: 105, deformation cost: 2.285022, attach: 1.054574. Total cost: 3.339596


It: 106, deformation cost: 2.290581, attach: 1.045337. Total cost: 3.335919


It: 107, deformation cost: 2.295934, attach: 1.036530. Total cost: 3.332464


It: 108, deformation cost: 2.301107, attach: 1.028030. Total cost: 3.329138


It: 109, deformation cost: 2.306112, attach: 1.019873. Total cost: 3.325985


It: 110, deformation cost: 2.310972, attach: 1.011998. Total cost: 3.322970


It: 111, deformation cost: 2.315748, attach: 1.004267. Total cost: 3.320015


It: 112, deformation cost: 2.320491, attach: 0.996586. Total cost: 3.317077


It: 113, deformation cost: 2.325281, attach: 0.988727. Total cost: 3.314008


It: 114, deformation cost: 2.329954, attach: 0.981297. Total cost: 3.311251


It: 115, deformation cost: 2.335376, attach: 0.971615. Total cost: 3.306991


It: 116, deformation cost: 2.340302, attach: 0.963743. Total cost: 3.304044


It: 117, deformation cost: 2.345345, attach: 0.954847. Total cost: 3.300191


It: 118, deformation cost: 2.349989, attach: 0.947448. Total cost: 3.297437


It: 119, deformation cost: 2.354983, attach: 0.938540. Total cost: 3.293523


It: 120, deformation cost: 2.359524, attach: 0.931180. Total cost: 3.290703


It: 121, deformation cost: 2.364420, attach: 0.922349. Total cost: 3.286768


It: 122, deformation cost: 2.368810, attach: 0.915182. Total cost: 3.283993


It: 123, deformation cost: 2.373737, attach: 0.906064. Total cost: 3.279801


It: 124, deformation cost: 2.378100, attach: 0.898606. Total cost: 3.276707


It: 125, deformation cost: 2.382931, attach: 0.889402. Total cost: 3.272334


It: 126, deformation cost: 2.387287, attach: 0.881334. Total cost: 3.268621


It: 127, deformation cost: 2.392007, attach: 0.871903. Total cost: 3.263910


It: 128, deformation cost: 2.396219, attach: 0.863803. Total cost: 3.260022


It: 129, deformation cost: 2.400744, attach: 0.854342. Total cost: 3.255086


It: 130, deformation cost: 2.404682, attach: 0.846255. Total cost: 3.250937


It: 131, deformation cost: 2.408922, attach: 0.836832. Total cost: 3.245754


It: 132, deformation cost: 2.412554, attach: 0.828658. Total cost: 3.241211


It: 133, deformation cost: 2.416408, attach: 0.819347. Total cost: 3.235755


It: 134, deformation cost: 2.419642, attach: 0.811069. Total cost: 3.230711


It: 135, deformation cost: 2.422989, attach: 0.801840. Total cost: 3.224829


It: 136, deformation cost: 2.425546, attach: 0.793784. Total cost: 3.219331


It: 137, deformation cost: 2.428182, attach: 0.784575. Total cost: 3.212758


It: 138, deformation cost: 2.429798, attach: 0.776998. Total cost: 3.206797


It: 139, deformation cost: 2.431546, attach: 0.767722. Total cost: 3.199268


It: 140, deformation cost: 2.432353, attach: 0.759647. Total cost: 3.192000
End of the optimisation process.


In [49]:
# Results
%matplotlib qt5
plt.subplot(1, 3, 1)
plt.axis(aabb.get_list())
plt.plot(pos_source[:, 0].numpy(), pos_source[:, 1].numpy(), '-')
plt.plot(pos_implicit1[:, 0].numpy(), pos_implicit1[:, 1].numpy(), '.')
plt.plot(pos_implicit0[:, 0].numpy(), pos_implicit0[:, 1].numpy(), 'x')


plt.subplot(1, 3, 2)
plt.axis(aabb.get_list())
out = model.shot_manifold[0].gd.view(-1, 2).detach().numpy()
shot_implicit0 = model.shot_manifold[1].gd.view(-1, 2).detach().numpy()
shot_implicit00 = model.shot_manifold[2].gd.view(-1, 2).detach().numpy()
shot_implicit1 = model.shot_manifold[3].gd[0].view(-1, 2).detach().numpy()
plt.plot(out[:, 0], out[:, 1], '-')
plt.plot(shot_implicit0[:, 0], shot_implicit0[:, 1], 'x')
plt.plot(shot_implicit00[:, 0], shot_implicit00[:, 1], 'o')
plt.plot(shot_implicit1[:, 0], shot_implicit1[:, 1], '.')

plt.subplot(1, 3, 3)
plt.axis(aabb.get_list())
plt.plot(pos_target[:, 0].numpy(), pos_target[:, 1].numpy(), '-')
plt.plot(out[:, 0], out[:, 1], '-')
plt.show()

In [None]:
compound = dm.deformationmodules.CompoundModule(model.modules)
compound.manifold.fill(model.init_manifold.copy())
#compound.manifold.fill_cotan(model.shot_manifold.cotan)

itt = 4
out = dm.shooting.shoot(dm.hamiltonian.Hamiltonian(compound), it=itt)

for i in range(itt):
    plt.figure()
    plt.plot(out[i].gd[0].view(-1, 2)[:, 0].detach().numpy(), out[i].gd[0].view(-1, 2)[:, 1].detach().numpy(), '-')
    plt.plot(out[i].gd[1].view(-1, 2)[:, 0].detach().numpy(), out[i].gd[1].view(-1, 2)[:, 1].detach().numpy(), '.')
    plt.plot(out[i][2].gd.view(-1, 2)[:, 0].detach().numpy(), out[i][2].gd.view(-1, 2)[:, 1].detach().numpy(), 'o')
    plt.plot(out[i][3].gd[0].view(-1, 2)[:, 0].detach().numpy(), out[i][3].gd[0].view(-1, 2)[:, 1].detach().numpy(), 'x')
    plt.show()




In [None]:
# Evolution of the cost with iterations
plt.title("Cost")
plt.xlabel("Iteration(s)")
plt.ylabel("Cost")
plt.plot(range(len(costs)), costs)
plt.show()