In [2]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib import cm
%matplotlib inline
import os
import sys
from PIL import Image

sys.path.append('/eagle/MDClimSim/mjp5595/ml4dvar/stormer/')
from stormer_utils_pangu import StormerWrapperPangu
from varsStormer import varsStormer

save_dir_name = 'stormer_long_forecast'

save_dir = '/eagle/MDClimSim/mjp5595/data/stormer/{}/'.format(save_dir_name)
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
plot_dir = os.path.join(save_dir,'plots')
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)

background_file_np = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/background_init_stormer_norm_hr12.npy' # This is just to initialize the model background
background_f = np.load(background_file_np, 'r')
print('background_f.shape :',background_f.shape)

  warn(f"Failed to load image Python extension: {e}")
  _torch_pytree._register_pytree_node(


background_f.shape : (1, 69, 128, 256)


### Plot Long Forecast

In [27]:
vars_stormer = varsStormer().vars_stormer
vars_units = varsStormer().var_units

def read_era5(data,vars_stormer):
    data_np = np.zeros((len(vars_stormer),128,256))
    for i,var in enumerate(vars_stormer):
        data_np[i] = data['input/{}'.format(var)][:]
    return data_np

forecast_scales = [0, 480, 900, 1800, 2400]

forecast_orig = h5py.File(os.path.join(save_dir, 'raw_forecast_noise{}.h5'.format(0)),'r')

lat = np.load('/eagle/MDClimSim/troyarcomano/1.40625deg_npz_40shards/lat.npy')
lon = np.load('/eagle/MDClimSim/troyarcomano/1.40625deg_npz_40shards/lon.npy')

#era5_dir = '/eagle/MDClimSim/tungnd/data/wb2/1.40625deg_from_full_res_1_step_6hr_h5df/train/'
era5_dir = '/eagle/MDClimSim/tungnd/data/wb2/1.40625deg_from_full_res_1_step_6hr_h5df/test/'
era5_data = []
for fs in forecast_scales:
    print('loading fs_scale :',fs)
    era5_data.append(read_era5(h5py.File(os.path.join(era5_dir,'2020_{:0>4d}.h5'.format((fs//6)+2)),'r'),vars_stormer))

for var_idx in [0,3,11]:
    var_dir = os.path.join(plot_dir,'{}'.format(vars_stormer[var_idx]))
    if not os.path.exists(var_dir):
        os.makedirs(var_dir)

    gif_files = []

    var_lim_min = float(np.inf)
    var_lim_max = float(-np.inf)
    inc_mse_max = 0
    for i,fs in enumerate(forecast_scales):
        var_lim_min = min(var_lim_min,np.min(era5_data[i][var_idx]))
        var_lim_max = max(var_lim_max,np.max(era5_data[i][var_idx]))
        inc_mse_max = max(inc_mse_max,np.mean(np.square(era5_data[i][var_idx]-forecast_orig[str(fs)][var_idx])))
    print('var_lim_min :',var_lim_min)
    print('var_lim_max :',var_lim_max)

    mse_x = []
    mse_y = []
    fig_gif, axs_gif = plt.subplots(5,3,sharex = True, sharey = False, figsize=(15,13))
    for i,fs in enumerate(forecast_scales):
        #pc_era50 = axs_gif[i,0].pcolormesh(lon, lat, forecast_orig[str(0)][var_idx], cmap='viridis', vmin=var_lim_min, vmax=var_lim_max)
        pc_era50 = axs_gif[i,0].pcolormesh(lon, lat, era5_data[i][var_idx], cmap='viridis')
        plt.colorbar(pc_era50, ax = axs_gif[i,0],label=vars_units[var_idx])
        axs_gif[i,0].set_title('ERA5 Data')
        axs_gif[i,0].set_xticks(np.linspace(0,360,9))

        #pc_era50 = axs_gif[i,1].pcolormesh(lon, lat, forecast_orig[str(fs)][var_idx], cmap='viridis', vmin=var_lim_min, vmax=var_lim_max)
        pc_era50 = axs_gif[i,1].pcolormesh(lon, lat, forecast_orig[str(fs)][var_idx], cmap='viridis')
        plt.colorbar(pc_era50, ax = axs_gif[i,1],label=vars_units[var_idx])
        axs_gif[i,1].set_title('ERA5 Forecast {}hrs'.format(fs))
        axs_gif[i,1].set_xticks(np.linspace(0,360,9))
        axs_gif[i,1].set_yticklabels([])

        increment = forecast_orig[str(fs)][var_idx] - era5_data[i][var_idx]
        pc_inc0 = axs_gif[i,2].pcolormesh(lon, lat, increment, 
                                        cmap='RdYlBu_r',
                                        #norm=colors.SymLogNorm(linthresh=1,vmin=-inc_lim_max,vmax=inc_lim_max),
                                        norm=colors.SymLogNorm(linthresh=1),
                                        )
        plt.colorbar(pc_inc0, ax = axs_gif[i,2], label=vars_units[var_idx])
        axs_gif[i,2].set_title('Increment (Pred Forecast - GT) {}hrs'.format(fs))
        axs_gif[i,2].set_yticks([])

        axs_mses = axs_gif[i,2].twinx()
        axs_mses.set_yscale('log')
        axs_mses.set_ylim(0,inc_mse_max)
        axs_mses.yaxis.tick_left()
        axs_mses.yaxis.set_label_position('left')
        axs_mses.set_ylabel('Mean Squared Increment ({})'.format(vars_units[var_idx]))
        mse_x.append(i*(359/(len(forecast_scales)-1)))
        mse_y.append(np.mean(np.square(increment)))
        axs_mses.plot(mse_x,mse_y,c='k')

    fig_gif.suptitle('Stormer Long Forecast ({})'.format(vars_stormer[var_idx]),fontsize=20)
    plt.tight_layout()

    plt.savefig(os.path.join(var_dir,'long_forecast_{}.png'.format(vars_stormer[var_idx])))
    plt.close()

loading fs_scale : 0
loading fs_scale : 480
loading fs_scale : 900
loading fs_scale : 1800
loading fs_scale : 2400
var_lim_min : 202.84988403320312
var_lim_max : 315.7083740234375


  plt.colorbar(pc_inc0, ax = axs_gif[i,2], label=vars_units[var_idx])
  axs_mses.set_ylim(0,inc_mse_max)


var_lim_min : 93197.0859375
var_lim_max : 104515.1171875
var_lim_min : 45695.75
var_lim_max : 58394.2890625


### Make Perturbation Forecast Gif

In [10]:
vars_stormer = varsStormer().vars_stormer
vars_units = varsStormer().var_units

def read_era5(data,vars_stormer):
    data_np = np.zeros((len(vars_stormer),128,256))
    for i,var in enumerate(vars_stormer):
        data_np[i] = data['input/{}'.format(var)][:]
    return data_np

#forecast_scales = [0, 480, 900, 1800, 2400]
forecast_scales = np.arange(0,2400,24)

forecast_orig = h5py.File(os.path.join(save_dir, 'raw_forecast_noise{}.h5'.format(0)),'r')

lat = np.load('/eagle/MDClimSim/troyarcomano/1.40625deg_npz_40shards/lat.npy')
lon = np.load('/eagle/MDClimSim/troyarcomano/1.40625deg_npz_40shards/lon.npy')

#era5_dir = '/eagle/MDClimSim/tungnd/data/wb2/1.40625deg_from_full_res_1_step_6hr_h5df/train/'
era5_dir = '/eagle/MDClimSim/tungnd/data/wb2/1.40625deg_from_full_res_1_step_6hr_h5df/test/'
era5_data = []
for fs in forecast_scales:
    #print('loading fs_scale :',fs)
    era5_data.append(read_era5(h5py.File(os.path.join(era5_dir,'2020_{:0>4d}.h5'.format((fs//6)+2)),'r'),vars_stormer))

for var_idx in [0,3,11]:
    var_dir = os.path.join(plot_dir,'{}'.format(vars_stormer[var_idx]))
    if not os.path.exists(var_dir):
        os.makedirs(var_dir)

    gif_files = []

    var_lim_min = float(np.inf)
    var_lim_max = float(-np.inf)
    inc_mse_max = 0
    min_mean_fc_val = float(np.inf)
    max_mean_fc_val = - float(np.inf)
    for i,fs in enumerate(forecast_scales):
        var_lim_min = min(var_lim_min,np.min(era5_data[i][var_idx]))
        var_lim_max = max(var_lim_max,np.max(era5_data[i][var_idx]))
        inc_mse_max = max(inc_mse_max,np.mean(np.square(era5_data[i][var_idx]-forecast_orig[str(fs)][var_idx])))
        mean_val = np.mean(forecast_orig[(str(fs))][var_idx])
        min_mean_fc_val = min(min_mean_fc_val,mean_val)
        max_mean_fc_val = max(min_mean_fc_val,mean_val)
    print('var_lim_min :',var_lim_min)
    print('var_lim_max :',var_lim_max)

    mse_x = []
    mse_y = []
    mse_y_mv = []
    for i,fs in enumerate(forecast_scales):
        fig_gif, axs_gif = plt.subplots(1,3,sharex = True, sharey = False, figsize=(15,4))
        pc_era50 = axs_gif[0].pcolormesh(lon, lat, era5_data[i][var_idx], cmap='viridis', vmin=var_lim_min, vmax=var_lim_max)
        plt.colorbar(pc_era50, ax = axs_gif[0],label=vars_units[var_idx])
        axs_gif[0].set_title('ERA5 Data')
        axs_gif[0].set_xticks(np.linspace(0,360,9))

        #pc_era50 = axs_gif[1].pcolormesh(lon, lat, forecast_orig[str(fs)][var_idx], cmap='viridis', vmin=var_lim_min, vmax=var_lim_max)
        pc_era50 = axs_gif[1].pcolormesh(lon, lat, forecast_orig[str(fs)][var_idx], cmap='viridis')
        plt.colorbar(pc_era50, ax = axs_gif[1],label=vars_units[var_idx])
        axs_gif[1].set_title('ERA5 Forecast {}hrs'.format(fs))
        axs_gif[1].set_xticks(np.linspace(0,360,9))
        axs_gif[1].set_yticklabels([])

        increment = forecast_orig[str(fs)][var_idx] - era5_data[i][var_idx]
        pc_inc0 = axs_gif[2].pcolormesh(lon, lat, increment, 
                                        cmap='RdYlBu_r',
                                        #norm=colors.SymLogNorm(linthresh=1,vmin=-inc_lim_max,vmax=inc_lim_max),
                                        norm=colors.SymLogNorm(linthresh=1),
                                        )
        plt.colorbar(pc_inc0, ax = axs_gif[2], label=vars_units[var_idx],pad=0.2)
        axs_gif[2].set_title('Increment (Pred Forecast - GT) {}hrs'.format(fs))
        axs_gif[2].set_yticks([])

        axs_mses = axs_gif[2].twinx()
        axs_mses.tick_params(axis='y',colors='green')
        axs_mses.set_yscale('log')
        axs_mses.set_ylim(0,inc_mse_max)
        axs_mses.yaxis.tick_left()
        axs_mses.yaxis.set_label_position('left')
        axs_mses.set_ylabel('Mean Squared Increment ({})'.format(vars_units[var_idx]),color='green')
        mse_x.append(i*(359/(len(forecast_scales)-1)))
        mse_y.append(np.mean(np.square(increment)))
        axs_mses.plot(mse_x,mse_y,c='green')

        axs_mses_mv = axs_gif[2].twinx()
        axs_mses_mv.set_yscale('linear')
        axs_mses_mv.tick_params(axis='y',colors='magenta')
        axs_mses_mv.set_ylim(min_mean_fc_val,max_mean_fc_val)
        axs_mses_mv.yaxis.tick_right()
        axs_mses_mv.yaxis.set_label_position('right')
        axs_mses_mv.set_ylabel('Mean {} ({})'.format(vars_stormer[var_idx],vars_units[var_idx]),color='magenta')
        #mse_x.append(i*(359/(len(forecast_scales)-1)))
        mse_y_mv.append(np.mean(forecast_orig[str(fs)][var_idx]))
        axs_mses_mv.plot(mse_x,mse_y_mv,c='magenta')

        fig_gif.suptitle('Stormer Long Forecast ({}) - {} hrs'.format(vars_stormer[var_idx],fs),fontsize=20)
        plt.tight_layout()

        plt.savefig(os.path.join(var_dir,'{}.png'.format(i)))
        plt.close()

        gif_files.append(os.path.join(var_dir,'{}.png'.format(i)))

    gif_imgs = []
    for gif_f in gif_files:
        gif_imgs.append(Image.open(gif_f))

    # create extra copies of the frist and last frame
    for x in range(0, 5):
        im = gif_imgs[0]
        gif_imgs.insert(0,im)
        im = gif_imgs[-1]
        gif_imgs.append(im)

    gif_imgs[0].save(os.path.join(var_dir,'long_forecast_{}.gif'.format(vars_stormer[var_idx])),
                save_all=True, append_images=gif_imgs[1:], optimize=False, duration=500, loop=0)

var_lim_min : 201.00021362304688
var_lim_max : 318.3009948730469


  plt.colorbar(pc_inc0, ax = axs_gif[2], label=vars_units[var_idx],pad=0.2)
  axs_mses.set_ylim(0,inc_mse_max)


var_lim_min : 92639.390625
var_lim_max : 106802.2265625
var_lim_min : 44809.22265625
var_lim_max : 58687.70703125
