In [None]:
%reset
%load_ext autoreload
%autoreload 2

import pickle
import math

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../")

import numpy as np
import matplotlib.pyplot as plt
import torch
import geomloss

import defmod as dm

torch.set_default_tensor_type(torch.FloatTensor)

In [None]:
# Loading the datasets
source_raw = pickle.load(open('../data/basi2btemp.pkl', 'rb'))
target_raw = pickle.load(open('../data/basi2target.pkl', 'rb'))

source = torch.tensor(source_raw[1], dtype=torch.get_default_dtype())
target = torch.tensor(target_raw[1], dtype=torch.get_default_dtype())

# Some rescaling for the source
Dx = 0.
Dy = 0.
height_source = 90.
height_target = 495.

smin, smax = torch.min(source[:, 1]), torch.max(source[:, 1])
sscale = height_source / (smax - smin)
source[:, 0] = Dx + sscale * (source[:, 0] - torch.mean(source[:, 0]))
source[:, 1] = Dy - sscale * (source[:, 1] - smax)

# Some rescaling for the target
tmin, tmax = torch.min(target[:, 1]), torch.max(target[:, 1])
tscale = height_target / (tmax - tmin)
target[:, 0] = tscale * (target[:, 0] - torch.mean(target[:, 0]))
target[:, 1] = - tscale * (target[:, 1] - tmax)

# Setting up tensors used by the deformation modules
pos_source = source[source[:, 2] == 2, 0:2]
pos_implicit1 = source[source[:, 2] == 1, 0:2]
pos_implicit0 = source[source[:, 2] == 1, 0:2]

pos_target = target[target[:, 2] == 2, 0:2]

# Compute an AABB for plotting
aabb = dm.usefulfunctions.AABB.build_from_points(pos_target)
aabb.squared()

In [None]:
# Some plots
%matplotlib qt5

plt.subplot(1, 2, 1)
plt.axis(aabb.get_list())
plt.title('Source')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(pos_source[:, 0].numpy(), pos_source[:, 1].numpy(), '-')
plt.plot(pos_implicit1[:, 0].numpy(), pos_implicit1[:, 1].numpy(), '.')
plt.plot(pos_implicit0[:, 0].numpy(), pos_implicit0[:, 1].numpy(), 'x')

plt.subplot(1, 2, 2)
plt.axis(aabb.get_list())
plt.title('Target')
plt.xlabel('x')
plt.ylabel('y')
plt.plot(pos_target[:, 0].numpy(), pos_target[:, 1].numpy(), '-')

plt.show()

In [None]:
# Setting up the modules
# Local translation module
sigma0 = 10.
nu0 = 0.001
coeff0 = 100.
implicit0 = dm.implicitmodules.ImplicitModule0(dm.manifold.Landmarks(2, pos_implicit0.shape[0], gd=pos_implicit0.view(-1).requires_grad_()), sigma0, nu0, coeff0)

# Global translation module
sigma00 = 800.
nu00 = 0.001
coeff00 = 0.01
implicit00 = dm.implicitmodules.ImplicitModule0(dm.manifold.Landmarks(2, 1, gd=torch.tensor([0., 0.], requires_grad=True)), sigma00, nu00, coeff00)

# Elastic modules
sigma1 = 100.
nu1 = 0.001
coeff1 = 0.000001
C = torch.zeros(pos_implicit1.shape[0], 2, 1)
K, L = 10, height_source
a, b = 1./L, 3.
z = a*(pos_implicit1[:, 1] - Dy)
C[:, 1, 0] = K * ((1 - b) * z**2 + b * z)
C[:, 0, 0] = 0.8 * C[:, 1, 0]
th = 0. * math.pi * torch.ones(pos_implicit1.shape[0])
R = torch.stack([dm.usefulfunctions.rot2d(t) for t in th])

implicit1 = dm.implicitmodules.ImplicitModule1(dm.manifold.Stiefel(2, pos_implicit1.shape[0], gd=(pos_implicit1.view(-1).requires_grad_(), R.view(-1).requires_grad_())), C, sigma1, nu1, coeff1)

In [None]:
# Setting up the model and start the fitting loop
model = dm.models.ModelCompoundWithPointsRegistration((pos_source, torch.ones(pos_source.shape[0])), [implicit0, implicit00, implicit1], [True, True, True])
costs = model.fit((pos_target, torch.ones(pos_target.shape[0])), max_iter=60, l=5e-3, lr=5e-2, log_interval=1)

In [None]:
# Results
%matplotlib qt5
plt.subplot(1, 3, 1)
plt.axis(aabb.get_list())
plt.plot(pos_source[:, 0].numpy(), pos_source[:, 1].numpy(), '-')
plt.plot(pos_implicit1[:, 0].numpy(), pos_implicit1[:, 1].numpy(), '.')
plt.plot(pos_implicit0[:, 0].numpy(), pos_implicit0[:, 1].numpy(), 'x')


plt.subplot(1, 3, 2)
plt.axis(aabb.get_list())
out = model.shot_manifold[0].gd.view(-1, 2).detach().numpy()
shot_implicit0 = model.shot_manifold[2].gd.view(-1, 2).detach().numpy()
shot_implicit00 = model.shot_manifold[1].gd.view(-1, 2).detach().numpy()
shot_implicit1 = model.shot_manifold[3].gd[0].view(-1, 2).detach().numpy()
plt.plot(out[:, 0], out[:, 1], '-')
plt.plot(shot_implicit0[:, 0], shot_implicit0[:, 1], 'x')
plt.plot(shot_implicit00[:, 0], shot_implicit00[:, 1], 'o')
plt.plot(shot_implicit1[:, 0], shot_implicit1[:, 1], '.')

plt.subplot(1, 3, 3)
plt.axis(aabb.get_list())
plt.plot(pos_target[:, 0].numpy(), pos_target[:, 1].numpy(), '-')
plt.plot(out[:, 0], out[:, 1], '-')
plt.show()

In [None]:
compound = dm.deformationmodules.CompoundModule(model.modules)
compound.manifold.fill(model.init_manifold.copy())
#compound.manifold.fill_cotan(model.shot_manifold.cotan)

itt = 4
out = dm.shooting.shoot(dm.hamiltonian.Hamiltonian(compound), it=itt)

for i in range(itt):
    plt.figure()
    plt.plot(out[i].gd[0].view(-1, 2)[:, 0].detach().numpy(), out[i].gd[0].view(-1, 2)[:, 1].detach().numpy(), '-')
    plt.plot(out[i].gd[1].view(-1, 2)[:, 0].detach().numpy(), out[i].gd[1].view(-1, 2)[:, 1].detach().numpy(), '.')
    plt.plot(out[i][2].gd.view(-1, 2)[:, 0].detach().numpy(), out[i][2].gd.view(-1, 2)[:, 1].detach().numpy(), 'o')
    plt.plot(out[i][3].gd[0].view(-1, 2)[:, 0].detach().numpy(), out[i][3].gd[0].view(-1, 2)[:, 1].detach().numpy(), 'x')
    plt.show()




In [None]:
# Evolution of the cost with iterations
plt.title("Cost")
plt.xlabel("Iteration(s)")
plt.ylabel("Cost")
plt.plot(range(len(costs)), costs)
plt.show()