In [None]:
%reset
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import matplotlib.pyplot as plt

import DeformationModules as defmod
import Shooting as shoot
import Hamiltonian as ham
import UsefulFunctions as fun

In [None]:
"""Dimension of the space"""
d = 2

In [None]:
σ = 1.5
nbPts = 4

Trans = defmod.Translations(σ, d, nbPts)
Trans_Id = defmod.TranslationsIdenticalCost(σ, d, nbPts)

In [None]:
Hamiltonian = ham.Hamilt(Trans)
Hamiltonian_Id = ham.Hamilt(Trans_Id)

In [None]:
GD_Trans = torch.tensor([[1., 2.], [-1., -2.], [1., -2.5], [-2., 1.5]], requires_grad=True)
MOM_Trans = torch.tensor([[0., 0.5], [0.5, 0.5], [-1., -0.2], [0.1, -0.4]], requires_grad=True)

GD_Trans_Id = GD_Trans.clone().detach()
MOM_Trans_Id = MOM_Trans.clone().detach()
GD_Trans_Id.requires_grad_()
MOM_Trans_Id.requires_grad_()

In [None]:
Controls = Hamiltonian.Cont_geo(GD_Trans, MOM_Trans)
Controls_Id = Hamiltonian_Id.Cont_geo(GD_Trans_Id, MOM_Trans_Id)

In [None]:
GD_Final, MOM_Final = shoot.shoot(Trans, GD_Trans, MOM_Trans, Hamiltonian, 100)
GD_Final_Id, MOM_Final_Id = shoot.shoot(Trans_Id, GD_Trans_Id, MOM_Trans_Id, Hamiltonian_Id, 100)

In [None]:
print(GD_Final.view(-1, 2))
print(MOM_Final.view(-1, 2))
print(Controls.view(-1, 2))
print("========================")
print(GD_Final_Id.view(-1, 2))
print(MOM_Final_Id.view(-1, 2))
print(Controls_Id.view(-1, 2))

In [None]:
nx, ny = 100, 100
sx, sy = 10, 10
x, y = torch.meshgrid([torch.arange(0, nx), torch.arange(0, ny)])
x = sx*(x.type(torch.FloatTensor)/nx - 0.5)
y = sy*(y.type(torch.FloatTensor)/ny - 0.5)
u, v = fun.vec2grid(Trans(GD_Final, Controls, fun.grid2vec(x, y).type(torch.FloatTensor)), nx, ny)
u_id, v_id = fun.vec2grid(Trans(GD_Final_Id, Controls_Id, fun.grid2vec(x, y).type(torch.FloatTensor)), nx, ny)


In [None]:
%matplotlib qt5
plt.subplot(1, 2, 1)
plt.title('Real cost')
plt.quiver(x.numpy(), y.numpy(), u.detach().numpy(), v.detach().numpy())
plt.subplot(1, 2, 2)
plt.title('Identical cost')
plt.quiver(x.numpy(), y.numpy(), u_id.detach().numpy(), v_id.detach().numpy())
plt.show()