# Plotly

In [1]:
import numpy as np

from plotly.subplots import make_subplots
import plotly.graph_objs as go
from plotly.offline import init_notebook_mode
init_notebook_mode(connected=True)

import ipywidgets as widgets
from ipywidgets import interact

# Create an interactive plot of the trajectory of a projectile.

## Plot the x,y axes

In [2]:
def create_interactive_plot(encoded_samples, time_steps, points, N_start, N_end):
    fig = make_subplots(rows=1, cols=3)
    figWidget = go.FigureWidget(fig)
    # add a black background to the figure
    figWidget.add_scatter(x=points[:,0], y=points[:,1], mode='lines', marker=dict(size=1), name='true positions', col=1, row=1)
    figWidget.add_scatter(x=[points[0,0]], y=[points[0,1]], mode='markers', marker=dict(size=10), name='current point', col=1, row=1)

    figWidget.add_scatter(x=time_steps, y=encoded_samples[:, 0], mode='lines', name='coord1', col=2, row=1)
    figWidget.add_scatter(x=[time_steps[0]], y=[encoded_samples[0,0]], mode='markers', marker=dict(size=10), name='current coord1', col=2, row=1)

    figWidget.add_scatter(x=time_steps, y=encoded_samples[:, 1], mode='lines', name='coord2', col=3, row=1)
    figWidget.add_scatter(x=[time_steps[0]], y=[encoded_samples[0,1]], mode='markers', marker=dict(size=10), name='current coord2', col=3, row=1)

    # figWidget.update_layout(xaxis=dict(range=[points[:,0].min(), points[:,0].max()]), yaxis=dict(range=[points[:,1].min(), points[:,1].max()]),
    #                 xaxis2=dict(range=[N_start*dt, N_end*dt]), yaxis2=dict(range=[0.49, 0.52]),
    #                     xaxis3=dict(range=[N_start*dt, N_end*dt]), yaxis3=dict(range=[0.46, 0.48]),
    # )


    # change the color of the scatter points

    figWidget.data[0].marker.color = 'blue'
    figWidget.data[1].marker.color = 'red'

    figWidget.data[2].marker.color = 'blue'
    figWidget.data[3].marker.color = 'purple'

    figWidget.data[4].marker.color = 'blue'
    figWidget.data[5].marker.color = 'orange'

    @interact(t=(time_steps.min(),time_steps.max(),dt))
    def update_plot(t):
        with figWidget.batch_update():
            # change the current point of 
            # the true position along the trajectory
            figWidget.data[1].x = [points[int(t/dt) - N_start,0]]
            figWidget.data[1].y = [points[int(t/dt) - N_start,1]]

            # the current point of the encoded sample on the x axis
            figWidget.data[3].x = [time_steps[int(t/dt) - N_start]]
            figWidget.data[3].y = [encoded_samples[int(t/dt) - N_start,0]]

            # the current point of the encoded sample on the y axis
            figWidget.data[5].x = [time_steps[int(t/dt) - N_start]]
            figWidget.data[5].y = [encoded_samples[int(t/dt) - N_start,1]]

        # print(figWidget)
            
    # save the figure as a html file with the interactive widget
    # figWidget.write_html('images/AE_ODE/encoded_reconstruction/encoded_reconstruction_slow_spiral.html')    
    return figWidget


def create_interactive_plot_phase(encoded_samples, time_steps, points, N_start, N_end):
    fig = make_subplots(rows=1, cols=2)
    figWidget = go.FigureWidget(fig)
    # add a black background to the figure
    figWidget.add_scatter(x=points[:,0], y=points[:,1], mode='lines', marker=dict(size=1), name='true positions', col=1, row=1)
    figWidget.add_scatter(x=[points[0,0]], y=[points[0,1]], mode='markers', marker=dict(size=10), name='current point', col=1, row=1)

    figWidget.add_scatter(x=encoded_samples[:, 0], y=encoded_samples[:, 1], mode='lines', name='phase', col=2, row=1)
    figWidget.add_scatter(x=[encoded_samples[0, 0]], y=[encoded_samples[0, 1]], mode='markers', marker=dict(size=10), name='current phase', col=2, row=1)
    # figWidget.update_layout(xaxis=dict(range=[points[:,0].min(), points[:,0].max()]), yaxis=dict(range=[points[:,1].min(), points[:,1].max()]),
    #                 xaxis2=dict(range=[N_start*dt, N_end*dt]), yaxis2=dict(range=[0.49, 0.52]),
    #                     xaxis3=dict(range=[N_start*dt, N_end*dt]), yaxis3=dict(range=[0.46, 0.48]),
    # )


    # change the color of the scatter points

    figWidget.data[0].marker.color = 'blue'
    figWidget.data[1].marker.color = 'red'

    figWidget.data[2].marker.color = 'blue'
    figWidget.data[3].marker.color = 'purple'

    @interact(t=(time_steps.min(),time_steps.max(),dt))
    def update_plot(t):
        with figWidget.batch_update():
            # change the current point of 
            # the true position along the trajectory
            figWidget.data[1].x = [points[int(t/dt) - N_start,0]]
            figWidget.data[1].y = [points[int(t/dt) - N_start,1]]

            # the current point of the encoded sample on the x axis
            figWidget.data[3].x = [encoded_samples[int(t/dt) - N_start,0]]
            figWidget.data[3].y = [encoded_samples[int(t/dt) - N_start,1]]

        # print(figWidget)
            
    # save the figure as a html file with the interactive widget
    # figWidget.write_html('images/AE_ODE/encoded_reconstruction/encoded_reconstruction_slow_spiral.html')    
    return figWidget


In [3]:
import torch
from src.models.ae import ConvAE

In [4]:
pathInputs = "images/AE_ODE/encoded_reconstruction/trues_images_fast_spiral.npy"
pathModel = "models/AE/FirstExplo/conv_custom_2_gaussian_r_1_custom_loss_alpha_0_0.pt"

images = np.load(pathInputs)
model = ConvAE(height=28, width=28, latent_dim=2, in_channels=1, relu=True, activation=torch.nn.ReLU())
model.load_state_dict(torch.load(pathModel))

Number of parameters in the model: 220387


<All keys matched successfully>

In [5]:
model.encode(torch.from_numpy(images).float()).shape

torch.Size([10000, 2])

In [6]:
encoded_samples = model.encode(torch.from_numpy(images).float()).detach().numpy()
encoded_samples.shape

(10000, 2)

In [7]:
np.save("images/AE_ODE/encoded_reconstruction/continuous_encoded_samples_fast_spiral.npy", encoded_samples)

In [8]:
N_start = 0
N_end = 10000
step_size = 1
plot_encoded_samples = encoded_samples[N_start:N_end:step_size]
print(plot_encoded_samples.shape)
dt = 0.001
time_steps = np.arange(0, plot_encoded_samples.shape[0]*dt, dt)[N_start:N_end:step_size]


points = np.load("images/AE_ODE/encoded_reconstruction/trues_positions_fast_spiral.npy")[N_start:N_end:step_size]
print(points.shape)

create_interactive_plot(plot_encoded_samples, time_steps, points, N_start, N_end)

(10000, 2)
(10000, 2)


interactive(children=(FloatSlider(value=4.999, description='t', max=9.999, step=0.001), Output()), _dom_classe…

FigureWidget({
    'data': [{'marker': {'color': 'blue', 'size': 1},
              'mode': 'lines',
          …

In [9]:
create_interactive_plot_phase(plot_encoded_samples, time_steps, points, N_start, N_end)

interactive(children=(FloatSlider(value=4.999, description='t', max=9.999, step=0.001), Output()), _dom_classe…

FigureWidget({
    'data': [{'marker': {'color': 'blue', 'size': 1},
              'mode': 'lines',
          …

In [10]:
N_start = 0
N_end = 8000
step_size = 1
encoded_samples = np.load('images/AE_ODE/encoded_reconstruction/encoded_samples_fast_spiral.npy')[N_start:N_end:step_size]
print(encoded_samples.shape)
dt = 0.001
time_steps = np.arange(0, encoded_samples.shape[0]*dt, dt)[N_start:N_end:step_size]


points = np.load("images/AE_ODE/encoded_reconstruction/trues_positions_fast_spiral.npy")[N_start:N_end:step_size]
print(points.shape)

create_interactive_plot(encoded_samples, time_steps, points, N_start, N_end)

(8000, 2)
(8000, 2)


interactive(children=(FloatSlider(value=3.999, description='t', max=7.9990000000000006, step=0.001), Output())…

FigureWidget({
    'data': [{'marker': {'color': 'blue', 'size': 1},
              'mode': 'lines',
          …