In [None]:
import matplotlib.pyplot as plt
import numpy as np

import seaborn as sns
import matplotlib as mpl
import scipy.stats as stats
import xarray as xr

from matplotlib.collections import LineCollection

import sys
sys.path.append("..")

from pivot import metric
from pivot.transform import PivotSpace
from pivot.plot import plot_grid

%load_ext autoreload
%autoreload 2

In [None]:
def fill_rectangle(x_lim, y_lim, n):
    """Return about `n` points that are inside the given rectangle"""
    samples = int(np.sqrt(n))
    x = np.linspace(*x_lim, samples)
    y = np.linspace(*y_lim, samples)
    xx, yy = np.meshgrid(x,y)
    points = np.array([xx.flatten(), yy.flatten()]).T
    return points

def fill_circle(radius, n):
    bounds = (-radius, radius)
    points = fill_rectangle(bounds, bounds, n)
    distance_to_center = np.sum(points**2, axis=1)
    return points[distance_to_center<=radius**2]

def move_points(points, delta_x, delta_y):
    points = points.copy()
    points[:,0] += delta_x
    points[:,1] += delta_y
    return points

def circle_around_point(radius, point, n):
    c = fill_circle(radius, n)
    return move_points(c, *point)

In [None]:
transform = PivotSpace(metric.Euclid(), pivots = np.asarray([[-0.5,0], [0.5,0]]))
points = fill_circle(2, 1000)

In [None]:
points = stats.multivariate_normal(mean=[0,1]).rvs(10_000_000)

points = fill_circle(5, 10_000_000)

points_piv = transform.transform_points(points)
points_rp = transform.rectify(points_piv)

def binspace(vals, bin_length = 0.05):
    """generates bin edges that are guaranteed to be between vals.min and vals.max
    The actual bin_length can be off by a little bit to accomplish this.""" 
    space = vals.max() - vals.min()
    steps = int(space / bin_length)
    return np.linspace(vals.min(), vals.max()+1e-6, steps)
    
bins = binspace(points_rp[:,0]), binspace(points_rp[:,1])


plt.hist2d(*points_rp.T, bins=bins, density=True) #, norm=mpl.colors.LogNorm());
plt.colorbar()
plt.gca().set_aspect('equal')
plt.show()

In [None]:
counts, xedges, yedges = np.histogram2d(*points_rp.T, bins=bins)

def bin_centers(edges):
    step_size = edges[1] - edges[0]
    return edges[:-1] + step_size/2

density_piv  = xr.DataArray(
    counts,
    dims=("p1", "p2"),
    coords=dict(p1=bin_centers(xedges), p2=bin_centers(yedges))
)

In [None]:
extent = -3,3
lims = np.linspace(*extent, 97)
xx,yy = np.meshgrid(lims, lims)
grid = np.array([xx.flatten(), yy.flatten()]).T

pgrid = transform.transform_points(grid)
tr = transform.rectify(pgrid)
density = density_piv.sel(
    p1=xr.DataArray(tr[:,0], dims="points"),
    p2=xr.DataArray(tr[:,1], dims="points"),
    method="nearest"
)
density = np.asarray(density)

def plot_config():
    plt.plot(*transform.pivots.T, "wx")
    
    plt.xlabel("x")
    plt.ylabel("y")
    plt.xlim(*extent)
    plt.ylim(*extent)
    plt.gca().set_aspect('equal')
    plt.tight_layout()
    

plt.subplot(1,2,1)
plt.title("point density in metric space")
plt.hist2d(*points.T, bins=29)
plot_config()



plt.subplot(1,2,2)
plt.title("point density in pivot space,\nreprojected to the metric space")
plt.imshow(
    density.reshape(len(lims), len(lims)),
    extent = [*extent, *extent],
)
plot_config()


In [None]:
n = 20
xx, yy = np.meshgrid(np.linspace(-2,0.5, n), np.linspace(-1,0,n))

points = np.array([xx.flatten(), yy.flatten()]).T
p_t = transform.transform_points(points)
xx_t, yy_t = p_t.T
xx_t = xx_t.reshape(xx.shape)
yy_t = yy_t.reshape(yy.shape)


cmap_1 = mpl.colormaps['plasma']
config = dict(
    ycolor = cmap_1(np.linspace(0.2,1, xx.shape[0])),
    xcolor = "lightblue",
)

plt.subplot(1,2,1)
plt.title(r"metric space $(\mathbb{R}^2, "+transform.metric.name+")$")
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")

plot_grid(xx,yy, **config)
plt.plot(*transform.pivots.T, "x", color="C1", label="pivots");
plt.ylim(-1.1,1)
plt.legend()
plt.gca().set_aspect('equal')

plt.subplot(1,2,2)
plt.title("pivot space $(\mathbb{R}_+)^2$")
plt.xlabel("$\Phi_1$")
plt.ylabel("$\Phi_2$")

plot_grid(xx_t,yy_t, **config)


plt.gca().set_aspect('equal')

plt.tight_layout()
plt.show()