In [None]:
import matplotlib.pyplot as plt
import torch
from scipy.integrate import solve_ivp
from torch import nn
import numpy as np
%matplotlib inline

In [None]:
from functorch import vmap, vjp
from functorch import jacrev, jacfwd

class NNApproximator(nn.Module):
  def __init__(self, dim_input = 1, dim_output = 2, num_hidden = 2, dim_hidden = 1, activation=nn.Tanh()):
    super().__init__()

    self.layer_in = nn.Linear(dim_input, dim_hidden)
    self.layer_out = nn.Linear(dim_hidden, dim_output)
    # self.A = nn.Parameter(torch.randn(2,2))
    self.k = nn.Parameter(torch.rand(1, requires_grad=True))
    # self.A = self.k * torch.from_numpy(np.array([[-1,1],[1,-1]]))

    num_middle = num_hidden - 1
    self.middle_layers = nn.ModuleList(
        [nn.Linear(dim_hidden, dim_hidden) for _ in range(num_middle)]
    )
    self.activation = activation

  def forward(self, x):
    out = self.activation(self.layer_in(x))
    for layer in self.middle_layers:
      out = self.activation(layer(out))
    return self.layer_out(out)

  # reference for implementing derivatives for batched inputs
  # https://pytorch.org/functorch/stable/notebooks/jacobians_hessians.html
  def jacobian(self, x):
    jac = vmap(jacrev(self.forward))
    return jac(x).squeeze()

In [None]:
def compute_data_loss(model, x_tr, y_tr):
  return 0.5 * torch.mean((model.forward(x_tr) - y_tr) ** 2)

def compute_PINN_loss(model, x, k):
    F_dot = model.jacobian(x)
    s1 = x[:, 0:2] - x[:, 2:4]
    s2 = x[:, 4:6] - x[:, 2:4]
    return ((k - torch.norm(torch.einsum('ijk,ik->ij', F_dot[:, :, 0:2], s1), dim=1) / torch.norm(s1, dim=1)) ** 2).mean() + ((k - torch.norm(torch.einsum('ijk,ik->ij', F_dot[:, :, 4:6], s2), dim=1) / torch.norm(s2, dim=1)) ** 2).mean()

In [None]:
in_dim = 6
out_dim = 2
num_layer = 5
hidden_dim = 32
xy_index = [13, 14, 1, 2, 5, 6]
epoch = 20000
lr = 1e-4

data_file = '../data_set_bis.npy'

x = torch.from_numpy(np.load(data_file)).float()
x_PINN = x.clone().requires_grad_(True)
F = x[:, [3, 4]]

In [None]:
def train_model(model, data_loss_fn, PINN_loss_fn, learning_rate=0.0001, max_epochs=1000):
  tr_losses = []

  # reference on torch.LBFGS usage
  # https://gist.github.com/tuelwer/0b52817e9b6251d940fd8e2921ec5e20
  USE_BFGS = False

  if USE_BFGS:
    optimizer = torch.optim.LBFGS(model.parameters())
    print("Using BFGS optimizer ... ")
    log_iter = 10
    def closure():
        optimizer.zero_grad()
        objective = data_loss_fn(model) + PINN_loss_fn(model)
        objective.backward()
        return objective
  else:
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    print("Using Adam optimizer ... ")
    log_iter = 1000

  for epoch in range(max_epochs):
    loss = data_loss_fn(model) + PINN_loss_fn(model)
    if USE_BFGS:
      optimizer.step(closure)
    else:
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    if epoch % log_iter == 0:
      print(f"Epoch: {epoch} - Loss: {float(loss):>7f}")
      print(f"Loss over entire dataset: {float(compute_data_loss(model, x[:, xy_index], F).detach()):>7f}")
    tr_losses.append(loss.detach().numpy())

  return model, np.array(tr_losses)

In [None]:
model = NNApproximator(dim_input=in_dim, dim_output=out_dim, num_hidden=num_layer, dim_hidden=hidden_dim)

model, tr_losses = train_model(
    model,
    data_loss_fn=lambda model: compute_data_loss(model, x[:500, xy_index], F[:500]),
    PINN_loss_fn=lambda model: 1e-1*compute_PINN_loss(model, x_PINN[:, xy_index], 50),
    learning_rate=lr,
    max_epochs=epoch,
)

In [None]:
def getqdot(x, xdot, model, xLeft, xRight):
    N_m = x.shape[0] // 2

    x = np.concatenate((xLeft, x, xRight))
    forces = []
    for i in range(0, N_m):
        triplet = x[np.arange(2*i, 2*i + 6)]
        force = model.forward(torch.from_numpy(triplet).float())
        forces.append(force.detach().numpy())
        # force = getSpringForcesOnMass(triplet[:2],triplet[2:4],triplet[4:])
        # forces.append(force)

    forces = np.concatenate(forces).reshape(-1)

    q0 = np.concatenate((x, xdot))
    qdot = np.concatenate((xdot, forces))
    return qdot

def compute_trajectory(x0, x0dot, model, xLeft, xRight):
    N_m = x0.shape[0] // 2
    q0 = np.concatenate((x0, x0dot))

    t0 = 0
    tf = 20
    Nt = 101
    sol = solve_ivp(lambda t, q: getqdot(q[:2*N_m], q[2*N_m:], model, xLeft, xRight), [t0,tf], y0=q0, t_eval=np.linspace(t0, tf, Nt))
    y = sol.y
    t = sol.t
    return y, t

# modify as appropriate:
N_m = 5
xLeft = x[0, 13:15]
xRight = x[0, 15:17]

# # the vector 'x0' contains the initial positions of the *movable* masses
# # i.e. x0.shape = [N_m * 2]
x0 = x[0, [1, 2, 5, 6, 9, 10]]
# x0 = np.c_[x0, np.zeros_like(x0)].reshape(-1)
# print(x0)
x0dot = np.zeros_like(x0)

# model = lambda x: np.array([(x[1]-x[0]), (x[0]-x[1])])

y, t = compute_trajectory(x0,x0dot,model,xLeft,xRight)

In [None]:
plt.plot(t, y.T[:,1], label='true tragetory')
plt.plot(x[:, 0], x[:, 2], label='predict tragetory')
plt.legend()

In [None]:
plt.plot(x[:, 0], x[:, 3], label='true force - x')
plt.plot(x[:, 0], model.forward(x[:, [13, 14, 1, 2, 5, 6]])[:, 0].detach().numpy(), label='predicted force - x')
plt.legend()

In [None]:
plt.plot(x[:, 0], x[:, 4], label='true force - y')
plt.plot(x[:, 0], model.forward(x[:, [13, 14, 1, 2, 5, 6]])[:, 1].detach().numpy(), label='predicted force - y')