In [None]:
%reset
%load_ext autoreload
%autoreload 2

import Sampling
import UsefulFunctions as fun
import DeformationModules as df
import Models
import Hamiltonian

import numpy as np
import matplotlib.pyplot as plt
import torch

torch.set_default_tensor_type(torch.DoubleTensor)

In [None]:
source = Sampling.sampleFromGreyscale("data/density_a.png", threshold=0.5, centered=True)
target = Sampling.sampleFromGreyscale("data/density_b.png", threshold=0.5, centered=True)

In [None]:
minx, miny, maxx, maxy = np.min(source[0][:, 0].numpy()), np.min(source[0][:, 1].numpy()), np.max(source[0][:, 0].numpy()), np.max(source[0][:, 1].numpy())

In [None]:
Sigma0 = 0.05
Sigma1 = 0.2
x0, y0 = torch.meshgrid([torch.arange(minx-Sigma0, maxx+Sigma0, step=Sigma0), torch.arange(miny-Sigma0, maxy+Sigma0, step=Sigma0)])
x1, y1 = torch.meshgrid([torch.arange(minx-Sigma1, maxx+Sigma1, step=Sigma1), torch.arange(miny-Sigma1, maxy+Sigma1, step=Sigma1)])

GD0 = fun.grid2vec(x0, y0).contiguous().view(-1)
GD1 = fun.grid2vec(x1, y1).contiguous().view(-1)

Trans0 = df.Translations(Sigma0, 2, GD0.view(-1, 2).shape[0])
Trans1 = df.Translations(Sigma1, 2, GD1.view(-1, 2).shape[0])

In [None]:
fun.plotTensorScatter(source, alpha=0.4)
fun.plotTensorScatter(target, alpha=0.4)
plt.plot(GD0.view(-1, 2)[:,1].numpy(), GD0.view(-1, 2)[:,0].numpy(), '.')
plt.plot(GD1.view(-1, 2)[:,1].numpy(), GD1.view(-1, 2)[:,0].numpy(), '.')
plt.show()

In [None]:
myModel = Models.ModelCompoundRegistration(2, source, [Trans0, Trans1], [GD0, GD1], [False, False])
costs = myModel.fit(target, lr=1e-5, l=20., maxiter=400, logInterval=10)

In [None]:
out = myModel()
outGD, _ = myModel.shootList()

%matplotlib qt5
plt.subplot(1, 2, 1)
fun.plotTensorScatter(target, alpha=0.4)
fun.plotTensorScatter(out, alpha=0.4)
plt.plot(outGD[1].view(-1, 2).detach().numpy()[:, 1], outGD[1].view(-1, 2).detach().numpy()[:, 0], '.')
plt.plot(outGD[2].view(-1, 2).detach().numpy()[:, 1], outGD[2].view(-1, 2).detach().numpy()[:, 0], '.')
plt.subplot(1, 2, 2)
fun.plotTensorScatter(target, alpha=0.4)
fun.plotTensorScatter(source, alpha=0.4)
plt.show()

In [None]:
plt.plot(range(len(costs)), costs)
plt.show()