In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchinfo import summary

import os
import time
import datetime as dt
from netCDF4 import Dataset as nc_Dataset
import pandas as pd
import numpy as np
import xarray as xr
from tqdm import tqdm
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

from HRRR_URMA_Datasets import *
from SR_UNet_simple import SR_UNet_simple
from utils import *

torch.manual_seed(42)

In [None]:
BATCH_SIZE = 64 #256
NUM_EPOCHS = 2000

In [None]:
MONTHS = [1,3]
DAYS = [1,31] #make sure last entry matches number of days in month[1]...
HOUR = 12  # 0 --> use only 00z data; 12 --> only 12z data

In [None]:
for W_HRRR_TERR in [True, False]:
    for W_URMA_TERR in [True, False]:
        for W_DIFF_TERR in [True, False]:
            savename = f"UNSim_BS{BATCH_SIZE}_NE{NUM_EPOCHS}_{str(HOUR).zfill(2)}z_months{MONTHS[0]}-{MONTHS[1]}_tHRRR{int(W_HRRR_TERR)}_tURMA{int(W_URMA_TERR)}_tDIFF{int(W_DIFF_TERR)}"

            if W_HRRR_TERR and W_URMA_TERR: #then using same mean and stddev to norm HRRR and URMA terrain - choosing to use URMA mean/stddev at the moment (see dataloader)
                savename = f"UNSim_BS{BATCH_SIZE}_NE{NUM_EPOCHS}_{str(HOUR).zfill(2)}z_months{MONTHS[0]}-{MONTHS[1]}_tHRRR{int(W_HRRR_TERR)}_tURMA{int(W_URMA_TERR)}_tDIFF{int(W_DIFF_TERR)}_OneTerrainNorm"

            
            if not os.path.exists(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Trained models/{savename}.pt"): #in case loop needs to be restarted after model(s) were already fully trained
                print(f"Training started for {savename}")
                
                train_ds = HRRR_URMA_Dataset_Anytime_Anydate_Anyterrain(is_train=True,
                                                                        with_hrrr_terrain=W_HRRR_TERR,
                                                                        with_urma_terrain=W_URMA_TERR,
                                                                        with_terrain_difference=W_DIFF_TERR,
                                                                        months=MONTHS,
                                                                        days=DAYS,
                                                                        hour=HOUR)
                train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    
                n_ch_in = 1+W_HRRR_TERR+W_URMA_TERR+W_DIFF_TERR
                model = SR_UNet_simple(n_channels_in=n_ch_in)
                device = "cuda:0"
                model.to(device)
    
                optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=[0.5,0.999])
                loss_function = torch.nn.L1Loss()
    
                epoch_losses = []
                log_print_interval = 50
                log_epoch_interval = NUM_EPOCHS/4
                lowest_loss = 999
                
                model.train()
    
                for epoch in range(1,NUM_EPOCHS+1):
                    running_loss = 0.0
                    epoch_loss = 0.0
                    start = time.time()
                    for i, (inputs,labels) in enumerate(train_dataloader):    
                        inputs = inputs.to(device)
                        labels = labels.to(device)
                        optimizer.zero_grad()
                
                        outputs = model(inputs)
                        loss = loss_function(outputs,labels)
                        loss.backward()
                        optimizer.step()
                        
                        epoch_loss += loss.item()
    
                    end2 = time.time()
                    epoch_losses.append(epoch_loss/len(train_dataloader)) 
    
                    if epoch % log_print_interval == 0:
                        print(f"End of epoch {epoch} | Average loss for epoch = {epoch_loss/len(train_dataloader):.3f} | Time for epoch = {end2-start:.1f} sec")
                
                    if epoch_loss <= lowest_loss:
                        lowest_loss = epoch_loss
                        torch.save(model.state_dict(), f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Trained models/{savename}_TEMP.pt")
                    
                    if epoch % log_epoch_interval == 0:
                        idx = 50 #some random index
                        X,y = train_ds[idx] 
                        X = X[np.newaxis,:]
                        X_gpu = torch.from_numpy(X).cuda(device)
                        with torch.no_grad():
                            pred = model(X_gpu)
                            pred = pred.cpu().numpy()
                        plot_prediction(X[0,0,:],y,pred)
            
                print('Finished Training')
                os.rename(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Trained models/{savename}_TEMP.pt", 
                          f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Trained models/{savename}.pt") #so if training is interrupted, previously saved completed model under the same name isn't wiped out
    
                plot_epoch_losses(epoch_losses, "MAE")
    
                #Assess seasonality distribution of the errors
                mean_errors = []
                for idx, (X,y) in enumerate(train_ds):
                    X = X[np.newaxis,:] 
                    X_gpu = torch.from_numpy(X).cuda(device)
                    with torch.no_grad():
                        pred = model(X_gpu)
                        pred = pred.cpu().numpy()
                    mean_errors.append(np.mean(pred-y))
                    
                ## Plot made for trained model over 2021-2023 data
                numdays_per_year = int(MONTHS[1]-MONTHS[0] +1)*30 #Need to change based on number of months, days in month range, but this is ok approximately
                
                fig, ax = plt.subplots()
                ax.plot(mean_errors)
                plt.title("Errors (spatial mean of pred. - true)")
                plt.xlabel("Index (one per day)")
                plt.ylabel("Spatial-avg error")
                plt.vlines(x=numdays_per_year, ymin = 0, ymax = 1, color = 'red', linestyle = '-', transform=ax.get_xaxis_transform())
                plt.vlines(x=2*numdays_per_year, ymin = 0, ymax = 1, color = 'red', linestyle = '-', transform=ax.get_xaxis_transform())
                plt.hlines(y=0, xmin = 0, xmax = 1, color = 'black', linestyle = '--', transform=ax.get_yaxis_transform())

In [None]:
#Load model for testing, if needed
# model = SR_UNet_simple(n_channels_in=n_ch_in)
# model.load_state_dict(torch.load(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Trained models/{savename}.pt", weights_only=True))
# device = "cuda:0"

# model.to(device)

In [None]:
#Assess seasonality distribution of the errors
# mean_errors = []
# for idx, (X,y) in enumerate(train_ds):
#     X = X[np.newaxis,:] 
#     X_gpu = torch.from_numpy(X).cuda(device)
#     with torch.no_grad():
#         pred = model(X_gpu)
#         pred = pred.cpu().numpy()
#     mean_errors.append(np.mean(pred-y))
    
# ## Plot made for trained model over 2021-2023 data
# numdays_per_year = int(MONTHS[1]-MONTHS[0] +1)*30 #Need to change based on number of months, days in month range, but this is ok approximately

# fig, ax = plt.subplots()
# ax.plot(mean_errors)
# plt.title("Errors (spatial mean of pred. - true)")
# plt.xlabel("Index (one per day)")
# plt.ylabel("Spatial-avg error")
# plt.vlines(x=numdays_per_year, ymin = 0, ymax = 1, color = 'red', linestyle = '-', transform=ax.get_xaxis_transform())
# plt.vlines(x=2*numdays_per_year, ymin = 0, ymax = 1, color = 'red', linestyle = '-', transform=ax.get_xaxis_transform())
# plt.hlines(y=0, xmin = 0, xmax = 1, color = 'black', linestyle = '--', transform=ax.get_yaxis_transform())

In [None]:
## play around with some examples

# idx = 77 #some random index
# X,y = train_ds[idx] 
# X = X[np.newaxis,:] 
# X_gpu = torch.from_numpy(X).cuda(device)

# with torch.no_grad():
#     pred = model(X_gpu)
#     pred = pred.cpu().numpy()

# X_unnormed = train_ds.hrrr_std*X[0,0,:] + train_ds.hrrr_mean
# y_unnormed = train_ds.urma_std*y + train_ds.urma_mean
# pred_unnormed = train_ds.urma_std*pred + train_ds.urma_mean

# plot_prediction(X[0,0,:],y,pred)
# plot_prediction(X_unnormed, y_unnormed, pred_unnormed)
# plot_prediction(X_unnormed, y_unnormed, X_unnormed) #to see original's error