In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize
from scipy import interpolate
import time
import torch
import torch.nn.functional as F
from torch.optim import LBFGS
import scipy

# Time scaling wrt to n_grid !!!


# from the first version, three updates:
# (1) for f_theta -> add 1/r metric
# (2) [0,pi/2] -> [0,pi]. Boundary condition: f1,f2,f4,f5 - no; f3 - 0 at theta=0 and pi
# (3) phi' = phi + a*t, a is a tunable parameter. a = average(f5)

# Still extrapolation at boundary. If no extrapolation, weights in error_all should also be changed.



'''############Initialize a=0, n_grid=6##########
theta_start = 0
theta_end = np.pi/2
r_start = 3
r_end = 4
thetas = torch.linspace(theta_start,theta_end,steps=n_grid, dtype=torch.double)
rs = torch.linspace(r_start,r_end,steps=n_grid, dtype=torch.double)
f_free = torch.zeros((n_grid,4*n_grid-5), dtype=torch.double)
f_free[:,:n_grid-1] = 0.001*(rs[:,np.newaxis] + thetas[np.newaxis,:-1])
f_free[:,n_grid-1:2*(n_grid-1)] = 0.001*(rs[:,np.newaxis] + thetas[np.newaxis,:-1])
f_free[:,2*(n_grid-1):3*n_grid-4] = 0.001*(rs[:,np.newaxis] + thetas[np.newaxis,1:-1])
f_free[:,3*n_grid-4:4*n_grid-5] = 0.001*(rs[:,np.newaxis] + thetas[np.newaxis,:-1])
f_free = f_free.reshape(-1,)
f_free = torch.nn.Parameter(f_free, requires_grad=True)
###############################################'''

torch.set_default_dtype(torch.float64)

interp_mode = 'bilinear'
interp_mode2 = 'bilinear'
align_corners = True
M = 1.0

# t' = t + f1(r,theta)
def f1(f1_free, n_grid):
    # f1_free has shape (n_grid, n_grid).
    # no boundary condition
    return f1_free

# r' = r + f2(r,theta)
def f2(f2_free, n_grid):
    # f2_free has shape (n_grid, n_grid).
    # no boundary condition
    return f2_free

# theta' = theta + f3(r,theta)
def f3(f3_free, n_grid):
    # f3_free has shape (n_grid, n_grid-2)
    # along theta, zero at theta = 0 and pi
    f3_ = torch.zeros(n_grid,n_grid, dtype=torch.double)
    f3_[:,1:-1] = f3_free
    return f3_

# phi' = phi + f4(r,theta) + average(f5)*t
def f4(f4_free, n_grid):
    # f4_free has shape (n_grid, n_grid-1).
    # no boundary condition
    return f4_free

# phi' = phi + f4(r,theta)
def f5(f5_free, n_grid):
    # f4_free has shape (n_grid, n_grid-1).
    # no boundary condition
    return f5_free


def free2pad_(f_free, n_grid, mode="0"):
    pad = torch.zeros(n_grid,n_grid, dtype=torch.double)
    if mode == "0":
        pad = f_free
    else:
        pad[:,1:-1] = f_free
    return pad


def interp_free(f_pad, n_grid, mode="0"):
    n_grid_old = f_pad.shape[0]
    f_pad = f_pad.unsqueeze(dim=0).unsqueeze(dim=0)
    if mode == "0":
        f_free_new = F.interpolate(f_pad, size=(n_grid,n_grid), mode=interp_mode2, align_corners=align_corners)[0,0,:,:]
    else:
        f_free_new = F.interpolate(f_pad, size=(n_grid,n_grid), mode=interp_mode2, align_corners=align_corners)[0,0,:,1:-1]
    return f_free_new

def interp_free_test(f_pad, n_grid, mode="0"):
    f_pad = f_pad.unsqueeze(dim=0).unsqueeze(dim=0)
    if mode == "0":
        f_free_new = F.interpolate(f_pad, size=(2*n_grid-1,2*n_grid-1), mode=interp_mode, align_corners=align_corners)[0,0,1:-1:2,1:-1:2]
    else:
        f_free_new = F.interpolate(f_pad, size=(2*n_grid-1,2*n_grid-1), mode=interp_mode, align_corners=align_corners)[0,0,1:-1:2,3:-3:2]
    return f_free_new

def decompose_free(f_free, n_grid):
    f1_free = f_free[:,:n_grid]
    f2_free = f_free[:,n_grid:2*n_grid]
    f3_free = f_free[:,2*n_grid:3*n_grid-2]
    f4_free = f_free[:,3*n_grid-2:4*n_grid-2]
    f5_free = f_free[:,4*n_grid-2:5*n_grid-2]
    return (f1_free,f2_free,f3_free,f4_free,f5_free)

def compose_free(f1_free,f2_free,f3_free,f4_free,f5_free):
    return torch.cat([f1_free, f2_free, f3_free, f4_free, f5_free], dim=1)

def interp_f_free(f_free, n_grid, n_grid_old):
    f1_free, f2_free, f3_free, f4_free, f5_free = decompose_free(f_free, n_grid_old)
    f1_pad = free2pad_(f1_free, n_grid_old, mode="0")
    f2_pad = free2pad_(f2_free, n_grid_old, mode="0")
    f3_pad = free2pad_(f3_free, n_grid_old, mode="1")
    f4_pad = free2pad_(f4_free, n_grid_old, mode="0")
    f5_pad = free2pad_(f5_free, n_grid_old, mode="0")
    f1_free_new = interp_free(f1_pad, n_grid, mode="0")
    f2_free_new = interp_free(f2_pad, n_grid, mode="0")
    f3_free_new = interp_free(f3_pad, n_grid, mode="1")
    f4_free_new = interp_free(f4_pad, n_grid, mode="0")
    f5_free_new = interp_free(f5_pad, n_grid, mode="0")
    f_free_new = compose_free(f1_free_new, f2_free_new, f3_free_new, f4_free_new, f5_free_new)
    return f_free_new


def interp_f_free_test(f_free, n_grid):
    f_free = f_free.reshape(n_grid, 5*(n_grid)-2)
    f1_free, f2_free, f3_free, f4_free, f5_free = decompose_free(f_free, n_grid)
    f1_pad = free2pad_(f1_free, n_grid, mode="0")
    f2_pad = free2pad_(f2_free, n_grid, mode="0")
    f3_pad = free2pad_(f3_free, n_grid, mode="1")
    f4_pad = free2pad_(f4_free, n_grid, mode="0")
    f5_pad = free2pad_(f5_free, n_grid, mode="0")
    f1_free_new = interp_free_test(f1_pad, n_grid, mode="0")
    f2_free_new = interp_free_test(f2_pad, n_grid, mode="0")
    f3_free_new = interp_free_test(f3_pad, n_grid, mode="1")
    f4_free_new = interp_free_test(f4_pad, n_grid, mode="0")
    f5_free_new = interp_free_test(f5_pad, n_grid, mode="0")
    f_free_new = compose_free(f1_free_new, f2_free_new, f3_free_new, f4_free_new, f5_free_new)
    return f_free_new


def free2pad(f_free, n_grid):
    f1_free, f2_free, f3_free, f4_free, f5_free = decompose_free(f_free, n_grid)
    f1_pad = free2pad_(f1_free, n_grid, mode="0")
    f2_pad = free2pad_(f2_free, n_grid, mode="0")
    f3_pad = free2pad_(f3_free, n_grid, mode="1")
    f4_pad = free2pad_(f4_free, n_grid, mode="0")
    f5_pad = free2pad_(f5_free, n_grid, mode="0")
    f_pad = compose_free(f1_pad, f2_pad, f3_pad, f4_pad, f5_pad)
    return f_pad



def train(a, n_grid, f_free, maxiter=10):

    interp_mode = "bilinear"
    align_corners = True

    theta_start = 0
    theta_end = np.pi
    r_start = 3
    r_end = 4
    M = 1


    thetas = torch.linspace(theta_start,theta_end,steps=n_grid, dtype=torch.double)
    rs = torch.linspace(r_start,r_end,steps=n_grid, dtype=torch.double)
    theta_h = (theta_end - theta_start)/(n_grid-1)
    r_h  = (r_end - r_start)/(n_grid-1)

    RS, THETAS = torch.meshgrid(rs, thetas)
    z = torch.transpose(torch.stack([RS.reshape(-1,), THETAS.reshape(-1,)]),0,1)

    # "Testing"
    rs_test = (rs[1:] + rs[:-1])/2
    thetas_test = (thetas[1:] + thetas[:-1])/2
    f_free_test = interp_f_free_test(f_free, n_grid).reshape(-1,)
    RS_test, THETAS_test = torch.meshgrid(rs_test, thetas_test)
    z_test = torch.transpose(torch.stack([RS_test.reshape(-1,), THETAS_test.reshape(-1,)]),0,1)
    
    def r_derivative(f_, n_grid):
        f_aug = torch.zeros(n_grid+2,n_grid,dtype=torch.double)
        f_aug[1:-1] = f_
        f_aug[0] = 2*f_[0] - f_[1]
        f_aug[-1] = 2*f_[-1] - f_[-2]
        f_r = (f_aug[2:] - f_aug[:-2])/(2*r_h)
        return f_r

    def theta_derivative(f_, n_grid, r):
        f_aug = torch.zeros(n_grid,n_grid+2,dtype=torch.double)
        f_aug[:,1:-1] = f_
        f_aug[:,0] = 2*f_[:,0] - f_[:,1]
        f_aug[:,-1] = 2*f_[:,-1] - f_[:,-2]
        f_theta = (f_aug[:,2:] - f_aug[:,:-2])/(2*theta_h)
        return f_theta/r.unsqueeze(dim=1)
    
    def w(f1_,f2_,f3_,f4_,f5_, n_grid, r):
        f1_r = r_derivative(f1_, n_grid).reshape(-1,)
        f2_r = r_derivative(f2_, n_grid).reshape(-1,)
        f3_r = r_derivative(f3_, n_grid).reshape(-1,)
        f4_r = r_derivative(f4_, n_grid).reshape(-1,)
        f1_theta = theta_derivative(f1_, n_grid, r).reshape(-1,)
        f2_theta = theta_derivative(f2_, n_grid, r).reshape(-1,)
        f3_theta = theta_derivative(f3_, n_grid, r).reshape(-1,)
        f4_theta = theta_derivative(f4_, n_grid, r).reshape(-1,)
        ones = torch.ones(f1_r.shape[0], dtype=torch.double)
        a = torch.mean(f5_)

        stack1 = torch.stack([ones, f1_r, f1_theta, 0*ones])
        stack2 = torch.stack([0*ones, 1+f2_r, f2_theta, 0*ones])
        stack3 = torch.stack([0*ones, f3_r, 1+f3_theta, 0*ones])
        stack4 = torch.stack([a*ones, f4_r, f4_theta, ones])
        w_ = torch.stack([stack1, stack2, stack3, stack4])
        w_ = w_.permute(2,0,1)
        return w_

    def w_inv_invt(w):
        w_inv = torch.linalg.inv(w)
        w_invt = w_inv.permute(0,2,1)
        return w_inv, w_invt

    def gp(g, w):
        w_inv, w_invt = w_inv_invt(w)
        gp_ = torch.matmul(torch.matmul(w_invt, g), w_inv)
        return gp_

    def zp(z, f2, f3):
        f2 = f2.reshape(-1,)
        f3 = f3.reshape(-1,)
        rp = z[:,0] + f2
        thetap = z[:,1] + f3
        zp_ = torch.transpose(torch.stack([rp, thetap]),0,1)
        return zp_

    def g(x_, a=0.0):
        r = x_[:,0]
        theta = x_[:,1]
        bs = x_.shape[0]
        Sigma = r**2 + a**2*np.cos(theta)**2
        Delta = r**2 - 2*M*r + a**2
        one = torch.ones(bs, dtype=torch.double)
        g01 = g02 = g10 = g12 = g13 = g20 = g21 = g23 = g31 = g32 = 0*one
        g00 = -(1-2*M*r/Sigma)
        g03 = g30 = -2*M*a*r*torch.sin(theta)**2/Sigma
        g11 = Sigma/Delta
        g22 = Sigma
        g33 = (r**2+a**2+2*M*a**2*r*torch.sin(theta)**2/Sigma)*torch.sin(theta)**2
        stack1 = torch.stack([g00, g01, g02, g03])
        stack2 = torch.stack([g10, g11, g12, g13])
        stack3 = torch.stack([g20, g21, g22, g23])
        stack4 = torch.stack([g30, g31, g32, g33])
        gs = torch.stack([stack1, stack2, stack3, stack4]).permute(2,0,1)
        return gs

    def gp_space_target(zp):
        bs = zp.shape[0]
        one = torch.ones(bs, dtype=torch.double)
        g11 = one
        g12 = g13 = g21 = g23 = g31 = g32 = 0*one
        g22 = zp[:,0]**2
        g33 = zp[:,0]**2*torch.sin(zp[:,1])**2
        stack1 = torch.stack([g11,g12,g13])
        stack2 = torch.stack([g21,g22,g23])
        stack3 = torch.stack([g31,g32,g33])
        gs = torch.stack([stack1, stack2, stack3]).permute(2,0,1)
        return gs

    
    def error(f_free):
        f_free = f_free.reshape(n_grid, 5*n_grid-2)
        f1_free, f2_free, f3_free, f4_free, f5_free = decompose_free(f_free, n_grid)
        f1_ = f1(f1_free, n_grid)
        f2_ = f2(f2_free, n_grid)
        f3_ = f3(f3_free, n_grid)
        f4_ = f4(f4_free, n_grid)
        f5_ = f5(f5_free, n_grid)
        g_ = g(z, a=a)
        w_ = w(f1_,f2_,f3_,f4_,f5_, n_grid, rs)
        zp_ = zp(z,f2_,f3_)
        gp_space = gp(g_, w_)[:,1:,1:].reshape(n_grid,n_grid,3,3)

        gp_space_target_ = gp_space_target(zp_).reshape(n_grid,n_grid,3,3)
        error_ = torch.mean((gp_space-gp_space_target_)[1:-1,:]**2)
        return error_


    def error_test(f_free):
        #print(f_free.shape)
        f_free = f_free.reshape(n_grid-1, 5*(n_grid-1)-2)
        f1_free, f2_free, f3_free, f4_free, f5_free = decompose_free(f_free, n_grid-1)
        f1_ = f1(f1_free, n_grid-1)
        f2_ = f2(f2_free, n_grid-1)
        f3_ = f3(f3_free, n_grid-1)
        f4_ = f4(f4_free, n_grid-1)
        f5_ = f5(f5_free, n_grid-1)
        g_ = g(z_test, a=a)
        w_ = w(f1_,f2_,f3_,f4_,f5_, n_grid-1, rs_test)
        zp_test = zp(z_test,f2_,f3_)
        gp_space = gp(g_, w_)[:,1:,1:].reshape(n_grid-1,n_grid-1,3,3)

        gp_space_target_ = gp_space_target(zp_test).reshape(n_grid-1,n_grid-1,3,3)
        error_ = torch.mean((gp_space-gp_space_target_)[1:-1,:]**2)
        return error_

    def error_all(f_free):
        f_free_test = interp_f_free_test(f_free, n_grid).reshape(-1,)
        return (error(f_free)*n_grid**2 + error_test(f_free_test)*(n_grid-1)**2)/(n_grid**2+(n_grid-1)**2)
    
    def error_all_np(f_free_np):
        f_free = torch.tensor(f_free_np, dtype=torch.double, requires_grad=True)
        f_free_test = interp_f_free_test(f_free, n_grid).reshape(-1,)
        error_ = (error(f_free)*n_grid**2 + error_test(f_free_test)*(n_grid-1)**2)/(n_grid**2+(n_grid-1)**2)
        return error_.detach().numpy()
    
    def jac(f_free_np):
        f_free = torch.tensor(f_free_np, dtype=torch.double, requires_grad=True)
        loss = error_all(f_free)
        loss.backward()
        return f_free.grad.detach().numpy()
    
    f_free_np = f_free.detach().numpy()
    start = time.time()

    sol = scipy.optimize.minimize(error_all_np, f_free_np, method='L-BFGS-B', jac=jac, tol=1e-32, options={'gtol':1e-32,'maxiter':10000})
    #print(sol.x)
    end = time.time()
    duration = end - start
    #print("loss={}".format(best_loss.detach().numpy()))
    print("time={}".format(duration))
    loss = sol.fun
    free = torch.tensor(sol.x, dtype=torch.double, requires_grad = True)
    return free.clone(), loss, duration