In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os

In [2]:
rpath = './geodesic/geodesics_points/'
methods = ['density', 'no_density', 'ours']
datas = ['ellipsoid', 'torus', 'saddle', 'hemisphere']

all_data = {}
for dir_name in os.listdir(rpath):
    data_name = dir_name.split('_')[0]
    if data_name not in datas:
        continue
    if data_name not in all_data:
        all_data[data_name] = {}
    for method in methods:
        cur_datafile = np.load(f'{rpath}/{dir_name}/{method}.npz')
        all_data[data_name][method] = {}
        #print(cur_datafile.files)
        all_data[data_name][method]['x0'] = cur_datafile['x0']
        all_data[data_name][method]['x1'] = cur_datafile['x1']
        all_data[data_name][method]['xhat'] = cur_datafile['xhat'] # trajectory points [t, n, dim]

        all_data[data_name][method]['x'] = np.load(f'./gt/{dir_name}.npz')['X']
        all_data[data_name][method]['geodesics'] = np.load(f'./gt/{dir_name}.npz')['geodesics'] # [n, t, dim]
        all_data[data_name][method]['geodesics'] = np.transpose(all_data[data_name][method]['geodesics'], (1, 0, 2))

In [3]:
for data_name in all_data:
    for method in methods:
        result = all_data[data_name][method]
        for k, v in result.items():
            print(data_name, method, k, v.shape)

hemisphere density x0 (20, 3)
hemisphere density x1 (20, 3)
hemisphere density xhat (100, 20, 3)
hemisphere density x (3056, 3)
hemisphere density geodesics (3000, 20, 3)
hemisphere no_density x0 (20, 3)
hemisphere no_density x1 (20, 3)
hemisphere no_density xhat (100, 20, 3)
hemisphere no_density x (3056, 3)
hemisphere no_density geodesics (3000, 20, 3)
hemisphere ours x0 (20, 3)
hemisphere ours x1 (20, 3)
hemisphere ours xhat (100, 20, 3)
hemisphere ours x (3056, 3)
hemisphere ours geodesics (3000, 20, 3)
ellipsoid density x0 (20, 3)
ellipsoid density x1 (20, 3)
ellipsoid density xhat (100, 20, 3)
ellipsoid density x (3000, 3)
ellipsoid density geodesics (586, 20, 3)
ellipsoid no_density x0 (20, 3)
ellipsoid no_density x1 (20, 3)
ellipsoid no_density xhat (100, 20, 3)
ellipsoid no_density x (3000, 3)
ellipsoid no_density geodesics (586, 20, 3)
ellipsoid ours x0 (20, 3)
ellipsoid ours x1 (20, 3)
ellipsoid ours xhat (100, 20, 3)
ellipsoid ours x (3000, 3)
ellipsoid ours geodesics (586,

In [4]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly3d.plot import scatter, trajectories

In [7]:
def _torus():
    # Parameters for the torus
    R = 2.0
    r = 1.0

    # Create a meshgrid for phi and theta
    phi = np.linspace(0, 2 * np.pi, 100)
    theta = np.linspace(0, 2 * np.pi, 100)
    phi, theta = np.meshgrid(phi, theta)

    # Parametric equations for the torus
    x = (R + r * np.cos(theta)) * np.cos(phi)
    y = (R + r * np.cos(theta)) * np.sin(phi)
    z = r * np.sin(theta)
    print(x.shape, y.shape, z.shape)

    return x,y,z

def _ellipsoid(a=3, b=2, c=1):
    # Create a meshgrid for phi and theta
    phi = np.linspace(0, 2 * np.pi, 100)
    theta = np.linspace(0, np.pi, 100)
    phi, theta = np.meshgrid(phi, theta)

    # Parametric equations for the ellipsoid
    x = a * np.sin(theta) * np.cos(phi)
    y = b * np.sin(theta) * np.sin(phi)
    z = c * np.cos(theta)
    print(x.shape, y.shape, z.shape)

    return x,y,z

def _saddle(a=1, b = 1):
    # Create a meshgrid for phi and theta
    x = np.linspace(-2, 2, 100)
    y = np.linspace(-2, 2, 100)
    x, y = np.meshgrid(x, y)

    # Parametric equations for the saddle
    z = a * x**2 - b * y**2
    print(x.shape, y.shape, z.shape)

    return x,y,z

def _hemisphere(r=1):
    # Create a meshgrid for phi and theta
    phi = np.linspace(0, 2 * np.pi, 100)
    theta = np.linspace(0, np.pi/2, 100)
    phi, theta = np.meshgrid(phi, theta)

    # Parametric equations for the hemisphere
    x = np.sin(theta) * np.cos(phi) * r
    y = np.sin(theta) * np.sin(phi) * r
    z = np.cos(theta) * r
    print(x.shape, y.shape, z.shape)

    return x,y,z
    


In [8]:
# figure size
datas = ['ellipsoid', 'torus', 'saddle', 'hemisphere']
# datas = ['torus']
methods = ['density', 'no_density', 'ours', 'gt']

# methods = ['density']
# datas = ['ellipsoid']
rows = len(datas)
cols = len(methods)

traj_num = 5
alpha = 0.3

for j, method in enumerate(methods):    
    for i, data_name in enumerate(datas):
        fig = make_subplots(rows=1, cols=1,
                    specs=[[{'type': 'scatter3d'}]],
                    shared_xaxes=False, vertical_spacing=0.0001)
        
        if method == 'gt':
            result = all_data[data_name]['density']
        else:
            result = all_data[data_name][method]

        # X 
        # fig.add_trace(go.Scatter3d(x=result['x'][:,0], y=result['x'][:,1], z=result['x'][:,2],
        #                            mode='markers', marker=dict(size=3, color='#646FFA', colorscale='Viridis', opacity=0.3)),
        #                            row=1, col=1)
        
        # X surface plot
        x, y, z = globals()[f'_{data_name}']()
        fig.add_trace(go.Surface(x=x, y=y, z=z, opacity=0.3, showscale=False), row=1, col=1)

        # trajectory
        if method == 'gt':
            geodesics = result['geodesics'][:, :traj_num, :]
            trajectories(geodesics, s=7, fig=fig, white_bkgrnd=True, ticks=False, title='',
                            row=1, col=1)
        else:
            trajectories(result['xhat'][:, :traj_num, :], color='blakc', s=7, fig=fig, white_bkgrnd=True, ticks=False, title='',
                                row=1, col=1)
            
        # starting points
        fig.add_trace(go.Scatter3d(x=result['x0'][:traj_num, 0], y=result['x0'][:traj_num, 1], z=result['x0'][:traj_num, 2],
                                      mode='markers', marker=dict(size=8, color='green', opacity=1, line=dict(width=1, color='black'))),
                                      row=1, col=1)
        
        # ending points
        fig.add_trace(go.Scatter3d(x=result['x1'][:traj_num, 0], y=result['x1'][:traj_num, 1], z=result['x1'][:traj_num, 2],
                                   mode='markers', marker=dict(size=8, color='red', opacity=1, line=dict(width=1, color='black'))),
                                      row=1, col=1)


        # fig.update_scenes(
        #     dict(
        #         xaxis=dict(showticklabels=False, title=''),
        #         yaxis=dict(showticklabels=False, title=''),
        #         zaxis=dict(showticklabels=False, title=''),
        #         bgcolor='white'
        #     ),
        #     row=1, col=1
        # )
        # fig.update_scenes(
        #     dict(
        #         xaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, ticks="",backgroundcolor='white'),
        #         yaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, ticks="",backgroundcolor='white'),
        #         zaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, ticks="",backgroundcolor='white'),
        #         bgcolor='white'
        #     ),
        #     row=1, col=1
        # )

        # fig.update_layout(
        #     title='',
        #     # paper_bgcolor='white',  # Color of the whole background
        #     # plot_bgcolor='white'    # Color of the plotting area
        # )

        # fig.update_layout(showlegend=False)

        #fig.show()
        fig.write_image(f"geovis_{method}_{data_name}.pdf", format='pdf')



(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
