In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [150]:
import torch
import os
import pandas as pd
import pickle
import numpy as np
import math

import plotly.graph_objects as go
import  plotly.express  as px
from plotly.subplots import make_subplots

## Initial Results: Observing predictions for Australia Dataset

In [156]:
print("Average prediction for mean term in predicted distribution for rainfall")
{ k:data['target_rain_value'][:,2].mean() for k,data in test_output.items() } 

Average prediction for mean term in predicted distribution for rainfall


{'Adelaide': 1.4793103448275864,
 'Albury': 1.8355164835164837,
 'AliceSprings': 0.8505788067675869,
 'BadgerysCreek': 2.221972132904609,
 'Ballarat': 1.5295627376425855,
 'Brisbane': 3.0472122302158273,
 'Cairns': 5.570773152081563,
 'Canberra': 1.8645805592543274,
 'Cobar': 1.1269573370839194,
 'CoffsHarbour': 4.664722536806342,
 'Dartmoor': 2.1072965388213283,
 'Darwin': 5.027389951089373,
 'GoldCoast': 3.647114030971375,
 'Hobart': 1.633318834275772,
 'Moree': 1.3129963898916965,
 'MountGambier': 1.9762770562770557,
 'Nhil': 0.8076470588235295,
 'NorahHead': 3.4462062256809336,
 'NorfolkIsland': 3.0599821348816434,
 'Nuriootpa': 1.2716460513796384,
 'PearceRAAF': 0.8476275738585497,
 'Perth': 1.8391752577319587,
 'PerthAirport': 1.8171142467000454,
 'Portland': 2.4377013963480128,
 'Richmond': 2.074204601076848,
 'Sydney': 3.3681476418318526,
 'SydneyAirport': 2.7280233122875184,
 'Townsville': 3.2214556163230963,
 'Tuggeranong': 2.082119205298013,
 'Uluru': 0.36897689768976893,
 '

In [157]:
print("Pred Disp: min, max", pred_disp.min(), pred_disp.max())
print("Pred Mean: min, max", pred_mean.min(), pred_mean.max())
print("Prob: min, max", pred_prob.min(), pred_prob.max())

Pred Disp: min, max 0.10657 1.201
Pred Mean: min, max 0.2095 4.805
Prob: min, max 9.316e-05 0.999


In [158]:

# Getting data
path_ = r"../Checkpoints/australia_rain/DGLM_HLSTM/lightning_logs/version_6/test_output.pkl"
test_output = pickle.load( open(path_,"rb") )
test_data = test_output['Cairns'] 

# The predictions are a batch of windows of 7 days. However each element in the batch has 6-day overlap with the next 7 day period
day_idx = 2
pred_mean = test_data['pred_mean'][:,day_idx] #.shape
pred_prob = test_data['pred_prob'][:,day_idx] 

pred_rain = np.where( pred_prob<0.5, 0, np.exp(pred_mean) )

pred_disp = test_data['pred_disp'][:, day_idx]
target_rain_value = test_data['target_rain_value'][:,day_idx]

dates = [ date_index[day_idx] for date_index in test_data['date'] if len(date_index)>day_idx]

# Setting up plot params
data_len = pred_mean.size
datums_in_plot = 120
cols = 1
rows = math.ceil( data_len/(datums_in_plot*cols) )

# Making figure
fig = make_subplots(rows=rows, cols=cols, start_cell="top-left", specs=[ [{"secondary_y": True}]*cols ]*rows )
idx = 0
for row in range(1, rows+1):
    for col in range(1, cols+1):
        
#         mean = pred_mean[idx:idx+datums_in_plot]
        disp = pred_disp[idx:idx+datums_in_plot]
        rain = pred_rain[idx:idx+datums_in_plot]
        trv = target_rain_value[idx:idx+datums_in_plot]
        x = dates[idx:idx+datums_in_plot]
        
        show_legend = (row==1 and col==1)
        
#         fig.add_trace( go.Scatter(x=x, y=mean.tolist(), name='preds_mean', legendgroup="preds_mean", line=dict(color='purple'), showlegend=show_legend), row=row, col=col) 
        
        fig.add_trace(go.Scatter(x=x, y=rain.tolist(),name='pred_rain',mode='lines',legendgroup="preds_rain", line=dict(color='blue'), showlegend=show_legend),
                  row=row, col=col) 
        
        fig.add_trace(go.Scatter(x=x, y=trv.tolist(),name='obs',mode='lines', legendgroup="obs", line=dict(color='red'), showlegend=show_legend),
                  row=row, col=col) 

        fig.add_trace(go.Scatter(x=x, y=disp.tolist(),name='disp',mode='lines', legendgroup="disp", line=dict(color='purple'), showlegend=show_legend),
                  row=row, col=col, secondary_y=True)
        
        idx = idx + datums_in_plot

fig.update_layout(height=6600, width=1600, title_text="Predictions", showlegend=True)
fig.show()