Relaxed entropy dual with normal feed forward nn

In [1]:
import torch
from torch import nn

class BetaNetwork(torch.nn.Module):
    """Input Convex Neural Network for beta potential estimation"""

    def __init__(self, input_dimension: int, hidden_dimension: int, num_hidden_layers: int, output_dimension: int):
        super().__init__()

        Wzs = []
        Wzs.append(nn.Linear(input_dimension, hidden_dimension))
        for _ in range(num_hidden_layers - 1):
            Wzs.append(torch.nn.Linear(hidden_dimension, hidden_dimension, bias=False))
        Wzs.append(torch.nn.Linear(hidden_dimension, 1, bias=False))
        self.Wzs = torch.nn.ModuleList(Wzs)

        Wxs = []
        for _ in range(num_hidden_layers - 1):
            Wxs.append(nn.Linear(input_dimension, hidden_dimension))
        Wxs.append(nn.Linear(input_dimension, output_dimension, bias=False))
        self.Wxs = torch.nn.ModuleList(Wxs)
        self.act = nn.Softplus()

    def forward(self, x):
        z = self.act(self.Wzs[0](x))
        for Wz, Wx in zip(self.Wzs[1:-1], self.Wxs[:-1]):
            z = self.act(Wz(z) + Wx(x))
        return self.Wzs[-1](z) + self.Wxs[-1](x)

In [2]:
def torch_sphere_uniform(n, d, **kwargs):
    """Generate n points inside the d-dimensional sphere."""
    random_vectors = torch.randn(n, d, **kwargs)
    vectors_norms = torch.norm(random_vectors, dim=1, keepdim=True)
    radius = torch.pow(torch.rand(n, 1, **kwargs), 1. / d)
    return radius * random_vectors / vectors_norms

In [3]:
from data_utils import create_joint_x_y
import numpy as np

num_points_to_generate = 1000
X, Y = create_joint_x_y(num_points_to_generate)

n, d = Y.shape
m = n

nu = np.ones((n, 1)) / n
mu = np.ones((m, 1)) / m

phi_network = nn.Sequential(
    nn.Linear(d + X.shape[1], 100),
    nn.Softplus(),
    nn.Linear(100, 1)
)
psi_network = nn.Sequential(
    nn.Linear(d + X.shape[1], 100),
    nn.Softplus(),
    nn.Linear(100, 1),
)

In [4]:
import torch
torch.manual_seed(0)

device_and_dtype_specifications = dict(dtype=torch.float64, device=torch.device("cpu"))
epsilon = 1
num_epochs = 5000

phi_network.to(**device_and_dtype_specifications)
psi_network.to(**device_and_dtype_specifications)

phi_network_optimizer = torch.optim.Adam([dict(params=phi_network.parameters())], lr=0.1)
psi_network_optimizer = torch.optim.Adam([dict(params=psi_network.parameters())], lr=0.1)

X, Y = create_joint_x_y(num_points_to_generate)
X_tensor = torch.tensor(X, **device_and_dtype_specifications)
Y_tensor = torch.tensor(Y, **device_and_dtype_specifications)
# U_tensor = torch_sphere_uniform(num_points_to_generate, Y.shape[1], **device_and_dtype_specifications)
dataset_size = num_points_to_generate
batch_size = 256


def estimate_entropy_dual(X_tensor, Y_tensor, U_tensor, phi_net, psi_net, k=5, epsilon=0.1, use_log=True):
        """
        Estimate the dual objective term for entropy estimation.

        This function implements the core calculation based on nearest neighbors and learned
        potential functions phi and psi. It offers an option to oversample some x's to better approximate P(Y|X)

        Args:
        X_tensor (torch.Tensor): The input tensor for x, with shape [n, p].
        Y_tensor (torch.Tensor): The input tensor for y, with shape [n, q].
        U_tensor (torch.Tensor): The tensor of oversampled variables u, with shape [m, q].
        phi_net (nn.Module): The neural network representing the potential function phi(u, x).
        psi_net (nn.Module): The neural network representing the potential function psi(x, y).
        k (int, optional): The number of nearest neighbors to use. Defaults to 5.
        epsilon (float, optional): A small positive constant for the calculation. Defaults to 0.1.

        Returns:
        torch.Tensor: A scalar tensor representing the estimated dual value.
        """
        # Get dimensions from input tensors
        n, _ = X_tensor.shape
        m, _ = U_tensor.shape

        dists = torch.cdist(X_tensor, X_tensor, p=2) # Shape: [n, n] [i, j] = ||x_i - x_j||^2
        _, topk_indices = torch.topk(dists, k, dim=1, largest=False)
        neighbor_indices = topk_indices[:, :]  # Shape: [n, k]
        Y_neighbors = Y_tensor[neighbor_indices]  # Shape: [n, k, q] [i, k] = y_i^k

        U_expanded = U_tensor.unsqueeze(1).expand(-1, n, -1)  # Shape: [m, n, q] [i, :, :] = u_i
        X_expanded_for_U = X_tensor.unsqueeze(0).expand(m, -1, -1)  # Shape: [m, n, p] [:, i, :] = x_i
        UX = torch.cat((X_expanded_for_U, U_expanded), dim=-1) # Shape: [m, n, q + p] [i, j] = torch.cat[u_i, x_j]

        X_expanded_for_Y = X_tensor.unsqueeze(1).expand(-1, k, -1) # Shape: [n, k, p]
        YX = torch.cat((X_expanded_for_Y, Y_neighbors), dim=-1) # Shape: [n, k, p + q] [i, j] = torch.cat[x_i, y_i^k]

        phi_vals = phi_net(UX).squeeze(-1)  # Shape: [m, n] [i, j] = phi(u_i, x_j)
        psi_vals = psi_net(YX).squeeze(-1)  # Shape: [n, k] [i, j] = psi(x_j, y_j^k)
        einsum_term = torch.einsum('mq,nkq->mnk', U_tensor, Y_neighbors) # Shape: [m, n, k]

        phi_vals_expanded = phi_vals.unsqueeze(-1)  # Shape: [m, n, 1]
        psi_vals_expanded = psi_vals.unsqueeze(0)   # Shape: [1, n, k]

        slackness = ( einsum_term - phi_vals_expanded - psi_vals_expanded )
        exponent_val = torch.exp(slackness / epsilon )
        dual_estimate = epsilon * torch.mean( exponent_val )
        return dual_estimate


for epoch_idx in range(1, num_epochs):
        if epoch_idx % 1000 == 0 and epoch_idx != 0:
                epsilon = epsilon * 0.1

        phi_network.zero_grad()
        psi_network.zero_grad()

        yindexes = torch.randint(0, dataset_size, (batch_size,))
        entropy_indexes = torch.randint(0, dataset_size, (16,))

        X_batch = X_tensor[yindexes]
        Y_batch = Y_tensor[yindexes]
        U_batch = torch.randn(
                batch_size, Y_batch.shape[1],
                **device_and_dtype_specifications
        )

        phi = phi_network(torch.cat([X_batch, U_batch], dim=1))
        psi = psi_network(torch.cat([X_batch, Y_batch], dim=1))

        entropy = estimate_entropy_dual(
                X_tensor=X_tensor[entropy_indexes],
                Y_tensor=Y_tensor[entropy_indexes],
                U_tensor=U_batch,
                phi_net=phi_network,
                psi_net=psi_network,
                k=1,
                epsilon=epsilon
        )
        objective = torch.mean(phi) + torch.mean(psi) + entropy

        torch.nn.utils.clip_grad_norm_(phi_network.parameters(), max_norm=1.0)
        torch.nn.utils.clip_grad_norm_(psi_network.parameters(), max_norm=1.0)

        objective.backward()
        phi_network_optimizer.step()
        psi_network_optimizer.step()
        print(objective.item(), epoch_idx, epsilon)

_ = phi_network.eval()
_ = psi_network.eval()

8.322345416888981e+16 1 1
24.4101367076671 2 1
43.02009140285212 3 1
57.40256961175455 4 1
73.13478349194828 5 1
83.54981506716113 6 1
95.47443880041733 7 1
105.41353174690667 8 1
123.50174938321227 9 1
126.27029755044184 10 1
134.56720509664498 11 1
144.0054075861118 12 1
151.68489829750777 13 1
160.12588507118738 14 1
170.46929597736096 15 1
173.52871115593712 16 1
184.40383996072808 17 1
188.35257938934626 18 1
185.41297012765426 19 1
193.9285526502174 20 1
191.59225665717034 21 1
204.72903738267956 22 1
206.0158738813497 23 1
210.6632494194676 24 1
208.2020609970005 25 1
220.96581425593672 26 1
217.04975293857365 27 1
224.62524599087328 28 1
233.67090793482672 29 1
227.2896119177924 30 1
236.07656884789776 31 1
234.33308142208446 32 1
224.5875033727727 33 1
234.17930974208733 34 1
237.81661895043055 35 1
230.36685640743647 36 1
230.67973622654128 37 1
237.25784804850684 38 1
231.404644019061 39 1
228.56097689334925 40 1
245.35260048041908 41 1
246.6703270943855 42 1
235.40827342145

In [5]:
import matplotlib.pyplot as plt

%matplotlib qt

fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')
number_of_points_to_visualize = 1000

with torch.no_grad():
        U_tensor = torch.randn(number_of_points_to_visualize, d, **device_and_dtype_specifications)
        UX_tensor = X_tensor[125:126].repeat(number_of_points_to_visualize, 1)
        potential_tensor = phi_network(torch.cat([UX_tensor, U_tensor], dim=1))

potential = potential_tensor.detach().cpu().numpy()
U = U_tensor.detach().cpu().numpy()
scatter = ax.scatter(U[:, 0], U[:, 1], potential.squeeze(), color='red', marker='o', s=30, alpha=0.6)
ax.grid(True)

ax.view_init(elev=20, azim=120)
ax.set_xlabel('u1')
ax.set_ylabel('u2')
ax.set_zlabel('phi_u')

plt.show()

In [6]:
# Change to %matplotlib qt to have interactive plots
%matplotlib qt

from data_utils import create_conditional_x

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10), subplot_kw={'projection': '3d'})
fig.suptitle('Separated 3D Plots', fontsize=16)

ax1.set_title('Conditional Scatter Data (y_x_gt)')
ax1.set_xlabel('Axis 0')
ax1.set_ylabel('Axis 1')
ax1.set_zlabel('x_ value')

for x_ in range(50, 250, 10):
    x = np.array([x_ / 100])[:, None]

    # This section is now active for the first plot
    _, y_x_gt = create_conditional_x(n_points=100, x_value=x_/100)
    z_scatter = np.full(y_x_gt.shape[0], x)
    ax1.scatter(y_x_gt[:, 0], y_x_gt[:, 1], z_scatter, color='blue', marker='o', s=30, alpha=0.2)

ax1.view_init(elev=-55, azim=154, roll=-83)

ax2.set_title('Contour Lines')
ax2.set_xlabel('Axis 0')
ax2.set_ylabel('Axis 1')
ax2.set_zlabel('x_ value')

loop_start_value = 50
for x_ in range(loop_start_value, 250, 10):

    x = torch.tensor([x_ / 100], **device_and_dtype_specifications)[:, None]
    x = x.repeat(repeats=(100, 1))

    colors = ['red', 'purple', 'green', 'orange']
    radii = [0.1, 0.5, 1., 1.5]
    for contour_radius, color in zip(radii, colors):
        pi_tensor = torch.linspace(-torch.pi, torch.pi, 100)
        u_tensor = torch.stack([
            contour_radius * torch.cos(pi_tensor),
            contour_radius * torch.sin(pi_tensor),
        ], dim=1)

        u_tensor = u_tensor.to(**device_and_dtype_specifications)
        u_tensor.requires_grad = True

        potential = phi_network(torch.cat([x, u_tensor], dim=1))
        pushforward_of_u = torch.autograd.grad(potential.sum(), u_tensor)[0]

        z_line = x.detach().cpu().numpy()
        label = f'Radius {contour_radius}' if x_ == loop_start_value else ""
        ax1.plot(pushforward_of_u[:, 0], pushforward_of_u[:, 1], z_line, color=color, linewidth=2.5, label=label)
        ax2.plot(pushforward_of_u[:, 0], pushforward_of_u[:, 1], z_line, color=color, linewidth=2.5, label=label)

ax2.view_init(elev=-55, azim=154, roll=-83)
ax2.legend()

plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make room for suptitle
plt.show()