In [1]:
from tqdm import tqdm
import numpy as np
import xesmf as xe
import rasterio as rio
import os
import glob

In [2]:
root = os.path.join(os.environ["PRISM_DIR"], "raws")
subdirs = sorted(os.listdir(root))

In [3]:
# Build regridder
dataset = rio.open(glob.glob(os.path.join(root, subdirs[0], "*.bil"))[0])
lats = np.empty(dataset.height, dtype=float)
lons = np.empty(dataset.width, dtype=float)
for i in range(dataset.height):
    lats[i] = (dataset.transform * (dataset.width // 2, i))[1]
for i in range(dataset.width):
    lons[i] = (dataset.transform * (i, dataset.height // 2))[0] % 360
    
target_res = 0.75 # degrees
scaling_factor = 0.032 / target_res
target_width = round(dataset.width * scaling_factor)
target_height = round(dataset.height * scaling_factor)
grid_in = {"lon": lons, "lat": lats}
grid_out = {
    "lon": np.linspace(lons.min(), lons.max(), target_width),
    "lat": np.linspace(lats.min(), lats.max(), target_height)
}
regridder = xe.Regridder(grid_in, grid_out, "bilinear")
regridder

xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_621x1405_26x60.nc 
Reuse pre-computed weights? False 
Input grid shape:           (621, 1405) 
Output grid shape:          (26, 60) 
Periodic in longitude?      False

In [4]:
# Get mask
arr = dataset.read(1)
mask = (arr != -9999).astype(int)

In [5]:
# Define function to fix border
masked_arr = np.where(mask, arr, np.nan)
arr_out = regridder(masked_arr)
first_row = np.empty(arr_out.shape[1])
first_row[:] = np.nan

def fix(arr):
    return np.vstack((first_row, arr[1:]))

In [6]:
# Process PRISM data
all_prism_data = []
for sd in tqdm(subdirs):
    dataset = rio.open(glob.glob(os.path.join(root, sd, "*.bil"))[0])
    arr = dataset.read(1)
    masked_arr = np.where(mask, arr, np.nan)
    arr_out = regridder(masked_arr)
    fixed_arr = fix(arr_out)
    all_prism_data.append(fixed_arr)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14975/14975 [01:54<00:00, 130.93it/s]


In [7]:
all_prism_data = np.stack(all_prism_data, 0)
all_prism_data.shape

(14975, 26, 60)

In [8]:
for i, sd in enumerate(subdirs):
    if sd.startswith("2016"):
        train_end = i
        break
print(train_end)

12783


In [9]:
for i, sd in enumerate(subdirs):
    if sd.startswith("2017"):
        val_end = i
        break
print(val_end)

13149


In [10]:
for i, sd in enumerate(subdirs):
    if sd.startswith("2019"):
        test_end = i
        break
print(test_end)

13879


In [11]:
train = all_prism_data[:train_end]
train_mean = train.mean(axis=0)
train_std = train.std(axis=0)

In [12]:
val = all_prism_data[train_end:val_end]
val_mean = val.mean(axis=0)
val_std = val.std(axis=0)

In [13]:
test = all_prism_data[val_end:test_end]
test_mean = test.mean(axis=0)
test_std = test.std(axis=0)

In [14]:
with open("/data0/datasets/prism/prism_processed/train.npz", "wb") as f:
    np.savez(f, data=train, mean=train_mean, std=train_std)
    
with open("/data0/datasets/prism/prism_processed/val.npz", "wb") as f:
    np.savez(f, data=val, mean=val_mean, std=val_std)
    
with open("/data0/datasets/prism/prism_processed/test.npz", "wb") as f:
    np.savez(f, data=test, mean=test_mean, std=test_std)

In [16]:
with open("/data0/datasets/prism/prism_processed/coords.npz", "wb") as f:
    np.savez(f, lat=grid_out["lat"], lon=grid_out["lon"])

In [17]:
regridded_mask = np.where(np.isnan(train[0]), 0, 1)

In [21]:
with open("/data0/datasets/prism/prism_processed/mask.npy", "wb") as f:
    np.save(f, regridded_mask)