In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchinfo import summary

import os
import glob
import time
import datetime as dt
from netCDF4 import Dataset as nc_Dataset
from netCDF4 import date2num, num2date
import pandas as pd
import numpy as np
import math
import xarray as xr
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.markers import MarkerStyle
import cartopy.crs as ccrs
import cartopy.feature as cfeature

from FunctionsAndClasses.HRRR_URMA_Datasets_AllVars import *
from FunctionsAndClasses.DefineModelAttributes import *
from FunctionsAndClasses.SR_UNet_simple import *
from FunctionsAndClasses.utils import *

In [None]:
IDX_MIN_LON=596
IDX_MIN_LAT=645
IMG_SIZE_LON=180
IMG_SIZE_LAT=180

In [None]:
varname_translation_dict = {"t2m":"t2m",
                            "d2m":"d2m",
                            "pressurf":"sp",
                            "u10m":"u10",
                            "v10m":"v10"}

urma_var_select_dict = {"t2m":{'filter_by_keys':{'typeOfLevel': 'heightAboveGround','level':2}}, 
                        "d2m":{'filter_by_keys':{'typeOfLevel': 'heightAboveGround','level':2}}, 
                        "sp":{'filter_by_keys':{'typeOfLevel': 'surface'}},
                        "u10":{'filter_by_keys':{'typeOfLevel': 'heightAboveGround','level':10}},
                        "v10":{'filter_by_keys':{'typeOfLevel': 'heightAboveGround','level':10}}}

In [None]:
## Change as needed
PREDICTOR_VARS = ["t2m"]
TARGET_VARS = ["t2m"]

In [None]:
model_MAE_attrs = DefineModelAttributes(is_train=False)

### Change as needed
# model_MAE_attrs.predictor_vars = PREDICTOR_VARS
# model_MAE_attrs.target_vars = TARGET_VARS

model_MAE_attrs.create_save_name()
model_MAE_attrs.create_dataset()

In [None]:
model_RMSE_attrs = DefineModelAttributes(is_train=False, NUM_EPOCHS=1500)

### Change as needed
# model_1_attrs.predictor_vars = PREDICTOR_VARS
# model_RMSE_attrs.target_vars = TARGET_VARS

model_RMSE_attrs.create_save_name()
model_RMSE_attrs.savename = f"RMSELoss_{model_RMSE_attrs.savename}"
model_RMSE_attrs.create_dataset()

In [None]:
model_MAE = SR_UNet_simple(n_channels_in=model_MAE_attrs.num_channels_in, n_channels_out=model_MAE_attrs.num_channels_out)
device = torch.device("cuda")
model_MAE.to(device)
model_MAE.load_state_dict(torch.load(f"/scratch/RTMA/alex.schein/CNN_Main/Trained_models/MAE_Loss/{model_MAE_attrs.savename}.pt", weights_only=True))

In [None]:
model_RMSE = SR_UNet_simple(n_channels_in=model_RMSE_attrs.num_channels_in, n_channels_out=model_RMSE_attrs.num_channels_out)
device = torch.device("cuda")
model_RMSE.to(device)
model_RMSE.load_state_dict(torch.load(f"/scratch/RTMA/alex.schein/CNN_Main/Trained_models/{model_RMSE_attrs.savename}.pt", weights_only=True))

In [None]:
TARG_VAR="t2m"
test_urma = xr.open_dataarray(f"/data1/projects/RTMA/alex.schein/URMA_train_test/test_urma_alltimes_{TARG_VAR}.nc", decode_timedelta=True)

In [None]:
IDX = 6009

pred, targ, model_output_MAE, dt_current = get_model_output_at_idx(model_attrs=model_MAE_attrs, model=model_MAE, idx=IDX, pred_var=TARG_VAR, targ_var=TARG_VAR)
_, _, model_output_RMSE, _ = get_model_output_at_idx(model_attrs=model_RMSE_attrs, model=model_RMSE, idx=IDX, pred_var=TARG_VAR, targ_var=TARG_VAR)

smartinit_dt = dt_current-dt.timedelta(hours=1) #remember, smartinit filename datetime is INITIALIZATION time, NOT valid time! 
smartinit_xr = xr.open_dataset(f"/data1/projects/RTMA/alex.schein/HRRR_Smartinit_Data/hrrr_smartinit_{str(smartinit_dt.year)}{str(smartinit_dt.month).zfill(2)}{str(smartinit_dt.day).zfill(2)}_t{str(smartinit_dt.hour).zfill(2)}z_f01.grib2", 
                                 engine="cfgrib", 
                                 backend_kwargs=urma_var_select_dict[varname_translation_dict[TARG_VAR]],
                                 decode_timedelta=True)
smartinit_xr_var = smartinit_xr[varname_translation_dict[TARG_VAR]]
smartinit_xr_var_sp_rest = smartinit_xr_var.isel(y=slice(IDX_MIN_LAT, IDX_MIN_LAT+IMG_SIZE_LAT),
                                                x=slice(IDX_MIN_LON, IDX_MIN_LON+IMG_SIZE_LON))

plot_predictor_output_truth_error(pred, model_output_MAE, targ, date_str=dt_current, title=f"UNet MAE, predictors={model_MAE_attrs.predictor_vars} --> {TARG_VAR}")
plot_predictor_output_truth_error(pred, model_output_RMSE, targ, date_str=dt_current, title=f"UNet RMSE, predictors={model_MAE_attrs.predictor_vars} --> {TARG_VAR}")
plot_predictor_output_truth_error(pred, smartinit_xr_var_sp_rest.data, targ, date_str=dt_current, title=f"Smartinit, {TARG_VAR}")