In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cd /content/drive/MyDrive/NPDG/OTMixGaussian

/content/drive/MyDrive/NPDG_Github/OTMixGaussian


This notebook computes 50D optimal transport (OT) problem form Gaussian to mixture of non-equally distributed Gaussians using

*   NPDG method
*   Primal-Dual Adam method

See Section 5.5.3 for a more detailed description.


In [None]:
# @title import
import math
import random
import scipy

import matplotlib
matplotlib.use('agg')

from matplotlib.pyplot import figure
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import numpy as np

from torch.func import grad, hessian, vmap
from torch.func import jacrev
from torch.func import functional_call
import time


import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pickle

from numpy import *
from torch import Tensor
from torch.nn import Parameter
from torch.optim.lr_scheduler import ExponentialLR

import os
import argparse



In [None]:
# @title Set up dimension

dim = 50


In [None]:
# @title check CUDA availability
torch.cuda.is_available()

True

In [None]:
# @title model (MLP)


class network_prim(nn.Module):
    def __init__(self, network_length, input_dimension, hidden_dimension, output_dimension):
        super(network_prim, self).__init__()
        self.network_length = network_length
        self.linears = nn.ModuleList([nn.Linear(input_dimension, hidden_dimension)])
        self.linears.extend([nn.Linear(hidden_dimension, hidden_dimension) for _ in range(1, network_length-1)])
        self.linears.extend([nn.Linear(hidden_dimension, output_dimension, bias=False)])
        self.prelu = nn.PReLU()

    def initialization(self):
        for l in self.linears:
            l.weight.data.uniform_(-0.25, 0.25)
            if l.bias is not None:
                l.bias.data.uniform_(-0.3, 0.3)

    def forward(self, x):
        l = self.linears[0]
        x = l(x)
        for l in self.linears[1: self.network_length-1]:
            x = self.prelu(x)
            x = l(x)
        x = self.prelu(x)
        l = self.linears[self.network_length-1]
        x = l(x)

        return x


class network_dual(nn.Module):
    def __init__(self, network_length, input_dimension, hidden_dimension, output_dimension):
        super(network_dual, self).__init__()
        self.network_length = network_length
        self.linears = nn.ModuleList([nn.Linear(input_dimension, hidden_dimension)])
        self.linears.extend([nn.Linear(hidden_dimension, hidden_dimension) for _ in range(1, network_length-1)])
        self.linears.extend([nn.Linear(hidden_dimension, output_dimension, bias=False)])
        self.prelu = nn.PReLU()

    def initialization(self):
        for l in self.linears:
            l.weight.data.normal_()
            if l.bias is not None:
                l.bias.data.normal_()

    def forward(self, x):
        l = self.linears[0]
        x = l(x)

        for l in self.linears[1: self.network_length-1]:
            x = self.prelu(x)
            x = l(x)
        x = self.prelu(x)
        l = self.linears[self.network_length-1]
        x = l(x)

        return x






In [None]:
# @title 2D vector field plotting & 2D pushforwarded sample


def plot_grad_u_with_grad_ureal(iter, samples, l, chosen_dim_0, chosen_dim_1, net_u, save_path):

    grad_netu = 3. * gradient_nn(net_u, samples)
    grad_netu = grad_netu.cpu().detach().numpy()
    grad_ureal = 3. * grad_u_real(samples)
    grad_ureal = grad_ureal.cpu().detach().numpy()

    samples = samples.cpu().detach().numpy()

    figure(num=None, figsize=(34, 34), dpi=80, facecolor='w', edgecolor='k')

    plt.xlim([-l, l])
    plt.ylim([-l, l])

    plt.quiver(samples[:, chosen_dim_0], samples[:, chosen_dim_1], grad_netu[:, chosen_dim_0], grad_netu[:, chosen_dim_1], scale=None, scale_units='inches', color = 'green', width=0.002)
    plt.quiver(samples[:, chosen_dim_0], samples[:, chosen_dim_1], grad_ureal[:, chosen_dim_0], grad_ureal[:, chosen_dim_1], scale=None, scale_units='inches', color = 'red', width=0.002)

    plt.xlabel("{} component".format(chosen_dim_0))
    plt.ylabel("{} component".format(chosen_dim_1))

    filename = os.path.join(save_path, '(Iteration={}) gradient of net_u with gradient of u_real at sample points (on {}-{} plane)'.format(iter, chosen_dim_0, chosen_dim_1) + '.pdf')
    plt.savefig(filename)
    plt.close()


def plot_T_with_grad_ureal(iter, samples, l, chosen_dim_0, chosen_dim_1, net_T, save_path):

    T_x = net_T(samples)
    T_x = T_x.cpu().detach().numpy()
    grad_ureal = 3. * grad_u_real(samples)
    grad_ureal = grad_ureal.cpu().detach().numpy()

    samples = samples.cpu().detach().numpy()

    figure(num=None, figsize=(34, 34), dpi=80, facecolor='w', edgecolor='k')

    plt.xlim([-l, l])
    plt.ylim([-l, l])

    plt.quiver(samples[:, chosen_dim_0], samples[:, chosen_dim_1], T_x[:, chosen_dim_0], T_x[:, chosen_dim_1], scale=None, scale_units='inches', color = 'green', width=0.002)
    plt.quiver(samples[:, chosen_dim_0], samples[:, chosen_dim_1], grad_ureal[:, chosen_dim_0], grad_ureal[:, chosen_dim_1], scale=None, scale_units='inches', color = 'red', width=0.002)

    plt.xlabel("{} component".format(chosen_dim_0))
    plt.ylabel("{} component".format(chosen_dim_1))

    filename = os.path.join(save_path, '(Iteration={}) map T with gradient of u_real at sample points (on {}-{} plane)'.format(iter, chosen_dim_0, chosen_dim_1) + '.pdf')
    plt.savefig(filename)
    plt.close()



def plot_T(iter, samples, l, chosen_dim_0, chosen_dim_1, net_T, save_path):


    T_x = net_T(samples)
    T_x = T_x.cpu().detach().numpy()
    samples = samples.cpu().detach().numpy()
    figure(num=None, figsize=(34, 34), dpi=80, facecolor='w', edgecolor='k')
    plt.xlim([-l, l])
    plt.ylim([-l, l])
    plt.quiver(samples[:, chosen_dim_0], samples[:, chosen_dim_1], T_x[:, chosen_dim_0], T_x[:, chosen_dim_1], scale=None, scale_units='inches', color = 'green', width=0.002)

    plt.xlabel("{} component".format(chosen_dim_0))
    plt.ylabel("{} component".format(chosen_dim_1))

    filename = os.path.join(save_path, '(Iteration={}) map T at sample points (on {}-{} plane)'.format(iter, chosen_dim_0, chosen_dim_1) + '.pdf')
    plt.savefig(filename)
    plt.close()


def plot_pushfwded_samples(iter, l, net_T, samples, chosen_dim_0, chosen_dim_1, save_path):

    N = samples.size()[0]
    T_x = net_T(samples)
    T_x = T_x.cpu().detach().numpy()
    samples = samples.cpu().detach().numpy()
    target_samples = rho1(N).cpu()
    figure(num=None, figsize=(34, 34), dpi=80, facecolor='w', edgecolor='k')
    # plot circles
    num_gaussians = 8
    sigma_mixedgauss = 0.4
    rad = 3.0 * sigma_mixedgauss
    Rad = 3.0
    for j in range(num_gaussians):
        polar_angle = np.linspace(0, 2 * np.pi, 500)
        x = rad * np.cos(polar_angle) + Rad * np.cos(2 * np.pi / num_gaussians * j)
        y = rad * np.sin(polar_angle) + Rad * np.sin(2 * np.pi / num_gaussians * j)
        for jj in range(500 -1 ):
            plt.plot([x[jj], x[jj+1]], [y[jj], y[jj+1]], color='gray', alpha=0.2)
    plt.scatter(T_x[:, chosen_dim_0], T_x[:, chosen_dim_1], color='blue', s=15)
    plt.xlim([ - l, l ])
    plt.ylim([ - l, l ])
    plt.xlabel("{} component".format(chosen_dim_0))
    plt.ylabel("{} component".format(chosen_dim_1))

    filename = os.path.join(save_path, '(Iteration={}) original samples (green, ~rho0) pushforwarded samples (blue) with target samples (red ~rho1) (on {}-{} plane)'.format(iter, chosen_dim_0, chosen_dim_1) + '.pdf')
    plt.savefig(filename)
    plt.close()


def plot_pushfwded_samples_kde_heatmap(iter, l, net_T, samples, chosen_dim_0, chosen_dim_1, save_path, num_gaussians=8, sigma_mixedgauss = 0.4, Rad=3.0 ):

    N = samples.size()[0]

    T_x = net_T(samples)
    T_x = T_x.cpu().detach().numpy()

    samples = samples.cpu().detach().numpy()

    target_samples = rho1(N).cpu()

    sample_start = torch.zeros(2 * num_gaussians, dim).cuda()
    for j in range(2 * num_gaussians):
        angle = np.pi / num_gaussians * j
        sample_start[j, chosen_dim_0] = 0.2 * Rad * np.cos(angle)
        sample_start[j, chosen_dim_1] = 0.2 * Rad * np.sin(angle)
    sample_start_np = sample_start.cpu().detach().numpy()

    sample_end = net_T(sample_start).cpu().detach().numpy()
    rad=sigma_mixedgauss*3.0
    figure(num=None, figsize=(34, 34), dpi=80, facecolor='w', edgecolor='k')

    # plot circles
    for j in range(num_gaussians):
        polar_angle = np.linspace(0, 2 * np.pi, 500)
        x = rad * np.cos(polar_angle) + Rad * np.cos(2 * np.pi / num_gaussians * j)
        y = rad * np.sin(polar_angle) + Rad * np.sin(2 * np.pi / num_gaussians * j)
        for jj in range(500 -1 ):
            plt.plot([x[jj], x[jj+1]], [y[jj], y[jj+1]], color='red', alpha=1)

    # plot maps
    for j in range(2 * num_gaussians):
        start_coord = sample_start_np[j, [chosen_dim_0, chosen_dim_1]]
        end_coord = sample_end[j, [chosen_dim_0, chosen_dim_1]]
        interpolt = np.linspace(0, 1, 500)
        x = start_coord[0] * (1 - interpolt) + end_coord[0] * interpolt
        y = start_coord[1] * (1 - interpolt) + end_coord[1] * interpolt
        for jj in range(500 - 1):
            plt.plot([x[jj], x[jj+1]], [y[jj], y[jj+1]], color='red', alpha=1)
    polar_angle = np.linspace(0, 2 * np.pi, 500)
    x = Rad * np.cos(polar_angle)
    y = Rad * np.sin(polar_angle)
    for j in range(500 -1 ):
            plt.plot([x[jj], x[jj+1]], [y[jj], y[jj+1]], color='red', alpha=1)
    d = {'chosen dim0': T_x[:, chosen_dim_0], 'chosen dim1': T_x[:, chosen_dim_1]}
    df = pd.DataFrame(d)
    sns.kdeplot(
      data=df ,  x="chosen dim0", y="chosen dim1",
      fill=True, thresh=0, levels=100, cmap = "vlag" , #cmap="mako",
    )

    filename = os.path.join(save_path, '(Iteration={}) KDE heatmap original samples (green, ~rho0) pushforwarded samples (blue) with target samples (red ~rho1) (on {}-{} plane)'.format(iter, chosen_dim_0, chosen_dim_1) + '.pdf')
    plt.savefig(filename)
    plt.close()






In [None]:
# @title define rho0 (sampler) , rho1 (sampler) , real solution


def rho0(n, d=dim):
    x = torch.randn(n, d)
    return x.cuda()


R = 3.0
def rho1(N, chosen_dim_0 = 10, chosen_dim_1 = 20, d=dim):

    sigma_mixedgauss = 0.4
    R = 3.0
    num_gaussians = 8
    mu = torch.zeros(num_gaussians, 1, dim)
    for i in range(num_gaussians):
        mu[i, 0, chosen_dim_0] = R * torch.cos(torch.tensor(i/num_gaussians * 2 * np.pi))
        mu[i, 0, chosen_dim_1] = R * torch.sin(torch.tensor(i/num_gaussians * 2 * np.pi))

    lambda1 =  0.2    # 0.5  # 0.2
    lambda2 = 0.8    # 0.5  # 0.8

    n = int32(N / (num_gaussians/2))
    N = n * num_gaussians
    x = torch.empty((0, d))

    for j in range(num_gaussians):

        if j % 2 == 0:
            num_samples = int32(lambda1 * n)
            z = torch.randn(num_samples, d)
        else:
            num_samples = int32(lambda2 * n)
            z = torch.randn(num_samples, d)

        z = sigma_mixedgauss * z + mu[j, :, :]
        x = torch.cat((x, z), 0)

    return x.cuda()



In [None]:
# @title L2 error


def gradient_nn(network, x):

    input_variable = autograd.Variable(x, requires_grad=True)
    output_value = network(input_variable)
    gradients_x = autograd.grad(outputs=output_value, inputs=input_variable, grad_outputs=torch.ones(output_value.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0]

    return gradients_x




In [None]:
# @title PDHG loss


def gradient_nn(network, x):

    input_variable = autograd.Variable(x, requires_grad=True)
    output_value = network(input_variable)
    gradients_x = autograd.grad(outputs=output_value, inputs=input_variable, grad_outputs=torch.ones(output_value.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0]

    return gradients_x


# PDHG loss with no boundary error
def PDHG_loss1(net_u, net_phi, rho0_samples, rho1_samples):

    grad_u_x = gradient_nn(net_u, rho0_samples)
    loss = net_phi(grad_u_x) - net_phi(rho1_samples)

    return loss.mean()


# derived from the Monge problem
def PDHG_loss2(net_u, net_phi, rho0_samples, rho1_samples):

    grad_u_x = gradient_nn(net_u, rho0_samples)
    loss = - torch.sum(rho0_samples * grad_u_x, -1).mean() + net_phi(grad_u_x).mean() - net_phi(rho1_samples).mean()

    return loss


def PDHG_loss3(net_T, net_phi, rho0_samples, rho1_samples):

    T_x = net_T(rho0_samples)
    loss = - torch.sum(rho0_samples * T_x, -1).mean() + net_phi(T_x).mean() - net_phi(rho1_samples).mean()

    return loss


def PDHG_loss4(net_T, net_phi, rho0_samples, rho1_samples):

    T_x = net_T(rho0_samples)
    sqr_term = 1/2*torch.sum((T_x - rho0_samples)*(T_x - rho0_samples), -1).mean()
    loss = sqr_term + net_phi(T_x).mean() - net_phi(rho1_samples).mean()

    return loss


def PDHG_loss_extplt_phi(net_T, net_phi, net_phi0, omega, rho0_samples, rho1_samples):

    T_x = net_T(rho0_samples)
    sqr_term = 1/2*torch.sum((T_x - rho0_samples)*(T_x - rho0_samples), -1).mean()
    phi_Trho0 = net_phi(T_x).mean()
    phi_rho1 = net_phi(rho1_samples).mean()
    phi0_Trho0 = net_phi0(T_x).mean()
    phi0_rho1 = net_phi0(rho1_samples).mean()
    tilde_phi_Trho0 = (1+omega) * phi_Trho0 - omega * phi0_Trho0
    tilde_phi_rho1 = (1+omega) * phi_rho1 - omega * phi0_rho1
    loss = sqr_term + tilde_phi_Trho0 - tilde_phi_rho1

    return loss



In [None]:
# @title G(\theta) as a linear opt another
import scipy
from scipy.sparse.linalg import LinearOperator
import torch.autograd.functional as functional
# from torch.func import hessian, vmap, jacrev, functional_call


def gradient_nn(network, x):
    input_variable = autograd.Variable(x, requires_grad=True)
    output_value = network(input_variable)
    gradients_x = autograd.grad(outputs=output_value, inputs=input_variable, grad_outputs=torch.ones(output_value.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0]
    return gradients_x


def tensor_to_numpy(u):
    if u.device=="cpu":
        return u.detach().numpy()
    else:
        return u.cpu().detach().numpy()


def form_metric_tensor(net, G_samples, device):
    N = G_samples.size()[0]
    Jacobi_NN = jacrev(functional_call, argnums = 1)
    D_param_D_x_NN = vmap(jacrev(Jacobi_NN, argnums = 2), in_dims = (None, None, 0))
    D_param_D_x_net_on_x = D_param_D_x_NN(net, dict(net.named_parameters()), G_samples)
    num_params = torch.nn.utils.parameters_to_vector(net.parameters()).size()[0]
    list_of_vectorized_param_gradients = []
    for param_gradients in dict(D_param_D_x_net_on_x).items():
        vectorized_param_gradients = param_gradients[1].view(N, -1, input_dim)
        list_of_vectorized_param_gradients.append(vectorized_param_gradients)
    total_vectorized_param_gradients = torch.cat(list_of_vectorized_param_gradients, 1)
    transpose_total_vectorized_param_gradients = torch.transpose(total_vectorized_param_gradients, 1, 2)
    batched_metric_tensor = torch.matmul(total_vectorized_param_gradients, transpose_total_vectorized_param_gradients)

    metric_tensor = torch.mean(batched_metric_tensor, 0)

    return metric_tensor


# \mathcal M = grad operator
def metric_tensor_as_nabla_op(net, net_auxil, G_samples, vec, device):

    num_params = len(torch.nn.utils.parameters_to_vector(net.parameters()))
    params_net = dict(net.named_parameters())
    params_net_auxil = dict(net_auxil.named_parameters())

    ################### computation starts here ##################################
    net.zero_grad()
    net_auxil.zero_grad()
    grad_net_x = gradient_nn(net, G_samples)
    grad_net_auxil_x = gradient_nn(net_auxil, G_samples)
    ave_sqr_grad_net = torch.sum(grad_net_x * grad_net_auxil_x) / G_samples.size()[0]
    nabla_theta_ave_sqr_grad_net = torch.autograd.grad(ave_sqr_grad_net, net_auxil.parameters(), grad_outputs=None ,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_nabla_theta_ave_sqr_grad_net = torch.nn.utils.parameters_to_vector(nabla_theta_ave_sqr_grad_net)
    vec_dot_nabla_theta_ave_sqr_grad_net = vectorize_nabla_theta_ave_sqr_grad_net.dot(vec)
    metric_tensor_mult_vec = torch.autograd.grad(vec_dot_nabla_theta_ave_sqr_grad_net, net.parameters(), grad_outputs=None,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_metric_tensor_mult_vec = torch.nn.utils.parameters_to_vector(metric_tensor_mult_vec)

    return vectorize_metric_tensor_mult_vec


# \mathcal M = identity operator
def metric_tensor_as_op_identity_part(net, net_auxil, G_samples, vec, device):

    num_params = len(torch.nn.utils.parameters_to_vector(net.parameters()))
    params_net = dict(net.named_parameters())
    params_net_auxil = dict(net_auxil.named_parameters())

    ################### computation starts here ##################################
    net_x = net(G_samples)
    net_auxil_x = net_auxil(G_samples)
    ave_net = torch.sum(net_x * net_auxil_x) / G_samples.size()[0]
    nabla_theta_ave_net = torch.autograd.grad(ave_net, net_auxil.parameters(), grad_outputs=None ,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_nabla_theta_net = torch.nn.utils.parameters_to_vector(nabla_theta_ave_net)
    vec_dot_nabla_theta_ave_net = vectorize_nabla_theta_net.dot(vec)
    metric_tensor_mult_vec = torch.autograd.grad(vec_dot_nabla_theta_ave_net, net.parameters(), grad_outputs=None,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_metric_tensor_mult_vec = torch.nn.utils.parameters_to_vector(metric_tensor_mult_vec)

    return vectorize_metric_tensor_mult_vec


# G1: \mathcal M = Id operator
# G2: \mathcal M = â–½ operator

def minres_solver_G(net, net_auxil, interior_samples, boundary_samples, RHS_vec, device, bd_lambda, max_iternum, minres_tolerance, G_type):

    num_params = torch.nn.utils.parameters_to_vector(net.parameters()).size()[0]

    def G1_as_operator(vec):  # input the vector v [on CPU], return vector Gv
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_op_identity_part(net, net_auxil, interior_samples, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    def G2_as_operator(vec):  # input the vector v [on CPU], return vector Gv
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_nabla_op(net, net_auxil, interior_samples, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    if G_type == "1":
        G_operator = LinearOperator((num_params, num_params), matvec=G1_as_operator)
    elif G_type == "2":
        G_operator = LinearOperator((num_params, num_params), matvec=G2_as_operator)
    else:
        print("Wrong G_type")

    np_RHS_vec = tensor_to_numpy(RHS_vec)
    sol_vec, info = scipy.sparse.linalg.minres(G_operator, np_RHS_vec, rtol=minres_tolerance, maxiter=max_iternum)
    if (torch.max(torch.isnan(torch.tensor(sol_vec))) > 0):
        print("Got the NAN!!!")
        sol_vec = np_RHS_vec
        info = 0
    tensorized_sol_vec = torch.Tensor(sol_vec).to(device)

    return tensorized_sol_vec, info #, norm_err_vec








In [None]:
# @title update param

def update_param(number_stepsizes, stepsize_0, base_stepsize, theta_0, tangent_theta, net_u, net_phi, rho0_samples, rho1_samples, lossfunction, net_type, descent_or_ascent):

    stepsize_list = stepsize_0 * base_stepsize**np.arange(number_stepsizes)
    min_index = 0
    index = 0
    loss_along_stepsizes = []
    current_min = 1000

    if descent_or_ascent == "descent":
        flag = -1
    else:
        flag = 1

    if net_type == "u":
        net_test = network_prim(network_length, dim, hidden_dimension_net_u, 1).to(device)
    elif net_type == "phi":
        net_test = network_dual(network_length, dim, hidden_dimension_net_phi, 1).to(device)

    for stepsize in stepsize_list:
        updated_theta = theta_0 + flag * stepsize * tangent_theta
        torch.nn.utils.vector_to_parameters(updated_theta, net_test.parameters())
        if net_type == "u":
            loss = lossfunction(net_test, net_phi, rho0_samples, rho1_samples)
        elif net_type == "phi":
            loss = lossfunction(net_u, net_test, rho0_samples, rho1_samples)
        else:
            print("Wrong net_type")
        loss_along_stepsizes.append( - flag * loss.cpu().detach().numpy() )
        if loss < current_min:
            min_index = index
            current_min = loss

        index = index + 1

    optimal_stepsize = stepsize_list[min_index]
    optimal_updated_theta = theta_0 + flag * optimal_stepsize * tangent_theta

    return optimal_updated_theta, optimal_stepsize, loss_along_stepsizes



In [None]:
# @title PD Adam solver use T instead of grad u


def PD_Adam_solver_with_T( device, save_path, L, N,
                    network_length, hidden_dimension_net_T, hidden_dimension_net_phi, flag_init,
                    lr_T, lr_phi, iter, phi_iter, T_iter,
                    plot_period, N_plot, chosen_dim_0, chosen_dim_1, chosen_dim_2, chosen_dim_3):

    torch.manual_seed(50)

    # initialize nets
    net_T = network_prim(network_length, dim, hidden_dimension_net_T, dim).to(device)
    optim_T = torch.optim.Adam(net_T.parameters(), lr=lr_T)
    net_phi = network_dual(network_length, dim, hidden_dimension_net_phi, 1).to(device)
    optim_phi = torch.optim.Adam(net_phi.parameters(), lr=lr_phi)
    if flag_init == True:
        net_T.initialization()
        net_phi.initialization()

    loss = PDHG_loss3

    rho0_samples = rho0(N)
    plot_T(0, rho0_samples, L, chosen_dim_0, chosen_dim_1, net_T, save_path)
    plot_pushfwded_samples(0, L, net_T, rho0_samples, chosen_dim_0, chosen_dim_1, save_path=save_path)

    l2error_list = []
    ######################################################### PDHG iterations START HERE ###################################################################################################################
    for t in range(iter):

        print("Iteration:")
        print(t)

        rho0_samples = rho0(N)
        rho1_samples = rho1(N)

        ############################# update phi_\eta #####################################
        for inner_iter in range(phi_iter):
            optim_phi.zero_grad()
            lossa = - loss(net_T, net_phi,  rho0_samples, rho1_samples)
            lossa.backward(retain_graph=True)
            optim_phi.step()

        ######################## update theta ##################################
        for inner_iter in range(T_iter):
            optim_T.zero_grad()
            lossb = loss(net_T, net_phi, rho0_samples, rho1_samples)
            lossb.backward(retain_graph=True)
            optim_T.step()

        ############### save the inter models #####################################
        # save the models
        if (t+1) % plot_period == 0:
          save_path = os.getcwd()
          filename = os.path.join(save_path, 'Iter_{}_netT.pt'.format(t))
          torch.save(net_T.state_dict(), filename)

          save_path = os.getcwd()
          filename = os.path.join(save_path, 'Iter_{}_netphi.pt'.format(t))
          torch.save(net_phi.state_dict(), filename)

        ################ plot ##################
        if (t+1) % plot_period == 0:
            plot_T(t+1, rho0_samples, L, chosen_dim_0, chosen_dim_1, net_T, save_path)
            plot_pushfwded_samples(t+1, L, net_T, rho0_samples, chosen_dim_0, chosen_dim_1, save_path=save_path)

        ##############
        print("Iter: {}, ".format(t))

    ######################################################### PDHG iterations END HERE #################################################################################################################
    # save the models
    save_path = os.getcwd()
    filename = os.path.join(save_path, 'netT.pt')
    torch.save(net_T.state_dict(), filename)

    save_path = os.getcwd()
    filename = os.path.join(save_path, 'netphi.pt')
    torch.save(net_phi.state_dict(), filename)





In [None]:
# @title apply PD_Adam solver

device = torch.device('cuda:0')

save_path = os.getcwd()

L = R + 4    # 3.0
N = 2000

iter = 370000
phi_iter = 1
T_iter = 1
lr_T =  1e-5
lr_phi = 1e-5

network_length = 6
hidden_dimension_net_T = 120
hidden_dimension_net_phi = 120
flag_init = False

plot_period = 5000
N_plot = 100
chosen_dim_0 = 10
chosen_dim_1 = 20

chosen_dim_2 = 3
chosen_dim_3 = 4

PD_Adam_solver_with_T( device, save_path, L, N,
                    network_length, hidden_dimension_net_T, hidden_dimension_net_phi, flag_init,
                    lr_T, lr_phi, iter, phi_iter, T_iter,
                    plot_period, N_plot, chosen_dim_0, chosen_dim_1, chosen_dim_2, chosen_dim_3
                    )



In [None]:
# @title PDHG solver use T instead of grad u
# (use T(z) z~rho0 when computing preconditioning matrix of phi)
# extrapolate in phi space
import pickle


def PDHG_solver_use_T_rho_0_samples_as_precond_extrapolate_phi(device, save_path, L, N,
                minres_max_iter, minres_tol,
                network_length, hidden_dimension_net_T, hidden_dimension_net_phi, flag_init,
                iter, phi_iter, T_iter, omega, iter0,
                plot_period, N_plot, chosen_dim_0, chosen_dim_1, chosen_dim_2, chosen_dim_3,
                number_stepsizes, base_stepsize,
                precond_type,
                stepsize_0=0.2,
                adaptive_or_fixed_stepsize="fixed", tau_T = 0.5 * 1e-3, tau_phi = 0.9 * 1e-3
                ):


    torch.manual_seed(50)

    # initialize nets
    net_T = network_prim(network_length, dim, hidden_dimension_net_T, dim).to(device)
    net_phi = network_dual(network_length, dim, hidden_dimension_net_phi, 1).to(device)
    if flag_init == True:
        net_T.initialization()

    if precond_type == "MpMd_id":
       G_T_type = "1"
       G_phi_type = "1"
    elif precond_type == "Mp_id_Md_nabla":
       G_T_type = "1"
       G_phi_type = "2"
    else:
       print("Wrong precond_type")

    loss = PDHG_loss4
    loss_with_extplt_phi = PDHG_loss_extplt_phi

    rho0_samples = rho0(N)

    plot_T(0, rho0_samples, L, chosen_dim_0, chosen_dim_1, net_T, save_path)
    plot_pushfwded_samples(0, L, net_T, rho0_samples, chosen_dim_0, chosen_dim_1, save_path=save_path)


    l2error_list = []
    preconded_nabla_eta_norm_list = []
    preconded_nabla_theta_norm_list = []
    ######################################################### PDHG iterations START HERE ###################################################################################################################
    for t in range(iter):

        print("Iteration:")
        print(t)

        rho0_samples = rho0(N)
        rho1_samples = rho1(N)

        net_T.zero_grad()
        net_phi.zero_grad()
        ############################# update phi_\eta #####################################\
        net_phi0 = network_dual(network_length, dim, hidden_dimension_net_phi, 1).to(device)
        net_phi0.load_state_dict(net_phi.state_dict())
        for inner_iter in range(phi_iter):
            ############### compute G(\eta)^{-1} \nabla_\eta loss() #########################
            lossa = loss(net_T, net_phi,  rho0_samples, rho1_samples)
            nabla_eta_loss = torch.autograd.grad(lossa, net_phi.parameters(), grad_outputs=None, allow_unused=True, retain_graph=True, create_graph=True)
            vectorized_nabla_eta_loss = torch.nn.utils.parameters_to_vector(nabla_eta_loss)

            # T_theta(rho0_samples)
            T_rho0_samples = net_T(rho0_samples)

            # copy net_phi for G(\eta) computation
            net_phi_auxil = network_dual(network_length, dim, hidden_dimension_net_phi, 1).to(device)
            net_phi_auxil.load_state_dict(net_phi.state_dict())

            # compute G(\eta)^{-1} \nabla_\eta loss()
            if t < iter0:
              G_inv_nabla_eta_loss, info_phi = minres_solver_G(net_phi, net_phi_auxil, rho1_samples, None, vectorized_nabla_eta_loss, device, None, minres_max_iter, minres_tol, G_phi_type)
            else:
              G_inv_nabla_eta_loss, info_phi = minres_solver_G(net_phi, net_phi_auxil, T_rho0_samples, None, vectorized_nabla_eta_loss, device, None, minres_max_iter, minres_tol, G_phi_type)

            # update \eta
            original_eta = torch.nn.utils.parameters_to_vector(net_phi.parameters())
            if adaptive_or_fixed_stepsize == "adaptive":
               updated_eta, tau_phi, value_along_tau_phis = update_param(number_stepsizes, stepsize_0, base_stepsize, original_eta, G_inv_nabla_eta_loss, net_T, net_phi, rho0_samples, rho1_samples, loss, "phi", "ascent")
            elif adaptive_or_fixed_stepsize == "fixed":
               updated_eta = original_eta + tau_phi * G_inv_nabla_eta_loss
            else:
               raise ValueError("adaptive_or_fixed_stepsize must be 'adaptive' or 'fixed'")
            torch.nn.utils.vector_to_parameters(updated_eta, net_phi.parameters())
            print("tau_phi = {}".format(tau_phi))

        ######################## update theta ##################################
        for inner_iter in range(T_iter):
            # compute G(\theta)^{-1} \nabla_\theta loss()
            lossb = loss_with_extplt_phi(net_T, net_phi, net_phi0, omega, rho0_samples, rho1_samples)
            nabla_theta_loss = torch.autograd.grad(lossb, net_T.parameters(), grad_outputs=None, allow_unused=True, retain_graph=True, create_graph=True)
            vectorized_nabla_theta_loss = torch.nn.utils.parameters_to_vector(nabla_theta_loss)

            # copy net_T for  G(\theta) computation
            net_T_auxil = network_prim(network_length, dim, hidden_dimension_net_T, dim).to(device)
            net_T_auxil.load_state_dict(net_T.state_dict())

            # compute G(\theta)^{-1} \nabla_\theta loss()
            G_inv_nabla_theta_loss, info_u = minres_solver_G(net_T, net_T_auxil, rho0_samples, None, vectorized_nabla_theta_loss, device, None, minres_max_iter, minres_tol, G_T_type)


            ############# update theta ####################
            original_theta = torch.nn.utils.parameters_to_vector(net_T.parameters())
            if adaptive_or_fixed_stepsize == "adaptive":
               updated_theta, tau_T, value_along_tau_Ts = update_param(number_stepsizes, stepsize_0, base_stepsize, original_theta, G_inv_nabla_theta_loss, net_T, net_phi, rho0_samples, rho1_samples, loss, "T", "ascent")
            elif adaptive_or_fixed_stepsize == "fixed":
               updated_theta = original_theta - tau_T * G_inv_nabla_theta_loss
            else:
               raise ValueError("adaptive_or_fixed_stepsize must be 'adaptive' or 'fixed'")
            torch.nn.utils.vector_to_parameters(updated_theta, net_T.parameters())
            print("tau_T={}".format( tau_T ))


        ############### save the inter models #####################################
        # save the models
        if (t+1) % plot_period == 0:
          save_path = os.getcwd()
          filename = os.path.join(save_path, 'Iter_{}_netT.pt'.format(t))
          torch.save(net_T.state_dict(), filename)

          save_path = os.getcwd()
          filename = os.path.join(save_path, 'Iter_{}_netphi.pt'.format(t))
          torch.save(net_phi.state_dict(), filename)


        ################ plot ##################
        if (t+1) % plot_period == 0:
            plot_T(t+1, rho0_samples, L, chosen_dim_0, chosen_dim_1, net_T, save_path)
            plot_pushfwded_samples(t+1, L, net_T, rho0_samples, chosen_dim_0, chosen_dim_1, save_path=save_path)
            rho0_samples_for_plot = rho0(5000)
            plot_pushfwded_samples_kde_heatmap(t+1, L, net_T, rho0_samples_for_plot, chosen_dim_0, chosen_dim_1, save_path, num_gaussians=8, sigma_mixedgauss = 0.4, Rad=3.0)

    ######################################################### PDHG iterations END HERE #################################################################################################################
    # save the models
    save_path = os.getcwd()
    filename = os.path.join(save_path, 'netT.pt')
    torch.save(net_T.state_dict(), filename)

    save_path = os.getcwd()
    filename = os.path.join(save_path, 'netphi.pt')
    torch.save(net_phi.state_dict(), filename)




In [None]:
# @title apply PDHG solver


device = torch.device('cuda:0')
save_path = os.getcwd()

L = R + 4
N =  2000

minres_max_iter = 1000
minres_tol =  1e-4

network_length = 6
hidden_dimension_net_T =  120
hidden_dimension_net_phi =  120
flag_init =   False

n0 = 1000
iter = 20000
phi_iter = 1
T_iter = 1
omega = 5

plot_period = 1000
N_plot = 1000
chosen_dim_0 =  10
chosen_dim_1 = 20
chosen_dim_2 = 3
chosen_dim_3 = 4

number_stepsizes = 50
base_stepsize = 0.8


precond_type = "Mp_id_Md_nabla"


PDHG_solver_use_T_rho_0_samples_as_precond_extrapolate_phi(device, save_path, L, N,
            minres_max_iter, minres_tol,
            network_length, hidden_dimension_net_T, hidden_dimension_net_phi, flag_init,
            iter, phi_iter, T_iter, omega, n0,
            plot_period, N_plot, chosen_dim_0, chosen_dim_1, chosen_dim_2, chosen_dim_3,
            number_stepsizes, base_stepsize,
            precond_type,
            adaptive_or_fixed_stepsize="fixed", tau_T = 0.5 * 1e-2, tau_phi = 0.5 * 1e-2
            )




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
tau_phi = 0.005
tau_T=0.005
Iteration:
17316
tau_phi = 0.005
tau_T=0.005
Iteration:
17317
tau_phi = 0.005
tau_T=0.005
Iteration:
17318
tau_phi = 0.005
tau_T=0.005
Iteration:
17319
tau_phi = 0.005
tau_T=0.005
Iteration:
17320
tau_phi = 0.005
tau_T=0.005
Iteration:
17321
tau_phi = 0.005
tau_T=0.005
Iteration:
17322
tau_phi = 0.005
tau_T=0.005
Iteration:
17323
tau_phi = 0.005
tau_T=0.005
Iteration:
17324
tau_phi = 0.005
tau_T=0.005
Iteration:
17325
tau_phi = 0.005
tau_T=0.005
Iteration:
17326
tau_phi = 0.005
tau_T=0.005
Iteration:
17327
tau_phi = 0.005
tau_T=0.005
Iteration:
17328
tau_phi = 0.005
tau_T=0.005
Iteration:
17329
tau_phi = 0.005
tau_T=0.005
Iteration:
17330
tau_phi = 0.005
tau_T=0.005
Iteration:
17331
tau_phi = 0.005
tau_T=0.005
Iteration:
17332
tau_phi = 0.005
tau_T=0.005
Iteration:
17333
tau_phi = 0.005
tau_T=0.005
Iteration:
17334
tau_phi = 0.005
tau_T=0.005
Iteration:
17335
tau_phi = 0.005
tau_T=0.005
Iterati