In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cd /content/drive/MyDrive/NPDG/OT1D

/content/drive/MyDrive/NPDG_codes_collection/OT1D


In [None]:
# @title import
import math
import random
import scipy

import matplotlib
matplotlib.use('agg')

from matplotlib.pyplot import figure
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

import numpy as np

from torch.func import grad, hessian, vmap
from torch.func import jacrev
from torch.func import functional_call
import time


import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pickle

from numpy import *
from torch import Tensor
from torch.nn import Parameter
from torch.optim.lr_scheduler import ExponentialLR

import os
import argparse





In [None]:
# @title set dimension of the problem


dim = 1



In [None]:
# @title Check CUDA availability

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cuda


In [None]:
# @title model (net)


class network_prim(nn.Module):
    def __init__(self, network_length, input_dimension, hidden_dimension, output_dimension):
        super(network_prim, self).__init__()

        self.network_length = network_length
        self.linears = nn.ModuleList([nn.Linear(input_dimension, hidden_dimension)])
        self.linears.extend([nn.Linear(hidden_dimension, hidden_dimension) for _ in range(1, network_length-1)])
        self.linears.extend([nn.Linear(hidden_dimension, output_dimension, bias=False)])
        self.softplus = nn.Softplus()

    def initialization(self):
        for l in self.linears:
            l.weight.data.normal_()
            if l.bias is not None:
                l.bias.data.normal_()

    def forward(self, x):

        l = self.linears[0]
        x = l(x)
        for l in self.linears[1: self.network_length-1]:
            x = self.softplus(x)
            x = l(x)
        x = self.softplus(x)
        l = self.linears[self.network_length-1]
        x = l(x)

        return x


class network_dual(nn.Module):
    def __init__(self, network_length, input_dimension, hidden_dimension, output_dimension):
        super(network_dual, self).__init__()

        self.network_length = network_length
        self.linears = nn.ModuleList([nn.Linear(input_dimension, hidden_dimension)])
        self.linears.extend([nn.Linear(hidden_dimension, hidden_dimension) for _ in range(1, network_length-1)])
        self.linears.extend([nn.Linear(hidden_dimension, output_dimension, bias=False)])
        self.softplus = nn.Softplus()

    def initialization(self):
        for l in self.linears:
            l.weight.data.normal_()
            if l.bias is not None:
                l.bias.data.normal_()

    def forward(self, x):

        l = self.linears[0]
        x = l(x)
        for l in self.linears[1: self.network_length-1]:
            x = self.softplus(x)
            x = l(x)
        x = self.softplus(x)
        l = self.linears[self.network_length-1]
        x = l(x)

        return x



In [None]:
# @title 1D plot of OT map on [-L, L]

def One_D_plot_T_with_real_solution_compare(net_primal, net_primal2, l, num_of_intervals, Iter, flag_plot_real, save_path, device, d=dim):
    interval_width = l
    delta_x = 2 * l / num_of_intervals
    x = torch.arange(-l, l + delta_x, delta_x)
    x = x.unsqueeze(-1)

    T_x = net_primal(x.to(device))
    T_x = T_x.detach().cpu()

    T2_x = net_primal2(x.to(device))
    T2_x = T2_x.detach().cpu()

    fig = plt.figure(figsize=(20, 20))
    plt.xlim([-l, l])
    plt.ylim([-1.2 * l, 1.2 * l])
    plt.xlabel('x', fontsize = 20)
    plt.ylabel('y', fontsize = 20)

    plt.plot(x, T_x, color='blue', label='NPDG')

    plt.plot(x, T2_x, color='green', label='Primal-Dual using Adam')

    if flag_plot_real == 1:
        OT_map_x = torch.zeros(x.size()[0], 1)
        for i in range(x.size()[0]):
            OT_map_x[i] = OT_map(x[i])
        OT_map_x = OT_map_x.squeeze().cpu()
        plt.plot(x, OT_map_x, color='r', label='real map')
    plt.legend(fontsize=30)
    plt.title('Plot of computed maps  \n', fontsize = 40)

    filename = os.path.join(save_path, "({}th Iteration) graph of T_theta comparison ".format(Iter)+'.pdf')
    plt.savefig(filename)

    plt.close()


def One_D_plot_T_with_real_solution(net_primal, l, num_of_intervals, Iter, flag_plot_real, save_path, device, d=dim):
    interval_width = l
    delta_x = 2 * l / num_of_intervals
    x = torch.arange(-l, l + delta_x, delta_x)
    x = x.unsqueeze(-1)

    T_x = net_primal(x.to(device))
    T_x = T_x.detach().cpu()

    fig = plt.figure(figsize=(20, 20))
    plt.xlim([-l, l])
    plt.ylim([-1.2 * l, 1.2 * l])
    plt.plot(x, T_x, color='blue')

    if flag_plot_real == 1:
        OT_map_x = torch.zeros(x.size()[0], 1)
        for i in range(x.size()[0]):
            OT_map_x[i] = OT_map(x[i])
        OT_map_x = OT_map_x.squeeze().cpu()
        plt.plot(x, OT_map_x, color='r')

    filename = os.path.join(save_path, "({}th Iteration) graph of T_theta ".format(Iter)+'.pdf')
    plt.savefig(filename)

    plt.close()



In [None]:
# @title define rho0 (sampler) , rho1 (sampler) , real solution

sigma0_a = 0.5
sigma0_b = 0.25
sigma0_c = 0.5
mu0_a = -1
mu0_b = 1
mu0_c = 1
lambda_a = 2/3
lambda_c = 1/3
def rho1(n, d=dim):
    z_a = torch.randn(2*n, d)
    x_a = sigma0_a * z_a + mu0_a
    z_c = torch.randn(n, d)
    x_c = sigma0_c * z_c + mu0_c
    samples = torch.cat((x_a, x_c), 0)
    return samples.cuda()


def rho0(n, d=dim):
    return torch.randn(3*n, d).cuda()


# # Optimal Transport map
# # T(\cdot) = F_1^{-1}F_0(\cdot)
# # F_1^{-1}(.) = erf^{-1}(2 * . - 1)
# # F_0(.) = \sum_k=1^m λ_k/2 (1 + erf((. - μ_k)/(\sqrt{2} σ_k)))
# def optimal_transport_map(x, d=dim):
# # def grad_u_real(x, d=dim):

#     erfa = torch.special.erf((x - mu0_a)/(np.sqrt(2) * sigma0_a))
#     # erfb = torch.special.erf((x - mu0_b)/(np.sqrt(2) * sigma0_b))
#     erfc = torch.special.erf((x - mu0_c)/(np.sqrt(2) * sigma0_c))

#     # F_0_x = lambda_a / 2 * (1 + erfa) + lambda_b / 2 * (1 + erfb) + lambda_c / 2 * (1 + erfc)
#     F_0_x = lambda_a / 2 * (1 + erfa) + lambda_c / 2 * (1 + erfc)

#     y = 2 * F_0_x - 1
#     z = torch.special.erfinv(y)
#     return z.cuda()


# Optimal Transport map  rho0 [normal gaussian] to rho1 [multimodal gaussian]
# Inverse of the above map
def optimal_transport_map_inv(x, d=dim):

    erfa = torch.special.erf((x - mu0_a)/(np.sqrt(2) * sigma0_a))
    erfc = torch.special.erf((x - mu0_c)/(np.sqrt(2) * sigma0_c))
    F_0_x = lambda_a / 2 * (1 + erfa) + lambda_c / 2 * (1 + erfc)
    y = 2 * F_0_x - 1
    z = np.sqrt(2) * torch.special.erfinv(y)

    return z.cuda()


def OT_map(x, d=dim):
    yk = x
    xk = optimal_transport_map_inv(yk).cpu()
    y_left = x - 4
    y_right = x + 4
    for k in range(100):
        if xk > x:
            y_right = yk
        else:
            y_left = yk
        yk = (y_left + y_right) / 2
        xk = optimal_transport_map_inv(yk).cpu()
        err = xk - x

        if (abs(err) < 1e-5):
            break

    return yk



In [None]:
# @title histogram
from matplotlib import colors
from matplotlib.ticker import PercentFormatter
rng = np.random.default_rng(800)


sigma0_a = 0.5
sigma0_c = 0.5
mu0_a = -1
mu0_c = 1
lambda_a = 2/3
lambda_c = 1/3
def rho0_density(x, d=dim):
    rho0_x = pow(1/(np.sqrt(2*math.pi)), d) * torch.exp( - torch.sum((x)*(x), -1)/(2))
    return rho0_x

def rho1_density(x, d=dim):
    rho1_x = lambda_a * pow(1/(np.sqrt(2*math.pi) * sigma0_a), d) * torch.exp( - torch.sum((x-mu0_a)*(x-mu0_a), -1)/(2 * sigma0_a * sigma0_a)) + lambda_c * pow(1/(np.sqrt(2*math.pi) * sigma0_c), d) * torch.exp( - torch.sum((x-mu0_c)*(x-mu0_c), -1)/(2 * sigma0_c * sigma0_c))
    return rho1_x


# pushfwd_rho0_samples_vs_rho1_samples
def plt_histogram(L, N, n_bins):

    delta_x = 2 * L / N
    x_nodes = (torch.arange(0, N+1, 1) - N/2) * delta_x
    x_nodes = x_nodes.unsqueeze(-1)

    density_func_x = rho0_density(x_nodes)
    density_func_x = density_func_x.squeeze()

    # Generate two normal distributions
    x = rho1(N)
    OT_inv_x = optimal_transport_map_inv(x)
    OT_inv_x = OT_inv_x.detach().cpu().squeeze()
    rho1_x = rho1(N).cpu().squeeze()

    fig = plt.figure(figsize=(20, 20))
    plt.hist(rho1_x, bins=n_bins, density=True, color='green', alpha=0.7)
    plt.hist(OT_inv_x, bins=n_bins, density=True, color='blue', alpha=1)
    plt.plot(x_nodes, density_func_x, color='red', alpha=0.5)
    plt.xlim((-L, L))
    plt.ylim((0, 1.5))
    plt.title("Histogram of OT inverse pushforward rho_1 (blue) and rho_1 (green).")
    save_path = os.getcwd()
    filename = os.path.join(save_path, "hist of rho1 and OT inv pushfwded rho0.pdf".format(iter))
    plt.savefig(filename)
    plt.close()


# pushfwd_rho0_samples_vs_rho1_samples
def plt_histogram2(L, N, n_bins):

    delta_x = 2 * L / N
    x_nodes = (torch.arange(0, N+1, 1) - N/2) * delta_x
    x_nodes = x_nodes.unsqueeze(-1)

    density_func_x = rho1_density(x_nodes)
    density_func_x = density_func_x.squeeze()

    # Generate two normal distributions
    x = rho0(N).cpu()
    y = torch.zeros(N)
    for i in range(N):
        y[i] = OT_map(x[i])

    y = y.detach().cpu().squeeze()
    z = rho1(N).cpu().squeeze()

    fig = plt.figure(figsize=(20, 20))
    plt.hist(z, bins=n_bins, density=True, color='green', alpha=0.7)
    plt.hist(y, bins=n_bins, density=True, color='blue', alpha=1)
    plt.plot(x_nodes, density_func_x, color='red', alpha=0.5)

    plt.xlim((-L, L))
    plt.ylim((0, 1.5))
    plt.title("Histogram of OT pushforward rho_0 (blue) and rho_1 (green).")
    save_path = os.getcwd()
    filename = os.path.join(save_path, "hist of rho1 and OT pushfwded rho0.pdf".format(iter))
    plt.savefig(filename)
    plt.close()


def plt_histogram_pushfwd_rho0_rho1_densitycurve(iter, L, N, net_T, n_bins):

    delta_x = 2 * L / N
    x_nodes = (torch.arange(0, N+1, 1) - N/2) * delta_x
    x_nodes = x_nodes.unsqueeze(-1)
    density_func_x = rho1_density(x_nodes)
    density_func_x = density_func_x.squeeze()

    # Generate two normal distributions
    x = rho0(N)
    T_x = net_T(x)
    T_x = T_x.detach().cpu().squeeze()
    rho1_x = rho1(N).cpu().squeeze()

    fig = plt.figure(figsize=(20, 20))
    plt.hist(rho1_x, bins=n_bins, density=True, color='green', alpha=0.7)
    plt.hist(T_x, bins=n_bins, density=True, color='blue', alpha=1)
    plt.plot(x_nodes, density_func_x, color='red', alpha=0.5)
    plt.xlim((-L, L))
    plt.ylim((0, 1.5))
    plt.title("Histogram of net_T pushforward rho_0 (blue) and rho_1 (green).")
    save_path = os.getcwd()
    filename = os.path.join(save_path, "(iteration = {}) hist of rho1 and numerical solution.pdf".format(iter))
    plt.savefig(filename)
    plt.close()



In [None]:
# @title L2 error


def gradient_nn(network, x):
    input_variable = autograd.Variable(x, requires_grad=True)
    output_value = network(input_variable)
    gradients_x = autograd.grad(outputs=output_value, inputs=input_variable, grad_outputs=torch.ones(output_value.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0]
    return gradients_x


def L2_error(net_u, N):  # L2 norm of \nabla net_u - \nabla u_real

    samples = rho0(N)
    grad_realsolution = grad_u_real(samples)
    grad_u_x = gradient_nn(net_u, samples)
    L2_error = torch.sqrt(((grad_u_x - grad_realsolution)*(grad_u_x - grad_realsolution)).mean())

    return L2_error


def L2_error_T(net_T, N):  # L2 norm of \nabla net_u - \nabla u_real

    samples = rho0(N)
    ot_map_x = optimal_transport_map(samples)
    # print(ot_map_x.size())
    T_x = net_T(samples)
    # print(T_x.size())
    L2_error = torch.sqrt(((T_x - ot_map_x)*(T_x - ot_map_x)).mean())

    return L2_error


def L2_error_T_with_OT_map_inv(net_T, N):  # L2 norm of \nabla net_u - \nabla u_real

    y = rho1(N)
    x = optimal_transport_map_inv(y)
    T_x = net_T(x)
    L2_error = torch.sqrt(((T_x - y)*(T_x - y)).mean())

    return L2_error





In [None]:
# @title PDHG loss


def gradient_nn(network, x):

    input_variable = autograd.Variable(x, requires_grad=True)
    output_value = network(input_variable)
    gradients_x = autograd.grad(outputs=output_value, inputs=input_variable, grad_outputs=torch.ones(output_value.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0]

    return gradients_x


# PDHG loss with no boundary error
def PDHG_loss1(net_u, net_phi, rho0_samples, rho1_samples):
    grad_u_x = gradient_nn(net_u, rho0_samples)
    loss = net_phi(grad_u_x) - net_phi(rho1_samples)
    return loss.mean()


# derived from the Monge problem
def PDHG_loss2(net_u, net_phi, rho0_samples, rho1_samples):
    grad_u_x = gradient_nn(net_u, rho0_samples)
    loss = - torch.sum(rho0_samples * grad_u_x, -1).mean() + net_phi(grad_u_x).mean() - net_phi(rho1_samples).mean()
    return loss


def PDHG_loss3(net_T, net_phi, rho0_samples, rho1_samples):
    T_x = net_T(rho0_samples)
    loss = - torch.sum(rho0_samples * T_x, -1).mean() + net_phi(T_x).mean() - net_phi(rho1_samples).mean()
    return loss


def PDHG_loss4(net_T, net_phi, rho0_samples, rho1_samples):
    T_x = net_T(rho0_samples)
    sqr_term = 1/2*torch.sum((T_x - rho0_samples)*(T_x - rho0_samples), -1).mean()
    loss = sqr_term + net_phi(T_x).mean() - net_phi(rho1_samples).mean()
    return loss


def PDHG_loss4_with_extrapolation(net_T, net_phi, net_phi0, rho0_samples, rho1_samples, omega):
    T_x = net_T(rho0_samples)
    sqr_term = 1/2*torch.sum((T_x - rho0_samples)*(T_x - rho0_samples), -1).mean()
    phi_T_x = net_phi(T_x)
    phi0_T_x = net_phi0(T_x)
    tilde_phi_T_x = phi_T_x +  omega * (phi_T_x - phi0_T_x)
    phi_y = net_phi(rho1_samples)
    phi0_y = net_phi0(rho1_samples)
    tilde_phi_y = phi_y +  omega * (phi_y - phi0_y)
    loss = sqr_term + tilde_phi_T_x.mean() - tilde_phi_y.mean()
    return loss



In [None]:
# @title G(\theta) as a linear opt another
import scipy
from scipy.sparse.linalg import LinearOperator
import torch.autograd.functional as functional


def v_compute_Laplacian(net, samples):

    def compute_Laplacian(x):
        hessian_net = hessian(net, argnums=0)(x) #forward-over-reverse hessian calc.
        laplacian_net = hessian_net.diagonal(0,-2,-1) #use relative dims for vmap (function doesn't see the batch dim of the input)
        return torch.sum(laplacian_net, -1)

    Laplacian_wrt_x = vmap(compute_Laplacian)(samples)

    return Laplacian_wrt_x


def tensor_to_numpy(u):
    if u.device=="cpu":
        return u.detach().numpy()
    else:
        return u.cpu().detach().numpy()


# explicitly form the Gram metric G(\theta). Only for verification.
def form_metric_tensor(input_dim, net, G_samples, device):

    N = G_samples.size()[0]

    Jacobi_NN = jacrev(functional_call, argnums = 1)
    D_param_D_x_NN = vmap(jacrev(Jacobi_NN, argnums = 2), in_dims = (None, None, 0))
    D_param_D_x_net_on_x = D_param_D_x_NN(net, dict(net.named_parameters()), G_samples)
    num_params = torch.nn.utils.parameters_to_vector(net.parameters()).size()[0]
    list_of_vectorized_param_gradients = []
    for param_gradients in dict(D_param_D_x_net_on_x).items():
        vectorized_param_gradients = param_gradients[1].view(N, -1, input_dim)
        list_of_vectorized_param_gradients.append(vectorized_param_gradients)
    total_vectorized_param_gradients = torch.cat(list_of_vectorized_param_gradients, 1)
    transpose_total_vectorized_param_gradients = torch.transpose(total_vectorized_param_gradients, 1, 2)
    batched_metric_tensor = torch.matmul(total_vectorized_param_gradients, transpose_total_vectorized_param_gradients)
    metric_tensor = torch.mean(batched_metric_tensor, 0)

    return metric_tensor


# \mathcal M = grad operator
def metric_tensor_as_nabla_op(net, net_auxil, G_samples, vec, device):

    num_params = len(torch.nn.utils.parameters_to_vector(net.parameters()))

    params_net = dict(net.named_parameters())
    params_net_auxil = dict(net_auxil.named_parameters())
    ################### computation starts here ##################################
    net.zero_grad()
    net_auxil.zero_grad()

    grad_net = vmap(jacrev(functional_call, argnums=2), in_dims=(None, None, 0))
    grad_net_x = grad_net(net, params_net, G_samples)
    grad_net_auxil_x =grad_net(net_auxil, params_net_auxil, G_samples)
    ave_sqr_grad_net = torch.sum(grad_net_x * grad_net_auxil_x) / G_samples.size()[0]  # torch.sum(grad_net_x * grad_net_auxil_x, 2).mean()

    nabla_theta_ave_sqr_grad_net = torch.autograd.grad(ave_sqr_grad_net, net_auxil.parameters(), grad_outputs=None ,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_nabla_theta_ave_sqr_grad_net = torch.nn.utils.parameters_to_vector(nabla_theta_ave_sqr_grad_net)
    vec_dot_nabla_theta_ave_sqr_grad_net = vectorize_nabla_theta_ave_sqr_grad_net.dot(vec)
    metric_tensor_mult_vec = torch.autograd.grad(vec_dot_nabla_theta_ave_sqr_grad_net, net.parameters(), grad_outputs=None,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_metric_tensor_mult_vec = torch.nn.utils.parameters_to_vector(metric_tensor_mult_vec)

    return vectorize_metric_tensor_mult_vec


# \mathcal M = identity operator
def metric_tensor_as_op_identity_part(net, net_auxil, G_samples, vec, device):

    num_params = len(torch.nn.utils.parameters_to_vector(net.parameters()))

    params_net = dict(net.named_parameters())
    params_net_auxil = dict(net_auxil.named_parameters())
    ################### computation starts here ##################################
    net_x = net(G_samples)
    net_auxil_x = net_auxil(G_samples)
    ave_net = torch.sum(net_x * net_auxil_x) / G_samples.size()[0]
    nabla_theta_ave_net = torch.autograd.grad(ave_net, net_auxil.parameters(), grad_outputs=None ,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_nabla_theta_net = torch.nn.utils.parameters_to_vector(nabla_theta_ave_net)
    vec_dot_nabla_theta_ave_net = vectorize_nabla_theta_net.dot(vec)
    metric_tensor_mult_vec = torch.autograd.grad(vec_dot_nabla_theta_ave_net, net.parameters(), grad_outputs=None,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_metric_tensor_mult_vec = torch.nn.utils.parameters_to_vector(metric_tensor_mult_vec)

    return vectorize_metric_tensor_mult_vec


# G1: \mathcal M = Id operator
# G2: \mathcal M = ▽ operator
def minres_solver_G(net, net_auxil, interior_samples, boundary_samples, RHS_vec, device, bd_lambda, max_iternum, minres_tolerance, G_type):

    num_params = torch.nn.utils.parameters_to_vector(net.parameters()).size()[0]

    def G1_as_operator(vec):  # input the vector v [on CPU], return vector Gv
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_op_identity_part(net, net_auxil, interior_samples, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    def G2_as_operator(vec):  # input the vector v [on CPU], return vector Gv
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_nabla_op(net, net_auxil, interior_samples, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    if G_type == "1":
        G_operator = LinearOperator((num_params, num_params), matvec=G1_as_operator)
    elif G_type == "2":
        G_operator = LinearOperator((num_params, num_params), matvec=G2_as_operator)
    else:
        print("Wrong G_type")


    np_RHS_vec = tensor_to_numpy(RHS_vec)
    sol_vec, info = scipy.sparse.linalg.minres(G_operator, np_RHS_vec, rtol=minres_tolerance)
    if (torch.max(torch.isnan(torch.tensor(sol_vec))) > 0):
        print("Got the NAN!!!")
        sol_vec = np_RHS_vec
        info = 0
    tensorized_sol_vec = torch.Tensor(sol_vec).to(device)

    return tensorized_sol_vec, info #, norm_err_vec


def minres_solver(net, net_auxil, interior_samples, boundary_samples, RHS_vec, device, bd_lambda, max_iternum, minres_tolerance):

    def G_as_operator(vec):  # input the vector v [on CPU], return vector Gv
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_Laplace_op(net, net_auxil, interior_samples, tensorized_vec, device) + bd_lambda * metric_tensor_as_trace_op(net, net_auxil, boundary_samples, tensorized_vec, device)
        return tensor_to_numpy(Gv)  # return a numpy array on CPU

    num_params = torch.nn.utils.parameters_to_vector(net.parameters()).size()[0]
    G_operator = LinearOperator((num_params, num_params), matvec=G_as_operator)

    np_RHS_vec = tensor_to_numpy(RHS_vec)
    np_RHS_vec_copy = np.copy(np_RHS_vec)
    sol_vec, info = scipy.sparse.linalg.minres(G_operator, np_RHS_vec, rtol=minres_tolerance)
    if (torch.max(torch.isnan(torch.tensor(sol_vec))) > 0):
        print("Got the NAN!!!")
        sol_vec = np_RHS_vec
        info = 0
    tensorized_sol_vec = torch.Tensor(sol_vec).to(device)

    return tensorized_sol_vec, info #, norm_err_vec




In [None]:
# @title update param via line search

def update_param(number_stepsizes, stepsize_0, base_stepsize, theta_0, tangent_theta, net_T_or_u, hidden_dim_net_T_or_u, net_phi, hidden_dim_net_phi, rho0_samples, rho1_samples, lossfunction, net_type, descent_or_ascent):

    stepsize_list = stepsize_0 * base_stepsize**np.arange(number_stepsizes)
    min_index = 0
    index = 0
    loss_along_stepsizes = []
    current_min = 1000

    if descent_or_ascent == "descent":
        flag = -1
    else:
        flag = 1


    if net_type == "T":
        net_test = network_prim(network_length, dim, hidden_dim_net_T_or_u, dim).to(device)
    elif net_type == "u":
        net_test = network_prim(network_length, dim, hidden_dim_net_T_or_u, 1).to(device)
    elif net_type == "phi":
        net_test = network_dual(network_length, dim, hidden_dim_net_phi, 1).to(device)
    else:
        print("Wrong net_type")

    for stepsize in stepsize_list:
        updated_theta = theta_0 + flag * stepsize * tangent_theta
        torch.nn.utils.vector_to_parameters(updated_theta, net_test.parameters())
        if net_type == "T" or net_type == "u":
            loss = lossfunction(net_test, net_phi, rho0_samples, rho1_samples)
        elif net_type == "phi":
            loss = lossfunction(net_T_or_u, net_test, rho0_samples, rho1_samples)
        else:
            print("Wrong net_type")
        loss_along_stepsizes.append( - flag * loss.cpu().detach().numpy() )
        if loss < current_min:
            min_index = index
            current_min = loss

        index = index + 1

    optimal_stepsize = stepsize_list[min_index]
    optimal_updated_theta = theta_0 + flag * optimal_stepsize * tangent_theta

    return optimal_updated_theta, optimal_stepsize, loss_along_stepsizes



In [None]:
# @title PD Adam solver use T instead of grad u

n_bins = 400

def PD_Adam_solver_with_T( device, save_path, L, N,
                    network_length, hidden_dimension_net_T, hidden_dimension_net_phi, flag_init,
                    lr_T, lr_phi, iter, phi_iter, T_iter,
                    plot_period, N_plot, chosen_dim_0, chosen_dim_1, chosen_dim_2, chosen_dim_3):

    torch.manual_seed(50)

    # initialize nets
    net_T = network_prim(network_length, dim, hidden_dimension_net_T, dim).to(device)
    optim_T = torch.optim.Adam(net_T.parameters(), lr=lr_T)
    net_phi = network_dual(network_length, dim, hidden_dimension_net_phi, 1).to(device)
    optim_phi = torch.optim.Adam(net_phi.parameters(), lr=lr_phi)
    if flag_init == True:
        net_T.initialization()
        net_phi.initialization()

    # initial err
    error_0 = L2_error_T_with_OT_map_inv(net_T, 1800)
    print("initial error = {}".format(error_0.cpu().detach().numpy()))

    # loss = PDHG_loss3
    loss = PDHG_loss4   #  use L3 loss with quadratic term

    rho0_samples = rho0(N)
    plt_histogram_pushfwd_rho0_rho1_densitycurve(0, L, N_plot, net_T, n_bins)
    One_D_plot_T_with_real_solution(net_T, L, 400, 0, 1, save_path, device )

    comp_time = []
    total_time = 0.0
    l2error_list = []
    ######################################################### PDHG iterations START HERE ###################################################################################################################
    for t in range(iter):

        t_0 = time.time()

        rho0_samples = rho0(N)
        rho1_samples = rho1(N)

        # net_u.zero_grad()
        # net_phi.zero_grad()
        ############################# update phi_\eta #####################################
        for inner_iter in range(phi_iter):
            optim_phi.zero_grad()
            lossa = - loss(net_T, net_phi,  rho0_samples, rho1_samples)
            lossa.backward(retain_graph=True)
            optim_phi.step()

        ######################## update theta ##################################
        for inner_iter in range(T_iter):
            optim_T.zero_grad()
            lossb = loss(net_T, net_phi, rho0_samples, rho1_samples)
            lossb.backward(retain_graph=True)
            optim_T.step()

        t_1 = time.time()
        total_time = total_time + (t_1 - t_0)
        comp_time.append(total_time)

        ################ plot ##################
        if (t+1) % plot_period == 0:
            plt_histogram_pushfwd_rho0_rho1_densitycurve(t, L, N_plot, net_T, n_bins)
            One_D_plot_T_with_real_solution(net_T, L, 400, t, 1, save_path, device )

        ##############
        print("Iter: {}, ".format(t))
        L2error = L2_error_T_with_OT_map_inv(net_T, 1800)
        l2error_list.append(L2error.cpu().detach())
        print("L2 error = {}".format(L2error))

    ######################################################### PDHG iterations END HERE #################################################################################################################
    # save the models
    save_path = os.getcwd()
    filename = os.path.join(save_path, 'netT_PD_Adam.pt')
    torch.save(net_T.state_dict(), filename)

    save_path = os.getcwd()
    filename = os.path.join(save_path, 'netphi_PD_Adam.pt')
    torch.save(net_phi.state_dict(), filename)

    # write down the error
    with open('l2error_list', 'wb') as file1:
        pickle.dump(l2error_list, file1)
    with open('comp_time', 'wb') as f:
        pickle.dump(comp_time, f)

    # plot the error decay
    fig_plot = plt.figure(figsize=(20, 20))
    plt.plot(range(0, len(l2error_list)), log(l2error_list)/log(10))
    plt.title("plot of log_10(L2 error)")
    fig_plot.savefig("Plot of the log l2 error"+'.pdf')
    plt.show()
    plt.close()

    # plot the error decay
    fig_plot = plt.figure(figsize=(20, 20))
    plt.plot(comp_time, log(l2error_list)/log(10))
    plt.title("log_10(L2 error) vs. computation time (seconds)\n", fontsize=40)
    plt.xlabel("computation time (seconds)", fontsize=30)
    plt.ylabel("log_10(L2 error)", fontsize=30)
    fig_plot.savefig("Plot of the log l2 error vs comp time "+'.pdf')
    plt.show()
    plt.close()



In [None]:
# @title apply PD_Adam solver
save_path = os.getcwd()


# In Google Colab, change to working directory
os.chdir("/content/drive/MyDrive/NPDG_codes_collection/OT1D")
# Verify
print("Current working directory:", os.getcwd())

# create folder for NPDHG
os.makedirs('PD Adam', exist_ok=True)
os.chdir('PD Adam')


L = 4.5
N = 800

iter = 1 # 40000
phi_iter = 1
T_iter = 1
lr_T = 5 *1e-4
lr_phi = 5*1e-4

network_length =  3
hidden_dimension_net_T = 50
hidden_dimension_net_phi = 50
flag_init = False

plot_period = 1 # 2000
N_plot = 6500
chosen_dim_0 = 0
chosen_dim_1 = 1
chosen_dim_2 = 3
chosen_dim_3 = 4
PD_Adam_solver_with_T( device, save_path, L, N,
                    network_length, hidden_dimension_net_T, hidden_dimension_net_phi, flag_init,
                    lr_T, lr_phi, iter, phi_iter, T_iter,
                    plot_period, N_plot, chosen_dim_0, chosen_dim_1, chosen_dim_2, chosen_dim_3)



In [None]:
# @title NPDHG solver
import pickle


n_bins = 400
def NPDHG_solver(device, save_path, L, N,
                minres_max_iter, minres_tol,
                network_length, hidden_dimension_net_T, hidden_dimension_net_phi, flag_init,
                iter, phi_iter, T_iter, omega,
                plot_period, print_period, N_plot,
                number_stepsizes, base_stepsize,
                precond_type,
                stepsize_0=0.2,
                adaptive_or_fixed_stepsize="fixed", tau_T = 0.5 * 1e-3, tau_phi = 0.9 * 1e-3 ):

    torch.manual_seed(50)

    # initialize nets
    net_T = network_prim(network_length, dim, hidden_dimension_net_T, dim).to(device)
    net_phi = network_dual(network_length, dim, hidden_dimension_net_phi, 1).to(device)

    if flag_init == True:
        net_T.initialization()
        net_phi.initialization()

    # initial err
    error_0 = L2_error_T_with_OT_map_inv(net_T, 2000)
    print(error_0)

    if precond_type == "MpMd_id":
       G_T_type = "1"
       G_phi_type = "1"
    elif precond_type == "Mp_id_Md_nabla":
       G_T_type = "1"
       G_phi_type = "2"
    else:
       print("Wrong precond_type")
    loss = PDHG_loss4

    rho0_samples = rho0(N)
    plt_histogram_pushfwd_rho0_rho1_densitycurve(0, L, N_plot, net_T, n_bins)
    One_D_plot_T_with_real_solution(net_T, L, 400, 0, 1, save_path, device )

    total_time = 0.0
    comptime = []
    l2error_list = []
    preconded_nabla_eta_norm_list = []
    preconded_nabla_theta_norm_list = []
    ######################################################### PDHG iterations START HERE ###################################################################################################################
    for t in range(iter):

        t_0 = time.time()

        net_phi_0 = network_dual(network_length, dim, hidden_dimension_net_phi, 1).to(device)
        net_phi_0.load_state_dict(net_phi.state_dict())
        rho0_samples = rho0(N)
        rho1_samples = rho1(N)

        net_T.zero_grad()
        net_phi.zero_grad()
        ############################# update phi_\eta #####################################
        for inner_iter in range(phi_iter):
            ############### compute G(\eta)^{-1} \nabla_\eta loss() #########################
            lossa = loss(net_T, net_phi,  rho0_samples, rho1_samples)
            nabla_eta_loss = torch.autograd.grad(lossa, net_phi.parameters(), grad_outputs=None, allow_unused=True, retain_graph=True, create_graph=True)
            vectorized_nabla_eta_loss = torch.nn.utils.parameters_to_vector(nabla_eta_loss)

            # copy net_phi for G(\eta) computation
            net_phi_auxil = network_dual(network_length, dim, hidden_dimension_net_phi, 1).to(device)
            net_phi_auxil.load_state_dict(net_phi.state_dict())

            # compute G(\eta)^{-1} \nabla_\eta loss()
            G_inv_nabla_eta_loss, info_phi = minres_solver_G(net_phi, net_phi_auxil, rho1_samples, None, vectorized_nabla_eta_loss, device, None, minres_max_iter, minres_tol, G_phi_type)

            # update \eta
            original_eta = torch.nn.utils.parameters_to_vector(net_phi.parameters())
            if adaptive_or_fixed_stepsize == "adaptive":
                updated_eta, tau_phi, value_along_tau_phis = update_param(number_stepsizes, stepsize_0, base_stepsize, original_eta, G_inv_nabla_eta_loss, net_T, hidden_dimension_net_T, net_phi, hidden_dimension_net_phi, rho0_samples, rho1_samples, loss, "phi", "ascent")
            elif adaptive_or_fixed_stepsize == "fixed":
                updated_eta = original_eta + tau_phi * G_inv_nabla_eta_loss
            else:
                raise ValueError("adaptive_or_fixed_stepsize must be 'adaptive' or 'fixed'")
            torch.nn.utils.vector_to_parameters(updated_eta, net_phi.parameters())

        ######################## update theta ##################################
        for inner_iter in range(T_iter):
            # compute G(\theta)^{-1} \nabla_\theta loss()
            lossb = PDHG_loss4_with_extrapolation(net_T, net_phi, net_phi_0, rho0_samples, rho1_samples, omega)
            nabla_theta_loss = torch.autograd.grad(lossb, net_T.parameters(), grad_outputs=None, allow_unused=True, retain_graph=True, create_graph=True)
            vectorized_nabla_theta_loss = torch.nn.utils.parameters_to_vector(nabla_theta_loss)

            # copy net_T for  G(\theta) computation
            net_T_auxil = network_prim(network_length, dim, hidden_dimension_net_T, dim).to(device)
            net_T_auxil.load_state_dict(net_T.state_dict())

            # compute G(\theta)^{-1} \nabla_\theta loss()
            G_inv_nabla_theta_loss, info_T = minres_solver_G(net_T, net_T_auxil, rho0_samples, None, vectorized_nabla_theta_loss, device, None, minres_max_iter, minres_tol, G_T_type)

            ############# update theta ####################
            original_theta = torch.nn.utils.parameters_to_vector(net_T.parameters())
            if adaptive_or_fixed_stepsize == "adaptive":
               updated_theta, tau_T, value_along_tau_Ts = update_param(number_stepsizes, stepsize_0, base_stepsize, original_theta, G_inv_nabla_theta_loss, net_T, hidden_dimension_net_T, net_phi, hidden_dimension_net_phi, rho0_samples, rho1_samples, loss, "T", "descent")
            elif adaptive_or_fixed_stepsize == "fixed":
               updated_theta = original_theta - tau_T * G_inv_nabla_theta_loss
            else:
               raise ValueError("adaptive_or_fixed_stepsize must be 'adaptive' or 'fixed'")
            torch.nn.utils.vector_to_parameters(updated_theta, net_T.parameters())

        t_1 = time.time()
        total_time += t_1 - t_0
        comptime.append(total_time)

        ################ plot ##################
        if (t+1) % plot_period == 0:
            plt_histogram_pushfwd_rho0_rho1_densitycurve(t, L, N_plot, net_T, n_bins)
            One_D_plot_T_with_real_solution(net_T, L, 400, t, 1, save_path, device)

        ########################################################
        L2error = L2_error_T_with_OT_map_inv(net_T, 1800)
        l2error_list.append(L2error.cpu().detach())
        if (t+1) % print_period == 0:
            print("Iteration: {}".format(t))
            print("L2 error = {}".format(L2error))


    ######################################################### PDHG iterations END HERE #################################################################################################################
    # save the models
    save_path = os.getcwd()
    filename = os.path.join(save_path, 'netT_NPDHG.pt')
    torch.save(net_T.state_dict(), filename)

    save_path = os.getcwd()
    filename = os.path.join(save_path, 'netphi_NPDHG.pt')
    torch.save(net_phi.state_dict(), filename)

    # write down the error
    with open('l2error_list', 'wb') as file1:
        pickle.dump(l2error_list, file1)
    with open('comptime', 'wb') as f:
        pickle.dump(comptime, f)

    # plot the error decay
    fig_plot = plt.figure(figsize=(20, 20))
    plt.plot(range(0, len(l2error_list)), log(l2error_list)/log(10))
    plt.title("plot of log_10(L2 error)")
    fig_plot.savefig("Plot of the log l2 error"+'.pdf')
    plt.show()
    plt.close()

    # plot the error decay
    fig_plot = plt.figure(figsize=(20, 20))
    plt.plot(comptime, log(l2error_list)/log(10))
    plt.title("log_10(L2 error) vs. computation time (seconds)\n", fontsize=40)
    plt.xlabel("computation time (seconds)", fontsize=30)
    plt.ylabel("log_10(L2 error)", fontsize=30)
    fig_plot.savefig("Plot of the log l2 error vs comp time "+'.pdf')
    plt.show()
    plt.close()



In [None]:
# @title apply NPDHG solver (use M_d = Id in preconditioner)
save_path = os.getcwd()


# In Google Colab, change to working directory
os.chdir("/content/drive/MyDrive/NPDG_codes_collection/OT1D")
# Verify
print("Current working directory:", os.getcwd())

# create folder for NPDHG
os.makedirs('NPDHG precondition1', exist_ok=True)
os.chdir('NPDHG precondition1')


L = 4.5
N = 800

minres_max_iter = 1000
minres_tol = 1e-3

network_length = 3
hidden_dimension_net_T = 50
hidden_dimension_net_phi = 50
flag_init = False

iter = 6000
phi_iter = 1
T_iter = 1
omega = 1.0

plot_period = 1000
print_period = 200
N_plot = 5000

number_stepsizes = 50
base_stepsize = 0.8

precond_type = "MpMd_id"

NPDHG_solver(device, save_path, L, N,
            minres_max_iter, minres_tol,
            network_length, hidden_dimension_net_T, hidden_dimension_net_phi, flag_init,
            iter, phi_iter, T_iter, omega,
            plot_period, print_period, N_plot,
            number_stepsizes, base_stepsize,
            precond_type,
            adaptive_or_fixed_stepsize="fixed", tau_T = 1e-1, tau_phi = 1.5 * 1e-1
            )



In [None]:
# @title apply NPDHG solver (use M_d = \nabla in preconditioner)
save_path = os.getcwd()


# In Google Colab, change to working directory
os.chdir("/content/drive/MyDrive/NPDG_codes_collection/OT1D")
# Verify
print("Current working directory:", os.getcwd())

# create folder for NPDHG
os.makedirs('NPDHG precondition2', exist_ok=True)
os.chdir('NPDHG precondition2')


L = 4.5
N = 800

minres_max_iter = 1000
minres_tol = 1e-3

network_length = 3
hidden_dimension_net_T = 50
hidden_dimension_net_phi = 50
flag_init = False

iter = 6000
phi_iter = 1
T_iter = 1
omega = 1.0

plot_period =  1000
print_period =  200
N_plot = 5000

number_stepsizes = 50
base_stepsize = 0.8

precond_type = "Mp_id_Md_nabla"

NPDHG_solver(device, save_path, L, N,
            minres_max_iter, minres_tol,
            network_length, hidden_dimension_net_T, hidden_dimension_net_phi, flag_init,
            iter, phi_iter, T_iter, omega,
            plot_period, print_period, N_plot,
            number_stepsizes, base_stepsize,
            precond_type,
            adaptive_or_fixed_stepsize="fixed", tau_T = 1e-1, tau_phi = 1.5 * 1e-1
            )

