for misc testing

In [1]:
import sys; sys.path.append("..")
from models.utils import *

In [2]:
def reshape_inputs(data: xr.core.dataset.Dataset, 
                   keep_coords: List=["time", "latitude", "longitude"],
                   avg_time_window: Optional[int]=None, 
                   history: Optional[int]=None,
                   data_vars: List=["SSH", "SST", "SSS", "OBP", "ZWS"],
                   return_pt: bool=False) -> np.ndarray:
    """
    Read in the original input dataset, with coordinates "time", "latitude", "longitude",
    and data variables "SSH", "SST", "SSS", "OBP", "ZWS".

    Return a numpy array of any subset of the data variables, optionally averaged over any coordinates or including history.
    Can also return a pytorch tensor.

    data: original xarray dataset - see solodoch_data_minimal in google drive.
    keep_coords: coordinate axes to be kept. others will be averaged over and collapsed.
    avg_time_window: if time is not included in keep_coords, optionally choose a lag over which to average.
    history: include a new axis for history (useful if we want to convolve over past values for example)
    data_vars: data variables to be kept.
    return_pt: if true, returns a pytorch tensor (cpu!) instead of a numpy array.

    """
    
    def moving_average(data: np.ndarray,
                       lag: int) -> np.ndarray:
        """
        Calculate a moving average over the time dimension.

        data: subset of values from original dataset. intermediate step of reshape_inputs.
        lag: lag over which to average.
        """
        # axis order is guaranteed due to reshape_inputs
        n_times, n_lats, n_lons, n_features = data.shape
        D = n_times - lag + 1
        view_shape = (D, lag, n_lats, n_lons, n_features)
        s = data.strides; strides = (s[0], s[0], s[1], s[2], s[3])
        data_ = as_strided(data, shape=view_shape, strides=strides)
        return data_.mean(axis=1).squeeze(axis=1)
    
    coords = ["time", "latitude", "longitude"]
    data = data[coords + data_vars].to_array().values
    data = data.transpose(1, 2, 3, 0)
    for ax in coords:
        if ax not in keep_coords: 
            if ax == "time" and avg_time_window != None: 
                data = moving_average(data, avg_time_window)
            else:
                data = data.mean(axis=coords.index(ax))
            coords = [c for c in coords if c != ax]   

    if history != None:
        if "time" not in keep_coords: raise Exception("Error. 'time' must be in keep_coords in order to use history.")  
        coords = ["time", "history"] + coords[1:]
        n_times = data.shape[0]
        if history > n_times: raise ValueError("Desired history is longer than the full time series.")
        view_shape = (n_times-history+1, history, *data.shape[1:])      
        s = data.strides[0]
        data = as_strided(data, shape=view_shape, strides=(s, s, *data.strides[1:]))    

    print(f"axes: {coords + ['feature']}")
    print(f"variables: {data_vars}")
    print(f"shape: {data.shape}")    
    return t.Tensor(data) if return_pt else data

In [None]:
H = 5

n_times, n_lats, n_lons, n_features = foo.shape
view_shape = (n_times-H+1, H, n_lats, n_lons, n_features)
x, y, z, w = foo.strides
bar = foo.as_strided(foo, shape=view_shape, strides=(x, x, y, z, w))

In [5]:
data_home = "/mnt/g/My Drive/GTC/solodoch_data_minimal"
lats = ["26N", "30S", "55S", "60S"]
lat = lats[0]
data = xr.open_dataset(f"{data_home}/{lat}.nc")

# apply whatever preprocessing we want *before* calling reshape_inputs
pp_data = apply_preprocessing(data,
                              mode="inputs",
                              remove_season=True,
                              remove_trend=True,
                              standardize=True,
                              lowpass=True)

# reshape as desired and convert to a numpy array
pp_data_np = reshape_inputs(pp_data, keep_coords=["time"], history=288, return_pt=True)

axes: ['time', 'history', 'feature']
variables: ['SSH', 'SST', 'SSS', 'OBP', 'ZWS']
shape: (1, 288, 5)


In [6]:
pp_data_np

tensor([[[ 1.0279,  1.0273,  1.3772, -0.1046,  3.3314],
         [ 0.1325,  0.3698,  0.9208, -1.2171,  1.1591],
         [-0.4907, -0.1333,  0.4977, -1.8350, -0.2081],
         ...,
         [-0.5665,  0.4738,  0.2288, -1.2381,  0.2949],
         [-0.7330,  0.4428,  0.4530, -2.0619, -0.2224],
         [-0.7610,  0.2743,  0.7834, -3.1836, -0.8352]]])

In [None]:
foo = reshape_inputs(pp_data)

In [None]:
foo.strides