In [None]:
%reset
%load_ext autoreload
%autoreload 2

import time

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../")

import numpy as np
import matplotlib.pyplot as plt
import torch

import defmod as dm

torch.set_default_tensor_type(torch.FloatTensor)

In [None]:
source_image = dm.sampling.load_greyscale_image("../data/heart_a.png")
target_image = dm.sampling.load_greyscale_image("../data/heart_b.png")

In [None]:
#aabb = dm.usefulfunctions.AABB(50., 150., 75., 150.)
aabb = dm.usefulfunctions.AABB(0., source_image.shape[0], 0., source_image.shape[1])
sigma = 2.5
sigma_2 = 10.
x, y = torch.meshgrid([torch.arange(aabb.xmin, aabb.xmax, step=0.5*sigma), torch.arange(aabb.ymin, aabb.ymax, step=0.5*sigma)])
x_2, y_2 = torch.meshgrid([torch.arange(aabb.xmin, aabb.xmax, step=sigma_2), torch.arange(aabb.ymin, aabb.ymax, step=sigma_2)])

gd = dm.usefulfunctions.grid2vec(x, y).contiguous().view(-1)
gd_2 = dm.usefulfunctions.grid2vec(x_2, y_2).contiguous().view(-1)

trans = dm.deformationmodules.Translations(2, gd.view(-1, 2).shape[0], sigma)
trans_2 = dm.deformationmodules.Translations(2, gd_2.view(-1, 2).shape[0], sigma_2)

In [None]:
plt.imshow(source_image)
plt.scatter(gd.view(-1, 2)[:, 0].numpy(), gd.view(-1, 2)[:, 1].numpy())
#plt.scatter(gd_2.view(-1, 2)[:, 0].numpy(), gd_2.view(-1, 2)[:, 1].numpy())
plt.show()

In [None]:
my_model = dm.models.ModelCompoundImageRegistration(2, source_image, [trans], [gd], [False])
start_time = time.clock()
costs = my_model.fit(target_image, lr=0.0005, l=50., max_iter=240, log_interval=10)
print("Elapsed time:", time.clock() - start_time)

In [None]:
out_gd, _ = my_model.shoot_list()
out_points = out_gd[0].view(-1, 2).detach(), my_model.source[1]
sampled_out = my_model()
%matplotlib qt5
plt.subplot(1, 3, 1)
plt.imshow(source_image, cmap='gray')
ax = plt.subplot(1, 3, 2)
plt.imshow(sampled_out.detach().numpy(), cmap='gray')
x_grid, y_grid = my_model.compute_deformation_grid(torch.tensor([0., 0.]), torch.tensor([32., 32.]), torch.Size([16, 16]))
dm.usefulfunctions.plot_grid(ax, y_grid.detach().numpy(), x_grid.detach().numpy(), color="C0")
plt.subplot(1, 3, 3)
plt.imshow(target_image, cmap='gray')
plt.show()

In [None]:
plt.plot(range(0, len(costs)), costs)
plt.show()