In [2]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [3]:
data_root = './exp_figures'
data_files = os.listdir(data_root)
print(data_files)
data = {}

for file in data_files:
    fdata = np.load(os.path.join(data_root, file), allow_pickle=True)
    print(file)
    name = file.split('.pkl')[0].split('_')[-1]
    data[name] = fdata


['figure1a_vol_ellipsoid.pkl', 'figure1b_vol_saddle.pkl', 'figure1c_vol_torus.pkl']
figure1a_vol_ellipsoid.pkl
figure1b_vol_saddle.pkl
figure1c_vol_torus.pkl


In [4]:
data['ellipsoid']

{'x': array([[-0.19343406,  0.03560936, -0.9977603 ],
        [-0.73281175,  0.6420481 , -0.91502756],
        [ 0.2894837 ,  1.9385066 ,  0.22635548],
        ...,
        [ 0.150479  ,  1.9120897 ,  0.28889826],
        [-0.7659232 , -0.9770324 ,  0.83436793],
        [ 2.8261988 , -0.44137612,  0.25260222]], dtype=float32),
 'vols': array([3.5746875 , 0.8745672 , 1.2594423 , ..., 1.5261447 , 2.1576943 ,
        0.44297624], dtype=float32),
 'title': 'Ellipsoid'}

In [27]:
nums = len(data.keys())

import plotly.graph_objects as go
from plotly.subplots import make_subplots

# figure size
fig = make_subplots(rows=1, cols=nums,
                    specs=[[{'type': 'scatter3d'}]*nums],
                    shared_xaxes=False, horizontal_spacing=0.001)

for i, key in enumerate(data.keys()):
    x = data[key]['x']
    vols = data[key]['vols']
    vols = (vols - np.min(vols)) / (np.max(vols) - np.min(vols))

    fig.add_trace(
        go.Scatter3d(x=x[:, 0], y=x[:, 1], z=x[:, 2], mode='markers',
                     marker=dict(size=4, color=vols, colorscale='Viridis', opacity=0.7),
                     name=""),
        row=1, col=i+1,
    )

    fig.update_scenes(
        dict(
            xaxis=dict(showticklabels=False, title=''),
            yaxis=dict(showticklabels=False, title=''),
            zaxis=dict(showticklabels=False, title=''),
            bgcolor='white'
        ),
        row=1, col=i+1
    )

    fig.update_scenes(
        dict(
            xaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, ticks="",backgroundcolor='white'),
            yaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, ticks="",backgroundcolor='white'),
            zaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, ticks="",backgroundcolor='white'),
            bgcolor='white'
        ),
        row=1, col=i+1
    )

# add colorbar at right side, label 'low' and 'high' are set to the min and max of the volume
# change font to Helvetica
#fig.update_layout(coloraxis=dict(colorscale='Viridis', cmin=0, cmax=1, colorbar=dict(title='Volume', tickvals=[0, 1], ticktext=['low', 'high'], tickfont=dict(family='Helvetica'))))
fig.update_traces(marker=dict(colorbar=dict(title='Volume', tickvals=[0, 1], 
                                            ticktext=['low', 'high'], tickfont=dict(family='Helvetica', size=12))))

fig.update_layout(
    height=600,
    width=600*nums,
    paper_bgcolor='white',  # Color of the whole background
    plot_bgcolor='white'    # Color of the plotting area
)

# remove the dots
fig.update_layout(showlegend=False)
fig.show()


# save to pdf
fig.write_image("volumes.pdf")

In [28]:
off_mfd_data = np.load('figure2_off_manifolder.pkl', allow_pickle=True)
off_mfd_data.keys()

dict_keys(['x', 'prob', 'pos_neg', 'ofm_emb'])

In [30]:
for k, v in off_mfd_data.items():
    print(k, v.shape, v.dtype)

x (5500, 3) float32
prob (5500,) float32
pos_neg (5500,) float64
ofm_emb (5500, 3) float32


In [63]:
x = off_mfd_data['x']
prob = off_mfd_data['prob']
prob = (prob - np.min(prob)) / (np.max(prob) - np.min(prob))
mask = off_mfd_data['pos_neg'].astype(np.bool)
pos_neg = off_mfd_data['pos_neg']

fig = go.Figure()
# make figure bigger, less margin 

fig.add_trace(go.Scatter3d(x=x[:, 0][mask], y=x[:, 1][mask], z=x[:, 2][mask], mode='markers',
                           marker=dict(size=2, color='blue', colorscale='Viridis', opacity=0.5),
                           name="positive"))
fig.add_trace(go.Scatter3d(x=x[:, 0][~mask], y=x[:, 1][~mask], z=x[:, 2][~mask], mode='markers',
                            marker=dict(size=2, color='red', colorscale='Viridis', opacity=0.5),
                            name="negative"))

fig.update_scenes(
    dict(
        xaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, 
                   ticks="", title="", backgroundcolor='white'),
        yaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, 
                   ticks="", title="",backgroundcolor='white'),
        zaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, 
                   ticks="",title="", backgroundcolor='white'),
        bgcolor='white'
    )
)

# update font to Helvetica
fig.update_layout(
    font=dict(family='Helvetica', size=20),
    height=600,
    width=600,
    margin=dict(l=0, r=0, b=0, t=0))


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations



In [64]:

fig.write_image("off_manifolder.pdf")

In [65]:
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=x[:, 0], y=x[:, 1], z=x[:, 2], mode='markers',
                           marker=dict(size=2, color=prob, colorscale='Viridis', opacity=0.5),
                           name="positive"))

fig.update_scenes(
    dict(
        xaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, 
                   ticks="", title="", backgroundcolor='white'),
        yaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, 
                   ticks="", title="",backgroundcolor='white'),
        zaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, 
                   ticks="",title="", backgroundcolor='white'),
        bgcolor='white'
    )
)

# update font to Helvetica
fig.update_layout(
    font=dict(family='Helvetica', size=12),
    height=600,
    width=600,
    margin=dict(l=0, r=0, b=0, t=0))

In [58]:

fig.write_image("off_manifolder_prob.pdf")

In [72]:
ofm_emb = off_mfd_data['ofm_emb']

fig = go.Figure()
fig.add_trace(go.Scatter3d(x=ofm_emb[:, 0], y=ofm_emb[:, 1], z=ofm_emb[:, 2], mode='markers',
                           marker=dict(size=2, color=prob, colorscale='Viridis', opacity=0.5),
                           name="positive"))

fig.update_scenes(
    dict(
        xaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, 
                   ticks="", title="", backgroundcolor='white'),
        yaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, 
                   ticks="", title="",backgroundcolor='white'),
        zaxis=dict(showgrid=False, zeroline=False, showline=False, showticklabels=False, 
                   ticks="",title="", backgroundcolor='white'),
        bgcolor='white'
    )
)
# add colorbar at right side, label 'low' and 'high' are set to the min and max of the volume
# shrink the height of the colorbar
fig.update_traces(marker=dict(colorbar=dict(title='Scores', tickvals=[0, 1], 
                                           tickfont=dict(family='Helvetica', size=12), len=0.5)))
# fig.update_traces(marker=dict(colorbar=dict(title='Scores', tickvals=[0, 1], tickfont=dict(family='Helvetica', size=12))))

# update font to Helvetica
fig.update_layout(
    font=dict(family='Helvetica', size=12),
    height=600,
    width=600,
    margin=dict(l=0, r=0, b=0, t=0))

# change default camera view
fig.update_layout(scene_camera=dict(eye=dict(x=1.87, y=-0.88, z=-0.64)))

In [73]:
# save to pdf but with a different camera angle
fig.write_image("off_manifolder_prob_emb.pdf")