In [None]:
%reset
%load_ext autoreload
%autoreload 2

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../")

import numpy as np
import matplotlib.pyplot as plt
import torch

import defmod as dm

torch.set_default_tensor_type(torch.FloatTensor)

In [None]:
source_image = dm.sampling.load_greyscale_image("../data/heart_a.png")
target_image = dm.sampling.load_greyscale_image("../data/heart_b.png")

In [None]:
aabb = dm.usefulfunctions.AABB(0., source_image.shape[1], 0., source_image.shape[0])
sigma = 3.
x, y = torch.meshgrid([torch.arange(aabb.xmin, aabb.xmax, step=sigma), torch.arange(aabb.ymin, aabb.ymax, step=sigma)])

gd = dm.usefulfunctions.grid2vec(x, y).contiguous().view(-1)

trans = dm.deformationmodules.Translations(2, gd.view(-1, 2).shape[0], sigma)

In [None]:
plt.imshow(source_image)
plt.scatter(gd.view(-1, 2)[:, 0].numpy(), gd.view(-1, 2)[:, 1].numpy())
plt.show()

In [None]:
my_model = dm.models.ModelCompoundImageRegistration(2, source_image, [trans], [gd], [True], threshold=0.)
costs = my_model.fit(target_image, lr=0.005, l=50., max_iter=200, log_interval=10)

In [None]:
out = my_model()
out_gd, _ = my_model.shoot_list()
out_points = out_gd[0].view(-1, 2).detach(), my_model.alpha
%matplotlib qt5
plt.subplot(1, 3, 1)
plt.imshow(source_image)
plt.subplot(1, 3, 2)
plt.imshow(torch.flip(dm.sampling.sample_from_points(out_points, source_image.shape), [0]))
plt.plot(out_gd[1].view(-1, 2).detach()[:, 1].numpy(), out_gd[1].view(-1, 2).detach()[:, 0].numpy(), '.')
plt.subplot(1, 3, 3)
plt.imshow(target_image)
plt.show()

In [None]:
plt.plot(range(0, len(costs)), costs)
plt.show()