In [None]:
%reset
%load_ext autoreload
%autoreload 2

import time

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../")

import numpy as np
import matplotlib.pyplot as plt
import torch

import defmod as dm

torch.set_default_tensor_type(torch.FloatTensor)

In [None]:
source_image = dm.sampling.load_greyscale_image("../data/heart_a.png")
target_image = dm.sampling.load_greyscale_image("../data/heart_b.png")

In [None]:
#aabb = dm.usefulfunctions.AABB(50., 150., 75., 150.)
aabb = dm.usefulfunctions.AABB(0., source_image.shape[0], 0., source_image.shape[1])
sigma = 7.
step = 0.5*sigma
x, y = torch.meshgrid([torch.arange(aabb.xmin, aabb.xmax, step=step), torch.arange(aabb.ymin, aabb.ymax, step=step)])

gd = dm.usefulfunctions.grid2vec(x, y).contiguous().view(-1)

trans = dm.deformationmodules.Translations(2, gd.view(-1, 2).shape[0], sigma)

In [None]:
plt.imshow(source_image)
plt.scatter(gd.view(-1, 2)[:, 0].numpy(), gd.view(-1, 2)[:, 1].numpy())
#plt.scatter(gd_2.view(-1, 2)[:, 0].numpy(), gd_2.view(-1, 2)[:, 1].numpy())
plt.show()

In [None]:
import numpy as np
import scipy.ndimage.filters as fi

def gkern2(kernlen=21, nsig=3):
    """Returns a 2D Gaussian kernel array."""

    # create nxn zeros
    inp = np.zeros((kernlen, kernlen))
    # set element at the middle to one, a dirac delta
    inp[kernlen//2, kernlen//2] = 1
    # gaussian-smooth the dirac, resulting in a gaussian filter mask
    return fi.gaussian_filter(inp, nsig)

def gaussian_filtering(img):
    kr = 50
    kd = kr*2+1
    sigma = 10
    frame_res = img.shape
    kernel = torch.tensor(gkern2(kd, sigma).astype(np.float32)).reshape(1, 1, kd, kd)
    return torch.nn.functional.conv2d(img.reshape(1, 1, frame_res[0], frame_res[1]), kernel, stride=1, padding=kr).reshape(frame_res)

my_model = dm.models.ModelCompoundImageRegistration(2, source_image, [trans], [gd], [True])
start_time = time.clock()
costs = my_model.fit(target_image, lr=0.001, l=150., max_iter=200, log_interval=10)
print("Elapsed time:", time.clock() - start_time)

In [None]:
it = 5
out_gd, _ = my_model.shoot_list(it=it, intermediate=True)
sampled_out = my_model(it=it, intermediate=True)
grid_x, grid_y = my_model.compute_deformation_grid(torch.tensor([0., 0.]), torch.tensor([32., 32.]), torch.Size([8, 8]), it=it, intermediate=True)
%matplotlib qt5
for i in range(0, it):
    ax = plt.subplot(1, it+1, i+1)
    plt.imshow(sampled_out[i].detach().numpy(), cmap='gray')
    dm.usefulfunctions.plot_grid(ax, grid_x[i].detach().numpy(), -grid_y[i].detach().numpy()+32, color='C0')
    #plt.plot(out_gd[i][0].view(-1, 2)[:, 0].detach().numpy(), -out_gd[i][0].view(-1, 2)[:, 1].detach().numpy()+32, '.')
    plt.plot(out_gd[i][0].view(-1, 2)[:, 0].detach().numpy(), -out_gd[i][0].view(-1, 2)[:, 1].detach().numpy()+32, '.')

plt.subplot(1, it+1, it+1)
plt.imshow(target_image, cmap='gray')

plt.show()

In [None]:
plt.plot(range(0, len(costs)), costs)
plt.show()