Notebook to demonstrate the key idea of the KSG estimator visually.

In [None]:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors


In [None]:
# Generate synthetic dependent data
np.random.seed(1)
N = 300
k = 4
x = np.random.rand(N, 1)
y = np.sin(2 * np.pi * x) + 0.1 * np.random.randn(N, 1)
data = np.hstack([x, y])

# Choose a point and find k-th neighbor in joint space
idx = 50
nbrs = NearestNeighbors(n_neighbors=k + 1, metric='chebyshev').fit(data)
distances, _ = nbrs.kneighbors(data)
eps = distances[idx, -1]

# Create plot
fig, ax = plt.subplots(figsize=(8, 6))
ax.scatter(x, y, alpha=0.3, label="All points")
ax.scatter(x[idx], y[idx], color="red", label="Query point")

# Joint ε-box (square)
ax.add_patch(plt.Rectangle(
    (x[idx] - eps, y[idx] - eps),
    2 * eps, 2 * eps,
    edgecolor="red", fill=False, linewidth=2, linestyle='--', label="Joint ε-box"
))

# Marginal X-strip (vertical)
ax.axvline(x[idx] - eps, color='blue', linestyle='--', linewidth=1)
ax.axvline(x[idx] + eps, color='blue', linestyle='--', linewidth=1)
ax.fill_betweenx([y.min(), y.max()], x[idx] - eps, x[idx] + eps, color='blue', alpha=0.1, label="X marginal strip")

# Marginal Y-strip (horizontal)
ax.axhline(y[idx] - eps, color='green', linestyle='--', linewidth=1)
ax.axhline(y[idx] + eps, color='green', linestyle='--', linewidth=1)
ax.fill_between(x.flatten(), y[idx] - eps, y[idx] + eps, color='green', alpha=0.1, label="Y marginal strip")

ax.set_title("KSG Marginal Counts vs Joint Neighborhood")
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.legend()
ax.grid(True)
plt.tight_layout()
plt.show()
