In [1]:
import numpy as np
import pace.util

In all cases subdomain is just dimension we would like to parallelize across.  The reshaping could be less intense about figuring out the correct way to handle the subdomain stacking by just leaving it as a sampling dimension, perhaps?

This notebook plays around with that idea.

The two data types we might see are

(t, x, y, z)

or 

(x, y, z)

The tile divisions happen along the (x,y) dimension only.  Currently z is inferred.

Would like to handle subdomains such that it is easy to leave it as a leading dimension, since most of the matrix multiplication will automatically parallelize across that dimension.


decomposed to (t, s, xs, ys, z)
or (s, xs, ys, z)

what about t, s, xs*ys*z?
or xs*ys*z?

In [2]:
partitioner = pace.util.TilePartitioner((2, 2))

In [8]:
partitioner.subtile_slice(0, ["x", "y", "feature"], (6, 4, 10))

(slice(0, 3, None), slice(0, 2, None), slice(0, 10, None))

In [11]:
partitioner.subtile_slice(1, ["y", "x", "feature"], (6, 4, 10))

(slice(0, 3, None), slice(2, 4, None), slice(0, 10, None))

It doesn't have to be square, but it does have to be even.  We pass the extent without any overlap into the function.  The function returns slices for x,y or x,y,z where z is the full information.  We can include the feauture in the decomposition function to reduce reliance on inference.

In [4]:
a = np.ones((6, 4, 10))
sl = [..., slice(0, 2), slice(0, 2)]
a[sl].shape

  a[sl].shape


(6, 2, 2)

In [5]:
from typing import Sequence, Tuple, Optional


In [14]:

def _check_feature_dims_consistent(data_shape, feature_shape):
    n_feature_dims = len(feature_shape)
    feature_dims = data_shape[-n_feature_dims:]
    if feature_dims != feature_shape:
        raise ValueError(f"Feature dimensions of data {feature_dims} are not consistent with expected: {feature_shape}")

class RankXYDivider:
    def __init__(
        self,
        subdomain_layout: Tuple[int, int],
        rank_extent: Tuple[int, int],  # shape of full x,y data, including overlap, easier to initialize from halo included data
        z_feature: Optional[int],
    ):
        if len(subdomain_layout) != 2:
            raise ValueError("Rank divider only handles 2D subdomain layouts")
        
        self.rank_extent = rank_extent
        self.subdomain_layout = subdomain_layout
        self.n_subdomains = subdomain_layout[0] * subdomain_layout[1]
        self._partitioner = pace.util.TilePartitioner(subdomain_layout)

        self._x_rank_extent = self.rank_extent[0]
        self._y_rank_extent = self.rank_extent[1]
        self._z_feature = z_feature

        # TODO: maybe assert that subdomain decomp works for extent and subdomain layout?
        self._x_subdomain_extent = self._x_rank_extent // self.subdomain_layout[0]
        self._y_subdomain_extent = self._y_rank_extent // self.subdomain_layout[1]

    @property
    def subdomain_xy_extent(self):
        return self._x_subdomain_extent, self._y_subdomain_extent
    
    @property
    def _rank_extent_all_features(self):
        # Used for slicing, so needs to be rank extent, no overlap
        rank_extent = list(self.rank_extent)

        if self._z_feature is not None:
            rank_extent += [self._z_feature]
        
        return rank_extent
    
    @property
    def _rank_dims_all_features(self):
        # TODO: y might go before "x" since that aligns with row-major order (also, lat, lon, feature)
        # probably doesn't matter given x,y agnostic merge subdomains function
        rank_dims = ["x", "y"]

        if self._z_feature is not None:
            rank_dims += ["z"]
        
        return rank_dims
    
    def _get_subdomain_slice(self, subdomain_index):
        
        rank_dims = self._rank_dims_all_features
        rank_extent = self._rank_extent_all_features
        return self._partitioner.subtile_slice(subdomain_index, rank_dims, rank_extent)
   
    def _add_potential_leading_dim_to_slices(self, data_shape, dim_slices):

        # add leading dimensions to slice if necessary
        if len(data_shape) > len(self._rank_dims_all_features):
            dim_slices = [..., *dim_slices]

        return dim_slices

    def get_subdomain(self, data, subdomain_index):
        
        _check_feature_dims_consistent(data.shape, self._rank_extent_all_features)
        dim_slices = self._get_subdomain_slice(subdomain_index)
        dim_slices = self._add_potential_leading_dim_to_slices(data.shape, dim_slices)

        return data[dim_slices]
    
    def get_all_subdomains(self, data):
        subdomains_with_new_leading_dim = []
        for i in range(self.n_subdomains):
            subdomain = self.get_subdomain(data, i)
            subdomains_with_new_leading_dim.append(subdomain[np.newaxis])
        return np.concatenate(subdomains_with_new_leading_dim, axis=0)

    @property
    def subdomain_shape(self):
        shape = list(self.subdomain_xy_extent)
        if self._z_feature is not None:
            shape += [self._z_feature]
        return shape
    
    @property
    def flat_subdomain_shape(self):
        return [np.prod(self.subdomain_shape)]

    def flatten_subdomain_features(self, data):
        feature_shape = self.subdomain_shape
        _check_feature_dims_consistent(data.shape, feature_shape)
        n_feature_dims = len(feature_shape)
        return data.reshape(list(data.shape[:-n_feature_dims]) + [-1])
    
    def reshape_flat_subdomain_features(self, data):
        flat_feature_shape = self.flat_subdomain_shape
        _check_feature_dims_consistent(data.shape, flat_feature_shape)
        return data.reshape(self.subdomain_shape)
    
    def merge_all_subdomains(self, data):
        # [nsubdomains, leading, ..., x, y, (z)]

        if data.shape[0] != self.n_subdomains:
            raise ValueError(f"Expected data to have first dimension of length {self.n_subdomains}, but got {data.shape[0]}")
        
        _check_feature_dims_consistent(data.shape, self.subdomain_shape)
        rank_extent = self._rank_extent_all_features
        new_shape = list(data.shape[1:-len(rank_extent)]) + rank_extent
        merged = np.empty(new_shape, dtype=data.dtype)

        for i in range(self.n_subdomains):
            subdomain = data[i]
            dim_slices = self._get_subdomain_slice(i)
            dim_slices = self._add_potential_leading_dim_to_slices(subdomain.shape, dim_slices)
            merged[dim_slices] = subdomain
        
        return merged


class OverlapRankXYDivider(RankXYDivider):

    """
    Adjusted rank divider to handle halo overlap regions
    """

    def __init__(
        self, 
        subdomain_layout: Tuple[int, int],
        overlap_rank_extent: Tuple[int, int],
        overlap: int,
        z_feature: Optional[int],
    ):
        if len(subdomain_layout) != 2:
            raise ValueError("Rank divider only handles 2D subdomain layouts")
        
        self.overlap_rank_extent = overlap_rank_extent
        self.subdomain_layout = subdomain_layout
        self.n_subdomains = subdomain_layout[0] * subdomain_layout[1]
        self._partitioner = pace.util.TilePartitioner(subdomain_layout)
        self.overlap = overlap

        self._x_rank_extent = self.overlap_rank_extent[0]
        self._y_rank_extent = self.overlap_rank_extent[1]
        self._z_feature = z_feature

        self.rank_extent = (
            self._x_rank_extent - 2 * self.overlap,
            self._y_rank_extent - 2 * self.overlap
        )

        # TODO: maybe assert that subdomain decomp works for extent and subdomain layout?
        self._x_subdomain_extent = self.rank_extent[0] // self.subdomain_layout[0] + 2 * self.overlap
        self._y_subdomain_extent = self.rank_extent[1] // self.subdomain_layout[1] + 2 * self.overlap

    def _update_slices_with_overlap(self, slices):
        rank_dims = self._rank_dims_all_features

        x_ind = rank_dims.index("x")
        y_ind = rank_dims.index("y")

        slices = list(slices)
        x_sl, y_sl = slices[x_ind], slices[y_ind]
        slices[x_ind] = slice(x_sl.start, x_sl.stop + 2 * self.overlap, None)
        slices[y_ind] = slice(y_sl.start, y_sl.stop + 2 * self.overlap, None)

        return slices
    
    def _get_subdomain_slice(self, subdomain_index):
        slices = super()._get_subdomain_slice(subdomain_index)
        return self._update_slices_with_overlap(slices)
    
    def merge_all_subdomains(self, data):
        raise NotImplementedError("Merging overlapped subdomains is not supported.")
