In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cd/content/drive/MyDrive/NPDG/OT_Gaussian

This notebook computes the optimal transport (OT) map between Gaussian distributions using


*   NPDG method
*   Primal-Dual Adam method

See Section 5.5.2 for more detailed description of the problem.



In [None]:
# @title import
import math
import random
import scipy

import matplotlib
matplotlib.use('agg')

from matplotlib.pyplot import figure
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

import numpy as np

from torch.func import grad, hessian, vmap
from torch.func import jacrev
from torch.func import functional_call
import time


import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pickle

from numpy import *
from torch import Tensor
from torch.nn import Parameter
from torch.optim.lr_scheduler import ExponentialLR

import os
import argparse



In [None]:
# @title set dimension of the problem

dim = 5

In [None]:
# @title Check CUDA availability

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

In [None]:
# @title model (net)


class network_map(nn.Module):
    def __init__(self, network_length, input_dimension, hidden_dimension, output_dimension):
        super(network_prim, self).__init__()

        self.network_length = network_length
        self.linears = nn.ModuleList([nn.Linear(input_dimension, hidden_dimension)])
        self.linears.extend([nn.Linear(hidden_dimension, hidden_dimension) for _ in range(1, network_length-1)])
        self.linears.extend([nn.Linear(hidden_dimension, output_dimension, bias=False)])
        self.softplus = nn.Softplus()

    def initialization(self):
        for l in self.linears:
            l.weight.data.normal_()
            if l.bias is not None:
                l.bias.data.normal_()

    def forward(self, x):
        l = self.linears[0]
        x = l(x)
        for l in self.linears[1: self.network_length-1]:
            x = self.softplus(x)
            x = l(x)

        x = self.softplus(x)
        l = self.linears[self.network_length-1]
        x = l(x)

        return x


class network_prim(nn.Module):
    def __init__(self, network_length, input_dimension, hidden_dimension, output_dimension):
        super(network_prim, self).__init__()

        self.network_length = network_length
        self.linears = nn.ModuleList([nn.Linear(input_dimension, hidden_dimension)])
        self.linears.extend([nn.Linear(hidden_dimension, hidden_dimension) for _ in range(1, network_length-1)])
        self.linears.extend([nn.Linear(hidden_dimension, output_dimension, bias=False)])
        self.softplus = nn.Softplus()

    def initialization(self):
        for l in self.linears:
            l.weight.data.normal_()
            if l.bias is not None:
                l.bias.data.normal_()

    def forward(self, x):
        l = self.linears[0]
        x = l(x)
        for l in self.linears[1: self.network_length-1]:
            x = self.softplus(x)
            x = l(x)
        x = self.softplus(x)
        l = self.linears[self.network_length-1]
        x = l(x)
        return x


class network_dual(nn.Module):
    def __init__(self, network_length, input_dimension, hidden_dimension, output_dimension):
        super(network_dual, self).__init__()

        self.network_length = network_length
        self.linears = nn.ModuleList([nn.Linear(input_dimension, hidden_dimension)])
        self.linears.extend([nn.Linear(hidden_dimension, hidden_dimension) for _ in range(1, network_length-1)])
        self.linears.extend([nn.Linear(hidden_dimension, output_dimension, bias=False)])
        self.softplus = nn.Softplus()

    def initialization(self):
        for l in self.linears:
            l.weight.data.normal_()
            if l.bias is not None:
                l.bias.data.normal_()

    def forward(self, x):

        l = self.linears[0]
        x = l(x)
        for l in self.linears[1: self.network_length-1]:
            x = self.softplus(x)
            x = l(x)

        x = self.softplus(x)
        l = self.linears[self.network_length-1]
        x = l(x)

        return x



In [None]:
# @title 2D vector field plotting & 2D pushforwarded sample

def plot_grad_u_with_grad_ureal(iter, samples, l, chosen_dim_0, chosen_dim_1, net_u, save_path):

    grad_netu = 3. * gradient_nn(net_u, samples)
    grad_netu = grad_netu.cpu().detach().numpy()
    grad_ureal = 3. * grad_u_real(samples)
    grad_ureal = grad_ureal.cpu().detach().numpy()

    samples = samples.cpu().detach().numpy()

    figure(num=None, figsize=(34, 34), dpi=80, facecolor='w', edgecolor='k')
    plt.xlim([-l, l])
    plt.ylim([-l, l])
    plt.quiver(samples[:, chosen_dim_0], samples[:, chosen_dim_1], grad_netu[:, chosen_dim_0], grad_netu[:, chosen_dim_1], scale=None, scale_units='inches', color = 'green', width=0.002)
    plt.quiver(samples[:, chosen_dim_0], samples[:, chosen_dim_1], grad_ureal[:, chosen_dim_0], grad_ureal[:, chosen_dim_1], scale=None, scale_units='inches', color = 'red', width=0.002)
    plt.xlabel("{} component".format(chosen_dim_0))
    plt.ylabel("{} component".format(chosen_dim_1))

    filename = os.path.join(save_path, '(Iteration={}) gradient of net_u with gradient of u_real at sample points (on {}-{} plane)'.format(iter, chosen_dim_0, chosen_dim_1) + '.pdf')
    plt.savefig(filename)
    plt.close()


def plot_pushfwded_samples(iter, net_T, samples, chosen_dim_0, chosen_dim_1, save_path):

    N = samples.size()[0]

    T_x = net_T(samples)
    T_x = T_x.cpu().detach().numpy()
    samples = samples.cpu().detach().numpy()
    target_samples = rho1(N).cpu()

    figure(num=None, figsize=(34, 34), dpi=80, facecolor='w', edgecolor='k')
    plt.scatter(samples[:, 0], samples[:, 1], color='green', s=9)
    plt.scatter(T_x[:, 0], T_x[:, 1], color='blue', s=9)
    plt.scatter(target_samples[:, 0], target_samples[:, 1], color='red', s=9, alpha=0.4)
    plt.xlim([-3, 3])
    plt.ylim([-3, 3])
    plt.xlabel("{} component".format(chosen_dim_0))
    plt.ylabel("{} component".format(chosen_dim_1))

    filename = os.path.join(save_path, '(Iteration={}) original samples (green, ~rho0) pushforwarded samples (blue) with target samples (red ~rho1) (on {}-{} plane)'.format(iter, chosen_dim_0, chosen_dim_1) + '.pdf')
    plt.savefig(filename)
    plt.close()


def plot_T_with_grad_ureal(iter, samples, l, chosen_dim_0, chosen_dim_1, net_T, save_path):

    T_x = net_T(samples)
    T_x = T_x.cpu().detach().numpy()
    grad_ureal = grad_u_real(samples)
    grad_ureal = grad_ureal.cpu().detach().numpy()

    samples = samples.cpu().detach().numpy()

    figure(num=None, figsize=(14, 14), dpi=80, facecolor='w', edgecolor='k')
    plt.xlim([-l, l])
    plt.ylim([-l, l])
    q = plt.quiver(samples[:, chosen_dim_0], samples[:, chosen_dim_1], grad_ureal[:, chosen_dim_0], grad_ureal[:, chosen_dim_1], scale = 1.6, scale_units='inches', color = 'red', width=0.0024, label='real OT map')
    plt.quiver(samples[:, chosen_dim_0], samples[:, chosen_dim_1], T_x[:, chosen_dim_0], T_x[:, chosen_dim_1], scale = q.scale, scale_units='inches', color=(1.6/100, 46.3/100, 81.6/100), width=0.001, label='OT map computed via NPDG')
    plt.tick_params(axis='both', labelsize=30)
    plt.legend(fontsize = 36)
    plt.xlabel("x_{}".format(chosen_dim_0+1), fontsize = 30)
    plt.ylabel("x_{}".format(chosen_dim_1+1), fontsize = 30)
    filename = os.path.join(save_path, '[a]Map T with gradient of u_real at sample points (on {}-{} plane)'.format(iter, chosen_dim_0, chosen_dim_1) + '.pdf')
    plt.savefig(filename)
    plt.close()


def plot_T_with_grad_ureal_two_plots(iter, samples, l, chosen_dim_0, chosen_dim_1, chosen_dim_2, chosen_dim_3, net_T, save_path):

    T_x = net_T(samples)
    T_x = T_x.cpu().detach().numpy()
    grad_ureal = grad_u_real(samples)
    grad_ureal = grad_ureal.cpu().detach().numpy()

    samples = samples.cpu().detach().numpy()

    plt.subplot(2, 1, 1)
    figure(num=None, figsize=(27, 27), dpi=80, facecolor='w', edgecolor='k')
    plt.xlim([-l, l])
    plt.ylim([-l, l])
    q = plt.quiver(samples[:, chosen_dim_0], samples[:, chosen_dim_1], grad_ureal[:, chosen_dim_0], grad_ureal[:, chosen_dim_1], scale = 1, scale_units='inches', color = 'red', width=0.0024, label='real map')
    plt.quiver(samples[:, chosen_dim_0], samples[:, chosen_dim_1], T_x[:, chosen_dim_0], T_x[:, chosen_dim_1], scale = q.scale, scale_units='inches', color=(1.6/100, 46.3/100, 81.6/100), width=0.001, label='NPDG')
    plt.xlabel("x_{}".format(chosen_dim_0+1), fontsize = 30)
    plt.ylabel("x_{}".format(chosen_dim_1+1), fontsize = 30)

    plt.subplot(2, 1, 2)
    figure(num=None, figsize=(27, 27), dpi=80, facecolor='w', edgecolor='k')
    plt.xlim([-l, l])
    plt.ylim([-l, l])
    q = plt.quiver(samples[:, chosen_dim_2], samples[:, chosen_dim_3], grad_ureal[:, chosen_dim_2], grad_ureal[:, chosen_dim_3], scale = 1, scale_units='inches', color = 'red', width=0.0024, label='real map')
    plt.quiver(samples[:, chosen_dim_2], samples[:, chosen_dim_3], T_x[:, chosen_dim_2], T_x[:, chosen_dim_3], scale = q.scale, scale_units='inches', color=(1.6/100, 46.3/100, 81.6/100), width=0.001, label='NPDG')
    plt.xlabel("x_{}".format(chosen_dim_2+1), fontsize = 30)
    plt.ylabel("x_{}".format(chosen_dim_3+1), fontsize = 30)

    filename = os.path.join(save_path, 'Map T with gradient of u_real at sample points '  + '.pdf')
    plt.savefig(filename)
    plt.close()


def plot_pushfwded_samples(iter, net_T, samples, chosen_dim_0, chosen_dim_1, save_path):

    N = samples.size()[0]

    T_x = net_T(samples)
    T_x = T_x.cpu().detach().numpy()
    samples = samples.cpu().detach().numpy()
    target_samples = rho1(N).cpu()

    figure(num=None, figsize=(34, 34), dpi=80, facecolor='w', edgecolor='k')
    plt.scatter(samples[:, 0], samples[:, 1], color='green', s=9)
    plt.scatter(T_x[:, 0], T_x[:, 1], color='blue', s=9)
    plt.scatter(target_samples[:, 0], target_samples[:, 1], color='red', s=9, alpha=0.4)
    plt.xlim([-3, 3])
    plt.ylim([-3, 3])
    plt.xlabel("{} component".format(chosen_dim_0))
    plt.ylabel("{} component".format(chosen_dim_1))

    filename = os.path.join(save_path, '(Iteration={}) original samples (green, ~rho0) pushforwarded samples (blue) with target samples (red ~rho1) (on {}-{} plane)'.format(iter, chosen_dim_0, chosen_dim_1) + '.pdf')
    plt.savefig(filename)
    plt.close()



In [None]:
# @title define rho0 (sampler) , rho1 (sampler) , real solution

Sigma0 = torch.eye(dim)
Sigma0[0, 0] = 1/4
sqrt_Sigma0 = torch.sqrt(Sigma0)
inv_sqrt_Sigma0 = torch.inverse(sqrt_Sigma0)
mu0 = torch.zeros( 1, dim )

def rho0(n, d=dim):
    z = torch.randn(n, d)
    Sigma0_z = torch.matmul(z, sqrt_Sigma0)
    x = Sigma0_z + mu0

    return x.cuda()


Sigma1 = torch.eye(dim)
Sigma1[1, 1] = 1/4
Sigma1[3, 3] = 5/8
Sigma1[4, 3] = 3/8
Sigma1[3, 4] = 3/8
Sigma1[4, 4] = 5/8
sqrt_Sigma1 = torch.eye(dim)
sqrt_Sigma1[1, 1] = 1/2
sqrt_Sigma1[3, 3] = 3/4
sqrt_Sigma1[4, 3] = 1/4
sqrt_Sigma1[3, 4] = 1/4
sqrt_Sigma1[4, 4] = 3/4
inv_sqrt_Sigma1 = torch.eye(dim)
inv_sqrt_Sigma1[1, 1] = 2
inv_sqrt_Sigma1[3, 3] = 3/2
inv_sqrt_Sigma1[4, 3] = -1/2
inv_sqrt_Sigma1[3, 4] = -1/2
inv_sqrt_Sigma1[4, 4] = 3/2
mu1 = torch.zeros( 1, dim )
def rho1(n, d=dim):
    z = torch.randn(n, d)
    Sigma1_z = torch.matmul(z, sqrt_Sigma1)
    x = Sigma1_z + mu1

    return x.cuda()


# real sol: u(x) = 1/2 x^TAx + b^Tx
# A = sqrt(Sigma_0^{-1}Sigma_1)  [Assume that Sigma_0, Sigma_1 are diagonal]
# b = mu_1 - A mu_0
A = torch.matmul(inv_sqrt_Sigma0, sqrt_Sigma1)
b = mu1 - torch.matmul(mu0, A)
A = A.cuda()
b = b.cuda()
# \nabla u(x) = Ax+b
def grad_u_real(x, d=dim):
    gradients = torch.matmul(x, A) + b
    return gradients.cuda()



In [None]:
# @title L2 error


def gradient_nn(network, x):

    input_variable = autograd.Variable(x, requires_grad=True)
    output_value = network(input_variable)
    gradients_x = autograd.grad(outputs=output_value, inputs=input_variable, grad_outputs=torch.ones(output_value.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0]

    return gradients_x


def L2_error(net_u, N):  # L2 norm of \nabla net_u - \nabla u_real

    samples = rho0(N)
    grad_realsolution = grad_u_real(samples)
    grad_u_x = gradient_nn(net_u, samples)
    L2_error = torch.sqrt(((grad_u_x - grad_realsolution)*(grad_u_x - grad_realsolution)).mean())

    return L2_error


def L2_error_T(net_T, N):  # L2 norm of \nabla net_u - \nabla u_real

    samples = rho0(N)
    grad_realsolution = grad_u_real(samples)
    T_x = net_T(samples)
    L2_error = torch.sqrt(((T_x - grad_realsolution)*(T_x - grad_realsolution)).mean())

    return L2_error



In [None]:
# @title PDHG loss


def gradient_nn(network, x):
    input_variable = autograd.Variable(x, requires_grad=True)
    output_value = network(input_variable)
    gradients_x = autograd.grad(outputs=output_value, inputs=input_variable, grad_outputs=torch.ones(output_value.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0]
    return gradients_x


# PDHG loss with no boundary error
def PDHG_loss1(net_u, net_phi, rho0_samples, rho1_samples):
    grad_u_x = gradient_nn(net_u, rho0_samples)
    loss = net_phi(grad_u_x) - net_phi(rho1_samples)
    return loss.mean()


# derived from the Monge problem
def PDHG_loss2(net_u, net_phi, rho0_samples, rho1_samples):
    grad_u_x = gradient_nn(net_u, rho0_samples)
    loss = - torch.sum(rho0_samples * grad_u_x, -1).mean() + net_phi(grad_u_x).mean() - net_phi(rho1_samples).mean()
    return loss


def PDHG_loss3(net_T, net_phi, rho0_samples, rho1_samples):
    T_x = net_T(rho0_samples)
    loss = - torch.sum(rho0_samples * T_x, -1).mean() + net_phi(T_x).mean() - net_phi(rho1_samples).mean()
    return loss


def PDHG_loss4_with_extrapolation(net_T, net_phi, net_phi0, rho0_samples, rho1_samples, omega):
    T_x = net_T(rho0_samples)
    sqr_term = 1/2*torch.sum((T_x - rho0_samples)*(T_x - rho0_samples), -1).mean()
    phi_T_x = net_phi(T_x)
    phi0_T_x = net_phi0(T_x)
    tilde_phi_T_x = phi_T_x +  omega * (phi_T_x - phi0_T_x)
    phi_y = net_phi(rho1_samples)
    phi0_y = net_phi0(rho1_samples)
    tilde_phi_y = phi_y +  omega * (phi_y - phi0_y)
    loss = sqr_term + tilde_phi_T_x.mean() - tilde_phi_y.mean()
    return loss



In [None]:
# @title G(\theta) as a linear opt another
import scipy
from scipy.sparse.linalg import LinearOperator
import torch.autograd.functional as functional


def tensor_to_numpy(u):
    if u.device=="cpu":
        return u.detach().numpy()
    else:
        return u.cpu().detach().numpy()


# explicitly form the Gram metric G(\theta). Only for verification.
def form_metric_tensor(input_dim, net, G_samples, device):

    N = G_samples.size()[0]

    Jacobi_NN = jacrev(functional_call, argnums = 1)
    D_param_D_x_NN = vmap(jacrev(Jacobi_NN, argnums = 2), in_dims = (None, None, 0))
    D_param_D_x_net_on_x = D_param_D_x_NN(net, dict(net.named_parameters()), G_samples)
    num_params = torch.nn.utils.parameters_to_vector(net.parameters()).size()[0]
    list_of_vectorized_param_gradients = []
    for param_gradients in dict(D_param_D_x_net_on_x).items():
        vectorized_param_gradients = param_gradients[1].view(N, -1, input_dim)
        list_of_vectorized_param_gradients.append(vectorized_param_gradients)
    total_vectorized_param_gradients = torch.cat(list_of_vectorized_param_gradients, 1)
    transpose_total_vectorized_param_gradients = torch.transpose(total_vectorized_param_gradients, 1, 2)
    batched_metric_tensor = torch.matmul(total_vectorized_param_gradients, transpose_total_vectorized_param_gradients)
    metric_tensor = torch.mean(batched_metric_tensor, 0)

    return metric_tensor


# pull back Lap operator (as metric tensor)
def metric_tensor_as_nabla_op(net, net_auxil, G_samples, vec, device):

    num_params = len(torch.nn.utils.parameters_to_vector(net.parameters()))

    params_net = dict(net.named_parameters())
    params_net_auxil = dict(net_auxil.named_parameters())
    ################### computation starts here ##################################
    net.zero_grad()
    net_auxil.zero_grad()

    grad_net = vmap(jacrev(functional_call, argnums=2), in_dims=(None, None, 0))
    grad_net_x = grad_net(net, params_net, G_samples)
    grad_net_auxil_x =grad_net(net_auxil, params_net_auxil, G_samples)
    ave_sqr_grad_net = torch.sum(grad_net_x * grad_net_auxil_x) / G_samples.size()[0]  # torch.sum(grad_net_x * grad_net_auxil_x, 2).mean()

    nabla_theta_ave_sqr_grad_net = torch.autograd.grad(ave_sqr_grad_net, net_auxil.parameters(), grad_outputs=None ,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_nabla_theta_ave_sqr_grad_net = torch.nn.utils.parameters_to_vector(nabla_theta_ave_sqr_grad_net)
    vec_dot_nabla_theta_ave_sqr_grad_net = vectorize_nabla_theta_ave_sqr_grad_net.dot(vec)
    metric_tensor_mult_vec = torch.autograd.grad(vec_dot_nabla_theta_ave_sqr_grad_net, net.parameters(), grad_outputs=None,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_metric_tensor_mult_vec = torch.nn.utils.parameters_to_vector(metric_tensor_mult_vec)

    return vectorize_metric_tensor_mult_vec


# pull back identity operator (as metric tensor)
def metric_tensor_as_op_identity_part(net, net_auxil, G_samples, vec, device):

    num_params = len(torch.nn.utils.parameters_to_vector(net.parameters()))

    params_net = dict(net.named_parameters())
    params_net_auxil = dict(net_auxil.named_parameters())
    ################### computation starts here ##################################
    net_x = net(G_samples)
    net_auxil_x = net_auxil(G_samples)
    ave_net = torch.sum(net_x * net_auxil_x) / G_samples.size()[0]
    nabla_theta_ave_net = torch.autograd.grad(ave_net, net_auxil.parameters(), grad_outputs=None ,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_nabla_theta_net = torch.nn.utils.parameters_to_vector(nabla_theta_ave_net)
    vec_dot_nabla_theta_ave_net = vectorize_nabla_theta_net.dot(vec)
    metric_tensor_mult_vec = torch.autograd.grad(vec_dot_nabla_theta_ave_net, net.parameters(), grad_outputs=None,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_metric_tensor_mult_vec = torch.nn.utils.parameters_to_vector(metric_tensor_mult_vec)

    return vectorize_metric_tensor_mult_vec


# G1: \mathcal M = Id operator
# G2: \mathcal M = â–½ operator
def minres_solver_G(net, net_auxil, interior_samples, boundary_samples, RHS_vec, device, bd_lambda, max_iternum, minres_tolerance, G_type):

    num_params = torch.nn.utils.parameters_to_vector(net.parameters()).size()[0]

    def G1_as_operator(vec):  # input the vector v [on CPU], return vector Gv
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_op_identity_part(net, net_auxil, interior_samples, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    def G2_as_operator(vec):  # input the vector v [on CPU], return vector Gv
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_nabla_op(net, net_auxil, interior_samples, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    if G_type == "1":
        G_operator = LinearOperator((num_params, num_params), matvec=G1_as_operator)
    elif G_type == "2":
        G_operator = LinearOperator((num_params, num_params), matvec=G2_as_operator)
    else:
        print("Wrong G_type")

    np_RHS_vec = tensor_to_numpy(RHS_vec)
    sol_vec, info = scipy.sparse.linalg.minres(G_operator, np_RHS_vec, rtol=minres_tolerance)
    if (torch.max(torch.isnan(torch.tensor(sol_vec))) > 0):
        print("Got the NAN!!!")
        sol_vec = np_RHS_vec
        info = 0
    tensorized_sol_vec = torch.Tensor(sol_vec).to(device)
    return tensorized_sol_vec, info #, norm_err_vec





In [None]:
# @title update param via line search

def update_param(number_stepsizes, stepsize_0, base_stepsize, theta_0, tangent_theta, net_T_or_u, hidden_dim_net_T_or_u, net_phi, hidden_dim_net_phi, rho0_samples, rho1_samples, lossfunction, net_type, descent_or_ascent):

    stepsize_list = stepsize_0 * base_stepsize**np.arange(number_stepsizes)
    min_index = 0
    index = 0
    loss_along_stepsizes = []
    current_min = 1000

    if descent_or_ascent == "descent":
        flag = -1
    else:
        flag = 1

    if net_type == "T":
        net_test = network_prim(network_length, dim, hidden_dim_net_T_or_u, dim).to(device)
    elif net_type == "u":
        net_test = network_prim(network_length, dim, hidden_dim_net_T_or_u, 1).to(device)
    elif net_type == "phi":
        net_test = network_dual(network_length, dim, hidden_dim_net_phi, 1).to(device)
    else:
        print("Wrong net_type")

    for stepsize in stepsize_list:
        updated_theta = theta_0 + flag * stepsize * tangent_theta
        torch.nn.utils.vector_to_parameters(updated_theta, net_test.parameters())
        if net_type == "T" or net_type == "u":
            loss = lossfunction(net_test, net_phi, rho0_samples, rho1_samples)
        elif net_type == "phi":
            loss = lossfunction(net_T_or_u, net_test, rho0_samples, rho1_samples)
        else:
            print("Wrong net_type")
        loss_along_stepsizes.append( - flag * loss.cpu().detach().numpy() )
        if loss < current_min:
            min_index = index
            current_min = loss

        index = index + 1

    optimal_stepsize = stepsize_list[min_index]
    optimal_updated_theta = theta_0 + flag * optimal_stepsize * tangent_theta

    return optimal_updated_theta, optimal_stepsize, loss_along_stepsizes



In [None]:
# @title PD Adam solver


def PD_Adam_solver(device, save_path, L, N,
                          network_length, hidden_dimension_net_T, hidden_dimension_net_phi, flag_init,
                          lr_T, lr_phi, iter, phi_iter, T_iter,
                          plot_period, print_period, N_plot, chosen_dim_0, chosen_dim_1, chosen_dim_2, chosen_dim_3):

    torch.manual_seed(50)

    # initialize nets
    net_T = network_prim(network_length, dim, hidden_dimension_net_T, dim).to(device)
    optim_T = torch.optim.Adam(net_T.parameters(), lr=lr_T)
    net_phi = network_dual(network_length, dim, hidden_dimension_net_phi, 1).to(device)
    optim_phi = torch.optim.Adam(net_phi.parameters(), lr=lr_phi)
    if flag_init == True:
        net_T.initialization()
        net_phi.initialization()

    # initial err
    error_0 = L2_error_T(net_T, 1800)
    print("initial error = {}".format(error_0.cpu().detach().numpy()))

    loss = PDHG_loss3

    rho0_samples = rho0(N)
    plot_T_with_grad_ureal(0, rho0_samples, L, chosen_dim_0, chosen_dim_1, net_T, save_path)
    plot_pushfwded_samples(0, net_T, rho0_samples, chosen_dim_0, chosen_dim_1, save_path=save_path)
    plot_T_with_grad_ureal(0, rho0_samples, L, chosen_dim_2, chosen_dim_3, net_T, save_path)
    plot_pushfwded_samples(0, net_T, rho0_samples, chosen_dim_2, chosen_dim_3, save_path=save_path)

    l2error_list = []
    ######################################################### PDHG iterations START HERE ###################################################################################################################
    for t in range(iter):

        rho0_samples = rho0(N)
        rho1_samples = rho1(N)
        ############################# update phi_\eta #####################################
        for inner_iter in range(phi_iter):
            optim_phi.zero_grad()
            lossa = - loss(net_T, net_phi,  rho0_samples, rho1_samples)
            lossa.backward(retain_graph=True)
            optim_phi.step()

        ######################## update theta ##################################
        for inner_iter in range(T_iter):
            optim_T.zero_grad()
            lossb = loss(net_T, net_phi, rho0_samples, rho1_samples)
            lossb.backward(retain_graph=True)
            optim_T.step()

        ################ plot ##################
        if (t+1) % plot_period == 0:
            plot_T_with_grad_ureal(t+1, rho0_samples, L, chosen_dim_0, chosen_dim_1, net_T, save_path)
            plot_pushfwded_samples(t+1, net_T, rho0_samples, chosen_dim_0, chosen_dim_1, save_path=save_path)
            plot_T_with_grad_ureal(t+1, rho0_samples, L, chosen_dim_2, chosen_dim_3, net_T, save_path)
            plot_pushfwded_samples(t+1, net_T, rho0_samples, chosen_dim_2, chosen_dim_3, save_path=save_path)

        ##############
        if (t+1) % print_period == 0:
            print("Iteration: {}, ".format(t))
            L2error = L2_error_T(net_T, 1800)
            l2error_list.append(L2error.cpu().detach())
            print("L2 error = {}".format(L2error))


    ######################################################### PDHG iterations END HERE #################################################################################################################
    # save the models
    save_path = os.getcwd()
    filename = os.path.join(save_path, 'netT_PD_Adam.pt')
    torch.save(net_T.state_dict(), filename)

    save_path = os.getcwd()
    filename = os.path.join(save_path, 'netphi_PD_Adam.pt')
    torch.save(net_phi.state_dict(), filename)

    # write down the error
    with open('l2error_list', 'wb') as file1:
        pickle.dump(l2error_list, file1)

    # plot the error decay
    fig_plot = plt.figure(figsize=(20, 20))
    plt.plot(range(0, len(l2error_list)), log(l2error_list)/log(10))
    plt.title("plot of log_10(L2 error)")
    fig_plot.savefig("Plot of the log l2 error"+'.pdf')
    plt.show()
    plt.close()



In [None]:
# @title apply Primal-Dual Adam solver
save_path = os.getcwd()


# In Google Colab, change to working directory
os.chdir("/content/drive/MyDrive/NPDG_codes_collection/OT_Gaussian")
# Verify
print("Current working directory:", os.getcwd())

# create folder for NPDHG
os.makedirs('PD Adam', exist_ok=True)
os.chdir('PD Adam')

L = 3.0
N = 1000

iter = 1 # 20000
phi_iter = 1
T_iter = 1
lr_T = 0.5 * 1e-4
lr_phi = 0.5 * 1e-4

network_length = 4  #  2
hidden_dimension_net_T = 80  #  50  # 100  # 600  # 40
hidden_dimension_net_phi = 80  #  50  # 150  # 600  # 50
flag_init = False

plot_period = 1 # 1500
print_period = 1 # 200
N_plot = 100
chosen_dim_0 = 0
chosen_dim_1 = 1
chosen_dim_2 = 3
chosen_dim_3 = 4

PD_Adam_solver( device, save_path, L, N,
                network_length, hidden_dimension_net_T, hidden_dimension_net_phi, flag_init,
                lr_T, lr_phi, iter, phi_iter, T_iter,
                plot_period, print_period, N_plot, chosen_dim_0, chosen_dim_1, chosen_dim_2, chosen_dim_3)



In [None]:
# @title NPDHG solver
import pickle


def NPDHG_solver(device, save_path, L, N,
                minres_max_iter, minres_tol,
                network_length, hidden_dimension_net_T, hidden_dimension_net_phi, flag_init,
                iter, phi_iter, T_iter, omega,
                plot_period, print_period, N_plot, chosen_dim_0, chosen_dim_1, chosen_dim_2, chosen_dim_3,
                number_stepsizes, base_stepsize,
                precond_type,  # "pullback_id_u_phi", "pullback_id_u_grad_phi", "pullback_grad_u_phi"
                stepsize_0=0.2,
                adaptive_or_fixed_stepsize="fixed", tau_T = 0.5 * 1e-3, tau_phi = 0.9 * 1e-3):

    torch.manual_seed(50)

    # initialize nets
    net_T = network_prim(network_length, dim, hidden_dimension_net_T, dim).to(device)
    net_phi = network_dual(network_length, dim, hidden_dimension_net_phi, 1).to(device)
    if flag_init == True:
        net_T.initialization()
        net_phi.initialization()

    # initial err
    error_0 = L2_error_T(net_T, 2000)
    print(error_0)

    if precond_type == "MpMd_id":
       G_T_type = "1"
       G_phi_type = "1"
    elif precond_type == "Mp_id_Md_nabla":
       G_T_type = "1"
       G_phi_type = "2"
    else:
       print("Wrong precond_type")


    loss = PDHG_loss3

    rho0_samples = rho0(N)
    plot_T_with_grad_ureal(0, rho0_samples, L, chosen_dim_0, chosen_dim_1, net_T, save_path)
    plot_T_with_grad_ureal(0, rho0_samples, L, chosen_dim_2, chosen_dim_3, net_T, save_path)
    plot_pushfwded_samples(0, net_T, rho0_samples, chosen_dim_0, chosen_dim_1, save_path=save_path)
    plot_pushfwded_samples(0, net_T, rho0_samples, chosen_dim_2, chosen_dim_3, save_path=save_path)


    comptime = []
    total_time = 0
    l2error_list = []
    preconded_nabla_eta_norm_list = []
    preconded_nabla_theta_norm_list = []
    ######################################################### PDHG iterations START HERE ###################################################################################################################
    for t in range(iter):

        t_0 = time.time()

        rho0_samples = rho0(N)
        rho1_samples = rho1(N)
        net_phi0 = network_dual(network_length, dim, hidden_dimension_net_phi, 1).to(device)
        net_phi0.load_state_dict(net_phi.state_dict())

        net_T.zero_grad()
        net_phi.zero_grad()
        ############################# update phi_\eta #####################################
        for inner_iter in range(phi_iter):
            ############### compute G(\eta)^{-1} \nabla_\eta loss() #########################
            lossa = loss(net_T, net_phi,  rho0_samples, rho1_samples)
            nabla_eta_loss = torch.autograd.grad(lossa, net_phi.parameters(), grad_outputs=None, allow_unused=True, retain_graph=True, create_graph=True)
            vectorized_nabla_eta_loss = torch.nn.utils.parameters_to_vector(nabla_eta_loss)

            # copy net_phi for G(\eta) computation
            net_phi_auxil = network_dual(network_length, dim, hidden_dimension_net_phi, 1).to(device)
            net_phi_auxil.load_state_dict(net_phi.state_dict())

            # compute G(\eta)^{-1} \nabla_\eta loss()
            G_inv_nabla_eta_loss, info_phi = minres_solver_G(net_phi, net_phi_auxil, rho1_samples, None, vectorized_nabla_eta_loss, device, None, minres_max_iter, minres_tol, G_phi_type)

            # update \eta
            original_eta = torch.nn.utils.parameters_to_vector(net_phi.parameters())
            if adaptive_or_fixed_stepsize == "adaptive":
              updated_eta, tau_phi, value_along_tau_phis = update_param(number_stepsizes, stepsize_0, base_stepsize, original_eta, G_inv_nabla_eta_loss, net_T, hidden_dimension_net_T, net_phi, hidden_dimension_net_phi, rho0_samples, rho1_samples, loss, "phi", "ascent")
            elif adaptive_or_fixed_stepsize == "fixed":
               updated_eta = original_eta + tau_phi * G_inv_nabla_eta_loss
            else:
               raise ValueError("adaptive_or_fixed_stepsize must be 'adaptive' or 'fixed'")
            torch.nn.utils.vector_to_parameters(updated_eta, net_phi.parameters())

        ######################## update theta ##################################
        for inner_iter in range(T_iter):
            # compute G(\theta)^{-1} \nabla_\theta loss()
            lossb = PDHG_loss4_with_extrapolation(net_T, net_phi, net_phi0, rho0_samples, rho1_samples, omega)
            nabla_theta_loss = torch.autograd.grad(lossb, net_T.parameters(), grad_outputs=None, allow_unused=True, retain_graph=True, create_graph=True)
            vectorized_nabla_theta_loss = torch.nn.utils.parameters_to_vector(nabla_theta_loss)

            # copy net_T for  G(\theta) computation
            net_T_auxil = network_prim(network_length, dim, hidden_dimension_net_T, dim).to(device)
            net_T_auxil.load_state_dict(net_T.state_dict())

            # compute G(\theta)^{-1} \nabla_\theta loss()
            G_inv_nabla_theta_loss, info_T = minres_solver_G(net_T, net_T_auxil, rho0_samples, None, vectorized_nabla_theta_loss, device, None, minres_max_iter, minres_tol, G_T_type)

            ############# update theta ####################
            original_theta = torch.nn.utils.parameters_to_vector(net_T.parameters())
            if adaptive_or_fixed_stepsize == "adaptive":
               updated_theta, tau_T, value_along_tau_Ts = update_param(number_stepsizes, stepsize_0, base_stepsize, original_theta, G_inv_nabla_theta_loss, net_T, hidden_dimension_net_T, net_phi, hidden_dimension_net_phi, rho0_samples, rho1_samples, loss, "T", "descent")
            elif adaptive_or_fixed_stepsize == "fixed":
               updated_theta = original_theta - tau_T * G_inv_nabla_theta_loss
            else:
               raise ValueError("adaptive_or_fixed_stepsize must be 'adaptive' or 'fixed'")
            torch.nn.utils.vector_to_parameters(updated_theta, net_T.parameters())

        t_1 = time.time()
        total_time += t_1 - t_0
        comptime.append(total_time)

        ################ plot ##################
        if (t+1) % plot_period == 0:
            plot_T_with_grad_ureal(t+1, rho0_samples, L, chosen_dim_0, chosen_dim_1, net_T, save_path)
            plot_pushfwded_samples(t+1, net_T, rho0_samples, chosen_dim_0, chosen_dim_1, save_path=save_path)
            plot_T_with_grad_ureal(t+1, rho0_samples, L, chosen_dim_2, chosen_dim_3, net_T, save_path)
            plot_pushfwded_samples(t+1, net_T, rho0_samples, chosen_dim_2, chosen_dim_3, save_path=save_path)

        ##############
        L2error = L2_error_T(net_T, 1800)
        l2error_list.append(L2error.cpu().detach())
        if (t+1) % print_period == 0:
            print("Iteration: {}, ".format(t))
            print("L2 error = {}".format(L2error))


    ######################################################### PDHG iterations END HERE #################################################################################################################
    # save the models
    save_path = os.getcwd()
    filename = os.path.join(save_path, 'netT_NPDHG.pt')
    torch.save(net_T.state_dict(), filename)

    save_path = os.getcwd()
    filename = os.path.join(save_path, 'netphi_NPDHG.pt')
    torch.save(net_phi.state_dict(), filename)


    # write down the error
    with open('l2error_list', 'wb') as file1:
        pickle.dump(l2error_list, file1)
    with open('comptime', 'wb') as f:
        pickle.dump(comptime, f)


    # plot the error decay
    fig_plot = plt.figure(figsize=(20, 20))
    plt.plot(range(0, len(l2error_list)), log(l2error_list)/log(10))
    plt.title("plot of log_10(L2 error)")
    fig_plot.savefig("Plot of the log l2 error"+'.pdf')
    plt.show()
    plt.close()

    # plot the error decay
    fig_plot = plt.figure(figsize=(20, 20))
    plt.plot(comptime, log(l2error_list)/log(10))
    plt.title("log_10(L2 error) vs. computation time (seconds)\n", fontsize=40)
    plt.xlabel("computation time (seconds)", fontsize=30)
    plt.ylabel("log_10(L2 error)", fontsize=30)
    fig_plot.savefig("Plot of the log l2 error vs comp time "+'.pdf')
    plt.show()
    plt.close()



In [None]:
# @title apply NPDHG solver
save_path = os.getcwd()


# In Google Colab, change to working directory
os.chdir("/content/drive/MyDrive/NPDG_codes_collection/OT_Gaussian")
# Verify
print("Current working directory:", os.getcwd())

# create folder for NPDHG
os.makedirs('NPDHG precondition', exist_ok=True)
os.chdir('NPDHG precondition')


L = 3
N = 2000

minres_max_iter = 1000
minres_tol = 1e-3

network_length = 4
hidden_dimension_net_T = 80
hidden_dimension_net_phi = 80
flag_init = False

iter =  12000
phi_iter = 1
T_iter = 1
omega = 1.0

plot_period =  2000
print_period =  100
N_plot = 1000
chosen_dim_0 = 0
chosen_dim_1 = 1
chosen_dim_2 = 3
chosen_dim_3 = 4

number_stepsizes = 50
base_stepsize = 0.8

# precond_type = "MpMd_id"
precond_type = "Mp_id_Md_nabla"

NPDHG_solver(device, save_path, L, N,
            minres_max_iter, minres_tol,
            network_length, hidden_dimension_net_T, hidden_dimension_net_phi, flag_init,
            iter, phi_iter, T_iter, omega,
            plot_period, print_period, N_plot, chosen_dim_0, chosen_dim_1, chosen_dim_2, chosen_dim_3,
            number_stepsizes, base_stepsize,
            precond_type,
            adaptive_or_fixed_stepsize="fixed", tau_T = 0.5 * 1e-1, tau_phi = 0.9 * 1e-1
            )

