In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchinfo import summary

import os
import time
import datetime as dt
from netCDF4 import Dataset as nc_Dataset
import pandas as pd
import numpy as np
import xarray as xr
from tqdm import tqdm
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

from HRRR_URMA_Datasets import *
from SR_UNet_simple import SR_UNet_simple
from utils import *

torch.manual_seed(42)

In [None]:
BATCH_SIZE = 64 #256
NUM_EPOCHS = 2000
#"MAE", "MSE", "CUSTOM"
LOSS_FCN = "MAE" 

In [None]:
MONTHS = [1,1]
DAYS = [1,31] #make sure last entry matches number of days in month[1]...
HOUR = 0  # 0 --> use only 00z data; 12 --> only 12z data

In [None]:
train_ds = HRRR_URMA_Dataset_Anytime_Anydate_Anyterrain(is_train=True, 
                                                         with_hrrr_terrain=True,
                                                         with_urma_terrain=True,
                                                         with_terrain_difference=True,
                                                         months=MONTHS, 
                                                         days=DAYS, 
                                                         hour=HOUR)
train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
np.shape(train_ds[0][0])

In [None]:
n_ch_in = 1+train_ds.with_hrrr_terrain+train_ds.with_urma_terrain+train_ds.with_terrain_difference
model = SR_UNet_simple(n_channels_in=n_ch_in)

### Uncomment if need to load a particular model
# savename = f"UNetSimple_batchsize{BATCH_SIZE}_numepochs2000_{str(HOUR).zfill(2)}z_months{MONTHS[0]}-{MONTHS[1]}_{LOSS_FCN}Loss"
# model.load_state_dict(torch.load(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Trained models/{savename}.pt", weights_only=True))

device = "cuda:0"
model.to(device)

In [None]:
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, inputs, targets):
        return torch.mean((inputs - targets)**4)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=[0.5,0.999]) #torch.optim.Adam(model.parameters(), lr = 1e-3)
if LOSS_FCN=="MAE":
    loss_function = torch.nn.L1Loss()  
elif LOSS_FCN=="MSE":
    loss_function = torch.nn.MSELoss() 
elif LOSS_FCN=="CUSTOM":
    loss_function=CustomLoss()
else:
    loss_function = torch.nn.L1Loss()
    print("ERROR: 'LOSS_FCN' needs to be 'MAE', 'MSE', or 'CUSTOM'. Defaulting to MAE")

In [None]:
epoch_losses = []
log_interval = 3
log_epoch_interval = 20
lowest_loss = 999

model.train()

for epoch in range(1,NUM_EPOCHS+1):
    
    running_loss = 0.0
    epoch_loss = 0.0
    
    start = time.time()
    for i, (inputs,labels) in enumerate(train_dataloader):    
        
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = loss_function(outputs,labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        epoch_loss += loss.item()
        
        #if i % log_interval == 2:  
        #    end = time.time()
        #    print(f'Epoch {epoch} | batch {i + 1}/{len(train_dataloader)} | loss: {running_loss / log_interval:.3f} | Time elapsed this epoch = {end-start:.1f} sec')
        #    running_loss = 0.0

    end2 = time.time()
    epoch_losses.append(epoch_loss/len(train_dataloader)) 
    print(f"End of epoch {epoch} | Average loss for epoch = {epoch_loss/len(train_dataloader):.3f} | Time for epoch = {end2-start:.1f} sec")

    if epoch_loss <= lowest_loss:
        lowest_loss = epoch_loss
        savename = f"UNetSimple_batchsize{BATCH_SIZE}_numepochs{NUM_EPOCHS}_{str(HOUR).zfill(2)}z_months{MONTHS[0]}-{MONTHS[1]}_{LOSS_FCN}Loss"
        torch.save(model.state_dict(), f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Trained models/{savename}_TEMP.pt")
    
    if epoch % log_epoch_interval == 0:
        
        idx = 50 #some random index
        X,y = train_ds[idx] 
        X = X[np.newaxis,:]
        X_gpu = torch.from_numpy(X).cuda(device)
        
        with torch.no_grad():
            pred = model(X_gpu)
            pred = pred.cpu().numpy()
        plot_prediction(X[0,0,:],y,pred)

print('Finished Training')
os.rename(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Trained models/{savename}_TEMP.pt", 
         f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Trained models/{savename}.pt") #so if training is interrupted, previously saved model under the same name isn't wiped out

In [None]:
plot_epoch_losses(epoch_losses, LOSS_FCN)

In [None]:
# Load model for testing, if needed
model = SR_UNet_simple(n_channels_in=n_ch_in)
model.load_state_dict(torch.load(f"/scratch/RTMA/alex.schein/hrrr_CNN_testing/Trained models/{savename}.pt", weights_only=True))
device = "cuda:0"

model.to(device)

In [None]:
#Assess seasonality distribution of the errors

mean_errors = []
log_interval = 50

for idx, (X,y) in enumerate(train_ds):
    #X,y = train_ds[idx] 
    X = X[np.newaxis,:] 
    X_gpu = torch.from_numpy(X).cuda(device)
    
    with torch.no_grad():
        pred = model(X_gpu)
        pred = pred.cpu().numpy()
    mean_errors.append(np.mean(pred-y))
    if idx % log_interval == (log_interval-1):
        print(f"{idx}/{len(train_ds)} done")

## Plot made for trained model over 2021-2023 data
numdays_per_year = int(MONTHS[1]-MONTHS[0] +1)*30 #Need to change based on number of months, days in month range

fig, ax = plt.subplots()
ax.plot(mean_errors)
plt.title("Errors (spatial mean of pred. - true)")
plt.xlabel("Index (one per day)")
plt.ylabel("Spatial-avg error")
plt.vlines(x=numdays_per_year, ymin = 0, ymax = 1, color = 'red', linestyle = '-', transform=ax.get_xaxis_transform())
plt.vlines(x=2*numdays_per_year, ymin = 0, ymax = 1, color = 'red', linestyle = '-', transform=ax.get_xaxis_transform())
plt.hlines(y=0, xmin = 0, xmax = 1, color = 'black', linestyle = '--', transform=ax.get_yaxis_transform())

In [None]:
## play around with some examples, UN-NORMED

idx = 77 #some random index
X,y = train_ds[idx] 
X = X[np.newaxis,:] 
X_gpu = torch.from_numpy(X).cuda(device)

with torch.no_grad():
    pred = model(X_gpu)
    pred = pred.cpu().numpy()

X_unnormed = train_ds.hrrr_std*X[0,0,:] + train_ds.hrrr_mean
y_unnormed = train_ds.urma_std*y + train_ds.urma_mean
pred_unnormed = train_ds.urma_std*pred + train_ds.urma_mean

plot_prediction(X[0,0,:],y,pred)
plot_prediction(X_unnormed, y_unnormed, pred_unnormed)
plot_prediction(X_unnormed, y_unnormed, X_unnormed) #to see original's error