In [1]:
import xarray as xr
import pandas as pd
import numpy as np

# Define the parameters
basins = ['A', 'B', 'C']
dates = pd.date_range('2022-01-01', periods=10, freq='D')
features = ['temp', 'precip', 'flow']

# Create random data
np.random.seed(0)
data = np.random.rand(len(basins), len(dates), len(features))

# Create the xarray Dataset
x_d = xr.DataArray(data, 
                   coords={'basin': basins, 'date': dates, 'feature': features},
                   dims=['basin', 'date', 'feature'])

# To make it more realistic, let's add some static data for each basin
static_features = ['elevation', 'area']
static_data = np.random.rand(len(basins), len(static_features))

x_s = xr.DataArray(static_data, 
                   coords={'basin': basins, 'static_feature': static_features},
                   dims=['basin', 'static_feature'])

print(x_d)
print(x_s)

In [2]:
# Example indices
indices = [0, 2]  # Corresponds to basins 'A' and 'C'
selected_basins = basins[indices]
selected_dates = [pd.date_range('2022-01-01', periods=3, freq='D'), pd.date_range('2022-01-04', periods=3, freq='D')]

# Convert to xarray-friendly formats
basins_da = xr.DataArray(selected_basins, dims="sample")
sequence_dates_da = xr.DataArray(selected_dates, dims=["sample", "time"])

# Fetch data
selected_data = x_d.sel(basin=basins_da, date=sequence_dates_da)

# Transpose to get the correct shape: [batch_size, seq_length, n_features]
selected_data = selected_data.transpose("sample", "time", "feature")

print(selected_data)


<xarray.DataArray (sample: 2, time: 3, feature: 3)> Size: 144B
array([[[0.5488135 , 0.71518937, 0.60276338],
        [0.54488318, 0.4236548 , 0.64589411],
        [0.43758721, 0.891773  , 0.96366276]],

       [[0.09609841, 0.97645947, 0.4686512 ],
        [0.97676109, 0.60484552, 0.73926358],
        [0.03918779, 0.28280696, 0.12019656]]])
Coordinates:
    basin    (sample) <U1 8B 'A' 'C'
    date     (sample, time) datetime64[ns] 48B 2022-01-01 ... 2022-01-06
  * feature  (feature) <U6 72B 'temp' 'precip' 'flow'
Dimensions without coordinates: sample, time


In [8]:
selected_data.data

array([[[0.5488135 , 0.71518937, 0.60276338],
        [0.54488318, 0.4236548 , 0.64589411],
        [0.43758721, 0.891773  , 0.96366276]],

       [[0.09609841, 0.97645947, 0.4686512 ],
        [0.97676109, 0.60484552, 0.73926358],
        [0.03918779, 0.28280696, 0.12019656]]])