In [1]:
# WORKING OFF OF
# https://www.learnpytorch.io/04_pytorch_custom_datasets/
# and Marshall's code

In [14]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

import os
from netCDF4 import Dataset as nc_Dataset
import pandas as pd
import numpy as np
import xarray as xr

In [5]:
PATH_MAIN_URMA = "/scratch/RTMA/alex.schein/URMA_train_test"
PATH_TRAIN_URMA = "/scratch/RTMA/alex.schein/URMA_train_test/train"
PATH_TEST_URMA = "/scratch/RTMA/alex.schein/URMA_train_test/test"

PATH_MAIN_HRRR = "/scratch/RTMA/alex.schein/Regridded_HRRR_train_test"
PATH_TRAIN_HRRR = "/scratch/RTMA/alex.schein/Regridded_HRRR_train_test/train_spatiallyrestricted"
PATH_TEST_HRRR = "/scratch/RTMA/alex.schein/Regridded_HRRR_train_test/test_spatiallyrestricted"

In [6]:
#Make class to load regridded HRRR and URMA 
class HRRR_URMA_Dataset(Dataset):
    
    def __init__(self, is_train=False):
        # is_train --> load either training or testing datasets
        
        # Establish paths
        # UNLIKE Marshall's code, the training and predictor indices align exactly and are contained in the sample_idx coordinate, so no need for separate path
        # Also, now we have separate .nc files for training and testing, so need multiple paths there
        path_root = os.path.dirname(os.getcwd())
        if is_train:
            data_save_path_pred = os.path.join(path_root,"Regridded_HRRR_train_test", "train_hrrr.nc")
            data_save_path_targ = os.path.join(path_root,"URMA_train_test", "train_urma.nc")
        else:
            data_save_path_pred = os.path.join(path_root,"Regridded_HRRR_train_test", "test_hrrr.nc")
            data_save_path_targ = os.path.join(path_root,"URMA_train_test", "test_urma.nc")

        # open netCDF4 datasets
        self.nc_dataset_pred = nc_Dataset(data_save_path_pred)
        self.nc_dataset_targ = nc_Dataset(data_save_path_targ)

        # open xarray datasets
        # necessary because coordinate indices store the mappings, and it's much easier to read those with xarray than netcdf4
        self.xr_dataset_pred = xr.open_dataset(data_save_path_pred)
        self.xr_dataset_targ = xr.open_dataset(data_save_path_targ)

        # get mapping indices from xarray coords
        if is_train:
            ds_pred = xr.open_dataset(os.path.join(path_root,"Regridded_HRRR_train_test", "train_hrrr.nc"))
            ds_targ = xr.open_dataset(os.path.join(path_root,"URMA_train_test", "train_urma.nc"))
        else:
            ds_pred = xr.open_dataset(os.path.join(path_root,"Regridded_HRRR_train_test", "test_hrrr.nc"))
            ds_targ = xr.open_dataset(os.path.join(path_root,"URMA_train_test", "test_urma.nc"))

        self.predictor_indices = ds_pred.sample_idx.data
        self.target_indices = ds_targ.sample_idx.data
        assert len(self.predictor_indices) == len(self.target_indices), "Predictor indices array should be of the same length as the target indices array"

    def __len__(self):
        return len(self.predictor_indices)
        
    def __getitem__(self, idx):

        #get sample index for predictor and target
        p_idx = self.predictor_indices[idx]
        t_idx = self.target_indices[idx]

        # extract 2m temp image 
        # awkward: forgot to rename HRRR temp from "t" to "t2m" as in URMA...
        predictor = self.nc_dataset_pred["t"][p_idx,:,:].data[np.newaxis,:,:]
        target = self.nc_dataset_targ["t2m"][t_idx,:,:].data[np.newaxis,:,:]

        return (predictor), (target) #copying Marshall's syntax, see if it works

In [8]:
############ TESTING ############

In [9]:
train_ds = HRRR_URMA_Dataset(is_train=True)

  self.xr_dataset_pred = xr.open_dataset(data_save_path_pred)
  self.xr_dataset_targ = xr.open_dataset(data_save_path_targ)
  ds_pred = xr.open_dataset(os.path.join(path_root,"Regridded_HRRR_train_test", "train_hrrr.nc"))
  ds_targ = xr.open_dataset(os.path.join(path_root,"URMA_train_test", "train_urma.nc"))


In [15]:
train_dataloader = DataLoader(train_ds, batch_size = 64, shuffle = True)

In [18]:
train_ds