# SPFlow 1.0.0 Design Concept

Design basis:
* Easy-to use and extend
<br>

* Functional design (using dispatch)
<br>

* Modular building blocks that can be stacked and nested
    * Allows for quick design of large models
    * Allows to combine existing or extend existing models
<br>


* Support multiple back-end with one-to-one mappings between backends
    * Allows for easy model sharing and conversion to favorite back-end
    * Base back-end should use explicit node modules as lowest basic blocks
        * E.g allows alternative node-wise evaluations (e.g. for computing p-values)
    * Other back-ends may use optimized modules (i.e. implicit nodes)
        * Still need to be mappable to base-backend

![overview](uml/all.svg)

##### Imports

In [1]:
from typing import Any, List, Tuple, Set, Dict, Union, Optional
from abc import ABC, abstractmethod, abstractproperty
from dataclasses import dataclass
import itertools
from multipledispatch import dispatch

import numpy as np
from scipy.stats import norm

## NetworkType Class

* All network types (e.g. `SPN`, `BN`, ...) inherit from it
* Network types do not need to actually implement anything
* Simply needed to dispatch on network type objects (e.g. for computing scopes, likelihoods, sampling etc.)

![network_type](uml/network_type.svg)

In [2]:
class NetworkType(ABC):
    """Abstract base class for network types."""
    pass

#### SPN Network Type

In [3]:
class SPN(NetworkType):
    """Sum-Product Network (SPN) network type."""
    pass


## Scope Class

Attributes:
* `scope`: set of indices (scope)
* `ntype`: network type this scope is in reference to

The class can, for example, overwride `__add__` operator to merge scopes based on their network types.

**TODO**:
* Network types and modules may have to check whether or not this merging is valid (e.g. sum- vs product-nodes)
* Conditional/evidence variables

![scope](uml/scope.svg)

In [4]:
class Scope:
    """Class representing Variable scopes."""
    def __init__(self, ntype: NetworkType, query: Union[Set[int], List[int]], evidence: Optional[Union[Set[int], List[int]]]=None) -> None:

        if evidence is None:
            evidence = set()

        query = set(query)
        evidence = set(evidence)

        if not query.isdisjoint(evidence):
            raise ValueError("Specified sets of query and evidence variables are not disjoint.")

        self.query = query
        self.evidence = evidence
        self.network_type = ntype

    def __repr__(self) -> str:
        return "Scope({}|{})".format(self.query if self.query else "{}", self.evidence if self.evidence else "{}")

    def __add__(self, other: "Scope") -> "Scope":
        "Dispatches scope merging based on network types."
        return merge_scope(self, other)
    
    def __len__(self) -> int:
        return len(self.query)

Merging scopes is automatically dispatched to merging scopes based on their respective network type contexts.

In [5]:
@dispatch(Scope, Scope)
def merge_scope(scope1: Scope, scope2: Scope) -> Scope:
    """Generic intermediate function to dispatch merging based on network types."""
    return merge_scope(scope1.network_type, scope1, scope2.network_type, scope2)

To register a new merge operation (e.g. between two `SPN` scopes): dispatch appropriate function.

In [6]:
@dispatch(SPN, Scope, SPN, Scope)
def merge_scope(ntype1: SPN, scope1: Scope, ntype2: SPN, scope2: Scope) -> Scope:
    """Merges two SPN scopes."""
    return Scope(SPN(), set.union(scope1.query, scope2.query), set.union(scope1.evidence, scope2.evidence))

##### Examples

In [7]:
s1 = Scope(SPN(), [0,1])
print("Scope1:", s1)

Scope1: Scope({0, 1}|{})


In [8]:
s2 = Scope(SPN(), [1,2,3])
print("Scope 2:", s2)

Scope 2: Scope({1, 2, 3}|{})


In [9]:
s = s1+s2
print("Merged scope:", s)

Merged scope: Scope({0, 1, 2, 3}|{})


### Scope Array

Array class inheriting from `np.ndarray`. Alterantively one can just pass `dtype=Scope` as an argument to `np.ndarray`.

Can store scopes with different scope length in an array format that can easily be propagated through a model (e.g. similar to likelihood values)

**NOTE**: creating arrays/tensors with `Scope`-elements might not be possible in all backends. In any case (since scope outputs for modules are not multi-dimensional), a list with `Scope`s could be returned.

In [10]:
class ScopeArray(np.ndarray):
    """Numpy array with Scope object elements."""
    def __new__(cls, data) -> np.ndarray:
        return np.array(data, dtype=Scope)

##### Examples

In [11]:
ScopeArray([s1, s2])

array([Scope({0, 1}|{}), Scope({1, 2, 3}|{})], dtype=object)

Also allows to use appropriate numpy-operations

In [12]:
ScopeArray([s1, s2]).sum()

Scope({0, 1, 2, 3}|{})

## Module Class

Every module inherits from this class.

Each module must also implement `__len__(self)` to return the number of (implicit or explicit) output node of this module.

![module](uml/module.svg)

In [13]:
class Module(ABC):
    """Abstract base class for modules.
    
    Args:
        children: list of child modules (may be empty for terminal modules).
    """
    def __init__(self, children: Optional[List["Module"]]=None) -> None:
        
        if children is None:
            children = []
        
        # set child modules
        self.children = children
        
        # infer number of inputs from children (and their numbers of outputs)
        child_num_outputs = [child.n_out for child in self.children]
        child_cum_outputs = np.cumsum(child_num_outputs)
        
        self.n_in = sum(child_num_outputs, 0)

        # compute conversion from input ids corresponding child and output id (Saves computation at run-time)
        self.input_to_output_id_dict = {}
        
        for input_id in range(self.n_in):
            # get child module for corresponding input
            child_id = np.sum(child_cum_outputs <= input_id, axis=0).tolist()
            # get output id of child module for corresponding input
            output_id = input_id-(child_cum_outputs[child_id]-child_num_outputs[child_id])
            
            self.input_to_output_id_dict[input_id] = (child_id, output_id)

    @abstractproperty
    def n_out(self) -> int:
        """Specifies the number of outputs, i.e. (implicit of explicit) output nodes."""
        pass

    def input_to_output_id(self, input_id) -> Tuple[int, int]:
        """Helper method to convert an input id to a corresponding child and child output id."""
        return self.input_to_output_id_dict[input_id]

# Dispatching

Errors/exceptions during dispatching are hard to trace back. Here, we followed the convention to name the `Module` argument (that is dispatched on) after the corresponding dispatched class. This signature shows up in the error traceback and helps debugging.

**TODO**: one might provide a better trace-back via decorators?

## Sampling

Sampling should return an array in a form that could immediately be used to infer likelihoods of that same module. That entails, that the RVs are in ascending order and that no RVs are skipped (e.g. sampling over a scope of `[0,2,5]` should return an array of size `(n,6)` to accommodate for all scope RVs at their respective indices while leaving all other entries as NaNs; `n` is the number of samples).

API:
- `sample(module)` should return return a single sample (size `(1,m)`)
- `sample(module, n)` should return return `n` samples (size `(n,m)`)
- `sample(module, array)` should fill the specified array in-place and return it as well (array must be of appropriate size). This also allows to specify incomplete data that is not replaced during sampled and whose likelihoods are taken into account while sampling.|

When sampling for multiple instances, we need to keep track which modules are supposed to sample which instances.

Example:
![sampling_context_example](img/sampling_context_example.drawio.svg)
In this case, the sum node samples a branch (B or C) for each instance, so A and B are sampling into the same data array, but get different instances to sample.

In the multi-output case (i.e. modules, see below) one additionally needs to specify which outputs are supposed to be sampled.

Example:
![sampling_context_example](img/sampling_context_example_2.drawio.svg)
Here, we essentially have the same graph as above, but `sample(...)` is called on the same child module for all instances.

`SamplingContext` class to track which instances to sample and which module outputs to sample from. Keeps a list of instance ids to fill with samples and corresponding output ids to sample from for these instances for a given module.

![sampling_context](uml/sampling_context.svg)

In [14]:
class SamplingContext:
    """Keeps track of instance ids to sample and which output ids to sample from (relevant for modules with multiple outputs).
    
    Args:
        instance_ids: list of ints representing the instances to sample.
        output_ids: list of lists of ints representing the output ids for the corresponding instances to sample from (relevant for multi-output module).
                    As a shorthand, '[]' implies to sample from all outputs for a given instance id.
    """
    def __init__(self, instance_ids: List[int], output_ids: Optional[List[List[int]]]=None) -> None:

        if output_ids is None:
            # assume sampling over all outputs (i.e. [])
            output_ids = [[] for _ in instance_ids]
        
        if (len(output_ids) != len(instance_ids)):
            raise ValueError(f"Number of specified instance ids {len(instance_ids)} does not match specified number of output ids {len(output_ids)}.")

        self.instance_ids = instance_ids
        self.output_ids = output_ids

    def select(self, ids: List[int]) -> "SamplingConctext":
        """Selects a subset of the instance ids and their corresponding output ids in a new sampling context object."""
        selection = [pair for i, pair in enumerate(zip(self.instance_ids, self.output_ids)) if i in ids]
        return SamplingContext(*tuple(list(item) for item in zip(*selection)))

    def append(self, instance_ids: List[int], output_ids: List[List[int]]) -> None:
        """Adds new instance ids and corresponding output ids to the sampling context."""
        if(len(instance_ids) != len(output_ids)):
            raise ValueError(f"Number of specified instance ids {len(instance_ids)} does not match specified number of output ids {len(output_ids)}.")
        
        self.instance_ids += instance_ids
        self.output_ids += output_ids

    def remove(self, ids: List[int]) -> None:
        """Removes a subset of the instance ids and their corresponding output ids from the sampling context object."""
        selection = [pair for i, pair in enumerate(zip(self.instance_ids, self.output_ids)) if i not in ids]
        self.instance_ids, self.output_ids = tuple(list(item) for item in zip(*selection))
    
    def is_valid(self, data: Optional[np.ndarray]=None, module: Optional[Module]=None) -> bool:
        """Returns a boolean whether or not the sampling context object is valid. Additionally, a data array and a module may be specified)."""
        
        # check if there are any duplicate instance ids
        if len(set(self.instance_ids)) != len(self.instance_ids):
            return False
        
        # validate individual instance ids
        for instance_id in self.instance_ids:
            # check if all instance ids are greater or equal to 0
            if instance_id < 0:
                return False
            # if a data array is specified, then also check if the instance ids are valid for it
            if (data is not None) and (instance_id >= data.shape[0]):
                return False
        
        # check if number of output ids matches number of input ids
        if len(self.instance_ids) != len(self.output_ids):
            return False
        
        # check if any output id lists contain duplicates
        for out_ids in self.output_ids:
            if len(set(out_ids)) != len(out_ids):
                return False
        
        # validate individual output ids
        for out_ids in self.output_ids:
            for out_id in out_ids:
                # Check if all output ids are greater of equal to 0
                if out_id < 0:
                    return False
                # if a module is specified, then also check if the output ids are valid indices for it
                if (module is not None) and (out_id >= module.n_out):
                    return False
        
        return True

##### Examples

In [15]:
sc = SamplingContext([0,2,4,6],[[1],[],[],[4]])
print(sc.instance_ids, sc.output_ids)

[0, 2, 4, 6] [[1], [], [], [4]]


In [16]:
# select certain instance ids (e.g. to pass to a child module for sampling)
sc2 = sc.select([1,3])
print(sc2.instance_ids, sc2.output_ids)

[2, 6] [[], [4]]


In [17]:
# remove certain instance ids
sc.remove([1,3])
print(sc.instance_ids, sc.output_ids)

[0, 4] [[1], []]


Dispatch sampling for all modules without specified data tensor (see API-notes above)

## Dispatch Cache

Dispatching function calls (e.g. `likelihood`), might require to cache results to avoid redundant computations.

![dispatch_cache](uml/dispatch_cache.svg)

In [18]:
from typing import Dict
import numpy as np

class DispatchCache(dict):
    def __init__(self):
        self.likelihood = {}
        self.sample = {}
        self.scope = {}
        self._valid_keys = ['likelihood', 'sample', 'scope']
    
    def __contains__(self, key: Module) -> bool:
        return any(key in getattr(self, k) for k in self._valid_keys)
    
    def __getitem__(self, key: Module) -> Dict[str, np.ndarray]:
        if not isinstance(key, Module):
            raise KeyError(f"Cache key is of type {type(key)}, but is expected to be of type {Module}.")

        cach_values = {k: getattr(self, k)[key] for k in self._valid_keys if key in getattr(self, k)}

        if(cach_values):
            return cach_values
        else:
            raise KeyError(f"{key} not found in cache.")
    
    def __setitem__(self, key: Module, values: Dict[str, np.ndarray]) -> None:
        for k in values.keys():
            if k not in self._valid_keys:
                raise KeyError(f"Value dictionary for setting dispatch cache contains invalid key '{k}'.")

        for k in self._valid_keys:
            if k in values:
                getattr(self, k)[key] = values[k]

# Dispatch Context

Besides a cache, modules might require (or allow) additional arguments for dispatched function calls. To avoid messy function signatures and an overly complicated argument management by the user, it might be wiser to collect everything in a single class (called `DispatchContext` here).

![dispatch_context](uml/dispatch_context.svg)

In [19]:
from typing import Any

class DispatchContext(dict):
    def __init__(self):
        self.cache = DispatchCache()
        self.args = {}
        self._valid_keys = ['cache', 'args']

    def __contains__(self, key: Module) -> bool:
        return (key in self.cache) or (key in self.args)

    def __getitem__(self, key: Module) -> Dict[str, Any]:
        
        values = {}
        
        if key in self.cache:
            values['cache'] = self.cache[key]
        if key in self.args:
            values['args'] = self.args[key]
        
        if values:
            return values
        else:
            raise KeyError(f"{key} not found in dispatch context.")

    def __setitem__(self, key: Module, values: Dict[str, Any]) -> None:
        
        for k in values.keys():
            if k not in self._valid_keys:
                raise KeyError(f"Value dictionary for setting dispatch context contains invalid key '{k}'.")
        
        for k in values.keys():
            getattr(self, k)[key] = values[k]

### Memoization Decorator

E.g. can wrap the `likelihood` function and automatically check the cache and re-use or fill it.

**NOTE**: the `memoize` decorate must come after the `dispatch` decorator (before in the decorator evaluation logic), so that the memoized version of the function is being dispatched!

![memoize](uml/memoize.svg)

In [20]:
from functools import wraps

def memoize(f):
    """Wraps a function to automatically check against a cache ('cache' keyword argument) using the first argument as the key.
    If present, the cached value is returned, otherwise it is computed using the wrapped function stored in the cache.
    """
    @wraps(f)
    def memoized_f(*args, **kwargs):
        
        # ----- cache -----
        
        # first look in keyword arguments
        if "dispatch_ctx" in kwargs and isinstance(kwargs["dispatch_ctx"], DispatchContext):
            cache = kwargs['dispatch_ctx'].cache
        else:
            # otherwise check positional argumentsw
            candidates = [arg for arg in args if isinstance(arg, DispatchContext)]

            if(len(candidates) > 1):
                # if there are multiple candidates raise an error
                raise ValueError("TODO: multiple candidates found.")
            elif candidates:
                # otherwise if there is only one candidate, then use it
                cache = candidates.pop().cache
            else:
                # otherwise create a new dispatch context
                kwargs['dispatch_ctx'] = DispatchContext()
                cache = kwargs['dispatch_ctx'].cache

        # ----- module (cache key) -----
        
        # args contains key variable to be used for cache (assumes key variable is first positional argument to f)
        if len(args) > 0:
            key = args[0]
        # key variable must be part of kwargs
        else:
            raise ValueError("No argument to cache against")
        
        # get function-specific cache
        f_cache = getattr(cache, f.__name__)

        if key not in f_cache:
            # compute result and update cache
            f_cache[key] = f(*args, **kwargs)

        return f_cache[key]
    return memoized_f

In [21]:
@dispatch(Module, dispatch_ctx=DispatchContext)
@memoize
def sample(module: Module, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Dispatches sampling a single instance from the module."""
    return sample(module, 1, dispatch_ctx=dispatch_ctx)

@dispatch(Module, int, dispatch_ctx=DispatchContext, sampling_ctx=SamplingContext)
@memoize
def sample(module: Module, n_samples: int, dispatch_ctx: DispatchContext, sampling_ctx: SamplingContext=None) -> np.ndarray:
    """Creates an array for n samples and dispatches sampling to fill the array"""
    if(not sampling_ctx):
        # create sampling context (assume all instances are sampled from)
        sampling_ctx = SamplingContext(list(range(n_samples)), [[] for _ in range(n_samples)])

    # get module scope and largest random variable id
    module_scope = scope(module).squeeze().tolist()
    max_rv = max(module_scope.query) # TODO: include evidence?

    # create appropriate data tensor to fill
    data = np.full((n_samples,max_rv+1), float("nan"))

    # sample and fill data tensor
    return sample(module, data, dispatch_ctx=dispatch_ctx, sampling_ctx=sampling_ctx)

## (Log-)Likelihood

In general, we would like to offer both `likelihood` and `loglikelihood` routines. For better comprehension, we only implement `likelihood` in this module. Note, that we can easily implement one from the other by simply calling `log(..)` or `exp(..)`, respectively.

API:
- `likelihood(module, data)`: should return an array of the size `(n,1)`. `np.nan` values are marginalized over (not necessarily demonstrated here).
- `likelihood(module, data, dict)`: specifies a dictionary (empty, partially or fully filled). For each module check if it is in the dictionary keys and reuses the stored value or computes it and updates the cache afterwards. The cache can then be reused for later calls (with the same data) or e.g. for the `sampling` function.

## Node Modules

Basic `LeafNode`,`SumNode` and `ProductNode` classes.

Every class dispatches `likelihood(...)`, `scope(...)` and `sample(...)`.

![nodes](uml/nodes.svg)

### Abstract Node Module

Fixed outputs size (i.e. `__len__`) of 1

In [22]:
class Node(Module, ABC):
    """Represents basic nodes, i.e. the smallest building block for non-optimized networks."""
    @property
    def n_out(self) -> int:
        return 1

### Leaf Nodes

**Note**: `LeafNode` should actually be an abstract base class for leaf nodes to inherit from. In this example we simply implement `LeafNode` as a one-dimensional Gaussian distribution for demonstration purposes.

In [23]:
class LeafNode(Node):
    """Basic univariate leaf node. Here, implemented as a Gaussian.

    Args:
        scope: scope of the distribution.
        mean: mean of the distribution.
        std: standard deviation of the distribution
    """
    def __init__(self, scope: Scope, mean=0.0, std=1.0) -> None:
        
        if(len(scope) != 1):
            raise ValueError("Scope to large for univariate leaf node.")
        
        super(LeafNode, self).__init__()
        
        self.scope = scope
        self.mean = mean
        self.std = std
        self.dist = norm(loc=mean, scale=std)

@dispatch(LeafNode, dispatch_ctx=DispatchContext)
@memoize
def scope(leaf_node: LeafNode, dispatch_ctx: DispatchContext) -> ScopeArray:
    """Simply returns the scope of the module, since it's a leaf node."""
    return ScopeArray([[leaf_node.scope]])

@dispatch(LeafNode, np.ndarray, dispatch_ctx=DispatchContext)
@memoize
def likelihood(leaf_node: LeafNode, data: np.ndarray, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Returns the likelihood for each data instance. Marginalizes (likelihood of 1) for NaN values."""
    likelihoods = np.ones((data.shape[0], len(leaf_node.scope)))
    inputs = data[:, list(leaf_node.scope.query)]

    marg_ids = np.isnan(inputs).sum(axis=1) == len(leaf_node.scope)
    likelihoods[~marg_ids, :] = leaf_node.dist.pdf(inputs[~marg_ids, :])

    return likelihoods

@dispatch(LeafNode, np.ndarray, dispatch_ctx=DispatchContext, sampling_context=SamplingContext)
@memoize
def sample(leaf_node: LeafNode, data: np.ndarray, dispatch_ctx: DispatchContext, sampling_ctx: Optional[SamplingContext]=None) -> np.ndarray:
    """Samples and fills the indices specified in the sampling context."""

    if sampling_ctx:
        if not sampling_ctx.is_valid(data, leaf_node):
            # invalid sampling context
            raise ValueError(f"Specified sampling context is invalid for specified data array and module.")
    else:
        # create sampling context (assume all instances and all output nodes are to be used)
        sampling_ctx = SamplingContext(list(range(data.shape[0])), [[] for _ in range(data.shape[0])]) 

    scope_vars = list(leaf_node.scope.query)

    # sample ids (nan entries)
    sample_ids = (np.isnan(data[:, scope_vars]).sum(axis=1) == len(scope_vars))
    
    # get mask for sampling context
    instance_id_mask = np.zeros(data.shape[0]).astype(bool)
    instance_id_mask[sampling_ctx.instance_ids] = True

    # filter instance ids according to sampling context
    sample_ids &= instance_id_mask

    # sample
    data[sample_ids, list(leaf_node.scope.query)] = np.random.normal(leaf_node.mean, leaf_node.std, sample_ids.sum())

    return data

### Product Nodes

In [24]:
class ProductNode(Node):
    """Simple SPN product node. Child nodes (explicit or implicit) are assumed to have pairwise disjoint scopes.
    
    Args:
        children: list of child modules.
    """
    def __init__(self, children: List[Module]) -> None:
        
        if(len(children) == 0):
            raise ValueError(f"List of child modules for ProductNode is empty.")

        super(ProductNode, self).__init__(children)

@dispatch(ProductNode, dispatch_ctx=DispatchContext)
@memoize
def scope(product_node: ProductNode, dispatch_ctx: DispatchContext) -> ScopeArray:
    """Returns merged child scopes.""" 
    scopes = np.concatenate([scope(child, dispatch_ctx=dispatch_ctx) for child in product_node.children], axis=1)
    return scopes.sum(keepdims=True)

@dispatch(ProductNode, np.ndarray, dispatch_ctx=DispatchContext)
@memoize
def likelihood(module: ProductNode, data: np.ndarray, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Returns the product of the child likelihoods."""
    inputs = np.concatenate([likelihood(child, data, dispatch_ctx=dispatch_ctx) for child in module.children], axis=1)
    return np.prod(inputs, axis=1, keepdims=True)

@dispatch(ProductNode, np.ndarray, dispatch_ctx=DispatchContext, sampling_ctx=SamplingContext)
@memoize
def sample(product_node: ProductNode, data: np.ndarray, dispatch_ctx: DispatchContext, sampling_ctx: Optional[SamplingContext]=None) -> np.ndarray:
    """Samples from all branches (i.e. nodes/module outputs).""" 
    
    if sampling_ctx:
        if not sampling_ctx.is_valid(data, product_node):
            # invalid sampling context
            raise ValueError(f"Specified sampling context is invalid for specified data array and module.")
    else:
        # create sampling context (assume all instances and all output nodes are to be used)
        sampling_ctx = SamplingContext(list(range(data.shape[0])), [[] for _ in range(data.shape[0])])

    # sample from all branches (i.e. inputs)
    for child in product_node.children:
        # create sampling context for child (same instance ids but different output ids now for new module)
        child_sampling_ctx = SamplingContext(sampling_ctx.instance_ids, [[] for _ in range(len(sampling_ctx.instance_ids))])
        # sample from child
        sample(child, data, dispatch_ctx=dispatch_ctx, sampling_ctx=child_sampling_ctx)
        

    return data

### Sum Nodes

In [25]:
class SumNode(Node):
    """Simple SPN sum node. All child nodes (explicit or implicit) are assumed to have the same scopes.
    
    Args:
        children: list of child modules.
    """ 
    def __init__(self, children: List[Module]) -> None:
        
        if(len(children) == 0):
            raise ValueError(f"List of child modules for SumNode is empty.")
        
        super(SumNode, self).__init__(children)
        
        self.weights = np.random.rand(self.n_in)
        self.weights /= self.weights.sum()

@dispatch(SumNode, dispatch_ctx=DispatchContext)
@memoize
def scope(sum_node: SumNode, dispatch_ctx: DispatchContext) -> ScopeArray:
    """Returns merged child scopes."""
    scopes = np.concatenate([scope(child, dispatch_ctx=dispatch_ctx) for child in sum_node.children])
    return scopes.sum(keepdims=True)

@dispatch(SumNode, np.ndarray, dispatch_ctx=DispatchContext)
@memoize
def likelihood(sum_node: SumNode, data: np.ndarray, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Returns the weighted sum of the child likelihoods."""
    inputs = np.concatenate([likelihood(child, data, dispatch_ctx=dispatch_ctx) for child in sum_node.children], axis=1)
    return (sum_node.weights*inputs).sum(axis=1, keepdims=True)

@dispatch(SumNode, np.ndarray, dispatch_ctx=DispatchContext, sampling_ctx=SamplingContext)
@memoize
def sample(sum_node: SumNode, data: np.ndarray, dispatch_ctx: DispatchContext, sampling_ctx: Optional[SamplingContext]=None) -> np.ndarray:
    """Samples a branch for each instance (taking likelihoods into account)."""

    if sampling_ctx:
        if not sampling_ctx.is_valid(data, sum_node):
            # invalid sampling context
            raise ValueError(f"Specified sampling context is invalid for specified data array and module.")
    else:
        # create sampling context (assume all instances and all output nodes are to be used)
        sampling_ctx = SamplingContext(list(range(data.shape[0])), [[] for _ in range(data.shape[0])])

    # get likelihoods for children (using cache; only use relevant instances to save computation)
    child_likelihoods = np.concatenate([likelihood(child, data[sampling_ctx.instance_ids], dispatch_ctx=dispatch_ctx) for child in sum_node.children], axis=1)
    
    # sample branch for each instance id
    choices = []
    sampling_probs = child_likelihoods * sum_node.weights
    
    for probs in sampling_probs:
        # normalize
        probs_norm = probs * (1 / np.sum(probs))
        choices.append(np.random.choice(list(range(probs.shape[0])), p=probs_norm))

    choices = np.array(choices)

    # get number of outputs per child module
    child_num_outputs = np.array([child.n_out for child in sum_node.children])
    child_cum_outputs = np.cumsum(child_num_outputs)

    # for each unique sampled branch
    for branch_id in np.unique(choices):
        # group instances by sampled branch
        child_sample_ids = np.where(choices == branch_id)[0]

        # get corresponding child and output id for sampled branch
        child_id, output_id = sum_node.input_to_output_id(branch_id)
    
        # sample from child
        sample(sum_node.children[child_id], data, dispatch_ctx=dispatch_ctx, sampling_ctx=SamplingContext(child_sample_ids, [[output_id] for _ in range(len(child_sample_ids))]))
        

    return data

##### Examples

In [26]:
l1 = LeafNode(Scope(SPN(), [0])) # leaf node with (scope 0)
l2 = LeafNode(Scope(SPN(), [0])) # leaf node with (scope 0)
s = SumNode(children=[l1,l2]) # sum node over both product nodes (scope 0)

samples = sample(s)
print(f"Sample: {samples}, Sample likelihood: {likelihood(s, samples)}")

Sample: [[-1.02077459]], Sample likelihood: [[0.2369446]]


In [27]:
l1 = LeafNode(Scope(SPN(), [0])) # leaf node with (scope 0)
l2 = LeafNode(Scope(SPN(), [1])) # leaf node with (scope 1)
p1 = ProductNode([l1,l2]) # product node over both leaf nodes (scope 0,1)
p2 = ProductNode([l1,l2]) # product node over both leaf nodes (scope 0,1)
s = SumNode(children=[p1,p2]) # sum node over both product nodes (scope 0,1)

data = np.random.randn(3,2)

samples = sample(s, data)
print(f"Sample: {samples}, Sample likelihood: {likelihood(s, samples)}")
print(f"Scope: {scope(s)}")

Sample: [[ 0.44808323  1.34456184]
 [-1.40430639  0.36902538]
 [-0.10079911 -1.00715846]], Sample likelihood: [[0.05829788]
 [0.05546489]
 [0.09535568]]
Scope: [[Scope({0, 1}|{})]]


In [28]:
data = np.random.randn(1,2)

# call without passing cache
l = likelihood(s, data)
print(f"Likelihood (w/o cache):\t\t{l}")

# use cache this time
ctx = DispatchContext()
cache = ctx.cache
l = likelihood(s, data, dispatch_ctx=ctx)
print(f"Likelihood (store cache):\t{l}, matches stored value: {all(l == cache.likelihood[s])}")

# set cache to check if stored value is used
cache.likelihood[s] = np.ones((1,1))
l = likelihood(s, data, dispatch_ctx=ctx)
print(f"Likelihood (modified cache):\t{l}")

Likelihood (w/o cache):		[[0.02766286]]
Likelihood (store cache):	[[0.02766286]], matches stored value: True
Likelihood (modified cache):	[[1.]]


In [29]:
l = LeafNode(Scope(SPN(), [0]))

s = sample(l)
print("Without any specifications:", s.shape)

s = sample(l, 5)
print("Specifying number of samples:", s.shape)

s = np.full((3,1), np.nan)
s_ = sample(l, s)

# make sure that all nan values are now replaced (i.e. sampled) and matches the returned tensor (filled in-place)
print("Passing data tensor:", all(s == s_) and not any(np.isnan(s_)))

Without any specifications: (1, 1)
Specifying number of samples: (5, 1)
Passing data tensor: True


## Nested Modules

We'd like to build more complex modules from basic nodes. These could then again be combined to create even more intricate modules. For that we need to be able to nest modules.

A network without any open non-terminal nodes/modules, can straight-forwardly be nested.

However, non-terminal modules need child modules to be specified at creation. In nested modules this would require internal non-terminal modules to reference the same child modules as the enclosing module. This would be extremely messy.

Instead, one could use placeholder modules that can stand-in for the actual children. The enclosing module can the set the cache for `scope`,`likelihood` calls to these modules or redirect `sample`-calls to the actual child modules. This also allows to divide up inputs from child modules for the internal/nested modules and change their order (in contrast to direct parent-child relationships).

![nesting_module](uml/nesting_module.svg)

**TODO**: find other/modern designation for `owner`.

In [30]:
class NestingModule(Module, ABC):
    """Convenient module class for nesting non-terminal modules.
    
    Args:
        childen: list of child modules.
    """
    def __init__(self, children: Optional[List[Module]]=None) -> None:
        
        if children is None:
            children = []
        
        super(NestingModule, self).__init__(children)
        self.placeholders = []

    def create_placeholder(self, input_ids: List[int]) -> "Placeholder":
        """Creates a placholder module that can be used for internal non-terminal modules.
        
        Also registers the placeholder internally.
        """
        # create and register placeholder
        ph = self.Placeholder(self, input_ids)
        self.placeholders.append(ph)

        return ph
    
    def set_placeholders(self, cache, inputs) -> None:
        """Fills the cache for all registered placeholder modules given specified input values."""
        for ph in self.placeholders:
            # fill placeholder cache with specified input values
            cache[ph] = inputs[:,ph.input_ids]

    class Placeholder(Module):
        """Placeholder module as an intermediary module between nested non-terminal modules and actual child modules."""
        def __init__(self, owner: Module, input_ids: List[int]) -> None:
            self.owner = owner
            self.input_ids = input_ids
            
            # compute conversion from input ids corresponding child and output id (Saves computation at run-time)
            self.input_to_output_id_dict = {}
            
            for input_id in range(len(input_ids)):
                # convert placeholder input id to actual input id
                input_id_actual = self.input_ids[input_id]

                # set corresponding child and output id via owner
                self.input_to_output_id_dict[input_id] = self.owner.input_to_output_id(input_id_actual)

        @property
        def n_out(self) -> int:
            return len(self.input_ids)

In [31]:
@dispatch(NestingModule.Placeholder, np.ndarray, dispatch_ctx=DispatchContext)
@memoize
def likelihood(nesting_module: NestingModule.Placeholder, data: np.ndarray, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Gets called if values for placeholder module are not in the cache. In that case raise an error."""
    raise LookupError("Likelihood values for placeholder module not found in cache. Check if these are correctly set by the nesting module.")

@dispatch(NestingModule.Placeholder, dispatch_ctx=DispatchContext)
@memoize
def scope(nesting_module: NestingModule.Placeholder, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Gets called if values for placeholder module are not in the cache. In that case raise an error."""
    raise LookupError("Scope values for placeholder module not found in cache. Check if these are correctly set by the nesting module.")

@dispatch(NestingModule.Placeholder, np.ndarray, dispatch_ctx=DispatchContext, sampling_ctx=SamplingContext)
@memoize
def sample(nesting_module: NestingModule.Placeholder, data: np.ndarray, dispatch_ctx: DispatchContext, sampling_ctx: Optional[SamplingContext]=None) -> np.ndarray:
    """Redirects sampling calls to sample actual child modules"""

    if sampling_ctx:
        if not sampling_ctx.is_valid(data, nesting_module):
            # invalid sampling context
            raise ValueError(f"Specified sampling context is invalid for specified data array and module.")
    else:
        if(nesting_module.n_out != 1):
            raise ValueError("No sampling context specified. It is unclear which output to sample from.")
        else:
            sampling_ctx = SamplingContext(list(range(len(data.shape[0]))), [[] for _ in len(data.shape[0])])

    sampling_ctx_per_child = {}

    # TODO: could potentially be done more efficiently via grouping
    for instance_id, instance_output_ids in zip(sampling_ctx.instance_ids, sampling_ctx.output_ids):

        output_per_child = {}
        
        # iterate over actual child and output ids
        if instance_output_ids == []:
            # all children
        
            for _, ids in nesting_module.input_to_output_id_dict.items():
                output_per_child[ids[0]] = [ids[1]]
        else:
            for child_id, output_id in [nesting_module.input_to_output_id(output_id) for output_id in instance_output_ids]:

                # sort output ids per child id
                if(child_id in output_per_child):
                    output_per_child[child_id].append(output_id)
                else:
                    output_per_child[child_id] = [output_id]
        
        # append (or create) sampling contexts
        for child_id, output_ids in output_per_child.items():
            if(child_id) in sampling_ctx_per_child:
                sampling_ctx_per_child[child_id].instance_ids.append(instance_id)
                sampling_ctx_per_child[child_id].output_ids.append(output_ids)
            else:
                sampling_ctx_per_child[child_id] = SamplingContext([instance_id], [output_ids])

    # sample from children
    for child_id, child_sampling_ctx in sampling_ctx_per_child.items():
        sample(nesting_module.owner.children[child_id], data, dispatch_ctx=dispatch_ctx, sampling_ctx=child_sampling_ctx)

    return data

### Layer Modules

Basic `LeafLayer`,`SumLayer` and `ProducLayer` classes.

![layers](uml/layers.svg)

#### Leaf Layer

Note: `LeafLayer` only contains terminal modules and therefore has no need for placeholders.

In [32]:
class LeafLayer(Module):
    """Layer of multiple leaf nodes over the same scope.
    
    Args:
        scope: scope of all leaf nodes.
        n_out: number of leaf nodes.
    """
    def __init__(self, scope: Scope, n_out) -> None:
        
        super(LeafLayer, self).__init__()
        
        self.nodes = []
        self.scope = scope

        # create leaf nodes
        for _ in range(n_out):
            self.nodes.append(LeafNode(scope=scope))
    
    @property
    def n_out(self) -> int:
        return len(self.nodes)

@dispatch(LeafLayer, dispatch_ctx=DispatchContext)
@memoize
def scope(leaf_layer: LeafLayer, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Concatenates the scopes of all leaf nodes."""
    return np.concatenate([scope(node, dispatch_ctx=dispatch_ctx) for node in leaf_layer.nodes], axis=1)

@dispatch(LeafLayer, np.ndarray, dispatch_ctx=DispatchContext)
@memoize
def likelihood(leaf_layer: LeafLayer, data: np.ndarray, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Concatenates the likelihoods for all leaf nodes."""
    return np.concatenate([likelihood(node, data, dispatch_ctx=dispatch_ctx) for node in leaf_layer.nodes], axis=1)

@dispatch(LeafLayer, np.ndarray, dispatch_ctx=DispatchContext, sampling_ctx=SamplingContext)
@memoize
def sample(leaf_layer: LeafLayer, data: np.ndarray, dispatch_ctx: DispatchContext, sampling_ctx: Optional[SamplingContext]=None) -> np.ndarray:
    """Samples leaf nodes accoding to sampling context."""

    if sampling_ctx:
        if not sampling_ctx.is_valid(data, leaf_layer):
            # invalid sampling context
            raise ValueError(f"Specified sampling context is invalid for specified data array and module.")
    else:
        raise ValueError("No sampling context specified. It is unclear which output to sample from.")
        # create sampling context (assume all instances and all output nodes are to be used)
        #sampling_context = SamplingContext(list(range(data.shape[0])), [[] for _ in range(data.shape[0])])
    
    for node_ids in np.unique(sampling_ctx.output_ids, axis=0):
        if(len(node_ids) != 1):
            raise ValueError("Too many output ids specified for outputs over same scope.")
        
        node_id = node_ids[0]
        node_instance_ids = np.where(sampling_ctx.output_ids == node_ids)[0]
        node_instance_ids_test = sampling_ctx.instance_ids
        
        sample(leaf_layer.nodes[node_id], data, dispatch_ctx=dispatch_ctx, sampling_ctx=SamplingContext(node_instance_ids_test, [[] for i in node_instance_ids]))

    return data

##### Examples

In [33]:
layer = LeafLayer(Scope(SPN(), [0]), 3)

data = np.random.randn(2,1)
print(f"Likelihood: {likelihood(layer, data).shape}")

print(f"Scope: {scope(layer)}")

Likelihood: (2, 3)
Scope: [[Scope({0}|{}) Scope({0}|{}) Scope({0}|{})]]


In [34]:
s = SumNode(children=[layer])
sample(s)

array([[1.53034358]])

#### Sum Layer

Because this layer now uses non-terminal modules/nodes, we use `NestingModule` as a blueprint.

In [35]:
class SumLayer(NestingModule):
    """Layer of multiple sum nodes over the same child modules.
    
    Args:
        n_out: number of sum nodes.
        children: list of child modules.
    """
    def __init__(self, n_out: int, children: List[Module]) -> None:
        super(SumLayer, self).__init__(children=children)
        
        self.nodes = []

        # all nodes share the same input link here
        ph = self.create_placeholder(list(range(self.n_in)))

        for _ in range(n_out):
            self.nodes.append(SumNode([ph]))

    @property
    def n_out(self) -> int:
        return len(self.nodes)

@dispatch(SumLayer, dispatch_ctx=DispatchContext)
@memoize
def scope(sum_layer: SumLayer, dispatch_ctx=DispatchContext) -> ScopeArray:
    """Concatenates the scopes of all sum nodes."""    
    input_scopes = np.concatenate([scope(child, dispatch_ctx=dispatch_ctx) for child in sum_layer.children], axis=1)

    # set placeholders
    sum_layer.set_placeholders(dispatch_ctx.cache.scope, input_scopes)
    
    # compute output scopes
    output_scopes = np.concatenate([scope(node, dispatch_ctx=dispatch_ctx) for node in sum_layer.nodes], axis=1)
    
    return output_scopes

@dispatch(SumLayer, np.ndarray, dispatch_ctx=DispatchContext)
@memoize
def likelihood(sum_layer: SumLayer, data: np.ndarray, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Concatenates the likelihoods for all sum nodes."""
    input_likelihoods = np.concatenate([likelihood(child, data, dispatch_ctx=dispatch_ctx) for child in sum_layer.children], axis=1)
    
    # set placeholders
    sum_layer.set_placeholders(dispatch_ctx.cache.likelihood, input_likelihoods)
    
    # compute output likelihoods
    output_scopes = np.concatenate([likelihood(node, data, dispatch_ctx=dispatch_ctx) for node in sum_layer.nodes], axis=1)
    
    return output_scopes

@dispatch(SumLayer, np.ndarray, dispatch_ctx=DispatchContext, sampling_ctx=SamplingContext)
@memoize
def sample(sum_layer: SumLayer, data: np.ndarray, dispatch_ctx: DispatchContext, sampling_ctx: Optional[SamplingContext]=None) -> np.ndarray:
    """Samples leaf nodes accoding to sampling context."""

    if sampling_ctx:
        if not sampling_ctx.is_valid(data, sum_layer):
            # invalid sampling context
            raise ValueError(f"Specified sampling context is invalid for specified data array and module.")
    else:
        raise ValueError("No sampling context specified. It is unclear which output to sample from.")

    # fix for [] case
    for node_ids in np.unique(sampling_ctx.output_ids, axis=0):
        if(len(node_ids) != 1):
            raise ValueError("Too many output ids specified for outputs over same scope.")

        node_id = node_ids[0]
        node_instance_ids = np.where(sampling_ctx.output_ids == node_ids)[0]

        sample(sum_layer.nodes[node_id], data, dispatch_ctx=dispatch_ctx, sampling_ctx=SamplingContext(node_instance_ids, [[] for i in node_instance_ids]))

    return data

##### Example

In [36]:
leaf_layer = LeafLayer(Scope(SPN(), [0]), 3)
sum_layer = SumLayer(3, [leaf_layer])

data = np.random.randn(1,1)
print(f"Likelihood: {likelihood(sum_layer, data)}")

print(f"Scope: {scope(sum_layer)}")

Likelihood: [[0.36744572 0.36744572 0.36744572]]
Scope: [[Scope({0}|{}) Scope({0}|{}) Scope({0}|{})]]


In [37]:
s_layer = SumNode([leaf_layer])
s_nodes = SumNode([LeafNode(Scope(SPN(), [0])), LeafNode(Scope(SPN(), [0])), LeafNode(Scope(SPN(), [0]))])

# TODO: there seems to be something wrong with sampling in LeafLayer
print(f"Sample1: {np.mean(sample(s_layer, 1000), axis=0)}")
print(f"Sample2: {np.mean(sample(s_nodes, 1000), axis=0)}")

Sample1: [nan]
Sample2: [-0.00732721]


#### ProductLayer Class

Number of `ProductNode`s is the number of combinations of elements from each input group

In [38]:
class ProductLayer(NestingModule):
    """Layer of multiple product nodes over the same child modules.
    
    Creates a product node for each combination of inputs from the child modules.
    E.g. for two modules with 2 (ids 0,1) and 3 (ids 2,3,4) outputs, respectively, one gets nodes with the following inputs:
        [0,2]
        [0,3]
        [0,4]
        [1,2]
        [1,3]
        [1,4]

    Args:
        children: list of child modules.
    """
    def __init__(self, children: List[Module]) -> None:
        super(ProductLayer, self).__init__(children)
        
        self.nodes = []
        self.input_placeholders = []
        
        children_n_out = [child.n_out for child in self.children]
        total_ids = list(range(sum(children_n_out)))
        factorized_ids = []
        
        for n in children_n_out:
            factorized_ids.append(total_ids[:n])
            total_ids = total_ids[n:]

        self.input_ids_per_node = list(itertools.product(*factorized_ids))
        
        # create product nodes
        for ids in self.input_ids_per_node:
            ph = self.create_placeholder(list(ids))
            self.nodes.append(ProductNode(children=[ph]))
    
    @property
    def n_out(self) -> int:
        return len(self.nodes)

@dispatch(ProductLayer, dispatch_ctx=DispatchContext)
@memoize
def scope(product_layer: ProductLayer, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Concatenates the scopes of all sum nodes."""
    input_scopes = np.concatenate([scope(child, dispatch_ctx=dispatch_ctx) for child in product_layer.children], axis=1)

    # set placeholders
    product_layer.set_placeholders(dispatch_ctx.cache.scope, input_scopes)

    # compute output scopes
    output_scopes = np.concatenate([scope(node, dispatch_ctx=dispatch_ctx) for node in product_layer.nodes], axis=1)

    return output_scopes

@dispatch(ProductLayer, np.ndarray, dispatch_ctx=DispatchContext)
@memoize
def likelihood(product_layer: ProductLayer, data: np.ndarray, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Concatenates the likelihoods for all sum nodes."""
    input_likelihoods = np.concatenate([likelihood(child, data, dispatch_ctx=dispatch_ctx) for child in product_layer.children], axis=1)
    
    # set placeholders
    product_layer.set_placeholders(dispatch_ctx.cache.likelihood, input_likelihoods)

    # compute output likelihoods
    output_scopes = np.concatenate([likelihood(node, data, dispatch_ctx=dispatch_ctx) for node in product_layer.nodes], axis=1)

    return output_scopes

@dispatch(ProductLayer, np.ndarray, dispatch_ctx=DispatchContext, sampling_ctx=SamplingContext)
@memoize
def sample(product_layer: ProductLayer, data: np.ndarray, dispatch_ctx: DispatchContext, sampling_ctx: Optional[SamplingContext]=None) -> np.ndarray:
    """Samples leaf nodes accoding to sampling context."""

    if sampling_ctx:
        if not sampling_ctx.is_valid(data, product_layer):
            # invalid sampling context
            raise ValueError(f"Specified sampling context is invalid for specified data array and module.")
    else:
        raise ValueError("No sampling context specified. It is unclear which output to sample from.")


    for node_ids in np.unique(sampling_ctx.output_ids, axis=0):
        if(len(node_ids) != 1):
            raise ValueError("Too many output ids specified for outputs over same scope.")

        node_id = node_ids[0]
        node_instance_ids = np.where(sampling_ctx.output_ids == node_ids)[0]
        sample(product_layer.nodes[node_id], data, dispatch_ctx=dispatch_ctx, sampling_ctx=SamplingContext(node_instance_ids, [[] for i in node_instance_ids]))

    return data

##### Examples

In [39]:
leaf_layer_1 = LeafLayer(Scope(SPN(), [0]), 2)
leaf_layer_2 = LeafLayer(Scope(SPN(), [1]), 2)
product_layer = ProductLayer([leaf_layer_1, leaf_layer_2])

data = np.random.randn(1,2)
print(f"Likelihood: {likelihood(product_layer, data)}")

print(f"Scope: {scope(product_layer)}")

Likelihood: [[0.1504239 0.1504239 0.1504239 0.1504239]]
Scope: [[Scope({0, 1}|{}) Scope({0, 1}|{}) Scope({0, 1}|{}) Scope({0, 1}|{})]]


In [40]:
l1 = LeafLayer(Scope(SPN(), [0]), n_out=3)
s1 = SumLayer(3, [l1])

l2 = LeafLayer(Scope(SPN(), [1]), n_out=3)
s2 = SumLayer(3, [l2])

p = ProductLayer([s1,s2])
s = SumNode([p])

data = np.random.randn(1,2)
print(f"Likelihood: {likelihood(s, data)}")

print(f"Scope: {scope(s)}")

Likelihood: [[0.06895548]]
Scope: [[Scope({0, 1}|{})]]


In [41]:
sample(s)

array([[-1.93164701, -1.77439365]])

## Networks

We could now build even more complex modules or networks

In [42]:
class ExampleNetwork(NestingModule):
    """Example network using layers and nodes.
    
    Args:
        children: list of child modules.
    """
    def __init__(self, children: List[Module]) -> None:
        super(ExampleNetwork, self).__init__(children)
        
        n_ins = [child.n_out for child in self.children]
        placeholders = []
        
        total_ids = range(sum(n_ins))

        for n in n_ins:
            placeholders.append(self.create_placeholder(total_ids[:n]))
            total_ids = total_ids[n:]
        
        # create product layer on top
        self.product_layer = ProductLayer(children=placeholders)
        
        # sum over all product layers
        self.sum_node = SumNode(children=[self.product_layer])
    
    @property
    def n_out(self) -> int:
        return 1

In this case we can dispatch to the sum node

In [43]:
@dispatch(ExampleNetwork, dispatch_ctx=DispatchContext)
@memoize
def scope(example_network: ExampleNetwork, dispatch_ctx: DispatchContext) -> np.ndarray:
    # compute input scopes
    input_scopes = np.concatenate([scope(child, dispatch_ctx=dispatch_ctx) for child in example_network.children], axis=1)

    # set placeholders
    example_network.set_placeholders(dispatch_ctx.cache.scope, input_scopes)
    
    return scope(example_network.sum_node, dispatch_ctx=dispatch_ctx)

@dispatch(ExampleNetwork, np.ndarray, dispatch_ctx=DispatchContext)
@memoize
def likelihood(example_network: ExampleNetwork, data: np.ndarray, dispatch_ctx: DispatchContext) -> np.ndarray:
    # compute input likelihoods
    input_likelihoods = np.concatenate([likelihood(child, data, dispatch_ctx=dispatch_ctx) for child in example_network.children], axis=1)
    
    # set placeholders
    example_network.set_placeholders(dispatch_ctx.cache.likelihood, input_likelihoods)
    
    return likelihood(example_network.sum_node, data, dispatch_ctx=dispatch_ctx)

@dispatch(ExampleNetwork, np.ndarray, dispatch_ctx=DispatchContext, sampling_ctx=SamplingContext)
@memoize
def sample(example_network: ExampleNetwork, data: np.ndarray, dispatch_ctx: DispatchContext, sampling_ctx: Optional[SamplingContext]=None) -> np.ndarray:
    likelihood(example_network, data, dispatch_ctx=dispatch_context)
    return sample(example_network.sum_node, data, dispatch_ctx=dispatch_context, sampling_ctx=sampling_ctx)

##### Example

In [44]:
leaf_layers = [LeafLayer(Scope(SPN(), [i]), 2) for i in range(3)]
net = ExampleNetwork(children=leaf_layers)

data = np.random.rand(2,3)

print(f"Likelihoods: {likelihood(net, data)}")
print(f"Scope: {scope(net)}")

Likelihoods: [[0.04025569]
 [0.02779345]]
Scope: [[Scope({0, 1, 2}|{})]]


## Conditional Modules

One might want to create conditional modules, e.g. modules where parameters are conditioned on some inputs and set accordingly.

We would like to have two different ways of retrieving/setting conditional parameters:

1. Self-contained as part of the module
2. Injecting values from the outside (e.g. cached values or values computed as part of an overlying module)

Additionally, we would like to avoid actually setting parameters, since parameters values are only valid for a given call and must be cleared afterwards. One solution would be to pass down values.The `args` dictionary in `DispatchContext` can be used to pass arguments to the `likelihood`,`scope` and `sample` calls of individual modules.

Modules should check/retrieve parameters in the following order:

1. Check if required parameters are specified in `DispatchContext`
2. Check if an (alternative) function `cond_f` is specified as an argument for the module in `DispatchContext`
3. Check if a `cond_f` function is specified in the module itself
4. Raise exception

#### Conditional Leaf Node

In [45]:
from typing import Callable

class CondLeafNode(Node):
    """Basic conditional univariate leaf node. Here, implemented as a Gaussian.

    Args:
        scope: scope of the distribution.
    """
    def __init__(self, scope: Scope, cond_f: Optional[Callable]=None) -> None:
        
        if(len(scope) != 1):
            raise ValueError("Scope to large for univariate leaf node.")
        
        super(CondLeafNode, self).__init__()
        
        self.scope = scope
        self.cond_f = cond_f
    
    def set_cond_f(self, cond_f):
        self.cond_f = cond_f

In [46]:
@dispatch(CondLeafNode, dispatch_ctx=DispatchContext)
@memoize
def scope(cond_leaf_node: CondLeafNode, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Simply returns the scope of the module, since it's a leaf node."""
    return ScopeArray([[cond_leaf_node.scope]])

@dispatch(CondLeafNode, np.ndarray, dispatch_ctx=DispatchContext)
@memoize
def likelihood(cond_leaf_node: CondLeafNode, data: np.ndarray, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Returns the likelihood for each data instance. Marginalizes (likelihood of 1) for NaN values."""

    # ----- get conditional parameters -----

    # check if argument dictionary first
    if cond_leaf_node in dispatch_ctx.args:
        args = dispatch_ctx.args[cond_leaf_node]

        if "mean" in args and "stdev" in args:
            mean = args["mean"]
            stdev = args["stdev"]
        
        elif "cond_f" in args:
            cond_f = args["cond_f"]

    # otherwise use conditional parameter function
    elif cond_leaf_node.cond_f:
        cond_f = cond_leaf_node.cond_f
        params = cond_f(data)

        # select parameters
        mean = params["mean"]
        stdev = params["stdev"]
        
        if cond_leaf_node in dispatch_ctx.args:
            # update arguments for possible future (re)use
            dispatch_ctx.args[cond_leaf_node].update(params)
        else:
            dispatch_ctx.args[cond_leaf_node] = params
    else:
        raise ValueError("Conditional leaf requires mean and standard deviation values are neither given nor is a computation function specified.")

    # check whether or not there are enough parameters given the data array
    if mean.shape != (data.shape[0],1):
        raise ValueError("Obtained values for mean of the conditional leaf are of wrong shape.")
    if stdev.shape != (data.shape[0],1):
        raise ValueError("Obtained values for standard deviation of the conditional leaf are of wrong shape.")
    
    dist = norm(mean, stdev)

    likelihoods = np.ones((data.shape[0], len(cond_leaf_node.scope)))
    inputs = data[:, list(cond_leaf_node.scope.query)]

    marg_ids = np.isnan(inputs).sum(axis=1) == len(cond_leaf_node.scope)
    likelihoods[~marg_ids, :] = dist.pdf(inputs[~marg_ids, :])

    return likelihoods

@dispatch(CondLeafNode, np.ndarray, dispatch_ctx=DispatchContext, sampling_ctx=SamplingContext)
@memoize
def sample(cond_leaf_node: CondLeafNode, data: np.ndarray, dispatch_ctx: DispatchContext, sampling_ctx: Optional[SamplingContext]=None) -> np.ndarray:
    raise NotImplementedError

##### Examples

In [47]:
cond_f = lambda data : {"mean": np.zeros((5,1)), "stdev": np.ones((5,1))}

In [48]:
cond_l1 = CondLeafNode(Scope(SPN(), [0], [1])) # leaf node (scope 0)
cond_l2 = CondLeafNode(Scope(SPN(), [0], [1]), cond_f) # leaf node with conditional parameter function (scope 0)
s = SumNode(children=[cond_l1, cond_l2]) # sum node over both product nodes (scope 0)

In [49]:
try:
    # should result in an error because no conditional paramters are specified (and no function to compute them is specified)
    likelihood(cond_l1, data)
except ValueError as e:
    print(e)

Conditional leaf requires mean and standard deviation values are neither given nor is a computation function specified.


In [50]:
# provide conditional paramters for the first conditional leaf
ctx = DispatchContext()
ctx.args[cond_l1] = {
    "mean": np.zeros((5,1)),
    "stdev": np.ones((5,1))
}
data = np.tile([0.0, 1.0], (5,1))

In [51]:
print(f"Likelihood: {likelihood(s, data, dispatch_ctx=ctx)}")

Likelihood: [[0.39894228]
 [0.39894228]
 [0.39894228]
 [0.39894228]
 [0.39894228]]


#### Conditional Sum Node

In [52]:
class CondSumNode(Node):
    """Simple conditional sum node. All child nodes (explicit or implicit) are assumed to have the same scopes.
    
    Args:
        children: list of child modules.
    """ 
    def __init__(self, children: List[Module], cond_f: Optional[Callable]=None) -> None:

        if(len(children) == 0):
            raise ValueError(f"List of child modules for SumNode is empty.")

        super(CondSumNode, self).__init__(children)

        self.cond_f = cond_f

    def set_cond_f(self, cond_f):
        self.cond_f = cond_f

In [53]:
@dispatch(CondSumNode, dispatch_ctx=DispatchContext)
@memoize
def scope(cond_sum_node: CondSumNode, dispatch_ctx: DispatchContext) -> ScopeArray:
    """Returns merged child scopes."""
    scopes = np.concatenate([scope(child, dispatch_ctx=dispatch_ctx) for child in cond_sum_node.children])
    return scopes.sum(keepdims=True)

@dispatch(CondSumNode, np.ndarray, dispatch_ctx=DispatchContext)
@memoize
def likelihood(cond_sum_node: CondSumNode, data: np.ndarray, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Returns the weighted sum of the child likelihoods."""

    inputs = np.concatenate([likelihood(child, data, dispatch_ctx=dispatch_ctx) for child in cond_sum_node.children], axis=1)
    
    # ----- get conditional parameters -----

    # check if argument dictionary first
    if cond_sum_node in dispatch_ctx.args:
        args = dispatch_ctx.args[cond_sum_node]

        if "weights" in args:
            weights = args["weights"]
        
        elif "cond_f" in args:
            cond_f = args["cond_f"]

    # otherwise use conditional parameter function
    elif cond_sum_node.cond_f:
        cond_f = cond_sum_node.cond_f
        params = cond_f(data)

        # select parameters
        weights = params["weights"]

        if cond_sum_node in dispatch_ctx.args:
            # update arguments for possible future (re)use
            dispatch_ctx.args[cond_sum_node].update(params)
        else:
            dispatch_ctx.args[cond_sum_node] = params
    else:
        raise ValueError("Conditional sum node requires weigth values that are neither given nor is a computation function specified.")

    # check whether or not there are enough parameters given the data array
    if weights.shape != (data.shape[0],1): # TODO
        raise ValueError("Obtained values for weights of the conditional sum node are of wrong shape.")

    return (weights*inputs).sum(axis=1, keepdims=True)

@dispatch(CondSumNode, np.ndarray, dispatch_ctx=DispatchContext, sampling_ctx=SamplingContext)
@memoize
def sample(cond_sum_node: CondSumNode, data: np.ndarray, dispatch_ctx: DispatchContext, sampling_ctx: Optional[SamplingContext]=None) -> np.ndarray:
    """Samples a branch for each instance (taking likelihoods into account)."""

    if sampling_ctx:
        if not sampling_ctx.is_valid(data, cond_sum_node):
            # invalid sampling context
            raise ValueError(f"Specified sampling context is invalid for specified data array and module.")
    else:
        # create sampling context (assume all instances and all output nodes are to be used)
        sampling_ctx = SamplingContext(list(range(data.shape[0])), [[] for _ in range(data.shape[0])])

    # get likelihoods for children (using cache; only use relevant instances to save computation)
    child_likelihoods = np.concatenate([likelihood(child, data[sampling_ctx.instance_ids], dispatch_ctx=dispatch_ctx) for child in module.children], axis=1)
    
    # sample branch for each instance id
    choices = []
    sampling_probs = child_likelihoods * cond_sum_node.weights
    
    for probs in sampling_probs:
        # normalize
        probs_norm = probs * (1 / np.sum(probs))
        choices.append(np.random.choice(list(range(probs.shape[0])), p=probs_norm))

    choices = np.array(choices)

    # get number of outputs per child module
    child_num_outputs = np.array([child.n_out for child in module.children])
    child_cum_outputs = np.cumsum(child_num_outputs)

    # for each unique sampled branch
    for branch_id in np.unique(choices):
        # group instances by sampled branch
        child_sample_ids = np.where(choices == branch_id)[0]

        # get corresponding child and output id for sampled branch
        child_id, output_id = cond_sum_node.input_to_output_id(branch_id)
    
        # sample from child
        sample(cond_sum_node.children[child_id], data, dispatch_ctx=dispatch_ctx, sampling_ctx=SamplingContext(child_sample_ids, [[output_id] for _ in range(len(child_sample_ids))]))

    return data

#### Conditional Network

Example of a conditional network that computes all parameters centrally and passes it down to all nodes.

In [54]:
class CondNet(Module):
    def __init__(self, cond_f: Optional[Callable]=None):
        
        super(CondNet, self).__init__()

        self.cond_f = cond_f
        
        # conditional leaves over rv 0
        self.l1 = CondLeafNode(Scope(SPN(), [0], [2]))
        self.l2 = CondLeafNode(Scope(SPN(), [0], [2]))
        
        # conditional leaves over rv 1
        self.l3 = CondLeafNode(Scope(SPN(), [1], [2]))
        self.l4 = CondLeafNode(Scope(SPN(), [1], [2]))
        
        # conditional sum nodes
        self.s1 = CondSumNode([self.l1, self.l2])
        self.s2 = CondSumNode([self.l3, self.l4])
        
        # product node
        self.p = ProductNode([self.s1, self.s2])
        
        self.cond_nodes = [self.l1, self.l2, self.l3, self.l4, self.s1, self.s2]
    
    @property
    def n_out(self):
        return 1
    
    def set_cond_f(self, cond_f):
        self.cond_f = cond_f

In [55]:
@dispatch(CondNet, dispatch_ctx=DispatchContext)
@memoize
def scope(cond_net: CondNet, dispatch_ctx: DispatchContext) -> ScopeArray:
    """Returns scope."""
    return scope(cond_net.p, dispatch_ctx=dispatch_ctx)

@dispatch(CondNet, np.ndarray, dispatch_ctx=DispatchContext)
@memoize
def likelihood(cond_net: CondNet, data: np.ndarray, dispatch_ctx: DispatchContext) -> np.ndarray:
    """Returns likelihoods."""
    
    # ----- get conditional parameters -----
    
    if not all(m in dispatch_ctx.args for m in cond_net.cond_nodes):
        
        if cond_net in dispatch_ctx.args and "cond_f" in dispatch_ctx.args[cond_net]:
            args = dispatch_ctx.args[cond_net]
            cond_f = args["cond_f"]
        elif cond_net.cond_f:
            cond_f = cond_net.cond_f
        else:
            raise ValueError("Conditional network requires weigth values that are neither given nor is a computation function specified.")

        cond_params = cond_f(data)
        
        for m, p in zip(cond_net.cond_nodes, cond_params):
            if m in dispatch_ctx.args:
                dispatch_ctx.args[m].update(p)
            else:
                dispatch_ctx.args[m] = p

    return likelihood(cond_net.p, data, dispatch_ctx=dispatch_ctx)

@dispatch(CondNet, np.ndarray, dispatch_ctx=DispatchContext, sampling_ctx=SamplingContext)
@memoize
def sample(cond_net: CondNet, data: np.ndarray, dispatch_ctx: DispatchContext, sampling_ctx: Optional[SamplingContext]=None) -> np.ndarray:
    """Returns samples."""
    return sample(cond_net.p, data, dispatch_ctx=dispatch_ctx, sampling_ctx=sampling_ctx)

##### Examples

In [56]:
def cond_f(data: np.ndarray) -> np.ndarray:
    
    # leaf node parameters
    l1_params = {"mean": np.zeros((data.shape[0],1)), "stdev": np.ones((data.shape[0],1))}
    l2_params = {"mean": np.zeros((data.shape[0],1)), "stdev": np.ones((data.shape[0],1))}
    l3_params = {"mean": np.zeros((data.shape[0],1)), "stdev": np.ones((data.shape[0],1))}
    l4_params = {"mean": np.zeros((data.shape[0],1)), "stdev": np.ones((data.shape[0],1))}
    
    # sum node parameters
    s1_params = {"weights": np.ones((data.shape[0],1))/2}
    s2_params = {"weights": np.ones((data.shape[0],1))/2}
    
    return l1_params, l2_params, l3_params, l4_params, s1_params, s2_params

In [57]:
net = CondNet(cond_f)
scope(net)

array([[Scope({0, 1}|{2})]], dtype=object)

In [58]:
data = np.random.randn(3,2)
likelihood(net, data)

array([[0.12215624],
       [0.1258246 ],
       [0.0033484 ]])

## TODOs/Open Questions

* Some bug in `LeafLayer` when sampling.
* Examples for optimized layers without explicit nodes.
* Structural marginalization
* Must make sure that likelihoods are computed in the very beginning and don't change (i.e. partly sampled data should not affect the likelihoods of the rest of the sampling routine)
* Creating default sampling contexts, computing input (scopes, likelihoods) etc. could be handled via specific decorator for sampling, scope and likelihood computation
* Sampling multiple non-disjoint outputs at the same time (e.g. different replicas in RAT-SPNs)
* Case distinction for scopes (e.g. merging based on sum-node different than product node). So far, modules themselves would have to check scopes manually.
* Scope variable order: e.g. for `LeafNode` the order of the scope might be important, but is a `set` here. For example, a multivariate Gaussian could be specified via scope `[0,1]` or `[1,0]`, depending on the desired order. Since the data is selected via the `set`, the order might be different from the specified `mean`,`std` values, which might not be what the user expects.
* Creating sampling array during dispatch might have to take evidence scopes into account.