In [1]:
import plotly.graph_objects as go
from ipywidgets import widgets
import numpy as np
import pandas as pd
import plotly.express as px
from copy import deepcopy
import tensorflow as tf
from scipy import spatial

In [2]:
beta = 1

In [3]:
# load data
input_imgs = np.load('data/train_gaf_ds.npy')
z_mean, z_log_var, z = np.load(f'data/z_all_train_beta_{beta}.npy')
output_imgs = np.load(f'data/dec_imgs_train_beta_{beta}.npy')
decoder = tf.keras.models.load_model(f'data/decoder_beta_{beta}')
labels = np.load(f'data/train_label.npy')



In [4]:
# deprecated -> causes an value error with an deviation factor but formula is right
class IGASF(tf.keras.layers.Layer):
    def __init__(self, gasf_post_scale = True, min_scale = 0, max_scale = 1):
        super(IGASF, self).__init__()
        self.gasf_post_scale = gasf_post_scale
        self.min_scale = min_scale
        self.max_scale = max_scale

    def call(self, inputs):
        diag = tf.linalg.diag_part(inputs)
        if self.gasf_post_scale:
            ts =  tf.sqrt(diag) 
        else:
            ts = tf.sqrt((diag + 1)/2)
        ts = ts #* (self.max_scale - self.min_scale) + self.min_scale
        return ts

inputs = tf.keras.Input(shape=(28,28,1),batch_size=None,dtype=tf.float32)
outputs = IGASF()(inputs)
postprocessor_model = tf.keras.models.Model(inputs = inputs, outputs = outputs)

In [5]:
# predict time series with given z-value
def predict_ts(z):
  img = decoder.predict(np.expand_dims(np.array(z),0))
  return postprocessor_model.predict(img)[0,:,0]

In [6]:
latent_df = pd.DataFrame({'ldim 1':z[:,0], 'ldim 2':z[:,1], 'ldim 3':z[:,2]})

In [7]:
# 3d latent plot setupt
latent_fig = px.scatter_3d(latent_df,x = 'ldim 1', y = 'ldim 2', z = 'ldim 3', opacity=0.5)
data=latent_fig.data
layout = go.Layout(title='latent dims')

figure = go.Figure(data=data, layout=layout)
figure.update_traces(marker_size=1)
f = go.FigureWidget(figure)
f.add_scatter3d(x = [0], y=[0],z = [0], name = 'Sliding Point')
pass

In [8]:
# predicted time series figure init
first_ts = predict_ts([0,0,0])
f2_fig = go.Figure()
ts_fig = go.Scatter(x = np.arange(len(first_ts)), y = first_ts, mode = 'lines')

f2_fig.add_trace(ts_fig)
# nearest neigbors
# f2_fig.add_trace(ts_fig)
# f2_fig.add_trace(ts_fig)
# f2_fig.add_trace(ts_fig)
# f2_fig.add_trace(ts_fig)
# f2_fig.add_trace(ts_fig)
f2 = go.FigureWidget(f2_fig)
f2.update_yaxes(range=[0, 1])
f2.update_layout(title = 'Predicted Time Series')
pass

In [9]:
# latent plot figures
l12_fig = go.Figure(px.scatter(x=z[:,0], y = z[:,1], trendline='ols',trendline_color_override='red'))
l12_fig.update_layout(xaxis_title = 'ldim 1', yaxis_title = 'ldim 2')
l12_fw = go.FigureWidget(l12_fig)

l13_fig = go.Figure(px.scatter(x=z[:,0], y = z[:,2], trendline='ols',trendline_color_override='red'))
l13_fig.update_layout(xaxis_title = 'ldim 1', yaxis_title = 'ldim 3')
l13_fw = go.FigureWidget(l13_fig)

l23_fig = go.Figure(px.scatter(x=z[:,1], y = z[:,2], trendline='ols',trendline_color_override='red'))
l23_fig.update_layout(xaxis_title = 'ldim 2', yaxis_title = 'ldim 3')
l23_fw = go.FigureWidget(l23_fig)

In [10]:
# slider for ldim 1
sl_1 = widgets.FloatSlider(
    value=0,
    min=min(z[:,0]),
    max=max(z[:,0]),
    step=0.1,
    description='ldim 1',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

In [11]:
# slider for ldim 2
sl_2 = widgets.FloatSlider(
    value=0,
    min=min(z[:,1]),
    max=max(z[:,1]),
    step=0.1,
    description='ldim 2',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

In [12]:
# slider for ldim 3
sl_3 = widgets.FloatSlider(
    value=0,
    min=min(z[:,2]),
    max=max(z[:,2]),
    step=0.1,
    description='ldim 3',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

In [13]:
# change figure data on slider val change
def on_change(change):
  f.data[1].x = [sl_1.value]
  f.data[1].y = [sl_2.value]
  f.data[1].z = [sl_3.value]
  # nearest_n = z[spatial.KDTree(z).query([sl_1.value,sl_2.value,sl_3.value], k = 5)[1]]
  # nearest_n_ts = [predict_ts([nn[0], nn[1], nn[2]]) for nn in nearest_n]
  f2.data[0].y = predict_ts([sl_1.value, sl_2.value, sl_3.value])
  # f2.data[1].y = nearest_n_ts[0]
  # f2.data[2].y = nearest_n_ts[1]
  # f2.data[3].y = nearest_n_ts[2]
  # f2.data[4].y = nearest_n_ts[3]
  # f2.data[5].y = nearest_n_ts[4]

In [14]:
sl_1.observe(on_change)
sl_2.observe(on_change)
sl_3.observe(on_change)

In [15]:
widgets.VBox([
widgets.HBox([sl_1,sl_2,sl_3]),
widgets.HBox([f2,f]),
widgets.HBox([l12_fw,l13_fw,l23_fw])
])

VBox(children=(HBox(children=(FloatSlider(value=0.0, continuous_update=False, description='ldim 1', max=2.6803…