Relaxed entropy dual with normal feed forward nn

In [4]:
import sys
sys.path.append('code')

In [5]:
import ot
import torch
def optimal_transport_plan(ground_truth: torch.Tensor, approximation: torch.Tensor) -> torch.Tensor:
    """
    Computes the Wasserstein distance between two sets of points.

    Args:
        ground_truth (torch.Tensor): The ground truth points.
        approximation (torch.Tensor): The approximation points.

    Returns:
        float: The Wasserstein distance between the two sets of points.
    """
    return ot.solve_sample(X_a=ground_truth, X_b=approximation).plan

In [6]:
import torch
torch.manual_seed(0)
from tqdm import trange

from data import create_joint_x_y
from torch import nn

device_and_dtype_specifications = {
    "dtype": torch.float64,
    "device": torch.device("cpu")
}

dataset_size = 100000
X, Y = create_joint_x_y(dataset_size)
batch_size = 256
n, d = Y.shape
m = n

monge_map_network = nn.Sequential(
    nn.Linear(d + X.shape[1], 20),
    nn.Softplus(),
    nn.Linear(20, 20),
    nn.Softplus(),
    nn.Linear(20, d)
)

c_optimality_regularizer = 0.1
jaccobian_regularizer = 0.1

monge_map_network.to(**device_and_dtype_specifications)
optimizer = torch.optim.Adam(monge_map_network.parameters(), 0.01)
X_tensor = torch.tensor(X, **device_and_dtype_specifications)
Y_tensor = torch.tensor(Y, **device_and_dtype_specifications)
progress_bar = trange(1, 10**4, desc="Training")
    
for epoch in progress_bar:
    yindexes = torch.randint(0, dataset_size, (batch_size,))
    entropy_indexes = torch.randint(0, dataset_size, (256,))

    X_batch = X_tensor[yindexes]
    Y_batch = Y_tensor[yindexes]
    U_batch = torch.randn_like(Y_batch)
    Y_pushforward = Y_batch - monge_map_network(torch.cat([X_batch, U_batch], dim=1))
    U_batch_Y_pushforward_pairwise_distance = torch.cdist(U_batch, Y_pushforward)**2
    Y_batch_Y_pushforward_pairwise_distance = torch.cdist(Y_batch, Y_pushforward)**2

    fitting_transport_plan = optimal_transport_plan(U_batch, Y_pushforward)
    fitting_cost = torch.sum(fitting_transport_plan * U_batch_Y_pushforward_pairwise_distance)

    c_optimality_transport_plan = optimal_transport_plan(Y_batch, Y_pushforward)
    c_optimality_cost = (
        torch.mean(
            torch.norm(Y_batch - Y_pushforward, dim=1)**2
        ) - torch.sum(
            c_optimality_transport_plan *
            Y_batch_Y_pushforward_pairwise_distance
        )
    )

    jvp_vector, vjp_vector = torch.randn_like(U_batch), torch.randn_like(U_batch)
    _, monge_map_network_jvp = torch.autograd.functional.jvp(
        lambda x: monge_map_network(torch.cat([X_batch, x], dim=1)),
        Y_batch, jvp_vector
    )

    _, monge_map_network_vjp = torch.autograd.functional.vjp(
        lambda x: monge_map_network(torch.cat([X_batch, x], dim=1)),
        Y_batch, vjp_vector
    )

    jaccobian_cost = torch.mean(torch.norm(monge_map_network_jvp - monge_map_network_vjp, dim=1))

    monge_map_network.zero_grad()
    monge_map_objective = fitting_cost + c_optimality_regularizer * c_optimality_cost + jaccobian_regularizer * jaccobian_cost
    monge_map_objective.backward()
    optimizer.step()

    progress_bar.set_description(f"Epoch: {epoch}, fitting_cost: {fitting_cost:.3f}, c_optimality_cost: {c_optimality_cost:.3f}, jaccobian_cost: {jaccobian_cost:.3f}")

Epoch: 9999, fitting_cost: 2.247, c_optimality_cost: 5.094, jaccobian_cost: 0.136: 100%|██████████| 9999/9999 [01:41<00:00, 98.31it/s] 


In [None]:
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt
from data import create_conditional_x

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10), subplot_kw={'projection': '3d'})
fig.suptitle('Separated 3D Plots', fontsize=16)

ax1.set_title('Conditional Scatter Data (y_x_gt)')
ax1.set_xlabel('Axis 0')
ax1.set_ylabel('Axis 1')
ax1.set_zlabel('x_ value')

for x_ in range(50, 250, 10):
    x = np.array([x_ / 100])[:, None]
    _, y_x_gt = create_conditional_x(n_points=100, x_value=x_/100)
    z_scatter = np.full(y_x_gt.shape[0], x)
    ax1.scatter(y_x_gt[:, 0], y_x_gt[:, 1], z_scatter, color='blue', marker='o', s=30, alpha=0.2)
    x_tensor = torch.tensor([x_ / 100], **device_and_dtype_specifications)[:, None].repeat(repeats=(100, 1))
    y_x_gt_tensor = torch.tensor(y_x_gt, **device_and_dtype_specifications)

    pushforward_of_y = y_x_gt_tensor - monge_map_network(torch.cat([x_tensor, y_x_gt_tensor], dim=1)).detach().cpu().numpy()
    ax2.scatter(pushforward_of_y[:, 0], pushforward_of_y[:, 1], color='red', marker='o', s=30, alpha=0.2)
    
ax1.view_init(elev=-55, azim=154, roll=-83)

ax2.set_title('Contour Lines')
ax2.set_xlabel('Axis 0')
ax2.set_ylabel('Axis 1')
ax2.set_zlabel('x_ value')

# loop_start_value = 50
# for x_ in range(loop_start_value, 250, 10):

#     x = torch.tensor([x_ / 100], **device_and_dtype_specifications)[:, None]
#     x = x.repeat(repeats=(100, 1))

#     colors = ['red', 'purple', 'green', 'orange']
#     radii = [0.1, 0.5, 1., 1.5]
#     for contour_radius, color in zip(radii, colors):
#         pi_tensor = torch.linspace(-torch.pi, torch.pi, 100)
#         u_tensor = torch.stack([
#             contour_radius * torch.cos(pi_tensor),
#             contour_radius * torch.sin(pi_tensor),
#         ], dim=1)

#         z_line = x.detach().cpu().numpy()
#         label = f'Radius {contour_radius}' if x_ == loop_start_value else ""
#         ax1.plot(pushforward_of_u[:, 0], pushforward_of_u[:, 1], z_line, color=color, linewidth=2.5, label=label)
#         ax2.plot(pushforward_of_u[:, 0], pushforward_of_u[:, 1], z_line, color=color, linewidth=2.5, label=label)

ax2.view_init(elev=-55, azim=154, roll=-83)
ax2.legend()

plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make room for suptitle
plt.show()