Skip to content

Commit

Permalink
bias_variance script; making metric.m and metric.shape settable.
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongu committed Mar 15, 2023
1 parent 78b5414 commit 2a8fea5
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
49 changes: 49 additions & 0 deletions diagnostics/bias_variance.py
@@ -0,0 +1,49 @@
import torch
import numpy as np
from tests.metrics import _list_of_metrics
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm


m, n, runs = 10000, 300, 10
mvals = np.logspace(2, np.log10(m), 31).round().astype(int)
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
dtype = torch.float32

data_x = torch.randn(m, n, device=device, dtype=dtype)
data_y = data_x + 2 * torch.randn(m, n, device=device, dtype=dtype) / np.sqrt(n)

#%%

save_dir = Path(__file__).parent / "bias_variance"
save_dir.mkdir(exist_ok=True)

#%%
metrics = [m["metric"] for m in _list_of_metrics]
for i, metric in enumerate(metrics):
if i == 0: continue
print("Starting", metric.string_id())
lengths = torch.zeros(len(mvals), runs)
for i, sub_m in tqdm(enumerate(mvals), desc=metric.string_id(), total=len(mvals)):
metric.m = sub_m
for j in range(runs):
idx = torch.randperm(m)[:sub_m]
sub_x, sub_y = data_x[idx, :], data_y[idx, :]
try:
lengths[i, j] = metric.length(*map(metric.neural_data_to_point, [sub_x, sub_y]))
except:
lengths[i, j] = np.nan


lengths = lengths.detach().cpu().numpy()
plt.figure()
mu, sigma = np.nanmean(lengths, axis=-1), np.nanstd(lengths, axis=-1)
plt.fill_between(mvals, mu-3*sigma, mu+3*sigma, color=(0., 0., 0., 0.25))
plt.plot(mvals, mu, color=(0., 0., 0.))
plt.xscale('log')
plt.xlabel('m')
plt.ylabel('length')
plt.title(metric.string_id())
plt.savefig(save_dir / (metric.string_id() + ".svg"))
plt.show()
11 changes: 10 additions & 1 deletion src/repsim/geometry/manifold.py
Expand Up @@ -27,9 +27,18 @@ class LengthSpace(abc.ABC):
"""
def __init__(self, *, dim: int, shape: tuple):
self.dim = dim
self.shape = shape
self._shape = shape
self.ambient = prod(shape)

@property
def shape(self):
return self._shape

@shape.setter
def shape(self, new_shape):
self._shape = new_shape
self.ambient = prod(new_shape)

def project(self, pt: Point) -> Point:
"""Project a point from the ambient space onto the manifold.
Expand Down

0 comments on commit 2a8fea5

Please sign in to comment.