In [None]:
%reset
%load_ext autoreload
%autoreload 2

# The deformation module library is not automatically installed yet, we need to add its path manually
import sys
sys.path.append("../")

import torch
import numpy as np
import matplotlib.pyplot as plt

import defmod as dm

torch.manual_seed(1337)
torch.set_default_tensor_type(torch.DoubleTensor)

In [None]:
d = 2

In [None]:
sigma = 0.1
nb_pts_trans = 10
nb_pts_trans2 = 5
nb_pts_silent = 25

trans = dm.deformationmodules.Translations(d, nb_pts_trans, sigma)
trans2 = dm.deformationmodules.Translations(d, nb_pts_trans2, sigma)
silent = dm.deformationmodules.SilentPoints(d, nb_pts_silent)

compound = dm.deformationmodules.Compound([trans, trans2, silent])

In [None]:
hamiltonian = dm.hamiltonian.Hamiltonian(compound)

In [None]:
gd_trans = torch.rand(nb_pts_trans, d, requires_grad=True).view(-1)
gd_trans2 = torch.rand(nb_pts_trans2, d, requires_grad=True).view(-1)
gd_silent = torch.rand(nb_pts_silent, d, requires_grad=True).view(-1)

mom_trans = torch.rand(nb_pts_trans, d, requires_grad=True).view(-1)
mom_trans2 = torch.rand(nb_pts_trans2, d, requires_grad=True).view(-1)
mom_silent = torch.rand(nb_pts_silent, d, requires_grad=True).view(-1)

gd_comp = torch.cat([gd_trans, gd_trans2, gd_silent])
mom_comp = torch.cat([mom_trans, mom_trans2, mom_silent])

In [None]:
controls = hamiltonian.geodesic_controls(gd_comp, mom_comp).view(-1)

In [None]:
gd_final, mom_final = dm.shooting.shoot(gd_comp, mom_comp, hamiltonian, 10)
print(gd_final, mom_final)

In [None]:
nx, ny = 100, 100
sx, sy = 10, 10
x, y = torch.meshgrid([torch.arange(0, nx), torch.arange(0, ny)])
x = sx*(x.type(torch.get_default_dtype())/nx - 0.5)
y = sy*(y.type(torch.get_default_dtype())/ny - 0.5)
u, v = dm.usefulfunctions.vec2grid(compound(gd_final, controls, dm.usefulfunctions.grid2vec(x, y)), nx, ny)

plt.quiver(x.numpy(), y.numpy(), u.detach().numpy(), v.detach().numpy())
plt.show()