In [1]:
import numpy as np
import torch
from autometric.geodesics import DjikstraGeodesic
from autometric.datasets import Hemisphere, Ellipsoid, Saddle, Torus
import pathlib
import pickle
from tqdm import tqdm

INFO: Using pytorch backend


In [2]:
data_path = "../../data/neurips_results/toy/visualize2/"
data_names = ['hemisphere', 'ellipsoid', 'saddle', 'torus']

In [3]:
num_points_per_geodesic = 3000

In [4]:
out_path = '../../data/neurips_results/toy/visualize_gt/'
pathlib.Path(out_path).mkdir(parents=True, exist_ok=True)

In [5]:
import pickle
data_name = 'hemisphere'
data = np.load(f'{data_path}/{data_name}.npz')
points = data['X']
hs = Hemisphere(num_points=3000)
hs.X = torch.tensor(points, dtype=torch.float32)
start_points = torch.tensor(data['start_points'], dtype=torch.float32)
end_points = torch.tensor(data['end_points'], dtype=torch.float32)
ts = np.linspace(0, 1, num_points_per_geodesic)
gs, ls = hs.geodesics(start_points, end_points, ts)
if isinstance(gs[0], torch.Tensor):
    gs = [g.detach().numpy() for g in gs]
    ls = ls.numpy()
max_len = max([len(g) for g in gs])
# pad the ends of the list with copies of the last element to make them all the same length, using np.vstack
gs = [np.vstack([g[:-1], np.repeat(g[-1][None,:], max_len - len(g) + 1, axis = 0)]) for g in gs]
data_dict = {n:data[n] for n in data.files}
data_dict['geodesics'] = gs
data_dict['geodesic_lengths'] = ls
np.savez(f'{out_path}/{data_name}.npz', **data_dict)

In [6]:
data_name = 'ellipsoid'
data = np.load(f'{data_path}/{data_name}.npz')
points = data['X']
hs = Ellipsoid(num_points=3000)
hs.X = torch.tensor(points, dtype=torch.float32)
start_points = torch.tensor(data['start_points'], dtype=torch.float32)
end_points = torch.tensor(data['end_points'], dtype=torch.float32)
ts = np.linspace(0, 1, num_points_per_geodesic)
gs, ls = hs.geodesics(start_points, end_points, ts)
if isinstance(gs[0], torch.Tensor):
    gs = [g.detach().numpy() for g in gs]
    ls = ls.numpy()
max_len = max([len(g) for g in gs])
# pad the ends of the list with copies of the last element to make them all the same length, using np.vstack
gs = [np.vstack([g[:-1], np.repeat(g[-1][None,:], max_len - len(g) + 1, axis = 0)]) for g in gs]
data_dict = {n:data[n] for n in data.files}
data_dict['geodesics'] = gs
data_dict['geodesic_lengths'] = ls
np.savez(f'{out_path}/{data_name}.npz', **data_dict)

KeyboardInterrupt: 

In [None]:
data_name = 'saddle'
data = np.load(f'{data_path}/{data_name}.npz')
points = data['X']
hs = Saddle(num_points=3000)
hs.X = torch.tensor(points, dtype=torch.float32)
start_points = torch.tensor(data['start_points'], dtype=torch.float32)
end_points = torch.tensor(data['end_points'], dtype=torch.float32)
ts = np.linspace(0, 1, num_points_per_geodesic)
gs, ls = hs.geodesics(start_points, end_points, ts)
if isinstance(gs[0], torch.Tensor):
    gs = [g.detach().numpy() for g in gs]
    ls = ls.numpy()
max_len = max([len(g) for g in gs])
# pad the ends of the list with copies of the last element to make them all the same length, using np.vstack
gs = [np.vstack([g[:-1], np.repeat(g[-1][None,:], max_len - len(g) + 1, axis = 0)]) for g in gs]
data_dict = {n:data[n] for n in data.files}
data_dict['geodesics'] = gs
data_dict['geodesic_lengths'] = ls
np.savez(f'{out_path}/{data_name}.npz', **data_dict)

In [None]:
data_name = 'torus'
data = np.load(f'{data_path}/{data_name}.npz')
points = data['X']
hs = Torus(num_points=3000)
hs.X = torch.tensor(points, dtype=torch.float32)
start_points = torch.tensor(data['start_points'], dtype=torch.float32)
end_points = torch.tensor(data['end_points'], dtype=torch.float32)
ts = np.linspace(0, 1, num_points_per_geodesic)
gs, ls = hs.geodesics(start_points, end_points, ts)
if isinstance(gs[0], torch.Tensor):
    gs = [g.detach().numpy() for g in gs]
    ls = ls.numpy()
max_len = max([len(g) for g in gs])
# pad the ends of the list with copies of the last element to make them all the same length, using np.vstack
gs = [np.vstack([g[:-1], np.repeat(g[-1][None,:], max_len - len(g) + 1, axis = 0)]) for g in gs]
data_dict = {n:data[n] for n in data.files}
data_dict['geodesics'] = gs
data_dict['geodesic_lengths'] = ls
np.savez(f'{out_path}/{data_name}.npz', **data_dict)