# Import packages

In [1]:
import ipywidgets as widgets
import numpy as np
import tensorflow as tf
from datetime import datetime
from IPython.display import clear_output
import math
import os
from sklearn.decomposition import PCA
import plotly.graph_objects as go
import pandas as pd

from parameters import *
from symae_model import SymAE
from MRA_generate import MRA_generate
from latent import latent

2022-06-22 15:48:07.415467: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


Num GPUs Available:  2
TensorFlow Version:  2.9.1


# Initialize SymAE and PCA

In [2]:
def g(n,x):
    if n==0:
        return math.e**(-9*x**2)
    elif n==1:
        return int(x<0.5)
    elif n==2:
        if x<0.3:
            return x
        elif x<0.6:
            return 0.6-x
        else:
            return 0
    elif n==3:
        return math.cos(2*math.pi*x)
    elif n==4:
        return math.e**(-30*(x-0.5)**2)
    else:
        return np.inf
model = SymAE(N,nt,d,p,q,kernel_size,filters,dropout_rate)
model.load_weights('./checkpoint/'+datetime.now().strftime("%B%d"))
clear_output()
MRA = MRA_generate(d,nt,1000,sigma,ne,g,replace=False,outer_replace=True)
MRA.generate_default()
Cs, Ns = latent(model,MRA)
Ns = Ns.reshape(-1,q)
coherent_pca = PCA(n_components=2)
nuisance_pca = PCA(n_components=2)
pca_C = coherent_pca.fit_transform(Cs)
pca_N = nuisance_pca.fit_transform(Ns)
coherent_centers = np.empty((ne,2))
for i in range(ne):
    coherent_centers[i,:] = sum(pca_C[MRA.states==i,:])/sum(MRA.states==i)
df = pd.DataFrame(pca_N, columns = ['1st','2nd'])
df_states = pd.DataFrame(np.repeat(MRA.states,nt),columns=['state']).astype('object')
df = pd.concat([df,df_states],axis=1)

def dec(latent_code):
    tem = latent_code[np.newaxis, np.newaxis, :]
    tem = np.repeat(tem, nt, axis=1)
    tem = model.mixer.predict(tem, verbose=0)
    return tem[0,0,:,0]
def pca_dec(pca_latent_coherent_code, pca_latent_nuisance_code):
    latent_coherent_code = coherent_pca.inverse_transform(pca_latent_coherent_code)
    latent_nuisance_code = nuisance_pca.inverse_transform(pca_latent_nuisance_code)
    return dec(np.concatenate([latent_coherent_code,latent_nuisance_code],axis=0))

2022-06-22 15:48:28.855471: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8302

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.
2022-06-22 15:48:30.336252: W tensorflow/stream_executor/gpu/asm_compiler.cc:230] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.6
2022-06-22 15:48:30.336295: W tensorflow/stream_executor/gpu/asm_compiler.cc:233] Used ptxas at ptxas
2022-06-22 15:48:30.336359: W tensorflow/stream_executor/gpu/redzone_allocator.cc:314] UNIMPLEMENTED: ptxas ptxas too old. Falling back to the driver to compile.
Relying on driver to perform ptx compilation. 
Modify $PATH to customize ptxas location.
This message will be only logged once.
2022-06-22 15:48:30.417716: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


# Initialize plotly and widget

In [48]:
pca_latent_coherent_code_1st = widgets.FloatSlider(min=-12.0, 
                                                 max=12.0, 
                                                 step=0.1,
                                                 description='c-1st:', 
                                                 continuous_update=False,
                                                 readout_format='.1f',
                                                 disabled=False,
                                                 value=0.0)
pca_latent_coherent_code_2nd = widgets.FloatSlider(min=-6.0, 
                                                 max=6.0, 
                                                 step=0.1,
                                                 description='c-2nd:', 
                                                 continuous_update=False,
                                                 orientation='vertical',
                                                 readout_format='.1f',
                                                 layout=widgets.Layout(width='40px',
                                                                      height='160px'),
                                                 disabled=False,
                                                 value=0.0)
pca_latent_nuisance_code_1st = widgets.FloatSlider(min=-60.0, 
                                                 max=60.0, 
                                                 step=0.1,
                                                 description='n-1st:', 
                                                 continuous_update=False,
                                                 readout_format='.1f',
                                                 disabled=False,
                                                 value=0.0)
pca_latent_nuisance_code_2nd = widgets.FloatSlider(min=-60.0, 
                                                 max=60.0, 
                                                 step=0.1,
                                                 description='n-2nd:', 
                                                 continuous_update=False,
                                                 orientation='vertical',
                                                 readout_format='.1f',
                                                 layout=widgets.Layout(width='40px',
                                                                      height='160px'),
                                                 disabled=False,
                                                 value=0.0)
signal_output_fig = go.FigureWidget()
signal_output_fig.add_scatter(y=list(range(d)))


latent_coherent_space = go.FigureWidget()
for i in range(ne):
    latent_coherent_space.add_trace(go.Scatter(x=[coherent_centers[i,0]], 
                                               y=[coherent_centers[i,1]],
                                               mode='markers',
                                               marker=dict(size=8)))
coherent_varible_data = go.Scatter(x=np.array([0.0]), y=np.array([0.0]), 
                                   mode='markers', marker=dict(size=12))
latent_coherent_space.add_trace(coherent_varible_data)
    
latent_nuisance_space = go.FigureWidget()
for i in range(ne):
    latent_nuisance_space.add_trace(go.Scatter(x=df['1st'][df['state']==i], 
                                               y=df['2nd'][df['state']==i],
                                               mode='markers',
                                               marker=dict(size=4)))
latent_nuisance_space.add_trace(go.Scatter(x=np.array([0.0]), y=np.array([0.0]), 
                                           mode='markers',
                                           marker=dict(size=12)))
    
latent_coherent_space.update_layout(width=250, height=150, 
                                    margin=dict(l=0,r=0,b=0,t=0,pad=0))
latent_nuisance_space.update_layout(width=250, height=150, 
                                    margin=dict(l=0,r=0,b=0,t=0,pad=0))
signal_output_fig.update_layout(width=450, height=400, 
                                margin=dict(l=0,r=1,b=0,t=0,pad=0))
def response(change):
    pca_latent_coherent_code = [pca_latent_coherent_code_1st.value, pca_latent_coherent_code_2nd.value]
    pca_latent_nuisance_code = [pca_latent_nuisance_code_1st.value, pca_latent_nuisance_code_2nd.value]
    signal_reconstruct = pca_dec(pca_latent_coherent_code, pca_latent_nuisance_code)
    with latent_coherent_space.batch_update():
        latent_coherent_space.data[ne].x = np.array(pca_latent_coherent_code[0])
        latent_coherent_space.data[ne].y = np.array(pca_latent_coherent_code[1])
    with latent_nuisance_space.batch_update():
        latent_nuisance_space.data[ne].x = np.array(pca_latent_nuisance_code[0])
        latent_nuisance_space.data[ne].y = np.array(pca_latent_nuisance_code[1]) 
    with signal_output_fig.batch_update():
        signal_output_fig.data[0].y = signal_reconstruct
pca_latent_coherent_code_1st.observe(response)
pca_latent_coherent_code_2nd.observe(response)
pca_latent_nuisance_code_1st.observe(response)
pca_latent_nuisance_code_2nd.observe(response)
container_coherent = widgets.VBox([pca_latent_coherent_code_1st,
                                  widgets.HBox([pca_latent_coherent_code_2nd, 
                                                 latent_coherent_space])])
container_nuisance = widgets.VBox([pca_latent_nuisance_code_1st,
                                  widgets.HBox([pca_latent_nuisance_code_2nd, 
                                                 latent_nuisance_space])])
container_varibles = widgets.VBox([container_coherent, container_nuisance])
fig = widgets.HBox([container_varibles, signal_output_fig])

# Create New View for Output

In [49]:
fig

HBox(children=(VBox(children=(VBox(children=(FloatSlider(value=0.0, continuous_update=False, description='c-1s…