In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os

import plotly.graph_objects as go
import plotly.io as pio

from plotly.subplots import make_subplots
from plotly3d.plot import scatter, trajectories

In [29]:
# V2 loader
rpath = './geodesic_same_starting/'
methods = ['gt', 'no_density', 'ours', 'density']
datas = ['ellipsoid', 'torus', 'saddle', 'hemisphere']

all_data = {}
for data_name in datas:
    all_data[data_name] = {}
for dir_name in os.listdir(rpath):
    if dir_name == 'start_end':
        continue
    if dir_name not in methods and dir_name != 'data_gt':
        continue
    if dir_name == 'data_gt':
        method = 'gt'
        for data_name in datas:
            cur_data = np.load(f"{rpath}/{dir_name}/{data_name}_none_0.0.npz")
            all_data[data_name][method] = {}
            all_data[data_name][method]['x'] = cur_data['X']
            all_data[data_name][method]['x_gt'] = cur_data['X_ground_truth']
            all_data[data_name][method]['x0'] = cur_data['start_points']
            all_data[data_name][method]['x1'] = cur_data['end_points']
            all_data[data_name][method]['geodesics'] = np.transpose(cur_data['geodesics'], (1,0,2)) # [n, t, dim] -> [t, n, dim]
    else:
        for data_name in datas:
            cur_data = np.load(f"{rpath}/{dir_name}/{data_name}/trajectories.npy")
            all_data[data_name][dir_name] = {}
            all_data[data_name][dir_name]['geodesics'] = cur_data # [t, n, dim]

# Load x0, x1 for non-gt
for dname in datas:
    cur_data = np.load(f"{rpath}/start_end/{dname}.npz")
    all_data[dname]['ours']['x0'] = cur_data['start_points']
    all_data[dname]['ours']['x1'] = cur_data['end_points']

In [2]:
# # V1 loader
# rpath = './geodesic/geodesics_points/'
# methods = ['density', 'no_density', 'ours']
# datas = ['ellipsoid', 'torus', 'saddle', 'hemisphere']

# all_data = {}
# for dir_name in os.listdir(rpath):
#     data_name = dir_name.split('_')[0]
#     if data_name not in datas:
#         continue
#     if data_name not in all_data:
#         all_data[data_name] = {}
#     for method in methods:
#         cur_datafile = np.load(f'{rpath}/{dir_name}/{method}.npz')
#         all_data[data_name][method] = {}
#         #print(cur_datafile.files)
#         all_data[data_name][method]['x0'] = cur_datafile['x0']
#         all_data[data_name][method]['x1'] = cur_datafile['x1']
#         all_data[data_name][method]['xhat'] = cur_datafile['xhat'] # trajectory points [t, n, dim]

#         all_data[data_name][method]['x'] = np.load(f'./gt/{dir_name}.npz')['X']
#         all_data[data_name][method]['geodesics'] = np.load(f'./gt/{dir_name}.npz')['geodesics'] # [n, t, dim]
#         all_data[data_name][method]['geodesics'] = np.transpose(all_data[data_name][method]['geodesics'], (1, 0, 2))

In [30]:
for data_name in all_data:
    for method in methods:
        result = all_data[data_name][method]
        for k, v in result.items():
            print(data_name, method, k, v.shape)

ellipsoid gt x (3000, 3)
ellipsoid gt x_gt (3000, 3)
ellipsoid gt x0 (10, 3)
ellipsoid gt x1 (10, 3)
ellipsoid gt geodesics (279, 10, 3)
ellipsoid no_density geodesics (100, 10, 3)
ellipsoid ours geodesics (100, 10, 3)
ellipsoid ours x0 (10, 3)
ellipsoid ours x1 (10, 3)
ellipsoid density geodesics (100, 10, 3)
torus gt x (3000, 3)
torus gt x_gt (3000, 3)
torus gt x0 (10, 3)
torus gt x1 (10, 3)
torus gt geodesics (199, 10, 3)
torus no_density geodesics (100, 10, 3)
torus ours geodesics (100, 10, 3)
torus ours x0 (10, 3)
torus ours x1 (10, 3)
torus density geodesics (100, 10, 3)
saddle gt x (3000, 3)
saddle gt x_gt (3000, 3)
saddle gt x0 (10, 3)
saddle gt x1 (10, 3)
saddle gt geodesics (687, 10, 3)
saddle no_density geodesics (100, 10, 3)
saddle ours geodesics (100, 10, 3)
saddle ours x0 (10, 3)
saddle ours x1 (10, 3)
saddle density geodesics (100, 10, 3)
hemisphere gt x (3056, 3)
hemisphere gt x_gt (3056, 3)
hemisphere gt x0 (10, 3)
hemisphere gt x1 (10, 3)
hemisphere gt geodesics (3000

In [31]:
def _torus():
    # Parameters for the torus
    R = 2.0
    r = 1.0

    # Create a meshgrid for phi and theta
    phi = np.linspace(0, 2 * np.pi, 100)
    theta = np.linspace(0, 2 * np.pi, 100)
    phi, theta = np.meshgrid(phi, theta)

    # Parametric equations for the torus
    x = (R + r * np.cos(theta)) * np.cos(phi)
    y = (R + r * np.cos(theta)) * np.sin(phi)
    z = r * np.sin(theta)
    print(x.shape, y.shape, z.shape)

    return x,y,z

def _ellipsoid(a=3, b=2, c=1):
    # Create a meshgrid for phi and theta
    phi = np.linspace(0, 2 * np.pi, 100)
    theta = np.linspace(0, np.pi, 100)
    phi, theta = np.meshgrid(phi, theta)

    # Parametric equations for the ellipsoid
    x = a * np.sin(theta) * np.cos(phi)
    y = b * np.sin(theta) * np.sin(phi)
    z = c * np.cos(theta)
    print(x.shape, y.shape, z.shape)

    return x,y,z

def _saddle(a=1,b=1, ylim=2, xlim=2):
    # Create a meshgrid for phi and theta
    x = np.linspace(xlim, -xlim, 100)
    y = np.linspace(ylim, -ylim, 100)
    x, y = np.meshgrid(x, y)

    # Parametric equations for the saddle
    z = a * x**2 - b * y**2
    print(x.shape, y.shape, z.shape)

    return x,y,z

def _hemisphere(r=1):
    # Create a meshgrid for phi and theta
    phi = np.linspace(0, 2 * np.pi, 100)
    theta = np.linspace(0, np.pi/2, 100)
    phi, theta = np.meshgrid(phi, theta)

    # Parametric equations for the hemisphere
    x = np.sin(theta) * np.cos(phi) * r
    y = np.sin(theta) * np.sin(phi) * r
    z = np.cos(theta) * r
    print(x.shape, y.shape, z.shape)

    return x,y,z
    


### Use matplotlib

### Using Plotly to visualize the data

In [34]:
# figure size
datas = ['ellipsoid', 'torus', 'saddle', 'hemisphere']
#datas = ['saddle']
methods = ['no_density', 'ours', 'gt', 'density']

# methods = ['density']
#datas = ['hemisphere']
rows = len(datas)
cols = len(methods)

traj_num = 10
alpha = 0.3

for j, method in enumerate(methods):    
    for i, data_name in enumerate(datas):
        fig = make_subplots(rows=1, cols=1,
                    specs=[[{'type': 'scatter3d'}]],
                    shared_xaxes=False, vertical_spacing=0.0001)
        
        # if method == 'gt':
        #     result = all_data[data_name]['ours']
        # else:
        #     result = all_data[data_name][method]
        result = all_data[data_name]['gt']

        # X 
        # fig.add_trace(go.Scatter3d(x=result['x'][:,0], y=result['x'][:,1], z=result['x'][:,2],
        #                            mode='markers', marker=dict(size=3, color='#646FFA', colorscale='Viridis', opacity=0.3)),
        #                            row=1, col=1)
        
        # X surface plot
        if data_name != 'saddle':
            x, y, z = globals()[f'_{data_name}']()
        else:
            x, y, z = globals()[f'_{data_name}'](xlim=1, ylim=1)
        fig.add_trace(go.Surface(x=x, y=y, z=z, opacity=0.3, showscale=False), row=1, col=1)

        traj_index = np.arange(10)
        # if data_name == 'saddle':
        #     traj_index = [1,3,9]
        # if data_name == 'torus':
        #     traj_index = [0,13]
        # if data_name == 'ellipsoid':
        #     traj_index = [0,2]
        # if data_name == 'hemisphere':
        #     traj_index = [0,3]
        
        traj_num = len(traj_index)
        # trajectory, each line different color
        colors = ['green', 'orange', 'purple', 'brown']
        colors = np.arange(traj_num)
        if method == 'gt':
            geodesics = result['geodesics'][:, traj_index, :]
            for k in range(traj_num):
                geodesic = geodesics[:, k, :]
                color = colors[k]
                fig.add_trace(go.Scatter3d(x=geodesic[:,0], y=geodesic[:,1], z=geodesic[:,2],
                                             mode='lines', line=dict(color=color, width=15), opacity=1.0), row=1, col=1)
            # trajectories(geodesics, s=10, colors=colors, fig=fig, white_bkgrnd=True, ticks=False, title='',
            #                 row=1, col=1)
        else:
            actual_result = all_data[data_name][method]
            geodesics = actual_result['geodesics'][:, traj_index, :]
            for k in range(traj_num):
                geodesic = geodesics[:, k, :]
                color = colors[k]
                fig.add_trace(go.Scatter3d(x=geodesic[:,0], y=geodesic[:,1], z=geodesic[:,2],
                                             mode='lines', line=dict(color=color, width=15), opacity=1.0), row=1, col=1)
            # trajectories(result['xhat'][:, traj_index, :], colors=colors, s=15, fig=fig, white_bkgrnd=True, ticks=False, title='',
            #                     row=1, col=1)
            
        # starting points
        if method == 'gt':
            x0 = result['x0']
            x1 = result['x1']
        else:
            x0 = all_data[data_name]['ours']['x0']
            x1 = all_data[data_name]['ours']['x1']

        fig.add_trace(go.Scatter3d(x=x0[traj_index, 0], y=x0[traj_index, 1], z=x0[traj_index, 2],
                                      mode='markers', marker=dict(size=12, color='green', opacity=1, line=dict(width=5, color='black'))),
                                      row=1, col=1)
        
        # ending points
        fig.add_trace(go.Scatter3d(x=x1[traj_index, 0], y=x1[traj_index, 1], z=x1[traj_index, 2],
                                   mode='markers', marker=dict(size=12, color='red', opacity=1, line=dict(width=5, color='black'))),
                                      row=1, col=1)

        # if saddle, slightly rotate arond z axis
        if data_name == 'saddle':
            fig.update_layout(scene_camera=dict(
                up=dict(x=0, y=0, z=1),
                center=dict(x=0, y=0, z=0),
                eye=dict(x=1.5, y=0.8, z=1.25)
            ))
        

        fig.update_scenes(
            dict(
                xaxis=dict(showticklabels=False, title=''),
                yaxis=dict(showticklabels=False, title=''),
                zaxis=dict(showticklabels=False, title=''),
                bgcolor='white'
            ),
            row=1, col=1
        )
        fig.update_scenes(
            dict(
                xaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, ticks="",backgroundcolor='white'),
                yaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, ticks="",backgroundcolor='white'),
                zaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, ticks="",backgroundcolor='white'),
                bgcolor='white'
            ),
            row=1, col=1
        )

        fig.update_layout(
            title='',
            paper_bgcolor='white',  # Color of the whole background
            plot_bgcolor='white'    # Color of the plotting area
        )

        fig.update_layout(showlegend=True)

        #fig.show()
        # pio.write_image(fig, f"./geovisV5/geovis_{method}_{data_name}.pdf", format='pdf', 
        # height=800, width=800,
        # scale=2)
        fig.write_html(f"geovis_{method}_{data_name}.html")



(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
(100, 100) (100, 100) (100, 100)
