In [None]:
%reset
%load_ext autoreload
%autoreload 2

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../")

import numpy as np
import matplotlib.pyplot as plt
import torch

import defmod as dm

torch.set_default_tensor_type(torch.FloatTensor)

In [None]:
source = dm.sampling.load_and_sample_greyscale("../data/density_a.png", threshold=0.5, centered=True)
target = dm.sampling.load_and_sample_greyscale("../data/density_b.png", threshold=0.5, centered=True)

In [None]:
aabb = dm.usefulfunctions.AABB.build_from_points(source[0])

In [None]:
sigma0 = 0.15
sigma1 = 0.05
x0, y0 = torch.meshgrid([torch.arange(aabb.xmin-sigma0, aabb.xmax+sigma0, step=sigma0), torch.arange(aabb.ymin-sigma0, aabb.ymax+sigma0, step=sigma0)])
x1, y1 = torch.meshgrid([torch.arange(aabb.xmin-sigma1, aabb.xmax+sigma1, step=sigma1), torch.arange(aabb.ymin-sigma1, aabb.ymax+sigma1, step=sigma1)])

gd0 = dm.usefulfunctions.grid2vec(x0, y0).contiguous().view(-1)
gd1 = dm.usefulfunctions.grid2vec(x1, y1).contiguous().view(-1)

trans0 = dm.deformationmodules.Translations(2, gd0.view(-1, 2).shape[0], sigma0)
trans1 = dm.deformationmodules.Translations(2, gd1.view(-1, 2).shape[0], sigma1)

In [None]:
dm.usefulfunctions.plot_tensor_scatter(source, alpha=0.4)
dm.usefulfunctions.plot_tensor_scatter(target, alpha=0.4)
plt.plot(gd0.view(-1, 2)[:,1].numpy(), gd0.view(-1, 2)[:,0].numpy(), '.')
plt.plot(gd1.view(-1, 2)[:,1].numpy(), gd1.view(-1, 2)[:,0].numpy(), '.')
plt.show()

In [None]:
my_model = dm.models.ModelCompoundWithPointsRegistration(2, source, [trans0, trans1], [gd0, gd1], [False, True])
costs = my_model.fit(target, lr=1e-4, l=50., max_iter=400, log_interval=5)

In [None]:
out = my_model()
out = out[0][-1], out[1][-1]
out_gd, _ = my_model.shoot_list()
in_gd, _ = my_model.get_var_list()

%matplotlib qt5
plt.subplot(1, 2, 1)
dm.usefulfunctions.plot_tensor_scatter(target, alpha=0.4)
dm.usefulfunctions.plot_tensor_scatter(out, alpha=0.4)
plt.plot(out_gd[-1][1].view(-1, 2).detach().numpy()[:, 1], out_gd[-1][1].view(-1, 2).detach().numpy()[:, 0], '.')
plt.plot(out_gd[-1][2].view(-1, 2).detach().numpy()[:, 1], out_gd[-1][2].view(-1, 2).detach().numpy()[:, 0], '.')
plt.subplot(1, 2, 2)
dm.usefulfunctions.plot_tensor_scatter(target, alpha=0.4)
dm.usefulfunctions.plot_tensor_scatter(source, alpha=0.4)
plt.plot(in_gd[1].view(-1, 2).detach().numpy()[:, 1], in_gd[1].view(-1, 2).detach().numpy()[:, 0], '.')
plt.plot(in_gd[2].view(-1, 2).detach().numpy()[:, 1], in_gd[2].view(-1, 2).detach().numpy()[:, 0], '.')
plt.show()

In [None]:
plt.plot(range(len(costs)), costs)
plt.show()