# Developer Guide

This guide describes how to extend the SpFlow library, including adding custom leaf modules with new distributions, implementing variations of sum or product modules, and developing alternative utility modules such as new split modules.

## Leaf Modules

In this section, we demonstrate how to implement a new leaf module with a custom distribution. As an example, we use a simple toy distribution defined over the range `[-10, 10]`. It includes one trainable parameter, **border**, which divides the domain into “likely’’ and “unlikely’’ regions:

- values in the interval `[border, 10]` have probability **0.7**,  
- values in `[-10, border)` have probability **0.3**.

Although this is not a probabilistically valid distribution (e.g., its probabilities do not integrate to 1), it is sufficient to illustrate how to construct a custom leaf module.

The `CustomLeaf` class inherits from `LeafModule`, the base class for all leaf modules. Consequently, all initialization parameters required by `LeafModule` must also be included in the subclass constructor, in addition to any distribution-specific parameters. For a normal distribution, these parameters might be `loc` and `scale`; in our custom example, the required parameter is the trainable **border**. Once included, this parameter can be initialized just like any other PyTorch parameter.

The only methods that must be implemented manually are the helper functions used for **maximum likelihood estimation**. The base `LeafModule` handles all remaining functionality.

Custom leaf distributions should follow the structure of a standard PyTorch distribution. Therefore, if your distribution is not derived from an existing PyTorch distribution class, you must implement the distribution logic and all required methods yourself.


In [1]:
import torch
from torch import Tensor, nn

from spflow.modules.leaves.base import LeafModule
from spflow.utils.leaves import init_parameter, _handle_mle_edge_cases


class CustomLeaf(LeafModule):
    """Custom distribution leaf module.

    Parameterized by border.

    Attributes:
        border: Border that separates the value interval.
    """

    def __init__(
        self,
        scope,
        out_channels: int = None,
        num_repetitions: int = 1,
        parameter_fn: nn.Module = None,
        validate_args: bool | None = True,
        border: Tensor = None,
    ):
        """Initialize Normal distribution.

        Args:
            scope: Variable scope (Scope, int, or list[int]).
            out_channels: Number of output channels (inferred from params if None).
            num_repetitions: Number of repetitions (for 3D event shapes).
            parameter_fn: Optional neural network for parameter generation.
            border: Border tensor that separates the value interval.
        """
        super().__init__(
            scope=scope,
            out_channels=out_channels,
            num_repetitions=num_repetitions,
            params=[border],
            parameter_fn=parameter_fn,
            validate_args=validate_args,
        )

        border = init_parameter(param=border, event_shape=self._event_shape, init=torch.randn) * 10
        border = torch.clamp(border, -10.0, 10.0)

        self.border = nn.Parameter(border)

    @property
    def _supported_value(self):
        return 0.0

    @property
    def _torch_distribution_class(self) -> type[torch.distributions.Normal]:
        return _CustomLeafDistribution

    def params(self):
        return {"border": self.border}

    def _compute_parameter_estimates(
        self, data: Tensor, weights: Tensor, bias_correction: bool
    ) -> dict[str, Tensor]:
        """Compute raw MLE estimates for normal distribution (without broadcasting).

        Args:
            data: Input data tensor.
            weights: Weight tensor for each data point.
            bias_correction: Whether to apply bias correction to variance estimate.

        Returns:
            Dictionary with 'loc' and 'scale' estimates (shape: out_features).
        """
        return {"border": self.border}

    def _set_mle_parameters(self, params_dict: dict[str, Tensor]) -> None:
        """Set MLE-estimated parameters for Normal distribution.

        Explicitly handles the two parameter types:
        - loc: Direct nn.Parameter, update .data attribute
        - scale: Property with setter, calls property setter which updates log_scale

        Args:
            params_dict: Dictionary with 'loc' and 'scale' parameter values.
        """
        self.border.data = params_dict["border"]
        self.scale = params_dict["scale"]  # Uses property setter

    def _mle_update_statistics(self, data: Tensor, weights: Tensor, bias_correction: bool) -> None:
        """Compute weighted mean and standard deviation.

        Args:
            data: Input data tensor.
            weights: Weight tensor for each data point.
            bias_correction: Whether to apply bias correction to variance estimate.
        """
        estimates = self._compute_parameter_estimates(data, weights, bias_correction)

        # Broadcast to event_shape and assign directly
        self.border.data = self._broadcast_to_event_shape(estimates["border"])


class _CustomLeafDistribution:
    """Custom Hypergeometric distribution implementation.

    Since PyTorch doesn't have a built-in Hypergeometric distribution,
    this class implements the necessary methods for inference and sampling.
    """

    def __init__(self, border: torch.Tensor, validate_args: bool = True):
        self.border = torch.nn.Parameter(torch.tensor(border))
        self.event_shape = border.shape
        self.validate_args = validate_args

    def _clamped_border(self):
        # Ensure border stays inside the allowed interval
        return torch.clamp(self.border, -10.0, 10.0)

    def check_support(self, data: Tensor) -> Tensor:
        """Hypergeometric support: integer counts within valid borders.

        Valid range: max(0, n+K-N) <= x <= min(n, K)
        """
        valid = torch.all((data >= -10.0) & (data <= 10.0))

        return valid

    @property
    def mode(self):
        """Return the mode of the distribution."""
        return self._clamped_border()

    def log_prob(self, data: torch.Tensor) -> torch.Tensor:
        """Compute log probability using logarithmic identities to avoid overflow."""

        if self.validate_args:
            support_mask = self.check_support(data)

        border = self._clamped_border()

        result = torch.where(data >= border, torch.log(torch.full_like(data, 0.3)), torch.log(torch.full_like(data, 0.7)))
        result = result.masked_fill(~support_mask, float("-inf"))
        return result

    def sample(self, n_samples):
        """
        Sample `n_samples` draws.
        Returns shape: (n_samples,)
        Values are 10 if sample >= border, otherwise 0.
        """
        if not isinstance(n_samples, tuple):
            n_samples = (n_samples,)

        # Prepare the tensor to store the samples
        sample_shape = n_samples + self.event_shape

        device = self.border.device
        region_mask = torch.rand(sample_shape, device=device) < 0.7
        border = self.border.expand(sample_shape)

        # Step 2: allocate output
        samples = torch.empty(sample_shape, device=device)

        # Step 3: sample each region element-wise
        # Upper region: U(border[i], high)
        if region_mask.any():
            samples[region_mask] = (
                border[region_mask]
                + torch.rand_like(border[region_mask]) * (10 - border[region_mask])
            )

        # Lower region: U(low, border[i])
        if (~region_mask).any():
            samples[~region_mask] = (
                -10
                + torch.rand_like(border[~region_mask]) * (border[~region_mask] - -10)
            )

        return samples

In [2]:
from spflow.meta import Scope
scope = Scope([0])
custom_leaf = CustomLeaf(scope=scope, out_channels=2)
samples = custom_leaf.sample(num_samples=10)
print(custom_leaf.border.data)
print(samples.shape)
print(samples)

data = -10 + 20 * torch.rand(5, 1)
ll = custom_leaf.log_likelihood(data)
print(data)
print(custom_leaf.border)
ll
custom_leaf.mode

tensor([[[-6.5437],
         [ 6.0723]]])
torch.Size([10, 1])
tensor([[-4.6417],
        [-2.4039],
        [ 1.2206],
        [-2.7092],
        [-6.9326],
        [ 2.9362],
        [ 7.5226],
        [ 7.3942],
        [ 0.7955],
        [-8.7497]], grad_fn=<IndexPutBackward0>)
tensor([[ 9.2375],
        [ 8.1393],
        [-6.0355],
        [ 1.0856],
        [-3.1894]])
Parameter containing:
tensor([[[-6.5437],
         [ 6.0723]]], requires_grad=True)


  self.border = torch.nn.Parameter(torch.tensor(border))


tensor([[[-6.5437],
         [ 6.0723]]], grad_fn=<ClampBackward1>)

## Intermediate Modules

Each intermediate module receives one or more input modules. Every module in the library is designed to process exactly one input. Therefore, if you want to pass multiple modules as input, you must first concatenate them using the `Cat` module.

All modules must implement the abstract properties and methods defined in the base `Module` class.

### Required Methods

1. **log_likelihood**
2. **sample**
3. **marginalize**

### Required Properties

1. **out_features**
2. **out_channels**
3. **feature_to_scope**

Whether these properties can be inferred internally or must be provided explicitly at construction time depends on the module type.

For example, a **Sum** module typically requires additional parameters such as:
- number of output channels,
- number of repetitions,
- an optional weight tensor,
- the summation dimension.

A normal **Product** module, on the other hand, always has the same number of output channels as its input, so this value cannot (and does not need to) be specified manually.



---

## Example: Implementing a Custom `SquaredSum` Module

In this section we implement a custom `SquaredSum` module. Conceptually, it behaves like a standard sum module, except that it squares its inputs before summing them.

The `SquaredSum` class inherits from the base `Module` class (although inheriting from the existing `Sum` module would also be possible). For clarity, we use the base class as our starting point.

Because this is a sum-like module, we introduce a weight tensor. The most common shape for such weights is:

(out_features, in_channels, out_channels, num_repetitions)


These weights are trainable parameters and therefore must be registered. To avoid unwanted behavior, the module performs several checks to ensure that:

- the weight tensor has the correct shape,
- all weights are positive,
- and the weights are normalized to sum to 1.

---

## `log_likelihood`

The `log_likelihood` method begins by initializing the module's cache, which stores computed log-likelihoods for later use (e.g., during sampling).

The log-likelihood produced by any module always has the shape:

(batch_size, out_features, out_channels, num_repetitions)


You can rely on this shape when accessing the output of the input module.

For the `SquaredSum` module, the log-likelihood computation consists of taking a weighted squared sum of the input log-likelihoods. The returned tensor must again match the standard shape:

(batch_size, out_features, out_channels, num_repetitions)


---

## `sample`

Sampling in this library relies on a **sampling context**.  
This context is propagated from the root module toward the leaf modules and contains, for each feature, the selected **channel** and **repetition** indices. Each module must implement its own logic for mapping the channel index in the context to the channel index expected by its input module.

The repetition index is chosen at the root and remains unchanged throughout propagation.

For the `SquaredSum` module, the sampling procedure works as follows:

1. Extract the weight slice associated with the repetition index stored in the context.  
2. Check whether log-likelihoods are available in the cache for conditional sampling.  
3. If performing MPE sampling, select the channel indices with the highest logits.  
4. Otherwise, draw channel indices using a categorical distribution over the logits.  
5. Update the sampling context with the newly selected indices.  
6. Pass the updated context to the input module to continue the sampling process.

---

## `marginalize`

The `marginalize` method structurally marginalizes the module by removing the specified random variables from the layer. This process involves not only updating the scope but also adapting any parameters that depend on the number of features in the layer. If the entire scope is marginalized, the method is expected to return `None`. Otherwise, it should return a new module with updated parameters and marginalized input modules.

In the `SquaredSum` layer, we begin by calling `marginalize` on the input module to obtain its marginalized version. Next, we marginalize the sum layer itself using the `feature_to_scope` property, which enables us to remove the appropriate feature column from the weight matrix. We then construct a new `SquaredSum` module using the marginalized input module and the updated weight matrix as the layer’s weights. From these updated weights, the new scope can be derived as described earlier. Finally, we return the marginalized `SquaredSum` layer.



In [3]:
from __future__ import annotations

import numpy as np
import torch
from torch import Tensor

from spflow.exceptions import InvalidParameterCombinationError
from spflow.modules.base import Module
from spflow.modules.ops.cat import Cat
from spflow.utils.cache import Cache, cached
from spflow.utils.projections import (
    proj_convex_to_real,
)
from spflow.utils.sampling_context import SamplingContext, init_default_sampling_context


class SquaredSum(Module):
    """Sum module representing mixture operations in probabilistic circuits.

    Implements mixture modeling by computing weighted combinations of child distributions.
    Weights are normalized to sum to one, maintaining valid probability distributions.
    Supports both single input (mixture over channels) and multiple inputs (mixture
    over concatenated inputs).

    Attributes:
        inputs (Module): Input module(s) to the sum node.
        sum_dim (int): Dimension over which to sum the inputs.
        weights (Tensor): Normalized weights for mixture components.
        logits (Parameter): Unnormalized log-weights for gradient optimization.
    """

    def __init__(
        self,
        inputs: Module | list[Module],
        out_channels: int | None = None,
        num_repetitions: int = 1,
        weights: Tensor | list[float] | None = None,
        sum_dim: int | None = 1,
    ) -> None:
        """Create a Sum module for mixture modeling.

        Weights are automatically normalized to sum to one using softmax.
        Multiple inputs are concatenated along dimension 2 internally.

        Args:
            inputs (Module | list[Module]): Single module or list of modules to mix.
            out_channels (int | None, optional): Number of output mixture components.
                Required if weights not provided.
            num_repetitions (int | None, optional): Number of repetitions for structured
                representations. Inferred from weights if not provided.
            weights (Tensor | list[float] | None, optional): Initial mixture weights.
                Must have compatible shape with inputs and out_channels.
            sum_dim (int | None, optional): Dimension over which to sum inputs. Default is 1.

        Raises:
            ValueError: If inputs empty, out_channels < 1, or weights have invalid shape/values.
            InvalidParameterCombinationError: If both out_channels and weights are specified.
        """
        super().__init__()

        # ========== 1. INPUT VALIDATION ==========
        if not inputs:
            raise ValueError("'Sum' requires at least one input to be specified.")

        # Convert weights from list to tensor if needed
        if weights is not None and isinstance(weights, list):
            weights = torch.tensor(weights)

        # ========== 2. WEIGHTS PARAMETER PROCESSING ==========
        if weights is not None:
            # Validate mutual exclusivity with out_channels
            if out_channels is not None:
                raise InvalidParameterCombinationError(
                    f"Cannot specify both 'out_channels' and 'weights' for 'Sum' module."
                )

            # Validate num_repetitions compatibility
            if num_repetitions is not None and (num_repetitions != 1 and num_repetitions != weights.shape[-1]):
                raise InvalidParameterCombinationError(
                    f"Cannot specify 'num_repetitions' that does not match weights shape for 'Sum' module. "
                    f"Was {num_repetitions} but weights shape indicates {weights.shape[-1]}."
                )

            # Reshape weights to canonical 4D form: (out_features, in_channels, out_channels, num_repetitions)
            weight_dim = weights.dim()
            if weight_dim == 1:
                weights = weights.view(1, -1, 1)
            elif weight_dim == 2:
                weights = weights.view(1, weights.shape[0], weights.shape[1])
            elif weight_dim == 3:
                pass  # Already 3D, will add repetition dimension below
            elif weight_dim == 4:
                pass  # Already 4D
            else:
                raise ValueError(
                    f"Weights for 'Sum' must be a 1D, 2D, 3D, or 4D tensor but was {weight_dim}D."
                )

            # Derive configuration from weights shape
            out_channels = weights.shape[2]
            num_repetitions = weights.shape[3]

        # ========== 3. CONFIGURATION VALIDATION ==========
        if out_channels < 1:
            raise ValueError(
                f"Number of nodes for 'Sum' must be greater of equal to 1 but was {out_channels}."
            )

        # Validate sum_dim compatibility with weights dimensionality
        if weights is not None:
            max_sum_dim = weights.dim() - 1
            if sum_dim > max_sum_dim:
                raise ValueError(
                    f"When providing {weights.dim()}D weights, 'sum_dim' must be at most {max_sum_dim} but was {sum_dim}."
                )

        # ========== 4. INPUT MODULE SETUP ==========
        if isinstance(inputs, list):
            if len(inputs) == 1:
                self.inputs = inputs[0]
            else:
                self.inputs = Cat(inputs=inputs, dim=2)
        else:
            self.inputs = inputs

        self.sum_dim = sum_dim

        # ========== 5. ATTRIBUTE INITIALIZATION ==========
        self._out_features = self.inputs.out_features
        self._in_channels_total = self.inputs.out_channels
        self._out_channels_total = out_channels
        self.num_repetitions = num_repetitions

        self.weights_shape = (
            self._out_features,
            self._in_channels_total,
            self._out_channels_total,
            self.num_repetitions,
        )

        self.scope = self.inputs.scope

        # ========== 6. WEIGHT INITIALIZATION & PARAMETER REGISTRATION ==========
        if weights is None:
            # Initialize weights randomly with small epsilon to avoid zeros
            weights = torch.rand(self.weights_shape) + 1e-08
            # Normalize to sum to one along sum_dim
            weights /= torch.sum(weights, dim=self.sum_dim, keepdims=True)

        # Register parameter for unnormalized log-probabilities
        self.logits = torch.nn.Parameter()

        # Set weights (converts to logits internally via property setter)
        self.weights = weights

    @property
    def feature_to_scope(self) -> np.ndarray:
        return self.inputs.feature_to_scope

    @property
    def out_features(self) -> int:
        return self._out_features

    @property
    def out_channels(self) -> int:
        return self._out_channels_total

    @property
    def log_weights(self) -> Tensor:
        """Returns the log weights of all nodes as a tensor.

        Returns:
            Tensor: Log weights normalized to sum to one.
        """
        # project auxiliary weights onto weights that sum up to one
        return torch.nn.functional.log_softmax(self.logits, dim=self.sum_dim)

    @property
    def weights(self) -> Tensor:
        """Returns the weights of all nodes as a tensor.

        Returns:
            Tensor: Weights normalized to sum to one.
        """
        # project auxiliary weights onto weights that sum up to one
        return torch.nn.functional.softmax(self.logits, dim=self.sum_dim)

    @weights.setter
    def weights(
        self,
        values: Tensor,
    ) -> None:
        """Set weights of all nodes.

        Args:
            values: Tensor containing weights for each input and node.

        Raises:
            ValueError: If weights have invalid shape, contain non-positive values,
                or do not sum to one.
        """
        if values.shape != self.weights_shape:
            raise ValueError(
                f"Invalid shape for weights: Was {values.shape} but expected {self.weights_shape}."
            )
        if not torch.all(values > 0):
            raise ValueError("Weights for 'Sum' must be all positive.")
        if not torch.allclose(torch.sum(values, dim=self.sum_dim), torch.tensor(1.0)):
            raise ValueError("Weights for 'Sum' must sum up to one.")
        self.logits.data = proj_convex_to_real(values)

    @log_weights.setter
    def log_weights(
        self,
        values: Tensor,
    ) -> None:
        """Set log weights of all nodes.

        Args:
            values: Tensor containing log weights for each input and node.

        Raises:
            ValueError: If log weights have invalid shape.
        """
        if values.shape != self.log_weights.shape:
            raise ValueError(f"Invalid shape for weights: {values.shape}.")
        self.logits.data = values

    def extra_repr(self) -> str:
        return f"{super().extra_repr()}, weights={self.weights_shape}"

    @cached
    def log_likelihood(
        self,
        data: Tensor,
        cache: Cache | None = None,
    ) -> Tensor:
        """Compute log likelihood P(data | module).

        Computes log likelihood using logsumexp for numerical stability.
        Results are cached for parameter learning algorithms.

        Args:
            data: Input data of shape (batch_size, num_features).
                NaN values indicate evidence for conditional computation.
            cache: Cache for intermediate computations. Defaults to None.

        Returns:
            Tensor: Log-likelihood of shape (batch_size, num_features, out_channels)
                or (batch_size, num_features, out_channels, num_repetitions).
        """
        if cache is None:
            cache = Cache()

        # Get input log-likelihoods
        ll = self.inputs.log_likelihood(
            data,
            cache=cache,
        )

        ll = ll.unsqueeze(3)

        squared_ll = ll * 2# shape: (B, F, input_OC, R)

        log_weights = self.log_weights.unsqueeze(0)  # shape: (1, F, IC, OC, R)

        # Weighted log-likelihoods
        weighted_lls = squared_ll + log_weights  # shape: (B, F, IC, OC, R)

        # Sum over input channels (sum_dim + 1 since here the batch dimension is the first dimension)
        output = torch.logsumexp(weighted_lls, dim=self.sum_dim + 1)

        batch_size = output.shape[0]
        result = output.view(batch_size, self.out_features, self.out_channels, self.num_repetitions)

        return result

    def sample(
        self,
        num_samples: int | None = None,
        data: Tensor | None = None,
        is_mpe: bool = False,
        cache: Cache | None = None,
        sampling_ctx: SamplingContext | None = None,
    ) -> Tensor:
        """Generate samples from sum module.

        Args:
            num_samples: Number of samples to generate.
            data: Data tensor with NaN values to fill with samples.
            is_mpe: Whether to perform maximum a posteriori estimation.
            cache: Optional cache dictionary.
            sampling_ctx: Optional sampling context.

        Returns:
            Tensor: Sampled values.
        """
        if cache is None:
            cache = Cache()

        # Handle num_samples case (create empty data tensor)
        if data is None:
            if num_samples is None:
                num_samples = 1
            data = torch.full((num_samples, len(self.scope.query)), float("nan")).to(self.device)

        # Initialize sampling context if not provided
        sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0], data.device)

        # Index into the correct weight channels given by parent module
        if sampling_ctx.repetition_idx is not None:
            logits = self.logits.unsqueeze(0).expand(
                sampling_ctx.channel_index.shape[0], -1, -1, -1, -1
            )  # shape [b , n_features , in_c, out_c, r]

            indices = sampling_ctx.repetition_idx  # Shape (30000, 1)

            # Use gather to select the correct repetition
            # Repeat indices to match the target dimension for gathering
            in_channels_total = logits.shape[2]
            indices = indices.view(-1, 1, 1, 1, 1).expand(
                -1, logits.shape[1], in_channels_total, logits.shape[3], -1
            )
            # Gather the logits based on the repetition indices
            logits = torch.gather(logits, dim=-1, index=indices).squeeze(-1)

        else:
            if self.num_repetitions > 1:
                raise ValueError(
                    "sampling_ctx.repetition_idx must be provided when sampling from a module with "
                    "num_repetitions > 1."
                )
            logits = self.logits[..., 0] # Select the 0th repetition
            logits = logits.unsqueeze(0) # Make space for the batch

            # Expand to batch size
            logits = logits.expand(sampling_ctx.channel_index.shape[0], -1, -1, -1)

        idxs = sampling_ctx.channel_index[..., None, None]
        in_channels_total = logits.shape[2]
        idxs = idxs.expand(-1, -1, in_channels_total, -1)
        # Gather the logits based on the channel indices
        logits = logits.gather(dim=3, index=idxs).squeeze(3)

        # Check if evidence is given (cached log-likelihoods)
        if (
            cache is not None
            and "log_likelihood" in cache
            and cache["log_likelihood"].get(self.inputs) is not None
        ):
            # Get the log likelihoods from the cache
            input_lls = cache["log_likelihood"][self.inputs]

            if sampling_ctx.repetition_idx is not None:
                indices = sampling_ctx.repetition_idx.view(-1, 1, 1, 1).expand(
                    -1, input_lls.shape[1], input_lls.shape[2], -1
                )

                # Use gather to select the correct repetition
                input_lls = torch.gather(input_lls, dim=-1, index=indices).squeeze(-1)

                log_prior = logits
                log_posterior = log_prior + input_lls
                log_posterior = log_posterior.log_softmax(dim=2)
                logits = log_posterior
            else:
                log_prior = logits
                log_posterior = log_prior + input_lls
                log_posterior = log_posterior.log_softmax(dim=2)
                logits = log_posterior

        # Sample from categorical distribution defined by weights to obtain indices into input channels
        if is_mpe:
            # Take the argmax of the logits to obtain the most probable index
            sampling_ctx.channel_index = torch.argmax(logits, dim=-1)
        else:
            # Sample from categorical distribution defined by weights to obtain indices into input channels
            sampling_ctx.channel_index = torch.distributions.Categorical(logits=logits).sample()

        # Sample from input module
        self.inputs.sample(
            data=data,
            is_mpe=is_mpe,
            cache=cache,
            sampling_ctx=sampling_ctx,
        )

        return data

    def expectation_maximization(
        self,
        data: Tensor,
        cache: Cache | None = None,
    ) -> None:
        """Perform expectation-maximization step.

        Args:
            data: Input data tensor.
            cache: Optional cache dictionary with log-likelihoods.

        Raises:
            ValueError: If required log-likelihoods are not found in cache.
        """
        if cache is None:
            cache = Cache()

        with torch.no_grad():
            # ----- expectation step -----

            # Get input LLs from cache
            input_lls = cache["log_likelihood"].get(self.inputs)
            if input_lls is None:
                raise ValueError("Input log-likelihoods not found in cache. Call log_likelihood first.")

            # Get module lls from cache
            module_lls = cache["log_likelihood"].get(self)
            if module_lls is None:
                raise ValueError("Module log-likelihoods not found in cache. Call log_likelihood first.")

            log_weights = self.log_weights.unsqueeze(0)
            log_grads = torch.log(module_lls.grad).unsqueeze(2)
            input_lls = input_lls.unsqueeze(3)
            module_lls = module_lls.unsqueeze(2)

            log_expectations = log_weights + log_grads + input_lls - module_lls
            log_expectations = log_expectations.logsumexp(0)  # Sum over batch dimension
            log_expectations = log_expectations.log_softmax(self.sum_dim)  # Normalize

            # ----- maximization step -----
            self.log_weights = log_expectations

        # Recursively call EM on inputs
        self.inputs.expectation_maximization(data, cache=cache)

    def maximum_likelihood_estimation(
        self,
        data: Tensor,
        weights: Tensor | None = None,
        cache: Cache | None = None,
    ) -> None:
        """Update parameters via maximum likelihood estimation.

        For Sum modules, this is equivalent to EM.

        Args:
            data: Input data tensor.
            weights: Optional sample weights (currently unused).
            cache: Optional cache dictionary.
        """
        self.expectation_maximization(data, cache=cache)

    def marginalize(
        self,
        marg_rvs: list[int],
        prune: bool = True,
        cache: Cache | None = None,
    ) -> SquaredSum | None:
        """Marginalize out specified random variables.

        Args:
            marg_rvs: List of random variables to marginalize.
            prune: Whether to prune the module.
            cache: Optional cache dictionary.

        Returns:
            Marginalized Sum module or None.
        """
        if cache is None:
            cache = Cache()

        # compute module scope (same for all outputs)
        module_scope = self.scope
        marg_input = None

        mutual_rvs = set(module_scope.query).intersection(set(marg_rvs))
        module_weights = self.weights

        # module scope is being fully marginalized over
        if len(mutual_rvs) == len(module_scope.query):
            # passing this loop means marginalizing over the whole scope of this branch
            return None

        # node scope is being partially marginalized
        elif mutual_rvs:
            # marginalize input modules
            marg_input = self.inputs.marginalize(marg_rvs, prune=prune, cache=cache)

            # if marginalized input is not None
            if marg_input:
                # Apply mask to weights per-repetition
                masked_weights_list = []
                for r in range(self.num_repetitions):
                    feature_to_scope_r = self.inputs.feature_to_scope[:, r].copy()
                    # remove mutual_rvs from feature_to_scope list
                    for rv in mutual_rvs:
                        for idx, scope in enumerate(feature_to_scope_r):
                            if scope is not None:
                                if rv in scope.query:
                                    feature_to_scope_r[idx] = scope.remove_from_query(rv)

                    # construct mask with empty scopes
                    mask = torch.tensor([not scope.empty() for scope in feature_to_scope_r], device=self.device).bool()

                    # Apply mask to weights for this repetition: (out_features, in_channels, out_channels)
                    masked_weights_r = module_weights[:, :, :, r][mask]
                    masked_weights_list.append(masked_weights_r)

                # Stack weights back along the repetition dimension
                # Handle different repetition counts if needed
                if all(w.shape[0] == masked_weights_list[0].shape[0] for w in masked_weights_list):
                    # All repetitions have same number of features, can stack directly
                    module_weights = torch.stack(masked_weights_list, dim=-1)
                else:
                    # Features differ across repetitions - this shouldn't happen in practice
                    # but handle gracefully by keeping the largest
                    max_features = max(w.shape[0] for w in masked_weights_list)
                    padded_list = []
                    for w in masked_weights_list:
                        if w.shape[0] < max_features:
                            padding = torch.zeros(
                                max_features - w.shape[0], w.shape[1], w.shape[2],
                                device=w.device, dtype=w.dtype
                            )
                            w = torch.cat([w, padding], dim=0)
                        padded_list.append(w)
                    module_weights = torch.stack(padded_list, dim=-1)

        else:
            marg_input = self.inputs

        if marg_input is None:
            return None

        else:
            return SquaredSum(inputs=marg_input, weights=module_weights, sum_dim=self.sum_dim)

## Split Modules

At the end of this guide, we take a look at a rather unique type of intermediate module: the **Split module**. A Split module receives a single module as input and produces multiple outputs by splitting it in a defined way. This becomes relevant whenever a module requires multiple inputs—for example, the elementwise product module, which computes the elementwise product of two or more input modules.

Since there are many possible ways to split a module, the following explanation provides guidance on how to create your own Split module adjusted to your specific splitting strategy.

Split modules inherit from the base `Split` class, where `out_features` and `out_channels` are already defined. Conceptually, a Split module represents a different *view* of an existing module. Therefore, `out_features` and `out_channels` always match those of the input module. However, you must implement the mapping from feature index to scope in the `feature_to_scope` property. And because Split modules only provide an alternative view, the only method that must be implemented is `log_likelihood`.

A Split module takes:
- a single input module,
- the dimension along which the split should occur,
- and the number of splits, which determines how many groups the input module should be divided into.

In the example below, we implement a **pairwise alternating split**. For instance, if we split along the feature dimension into two groups and the features are:

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


then the resulting split would be:

- Group 1: `[0, 1, 4, 5, 8, 9]`  
- Group 2: `[2, 3, 6, 7]`

For efficient evaluation of `log_likelihood`, the most common numbers of splits (2 and 3) are implemented in a hard-coded fashion. Other numbers of splits are handled by constructing a mask during initialization.


In [4]:
"""Alternating splitting operation for tensor partitioning.

Distributes features in an alternating pattern across splits using modulo
arithmetic. Promotes feature diversity across branches. Used in RAT-SPN
and similar architectures.
"""

from __future__ import annotations

import torch
from torch import Tensor

from spflow.meta.data import Scope
from spflow.modules.base import Module
from spflow.modules.ops.split import Split
from spflow.utils.cache import Cache, cached


class SplitAlternatePairwise(Split):
    """Split operation using pairwise alternating feature distribution.

    Distributes features pairwise: features (0,1) go to split 0, (2,3) go to split 1,
    (4,5) go to split 2, then repeating modulo num_splits.
    """

    def __init__(self, inputs: Module, dim: int = 1, num_splits: int | None = 2):
        super().__init__(inputs=inputs, dim=dim, num_splits=num_splits)

        num_f = inputs.out_features
        device = inputs.device

        # Pairwise alternating index assignment:
        # For feature i, compute which pair it's in (i // 2), then modulo num_splits
        pair_ids = (torch.arange(num_f, device=device) // 2) % num_splits

        # Create masks for each split
        self.split_masks = [pair_ids == i for i in range(num_splits)]

    def extra_repr(self) -> str:
        return f"{super().extra_repr()}, dim={self.dim}"

    @property
    def feature_to_scope(self) -> list[Scope]:
        scopes = self.inputs[0].feature_to_scope
        feature_to_scope = []

        # pairwise selection: (0,1)->split0, (2,3)->split1, ...
        for i in range(self.num_splits):
            sub_scopes = [scopes[j] for j in range(len(scopes))
                          if ((j // 2) % self.num_splits) == i]
            feature_to_scope.append(sub_scopes)

        return feature_to_scope

    def _apply(self, fn):
        super()._apply(fn)
        self.split_masks = [fn(mask) for mask in self.split_masks]
        return self

    @cached
    def log_likelihood(self, data: Tensor, cache: Cache | None = None) -> list[Tensor]:
        lls = self.inputs[0].log_likelihood(data, cache=cache)

        if self.num_splits == 1:
            return [lls]

        # Hard-code pairwise versions for common cases
        if self.num_splits == 2:
            # pairs: (0,1)->0, (2,3)->1, ...
            return [
                lls[:, [i for i in range(lls.shape[1]) if (i // 2) % 2 == 0], ...],
                lls[:, [i for i in range(lls.shape[1]) if (i // 2) % 2 == 1], ...]
            ]

        elif self.num_splits == 3:
            return [
                lls[:, [i for i in range(lls.shape[1]) if (i // 2) % 3 == 0], ...],
                lls[:, [i for i in range(lls.shape[1]) if (i // 2) % 3 == 1], ...],
                lls[:, [i for i in range(lls.shape[1]) if (i // 2) % 3 == 2], ...]
            ]

        # General fallback: use masks computed in __init__
        return [lls[:, mask, ...] for mask in self.split_masks]


In [10]:
from tests.utils.leaves import make_normal_leaf, make_normal_data
out_features = 5
out_channels = 3
num_reps = 1
num_splits = 2
scope = Scope(list(range(0, out_features)))

inputs_a = make_normal_leaf(scope, out_channels=out_channels, num_repetitions=num_reps)

module = SplitAlternatePairwise(inputs=inputs_a, num_splits=num_splits, dim=1)

data = make_normal_data(out_features=module.out_features, num_samples=1)
lls = module.log_likelihood(data)

print(module.feature_to_scope[0])
print(module.feature_to_scope[1])



[array([0], dtype=object), array([1], dtype=object), array([4], dtype=object)]
[array([2], dtype=object), array([3], dtype=object)]
