# Exploring Diverse Solutions for Underdetermined Problems

This notebook accompanies the workshop paper "Exploring Diverse Solutions for Underdetermined Problems".

The goal of this notebook is to illustrate the nearest-neighbor diversity loss on finite vector and function spaces.
The notebook consists of 3 parts.
1) **Horseshoe**: We start with a finite vector space examples and look at some variants and properties of the nearest-neighbor and Leinster diversity loss.
2) **Flat parametric curve**: We then show how the nearest neighbor diversity loss acts on simple parametric curves in the flat plane.
3) **Parametric curve on manifold**: Lastly, we show the nearest neighbor diversity on a parametric

In [None]:
## GLOBAL IMPORTS
import math
from functools import partial
from tqdm import trange
import numpy as np
import torch
from torch import relu
import matplotlib.pyplot as plt

# utilities import
from utils.util import tensor_product_xz
from utils.model_defs import Net, ConditionalNet
from utils.sampling_primitives import sample_bbox, get_meshgrid_in_domain2d, get_meshgrid_in_domain3d
from utils.horse_shoe import horse_shoe_sdf, get_horse_shoe_bounds

In [None]:
## SET SEEDS
np.random.seed(42)

torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## 1. Horseshoe

In this example we show the nearest-neighbor diversity loss on finite vector spaces, concretely $\mathbb{R}^2$.

We introduce a design region or envelope where points shall be in.
For this example, the design region is a horse shoe and it is implicitly defined by a signed distance function (SDF).
An SDF defines a shape by assigning each point in the domain a scalar value. This value is equal to the distance to the boundary of the shape. To distinguish inside and outside of the shape, inside regions have a negative sign.

More formally, let $\phi(x)$ be a signed distance function. Then the domain $\Omega$ is defined as:

$
\Omega = \{ x \in \mathbb{R}^n \mid \phi(x) < 0 \}
$

The boundary of the shape is given by the levelset 0, as the distance to the boundary is 0:

$
\partial \Omega = \{ x \in \mathbb{R}^n \mid \phi(x) = 0 \}
$



In [None]:
horse_shoe_bounds = get_horse_shoe_bounds()
X0_horse_shoe, X1_horse_shoe, pts = get_meshgrid_in_domain2d(horse_shoe_bounds)
sdf_horse_shoe = horse_shoe_sdf(pts)
sdf_horse_shoe = sdf_horse_shoe.reshape(X0_horse_shoe.shape).detach()

im = plt.contourf(X0_horse_shoe, X1_horse_shoe, sdf_horse_shoe, levels=50)
plt.contour(X0_horse_shoe, X1_horse_shoe, sdf_horse_shoe, levels=[0], colors='k')
plt.axis('scaled')
plt.colorbar(im)
plt.show()

Next, we define the nearest neighbor diversity loss as Berzins et al.
This loss pushes each element of a set away from its nearest neighbor.

To highlight some important aspects of this loss, we put it into contrast with another diversity loss, the repell_all_diversity.
As the name implies, each element of a set is pushed away from all others.
We will see later that this fails to establish diversity among points in the horseshoe.

In [None]:
def nearest_neighbor_diversity(D, p=.5):
    """
    Takes the pairwise dissimilarity matrix D and computes a diversity loss by repelling the nearest neighbor.
    The power p should be <= 1 to make the loss concave.
    """
    # Create a mask to exclude diagonal elements
    D = D.masked_fill(torch.eye(D.size(0), dtype=torch.bool), float('inf'))
    nearest_neighbor_d, _ = D.min(dim=1)
    return -(nearest_neighbor_d.pow(p)).mean().pow(1/p)

def leinster_diversity(D, q=.5):
    """
    Takes the pairwise dissimilarity matrix D and computes the Leinster diversity loss.
    We follow Equation 6.5 in Leinster's book (https://arxiv.org/pdf/2012.02113).
    The sampling probability p is assumed to be 1/n * one-vector.
    """
    assert D.ndim == 2 and D.shape[0] == D.shape[1], "Assure quadratic matrix"
    assert q > 0 and q != 1 and math.isfinite(q), "valid range violated, see Leinster, p. 175"

    # convert to similarity matrix using formula on bottom of p. 172
    Z = torch.exp(-D)

    # convert matrix
    Zp = Z.mean(dim=-1)
    term_sum = torch.pow(Zp, q-1).mean()
    div = term_sum ** (1 / (1-q))
    return -div

To keep points with the design region, we introduce a design region loss, as Berzins et al.
This loss enacts a counterforce to the diversity loss and constrains points to state within the horseshoe.

In [None]:
def design_region_loss(y, sdf_func=horse_shoe_sdf):
    """Transform a SDF of a design region into an objective: 0 inside, distance squared outside."""
    loss_per_sample = sdf_func(y).relu().square()
    return loss_per_sample.mean()


### Finite point set

Next, we train our model with different variants of the nearest neighbor and Leinster's diversity loss.
Importantly, we ablate different exponents.

In [None]:
def train_horse_shoe(div_loss_fn, lambda_div, n_points=1000, n_iter=1000):

    x = sample_bbox(horse_shoe_bounds, N=n_points).detach()
    x.requires_grad_(True)
    opt = torch.optim.Adam([x], lr=1e-2)

    for i in (pbar := trange(n_iter)):
        opt.zero_grad()
        loss_obj = design_region_loss(x)
        pairwise_dist_mat = torch.cdist(x, x)
        loss_div = lambda_div*div_loss_fn(pairwise_dist_mat)
        loss = loss_obj + loss_div
        loss.backward()
        opt.step()

        pbar.set_description(f"{loss_obj.item():.2e}, {loss_div.item():.2e}")

    return x.detach()

lambda_div=1e-2
n_points = 1000
x_nn_05 = train_horse_shoe(partial(nearest_neighbor_diversity, p=0.5),lambda_div=1e-2)
x_nn_1 = train_horse_shoe(partial(nearest_neighbor_diversity, p=1),lambda_div=1e-2)
x_nn_2 = train_horse_shoe(partial(nearest_neighbor_diversity, p=2),lambda_div=1e-2)

x_hill_05 = train_horse_shoe(partial(leinster_diversity, q=0.5),lambda_div=lambda_div, n_points=n_points)
x_hill_1 = train_horse_shoe(partial(leinster_diversity, q=0.99),lambda_div=lambda_div, n_points=n_points)
x_hill_2 = train_horse_shoe(partial(leinster_diversity, q=2),lambda_div=lambda_div, n_points=n_points)



The following cell visualizes the results. There are two things to learn.
First, using the Leinster loss with similarity repells from all other points. While the repelling happens with smaller weight for farer away points, the loss is dominated by the biggest distances. This pushes the points to the outer boundary of the domain.
Second, the nearest neighbor loss should be concave, which is achieved by using a power $p <= 1$. This way smaller distances have a larger gradient than greater distances.

The points under the nearest neighbor loss behave similarly to molecules in a room, which are repelled only by their neighbors and can therefore more evenly spread.

In [None]:

fig, axs = plt.subplots(2, 3, figsize=(12, 8))  # 2 rows, 3 columns
axs = axs.ravel()

point_sets = [x_nn_05, x_nn_1, x_nn_2, x_hill_05, x_hill_1, x_hill_2]  # Assuming you have x_nn5 and x_nn6
titles = [
    r'$\delta_p$, p=0.5',
    r'$\delta_p$, p=1',
    r'$\delta_p$, p=2',
    r'$D_q^Z$, q=0.5',
    r'$D_q^Z$, q=0.99',
    r'$D_q^Z$, q=2',
    ]

for i, (ax, pts) in enumerate(zip(axs, point_sets)):
    ax.contour(X0_horse_shoe, X1_horse_shoe, sdf_horse_shoe, levels=[0], colors='k')
    ax.contourf(X0_horse_shoe, X1_horse_shoe, sdf_horse_shoe, levels=[sdf_horse_shoe.min(), 0, sdf_horse_shoe.max()], colors=['#0000ff', 'white'], alpha=0.5)
    ax.scatter(*pts.T.detach(), c='r', label='y', marker='.', s=2)
    ax.axis('off')
    ax.axis('scaled')
    ax.set_title(titles[i], fontsize=16)  # Set the title for each subplot

plt.subplots_adjust(hspace=0.1, wspace=-0.1)
plt.show()


### Neural Curve

Next, we define a simple MLP which takes a modulation vector $z \in [0, 1]$ as input and outputs 2d points. We optimize it to map points into the horseshoe and to increase diversity.

To run the experiment with diversity simply comment out the loss term in the training loop.

In [None]:
ny = 2
nx = 1
x_range = [-1, 1]
n_samples = 100
n_iter = 10000

lambda_design = 10
lambda_div = 1

model2 = Net(layer_widths=[nx, 40, 40, ny])
opt = torch.optim.Adam(model2.parameters(), lr=3e-3)

In [None]:
## Logging
loss_keys = ['design region', 'diversity']
loss_over_iters = {key: {} for key in loss_keys}


for i in (pbar := trange(n_iter)):
    opt.zero_grad()

    ## SPHERICALITY ##
    x = torch.rand(size=(n_samples, 1)) * (x_range[1] - x_range[0]) + x_range[0]
    x = torch.cat([x, torch.tensor([[0.0], [1.0]])])
    y = model2(x)
    loss_design = design_region_loss(y, sdf_func=horse_shoe_sdf)

    ## DIVERSITY ##
    D = torch.cdist(y, y)
    diversity = nearest_neighbor_diversity(D, p=0.5)

    loss_over_iters['design region'][i] = loss_design.item()
    loss_over_iters['diversity'][i] = -diversity.item()

    loss = lambda_design * loss_design + diversity * lambda_div
    loss.backward()
    opt.step()

    #pbar.set_description(f"design region: {loss_design.item():.2e}")  # , diversity: {diversity.item():.2f}")
    pbar.set_description(f"design region: {loss_design.item():.2e}, diversity: {diversity.item():.2f}")

In [None]:
plt.plot(list(loss_over_iters['design region'].keys()), list(loss_over_iters['design region'].values()), label='design region')
plt.plot(list(loss_over_iters['diversity'].keys()), list(loss_over_iters['diversity'].values()), label='diversity')
plt.semilogy()
plt.legend()
plt.show()

In [None]:
xs = [torch.linspace(x_range[0], x_range[1], 10*n_samples) for i in range(nx)]
Xs = torch.meshgrid(*xs)
x = torch.vstack([X.flatten() for X in Xs]).T
y = model2(x)
points = y.detach().cpu().numpy()

plt.figure(figsize=(8, 6))

ax = plt.gca()
ax.contour(X0_horse_shoe, X1_horse_shoe, sdf_horse_shoe, levels=[0], colors='k', linewidths=2)
ax.contourf(X0_horse_shoe, X1_horse_shoe, sdf_horse_shoe, levels=[sdf_horse_shoe.min(), 0, sdf_horse_shoe.max()], colors=['#0000ff', 'white'], alpha=0.5)
ax.plot(points[:, 0], points[:, 1], c='r', linewidth=3)
ax.axis('off')
ax.axis('scaled')

plt.show()

## 2. Parametric function in 1D

In this example, we will use the nearest_neighbor_diversity loss on elements of function space.
The goal is to find a diverse set of curves that connect the points (-0.8, 0) with (0.8, 0).
Again all the curves should stay within a design region (envelope), which in this case is a circle of radius 1.


A neural network takes as input 2 variables x and z.
- x denotes the position of the parameterized curve (also often denoted t in mathematics). Usually $t \in [0, 1]$
- z distinguishes different curves. This way we can parameterize an infinite number of curves


We start by defining a pairwise distance on the curves.

In [None]:
def pairwise_dist_curves(y):
    """
    y: [num_curves, num_points, 2] tensor
    pairwise_dist[i, j] contains the distance between curves i and j
    """
    pairwise_dist = (y.unsqueeze(0) - y.unsqueeze(1)).norm(dim=-1).mean(-1)
    return pairwise_dist

Next, we train the curves, sampling different points x and different curve parameters z.

In [None]:
nx = 1
ny = 2
nz = 5

n_iter = 3000
lambda_div = 2
lambda_envelope = 1
lambda_interface = 10
envelope_radius = 1
n_latents_per_iter = 100
n_points_per_iter = 50
layer_widths = [nx+nz, 40, 40, ny]

## End points
x_endp = torch.tensor([[0.0], [1.0]])
y_endp = torch.tensor([[-.8, 0],[.8,0]])

# define model
model2 = ConditionalNet(layer_widths=layer_widths)

## Optimizer
opt = torch.optim.Adam(model2.parameters(), lr=1e-2)

# logging
loss_over_iters = {}
loss_over_iters['interface'] = {}
loss_over_iters['envelope'] = {}
loss_over_iters['diversity'] = {}


## Train
for i in (pbar := trange(n_iter)):
    opt.zero_grad()

    z = torch.rand(n_latents_per_iter, nz) ## number of codes, bz

    ## INTERFACE
    x_tp, z_tp = tensor_product_xz(x_endp, z)  # repeats and interleaves x and z
    y = model2(x_tp, z_tp).reshape(len(z), len(x_endp), ny) ## [bz, bx, ny]
    loss_interface = (y - y_endp[None,:,:]).square().sum() / len(y)

    # FORWARD for envelope and diversity
    x = torch.rand(n_points_per_iter, nx)
    x_tp, z_tp = tensor_product_xz(x, z)
    y = model2(x_tp, z_tp).reshape(len(z), len(x), ny) ## [bz, bx, ny]

    ## ENVELOPE
    r = y.norm(dim=2)
    loss_envelope = relu(r-envelope_radius).square().sum()/nz

    ## DIVERSITY
    distances = pairwise_dist_curves(y)
    loss_diversity =  nearest_neighbor_diversity(distances)

    # logging
    loss_over_iters['interface'][i] = loss_interface.item()
    loss_over_iters['envelope' ][i] = loss_envelope.item()
    loss_over_iters['diversity'][i] = -loss_diversity.item()

    # total
    loss = lambda_interface*loss_interface + lambda_envelope*loss_envelope + lambda_div*loss_diversity
    loss.backward()
    opt.step()

    pbar.set_description(f"interface: {loss_interface.item():.2e}, envelope: {loss_envelope.item():.2e}, diversity: {loss_diversity.item():.2e}")


plt.plot(list(loss_over_iters['interface'].keys()), list(loss_over_iters['interface'].values()), label='interface')
plt.plot(list(loss_over_iters['envelope' ].keys()), list(loss_over_iters['envelope'].values()),  label='envelope')
plt.plot(list(loss_over_iters['diversity'].keys()), list(loss_over_iters['diversity'].values()), label='diversity')
plt.semilogy()
plt.legend()
plt.show()

The following cell plots a single parametric curve with a randomly sampled z.

In [None]:
n_curves_to_plot = 1
x_resolution = 400
display_envelope = False

x = torch.linspace(0, 1, x_resolution)[:,None]
z = torch.rand(n_curves_to_plot, nz) ## number of codes, bz

x_tp, z_tp = tensor_product_xz(x, z)
y = model2(x_tp, z_tp).detach().reshape(len(z), len(x), ny)
y = y.detach().numpy()

import matplotlib.patches as patches
fig, ax = plt.subplots(figsize=(8,8))
ax.plot(*y.T, c='k', alpha=0.8)
ax.scatter(*y_endp.T, c='k')
ax.axis('equal')
ax.axis('off')
if display_envelope:
    circle = patches.Circle((0, 0), radius=envelope_radius, fill=False, color='lightblue', linewidth=2)
    ax.add_patch(circle)
plt.show()

Now we plot multiple samples on a linear trajectory in the z space.
The diversity loss successfully pushes the curves apart.

In [None]:
steps = 10
x_resolution = 400
interpolate_random_endpoints = False
display_envelope = False
plot_use_alpha = True

fig, ax = plt.subplots(1, 1, figsize=(8, 8))

if interpolate_random_endpoints:
    z0 = torch.rand(1,nz) ## latent code start
    z1 = torch.rand(1,nz) ## latent code stop
    print(z0-z1)
else:
    z0 = torch.tensor([[1.0,]*nz]) ## latent code start
    z1 = torch.tensor([[0.0,]*nz]) ## latent code stop

dz = (z1-z0)/steps
x = torch.linspace(0, 1, x_resolution)[:,None]
if plot_use_alpha:
    alphas = np.hstack([np.linspace(.2, 1.0, steps//2), np.linspace(1.0, .2, steps//2)])
else:
    cmap = 'winter'
    colormap = plt.get_cmap(cmap)
    colors = [colormap(i) for i in np.linspace(0, 1, steps)]

for i in trange(steps):
    z = z0 + dz*i
    x_tp, z_tp = tensor_product_xz(x, z)
    y = model2(x_tp, z_tp).detach().reshape(len(z), len(x), ny)
    if plot_use_alpha:
        ax.plot(*y.T, c='k', alpha=alphas[i])
    else:
        ax.plot(*y.T, c=colors[i])


ax.scatter(*y_endp.T, c='k')
ax.axis('equal')
ax.axis('off')
if display_envelope:
    circle = patches.Circle((0, 0), radius=envelope_radius, fill=False, color='lightblue', linewidth=2)
    ax.add_patch(circle)
fig.patch.set_facecolor('white')
plt.show()

## 3. Curves on a sphere

In this example we want to find diverse points on sphere. First, we solve this task with a finite set of points similar to the horseshoe example.

Next, we train a NN which parametrizes a curve to spread over the manifold.
Concretely we learn a neural network with parameters $\theta$ that represents a function $f_\theta: \mathbb{R}^1 \mapsto \mathbb{R}^3$ on the sphere.


Again we start by introducing a pairwise distance function on curves on the sphere.
As we will use it for the nearest neighbor and assume it is close, we use the euclidean distance as a local approximation of the geodesic.

In [None]:
### Needed to enable k3d in colab. In a local env the jupyter nbextension commands and google.colab imports are not needed
!pip install k3d
!jupyter nbextension install --py --user k3d
!jupyter nbextension enable --py --user k3d

import k3d

In [None]:
from google.colab import output

output.enable_custom_widget_manager()

k3d.switch_to_text_protocol()

In [None]:
def pairwise_dist_points_on_sphere(points):
    """
    y: [num_points, 3] tensor
    pairwise_dist[i, j] contains the distance between points i and j
    """
    pairwise_dist = torch.norm(points[:, None] - points, dim=2, p=2)
    return pairwise_dist


def pairwise_dist_curves_on_sphere(y):
    """
    y: [num_curves, num_points, 2] tensor
    pairwise_dist[i, j] contains the distance between curves i and j
    """
    pairwise_dist = torch.norm(y[:, None] - y, dim=2, p=2)
    return pairwise_dist

We define the spherical loss, as the distance of a point to the sphere.
Note that this is a soft-constraint and only leads to approximate solutions.

### Finite Set of Points

In [None]:
ny = 3
nx = 1
x_range = [-1., 1.]
n_samples = 500
n_iter = 1000
sphere_radius = 1

lambda_spherical = 100
lambda_div = 1

box_bounds = torch.tensor([[x_range[0], x_range[1]], [x_range[0], x_range[1]], [x_range[0], x_range[1]]])  # [3, 2] tensor for 3D box bounds
pts = sample_bbox(box_bounds, N=n_samples)
pts.requires_grad_(True)
opt = torch.optim.Adam([pts], lr=3e-3)

## Logging
loss_keys = ['sphericality', 'diversity']
loss_over_iters = {key: {} for key in loss_keys}

In [None]:
for i in (pbar := trange(n_iter)):
    opt.zero_grad()

    ## SPHERICALITY ##
    r = pts.norm(dim=-1)
    loss_sphericality = (r-sphere_radius).square().mean()

    ## DIVERSITY ##
    D = pairwise_dist_points_on_sphere(pts)
    diversity = nearest_neighbor_diversity(D)

    loss_over_iters['sphericality'][i] = loss_sphericality.item()
    loss_over_iters['diversity'][i] = - diversity.item()

    loss = lambda_spherical * loss_sphericality + diversity * lambda_div
    loss.backward()
    opt.step()


    pbar.set_description(f"sphericality: {loss_sphericality.item():.2e}, diversity: {diversity.item():.2f}")


plt.plot(list(loss_over_iters['sphericality'].keys()), list(loss_over_iters['sphericality'].values()), label='sphericality')
plt.plot(list(loss_over_iters['diversity'].keys()), list(loss_over_iters['diversity'].values()), label='diversity')
plt.semilogy()
plt.legend()
plt.show()

In [None]:
###Care: Colab can be bit buggy with k3d, sometimes as it only shows empty plots, see https://github.com/K3D-tools/K3D-jupyter section "Google Colab"
import k3d
plot = k3d.plot(height=1000)
plot += k3d.points(pts.detach().cpu().numpy(), point_size=0.03, color=0xff0000)
plot += k3d.points([0,0,0], point_size=2, shader="mesh", mesh_detail=10, color=0x0000ff, opacity=0.5)

plot.display()

### Neural curve

In [None]:
ny = 3
nx = 1
x_range = [-1, 1]
n_samples = 100
n_iter = 2500
# div_scale = 1e-3
sphere_radius = 1

lambda_spherical = 1
lambda_div = 1
torch.manual_seed(0)

model3 = Net(layer_widths=[nx, 40, 40, ny])
opt = torch.optim.Adam(model3.parameters(), lr=1e-3)

## Logging
loss_keys = ['sphericality', 'diversity']
loss_over_iters = {key: {} for key in loss_keys}


for i in (pbar := trange(n_iter)):
    opt.zero_grad()

    ## SPHERICALITY ##
    x = torch.rand(size=(n_samples, 1)) * (x_range[1] - x_range[0]) + x_range[0]
    y = model3(x)
    r = y.norm(dim=-1)
    loss_sphericality = (r-sphere_radius).square().mean()

    ## DIVERSITY ##
    D = pairwise_dist_curves_on_sphere(y)
    diversity = nearest_neighbor_diversity(D)

    loss_over_iters['sphericality'][i] = loss_sphericality.item()
    loss_over_iters['diversity'][i] = - diversity.item()

    loss = lambda_spherical * loss_sphericality + diversity * lambda_div
    loss.backward()
    opt.step()


    pbar.set_description(f"sphericality: {loss_sphericality.item():.2e}, diversity: {diversity.item():.2f}")


plt.plot(list(loss_over_iters['sphericality'].keys()), list(loss_over_iters['sphericality'].values()), label='sphericality')
plt.plot(list(loss_over_iters['diversity'].keys()), list(loss_over_iters['diversity'].values()), label='diversity')
plt.semilogy()
plt.legend()
plt.show()

A 3D visualization domonstrates that the curve spreads over the sphere.
Feel free to decrease the lambda_div variable and compare the result.

In [None]:
###Care: Colab is a bit buggy with k3d, sometimes as it only shows empty plots
n_samples = 10000

xs = [torch.linspace(x_range[0], x_range[1], n_samples) for i in range(nx)]
Xs = torch.meshgrid(*xs)
x = torch.vstack([X.flatten() for X in Xs]).T
y = model3(x).detach().cpu().numpy()

plot2 = k3d.plot(height=1000)
plot2 += k3d.points(y, point_size=0.03, color=0xff0000)
plot2 += k3d.points([0,0,0], point_size=2, shader="mesh", mesh_detail=10, color=0x0000ff, opacity=0.5)
plot2.display()