In [10]:
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import torch

import plotly.graph_objects as go
from plotly3d.plot import scatter, trajectories
from plotly.subplots import make_subplots
import plotly.io as pio

In [11]:
rpath = './fm/'
data2traj = {}

for filename in os.listdir(rpath):
    if filename.endswith('.npz'):
        d_name = filename.split('.')[0].split('_')[-1]
        data = np.load(rpath + filename, allow_pickle=True)
        data2traj[d_name] = data

In [12]:
# for d_name, data in data2traj.items():
#     print(d_name)
#     for k, v in data.items():
#         print(k, v.shape)

In [13]:
# fig = make_subplots(rows=rows, cols=cols,
#                     specs=[[{'type': 'scatter3d'}]*cols]*rows,
#                     shared_xaxes=False, horizontal_spacing=0.001, vertical_spacing=0.001)

In [14]:
def _torus():
    # Parameters for the torus
    R = 2.0
    r = 1.0

    # Create a meshgrid for phi and theta
    phi = np.linspace(0, 2 * np.pi, 100)
    theta = np.linspace(0, 2 * np.pi, 100)
    phi, theta = np.meshgrid(phi, theta)

    # Parametric equations for the torus
    x = (R + r * np.cos(theta)) * np.cos(phi)
    y = (R + r * np.cos(theta)) * np.sin(phi)
    z = r * np.sin(theta)
    print(x.shape, y.shape, z.shape)

    return x,y,z

def _ellipsoid(a=3, b=2, c=1):
    # Create a meshgrid for phi and theta
    phi = np.linspace(0, 2 * np.pi, 100)
    theta = np.linspace(0, np.pi, 100)
    phi, theta = np.meshgrid(phi, theta)

    # Parametric equations for the ellipsoid
    x = a * np.sin(theta) * np.cos(phi)
    y = b * np.sin(theta) * np.sin(phi)
    z = c * np.cos(theta)
    print(x.shape, y.shape, z.shape)

    return x,y,z

def _saddle(a=1, b = 1):
    # Create a meshgrid for phi and theta
    x = np.linspace(-2, 2, 100)
    y = np.linspace(-2, 2, 100)
    x, y = np.meshgrid(x, y)

    # Parametric equations for the saddle
    z = a * x**2 - b * y**2
    print(x.shape, y.shape, z.shape)

    return x,y,z

def _hemisphere(r=1):
    # Create a meshgrid for phi and theta
    phi = np.linspace(0, 2 * np.pi, 100)
    theta = np.linspace(0, np.pi/2, 100)
    phi, theta = np.meshgrid(phi, theta)

    # Parametric equations for the hemisphere
    x = np.sin(theta) * np.cos(phi) * r
    y = np.sin(theta) * np.sin(phi) * r
    z = np.cos(theta) * r
    print(x.shape, y.shape, z.shape)

    return x,y,z
    

In [18]:
datas = data2traj.keys()
#datas = ['ellipsoid']
#datas = ['hemisphere']
for i, dname in enumerate(datas):
    x = data2traj[dname]['x']
    x0 = data2traj[dname]['x0']
    x1 = data2traj[dname]['x1']
    traj = data2traj[dname]['traj']

    fig = make_subplots(rows=1, cols=1,
                        specs=[[{'type': 'scatter3d'}]])


    # scatter(x0, s=10, fig=fig)
    # scatter(x1, s=10, fig=fig)
    # scatter(x, s=2, alpha=1, fig=fig, colors=['#646FFA'], colorscale='Viridis')

    # All points
    # fig.add_trace(go.Scatter3d(x=x[:,0], y=x[:,1], z=x[:,2],
    #                                mode='markers', marker=dict(size=3, color='#646FFA', colorscale='Viridis', opacity=0.3)),
    #                                row=1, col=1)
    
    # surface plot
    x,y,z = globals()[f'_{dname}']()
    fig.add_trace(go.Surface(x=x, y=y, z=z, opacity=0.3, showscale=False), row=1, col=1)


    # traj
    trajectories(traj, s=1, fig=fig)

    fig.update_scenes(
        xaxis=dict(showgrid=False, zeroline=True, showline=True, showticklabels=False, ticks=""),
        yaxis=dict(showgrid=False, zeroline=True, showline=True, showticklabels=False, ticks=""),
        zaxis=dict(showgrid=False, zeroline=True, showline=True, showticklabels=False, ticks="")
    )

    # starting points
    fig.add_trace(go.Scatter3d(x=x0[:, 0], y=x0[:, 1], z=x0[:, 2],
                                    mode='markers', marker=dict(size=10, color='green', opacity=1, line=dict(color='black', width=1))),
                                    row=1, col=1)
    
    # ending points
    fig.add_trace(go.Scatter3d(x=x1[:, 0], y=x1[:, 1], z=x1[:, 2],
                                mode='markers', marker=dict(size=10, opacity=1, color='red', line=dict(color='black', width=1))),
                                row=1, col=1)

    fig.update_layout(
        title='',
        scene=dict(
            xaxis=dict(title='', backgroundcolor='white'),
            yaxis=dict(title='', backgroundcolor='white'),
            zaxis=dict(title='', backgroundcolor='white')
        ),
    )
    fig.update_layout(
        paper_bgcolor='white',  # Color of the whole background
        plot_bgcolor='white'    # Color of the plotting area
    )
    fig.update_layout(showlegend=False)


    # if ellipsoid or hemisphere, rotate the camera
    if dname == 'hemisphere':
        fig.update_layout(scene_camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=0.1, y=-2.5, z=0.5)
        ))
    elif dname == 'ellipsoid':
        fig.update_layout(scene_camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=-0.8, y=-2.5, z=0.5)
        ))

    # fig.update_layout(
    #     height=400,
    # #     width=400
    # # )
    # fig.show()

    # fig.write_image(f'./fm/fm_{dname}.pdf')
    pio.write_image(fig, f'./fm/fm_{dname}.pdf', format='pdf', 
        height=1000, width=1000,
        scale=2)





(100, 100) (100, 100) (100, 100)


(100, 100) (100, 100) (100, 100)


(100, 100) (100, 100) (100, 100)


(100, 100) (100, 100) (100, 100)


In [16]:
import plotly.io as pio
