# Load prediction data and visualize timesteps

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
from ipywidgets import interactive, widgets, interact

import os
import json
import numpy as np
import matplotlib.pyplot as plt
import tikzplotlib

filename_base = 'prediction_a_0_100_u_p_net_64_final'

notebook_path = os.path.abspath("notebook_visualize_predicted_data.ipynb")

folder = os.path.join(os.path.dirname(notebook_path), "data/prediction_heatequation/")
# folder = os.path.join(os.path.dirname(notebook_path), "data/prediction_waveequation/")

input_sizes = [3,1,1]
num_inputs = len(input_sizes)

inputs_out = []
for j in range(num_inputs):
    path = folder+filename_base+'_inputs_'+str(j)+'_out.npy'
    inputs_out.append(np.load(path, allow_pickle=True))
    
outputs_pathname = folder+filename_base+"_outputs.npy"
reference_outputs_pathname = folder+filename_base+"_reference_outputs.npy"
output_diffs_pathname = folder+filename_base+"_output_diffs.npy"

outputs = np.load(outputs_pathname, allow_pickle=True)
reference_outputs = np.load(reference_outputs_pathname, allow_pickle=True)
output_diffs = np.load(output_diffs_pathname, allow_pickle=True)

with open(folder+filename_base+"_error_metrics.json") as json_file:
    error_metrics = json.load(json_file)

print(inputs_out[1].shape)
print(reference_outputs.shape)
print(outputs.shape)
print(output_diffs.shape)

start_step = 0
num_steps = outputs.shape[0]

num_figs_per_ts = 8
num_steps = 2
figE, axsE = plt.subplots(2, 2, figsize=(20,8))
imE = []

imE.append(axsE[0,0].plot(error_metrics['step_prediction'], error_metrics['domain']['raw']['rmse']))
imE.append(axsE[1,0].plot(error_metrics['step_prediction'], error_metrics['domain']['raw']['mse']))
imE.append(axsE[0,1].plot(error_metrics['step_prediction'], error_metrics['domain']['raw']['mae']))
imE.append(axsE[1,1].plot(error_metrics['step_prediction'], error_metrics['domain']['raw']['ssp']))

axsE[0,0].title.set_text('RMSE')
axsE[1,0].title.set_text('MSE')
axsE[0,1].title.set_text('MAE')
axsE[1,1].title.set_text('SSP')

plt.show()

num_figs_per_ts = 8
num_steps = 2
fig1, axs1 = plt.subplots(num_steps, num_figs_per_ts, figsize=(20,4))
fig1.suptitle('Inputs, Truth, Prediction and Absolute difference at first and last step')
im1 = []

im1.append(axs1[0,0].imshow(inputs_out[0][0,:,:,0], origin='lower'))
im1.append(axs1[0,1].imshow(inputs_out[0][0,:,:,1], origin='lower'))
im1.append(axs1[0,2].imshow(inputs_out[0][0,:,:,2], origin='lower'))
im1.append(axs1[0,3].imshow(inputs_out[1][0,:,:], origin='lower'))
im1.append(axs1[0,4].imshow(inputs_out[2][0,:,:], origin='lower'))
im1.append(axs1[0,5].imshow(reference_outputs[0], origin='lower'))
im1.append(axs1[0,6].imshow(outputs[0], origin='lower'))
im1.append(axs1[0,7].imshow(output_diffs[0], origin='lower'))

im1.append(axs1[1,0].imshow(inputs_out[0][-1,:,:,0], origin='lower'))
im1.append(axs1[1,1].imshow(inputs_out[0][-1,:,:,1], origin='lower'))
im1.append(axs1[1,2].imshow(inputs_out[0][-1,:,:,2], origin='lower'))
im1.append(axs1[1,3].imshow(inputs_out[1][-1,:,:], origin='lower'))
im1.append(axs1[1,4].imshow(inputs_out[2][-1,:,:], origin='lower'))
im1.append(axs1[1,5].imshow(reference_outputs[-1], origin='lower'))
im1.append(axs1[1,6].imshow(outputs[-1], origin='lower'))
im1.append(axs1[1,7].imshow(output_diffs[-1], origin='lower'))

axs1[0,0].set_title('Input 0 [0]')
axs1[0,1].set_title('Input 0 [1]')
axs1[0,2].set_title('Input 0 [2]')
axs1[0,3].set_title('Input 1')
axs1[0,4].set_title('Input 2')
axs1[0,5].set_title('Truth')
axs1[0,6].set_title('Prediction')
axs1[0,7].set_title('Diff(Truth, Pred.)')
axs1[1,0].set_title('Input 0 [0]')
axs1[1,1].set_title('Input 0 [1]')
axs1[1,2].set_title('Input 0 [2]')
axs1[1,3].set_title('Input 1')
axs1[1,4].set_title('Input 2')
axs1[1,5].set_title('Truth')
axs1[1,6].set_title('Prediction')
axs1[1,7].set_title('Diff(Truth, Pred.)')

for i in range(num_steps):
    fig1.colorbar(im1[i*num_figs_per_ts+0], ax=axs1[i,0])
    fig1.colorbar(im1[i*num_figs_per_ts+1], ax=axs1[i,1])
    fig1.colorbar(im1[i*num_figs_per_ts+2], ax=axs1[i,2])
    fig1.colorbar(im1[i*num_figs_per_ts+3], ax=axs1[i,3])
    fig1.colorbar(im1[i*num_figs_per_ts+4], ax=axs1[i,4])
    fig1.colorbar(im1[i*num_figs_per_ts+5], ax=axs1[i,5])
    fig1.colorbar(im1[i*num_figs_per_ts+6], ax=axs1[i,6])
    fig1.colorbar(im1[i*num_figs_per_ts+7], ax=axs1[i,7])
    
plt.show()

import plotly.graph_objs as go
import plotly.offline as py
import plotly

from ipywidgets import interactive, HBox, VBox, widgets, interact

py.init_notebook_mode()


setup = 2 
plottype = 'surf'

if 'heatequation' in folder:
    setup = 1
    plottype = 'heatmap'


if setup == 1:
    zLimitMin = 0.00
    zLimitMax = 1.00
    # outputs_flipped = np.flip(outputs, 2)    
    # field = np.concatenate((reference_outputs, outputs_flipped), axis=2)
    field = np.concatenate((reference_outputs, outputs), axis=2)
elif setup == 2:
    zLimitMin = -0.1
    zLimitMax = 0.1
    # outputs_flipped = np.flip(outputs, 2)    
    # field = np.concatenate((reference_outputs, outputs_flipped), axis=2)
    field = np.concatenate((reference_outputs, outputs), axis=2)
    
aspect = field.shape[2]/field.shape[1]

if plottype == 'heatmap':
    
    imgHeatmap = go.Heatmap(z=field[0,:,:], zmin=-zLimitMin, zmax=zLimitMax)
    fig = go.Figure(data=[imgHeatmap])
    fig.update_layout(yaxis = dict(scaleanchor = 'x'))
    
elif plottype == 'surf':

    surf = go.Surface(z=field[0,:,:], cmin=zLimitMin, cmax=zLimitMax)
    fig = go.Figure(data=[surf])
    fig.update_layout(scene_aspectmode='manual',
                  scene_aspectratio=dict(x=aspect, y=1, z=1))
    fig.update_layout(scene = dict(
                    xaxis = dict(nticks=4, range=[0,field.shape[2]],),
                    yaxis = dict(nticks=4, range=[0,field.shape[1]],),
                    zaxis = dict(nticks=4, range=[zLimitMin,zLimitMax],),),)
    
    xdim = field.shape[2]
    ydim = field.shape[1]

    camera = dict(
                up=dict(x=0, y=1, z=0),
                center=dict(x=0, y=0, z=0),
                eye=dict(x=0, y=0.0, z=3.5)
            )

    fig.update_layout(scene_camera=camera)
    
fHeatmap = go.FigureWidget(data=fig.data, layout=fig.layout)

sliderHeatmap = widgets.IntSlider(
    min=0,
    max=field.shape[0]-1,
    step=1,
    readout=False,
    description='Timestep')
sliderHeatmap.layout.width = '800px'

def updateHeatmap(x):
    print(x)
    fHeatmap.data[0].z = field[x,:,:]

vbHeatmap = VBox((fHeatmap, interactive(updateHeatmap, x=sliderHeatmap)))
vbHeatmap.layout.align_items = 'center'
vbHeatmap