# Parallelism Strategy

In [None]:
#| default_exp strategy

In [None]:
#| export
from __future__ import annotations
from relax.import_essentials import *

In [None]:
#| export
class BaseStrategy:
    """Base class for mapping strategy."""
    
    def __call__(
        self, 
        fn: Callable, # Function to generate cf for a single input
        X: Array, # Input instances to be explained
        pred_fn: Callable[[Array], Array],
        **kwargs
    ) -> Array: # Generated counterfactual explanations
        raise NotImplementedError
    
    __ALL__ = ["__call__"]


In [None]:
#| export
class IterativeStrategy(BaseStrategy):
    """Iterativly generate counterfactuals."""

    def __call__(
        self, 
        fn: Callable, # Function to generate cf for a single input
        X: Array, # Input instances to be explained
        pred_fn: Callable[[Array], Array],
        **kwargs
    ) -> Array: # Generated counterfactual explanations
        
        assert X.ndim == 2
        cfs = jnp.stack([fn(X[i], pred_fn=pred_fn, **kwargs) for i in range(X.shape[0])])
        assert X.shape == cfs.shape
        return cfs


In [None]:
#| export
class VmapStrategy(BaseStrategy):
    """Generate counterfactuals via `jax.vmap`."""

    def __call__(
        self, 
        fn: Callable, # Function to generate cf for a single input
        X: Array, # Input instances to be explained
        pred_fn: Callable[[Array], Array],
        **kwargs
    ) -> Array: # Generated counterfactual explanations
        
        assert X.ndim == 2
        partial_fn = partial(fn, pred_fn=pred_fn, **kwargs)
        cfs = jax.vmap(partial_fn)(X)
        return cfs


In [None]:
#| exporti
def _pad_divisible_X(
    X: Array,
    n_devices: int
):
    """Pad `X` to be divisible by `n_devices`."""
    if X.shape[0] % n_devices != 0:
        pad_size = n_devices - X.shape[0] % n_devices
        X = jnp.concatenate([X, jnp.zeros((pad_size, *X.shape[1:]))])
    X_padded = X.reshape(n_devices, -1, *X.shape[1:])
    return X_padded


In [None]:
#| hide
X = jnp.ones((5, 29))
X_padded = _pad_divisible_X(X, 2)
assert X_padded.shape == (2, 3, 29)
assert X.sum() == X_padded.sum()

X = jnp.ones((5, 29))
X_padded = _pad_divisible_X(X, 6)
assert X_padded.shape == (6, 1, 29)

X = jnp.ones((5, 29))
X_padded = _pad_divisible_X(X, 1)
assert X_padded.shape == (1, 5, 29)



No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [None]:
#| export
class PmapStrategy(BaseStrategy):
    def __init__(
        self, 
        n_devices: int = None, # Number of devices. If None, use all available devices
        strategy: str = 'auto', # Strategy to generate counterfactuals
        **kwargs
    ):
        self.strategy = strategy
        self.n_devices = n_devices or jax.device_count()

    def __call__(
        self, 
        fn: Callable, # Function to generate cf for a single input
        X: Array, # Input instances to be explained
        pred_fn: Callable[[Array], Array],
        **kwargs
    ) -> Array: # Generated counterfactual explanations
        
        assert X.ndim == 2
        X_padded = _pad_divisible_X(X, self.n_devices)
        partial_fn = partial(fn, pred_fn=pred_fn, **kwargs)
        cfs = jax.pmap(jax.vmap(partial_fn))(X_padded)
        cfs = cfs.reshape(-1, *cfs.shape[2:])
        cfs = cfs[:X.shape[0]]
        return cfs


In [None]:
#| exporti
def _batched_generation(
    gs_fn: Callable, # Generation strategy function
    cf_fn: Callable, # Function to generate cf for a single input
    X: Array, # Input instances to be explained
    pred_fn: Callable[[Array], Array],
    batch_size: int,
    **kwargs
) -> Array: # Generated counterfactual explanations
    """Batched  of counterfactuals."""
    
    assert X.ndim == 2, f"X must be a 2D array, got {X.ndim}D array"
    x_shape = X.shape
    batch_size = min(batch_size, x_shape[0])
    # pad X to be divisible by batch_size
    pad_size = batch_size - (X.shape[0] % batch_size)
    X = jnp.pad(X, ((0, pad_size), (0, 0)))
    X = X.reshape(-1, batch_size, *x_shape[1:])
    # generate cfs via lax.map
    gs_fn_partial = lambda x: gs_fn(cf_fn, x, pred_fn=pred_fn, **kwargs)
    cfs = lax.map(gs_fn_partial, X)
    cfs = cfs.reshape(-1, *x_shape[1:])[:x_shape[0]]
    return cfs
     


In [None]:
#| export
class BatchedVmapStrategy(BaseStrategy):
    """Auto-batching for generate counterfactuals via `jax.vmap`."""
    def __init__(self, batch_size: int):
        self.batch_size = batch_size

    def __call__(
        self, 
        fn: Callable, # Function to generate cf for a single input
        X: Array, # Input instances to be explained
        pred_fn: Callable[[Array], Array],
        **kwargs
    ) -> Array: # Generated counterfactual explanations
        vmap_g = VmapStrategy()    
        cfs = _batched_generation(
            vmap_g, fn, X, pred_fn, self.batch_size, **kwargs
        )
        return cfs


In [None]:
#| export
class BatchedPmapStrategy(BaseStrategy):
    """Auto-batching for generate counterfactuals via `jax.vmap`."""
    def __init__(self, batch_size: int, n_devices: int = None):
        self.batch_size = batch_size
        self.n_devices = n_devices

    def __call__(
        self, 
        fn: Callable, # Function to generate cf for a single input
        X: Array, # Input instances to be explained
        pred_fn: Callable[[Array], Array],
        **kwargs
    ) -> Array: # Generated counterfactual explanations
        pmap_g = PmapStrategy(self.n_devices)
        cfs = _batched_generation(
            pmap_g, fn, X, pred_fn, self.batch_size, **kwargs
        )
        return cfs


In [None]:
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

w = jrand.normal(jrand.PRNGKey(0), (100, 100))
X = jrand.normal(jrand.PRNGKey(0), (1000, 100))

@jit
def pred_fn(x): return jnp.dot(x, w.T)

def f(x, pred_fn=None, **kwargs):
    return pred_fn(x)

iter_gen = IterativeStrategy()
vmap_gen = VmapStrategy()
pmap_gen = PmapStrategy()
bvmap_gen = BatchedVmapStrategy(128)
bpmap_gen = BatchedPmapStrategy(128)


In [None]:
cf_iter = iter_gen(f, X, pred_fn=pred_fn).block_until_ready()

In [None]:
cf_vmap = vmap_gen(f, X, pred_fn=pred_fn).block_until_ready()


In [None]:
cf_pmap = pmap_gen(f, X, pred_fn=pred_fn).block_until_ready()


In [None]:
cf_bvmap = bvmap_gen(f, X, pred_fn=pred_fn).block_until_ready()


In [None]:
#| hide
# check when batch_size > X.shape[0]
_bvmap_gen = BatchedVmapStrategy(1280)
_cf_bvmap = _bvmap_gen(f, X, pred_fn=pred_fn).block_until_ready()
assert jnp.allclose(cf_bvmap, _cf_bvmap, atol=1e-4)


In [None]:
#| hide
cf_bpmap = bpmap_gen(f, X, pred_fn=pred_fn).block_until_ready()

In [None]:
#| hide
assert jnp.allclose(cf_iter, cf_vmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_bvmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_pmap, atol=1e-4)
assert jnp.allclose(cf_iter, cf_bpmap, atol=1e-4)

In [None]:
#| export
class StrategyFactory(object):
    """Factory class for Parallelism Strategy."""

    __strategy_map = {
        'iter': IterativeStrategy(),
        'vmap': VmapStrategy(),
        'pmap': PmapStrategy(),
    }

    def __init__(self) -> None:
        raise ValueError("This class should not be instantiated.")
        
    @staticmethod
    def get_default_strategy() -> BaseStrategy:
        """Get default strategy."""
        return VmapStrategy()

    @classmethod
    def get_strategy(cls, strategy: str | BaseStrategy) -> BaseStrategy:
        """Get strategy."""
        if isinstance(strategy, BaseStrategy):
            return strategy
        elif isinstance(strategy, str) and strategy in cls.__strategy_map:
            return cls.__strategy_map[strategy]
        else:
            raise ValueError(f"Invalid strategy: {strategy}")
        
    __ALL__ = ["get_default_strategy", "get_strategy"]

In [None]:
StrategyFactory.__ALL__

['get_default_strategy', 'get_strategy']

In [None]:
it = StrategyFactory.get_strategy('iter')
vm = StrategyFactory.get_strategy('vmap')
pm = StrategyFactory.get_strategy('pmap')
default = StrategyFactory.get_default_strategy()
cus = StrategyFactory.get_strategy(VmapStrategy())

assert isinstance(it, IterativeStrategy)
assert isinstance(vm, VmapStrategy)
assert isinstance(pm, PmapStrategy)
assert isinstance(default, VmapStrategy)
assert isinstance(cus, VmapStrategy)