In [None]:
%reset
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../")

import torch
import numpy as np
import matplotlib.pyplot as plt

import defmod as dm

In [None]:
dim = 2
sigma = 5.
nb_pts = 2

gd_trans = torch.tensor([[-2., 0.], [2., 0.]], requires_grad=True).view(-1)
mom_trans = torch.tensor([[-1., 0.], [1., 3.]], requires_grad=True).view(-1)

trans = dm.deformationmodules.Translations(dm.manifold.Landmarks(dim, nb_pts, gd=gd_trans, cotan=mom_trans), sigma)
hamiltonian = dm.hamiltonian.Hamiltonian([trans])

In [None]:
dm.shooting.shoot(hamiltonian, method='rk4', it=10)
print("Initial")
print(gd_trans.view(-1, 2))
print(mom_trans.view(-1, 2))
print("Final")
print(trans.manifold.gd.view(-1, 2))
print(trans.manifold.cotan.view(-1, 2))
print(trans.controls.view(-1, 2))

In [None]:
nx, ny = 100, 100
sx, sy = 10, 10
x, y = torch.meshgrid([torch.arange(0, nx), torch.arange(0, ny)])
x = sx*(x.type(torch.FloatTensor)/nx - 0.5)
y = sy*(y.type(torch.FloatTensor)/ny - 0.5)
u, v = dm.usefulfunctions.vec2grid(trans(dm.usefulfunctions.grid2vec(x, y).type(torch.FloatTensor)), nx, ny)

In [None]:
%matplotlib qt5

plt.quiver(x.numpy(), y.numpy(), u.detach().numpy(), v.detach().numpy())
plt.show()