In [None]:
%reset
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")

import numpy as np
import matplotlib.pyplot as plt
import torch

import defmod as dm

torch.set_default_tensor_type(torch.DoubleTensor)

In [None]:
source = dm.sampling.load_and_sample_greyscale("../data/density_a.png", threshold=0.5, centered=True)
target = dm.sampling.load_and_sample_greyscale("../data/density_b.png", threshold=0.5, centered=True)

In [None]:
aabb = dm.usefulfunctions.AABB.build_from_points(source[0])
aabb_total = dm.usefulfunctions.AABB.build_from_points(torch.cat([source[0], target[0]]))

In [None]:
sigma = 0.1
x, y = torch.meshgrid([torch.arange(aabb.xmin-sigma, aabb.xmax+sigma, step=sigma), torch.arange(aabb.ymin-sigma, aabb.ymax+sigma, step=sigma)])
gd = dm.usefulfunctions.grid2vec(x, y).contiguous().view(-1)

In [None]:
dm.usefulfunctions.plotTensorScatter(source, alpha=0.4)
dm.usefulfunctions.plotTensorScatter(target, alpha=0.4)
plt.plot(gd.view(-1, 2)[:, 1].numpy(), gd.view(-1, 2)[:, 0].numpy(), '.')
plt.show()

In [None]:
trans = dm.deformationmodules.Translations(2, gd.view(-1, 2).shape[0], sigma)
my_model = dm.models.ModelCompoundWithPointsRegistration(2, source, [trans], [gd], [False])
costs = my_model.fit(target, max_iter=200, l=80., lr=1e-3, log_interval=10)

In [None]:
out = my_model()
out_gd, _ = my_model.shoot_list()
%matplotlib qt5
plt.subplot(1, 2, 1)
dm.usefulfunctions.plot_tensor_scatter(target, alpha=0.4)
dm.usefulfunctions.plot_tensor_scatter(source, alpha=0.4)
ax = plt.subplot(1, 2, 2)
dm.usefulfunctions.plot_tensor_scatter(target, alpha=0.4)
dm.usefulfunctions.plot_tensor_scatter(out, alpha=0.4)

grid_x, grid_y = my_model.compute_deformation_grid(
    torch.tensor([aabb_total.xmin - 2.*sigma, aabb_total.ymin - 2.*sigma]),
    torch.tensor([aabb_total.width + 4.*sigma, aabb_total.height + 4.*sigma]),
    torch.Size([32, 32]))

dm.usefulfunctions.plot_grid(ax, grid_y.detach().numpy(), grid_x.detach().numpy(), color='C0')
plt.show()


In [None]:
plt.plot(range(0, len(costs)), costs)
plt.show()