In [2]:
import numpy as np

import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from model import Model

In [4]:
def testData():
    theta = np.random.random(25) * np.pi * 1.2
    r = np.random.random(25) * 0.3 + 0.6

    x1 = np.concatenate(((r * np.cos(theta))[:,None], (r * np.sin(theta))[:,None]), axis=1)
    y1 = np.ones(25)
    
    theta = np.random.random(25) * np.pi * 1.2 + 1.2 * np.pi
    r = np.random.random(25) * 0.2 + 0.05

    x0 = np.concatenate(((r * np.cos(theta))[:,None], (r * np.sin(theta))[:,None]), axis=1)
    y0 = np.zeros(25)
    
    x = np.concatenate((x0, x1), axis=0)
    y = np.concatenate((y0, y1), axis=0)
    
    return x, y[:,None].astype(float)

In [5]:
# TODO: Train and save data for visualization

np.random.seed(35) # 

x, y = testData()
model = Model()

data = {'x':[], 'dl':[], 'grid':[], 'acc':[]}

xs, dls, grids = model.spaces(x)
acc = model.accuracy(x, y)

data['x'].append(xs)
data['dl'].append(dls)
data['grid'].append(grids)
data['acc'].append(acc)

for t in range(120): 
    
    for i in range(1): 
        
        loss = model.opt(x, y, 0.1)
        
        acc = model.accuracy(x, y)

    # print(f"{loss = :<10.8} {acc = :.4%}")

    xs, dls, grids = model.spaces(x)
    
    data['x'].append(xs)
    data['dl'].append(dls)
    data['grid'].append(grids)
    data['acc'].append(acc)

In [6]:
# make subplots
fig = make_subplots(
    rows=1,
    cols=3,
    subplot_titles=(
        "Orignal Space",
        "Latent Space 1",
        "Latent Space 2",
        "Latent Space 3",
    ),
    horizontal_spacing=0.05,
    vertical_spacing=0.01,
)

In [None]:
# the initial plot
t = 1

xs = data["x"][t]
dls = data["dl"][t]
grids = data["grid"][t]

for i in range(3):

    for j in range(len(grids[i])):
        fig.add_trace(
            go.Scatter(
                x=grids[i][j][:, 0],
                y=grids[i][j][:, 1],
                mode="lines",
                marker_color="lightgreen",
                line={"dash": "dash"},
                hoverinfo='none',
            ),
            row=1,
            col=i + 1,
        )

    fig.add_traces(
        [
            go.Scatter(
                x=xs[i][:, 0],
                y=xs[i][:, 1],
                mode="markers",
                marker_color=list(map(lambda x: "orange" if x else "blue", y)),
                hoverinfo='skip',
            ),
            go.Scatter(x=dls[i][:, 0], y=dls[i][:, 1], mode="lines", line_color="red", hoverinfo='skip'),
        ],
        rows=1,
        cols=i + 1,
    )

# build each frame
frames = []

for t in range(len(data["x"])):

    xs = data["x"][t]
    dls = data["dl"][t]
    grids = data["grid"][t]

    frame = []

    for i in range(3):

        for j in range(len(grids[i])):
            frame.append(
                go.Scatter(
                    x=grids[i][j][:, 0],
                    y=grids[i][j][:, 1],
                    mode="lines",
                    marker_color="lightgreen",
                    line={"dash": "dash"}, 
                    hoverinfo='none',
                )
            )

        frame.extend(
            [
                go.Scatter(
                    x=xs[i][:, 0],
                    y=xs[i][:, 1],
                    mode="markers",
                    marker_color=list(map(lambda x: "orange" if x else "blue", y)),
                    hoverinfo='skip', 
                ),
                go.Scatter(
                    x=dls[i][:, 0], y=dls[i][:, 1], mode="lines", line_color="red", hoverinfo='skip'
                ),
            ]
        )

    frames.append(go.Frame(data=frame, traces=list(range(36)), name=str(t))) # Name used to link slider

fig.frames = frames

# build button
button_dict = {
    "buttons": [
        {
            "args": [
                None,
                {
                    "frame": {"duration": 50, "redraw": False},
                    "fromcurrent": True,
                    "transition": {"duration": 10},
                },
            ],
            "label": "Run",
            "method": "animate",
        },
        {
            "args": [
                [None],
                {
                    "frame": {"duration": 0, "redraw": False},
                    "mode": "immediate",
                    "transition": {"duration": 0},
                },
            ],
            "label": "Pause",
            "method": "animate",
        },
    ],
    "direction": "left",
    "pad": {"r": 10, "t": 87},
    "showactive": False,
    "type": "buttons",
    "x": 0.1,
    "xanchor": "right",
    "y": 0.05,
    "yanchor": "top",
}

fig.update_layout(updatemenus=[button_dict])

In [None]:
sliders_dict = {
    "active": 0,
    "yanchor": "top",
    "xanchor": "left",
    "currentvalue": {
        "font": {"size": 20},
        "prefix": "Step:",
        "visible": True,
        "xanchor": "right"
    },
    "transition": {"duration": 10},
    "pad": {"b": 10, "t": 50},
    "len": 0.9,
    "x": 0.1,
    "y": 0.05,
    "steps": []
}

for t in range(len(data["x"])):
    slider_step = {"args": [
        [t],
        {"frame": {"duration": 50, "redraw": False},
         "mode": "immediate",
         "transition": {"duration": 10}}
    ],
        "label": t,
        "method": "animate"}
    sliders_dict["steps"].append(slider_step)

fig.update_layout(updatemenus=[button_dict], sliders=[sliders_dict])

fig.update_layout(width=1200, height=550, showlegend=False, title="Visualize (Latent) Spaces during Training of Neural Network",)

fig.update_xaxes(range=[-1.05, 1.05])
fig.update_yaxes(range=[-1.05, 1.05])

plotly.offline.plot(fig, filename='visual.html');

# fig.show()