Relaxed entropy dual with normal feed forward nn

In [1]:
import torch
from torch import nn

class BetaNetwork(torch.nn.Module):
    """Input Convex Neural Network for beta potential estimation"""

    def __init__(self, input_dimension: int, hidden_dimension: int, num_hidden_layers: int, output_dimension: int):
        super().__init__()

        Wzs = []
        Wzs.append(nn.Linear(input_dimension, hidden_dimension))
        for _ in range(num_hidden_layers - 1):
            Wzs.append(torch.nn.Linear(hidden_dimension, hidden_dimension, bias=False))
        Wzs.append(torch.nn.Linear(hidden_dimension, 1, bias=False))
        self.Wzs = torch.nn.ModuleList(Wzs)

        Wxs = []
        for _ in range(num_hidden_layers - 1):
            Wxs.append(nn.Linear(input_dimension, hidden_dimension))
        Wxs.append(nn.Linear(input_dimension, output_dimension, bias=False))
        self.Wxs = torch.nn.ModuleList(Wxs)
        self.act = nn.Softplus()

    def forward(self, x):
        z = self.act(self.Wzs[0](x))
        for Wz, Wx in zip(self.Wzs[1:-1], self.Wxs[:-1]):
            z = self.act(Wz(z) + Wx(x))
        return self.Wzs[-1](z) + self.Wxs[-1](x)

In [2]:
def torch_sphere_uniform(n, d, **kwargs):
    """Generate n points inside the d-dimensional sphere."""
    random_vectors = torch.randn(n, d, **kwargs)
    vectors_norms = torch.norm(random_vectors, dim=1, keepdim=True)
    radius = torch.pow(torch.rand(n, 1, **kwargs), 1. / d)
    return radius * random_vectors / vectors_norms

In [3]:
from data_utils import create_joint_x_y
import numpy as np

num_points_to_generate = 1000
X, Y = create_joint_x_y(num_points_to_generate)

n, d = Y.shape
m = n

nu = np.ones((n, 1)) / n
mu = np.ones((m, 1)) / m

phi_network = nn.Sequential(
    nn.Linear(d + X.shape[1], 100),
    nn.Softplus(),
    nn.Linear(100, 1)
)
psi_network = nn.Sequential(
    nn.Linear(d + X.shape[1], 100),
    nn.Softplus(),
    nn.Linear(100, 1),
)

In [4]:
import torch
torch.manual_seed(0)

device_and_dtype_specifications = dict(dtype=torch.float64, device=torch.device("cpu"))
epsilon = 0.01
num_epochs = 10

phi_network.to(**device_and_dtype_specifications)
psi_network.to(**device_and_dtype_specifications)

phi_network_optimizer = torch.optim.Adam([dict(params=phi_network.parameters())], lr=0.01)
psi_network_optimizer = torch.optim.Adam([dict(params=psi_network.parameters())], lr=0.01)

X, Y = create_joint_x_y(num_points_to_generate)
X_tensor = torch.tensor(X, **device_and_dtype_specifications)
Y_tensor = torch.tensor(Y, **device_and_dtype_specifications)
# U_tensor = torch_sphere_uniform(num_points_to_generate, Y.shape[1], **device_and_dtype_specifications)
dataset_size = num_points_to_generate
batch_size = 256


def estimate_entropy_dual(X_tensor, Y_tensor, U_tensor, phi_net, psi_net, k=5, epsilon=0.1, use_log=True):
        """
        Estimate the dual objective term for entropy estimation.

        This function implements the core calculation based on nearest neighbors and learned
        potential functions phi and psi. It offers an option to oversample some x's to better approximate P(Y|X)

        Args:
        X_tensor (torch.Tensor): The input tensor for x, with shape [n, p].
        Y_tensor (torch.Tensor): The input tensor for y, with shape [n, q].
        U_tensor (torch.Tensor): The tensor of oversampled variables u, with shape [m, q].
        phi_net (nn.Module): The neural network representing the potential function phi(u, x).
        psi_net (nn.Module): The neural network representing the potential function psi(x, y).
        k (int, optional): The number of nearest neighbors to use. Defaults to 5.
        epsilon (float, optional): A small positive constant for the calculation. Defaults to 0.1.

        Returns:
        torch.Tensor: A scalar tensor representing the estimated dual value.
        """
        # Get dimensions from input tensors
        n, _ = X_tensor.shape
        m, _ = U_tensor.shape

        dists = torch.cdist(X_tensor, X_tensor, p=2) # Shape: [n, n] [i, j] = ||x_i - x_j||^2
        _, topk_indices = torch.topk(dists, k, dim=1, largest=False)
        neighbor_indices = topk_indices[:, :]  # Shape: [n, k]
        Y_neighbors = Y_tensor[neighbor_indices]  # Shape: [n, k, q] [i, k] = y_i^k

        U_expanded = U_tensor.unsqueeze(1).expand(-1, n, -1)  # Shape: [m, n, q] [i, :, :] = u_i
        X_expanded_for_U = X_tensor.unsqueeze(0).expand(m, -1, -1)  # Shape: [m, n, p] [:, i, :] = x_i
        UX = torch.cat((X_expanded_for_U, U_expanded), dim=-1) # Shape: [m, n, q + p] [i, j] = torch.cat[u_i, x_j]

        X_expanded_for_Y = X_tensor.unsqueeze(1).expand(-1, k, -1) # Shape: [n, k, p]
        YX = torch.cat((X_expanded_for_Y, Y_neighbors), dim=-1) # Shape: [n, k, p + q] [i, j] = torch.cat[x_i, y_i^k]

        phi_vals = phi_net(UX).squeeze(-1)  # Shape: [m, n] [i, j] = phi(u_i, x_j)
        psi_vals = psi_net(YX).squeeze(-1)  # Shape: [n, k] [i, j] = psi(x_j, y_j^k)
        einsum_term = torch.einsum('mq,nkq->mnk', U_tensor, Y_neighbors) # Shape: [m, n, k]

        phi_vals_expanded = phi_vals.unsqueeze(-1)  # Shape: [m, n, 1]
        psi_vals_expanded = psi_vals.unsqueeze(0)   # Shape: [1, n, k]

        slackness = ( einsum_term - phi_vals_expanded - psi_vals_expanded )
        max_slackness = torch.max(slackness)
        exponent_val = torch.exp((slackness - max_slackness) / epsilon )
        dual_estimate = epsilon * torch.mean( exponent_val ) * torch.exp(max_slackness / epsilon)
        return dual_estimate


for epoch_idx in range(1, num_epochs):

        phi_network.zero_grad()
        psi_network.zero_grad()

        yindexes = torch.randint(0, dataset_size, (batch_size,))
        entropy_indexes = torch.randint(0, dataset_size, (16,))

        X_batch = X_tensor[yindexes]
        Y_batch = Y_tensor[yindexes]
        U_batch = torch.randn(
                batch_size, Y_batch.shape[1],
                **device_and_dtype_specifications
        )

        phi = phi_network(torch.cat([X_batch, U_batch], dim=1))
        psi = psi_network(torch.cat([X_batch, Y_batch], dim=1))

        entropy = estimate_entropy_dual(
                X_tensor=X_tensor[entropy_indexes],
                Y_tensor=Y_tensor[entropy_indexes],
                U_tensor=U_batch,
                phi_net=phi_network,
                psi_net=psi_network,
                k=1,
                epsilon=epsilon,
                use_log=True
        )
        objective = torch.mean(phi) + torch.mean(psi) + entropy

        objective.backward()
        phi_network_optimizer.step()
        psi_network_optimizer.step()
        print(objective.item(), epoch_idx)

_ = phi_network.eval()
_ = psi_network.eval()

inf 1
nan 2
nan 3
nan 4
nan 5
nan 6
nan 7
nan 8
nan 9
