In [1]:
import jax.numpy as np
from jax import random, grad, vmap, jit
from jax.config import config

import numpy as onp

import matplotlib

import matplotlib.pyplot as plt

import scipy.io
from scipy.interpolate import griddata


In [2]:
plt.rcParams.update(plt.rcParamsDefault)
plt.rc('font', family='serif')
plt.rcParams.update({
                      "text.usetex": True,
                      "font.family": "serif",
                     'text.latex.preamble': r'\usepackage{amsmath}',
                      'font.size': 16,
                      'lines.linewidth': 3,
                      'axes.labelsize': 20, 
                      'axes.titlesize': 16,
                      'xtick.labelsize': 16,
                      'ytick.labelsize': 16,
                      'legend.fontsize': 16,
                      'axes.linewidth': 2})

# Load the training data

P = 72
m = int(72*72)
N_train = 1825
N_test  = 1825

lon = np.linspace(0,355,num=72)
lat = np.linspace(90,-87.5,num=72)

# lon[-1]= 360
lons,lats= np.meshgrid(lon,lat)


d = np.load("weather_dataset.npz")
u_train = d["U_train"][:N_train,:]
S_train = d["S_train"][:N_train,:]/1000.
y_train = d["Y_train"]

u_train = np.array(u_train)
S_train = np.array(S_train)
y_train = np.array(y_train)

d = np.load("weather_dataset.npz")
u_test = d["U_train"][-N_test:,:]
S_test = d["S_train"][-N_test:,:]/1000.
y_test = d["Y_train"]

u_test = np.array(u_test)
S_test = np.array(S_test)
y_test = np.array(y_test)


u_mu_total, u_std_total = np.mean(u_train), np.std(u_train)
s_mu_total, s_std_total = np.mean(S_train), np.std(S_train)


u_train = (u_train - u_mu_total) / u_std_total
s_train = (S_train - s_mu_total) / s_std_total

u_test = (u_test - u_mu_total) / u_std_total
s_test = (S_test - s_mu_total) / s_std_total

print("normalizing constant", u_mu_total, u_std_total, s_mu_total, s_std_total)

print("shape of training data", u_train.shape, S_train.shape, y_train.shape)
print("shape of testing data", u_test.shape, S_test.shape, y_test.shape)



normalizing constant 278.68533 20.77177 96.8152 9.130893
shape of training data (1825, 5184) (1825, 5184) (5184, 2)
shape of testing data (1825, 5184) (1825, 5184) (5184, 2)


In [3]:
Predict_mu_save = np.load("Predict_mu_save.npy")
Predict_std_save = np.load("Predict_std_save.npy")
errors = np.load("normed_errors.npy")

print(Predict_mu_save.shape, Predict_std_save.shape, errors.shape)
print(np.max(errors))

(1825, 72, 72) (1825, 72, 72) (1825, 1)
0.13143897


In [4]:
def format_func(value, tick_number):
    if value == 0.:
        return "$0^o$"
    elif value == 180.:
        return "$180^o$"
    elif value < 180.:
        return str(int(value)) +"$^o$ E"
    else:
        return str(360-int(value)) +"$^o$ W"



In [5]:
idx2 = 35
idx3 = 50 # 50 (good), 52(useful), 54, 56, 58 (worst), 60 (worst), 62(worst), 64, 66, 68, 70 

In [6]:
idxs = [30, 130, 230, 330, 430, 530, 630, 730, 830, 930, 1030, 1130, 1230, 1330, 1430, 1530, 1630, 1730]

for k in range(N_test):

    idx = k
    s_test_sample = (s_test[idx,:] * s_std_total + s_mu_total) / 1000. # May need [:,None]

    S_pred_sample_mu = (Predict_mu_save[idx,:,:] * s_std_total + s_mu_total) / 1000.
    S_pred_sample_std = Predict_std_save[idx,:,:] * s_std_total / 1000.

    u = np.reshape(s_test_sample, (72, 72))

    if k in idxs:

        fig = plt.figure(figsize=(12,9))
        plt.subplot(2,2,1)
        plt.pcolor(lons, lats, u, cmap='jet')
        plt.xlabel('$x$')
        plt.ylabel('$t$')
        plt.title('Exact $s(x,t)$')
        plt.colorbar()
        plt.tight_layout()

        plt.subplot(2,2,2)
        plt.pcolor(lons, lats, S_pred_sample_mu, cmap='jet')
        plt.xlabel('$x$')
        plt.ylabel('$t$')
        plt.title('Predictive $s(x,t)$')
        plt.colorbar()
        plt.tight_layout()

        plt.subplot(2,2,3)
        plt.pcolor(lons, lats, np.abs(S_pred_sample_mu - u), cmap='jet')
        plt.xlabel('$x$')
        plt.ylabel('$t$')
        plt.title('Absolute error')
        plt.colorbar()
        plt.tight_layout()

        plt.subplot(2,2,4)
        plt.pcolor(lons, lats, S_pred_sample_std, cmap='jet')
        plt.xlabel('$x$')
        plt.ylabel('$t$')
        plt.title('Predictive uncertainty')
        plt.colorbar()
        plt.tight_layout()
        plt.savefig('./normed_testing_Samples' + str(idx) + '.png', dpi = 300)

        
        fig, _axs = plt.subplots(nrows=1, ncols=3, figsize=(15,4))
        axs = _axs.flatten()
        ax1 = axs[0]
        ax2 = axs[1]
        ax3 = axs[2]

        ax1.plot(lon,u[20,:], 'b-', linewidth = 2, label = 'Exact')       
        ax1.plot(lon,S_pred_sample_mu[20,:], 'r--', linewidth = 2, label = 'Prediction')
        lower = S_pred_sample_mu[20,:] - 2.0*S_pred_sample_std[20,:]
        upper = S_pred_sample_mu[20,:] + 2.0*S_pred_sample_std[20,:]
        ax1.fill_between(lon.flatten(), lower.flatten(), upper.flatten(), 
                        facecolor='orange', alpha=0.5, label="Two std band")
        ax1.set_xlabel('lon')
        ax1.set_ylabel('$s(x)$')  
        ax1.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
        ax1.set_title('lat = $40^o$ N')
        #ax1.set_xticks(ax1.get_xticks()[::6])
        # ax.set_xlim([-0.1,1.1])
        # ax.set_ylim([-1.1,1.1])

        ax2.plot(lon,u[idx2,:], 'b-', linewidth = 2, label = 'Exact')       
        ax2.plot(lon,S_pred_sample_mu[idx2,:], 'r--', linewidth = 2, label = 'Prediction')
        lower = S_pred_sample_mu[idx2,:] - 2.0*S_pred_sample_std[idx2,:]
        upper = S_pred_sample_mu[idx2,:] + 2.0*S_pred_sample_std[idx2,:]
        ax2.fill_between(lon.flatten(), lower.flatten(), upper.flatten(), 
                        facecolor='orange', alpha=0.5, label="Two std band")
        ax2.set_xlabel('lon')
        ax2.set_ylabel('$s(x)$')
        ax2.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
        # ax.set_xlim([-0.1,1.1])
        # ax.set_ylim([-1.1,1.1])
        ax2.set_title('lat = $0^o$')
        #ax2.set_xticks(ax2.get_xticks()[::2])
        ax2.legend(loc='upper center', bbox_to_anchor=(0.5, -0.35), ncol=5, frameon=False)

        ax3.plot(lon,u[idx3,:], 'b-', linewidth = 2, label = 'Exact')       
        ax3.plot(lon,S_pred_sample_mu[idx3,:], 'r--', linewidth = 2, label = 'Prediction')
        lower = S_pred_sample_mu[idx3,:] - 2.0*S_pred_sample_std[idx3,:]
        upper = S_pred_sample_mu[idx3,:] + 2.0*S_pred_sample_std[idx3,:]
        ax3.fill_between(lon.flatten(), lower.flatten(), upper.flatten(), 
                        facecolor='orange', alpha=0.5, label="Two std band")
        ax3.set_xlabel('lon')
        ax3.set_ylabel('$s(x)$')
        ax3.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
        # ax.set_xlim([-0.1,1.1])
        # ax.set_ylim([-1.1,1.1])    
        ax3.set_title('lat = $40^o$ S')
        #ax3.set_xticks(ax3.get_xticks()[::6])


        fig.tight_layout(w_pad=-4.5)
        fig.savefig('./normed_testing_slices' + str(idx) + '.png')


  from ipykernel import kernelapp as app
  from ipykernel import kernelapp as app
  from ipykernel import kernelapp as app
  from ipykernel import kernelapp as app
  from ipykernel import kernelapp as app
  from ipykernel import kernelapp as app
  from ipykernel import kernelapp as app
  from ipykernel import kernelapp as app


In [7]:
import os
import conda

conda_file_dir = conda.__file__
conda_dir = conda_file_dir.split('lib')[0]
proj_lib = os.path.join(os.path.join(conda_dir, 'share'), 'proj')
os.environ["PROJ_LIB"] = proj_lib

from scipy.sparse import diags
import string
from netCDF4 import Dataset as NetCDFFile
from mpl_toolkits.basemap import Basemap


idx_max = np.argmax(errors)
idx = idx_max

s_test_sample = (s_test[idx,:] * s_std_total + s_mu_total) / 1000. # May need [:,None]

S_pred_sample_mu = (Predict_mu_save[idx,:,:] * s_std_total + s_mu_total) / 1000.
S_pred_sample_std = Predict_std_save[idx,:,:] * s_std_total / 1000.

u = np.reshape(s_test_sample, (72, 72))


######### Let's make better plot #########

lon = np.linspace(0,355,num=72)
lat = np.linspace(90,-87.5,num=72)

#lon[-1]= 360



fig = plt.figure(figsize=(12,6))
ax = fig.add_subplot(221)


map = Basemap(projection='cyl',llcrnrlon=0.,llcrnrlat=-85.,urcrnrlon=360.,urcrnrlat=85.,resolution='i') 

map.drawcoastlines()
map.drawstates()
map.drawcountries()
map.drawlsmask(land_color='Linen', ocean_color='#CCFFFF')

parallels = np.arange(-82.5,82.5,30)
meridians = np.arange(0.,355,30)
map.drawparallels(parallels,labels=[1,0,0,0],fontsize=10)
map.drawmeridians(meridians,labels=[0,0,0,1],fontsize=10)
    
lons,lats= np.meshgrid(lon,lat)
x,y = lons,lats
    
temp = map.contourf(x,y,u)
cb = map.colorbar(temp,"right", size="2%", pad="0%")
plt.title('Exact surface pressure in MPa')
    
ax = fig.add_subplot(222)

map = Basemap(projection='cyl',llcrnrlon=0.,llcrnrlat=-85.,urcrnrlon=360.,urcrnrlat=85.,resolution='i') 

map.drawcoastlines()
map.drawstates()
map.drawcountries()
map.drawlsmask(land_color='Linen', ocean_color='#CCFFFF')

parallels = np.arange(-82.5,82.5,30)
meridians = np.arange(0.,355,30)
map.drawparallels(parallels,labels=[1,0,0,0],fontsize=10)
map.drawmeridians(meridians,labels=[0,0,0,1],fontsize=10)

lons,lats= np.meshgrid(lon,lat)
# x,y = map(lons,lats)
x,y = lons,lats

temp = map.contourf(x,y,S_pred_sample_mu)
cb = map.colorbar(temp,"right", size="2%", pad="0%")
plt.title('Predictive mean of surface pressure in MPa')
# cb.set_label('Temperature (C)')


ax = fig.add_subplot(223)

map = Basemap(projection='cyl',llcrnrlon=0.,llcrnrlat=-85.,urcrnrlon=360.,urcrnrlat=85.,resolution='i') 

map.drawcoastlines()
map.drawstates()
map.drawcountries()
map.drawlsmask(land_color='Linen', ocean_color='#CCFFFF')

parallels = np.arange(-82.5,82.5,30)
meridians = np.arange(0.,355,30)
map.drawparallels(parallels,labels=[1,0,0,0],fontsize=10)
map.drawmeridians(meridians,labels=[0,0,0,1],fontsize=10)

lons,lats= np.meshgrid(lon,lat)
# x,y = map(lons,lats)
x,y = lons,lats

temp = map.contourf(x,y,np.abs(S_pred_sample_mu - u))
cb = map.colorbar(temp,"right", size="2%", pad="0%")
plt.title('Absolute error of surface pressure in MPa')


ax = fig.add_subplot(224)
map = Basemap(projection='cyl',llcrnrlon=0.,llcrnrlat=-85.,urcrnrlon=360.,urcrnrlat=85.,resolution='i') 
map.drawcoastlines()
map.drawstates()
map.drawcountries()
map.drawlsmask(land_color='Linen', ocean_color='#CCFFFF') 

parallels = np.arange(-82.5,82.5,30) 
meridians = np.arange(0.,355,30) 
map.drawparallels(parallels,labels=[1,0,0,0],fontsize=10)
map.drawmeridians(meridians,labels=[0,0,0,1],fontsize=10)

lons,lats= np.meshgrid(lon,lat)
# x,y = map(lons,lats)
x,y = lons,lats
temp = map.contourf(x,y,S_pred_sample_std)
cb = map.colorbar(temp,"right", size="2%", pad="0%")
plt.title('Predictive uncertainty of surface pressure in MPa')



plt.tight_layout()
plt.savefig('./normed_testing_Samples_max.png', dpi = 300)







fig, _axs = plt.subplots(nrows=1, ncols=3, figsize=(15,4))
axs = _axs.flatten()
ax1 = axs[0]
ax2 = axs[1]
ax3 = axs[2]
    
ax1.plot(lon,u[20,:], 'b-', linewidth = 2, label = 'Exact')       
ax1.plot(lon,S_pred_sample_mu[20,:], 'r--', linewidth = 2, label = 'Prediction')
lower = S_pred_sample_mu[20,:] - 2.0*S_pred_sample_std[20,:]
upper = S_pred_sample_mu[20,:] + 2.0*S_pred_sample_std[20,:]
ax1.fill_between(lon.flatten(), lower.flatten(), upper.flatten(), 
                facecolor='orange', alpha=0.5, label="Two std band")
ax1.set_xlabel('lon')
ax1.set_ylabel('$s(x)$')  
ax1.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
ax1.set_title('lat = $40^o$ N')
#ax1.set_xticks(ax1.get_xticks()[::6])
# ax.set_xlim([-0.1,1.1])
# ax.set_ylim([-1.1,1.1])

ax2.plot(lon,u[idx2,:], 'b-', linewidth = 2, label = 'Exact')       
ax2.plot(lon,S_pred_sample_mu[idx2,:], 'r--', linewidth = 2, label = 'Prediction')
lower = S_pred_sample_mu[idx2,:] - 2.0*S_pred_sample_std[idx2,:]
upper = S_pred_sample_mu[idx2,:] + 2.0*S_pred_sample_std[idx2,:]
ax2.fill_between(lon.flatten(), lower.flatten(), upper.flatten(), 
                facecolor='orange', alpha=0.5, label="Two std band")
ax2.set_xlabel('lon')
ax2.set_ylabel('$s(x)$')
ax2.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
# ax.set_xlim([-0.1,1.1])
# ax.set_ylim([-1.1,1.1])
ax2.set_title('lat = $0^o$')
#ax2.set_xticks(ax2.get_xticks()[::2])
ax2.legend(loc='upper center', bbox_to_anchor=(0.5, -0.35), ncol=5, frameon=False)

ax3.plot(lon,u[idx3,:], 'b-', linewidth = 2, label = 'Exact')       
ax3.plot(lon,S_pred_sample_mu[idx3,:], 'r--', linewidth = 2, label = 'Prediction')
lower = S_pred_sample_mu[idx3,:] - 2.0*S_pred_sample_std[idx3,:]
upper = S_pred_sample_mu[idx3,:] + 2.0*S_pred_sample_std[idx3,:]
ax3.fill_between(lon.flatten(), lower.flatten(), upper.flatten(), 
                facecolor='orange', alpha=0.5, label="Two std band")
ax3.set_xlabel('lon')
ax3.set_ylabel('$s(x)$')
ax3.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
# ax.set_xlim([-0.1,1.1])
# ax.set_ylim([-1.1,1.1])    
ax3.set_title('lat = $40^o$ S')
#ax3.set_xticks(ax3.get_xticks()[::6])


fig.tight_layout(w_pad=-4.5)
fig.savefig('./normed_testing_slices_max.png')


