In [None]:
import xarray as xr
# Need to use xbatcher from: https://github.com/arbennett/xbatcher/tree/develop
import xbatcher as xb
import numpy as np

from glob import glob
from parflow.tools.io import read_pfb

In [24]:
BASE_DIR = '/hydrodata/PFCLM/CONUS1_baseline/simulations'
YEAR = 2004

# Get Pressure files
pressure_files = sorted(glob(f'{BASE_DIR}/{YEAR}/raw_outputs/pressure/*.pfb'))
pressure_files = {
    't': pressure_files[0:-1],
    't+1': pressure_files[1:]
}

# Get parameter filesk
parameter_names = [
    'permeability', 'porosity', 'vgn_alpha', 'vgn_n', 'slope_x', 'slope_y'
]
parameter_files = {
    name: f'{BASE_DIR}/static/CONUS1_{name}.pfb' for name in parameter_names
}


# Get forcing files
all_forcings = glob(f'{BASE_DIR}/{YEAR}/WY{YEAR}/*.pfb')
varnames = set([f.split('/')[-1].split('.')[1] for f in all_forcings])

variable_forcings = {}
for v in varnames:
    variable_forcings[v] = sorted(glob(f'{BASE_DIR}/{YEAR}/WY{YEAR}/*.{v}.*pfb'))

In [40]:
X_EXTENT = 3342
Y_EXTENT = 1888
T_EXTENT = 8759 # 1 less because we are predicting t+1
Z_EXTENT = 5
PATCH_SIZE = 128
PATCH_OVERLAP = 32

In [41]:
dummy_data = xr.Dataset().assign_coords({
    'time': np.arange(T_EXTENT),
    'z': np.arange(Z_EXTENT),
    'y': np.arange(Y_EXTENT),
    'x': np.arange(X_EXTENT)
})

In [45]:
bgen = xb.BatchGenerator(
    dummy_data,
    input_dims={'x': PATCH_SIZE, 'y': PATCH_SIZE, 'time': 1},
    input_overlap={'x': PATCH_OVERLAP, 'y': PATCH_OVERLAP},
    return_partial=True,
    shuffle=True,
)

In [62]:
# Now you can see this pulls samples from teh dummy data
sample_indices = next(iter(bgen))
time_index = sample_indices['time'].values[0]
x_min, x_max = sample_indices['x'].values[[0, -1]]
y_min, y_max = sample_indices['y'].values[[0, -1]]

pressure_keys = {
    'x': {'start': x_min, 'stop': x_max},
    'y': {'start': y_min, 'stop': y_max},
}

In [None]:
# Construct the state data:
file_to_read = pressure_files['t'][time_index]
state_data = read_pfb(file_to_read, keys=pressure_keys)

# Construct the target data:
file_to_read_target = pressure_files['t+1'][time_index]
target_data = read_pfb(file_to_read_target, keys=pressure_keys)

# Construct the forcing data:


In [64]:
state_data.shape, target_data.shape

((5, 127, 127), (5, 127, 127))