In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchinfo import summary

import os
import time
import datetime as dt
from netCDF4 import Dataset as nc_Dataset
import pandas as pd
import numpy as np
import xarray as xr
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1 import make_axes_locatable
import cartopy.crs as ccrs
import cartopy.feature as cfeature

from HRRR_URMA_Datasets import *
from SR_UNet_simple import *
from utils import *

torch.manual_seed(42)

In [None]:
def get_model_params(savename):
    tHRRR = 0
    tURMA = 0
    tDIFF = 0
    sigYr = 0
    sigHr = 0
    
    if "tHRRR1" in savename:
        tHRRR = 1
    if "tURMA1" in savename:
        tURMA = 1
    if "tDIFF1" in savename:
        tDIFF = 1
    if "sigYr1" in savename:
        sigYr = 2 #remember sigs have a sin AND cos channel
    if "sigHr1" in savename:
        sigHr = 2

    return tHRRR, tURMA, tDIFF, sigYr, sigHr

In [None]:
def plot_prediction(X, pred, y, date_str="DATE", savename="MODEL_NAME", to_save=False):
    #X,y,pred = input data, truth, prediction respectively, as numpy arrays
    fig, axes = plt.subplots(1,4, figsize=(20,5))
    maxtemp = np.max([np.max(X.squeeze()), np.max(y.squeeze()), np.max(pred.squeeze())])
    mintemp = np.min([np.min(X.squeeze()), np.min(y.squeeze()), np.min(pred.squeeze())])
    #avg = (abs(maxtemp)+abs(mintemp))/2
    avg = (maxtemp-mintemp)/10
    
    axes[0].imshow(X.squeeze(), cmap="coolwarm", vmin = mintemp, vmax = maxtemp)
    axes[0].set_title(f"Predictor temp. (HRRR 2.5km)")
    axes[0].axis("off")
    axes[1].imshow(pred.squeeze(), cmap="coolwarm", vmin = mintemp, vmax = maxtemp)
    axes[1].set_title(f"Predicted temp.")
    axes[1].axis("off")
    axes[2].imshow(y.squeeze(), cmap="coolwarm", vmin = mintemp, vmax = maxtemp)
    axes[2].set_title(f"True temp. (URMA)")
    axes[2].axis("off")
    pos = axes[3].imshow((pred.squeeze() - y.squeeze()), cmap="coolwarm", vmin = -1*avg, vmax = avg) #Note we want the error centered around zero - may need to adjust the bounds.
    axes[3].set_title(f"Prediction - Truth (RMSE = {np.sqrt(np.mean((pred.squeeze() - y.squeeze())**2)):.4f})")
    axes[3].axis("off")

    cbar = fig.colorbar(pos, ax=axes[3], fraction=0.03) #needs to be made more elegant
    cbar.set_label('Error')
    
    plt.suptitle(f"{savename} | Date = {date_str} | Maximum = {maxtemp:.1f} | Minimum = {mintemp:.1f}", va="bottom", fontsize=14)
    plt.tight_layout()

    if to_save:
        plt.savefig("temp.png",dpi=300, bbox_inches="tight")

    plt.show()

In [None]:
def get_model_output(train_ds, model, idx=0, is_unnormed=True):
    X,y = train_ds[idx] 
    X = X[np.newaxis,:] 
    X_gpu = torch.from_numpy(X).cuda(device)
    
    with torch.no_grad():
        pred = model(X_gpu.float())
        pred = pred.cpu().numpy()
    
    date = train_ds.xr_dataset_pred[idx].valid_time.data
    dt_current = dt.datetime.strptime(str(np.datetime_as_string(date, unit='s')), "%Y-%m-%dT%H:%M:%S")
    hour_idx = np.where(np.array(train_ds.hours) == dt_current.hour)[0][0] #overengineered but will work for non-"all" hours

    if is_unnormed:
        X = train_ds.dataset_pred_normed_stddevs[hour_idx]*X[0,0,:] + train_ds.dataset_pred_normed_means[hour_idx]
        y = train_ds.dataset_targ_normed_stddevs[hour_idx]*y + train_ds.dataset_targ_normed_means[hour_idx]
        pred = train_ds.dataset_targ_normed_stddevs[hour_idx]*pred + train_ds.dataset_targ_normed_means[hour_idx]
    else:
        X = X[0,0,:]

    return X, y, pred, dt_current

In [None]:
BATCH_SIZE = 256
NUM_EPOCHS = 1000

# MONTHS=[4,6]
HOURS="all"

In [None]:
#Loop over all possibilities

for MONTHS in [[1,12], [7,9], [10,12], [1,3], [4,6]]:
    errors_list = []
    rmse_list = []
    spatial_errors_list = []
    spatial_rmse_list = []
    labels_list = []
    
    for pred_model_name in [x for x in os.listdir(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Trained models") if "NE1000" in x]: #bad hack but w/e
    
        # pred_model_name = f"UNSim_BS{BATCH_SIZE}_NE{NUM_EPOCHS}_months{MONTHS[0]}-{MONTHS[1]}_tHRRR{W_HRRR_T}_tURMA{W_URMA_T}_tDIFF{W_DIFF_T}"
        tHRRR, tURMA, tDIFF, sigYr, sigHr = get_model_params(pred_model_name)
    
        train_ds = HRRR_URMA_Dataset_AllTimes_AnyDates_AnyTerrains(is_train=True,
                                                                  months=MONTHS,  
                                                                  hours=HOURS, 
                                                                  with_terrains=[["hrrr","urma","diff"] if tHRRR else ["diff"]][0], 
                                                                  with_yearly_time_sig=sigYr, 
                                                                  with_hourly_time_sig=sigHr)
    
        pred_model = SR_UNet_simple(n_channels_in=np.shape(train_ds[0][0])[0])
        pred_model.load_state_dict(torch.load(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Trained models/{pred_model_name}", weights_only=True))
        device = "cuda:3"
        
        pred_model.to(device)
    
        tmp_error_list = []
        tmp_rmse_list = []
        tmp_spatial_error_list = []
        for idx, (X,y) in enumerate(train_ds):
            _, y_unnormed, pred_unnormed, _ = get_model_output(train_ds, model=pred_model, idx=idx, is_unnormed=True)
            
    
            tmp_spatial_error_list.append(pred_unnormed-y_unnormed)
            tmp_rmse_list.append(np.sqrt(np.mean(pred_unnormed - y_unnormed)**2))
            tmp_error_list.append((pred_unnormed-y_unnormed).flatten())
    
        
        errors_list.append(np.array(tmp_error_list).flatten()) #(pred_unnormed-y_unnormed).flatten())
        rmse_list.append(np.array(tmp_rmse_list).flatten())
        spatial_errors_list.append(np.mean(tmp_spatial_error_list, axis=0))
        spatial_rmse_list.append(np.sqrt(np.mean(tmp_spatial_error_list, axis=0)**2)) #technically this could be calculated after the fact... but ok to have as its own thing, I think
        labels_list.append(f"H{tHRRR}/U{tURMA}/D{tDIFF}/Yr{int(bool(sigYr))}/Hr{int(bool(sigHr))}")
        print(f"H{tHRRR}/U{tURMA}/D{tDIFF}/Yr{int(bool(sigYr))}/Hr{int(bool(sigHr))} DONE")
    print("ALL COMBOS DONE")
    
    print("Starting label creation")
    labels_with_quantiles = [label+f" \n 25%={np.quantile(errors_list[i], 0.25):.4f} \n 50% = {np.quantile(errors_list[i], 0.50):.4f} \n 75% = {np.quantile(errors_list[i], 0.75):.4f} \n Q3-Q1={(np.quantile(errors_list[i], 0.75)-np.quantile(errors_list[i], 0.25)):.4f}" for i, label in enumerate(labels_list)]
    RMSE_labels_with_quantiles = [label+f" \n 25%={np.quantile(rmse_list[i], 0.25):.4f} \n 50% = {np.quantile(rmse_list[i], 0.50):.4f} \n 75% = {np.quantile(rmse_list[i], 0.75):.4f} \n Q3-Q1={(np.quantile(rmse_list[i], 0.75)-np.quantile(rmse_list[i], 0.25)):.4f}" for i, label in enumerate(labels_list)]
    
    print("Labels done. Making error whisker plot")
    fig_hist, axes_hist = plt.subplots(1,1, figsize=(10,5))
    bp = axes_hist.boxplot(errors_list, tick_labels=labels_with_quantiles, sym="", whis=(5,95))
    plt.ylabel(f"Model Error (°C) (num. idxs = {idx+1})")# (idx={idx}) ")
    plt.suptitle(f"Error for Various Models, all hours, months {MONTHS[0]}-{MONTHS[1]} 2021/22/23", fontsize=14)
    plt.tight_layout()
    plt.savefig(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Plots/whiskerplot_error_months{MONTHS[0]}-{MONTHS[1]}.png", dpi=300)
    
    print("Error whisker plot done. Making RMSE whisker plot")
    fig_hist, axes_hist = plt.subplots(1,1, figsize=(10,5))
    bp = axes_hist.boxplot(rmse_list, tick_labels=RMSE_labels_with_quantiles, sym="", whis=(5,95))
    plt.ylabel(f"Model RMSE (°C) (num. idxs = {idx+1})")# (idx={idx}) ")
    plt.suptitle(f"RMSE for Various Models, all hours, months {MONTHS[0]}-{MONTHS[1]} 2021/22/23", fontsize=14)
    plt.tight_layout()
    plt.savefig(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Plots/whiskerplot_rmse_months{MONTHS[0]}-{MONTHS[1]}.png", dpi=300)
    
    
    print("RMSE whisker plot done. Making spatial error plot")
    fig_sp_errs, ax_sp_errs = plt.subplots(1,2,figsize=(20,10)) #2,4, figsize=(20,10))
    mintemp = np.min(spatial_errors_list)
    maxtemp = np.max(spatial_errors_list)
    halfinterval = (maxtemp-mintemp)/7
    
    n = 0
    for i in [0,1]:
        # for j in [0,1]:#,2,3]:
        # im = ax_sp_errs[i][j].imshow(spatial_errors_list[n].squeeze(), cmap="coolwarm", vmin=-1*halfinterval, vmax=halfinterval)
        im = ax_sp_errs[i].imshow(spatial_errors_list[n].squeeze(), cmap="coolwarm", vmin=-1*halfinterval, vmax=halfinterval)
        divider = make_axes_locatable(ax_sp_errs[i])#[j])
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im, cax=cax)
        # ax_sp_errs[i][j].axis("off")
        # ax_sp_errs[i][j].set_title(labels_list[n])
        ax_sp_errs[i].axis("off")
        ax_sp_errs[i].set_title(labels_list[n])

        n+=1
    
    plt.suptitle(f"Errors in Models, average over {idx+1} times, all hours, months {MONTHS[0]}-{MONTHS[1]} 2021/22/23", fontsize=18)
    plt.tight_layout()
    plt.savefig(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Plots/spatial_error_plot_months{MONTHS[0]}-{MONTHS[1]}.png", dpi=300)
    
    
    print("Spatial error plot done. Making spatial RMSE plot")
    fig_sp_errs, ax_sp_errs = plt.subplots(1,2,figsize=(20,10)) #2,4, figsize=(20,10))
    mintemp = np.min(spatial_rmse_list)
    maxtemp = np.max(spatial_rmse_list)
    halfinterval = (maxtemp-mintemp)/7
    
    n = 0
    for i in [0,1]:
        # for j in [0,1]:#,2,3]:
        # im = ax_sp_errs[i][j].imshow(spatial_rmse_list[n].squeeze(), cmap="jet", vmin=0, vmax=halfinterval)
        im = ax_sp_errs[i].imshow(spatial_rmse_list[n].squeeze(), cmap="jet", vmin=0, vmax=halfinterval)
        divider = make_axes_locatable(ax_sp_errs[i])#[j])
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im, cax=cax)
        # ax_sp_errs[i][j].axis("off")
        # ax_sp_errs[i][j].set_title(labels_list[n])
        ax_sp_errs[i].axis("off")
        ax_sp_errs[i].set_title(labels_list[n])
        n+=1
    
    plt.suptitle(f"RMSE in Models, average over {idx+1} times, all hours, months {MONTHS[0]}-{MONTHS[1]} 2021/22/23", fontsize=18)
    plt.tight_layout()
    plt.savefig(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Plots/spatial_rmse_plot_months{MONTHS[0]}-{MONTHS[1]}.png", dpi=300)
    
    
    
    print("ALL PLOTS DONE")

In [None]:
# Make plots for each hour

for MONTHS in [[1,12], [7,9], [10,12], [1,3], [4,6]]: 
    for HOURS in [[i] for i in range(24)]:
        errors_list = []
        rmse_list = []
        spatial_errors_list = []
        spatial_rmse_list = []
        labels_list = []
        
        for pred_model_name in [x for x in os.listdir(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Trained models") if "NE1000" in x]: #bad hack but w/e
            
            tHRRR, tURMA, tDIFF, sigYr, sigHr = get_model_params(pred_model_name)
        
            train_ds = HRRR_URMA_Dataset_AllTimes_AnyDates_AnyTerrains(is_train=True,
                                                                      months=MONTHS,  
                                                                      hours=HOURS, 
                                                                      with_terrains=[["hrrr","urma","diff"] if tHRRR else ["diff"]][0], 
                                                                      with_yearly_time_sig=sigYr, 
                                                                      with_hourly_time_sig=sigHr)
        
            pred_model = SR_UNet_simple(n_channels_in=np.shape(train_ds[0][0])[0])
            pred_model.load_state_dict(torch.load(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Trained models/{pred_model_name}", weights_only=True))
            device = "cuda:3"
            
            pred_model.to(device)
        
            tmp_error_list = []
            tmp_rmse_list = []
            tmp_spatial_error_list = []
            for idx, (X,y) in enumerate(train_ds):
                _, y_unnormed, pred_unnormed, _ = get_model_output(train_ds, model=pred_model, idx=idx, is_unnormed=True)
                tmp_spatial_error_list.append(pred_unnormed-y_unnormed)
                tmp_rmse_list.append(np.sqrt(np.mean(pred_unnormed - y_unnormed)**2))
                tmp_error_list.append((pred_unnormed-y_unnormed).flatten())
        
            
            errors_list.append(np.array(tmp_error_list).flatten()) #(pred_unnormed-y_unnormed).flatten())
            rmse_list.append(np.array(tmp_rmse_list).flatten())
            spatial_errors_list.append(np.mean(tmp_spatial_error_list, axis=0))
            spatial_rmse_list.append(np.sqrt(np.mean(tmp_spatial_error_list, axis=0)**2)) #technically this could be calculated after the fact... but ok to have as its own thing, I think
            labels_list.append(f"H{tHRRR}/U{tURMA}/D{tDIFF}/Yr{int(bool(sigYr))}/Hr{int(bool(sigHr))}")
            print(f"H{tHRRR}/U{tURMA}/D{tDIFF}/Yr{int(bool(sigYr))}/Hr{int(bool(sigHr))} DONE")
        print("ALL COMBOS DONE")
        
        print("Starting label creation")
        labels_with_quantiles = [label+f" \n 25%={np.quantile(errors_list[i], 0.25):.4f} \n 50% = {np.quantile(errors_list[i], 0.50):.4f} \n 75% = {np.quantile(errors_list[i], 0.75):.4f} \n Q3-Q1={(np.quantile(errors_list[i], 0.75)-np.quantile(errors_list[i], 0.25)):.4f}" for i, label in enumerate(labels_list)]
        RMSE_labels_with_quantiles = [label+f" \n 25%={np.quantile(rmse_list[i], 0.25):.4f} \n 50% = {np.quantile(rmse_list[i], 0.50):.4f} \n 75% = {np.quantile(rmse_list[i], 0.75):.4f} \n Q3-Q1={(np.quantile(rmse_list[i], 0.75)-np.quantile(rmse_list[i], 0.25)):.4f}" for i, label in enumerate(labels_list)]
        
        print("Labels done. Making error whisker plot")
        fig_hist, axes_hist = plt.subplots(1,1, figsize=(10,5))
        bp = axes_hist.boxplot(errors_list, tick_labels=labels_with_quantiles, sym="", whis=(5,95))
        plt.ylabel(f"Model Error (°C) (num. idxs = {idx+1})")# (idx={idx}) ")
        plt.suptitle(f"Error for Various Models, {str(HOURS[0]).zfill(2)}z, months {MONTHS[0]}-{MONTHS[1]} 2021/22/23", fontsize=14)
        plt.tight_layout()
        plt.savefig(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Plots/whiskerplot_error_months{MONTHS[0]}-{MONTHS[1]}_{str(HOURS[0]).zfill(2)}z.png", dpi=300)
        
        print("Error whisker plot done. Making RMSE whisker plot")
        fig_hist, axes_hist = plt.subplots(1,1, figsize=(10,5))
        bp = axes_hist.boxplot(rmse_list, tick_labels=RMSE_labels_with_quantiles, sym="", whis=(5,95))
        plt.ylabel(f"Model RMSE (°C) (num. idxs = {idx+1})")# (idx={idx}) ")
        plt.suptitle(f"RMSE for Various Models, {str(HOURS[0]).zfill(2)}z, months {MONTHS[0]}-{MONTHS[1]} 2021/22/23", fontsize=14)
        plt.tight_layout()
        plt.savefig(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Plots/whiskerplot_rmse_months{MONTHS[0]}-{MONTHS[1]}_{str(HOURS[0]).zfill(2)}z.png", dpi=300)
        
        
        print("RMSE whisker plot done. Making spatial error plot")
        fig_sp_errs, ax_sp_errs = plt.subplots(1,2, figsize=(20,10))#(2,4, figsize=(20,10))
        mintemp = np.min(spatial_errors_list)
        maxtemp = np.max(spatial_errors_list)
        halfinterval = (maxtemp-mintemp)/7
        
        n = 0
        for i in [0,1]:
            # for j in [0,1]#,2,3]:
            # im = ax_sp_errs[i][j].imshow(spatial_errors_list[n].squeeze(), cmap="coolwarm", vmin=-1*halfinterval, vmax=halfinterval)
            im = ax_sp_errs[i].imshow(spatial_errors_list[n].squeeze(), cmap="coolwarm", vmin=-1*halfinterval, vmax=halfinterval)
            divider = make_axes_locatable(ax_sp_errs[i])#[j])
            cax = divider.append_axes("right", size="5%", pad=0.05)
            plt.colorbar(im, cax=cax)
            # ax_sp_errs[i][j].axis("off")
            # ax_sp_errs[i][j].set_title(labels_list[n])
            ax_sp_errs[i].axis("off")
            ax_sp_errs[i].set_title(labels_list[n])
    
            n+=1
        
        plt.suptitle(f"Errors in Models, average over {idx+1} times, {str(HOURS[0]).zfill(2)}z, months {MONTHS[0]}-{MONTHS[1]} 2021/22/23", fontsize=18)
        plt.tight_layout()
        plt.savefig(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Plots/spatial_error_plot_months{MONTHS[0]}-{MONTHS[1]}_{str(HOURS[0]).zfill(2)}z.png", dpi=300)
        
        
        print("Spatial error plot done. Making spatial RMSE plot")
        fig_sp_errs, ax_sp_errs = plt.subplots(1,2, figsize=(20,10))#(2,4, figsize=(20,10))
        mintemp = np.min(spatial_rmse_list)
        maxtemp = np.max(spatial_rmse_list)
        halfinterval = (maxtemp-mintemp)/7
        
        n = 0
        for i in [0,1]:
            # for j in [0,1]:#,2,3]:
            # im = ax_sp_errs[i][j].imshow(spatial_rmse_list[n].squeeze(), cmap="jet", vmin=0, vmax=halfinterval)
            im = ax_sp_errs[i].imshow(spatial_rmse_list[n].squeeze(), cmap="jet", vmin=0, vmax=halfinterval)
            divider = make_axes_locatable(ax_sp_errs[i])#[j])
            cax = divider.append_axes("right", size="5%", pad=0.05)
            plt.colorbar(im, cax=cax)
            # ax_sp_errs[i][j].axis("off")
            # ax_sp_errs[i][j].set_title(labels_list[n])
            ax_sp_errs[i].axis("off")
            ax_sp_errs[i].set_title(labels_list[n])
    
            n+=1
        
        plt.suptitle(f"RMSE in Models, average over {idx+1} times, {str(HOURS[0]).zfill(2)}z, months {MONTHS[0]}-{MONTHS[1]} 2021/22/23", fontsize=18)
        plt.tight_layout()
        plt.savefig(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Plots/spatial_rmse_plot_months{MONTHS[0]}-{MONTHS[1]}_{str(HOURS[0]).zfill(2)}z.png", dpi=300)
        
        
        
        print("ALL PLOTS DONE")