In [17]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize
from scipy import interpolate
import time
import torch
import torch.nn.functional as F
from torch.optim import LBFGS



'''############Initialize a=0, n_grid=6##########
theta_start = 0
theta_end = np.pi/2
r_start = 3
r_end = 4
thetas = torch.linspace(theta_start,theta_end,steps=n_grid, dtype=torch.double)
rs = torch.linspace(r_start,r_end,steps=n_grid, dtype=torch.double)
f_free = torch.zeros((n_grid,4*n_grid-5), dtype=torch.double)
f_free[:,:n_grid-1] = 0.001*(rs[:,np.newaxis] + thetas[np.newaxis,:-1])
f_free[:,n_grid-1:2*(n_grid-1)] = 0.001*(rs[:,np.newaxis] + thetas[np.newaxis,:-1])
f_free[:,2*(n_grid-1):3*n_grid-4] = 0.001*(rs[:,np.newaxis] + thetas[np.newaxis,1:-1])
f_free[:,3*n_grid-4:4*n_grid-5] = 0.001*(rs[:,np.newaxis] + thetas[np.newaxis,:-1])
f_free = f_free.reshape(-1,)
f_free = torch.nn.Parameter(f_free, requires_grad=True)
###############################################'''

interp_mode = 'bilinear'
align_corners = True
M = 1.0

# t' = t + f1(r,theta)
def f1(f1_free, n_grid):
    # f1_free has shape (n_grid, n_grid-1).
    # Along theta, zero derivative at theta=pi/2
    f1_ = torch.zeros(n_grid,n_grid, dtype=torch.double)
    f1_[:,:-1] = f1_free
    f1_[:,-1] = f1_free[:,-1]
    return f1_

# r' = r + f2(r,theta)
def f2(f2_free, n_grid):
    # f2_free has shape (n_grid, n_grid-1).
    # Along theta, zero derivative at theta=pi/2
    f2_ = torch.zeros(n_grid,n_grid, dtype=torch.double)
    f2_[:,:-1] = f2_free
    f2_[:,-1] = f2_free[:,-1]
    return f2_

# theta' = theta + f3(r,theta)
def f3(f3_free, n_grid):
    # f3_free has shape (n_grid, n_grid-2)
    # Along theta, zero at theta=0 and theta=pi/2
    f3_ = torch.zeros(n_grid,n_grid, dtype=torch.double)
    f3_[:,1:-1] = f3_free
    return f3_

# phi' = phi + f4(r,theta)
def f4(f4_free, n_grid):
    # f4_free has shape (n_grid, n_grid-1).
    # Along theta, zero derivative at theta=pi/2
    f4_ = torch.zeros(n_grid,n_grid, dtype=torch.double)
    f4_[:,:-1] = f4_free
    f4_[:,-1] = f4_free[:,-1]
    return f4_


def interp_free(f_free, n_grid, mode="0"):
    n_grid_old = f_free.shape[0]
    if mode == "0":
        f_ = f1(f_free, n_grid_old)
    else:
        f_ = f3(f_free, n_grid_old)
    f_free_std = f_.unsqueeze(dim=0).unsqueeze(dim=0)
    if mode == "0":
        f_free_new = F.interpolate(f_free_std, size=(n_grid,n_grid), mode=interp_mode, align_corners=align_corners)[0,0,:,:-1]
    else:
        f_free_new = F.interpolate(f_free_std, size=(n_grid,n_grid), mode=interp_mode, align_corners=align_corners)[0,0,:,1:-1]
    return f_free_new

def interp_free_test(f_free, n_grid, mode="0"):
    if mode == "0":
        f_ = f1(f_free, n_grid)
    else:
        f_ = f3(f_free, n_grid)
    f_free_std = f_.unsqueeze(dim=0).unsqueeze(dim=0)
    if mode == "0":
        f_free_new = F.interpolate(f_free_std, size=(2*n_grid-1,2*n_grid-1), mode=interp_mode, align_corners=align_corners)[0,0,1:-1:2,1:-3:2]
    else:
        f_free_new = F.interpolate(f_free_std, size=(2*n_grid-1,2*n_grid-1), mode=interp_mode, align_corners=align_corners)[0,0,1:-1:2,3:-3:2]
    return f_free_new

def decompose_free(f_free, n_grid):
    f1_free = f_free[:,:n_grid-1]
    f2_free = f_free[:,n_grid-1:2*n_grid-2]
    f3_free = f_free[:,2*n_grid-2:3*n_grid-4]
    f4_free = f_free[:,3*n_grid-4:4*n_grid-5]
    return (f1_free,f2_free,f3_free,f4_free)

def compose_free(f1_free,f2_free,f3_free,f4_free):
    return torch.cat([f1_free, f2_free, f3_free, f4_free], dim=1)

def interp_f_free(f_free, n_grid, n_grid_old):
    f1_free, f2_free, f3_free, f4_free = decompose_free(f_free, n_grid_old)
    f1_free_new = interp_free(f1_free, n_grid, mode="0")
    f2_free_new = interp_free(f2_free, n_grid, mode="0")
    f3_free_new = interp_free(f3_free, n_grid, mode="1")
    f4_free_new = interp_free(f4_free, n_grid, mode="0")
    f_free_new = compose_free(f1_free_new, f2_free_new, f3_free_new, f4_free_new)
    return f_free_new


def interp_f_free_test(f_free, n_grid):
    f_free = f_free.reshape(n_grid, 4*n_grid-5)
    f1_free, f2_free, f3_free, f4_free = decompose_free(f_free, n_grid)
    f1_free_new = interp_free_test(f1_free, n_grid, mode="0")
    f2_free_new = interp_free_test(f2_free, n_grid, mode="0")
    f3_free_new = interp_free_test(f3_free, n_grid, mode="1")
    f4_free_new = interp_free_test(f4_free, n_grid, mode="0")
    f_free_new = compose_free(f1_free_new, f2_free_new, f3_free_new, f4_free_new)
    return f_free_new





def train(a, n_grid, f_free):

    interp_mode = "bilinear"
    align_corners = True

    theta_start = 0
    theta_end = np.pi/2
    r_start = 3
    r_end = 4
    M = 1


    thetas = torch.linspace(theta_start,theta_end,steps=n_grid, dtype=torch.double)
    rs = torch.linspace(r_start,r_end,steps=n_grid, dtype=torch.double)
    theta_h = (theta_end - theta_start)/(n_grid-1)
    r_h  = (r_end - r_start)/(n_grid-1)

    RS, THETAS = torch.meshgrid(rs, thetas)
    z = torch.transpose(torch.stack([RS.reshape(-1,), THETAS.reshape(-1,)]),0,1)

    # "Testing"
    rs_test = (rs[1:] + rs[:-1])/2
    thetas_test = (thetas[1:] + thetas[:-1])/2
    f_free_test = interp_f_free_test(f_free, n_grid).reshape(-1,)
    RS_test, THETAS_test = torch.meshgrid(rs_test, thetas_test)
    z_test = torch.transpose(torch.stack([RS_test.reshape(-1,), THETAS_test.reshape(-1,)]),0,1)
    
    def r_derivative(f, n_grid):
        f_aug = torch.zeros(n_grid+2,n_grid,dtype=torch.double)
        f_aug[1:-1] = f
        f_aug[0] = 2*f[0] - f[1]
        f_aug[-1] = 2*f[-1] - f[-2]
        f_r = (f_aug[2:] - f_aug[:-2])/(2*r_h)
        return f_r

    def theta_derivative(f, n_grid):
        f_aug = torch.zeros(n_grid,n_grid+2,dtype=torch.double)
        f_aug[:,1:-1] = f
        f_aug[:,0] = 2*f[:,0] - f[:,1]
        f_aug[:,-1] = 2*f[:,-1] - f[:,-2]
        f_theta = (f_aug[:,2:] - f_aug[:,:-2])/(2*theta_h)
        return f_theta
    
    def w(f1,f2,f3,f4, n_grid):
        f1_r = r_derivative(f1, n_grid).reshape(-1,)
        f2_r = r_derivative(f2, n_grid).reshape(-1,)
        f3_r = r_derivative(f3, n_grid).reshape(-1,)
        f4_r = r_derivative(f4, n_grid).reshape(-1,)
        f1_theta = theta_derivative(f1, n_grid).reshape(-1,)
        f2_theta = theta_derivative(f2, n_grid).reshape(-1,)
        f3_theta = theta_derivative(f3, n_grid).reshape(-1,)
        f4_theta = theta_derivative(f4, n_grid).reshape(-1,)
        ones = torch.ones(f1_r.shape[0], dtype=torch.double)

        stack1 = torch.stack([ones, f1_r, f1_theta, 0*ones])
        stack2 = torch.stack([0*ones, 1+f2_r, f2_theta, 0*ones])
        stack3 = torch.stack([0*ones, f3_r, 1+f3_theta, 0*ones])
        stack4 = torch.stack([0*ones, f4_r, f4_theta, ones])
        w_ = torch.stack([stack1, stack2, stack3, stack4])
        w_ = w_.permute(2,0,1)
        return w_

    def w_inv_invt(w):
        w_inv = torch.linalg.inv(w)
        w_invt = w_inv.permute(0,2,1)
        return w_inv, w_invt

    def gp(g, w):
        w_inv, w_invt = w_inv_invt(w)
        gp_ = torch.matmul(torch.matmul(w_invt, g), w_inv)
        return gp_

    def zp(z, f2, f3):
        f2 = f2.reshape(-1,)
        f3 = f3.reshape(-1,)
        rp = z[:,0] + f2
        thetap = z[:,1] + f3
        zp_ = torch.transpose(torch.stack([rp, thetap]),0,1)
        return zp_

    def g(x_, a=0.0):
        r = x_[:,0]
        theta = x_[:,1]
        bs = x_.shape[0]
        Sigma = r**2 + a**2*np.cos(theta)**2
        Delta = r**2 - 2*M*r + a**2
        one = torch.ones(bs, dtype=torch.double)
        g01 = g02 = g10 = g12 = g13 = g20 = g21 = g23 = g31 = g32 = 0*one
        g00 = -(1-2*M*r/Sigma)
        g03 = g30 = -2*M*a*r*torch.sin(theta)**2/Sigma
        g11 = Sigma/Delta
        g22 = Sigma
        g33 = (r**2+a**2+2*M*a**2*r*torch.sin(theta)**2/Sigma)*torch.sin(theta)**2
        stack1 = torch.stack([g00, g01, g02, g03])
        stack2 = torch.stack([g10, g11, g12, g13])
        stack3 = torch.stack([g20, g21, g22, g23])
        stack4 = torch.stack([g30, g31, g32, g33])
        gs = torch.stack([stack1, stack2, stack3, stack4]).permute(2,0,1)
        return gs

    def gp_space_target(zp):
        bs = zp.shape[0]
        one = torch.ones(bs, dtype=torch.double)
        g11 = one
        g12 = g13 = g21 = g23 = g31 = g32 = 0*one
        g22 = zp[:,0]**2
        g33 = zp[:,0]**2*torch.sin(zp[:,1])**2
        stack1 = torch.stack([g11,g12,g13])
        stack2 = torch.stack([g21,g22,g23])
        stack3 = torch.stack([g31,g32,g33])
        gs = torch.stack([stack1, stack2, stack3]).permute(2,0,1)
        return gs

    
    def error(f_free, n_grid):
        f_free = f_free.reshape(n_grid, 4*n_grid-5)
        f1_free, f2_free, f3_free, f4_free = decompose_free(f_free, n_grid)
        f1_ = f1(f1_free, n_grid)
        f2_ = f2(f2_free, n_grid)
        f3_ = f3(f3_free, n_grid)
        f4_ = f4(f4_free, n_grid)
        g_ = g(z, a=a)
        w_ = w(f1_,f2_,f3_,f4_, n_grid)
        zp_ = zp(z,f2_,f3_)
        gp_space = gp(g_, w_)[:,1:,1:].reshape(n_grid,n_grid,3,3)

        gp_space_target_ = gp_space_target(zp_).reshape(n_grid,n_grid,3,3)
        error_ = torch.mean((gp_space-gp_space_target_)[1:-1,:]**2)
        return error_


    def error_test(f_free, n_grid):
        #print(f_free.shape)
        f_free = f_free.reshape(n_grid-1, 4*(n_grid-1)-5)
        f1_free, f2_free, f3_free, f4_free = decompose_free(f_free, n_grid-1)
        f1_ = f1(f1_free, n_grid-1)
        f2_ = f2(f2_free, n_grid-1)
        f3_ = f3(f3_free, n_grid-1)
        f4_ = f4(f4_free, n_grid-1)
        g_ = g(z_test, a=a)
        w_ = w(f1_,f2_,f3_,f4_, n_grid-1)
        zp_test = zp(z_test,f2_,f3_)
        gp_space = gp(g_, w_)[:,1:,1:].reshape(n_grid-1,n_grid-1,3,3)

        gp_space_target_ = gp_space_target(zp_test).reshape(n_grid-1,n_grid-1,3,3)
        error_ = torch.mean((gp_space-gp_space_target_)[1:-1,:]**2)
        return error_

    def error_all(f_free, n_grid):
        f_free_test = interp_f_free_test(f_free, n_grid).reshape(-1,)
        return error(f_free, n_grid) + error_test(f_free_test, n_grid)

        
    start = time.time()
    lr = 1e-2*(6/n_grid)**2
    opt = torch.optim.Adam({f_free}, lr=lr, eps=1e-8)
    
    epochs = 15000
    switch_epoch = 3000
    log = 100
    best_loss = 1e20
    losses = []
    for i in range(epochs):
        if (i+1) % switch_epoch == 0:
            for opt_param in opt.param_groups:
                lr = lr * 0.5
                opt_param['lr'] = lr
        
        def loss_closure():
            opt.zero_grad()
            loss = error_all(f_free, n_grid)
            loss.backward()
            return loss
          # -------------------------------------------
        loss = loss_closure()
        opt.step(loss_closure)  # get loss, use to update wts
        if loss < best_loss:
            best_loss = loss
            best_epoch = i
            best_free = f_free.clone()
        if i % log == 0:
            print("Epoch: {}".format(i) + " | " + "Loss: {}".format(loss.detach().numpy()))
        losses.append(loss.detach().numpy())
    end = time.time()
    duration = end - start
    print("loss={}".format(best_loss.detach().numpy()))
    print("time={}".format(duration))
    return best_free.clone(), best_loss.detach().numpy(), duration, losses