# Import Statements

All the necessary libraries are imported in the cell below.

In [None]:
# ruff: noqa: D101, D102, D107, E501, FBT003, N806, PD011, PLR0913, PLR2004, S311, T201
import math
import random
from collections.abc import Callable
from typing import Literal

import entmax
import matplotlib.pyplot as plt
import numpy as np
import sklearn.cluster
import torch
import torch.nn.functional as F  # noqa: N812
from scipy import stats
from torch import Tensor, nn, optim
from torch.optim import lr_scheduler
from torch.utils import data
from tqdm import tqdm

# Utility Functions

In [None]:
def seed_all(seed: int) -> None:
    """Seeds all random number generators."""
    random.seed(seed)
    np.random.seed(seed)  # noqa: NPY002
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def change_range(
    x: Tensor,
    old_range: tuple[float, float],
    new_range: tuple[float, float],
) -> Tensor:
    """Changes the range of each element in the tensor.

    !!! note
        If the old range is equal to the new range, then the tensor is returned as is.

    Args:
        x: The tensor to change the range of.
        old_range: The old range of the tensor.
        new_range: The new range of the tensor.

    Returns:
        The tensor with the new range.
    """
    if old_range == new_range:
        return x

    old_min, old_max = old_range
    new_min, new_max = new_range

    return ((x - old_min) / (old_max - old_min)) * (new_max - new_min) + new_min

In [None]:
def sample_from_func(
    func: Callable[[Tensor], Tensor],
    seeds: tuple[int, int] | Tensor,
    num_steps: int = 100,
    step_size: float = 0.1,
    noise_size: float | None = None,
    range_: tuple[float, float] = (-1, 1),
) -> Tensor:
    """Samples from a distribution using Langevin dynamics."""
    if isinstance(seeds, tuple):
        x = torch.rand(seeds, device="cuda")
        x = change_range(x, (0, 1), range_)
    else:
        x = seeds.clone()

    x.requires_grad_(True)

    noise = torch.empty_like(x, requires_grad=False)
    if noise_size is None:
        noise_step = math.sqrt(2 * step_size) if step_size > 1 else (2 * step_size) ** 2
    else:
        noise_step = noise_size

    optimizer = optim.SGD([x], lr=step_size)

    for _ in range(num_steps):
        optimizer.zero_grad()
        y = func(x).sum()
        y.backward()
        optimizer.step()

        with torch.no_grad():
            x.clamp_(range_[0], range_[1])

            if noise_step > 0:
                noise.normal_(0, noise_step)
                x.add_(noise)
                x.clamp_(range_[0], range_[1])

    return x.detach()


def sample_from_model(
    model: nn.Module,
    seeds: tuple[int, int] | Tensor,
    num_steps: int = 100,
    step_size: float = 0.1,
    noise_size: float | None = None,
) -> Tensor:
    """Samples from a model using Langevin dynamics."""
    model.requires_grad_(False)
    model.eval()

    x = sample_from_func(model, seeds, num_steps, step_size, noise_size)

    model.requires_grad_(True)
    model.train()

    return x


def distance_matrix(
    x: Tensor,
    y: Tensor,
    metric: str = "euclidean",
) -> Tensor:
    """Computes the distance matrix between two sets of points.

    Args:
        x: The first set of points. This must be a tensor of shape (N, D).
        y: The second set of points. This must be a tensor of shape (M, D).
        metric: The distance metric to use. This can be either "euclidean" or "cosine".

    Returns:
        The distance matrix between the two sets of points.
        This will be a tensor of shape (N, M).
    """
    match metric:
        case "euclidean":
            return torch.cdist(x, y)
        case "cosine":
            return 1 - F.cosine_similarity(x[:, None, :], y[None, :, :], dim=-1)
        case _:
            msg = f"Unknown distance: {metric}"
            raise ValueError(msg)


def estimate_bandwidth(x: Tensor, quantile: float = 0.3) -> float:
    """Estimates the bandwidth for mean shift clustering."""
    n_neighbors = int(quantile * x.shape[0])
    n_neighbors = max(1, n_neighbors)

    dist = distance_matrix(x, x)
    # dist.fill_diagonal_(torch.inf)

    dist, _ = torch.topk(dist, n_neighbors, largest=False)
    dist = dist[:, -1]

    return dist.mean().item()


def meanshift(
    x: Tensor,
    num_iters: int = 100,
    bandwidth: float | None = None,
    centroids: Tensor | None = None,
) -> tuple[Tensor, Tensor]:
    """Clusters the data points using mean shift."""
    if bandwidth is None:
        bandwidth = estimate_bandwidth(x)

    if centroids is None:
        centroids = x

    stop_threshold = 1e-3 * bandwidth
    for _ in range(num_iters):
        cx_dist = distance_matrix(centroids, x)  # (C, N)
        inside = (cx_dist <= bandwidth).float()  # (C, N)

        denominator = inside.sum(dim=1, keepdim=True)  # (C, 1)
        new_centroids = (inside @ x) / denominator  # (C, D)

        mask = (denominator == 0).squeeze_(1)
        new_centroids[mask] = centroids[mask]

        if torch.norm(new_centroids - centroids) <= stop_threshold:
            break

        centroids = new_centroids

    cx_dist = distance_matrix(centroids, x)
    inside = cx_dist <= bandwidth
    count = inside.sum(dim=1)  # (C,)

    cc_dist = distance_matrix(centroids, centroids)  # (C, C)
    overlap = cc_dist <= bandwidth
    overlap.fill_diagonal_(False)  # (C, C)
    not_overlap = ~overlap

    keep = torch.ones_like(count, dtype=torch.bool)
    order = torch.argsort(count, descending=True)
    for i in order:
        if keep[i]:
            keep &= not_overlap[i]

    centroids = centroids[keep]

    xc_dist = distance_matrix(x, centroids)  # (N, C)
    labels = torch.argmin(xc_dist, dim=1)  # (N,)

    return centroids, labels


def cluster(x: Tensor) -> tuple[Tensor, Tensor]:
    """Clusters the data points using mean shift.

    Returns:
        The cluster centers and the cluster assignments.
    """
    ms = sklearn.cluster.MeanShift(max_iter=100)
    ms.fit(x.cpu().numpy())
    centers = torch.tensor(ms.cluster_centers_, device=x.device)
    labels = torch.tensor(ms.labels_, device=x.device)
    return centers, labels


def plot_points(x: Tensor, range_: tuple[float, float] | None = None) -> None:
    """Plots the points."""
    x = x.cpu()
    if x.shape[1] == 1:
        plt.hist(x, bins=100)
        if range_ is not None:
            plt.xlim(*range_)
    elif x.shape[1] == 2:
        plt.scatter(x[:, 0], x[:, 1])
        if range_ is not None:
            plt.xlim(*range_)
            plt.ylim(*range_)
    else:
        return

    plt.show()


def plot_clusters(x: Tensor, centers: Tensor, labels: Tensor) -> None:
    """Plots the clusters."""
    if x.shape[1] != 2:
        return

    x, centers, labels = x.cpu(), centers.cpu(), labels.cpu()
    plt.scatter(x[:, 0], x[:, 1], c=labels, cmap="tab20")
    plt.scatter(centers[:, 0], centers[:, 1], c="red", s=100, marker="x")
    plt.show()

# Benchmarking Functions

In [None]:
def sphere(x: Tensor) -> Tensor:
    """Computes the energy of the sphere distribution."""
    return (x**2).sum(dim=1)


def rosenbrock(x: Tensor) -> Tensor:
    """Computes the energy of the Rosenbrock function."""
    first_factor = 100 * (x[:, 1:] - x[:, :-1] ** 2) ** 2
    second_factor = (1 - x[:, :-1]) ** 2
    return (first_factor + second_factor).sum(dim=1)


def rastrigin(x: Tensor) -> Tensor:
    """Computes the energy of the Rastrigin function."""
    return 10 * x.shape[1] + (x**2 - 10 * torch.cos(2 * math.pi * x)).sum(dim=1)


def ackley(x: Tensor) -> Tensor:
    """Computes the energy of the Ackley function."""
    n = x.shape[1]
    first_factor = -20 * torch.exp(-0.2 * torch.sqrt((x**2).sum(dim=1) / n))
    second_factor = -torch.exp((torch.cos(2 * math.pi * x)).sum(dim=1) / n)
    return first_factor + second_factor + 20 + math.e


def griewank(x: Tensor) -> Tensor:
    """Computes the energy of the Griewank function."""
    n = x.shape[1]
    first_factor = (x**2).sum(dim=1) / 4000
    second_factor = torch.prod(
        torch.cos(x / torch.sqrt(torch.arange(1, n + 1, device=x.device))), dim=1
    )
    return first_factor - second_factor + 1

# Model Definition

In [None]:
class MeanShift(nn.Module):
    """Gaussian Mean-Shift clustering."""

    def __init__(
        self,
        sigma: float,
        damping: float = 0.5,
        max_iter: int = 50,
        shift_tol: float = 1e-4,
        alpha: float = 2,
        *,
        learnable_sigma: bool = True,
    ) -> None:
        """Initializes the clustering algorithm."""
        super().__init__()

        self.sigma = nn.Parameter(
            torch.tensor(float(sigma)), requires_grad=learnable_sigma
        )
        self.damping = damping
        self.max_iter = max_iter
        self.shift_tol = shift_tol
        self.alpha = alpha

    def find_centroids(self, x: Tensor) -> Tensor:
        """Finds the centroids of the clusters using the mean-shift algorithm.

        Args:
            x: The input data tensor. This is expected to be a 2D tensor of shape
                (N, D), where N is the number of data points and D is the dimensionality
                of the data.

        Returns:
            The centroids of the clusters. This is a 2D tensor of shape (C, D), where C
            is the number of clusters.
        """
        centroids = x.clone()
        gamma = 1 / (2 * self.sigma**2)
        for _ in range(self.max_iter):
            dist = torch.cdist(x, centroids, p=2)  # (N, N)
            weight = torch.exp(-(dist**2) * gamma)  # (N, N)

            nominator = weight @ x  # (N, D)
            denominator = weight.sum(dim=1, keepdim=True)  # (N, 1)
            new_centroids = nominator / denominator

            shift = (new_centroids - centroids).norm(dim=1, p=2)  # (N,)
            if shift.max() < self.shift_tol:
                break

            centroids = self.damping * centroids + (1 - self.damping) * new_centroids

        centroids = _nms(x, centroids, threshold=self.sigma.item())

        return centroids

    def assign(self, x: Tensor, centroids: Tensor) -> Tensor:
        """Assigns the data points to the clusters.

        Args:
            x: The input data tensor. This is expected to be a 2D tensor of shape
                (N, D), where N is the number of data points and D is the dimensionality
                of the data.
            centroids: The centroids of the clusters. This is a 2D tensor of shape
                (C, D), where C is the number of clusters.

        Returns:
            A tensor of shape (N, C), where each row is a probability distribution over
            the clusters.
        """
        cost = -torch.cdist(x, centroids, p=2)  # (N, C)

        match self.alpha:
            case 1:
                assignment = torch.softmax(cost, dim=1)
            case 1.5:
                assignment = entmax.entmax15(cost, dim=1)
            case 2:
                assignment = entmax.sparsemax(cost, dim=1)
            case math.inf:
                assigned_to = cost.argmax(dim=1)  # (N,)
                assignment = torch.zeros_like(cost)
                assignment[torch.arange(x.size(0), device=x.device), assigned_to] = 1
                # use straight-through estimator
                assignment = assignment - cost.detach() + cost
            case _:
                msg = f"Invalid alpha value: {self.alpha}."
                raise RuntimeError(msg)

        return assignment  # type: ignore

    def __call__(self, x: Tensor) -> tuple[Tensor, Tensor]:
        """Performs mean-shift clustering on the input data.

        Args:
            x: The input data tensor. This is expected to be a 2D tensor of shape
                (N, D), where N is the number of data points and D is the dimensionality
                of the data.

        Returns:
            A tuple containing the centroids and the assignment of the data points to
            the clusters. The centroids tensor has shape (C, D), where C is the number
            of clusters. The assignment tensor has shape (N, C), where each row is a
            probability distribution over the clusters.
        """
        centroids = self.find_centroids(x)
        assignment = self.assign(x, centroids)

        return centroids, assignment


# --------------------------------------------------------------------------- #
# Helper Functions
# --------------------------------------------------------------------------- #


def _nms(x: Tensor, centroids: Tensor, threshold: float) -> Tensor:
    pc_dist = torch.cdist(x, centroids, p=2)  # (N, N)
    _, clostest_centroid = pc_dist.min(dim=1)  # (N,)
    uniques, counts = clostest_centroid.unique(return_counts=True)
    scores = torch.zeros_like(clostest_centroid)
    scores[uniques] = counts

    cc_dist = torch.cdist(centroids, centroids, p=2)
    overlap = cc_dist < threshold
    overlap.fill_diagonal_(fill_value=False)

    order = scores.argsort(descending=True)
    keep = torch.ones_like(scores, dtype=torch.bool)
    for idx in order:
        if keep[idx]:
            keep &= ~overlap[idx]
            keep[idx] = scores[idx] > 0

    return centroids[keep]  # (C, D)

In [None]:
class GeneEmbeddings(nn.Module):
    """Embeds genes values into a high-dimensional space."""

    def __init__(
        self,
        degree: int = 3,
        grid_size: int = 5,
        embed_dim: int = 64,
    ) -> None:
        super().__init__()

        self.degree = 3
        num_internal_knots = grid_size + 1
        num_knots = num_internal_knots + 2 * degree
        knots = torch.linspace(-1, 1, num_internal_knots)
        diff = knots[1] - knots[0]
        knots = torch.cat([
            knots[0] - diff * torch.arange(degree, 0, -1),
            knots,
            knots[-1] + diff * torch.arange(1, degree + 1),
        ])
        self.register_buffer("knots", knots)
        self.knots: Tensor

        num_control_points = num_knots - degree - 1
        # n_splines = embed_dim
        cp = torch.empty(embed_dim, num_control_points)
        self.cp = nn.Parameter(cp, requires_grad=True)

        self.reset_parameters()

    def __call__(self, x: Tensor) -> Tensor:
        """Embeds the gene values into a high-dimensional space.

        Args:
            x: The gene values to embed. This should be a tensor of shape `(B, N)`,
                where `B` is the batch size and `N` is the number of genes.

        Returns:
            The embedded gene values. This is a tensor of shape `(B, N, D)`, where `D`
            is the embedding dimension.
        """
        cp = self.cp  # (embed_dim, num_control_points)

        B, N = x.shape
        x = x.flatten()  # (B * N,)
        x = x.unsqueeze(0).expand(cp.size(0), -1)  # (embed_dim, B * N)

        knots = self.knots.unsqueeze(0).expand(cp.size(0), -1)  # (embed_dim, num_knots)

        embeds = compute_b_spline(x, knots, cp, self.degree)  # (embed_dim, B * N)
        embeds = embeds.view(cp.size(0), B, N)  # (embed_dim, B, N)
        embeds = embeds.permute(1, 2, 0)  # (B, N, embed_dim)

        return embeds

    @torch.no_grad()  # type: ignore
    def reset_parameters(self) -> None:
        nn.init.normal_(self.cp, mean=0.0, std=0.1)


def compute_basis_functions(x: Tensor, knots: Tensor, k: int) -> Tensor:
    """Computes the b-spline basis functions for the given input data.

    Args:
        x: The points at which to evaluate the basis functions. This should be a
            tensor of shape `(n_splines, n_points)`.
        knots: The knot vector. This should be a tensor of shape `(n_splines, n_knots)`.
        k: The degree of the spline.

    Returns:
        The basis functions evaluated at the given points. This is a tensor of
        shape `(n_splines, n_basis_functions, n_points)`. The number of basis functions
        is equal to `n_knots - k - 1`.
    """
    if k == 0:
        x = x.unsqueeze(1)  # (n_splines, 1, n_points)
        knots = knots.unsqueeze(2)  # (n_splines, n_knots, 1)
        bf_k = (x >= knots[:, :-1]) & (x < knots[:, 1:])
        bf_k = bf_k.to(x.dtype)
    else:
        bf_k_minus_1 = compute_basis_functions(x, knots, k - 1)
        x = x.unsqueeze(1)  # (n_splines, 1, n_points)
        knots = knots.unsqueeze(2)  # (n_splines, n_knots, 1)

        first = (x - knots[:, : -(k + 1)]) / (knots[:, k:-1] - knots[:, : -(k + 1)])
        first = first * bf_k_minus_1[:, :-1]

        second = (knots[:, k + 1 :] - x) / (knots[:, k + 1 :] - knots[:, 1:(-k)])
        second = second * bf_k_minus_1[:, 1:]

        bf_k = first + second

    return bf_k


def compute_b_spline(
    x: Tensor,
    knots: Tensor,
    coeffs: Tensor,
    k: int,
) -> Tensor:
    """Computes the b-spline function for the given input data.

    Args:
        x: The points at which to evaluate the b-spline. This should be a tensor
            of shape `(n_splines, n_points)`.
        knots: The knot vector. This should be a tensor of shape `(n_splines, n_knots)`.
        coeffs: The coefficients of the b-spline. This should be a tensor of shape
            `(n_splines, n_coeffs)`.
        k: The degree of the spline.

    Returns:
        The b-spline evaluated at the given points. This is a tensor of shape
        `(n_splines, n_points)`.
    """
    # (n_splines, n_coeffs, n_points)
    basis_functions = compute_basis_functions(x, knots, k)
    # (n_splines, n_points)
    b_spline = torch.einsum("ij,ijk->ik", coeffs, basis_functions)

    return b_spline

In [None]:
class HyperGraphConvolution(nn.Module):
    """A hypergraph convolution module."""

    def __init__(
        self,
        embed_dim: int,
        num_layers: int = 1,
        activation: Callable[[], nn.Module] = nn.ReLU,
    ) -> None:
        super().__init__()

        self.proj = nn.ModuleList([
            nn.Linear(embed_dim, embed_dim) for _ in range(num_layers)
        ])
        self.activation = activation()

    def __call__(self, x: Tensor, b: Tensor) -> Tensor:
        """Performs a hypergraph convolution.

        Args:
            x: The nodes embeddings. This should be a tensor of shape `(B, N, D)`.
                where `N` is the number of nodes and `D` is the dimensionality of the
                embeddings.
            b: The hypergraph incidence matrix. This should be a tensor of shape
                `(N, M)`, where `M` is the number of hyperedges.

        Returns:
            The updated node embeddings. This is a tensor of shape `(B, N, D)`.
        """
        adj = b @ b.T  # (N, N)
        adj = adj.unsqueeze(0)  # (1, N, N)
        for layer in self.proj:
            x = layer(x)  # (B, N, D)
            x = adj @ x  # (B, N, D)
            x = self.activation(x)

        return x

In [None]:
class EDEN(nn.Module):
    """Estimation of Distribution using Energy-based models (EDEN)."""

    def __init__(
        self,
        num_genes: int,
        embed_dim: int,
        meanshift: MeanShift,
        gene_embeds: GeneEmbeddings,
        hconv: HyperGraphConvolution,
        pooling: Literal["max", "mean", "sum"] = "max",
    ) -> None:
        super().__init__()

        self.embed_dim = embed_dim
        self.memory = nn.Parameter(torch.randn(num_genes, embed_dim))

        self.meanshift = meanshift
        self.gene_embeds = gene_embeds
        self.hconv = hconv

        self.pooling = pooling
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, 1, bias=False),
            nn.Tanh(),
        )

    @property
    def name(self) -> str:
        return f"spline-{self.pooling}"

    def __call__(self, x: Tensor) -> Tensor:
        """Computes the energy of the data points.

        Args:
            x: The data points to compute the energy of. This should be a tensor of
                shape `(B, N)`, where `B` is the batch size and `N` is the number of
                genes.

        Returns:
            The energy of the data points. This is a tensor of shape `(B,)`.
        """
        _, assignment = self.meanshift(self.memory)
        embeds = self.gene_embeds(x)
        embeds = self.hconv(embeds, assignment)  # (B, N, D)

        match self.pooling:
            case "max":
                embeds = embeds.max(dim=1).values
            case "mean":
                embeds = embeds.mean(dim=1)
            case "sum":
                embeds = embeds.sum(dim=1)
            case _:
                msg = f"Unknown pooling method: {self.pooling}"
                raise ValueError(msg)

        energies = self.mlp(embeds)  # (B, 1)
        return energies.squeeze(-1)

# Optimization Loop

In [None]:
# all hyperparameters
# set these to the values you want

seed_all(42)

RANGE = (-5, 5)
NUM_POINTS = 10000
GENOME_SIZE = 10
BASE_LR = 1e-3
BASE_NOISE = 1e-4
QUANTILE = 0.2
NUM_EXTERNAL_EPOCHS = 10
NUM_INTERNAL_EPOCHS = 10

# training loop hyperparameters
FITNESS_BOUND = 5**2 * GENOME_SIZE
INITIAL_DATASET_SIZE = 1_000
BATCH_DIM = 128
TRAINING_SUBSET_SIZE = (10000 - INITIAL_DATASET_SIZE) // NUM_EXTERNAL_EPOCHS

# oracle
energy = sphere

# energy-based model
gene_embeddings = GeneEmbeddings(embed_dim=64)
ms = MeanShift(sigma=0.1)
hconv = HyperGraphConvolution(embed_dim=64, num_layers=1)
model = EDEN(
    num_genes=GENOME_SIZE,
    embed_dim=64,
    meanshift=ms,
    gene_embeds=gene_embeddings,
    hconv=hconv,
    pooling="max",
)
model = model.cuda()

optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0)
# we found schedulers to be detrimental to the training
# scheduler = lr_scheduler.CosineAnnealingLR(optimizer, NUM_EPOCHS, eta_min=1e-6)
# scheduler = lr_scheduler.StepLR(optimizer, step_size=NUM_EPOCHS, gamma=0.1)


def train(loader: data.DataLoader[tuple[Tensor, Tensor]]) -> None:
    """Training loop."""
    for _ in tqdm(range(NUM_INTERNAL_EPOCHS), desc="Internal Training", total=NUM_INTERNAL_EPOCHS):  # fmt: skip
        optimizer.zero_grad()
        for x, y in loader:
            x, y = x.cuda(), y.cuda()  # noqa: PLW2901
            e = model(x)
            loss = F.mse_loss(e, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()


def langevin(x: Tensor, step_size: Tensor, noise_size: Tensor | None = None) -> Tensor:
    """Samples from a distribution using Langevin dynamics."""
    x = x.clone()
    x.requires_grad_(True)
    if noise_size is None:
        noise_size = step_size**2

    noise = torch.empty_like(x, requires_grad=False)

    for _ in range(50):
        y = energy(x).sum()
        grad = torch.autograd.grad(y, x)[0]

        with torch.no_grad():
            x.sub_(step_size * grad)
            x.clamp_(*RANGE)

            noise.normal_(0, 1)
            x.add_(noise.mul_(noise_size))
            x.clamp_(*RANGE)

    return x.detach()


def loop() -> Tensor:  # noqa: PLR0915
    """Main optimization loop."""
    # populate buffer with initial samples
    x = torch.rand(INITIAL_DATASET_SIZE, GENOME_SIZE).mul_(2).sub_(1)
    y = energy(change_range(x, (-1, 1), (-5, 5)) ** 2) / FITNESS_BOUND

    dataset = data.TensorDataset(x, y)
    loader = data.DataLoader(dataset, batch_size=BATCH_DIM, shuffle=True)

    x = torch.rand((NUM_POINTS, GENOME_SIZE), device="cuda").mul_(2).sub_(1)
    lr = torch.full_like(x, BASE_LR)
    noise_size = torch.full_like(x, BASE_NOISE)
    prev_clusters = []

    train(loader)  # type: ignore

    for _ in tqdm(range(NUM_EXTERNAL_EPOCHS), desc="Optimization", total=NUM_EXTERNAL_EPOCHS):  # fmt: skip
        x = langevin(x, lr, noise_size)

        bandwidth = estimate_bandwidth(x, quantile=QUANTILE)
        centroids, labels = meanshift(x, num_iters=100, bandwidth=bandwidth)

        new_clusters: list[tuple[Tensor, float]] = []
        for j in range(centroids.shape[0]):
            points = x[labels == j]
            fitness = float(energy(points).mean())
            new_clusters.append((points, fitness))

        f_denominator = sum(1 / f for _, f in new_clusters)

        new_points = []
        new_noise_sizes = []
        for i, (c_points, c_fitness) in enumerate(new_clusters):
            B = math.ceil(NUM_POINTS * (1 / c_fitness) / f_denominator)
            if B <= GENOME_SIZE:
                continue

            mean, cov = torch.mean(c_points, dim=0), torch.cov(c_points.T)
            # make sure the covariance matrix is positive definite
            cov = cov + torch.eye(GENOME_SIZE, device=cov.device) * 1e-6
            mn = torch.distributions.MultivariateNormal(mean, cov)
            p = mn.rsample((B,))  # type: ignore
            new_points.append(p)

            best_match, best_icd = None, torch.inf
            for j in range(len(prev_clusters)):
                icd = distance_matrix(c_points, prev_clusters[j][0])
                icd = icd.min(dim=1).values.mean().item()
                # print(f"ICD: {icd:.2f}, bandwidth: {bandwidth:.2f}")
                if icd < best_icd and icd < bandwidth:
                    best_match, best_icd = j, icd

            if best_match is not None:
                # print(f"Matched with cluster {best_match}")
                _, prev_fitness = prev_clusters[best_match]
                factor = math.sqrt(c_fitness / prev_fitness)
                new_clusters[i] = (c_points, min(c_fitness, prev_fitness))
                # print(f"Factor: {factor:.2f}")
            else:
                factor = 1.0

            extension = c_points.max(dim=0).values - c_points.min(dim=0).values / 2
            c_noise_size = torch.full((B, GENOME_SIZE), BASE_NOISE, device="cuda")
            c_noise_size = c_noise_size * extension * factor
            new_noise_sizes.append(c_noise_size)

        prev_clusters = new_clusters
        x = torch.cat(new_points, dim=0)
        x.clamp_(*RANGE)
        lr = torch.full_like(x, BASE_LR)
        noise_size = torch.cat(new_noise_sizes, dim=0)

        # add new points to the dataset
        indices = torch.randint(0, x.shape[0], (TRAINING_SUBSET_SIZE,), device="cuda")
        x_subset = x[indices]
        y_subset = energy(change_range(x_subset, (-1, 1), (-5, 5)) ** 2) / FITNESS_BOUND

        old_x, old_y = dataset.tensors
        new_x = torch.cat([old_x, x_subset.cpu()], dim=0)
        new_y = torch.cat([old_y, y_subset.cpu()], dim=0)
        dataset = data.TensorDataset(new_x, new_y)
        loader = data.DataLoader(dataset, batch_size=BATCH_DIM, shuffle=True)
        train(loader)  # type: ignore

    # sample from the found points the best one
    e = model(x)
    idx, _ = e.min(dim=0)
    return x[idx]

# Check Fitting

In [None]:
models = [model]

In [None]:
# splot the 2d landscape
# let's  not focus on the whole space, but on the [-r, r] x [-r, r] subspace
# to verify whether the model has correctly learned the landscape for the points
# with better fitness
NUM_SAMPLES = 200
with torch.no_grad():
    low, high = -0.001, 0.001
    elev, azim = None, None
    x = torch.linspace(low, high, NUM_SAMPLES)
    y = torch.linspace(low, high, NUM_SAMPLES)
    points = torch.cartesian_prod(x, y)
    X = points[:, 0].reshape(NUM_SAMPLES, NUM_SAMPLES)
    Y = points[:, 1].reshape(NUM_SAMPLES, NUM_SAMPLES)
    # X, Y = torch.meshgrid(x, y, indexing="ij")
    Z = X**2 + Y**2

    n_plots = len(models) + 1
    fig = plt.figure(figsize=(n_plots * 5, 5))
    axes = fig.subplots(1, n_plots, subplot_kw={"projection": "3d"})

    axes[0].plot_surface(X.cpu(), Y.cpu(), Z.cpu(), cmap="viridis")
    axes[0].set_title("Ground truth")
    axes[0].set_xlabel("x")
    axes[0].set_ylabel("y")
    axes[0].set_zlabel("fitness")
    axes[0].view_init(elev=elev, azim=azim)

    solutions = points.cuda()
    solutions = change_range(solutions, (-5, 5), (-1, 1))

    for j, model in enumerate(models):
        f_hat = model(solutions) * FITNESS_BOUND
        f_hat = f_hat.reshape(NUM_SAMPLES, NUM_SAMPLES)

        axes[j + 1].plot_surface(X.cpu(), Y.cpu(), f_hat.cpu(), cmap="viridis")
        axes[j + 1].set_title(model.name)
        axes[j + 1].set_xlabel("x")
        axes[j + 1].set_ylabel("y")
        axes[j + 1].set_zlabel("fitness")

        # change rotation
        axes[j + 1].view_init(elev=elev, azim=azim)

    plt.show()