In [43]:
import numpy as np
import os
import glob
from tqdm import tqdm
import xarray as xr

In [49]:
filelist = sorted(glob.glob(
    os.path.join(
        "/data0/datasets/weatherbench/data/weatherbench/era5/2.8125deg",
        "2m_temperature",
        "*.nc"
    )
))

In [50]:
xarr = None
for fi in filelist:
    if xarr is None:
        xarr = xr.open_dataset(fi)
    else:
        xarr = xr.concat((xarr, xr.open_dataset(fi)), dim="time")

In [54]:
lats = xarr.lat.data
lons = xarr.lon.data

In [90]:
# PRISM bounds
bottom = 24.10
top = 49.94
left = 234.98
right = 293.48

In [97]:
train_data = xarr.sel({
    "time": slice("1981-01-01", "2015-12-31"),
    "lat": slice(bottom, top),
    "lon": slice(left, right)
})
cropped_lats = train_data.lat.data
cropped_lons = train_data.lon.data

train_data = train_data.resample(time="1D").max(dim="time")
train_mean = train_data.mean(dim="time")["t2m"].data
train_std = train_data.std(dim="time")["t2m"].data
train_narr = train_data["t2m"].data

In [101]:
val_data = xarr.sel({
    "time": slice("2016-01-01", "2016-12-31"),
    "lat": slice(bottom, top),
    "lon": slice(left, right)
})
val_data = val_data.resample(time="1D").max(dim="time")
val_mean = val_data.mean(dim="time")["t2m"].data
val_std = val_data.std(dim="time")["t2m"].data
val_narr = val_data["t2m"].data

In [102]:
test_data = xarr.sel({
    "time": slice("2017-01-01", "2018-12-31"),
    "lat": slice(bottom, top),
    "lon": slice(left, right)
})
test_data = test_data.resample(time="1D").max(dim="time")
test_mean = test_data.mean(dim="time")["t2m"].data
test_std = test_data.std(dim="time")["t2m"].data
test_narr = test_data["t2m"].data

In [104]:
with open("/data0/datasets/prism/era5_cropped/train.npz", "wb") as f:
    np.savez(f, data=train_narr, mean=train_mean, std=train_std)

In [105]:
with open("/data0/datasets/prism/era5_cropped/val.npz", "wb") as f:
    np.savez(f, data=val_narr, mean=val_mean, std=val_std)

In [106]:
with open("/data0/datasets/prism/era5_cropped/test.npz", "wb") as f:
    np.savez(f, data=test_narr, mean=test_mean, std=test_std)

In [107]:
with open("/data0/datasets/prism/era5_cropped/coords.npz", "wb") as f:
    np.savez(f, lat=cropped_lats, lon=cropped_lons)

The below does not work because for some reason the number of days is off.

In [45]:
# root = os.environ["ERA5_2DEG"]
# lats = np.load(os.path.join(root, "lat.npy"))
# lons = np.load(os.path.join(root, "lon.npy"))

In [46]:
# def process(split):
#     file_list = glob.glob(os.path.join(root, split, "*.npz"))
#     conus_data = []
#     for fi in tqdm(file_list):
#         if not fi.endswith("climatology.npz"):
#             npz = np.load(fi)
#             arr = npz["2m_temperature"]
#             cropped = arr[:,0,y:yy,x:xx]
#             conus_data.append(cropped)
#     conus_data = np.concatenate(conus_data)
#     days = []
#     for i in range(0, len(conus_data), 24):
#         j = min(i+24, len(conus_data))
#         days.append(conus_data[i:j])
#     days = np.stack(days)
#     daily_max_t2m = np.max(days, axis=1)
#     with open(f"/data0/datasets/prism/era5_cropped/{split}.npy", "wb") as f:
#         np.save(f, daily_max_t2m)
#     return daily_max_t2m

PRISM data is in EST, ERA5 is in UTC, which is 4 hours ahead, so we should skip the first four hours to offset by 4 and then drop the last day (since it won't be a full 24 hours). I'm not going to do this though because I don't have the time to figure it out.