In [None]:
# ruff: noqa: D101, D102, D107, E501, FBT003, N806, PD011, PLR2004, S311, T201
import math
from collections.abc import Callable

import matplotlib.pyplot as plt
import sklearn.cluster
import torch
import torch.nn.functional as F  # noqa: N812
from scipy import stats
from torch import Tensor, nn, optim
from torch.optim import lr_scheduler
from torch.utils import data
from tqdm import tqdm

from eden.utils import change_range, seed_all

# Utility Functions

In [None]:
def sample_from_func(
    func: Callable[[Tensor], Tensor],
    seeds: tuple[int, int] | Tensor,
    num_steps: int = 100,
    step_size: float = 0.1,
    noise_size: float | None = None,
    range_: tuple[float, float] = (-1, 1),
) -> Tensor:
    """Samples from a distribution using Langevin dynamics."""
    if isinstance(seeds, tuple):
        x = torch.rand(seeds, device="cuda")
        x = change_range(x, (0, 1), range_)
    else:
        x = seeds.clone()

    x.requires_grad_(True)

    noise = torch.empty_like(x, requires_grad=False)
    if noise_size is None:
        noise_step = math.sqrt(2 * step_size) if step_size > 1 else (2 * step_size) ** 2
    else:
        noise_step = noise_size

    optimizer = optim.SGD([x], lr=step_size)

    for _ in range(num_steps):
        optimizer.zero_grad()
        y = func(x).sum()
        y.backward()
        optimizer.step()

        with torch.no_grad():
            x.clamp_(range_[0], range_[1])

            if noise_step > 0:
                noise.normal_(0, noise_step)
                x.add_(noise)
                x.clamp_(range_[0], range_[1])

    return x.detach()


def sample_from_model(
    model: nn.Module,
    seeds: tuple[int, int] | Tensor,
    num_steps: int = 100,
    step_size: float = 0.1,
    noise_size: float | None = None,
) -> Tensor:
    """Samples from a model using Langevin dynamics."""
    model.requires_grad_(False)
    model.eval()

    x = sample_from_func(model, seeds, num_steps, step_size, noise_size)

    model.requires_grad_(True)
    model.train()

    return x


def distance_matrix(
    x: Tensor,
    y: Tensor,
    metric: str = "euclidean",
) -> Tensor:
    """Computes the distance matrix between two sets of points.

    Args:
        x: The first set of points. This must be a tensor of shape (N, D).
        y: The second set of points. This must be a tensor of shape (M, D).
        metric: The distance metric to use. This can be either "euclidean" or "cosine".

    Returns:
        The distance matrix between the two sets of points.
        This will be a tensor of shape (N, M).
    """
    match metric:
        case "euclidean":
            return torch.cdist(x, y)
        case "cosine":
            return 1 - F.cosine_similarity(x[:, None, :], y[None, :, :], dim=-1)
        case _:
            msg = f"Unknown distance: {metric}"
            raise ValueError(msg)


def estimate_bandwidth(x: Tensor, quantile: float = 0.3) -> float:
    """Estimates the bandwidth for mean shift clustering."""
    n_neighbors = int(quantile * x.shape[0])
    n_neighbors = max(1, n_neighbors)

    dist = distance_matrix(x, x)
    # dist.fill_diagonal_(torch.inf)

    dist, _ = torch.topk(dist, n_neighbors, largest=False)
    dist = dist[:, -1]

    return dist.mean().item()


def meanshift(
    x: Tensor,
    num_iters: int = 100,
    bandwidth: float | None = None,
    centroids: Tensor | None = None,
) -> tuple[Tensor, Tensor]:
    """Clusters the data points using mean shift."""
    if bandwidth is None:
        bandwidth = estimate_bandwidth(x)

    if centroids is None:
        centroids = x

    stop_threshold = 1e-3 * bandwidth
    for _ in range(num_iters):
        cx_dist = distance_matrix(centroids, x)  # (C, N)
        inside = (cx_dist <= bandwidth).float()  # (C, N)

        denominator = inside.sum(dim=1, keepdim=True)  # (C, 1)
        new_centroids = (inside @ x) / denominator  # (C, D)

        mask = (denominator == 0).squeeze_(1)
        new_centroids[mask] = centroids[mask]

        if torch.norm(new_centroids - centroids) <= stop_threshold:
            break

        centroids = new_centroids

    cx_dist = distance_matrix(centroids, x)
    inside = cx_dist <= bandwidth
    count = inside.sum(dim=1)  # (C,)

    cc_dist = distance_matrix(centroids, centroids)  # (C, C)
    overlap = cc_dist <= bandwidth
    overlap.fill_diagonal_(False)  # (C, C)
    not_overlap = ~overlap

    keep = torch.ones_like(count, dtype=torch.bool)
    order = torch.argsort(count, descending=True)
    for i in order:
        if keep[i]:
            keep &= not_overlap[i]

    centroids = centroids[keep]

    xc_dist = distance_matrix(x, centroids)  # (N, C)
    labels = torch.argmin(xc_dist, dim=1)  # (N,)

    return centroids, labels


def cluster(x: Tensor) -> tuple[Tensor, Tensor]:
    """Clusters the data points using mean shift.

    Returns:
        The cluster centers and the cluster assignments.
    """
    ms = sklearn.cluster.MeanShift(max_iter=100)
    ms.fit(x.cpu().numpy())
    centers = torch.tensor(ms.cluster_centers_, device=x.device)
    labels = torch.tensor(ms.labels_, device=x.device)
    return centers, labels


def plot_points(x: Tensor, range_: tuple[float, float] | None = None) -> None:
    """Plots the points."""
    x = x.cpu()
    if x.shape[1] == 1:
        plt.hist(x, bins=100)
        if range_ is not None:
            plt.xlim(*range_)
    elif x.shape[1] == 2:
        plt.scatter(x[:, 0], x[:, 1])
        if range_ is not None:
            plt.xlim(*range_)
            plt.ylim(*range_)
    else:
        return

    plt.show()


def plot_clusters(x: Tensor, centers: Tensor, labels: Tensor) -> None:
    """Plots the clusters."""
    if x.shape[1] != 2:
        return

    x, centers, labels = x.cpu(), centers.cpu(), labels.cpu()
    plt.scatter(x[:, 0], x[:, 1], c=labels, cmap="tab20")
    plt.scatter(centers[:, 0], centers[:, 1], c="red", s=100, marker="x")
    plt.show()

In [None]:
def affinity_propagation(
    x: Tensor,
    damping: float,
    max_iters: int,
    tol: float = 1e-4,
) -> Tensor:
    """Clusters the data points using affinity propagation."""
    similarity = F.cosine_similarity(x[:, None, :], x[None, :, :], dim=-1)
    median = torch.median(similarity)
    similarity.fill_diagonal_(median.item())

    responsibility = torch.zeros_like(similarity)
    availability = torch.zeros_like(similarity)

    converged = False
    t = 0
    while (not converged) and (t < max_iters):
        r_t = _compute_responsibility(similarity, availability)
        responsibility = _damper(responsibility, r_t, damping)

        a_t = _compute_availability(responsibility)
        availability = _damper(availability, a_t, damping)

        ch = max(
            _chebyshev_distance(r_t, responsibility).max().item(),
            _chebyshev_distance(a_t, availability).max().item(),
        )
        ch = ch / (1 - damping)

        converged = ch < tol
        t += 1

    cost = responsibility + availability
    exemplars = torch.nonzero(cost.diagonal() > 0).squeeze(1)
    cost = cost[:, exemplars]

    labels = torch.argmax(cost, dim=1)
    return labels


# --------------------------------------------------------------------------- #
# Helper Functions
# --------------------------------------------------------------------------- #


def _compute_responsibility(similarity: Tensor, availability: Tensor) -> Tensor:
    p_indices = torch.arange(similarity.shape[0], device=similarity.device)

    sum_sa = availability + similarity
    first_max, first_max_idx = torch.max(sum_sa, dim=1)
    sum_sa[p_indices, first_max_idx] = -torch.inf
    second_max, _ = torch.max(sum_sa, dim=1)

    r = similarity - first_max.unsqueeze(1)
    r[p_indices, first_max_idx] = similarity[p_indices, first_max_idx] - second_max

    return r


def _compute_availability(responsibility: Tensor) -> Tensor:
    pos_r = responsibility.clamp_min(0)
    pos_r = pos_r.diagonal_scatter(responsibility.diagonal())

    a = pos_r.sum(dim=0) - pos_r
    a_diagonal = a.diagonal().clone()
    a = a.clamp_max(0)
    return a.diagonal_scatter(a_diagonal)


def _damper(x_p: Tensor, x_t: Tensor, damping: float) -> Tensor:
    return x_p * damping + x_t * (1 - damping)


def _chebyshev_distance(x: Tensor, y: Tensor) -> Tensor:
    return torch.max(torch.abs(x - y), dim=1)[0]

In [None]:
def sphere(x: Tensor) -> Tensor:
    """Computes the energy of the sphere distribution."""
    return (x**2).sum(dim=1)


def rosenbrock(x: Tensor) -> Tensor:
    """Computes the energy of the Rosenbrock function."""
    first_factor = 100 * (x[:, 1:] - x[:, :-1] ** 2) ** 2
    second_factor = (1 - x[:, :-1]) ** 2
    return (first_factor + second_factor).sum(dim=1)


def rastrigin(x: Tensor) -> Tensor:
    """Computes the energy of the Rastrigin function."""
    return 10 * x.shape[1] + (x**2 - 10 * torch.cos(2 * math.pi * x)).sum(dim=1)

In [None]:
x = sample_from_func(
    rastrigin,
    (10000, 2),
    num_steps=100,
    step_size=1e-3,
    noise_size=1e-4,
    range_=(-5, 5),
)

# use binary search to find the best bandwidth
eps = 1 / 10000
low_q = 0 + eps
high_q = 1 - eps
low_b = estimate_bandwidth(x, low_q)
high_b = estimate_bandwidth(x, high_q)

while high_q - low_q > eps:
    mid_q = (low_q + high_q) / 2
    mid_b = estimate_bandwidth(x, mid_q)

    if mid_b - low_b > high_b - mid_b:
        high_q = mid_q
        high_b = mid_b
    else:
        low_q = mid_q
        low_b = mid_b

quantile = (low_q + high_q) / 2
print(quantile)

In [None]:
x = torch.randn(1, 2, requires_grad=True)
e = sphere(x).sum()
grad = torch.autograd.grad(e, x)[0]
x, e, grad

In [None]:
x = sample_from_func(
    rastrigin,
    (10000, 2),
    num_steps=200,
    step_size=1e-3,
    noise_size=0.00,
    range_=(-5, 5),
)
centroids, labels = meanshift(x, bandwidth=estimate_bandwidth(x, 0.01))
plot_clusters(x, centroids, labels)
min_f = rastrigin(x).min().item()
min_f

In [None]:
centroid = centroids[0]
points = x[labels == 0]

# find points whose distance from the centroid is less than the median distance
dist = torch.cdist(points, centroid.unsqueeze(0)).squeeze()
mask = dist <= dist.median()
mask = mask.cpu()
points = points.cpu()
centroid = centroid.cpu()

plt.plot(points[mask][:, 0], points[mask][:, 1], "o", color="blue")
plt.plot(points[~mask][:, 0], points[~mask][:, 1], "o", color="red")
plt.plot(centroid[0], centroid[1], "x", color="black")
plt.show()

In [None]:
RANGE = (-5, 5)
NUM_POINTS = 10000
GENOME_SIZE = 100
BASE_LR = 1e-3
BASE_NOISE = 1e-4
QUANTILE = 0.2

energy = rastrigin


def sample(x: Tensor, step_size: Tensor, noise_size: Tensor | None = None) -> Tensor:
    """Samples from a distribution using Langevin dynamics."""
    x = x.clone()
    x.requires_grad_(True)
    if noise_size is None:
        noise_size = step_size**2

    noise = torch.empty_like(x, requires_grad=False)

    for _ in range(50):
        y = energy(x).sum()
        grad = torch.autograd.grad(y, x)[0]

        with torch.no_grad():
            x.sub_(step_size * grad)
            x.clamp_(*RANGE)

            noise.normal_(0, 1)
            x.add_(noise.mul_(noise_size))
            x.clamp_(*RANGE)

    return x.detach()


x = torch.rand((NUM_POINTS, GENOME_SIZE), device="cuda").mul_(2).sub_(1)
lr = torch.full_like(x, BASE_LR)
noise_size = torch.full_like(x, BASE_NOISE)
prev_clusters = []

for _ in range(20):
    x = sample(x, lr, noise_size)
    f = energy(x)

    min_f, mean_f, max_f = f.min().item(), f.mean().item(), f.max().item()
    print(f"Min: {min_f:.2E}, Mean: {mean_f:.2E}, Max: {max_f:.2E}")
    if min_f < 1e-20:
        print("Minimum reached")
        break

    bandwidth = estimate_bandwidth(x, quantile=QUANTILE)
    centroids, labels = meanshift(x, num_iters=100, bandwidth=bandwidth)
    plot_clusters(x, centroids, labels)

    new_clusters: list[tuple[Tensor, float]] = []
    for j in range(centroids.shape[0]):
        points = x[labels == j]
        fitness = float(energy(points).mean())
        new_clusters.append((points, fitness))

    f_denominator = sum(1 / f for _, f in new_clusters)

    new_points = []
    new_noise_sizes = []
    for i, (c_points, c_fitness) in enumerate(new_clusters):
        N = math.ceil(NUM_POINTS * (1 / c_fitness) / f_denominator)
        if N <= GENOME_SIZE:
            continue

        mean, cov = torch.mean(c_points, dim=0), torch.cov(c_points.T)
        # make sure the covariance matrix is positive definite
        cov = cov + torch.eye(GENOME_SIZE, device=cov.device) * 1e-6
        mn = torch.distributions.MultivariateNormal(mean, cov)
        p = mn.rsample((N,))
        new_points.append(p)

        best_match, best_icd = None, torch.inf
        for j in range(len(prev_clusters)):
            icd = distance_matrix(c_points, prev_clusters[j][0])
            icd = icd.min(dim=1).values.mean().item()
            # print(f"ICD: {icd:.2f}, bandwidth: {bandwidth:.2f}")
            if icd < best_icd and icd < bandwidth:
                best_match, best_icd = j, icd

        if best_match is not None:
            # print(f"Matched with cluster {best_match}")
            _, prev_fitness = prev_clusters[best_match]
            factor = math.sqrt(c_fitness / prev_fitness)
            new_clusters[i] = (c_points, min(c_fitness, prev_fitness))
            # print(f"Factor: {factor:.2f}")
        else:
            factor = 1.0

        extension = c_points.max(dim=0).values - c_points.min(dim=0).values / 2
        c_noise_size = torch.full((N, GENOME_SIZE), BASE_NOISE, device="cuda")
        c_noise_size = c_noise_size * extension * factor
        new_noise_sizes.append(c_noise_size)

    prev_clusters = new_clusters
    x = torch.cat(new_points, dim=0)
    x.clamp_(*RANGE)
    lr = torch.full_like(x, BASE_LR)
    noise_size = torch.cat(new_noise_sizes, dim=0)

    plot_points(x)
    print()

In [None]:
def sample(x: Tensor, step_size: Tensor, noise_size: Tensor | None = None) -> Tensor:
    """Samples from a distribution using Langevin dynamics."""
    x = x.clone()
    x.requires_grad_(True)
    if noise_size is None:
        noise_size = step_size**2

    noise = torch.empty_like(x, requires_grad=False)

    for _ in range(50):
        y = sphere(x).sum()
        grad = torch.autograd.grad(y, x)[0]

        with torch.no_grad():
            x.sub_(step_size * grad)
            x.clamp_(-1, 1)

            noise.normal_(0, 1)
            x.add_(noise.mul_(noise_size))
            x.clamp_(-1, 1)

    return x.detach()


NUM_POINTS = 10000
GENOME_SIZE = 10
x = torch.rand((NUM_POINTS, GENOME_SIZE), device="cuda").mul_(2).sub_(1)
lr = torch.full_like(x, 0.1)
noise_size = torch.full_like(x, 0.01)
prev_clusters = []

for _ in range(20):
    x = sample(x, lr, noise_size)
    f = sphere(x)

    min_f, mean_f, max_f = f.min().item(), f.mean().item(), f.max().item()
    print(f"Min: {min_f:.2E}, Mean: {mean_f:.2E}, Max: {max_f:.2E}")
    if min_f < 1e-20:
        print("Minimum reached")
        break

    bandwidth = estimate_bandwidth(x)
    centroids, labels = meanshift(x, num_iters=100, bandwidth=bandwidth)

    # remove from each cluster the points that are far from the centroid
    for j in range(centroids.shape[0]):
        membership = labels == j
        points = x[membership]
        centroid = points.mean(dim=0)

        dist = distance_matrix(points, centroid[None]).squeeze_(1)
        mask = dist <= dist.median()
        tmp_labels = torch.full(mask.shape, -1, dtype=labels.dtype, device=labels.device)  # fmt: skip
        tmp_labels[mask] = j

        labels[membership] = tmp_labels

    not_membership = labels == -1
    not_membership_count = int(not_membership.sum())
    x[not_membership] = torch.rand((not_membership_count, GENOME_SIZE), device="cuda").mul_(2).sub_(1)  # fmt: skip

    noise_size = torch.full_like(x, 0.01)
    for j in range(centroids.shape[0]):
        membership = labels == j
        points = x[membership]

        if points.shape[0] == 0:
            continue

        extension = points.max(dim=0).values - points.min(dim=0).values / 2
        noise_size[membership] = extension * 0.01

    lr = torch.full_like(x, 0.1)

# Model

In [None]:
def compute_basis_functions(x: Tensor, knots: Tensor, k: int) -> Tensor:
    """Computes the b-spline basis functions for the given input data.

    Args:
        x: The points at which to evaluate the basis functions. This should be a
            tensor of shape `(n_splines, n_points)`.
        knots: The knot vector. This should be a tensor of shape `(n_splines, n_knots)`.
        k: The degree of the spline.

    Returns:
        The basis functions evaluated at the given points. This is a tensor of
        shape `(n_splines, n_basis_functions, n_points)`. The number of basis functions
        is equal to `n_knots - k - 1`.
    """
    if k == 0:
        x = x.unsqueeze(1)  # (n_splines, 1, n_points)
        knots = knots.unsqueeze(2)  # (n_splines, n_knots, 1)
        bf_k = (x >= knots[:, :-1]) & (x < knots[:, 1:])
        bf_k = bf_k.to(x.dtype)
    else:
        bf_k_minus_1 = compute_basis_functions(x, knots, k - 1)
        x = x.unsqueeze(1)  # (n_splines, 1, n_points)
        knots = knots.unsqueeze(2)  # (n_splines, n_knots, 1)

        first = (x - knots[:, : -(k + 1)]) / (knots[:, k:-1] - knots[:, : -(k + 1)])
        first = first * bf_k_minus_1[:, :-1]

        second = (knots[:, k + 1 :] - x) / (knots[:, k + 1 :] - knots[:, 1:(-k)])
        second = second * bf_k_minus_1[:, 1:]

        bf_k = first + second

    return bf_k


def compute_b_spline(
    x: Tensor,
    knots: Tensor,
    coeffs: Tensor,
    k: int,
) -> Tensor:
    """Computes the b-spline function for the given input data.

    Args:
        x: The points at which to evaluate the b-spline. This should be a tensor
            of shape `(n_splines, n_points)`.
        knots: The knot vector. This should be a tensor of shape `(n_splines, n_knots)`.
        coeffs: The coefficients of the b-spline. This should be a tensor of shape
            `(n_splines, n_coeffs)`.
        k: The degree of the spline.

    Returns:
        The b-spline evaluated at the given points. This is a tensor of shape
        `(n_splines, n_points)`.
    """
    # (n_splines, n_coeffs, n_points)
    basis_functions = compute_basis_functions(x, knots, k)
    # (n_splines, n_points)
    b_spline = torch.einsum("ij,ijk->ik", coeffs, basis_functions)

    return b_spline


# -----------------------------------------------------------------------------
# MODEL
# -----------------------------------------------------------------------------


class SplineModel(nn.Module):
    """A model taht uses B-spline encoding to learn an energy function."""

    def __init__(
        self,
        degree: int = 3,
        grid_size: int = 5,
        embed_dim: int = 64,
        pooling: str = "mean",
    ) -> None:
        super().__init__()

        self.degree = 3
        num_internal_knots = grid_size + 1
        num_knots = num_internal_knots + 2 * degree
        knots = torch.linspace(-1, 1, num_internal_knots)
        diff = knots[1] - knots[0]
        knots = torch.cat([
            knots[0] - diff * torch.arange(degree, 0, -1),
            knots,
            knots[-1] + diff * torch.arange(1, degree + 1),
        ])
        self.register_buffer("knots", knots)
        self.knots: Tensor

        num_control_points = num_knots - degree - 1
        # n_splines = embed_dim
        cp = torch.empty(embed_dim, num_control_points)
        self.cp = nn.Parameter(cp, requires_grad=True)

        self.pooling = pooling
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, 1, bias=False),
            nn.Tanh(),
        )

        self.reset_parameters()

    @property
    def name(self) -> str:
        return f"spline-{self.pooling}"

    @torch.no_grad()  # type: ignore
    def reset_parameters(self) -> None:
        nn.init.normal_(self.cp, mean=0.0, std=0.1)

    def forward(self, x: Tensor) -> Tensor:
        cp = self.cp  # (embed_dim, num_control_points)

        N, D = x.shape
        x = x.flatten()  # (N * D,)
        x = x.unsqueeze(0).expand(cp.size(0), -1)  # (embed_dim, N * D)

        knots = self.knots.unsqueeze(0).expand(cp.size(0), -1)  # (embed_dim, num_knots)

        embeds = compute_b_spline(x, knots, cp, self.degree)  # (embed_dim, N * D)
        embeds = embeds.view(cp.size(0), N, D)  # (embed_dim, N, D)
        embeds = embeds.permute(1, 2, 0)  # (N, D, embed_dim)

        match self.pooling:
            case "max":
                embeds = embeds.max(dim=1).values
            case "mean":
                embeds = embeds.mean(dim=1)
            case "sum":
                embeds = embeds.sum(dim=1)
            case _:
                msg = f"Unknown pooling method: {self.pooling}"
                raise ValueError(msg)

        energies = self.mlp(embeds)
        return energies.squeeze(-1)

# Training

In [None]:
# def generate_new_data(
#     model: nn.Module,
#     seeds: tuple[int, int] | Tensor,
#     num_new_points: int,
#     random_ratio: float = 0.1,
# ) -> tuple[Tensor, Tensor]:
#     """Generates new data points using the model."""
#     # sample from the model an high number of points
#     x = sample_from_model(model, seeds, num_steps=50, step_size=0.1, noise_size=0.01)
#     centroids, labels = meanshift(x)

#     _, counts = torch.unique(labels, return_counts=True)
#     fractions: Tensor = counts.float() / labels.size(0)

#     num_random_points = round(num_new_points * random_ratio)
#     num_cluster_points = num_new_points - num_random_points
#     num_points_per_cluster = (fractions * num_cluster_points).round().int()
#     num_points_per_cluster.clamp_(min=1)

#     new_points = []
#     for i, num_points in enumerate(num_points_per_cluster.tolist()):
#         cluster_points = x[labels == i]
#         extension = cluster_points.max(dim=0).values - cluster_points.min(dim=0).values

#         p = torch.rand(num_points, x.size(1), device=x.device).mul_(2).sub_(1)
#         p.mul_(extension).div_(2).add_(centroids[i])

#         new_points.append(p)

#     # add random points
#     p = torch.rand(num_random_points, x.size(1), device=x.device).mul_(2).sub_(1)
#     new_points.append(p)

#     new_points = torch.cat(new_points)
#     return new_points, x


def generate_new_data(
    model: nn.Module,
    seeds: tuple[int, int] | Tensor,
    num_new_points: int,
    random_ratio: float = 0.1,
) -> tuple[Tensor, Tensor]:
    """Generates new data points using the model."""
    # sample from the model an high number of points
    x = sample_from_model(model, seeds, num_steps=50, step_size=0.1, noise_size=0.01)
    _, labels = meanshift(x)

    _, counts = torch.unique(labels, return_counts=True)
    fractions: Tensor = counts.float() / labels.size(0)

    num_random_points = round(num_new_points * random_ratio)
    num_cluster_points = num_new_points - num_random_points
    num_points_per_cluster = (fractions * num_cluster_points).round().int()
    num_points_per_cluster.clamp_(min=1)

    new_points = []
    for i, num_points in enumerate(num_points_per_cluster.tolist()):
        cluster_points = x[labels == i]
        mean, cov = stats.multivariate_normal.fit(cluster_points.cpu().numpy())
        dist = stats.multivariate_normal(mean=mean, cov=cov)
        p = torch.tensor(dist.rvs(num_points), device=x.device, dtype=x.dtype)
        if num_points == 1:
            p = p.unsqueeze(0)
        new_points.append(p)

    # add random points
    p = torch.rand(num_random_points, x.size(1), device=x.device).mul_(2).sub_(1)
    new_points.append(p)

    new_points = torch.cat(new_points)
    return new_points, x

In [None]:
seed_all(3407)
GENOME_SIZE = 2
FITNESS_BOUND = 5**2 * GENOME_SIZE
DATASET_SIZE = 1_000
BATCH_DIM = 128
NUM_EPOCHS = 10

model = SplineModel(pooling="mean").cuda()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0)
# scheduler = lr_scheduler.CosineAnnealingLR(optimizer, NUM_EPOCHS, eta_min=1e-6)
scheduler = lr_scheduler.StepLR(optimizer, step_size=NUM_EPOCHS, gamma=0.1)

# populate buffer with initial samples
x = torch.rand(DATASET_SIZE, GENOME_SIZE).mul_(2).sub_(1)
y = (change_range(x, (-1, 1), (-5, 5)) ** 2).sum(dim=-1) / FITNESS_BOUND

dataset = data.TensorDataset(x, y)
loader = data.DataLoader(dataset, batch_size=BATCH_DIM, shuffle=True)

for _ in tqdm(range(NUM_EPOCHS), desc="Training", total=NUM_EPOCHS):
    for x, y in loader:
        x, y = x.cuda(), y.cuda()  # noqa: PLW2901
        optimizer.zero_grad()
        e = model(x)
        loss = F.mse_loss(e, y)
        loss.backward()
        optimizer.step()

    scheduler.step()

In [None]:
x = sample_from_model(model, (1000, GENOME_SIZE), num_steps=50, step_size=0.1, noise_size=0.01)  # fmt: skip
clusters, labels = cluster(x)
plot_clusters(x, clusters, labels)

In [None]:
points = generate_new_data(model, 100, random_ratio=0.0)
plot_points(points)

In [None]:
seed_all(3407)
GENOME_SIZE = 2
FITNESS_BOUND = 5**2 * GENOME_SIZE
NUM_EXTERNAL_EPOCHS = 20
NUM_NEW_POINTS = 100
NUM_INITIAL_POINTS = 1_000
NUM_INTERNAL_EPOCHS = 10
BATCH_DIM = 128

model = SplineModel(pooling="mean").cuda()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0)

# create initial dataset
x = torch.rand(NUM_INITIAL_POINTS, GENOME_SIZE).mul_(2).sub_(1)
y = (change_range(x, (-1, 1), (-5, 5)) ** 2).sum(dim=-1) / FITNESS_BOUND
dataset = data.TensorDataset(x, y)
loader = data.DataLoader(dataset, batch_size=BATCH_DIM, shuffle=True, pin_memory=True)

prev_sampled_points = torch.rand(1000, GENOME_SIZE, device="cuda").mul_(2).sub_(1)
generated, sampled = [], []

for _ in tqdm(range(NUM_EXTERNAL_EPOCHS), desc="Training", total=NUM_EXTERNAL_EPOCHS):
    # train for INTERNAL_EPOCHS on the current dataset
    for _ in tqdm(range(NUM_INTERNAL_EPOCHS), desc="Internal Training", total=NUM_INTERNAL_EPOCHS):  # fmt: skip
        optimizer.zero_grad()
        for x, y in loader:
            x, y = x.cuda(), y.cuda()  # noqa: PLW2901
            e = model(x)
            loss = F.mse_loss(e, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

    # generate new data points
    x, prev_sampled_points = generate_new_data(model, (1000, GENOME_SIZE), NUM_NEW_POINTS, random_ratio=0.2)  # fmt: skip
    # new_x, prev_sampled_points = generate_new_data(model, prev_sampled_points, NUM_NEW_POINTS, random_ratio=0.2)  # fmt: skip
    generated.append(x)
    sampled.append(prev_sampled_points)
    new_y = (change_range(x, (-1, 1), (-5, 5)) ** 2).sum(dim=-1) / FITNESS_BOUND
    old_x = dataset.tensors[0]
    old_y = dataset.tensors[1]

    # combine old and new data points
    x = torch.cat([old_x, x.cpu()])
    y = torch.cat([old_y, new_y.cpu()])
    dataset = data.TensorDataset(x, y)
    loader = data.DataLoader(dataset, batch_size=BATCH_DIM, shuffle=True, pin_memory=True)  # fmt: skip

In [None]:
for x in sampled:
    plot_points(x, range_=(-0.01, 0.01))

In [None]:
x = sample_from_model(model, (10000, GENOME_SIZE), num_steps=50, step_size=0.1, noise_size=0.01)  # fmt: skip
# y = (change_range(x, (-1, 1), (-5, 5)) ** 2).sum(dim=-1)
y = model(x)

_, min_idx = y.min(dim=0)
min_x = x[min_idx]
min_x, y[min_idx], (change_range(min_x, (-1, 1), (-5, 5)) ** 2).sum()

In [None]:
x = min_x.clone()
x.requires_grad = True

optimizer = optim.SGD([x], lr=0.001)
model.eval()
model.requires_grad_(False)

for _ in range(100):
    optimizer.zero_grad()
    e = model(x.unsqueeze(0))
    e.backward()
    optimizer.step()

model.train()
model.requires_grad_(True)

x = x.detach()
x = change_range(x, (-1, 1), (-5, 5))
x, (x**2).sum()

# Check Fitting

In [None]:
models = [model]

In [None]:
# splot the 2d landscape
# let's  not focus on the whole space, but on the [-1, 1] x [-1, 1] subspace
# to verify whether the model has correctly learned the landscape for the points
# with better fitness
NUM_SAMPLES = 200
with torch.no_grad():
    low, high = -0.001, 0.001
    elev, azim = None, None
    x = torch.linspace(low, high, NUM_SAMPLES)
    y = torch.linspace(low, high, NUM_SAMPLES)
    points = torch.cartesian_prod(x, y)
    X = points[:, 0].reshape(NUM_SAMPLES, NUM_SAMPLES)
    Y = points[:, 1].reshape(NUM_SAMPLES, NUM_SAMPLES)
    # X, Y = torch.meshgrid(x, y, indexing="ij")
    Z = X**2 + Y**2

    n_plots = len(models) + 1
    fig = plt.figure(figsize=(n_plots * 5, 5))
    axes = fig.subplots(1, n_plots, subplot_kw={"projection": "3d"})

    axes[0].plot_surface(X.cpu(), Y.cpu(), Z.cpu(), cmap="viridis")
    axes[0].set_title("Ground truth")
    axes[0].set_xlabel("x")
    axes[0].set_ylabel("y")
    axes[0].set_zlabel("fitness")
    axes[0].view_init(elev=elev, azim=azim)

    solutions = points.cuda()
    solutions = change_range(solutions, (-5, 5), (-1, 1))

    for j, model in enumerate(models):
        f_hat = model(solutions) * FITNESS_BOUND
        f_hat = f_hat.reshape(NUM_SAMPLES, NUM_SAMPLES)

        axes[j + 1].plot_surface(X.cpu(), Y.cpu(), f_hat.cpu(), cmap="viridis")
        axes[j + 1].set_title(model.name)
        axes[j + 1].set_xlabel("x")
        axes[j + 1].set_ylabel("y")
        axes[j + 1].set_zlabel("fitness")

        # change rotation
        axes[j + 1].view_init(elev=elev, azim=azim)

    plt.show()