In [None]:
%matplotlib inline

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (16.0, 8.0)

from gluonts.nursery.ts_embeddings.run import MyDataModule

In [None]:
encoder_model_path = './encoder.pt'
print('loading encoder')
loaded_encoder = torch.jit.load(encoder_model_path)

In [None]:
dm = MyDataModule(ts_len=3 * 7 * 24, dataset_name='traffic', num_workers=4,
                 batch_size=128)
dm.setup()

In [None]:
series = []
for i, batch in enumerate(dm.train_dataloader()):
    series.append(batch)
    if len(series) > 500:
        break
    
series = torch.cat(series, dim=0)
print('running encoder')

with torch.no_grad():
    embed = loaded_encoder(series).numpy()

In [None]:
from sklearn.manifold import TSNE
tsne_out = TSNE(n_components=2, perplexity=80).fit_transform(embed)
import matplotlib.pyplot as plt
plt.scatter(tsne_out[:, 0], tsne_out[:, 1])
plt.show()

In [None]:
import plotly.graph_objs as go
import plotly.offline as py

from ipywidgets import interactive, HBox, VBox

py.init_notebook_mode()

In [None]:
targets = series.numpy()

In [None]:
f = go.FigureWidget([go.Scatter(
        x=tsne_out[:, 0],
        y=tsne_out[:, 1],
        mode='markers',
)])
scatter = f.data[0]

t = go.FigureWidget([
    go.Scatter(x=np.arange(len(targets[0])), y=targets[0]),
])

def p(s, points, input_state):
    if len(points.point_inds) != 1:
        return
    i = points.point_inds[0]
    ts = targets[i]
    
    t.data[0].x = np.arange(len(ts))
    t.data[0].y = ts
    t.update_layout(yaxis=dict(range=[0, 1.1 * np.nanmax(ts)]))

scatter.on_hover(p)

VBox((t,f))