In [None]:
# REFERENCES and further reading:

# Di Mauro, Gala et al. Random Probabilistic Circuits. UAI (2021).
# Peharz et al. Einsum Networks: Fast and Scalable Learning of Tractable Probabilistic Circuits. ICML (2020). 
# Peharz et al. Probabilistic Deep Learning using Random Sum-Product Networks. UAI (2020). 
# Van de Wolfshaar and Pronobis. Deep Generalized Convolutional Sum-Product Networks for Probabilistic Image Representations. PGM (2020).
# Molina, Vergari et al. SPFLOW : An easy and extensible library for deep probabilistic learning using Sum-Product Networks. CoRR (2019). 
# Molina, Vergari et al. Mixed Sum-Product Networks: A Deep Architecture for Hybrid Domains. AAAI (2018).
# Di Mauro et al. Sum-Product Network structure learning by efficient product nodes discovery. AIxIA (2018).
# Papamakarios et al. Masked Autoregressive Flow for Density Estimation. NeurIPS (2017). 
# Dinh et al. Density Estimation using RealNVP. ICLR (2017). 
# Desana and Schnörr. Learning Arbitrary Sum-Product Network Leaves with Expectation-Maximization. CoRR (2016). 
# Peharz et al. On Theoretical Properties of Sum-Product Networks. AISTATS (2015). 
# Dinh et al. NICE: Non-linear Independent Components Estimation. ICLR (2015). 
# Rahman et al. Cutset Networks: A Simple, Tractable, and Scalable Approach for Improving the Accuracy of Chow-Liu Trees. ECML-PKDD (2014).
# Poon and Domingos. Sum-Product Networks: A New Deep Architecture. UAI (2011).
# https://github.com/deeprob-org

## For Basic Knowledge
JAX and Tensorflow: 
- Jax whole [github](https://github.com/probml) platform dedicated to it and [3 books](https://probml.github.io/pml-book/) (the book N0 is available [here](http://noiselab.ucsd.edu/ECE228/Murphy_Machine_Learning.pdf) ), [other books](https://www.ics.uci.edu/~smyth/courses/cs274/books.html) and [lectures with excercises](https://www.ics.uci.edu/~smyth/courses/cs274/), [reading groups using those books](https://www.youtube.com/watch?v=FDTXBaMazNg&list=PLOk2cpmAEiU3YgtHRUm58zGkw66nF2NLZ&index=1), [also this](https://www.youtube.com/watch?v=1vtkeR5yieo&list=PLmp4AHm0u1g3xuIHtrT37yOZCj51lWqic) and [chapter implementations of the same content: see book2 and tutorials](https://github.com/probml/pyprobml/tree/master/notebooks), also [direct colab versions available ](https://github.com/probml/pyprobml/blob/auto_notebooks_md/notebooks.md#gauss_plot_2d.ipynb)
- Jax examples [VE and more](https://github.com/google/jax/blob/main/examples/mnist_vae.py)
- https://www.tensorflow.org/probability
  also look at [overview](https://www.youtube.com/watch?v=BrwKURU-wpk) and a [great blog](https://blog.tensorflow.org/2018/12/an-introduction-to-probabilistic.html?_gl=1*stcllb*_ga*MzA2NjA0NzY2LjE2ODU2ODA1Mjg.*_ga_W0YLR4190T*MTY4NzE2MjcxMi4yLjEuMTY4NzE2NjY3OC4wLjAuMA..),
- book with code [Bayesian Methods for Hackers](https://github.com/CamDavidsonPilon/Probabilistic-Programming-and-Bayesian-Methods-for-Hackers),
  also github and [examples](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/examples/jupyter_notebooks), also probabilistic programming library [PyMC](https://www.pymc.io/welcome.html)
- Other tutorials:
  libspn [examples](https://github.com/pronobis/libspn-keras),
  tutorial by Luis [video](https://www.youtube.com/watch?v=ZK3NjrGLIQY&list=LL&index=1)  &  [code](https://github.com/luisroque/probabilistic-deep-learning/tree/main/workshops)
- Several playlists from University of Tübingen for [theoretic background](https://www.youtube.com/@TubingenML)
- General probabilistic learning by Phillip Henning [course website](https://uni-tuebingen.de/en/180804) with [code for lectures in JAX](https://github.com/philipphennig/ProbML_Apps/tree/a893802b8fd34672ac5d532f08de540a19e465b7) and [lectures](https://www.youtube.com/playlist?list=PL05umP7R6ij2YE8rRJSb-olDNbntAQ_Bx) ([older 2021](https://www.youtube.com/watch?v=UbaVGD4Lfis&list=PL05umP7R6ij1tHaOFY96m5uX3J21a6yNd)). Check out his [book with code](https://www.probabilistic-numerics.org/textbooks/) and Rasmusen's [book](http://gaussianprocess.org) and simple explanation on [gaussian process](https://distill.pub/2019/visual-exploration-gaussian-processes/) (some other explanations [1](https://domino.ai/blog/fitting-gaussian-process-models-python),[2](http://katbailey.github.io/post/gaussian-processes-for-dummies/),[3](https://nbviewer.org/github/adamian/adamian.github.io/blob/master/talks/Brown2016.ipynb)), [bayesian optimization etc.](https://distill.pub/2020/bayesian-optimization/), [constructing kernels](https://www.cs.toronto.edu/~duvenaud/cookbook/). libraries just for Gaussian process specifically: [GPFlow](https://gpflow.org/), [scikit-learn](https://scikit-learn.org/stable/modules/gaussian_process.html), [GPy](http://sheffieldml.github.io/GPy/). 

PyTorch:
- [Torch distrubutions](https://pytorch.org/docs/stable/distributions.html)
- [Code with tutorials in both: Jax and Pytorch](https://uvadlc-notebooks.readthedocs.io/en/latest/)
- [GPyTorch](https://gpytorch.ai/) for gaussian inference
- [Pyro](https://pyro.ai/), [deeprob](https://github.com/deeprob-org/deeprob-kit), [StarAI](https://github.com/UCLA-StarAI), [Several](https://arranger1044.github.io/probabilistic-circuits/code)

Researchers:

- [Kevin Murphy](https://www.cs.ubc.ca/~murphyk/)
- Antonio: [has sorted library of all important resources](http://nolovedeeplearning.com/buysellexchange.html) 
- [YooJung](https://yoojungchoi.github.io/)
- [Guy Van den Broeck](https://web.cs.ucla.edu/~guyvdb/)
- Phillip Henning: [github](https://github.com/philipphennig) (check another lecture course with the book and code as well: [numerics](https://www.probabilistic-numerics.org/teaching/2022_Numerics_of_Machine_Learning/), [video lectures](https://www.youtube.com/watch?v=0Q1ZTLHULcw&list=PL05umP7R6ij1qQyGmhrto2iMJXcmfSz5J), [more lectures](https://www.youtube.com/watch?v=RqFwO3GwYf4&list=PL05umP7R6ij2lwDdj7IkuHoP9vHlEcH0s)

Some other books:

- [High-Dimensional Probability An Introduction with Applications in Data Science](https://www.math.uci.edu/~rvershyn/papers/HDP-book/HDP-book.html#)

State-of-the-art Research:
- [Bayesian Flow Networks paper](https://arxiv.org/abs/2308.07037) and [simple available implementation](https://github.com/Algomancer/Bayesian-Flow-Networks)


In [None]:
#@title Imports

from __future__ import annotations
import abc
from typing import Type, Optional, Union, List, Tuple, NamedTuple, Iterator, Any, Callable
from collections import deque, defaultdict
import itertools
from itertools import combinations

import numpy as np
import scipy.stats as ss
from scipy import linalg
from scipy.special import logsumexp, gammaln, log_softmax
from scipy import sparse as sp
from scipy import stats

from enum import Enum
import copy
from copy import deepcopy
import joblib
from tqdm import tqdm

import warnings
from sklearn import mixture, cluster, cross_decomposition
from sklearn.exceptions import ConvergenceWarning
from sklearn.base import BaseEstimator, DensityMixin, ClassifierMixin

# Data handling

In [None]:
#@title context for each thread

import contextlib
import contextvars

# Thread-safe context variables, i.e. each thread will have its own flags assignments
_context_variables = contextvars.ContextVar(
    'context_variables',
    default={
        'check_dtype': True,
        'check_spn': True
    }
)


def is_check_dtype_enabled() -> bool:
    """Returns whether the context flag 'check_dtype' is enabled."""
    return _context_variables.get()['check_dtype']


def is_check_spn_enabled() -> bool:
    """Returns whether the context flag 'check_spn' is enabled."""
    return _context_variables.get()['check_spn']


class ContextState(contextlib.ContextDecorator):
    def __init__(self, **kwargs):
        """
        Thread-safe Context State that disables some flags during execution.

        Current supported flags are the following:
        - check_dtype: bool = True, Whether to check (and cast when needed) Numpy arrays data types.
        - check_spn: bool = True, Whether to check the SPNs structure properties.
        """
        self.__token = None
        self.__state = _context_variables.get().copy()
        for flag, value in kwargs.items():
            if flag not in self.__state:
                raise ValueError("Cannot set an unknown flag called '{}', suitable flags are: {}".format(
                    flag, ', '.join(self.__state.keys())
                ))
            self.__state[flag] = value

    def __enter__(self):
        self.__token = _context_variables.set(self.__state)

    def __exit__(self, *exc):
        _context_variables.reset(self.__token)

In [None]:
#@title data transforms

class DataTransform(abc.ABC):
    """Abstract data transformation."""
    @abc.abstractmethod
    def fit(self, data: np.ndarray):
        """
        Fit the data transform with some data.

        :param data: The data for fitting.
        """

    @abc.abstractmethod
    def forward(self, data: np.ndarray) -> np.ndarray:
        """
        Apply the data transform to some data.

        :param data: The data to transform.
        :return: The transformed data.
        """

    @abc.abstractmethod
    def backward(self, data: np.ndarray) -> np.ndarray:
        """
        Apply the backward data transform to some data.

        :param data: The data to transform.
        :return: The transformed data.
        """


class DataFlatten(DataTransform):
    def __init__(self):
        """
        Build the data flatten transformation.
        """
        self.shape = None

    def fit(self, data: np.ndarray):
        self.shape = data.shape[1:]

    def forward(self, data: np.ndarray) -> np.ndarray:
        return np.reshape(data, [len(data), -1])

    def backward(self, data: np.ndarray) -> np.ndarray:
        return np.reshape(data, [len(data), *self.shape])


class DataNormalizer(DataTransform):
    def __init__(
        self,
        interval: Optional[Tuple[float, float]] = None,
        clip: bool = False,
        dtype=np.float32
    ):
        """
        Build the data normalizer transformation.

        :param interval: The normalizing interval. If None data will be normalized in [0, 1].
        :param clip: Whether to clip data if out of interval.
        :param dtype: The type for type conversion.
        :raises ValueError: If the normalizing interval is out of domain.
        """
        if interval is None:
            interval = (0.0, 1.0)
        elif interval[0] >= interval[1]:
            raise ValueError("The normalizing interval must be (a, b) with a < b")

        self.interval = interval
        self.clip = clip
        self.dtype = dtype
        self.prev_dtype = None
        self.min = None
        self.max = None

    def fit(self, data: np.ndarray):
        self.prev_dtype = data.dtype
        self.min = np.min(data, axis=0)
        self.max = np.max(data, axis=0)

    def forward(self, data: np.ndarray) -> np.ndarray:
        a, b = self.interval
        data = (data - self.min) / (self.max - self.min)
        data = data * (b - a) + a
        if self.clip:
            data = np.clip(data, a, b)
        return data.astype(self.dtype)

    def backward(self, data: np.ndarray) -> np.ndarray:
        a, b = self.interval
        data = (data - a) / (b - a)
        data = (self.max - self.min) * data + self.min
        return data.astype(self.prev_dtype)


class DataStandardizer(DataTransform):
    def __init__(self, sample_wise: bool = True, eps: float = 1e-7, dtype=np.float32):
        """
        Build the data standardizer transformation.

        :param sample_wise: Whether to apply sample wise standardization.
        :param eps: The epsilon value for standardization.
        :param dtype: The type for type conversion.
        :raises ValueError: If the epsilon value is out of domain.
        """
        if eps <= 0.0:
            raise ValueError("The epsilon value must be positive")
        self.sample_wise = sample_wise
        self.eps = eps
        self.dtype = dtype
        self.prev_dtype = None
        self.mean = None
        self.stddev = None

    def fit(self, data: np.ndarray):
        self.prev_dtype = data.dtype
        axis = 0 if self.sample_wise else None
        self.mean = np.mean(data, axis=axis)
        self.stddev = np.std(data, axis=axis)

    def forward(self, data: np.ndarray) -> np.ndarray:
        data = (data - self.mean) / (self.stddev + self.eps)
        return data.astype(self.dtype)

    def backward(self, data: np.ndarray) -> np.ndarray:
        data = (self.stddev + self.eps) * data + self.mean
        return data.astype(self.prev_dtype)


def ohe_data(data: np.ndarray, domain: Union[List[int], np.ndarray]) -> np.ndarray:
    """
    One-Hot-Encoding function.

    :param data: The 1D data to encode.
    :param domain: The domain to use.
    :return: The One Hot encoded data.
    """
    ohe = np.zeros((len(data), len(domain)), dtype=np.float32)
    ohe[np.equal.outer(data, domain)] = 1.0
    return ohe


def mixed_ohe_data(data: np.ndarray, domains: List[Union[list, tuple]]) -> np.ndarray:
    """
    One-Hot-Encoding function, applied on mixed data (both continuous and non-binary discrete).
    Note that One-Hot-Encoding is applied only on categorical random variables having more than two values.

    :param data: The data matrix to encode.
    :param domains: The domains to use.
    :return: The One Hot encoded data.
    :raises ValueError: If there are inconsistencies between the data and domains.
    """
    _, n_features = data.shape
    if len(domains) != n_features:
        raise ValueError("Each data column should correspond to a random variable having a domain")

    ohe = []
    for i in range(n_features):
        if len(domains[i]) > 2:
            ohe.append(ohe_data(data[:, i], domains[i]))
        else:
            ohe.append(data[:, i])
    return np.column_stack(ohe)


def ecdf_data(data: np.ndarray) -> np.ndarray:
    """
    Empirical Cumulative Distribution Function (ECDF).

    :param data: The data.
    :return: The result of the ECDF on data.
    """
    return stats.rankdata(data, method='max') / len(data)


def check_data_dtype(data: np.ndarray, dtype: Type[np.dtype] = np.float32):
    """
    Check whether the data is compatible with a given dtype (defaults to np.float32).
    If the data dtype is not compatible, then cast it.

    :param data: The data.
    :param dtype: The desidered dtype compatibility (defaults to np.float32).
    :return: The casted data if necessary, otherwise returns data itself.
    """
    if not is_check_dtype_enabled():
        # Skip data dtype check and casting
        return data

    # Get flags for floating point data and type
    is_data_fp = data.dtype in [np.float32, np.float64]
    is_dtype_fp = dtype in [np.float32, np.float64]

    if is_dtype_fp:
        if not is_data_fp or data.dtype.itemsize < np.dtype(dtype).itemsize:
            # If dtype is FP and data is not FP or it is a "smaller" FP, then cast it
            return data.astype(dtype)
    elif is_data_fp or data.dtype.itemsize < np.dtype(dtype).itemsize:
        # If dtype is integral and data is FP or it is a "smaller" integral, then cast it
        return data.astype(dtype)

    # Data is compatible w.r.t. dtype
    # i.e. it is FP if dtype is FP and integral if dtype is integral, and it is at least as "big" as dtype
    return data

# Structure

In [None]:
#@title statistics helpers


def compute_mean_quantiles(data: np.ndarray, n_quantiles: int) -> np.ndarray:
    """
    Compute the mean quantiles of a dataset (Poon-Domingos).

    :param data: The data.
    :param n_quantiles: The number of quantiles.
    :return: The mean quantiles.
    :raises ValueError: If the number of quantiles is not valid.
    """
    n_samples = len(data)
    if n_quantiles <= 0 or n_quantiles > n_samples:
        raise ValueError("The number of quantiles must be positive and less or equal than the number of samples")

    # Split the dataset in quantiles regions
    data = np.sort(data, axis=0)
    values_per_quantile = np.array_split(data, n_quantiles, axis=0)

    # Compute the mean quantiles
    mean_per_quantiles = [np.mean(x, axis=0) for x in values_per_quantile]
    return np.stack(mean_per_quantiles, axis=0)


def compute_mutual_information(priors: np.ndarray, joints: np.ndarray) -> np.ndarray:
    """
    Compute the mutual information between each features, given priors and joints distributions.

    :param priors: The priors probability distributions, as a (N, D) Numpy array
                   having priors[i, k] = P(X_i=k).
    :param joints: The joints probability distributions, as a (N, N, D, D) Numpy array
                   having joints[i, j, k, l] = P(X_i=k, X_j=l).
    :return: The mutual information between each pair of features, as a (N, N) Numpy symmetric matrix.
    :raises ValueError: If there are inconsistencies between priors and joints arrays.
    :raises ValueError: If joints array is not symmetric.
    :raises ValueError: If priors or joints arrays don't encode valid probability distributions.
    """
    n_variables, n_values = priors.shape
    if joints.shape != (n_variables, n_variables, n_values, n_values):
        raise ValueError("There are inconsistencies between priors and joints distributions")
    if not np.all(joints == joints.transpose([1, 0, 3, 2])):
        raise ValueError("The joints probability distributions are expected to be symmetric")
    if not np.allclose(np.sum(priors, axis=1), 1.0):
        raise ValueError("The priors probability distributions are not valid")
    if not np.allclose(np.sum(joints, axis=(2, 3)), 1.0):
        raise ValueError("The joints probability distributions are not valid ")

    outers = np.multiply.outer(priors, priors).transpose([0, 2, 1, 3])
    # Ignore warnings of logarithm at zero (because NaNs on the diagonal will be zeroed later anyway)
    with np.errstate(divide='ignore', invalid='ignore'):
        mutual_info = np.sum(joints * (np.log(joints) - np.log(outers)), axis=(2, 3))
    np.fill_diagonal(mutual_info, 0.0)
    return mutual_info


def estimate_priors_joints(data: np.ndarray, alpha: float = 0.1) -> Tuple[np.ndarray, np.ndarray]:
    """
    Estimate both priors and joints probability distributions from binary data.

    This function returns both the prior distributions and the joint distributions.
    Note that priors[i, k] = P(X_i=k) and joints[i, j, k, l] = P(X_i=k, X_j=l).

    :param data: The binary data matrix.
    :param alpha: The Laplace smoothing factor.
    :return: A pair of priors and joints distributions.
    :raises ValueError: If the Laplace smoothing factor is out of domain.
    """
    if alpha < 0.0:
        raise ValueError("The Laplace smoothing factor must be non-negative")

    # Check the data dtype
    data = check_data_dtype(data, dtype=np.float32)

    # Compute the counts
    n_samples, n_features = data.shape
    counts_ones = np.dot(data.T, data)
    counts_features = np.diag(counts_ones)
    counts_cols = counts_features * np.ones_like(counts_ones)
    counts_rows = np.transpose(counts_cols)

    # Compute the prior probabilities
    priors = np.empty(shape=(n_features, 2), dtype=data.dtype)
    priors[:, 1] = (counts_features + 2 * alpha) / (n_samples + 4 * alpha)
    priors[:, 0] = 1.0 - priors[:, 1]

    # Compute the joints probabilities
    joints = np.empty(shape=(n_features, n_features, 2, 2), dtype=data.dtype)
    joints[:, :, 0, 0] = n_samples - counts_cols - counts_rows + counts_ones
    joints[:, :, 0, 1] = counts_cols - counts_ones
    joints[:, :, 1, 0] = counts_rows - counts_ones
    joints[:, :, 1, 1] = counts_ones
    joints = (joints + alpha) / (n_samples + 4 * alpha)

    # Correct smoothing on the diagonal of joints array
    idx_features = np.arange(n_features)
    joints[idx_features, idx_features, 0, 0] = priors[:, 0]
    joints[idx_features, idx_features, 0, 1] = 0.0
    joints[idx_features, idx_features, 1, 0] = 0.0
    joints[idx_features, idx_features, 1, 1] = priors[:, 1]

    return priors, joints


def compute_gini(probs: np.ndarray) -> float:
    """
    Computes the Gini index given some probabilities.

    :param probs: The probabilities.
    :return: The Gini index.
    :raises ValueError: If the probabilities doesn't sum up to one.
    """
    if not np.isclose(np.sum(probs), 1.0):
        raise ValueError("Probabilities must sum up to one")
    return 1.0 - np.sum(probs ** 2.0)


def compute_bpp(avg_ll: float, shape: Union[int, tuple, list]):
    """
    Compute the average number of bits per pixel (BPP).

    :param avg_ll: The average log-likelihood, expressed in nats.
    :param shape: The number of dimensions or, alternatively, a sequence of dimensions.
    :return: The average number of bits per pixel.
    """
    return -avg_ll / (np.log(2.0) * np.prod(shape))


def compute_fid(
    mean1: np.ndarray,
    cov1: np.ndarray,
    mean2: np.ndarray,
    cov2: np.ndarray,
    blocksize: int = 64,
    eps: float = 1e-6
) -> float:
    """
    Computes the Frechet Inception Distance (FID) between two multivariate Gaussian distributions.
    This implementation has been readapted from https://github.com/mseitzer/pytorch-fid.

    :param mean1: The mean of the first multivariate Gaussian.
    :param cov1: The covariance of the first multivariate Gaussian.
    :param mean2: The mean of the second multivariate Gaussian.
    :param cov2: The covariance of the second multivariate Gaussian.
    :param blocksize: The block size used by the matrix square root algorithm.
    :param eps: Epsilon value used to avoid singular matrices.
    :return: The FID score.
    :raises ValueError: If there is a shape mismatch between input arrays.
    """
    if mean1.ndim != 1 or mean2.ndim != 1:
        raise ValueError("Mean arrays must be one-dimensional")
    if cov1.ndim != 2 or cov2.ndim != 2:
        raise ValueError("Covariance arrays must be two-dimensional")
    if mean1.shape != mean2.shape:
        raise ValueError("Shape mismatch between mean arrays")
    if cov1.shape != cov2.shape:
        raise ValueError("Shape mismatch between covariance arrays")

    # Compute the matrix square root of the dot product between covariance matrices
    sqrtcov, _ = linalg.sqrtm(np.dot(cov1, cov2), disp=False, blocksize=blocksize)
    if np.any(np.isinf(sqrtcov)):  # Matrix square root can give Infinity values in case of singular matrices
        epsdiag = np.zeros_like(cov1)
        np.fill_diagonal(epsdiag, eps)
        sqrtcov, _ = linalg.sqrtm(np.dot(cov1 + epsdiag, cov2 + epsdiag), disp=False, blocksize=blocksize)

    # Numerical errors might give a complex output, even if the input arrays are real
    if np.iscomplexobj(sqrtcov) and np.isrealobj(cov1) and np.isrealobj(cov2):
        sqrtcov = sqrtcov.real

    # Compute the dot product of the difference between mean arrays
    diffm = mean1 - mean2
    diffmdot = np.dot(diffm, diffm)

    # Return the final FID score
    return diffmdot + np.trace(cov1) + np.trace(cov2) - 2.0 * np.trace(sqrtcov)


def compute_prior_counts(
    data: np.ndarray
):
    """
    Compute the counts of the values of an RV given the data.

    :param data: The binary data matrix.
    :return: The counts.
    """
    n_samples, n_features = data.shape
    counts_features = data.sum(axis=0)

    # Compute the prior counts
    prior_counts = np.empty(shape=(n_features, 2), dtype=np.float32)
    prior_counts[:, 1] = counts_features
    prior_counts[:, 0] = n_samples - prior_counts[:, 1]
    return prior_counts


def compute_joint_counts(
    data: np.ndarray
):
    """
    Compute the counts of the configurations of an RV and its parent given the data.

    :param data: The binary data matrix.
    :return: The counts.
    """
    n_samples, n_features = data.shape
    counts_ones = np.dot(data.T, data)
    counts_features = np.diag(counts_ones)
    counts_cols = counts_features * np.ones_like(counts_ones)
    counts_rows = np.transpose(counts_cols)

    # Compute the joint counts
    joint_counts = np.empty(shape=(n_features, n_features, 2, 2), dtype=np.float32)
    joint_counts[:, :, 0, 0] = n_samples - counts_cols - counts_rows + counts_ones
    joint_counts[:, :, 0, 1] = counts_cols - counts_ones
    joint_counts[:, :, 1, 0] = counts_rows - counts_ones
    joint_counts[:, :, 1, 1] = counts_ones
    return joint_counts


In [None]:
#@title node and leaf


#------------------------------------------------------------- Correct Data Type


def check_data_dtype(data: np.ndarray, dtype: Type[np.dtype] = np.float32):
    """
    Check whether the data is compatible with a given dtype (defaults to np.float32).
    If the data dtype is not compatible, then cast it.

    :param data: The data.
    :param dtype: The desidered dtype compatibility (defaults to np.float32).
    :return: The casted data if necessary, otherwise returns data itself.
    """
    if not is_check_dtype_enabled():
        # Skip data dtype check and casting
        return data

    # Get flags for floating point data and type
    is_data_fp = data.dtype in [np.float32, np.float64]
    is_dtype_fp = dtype in [np.float32, np.float64]

    if is_dtype_fp:
        if not is_data_fp or data.dtype.itemsize < np.dtype(dtype).itemsize:
            # If dtype is FP and data is not FP or it is a "smaller" FP, then cast it
            return data.astype(dtype)
    elif is_data_fp or data.dtype.itemsize < np.dtype(dtype).itemsize:
        # If dtype is integral and data is FP or it is a "smaller" integral, then cast it
        return data.astype(dtype)

    # Data is compatible w.r.t. dtype
    # i.e. it is FP if dtype is FP and integral if dtype is integral, and it is at least as "big" as dtype
    return data

#-------------------------------------------------------------------------- Node

class Node(abc.ABC):
    def __init__(self, scope: List[int], children: Optional[List[Node]] = None):
        """
        Initialize a SPN node given the children list and its scope.

        :param scope: The scope.
        :param children: A list of nodes. If None, children are initialized as an empty list.
        :raises ValueError: If the scope is empty.
        :raises ValueError: If the scope contains duplicates.
        """
        if not scope:
            raise ValueError("The scope must not be empty")
        if len(scope) != len(set(scope)):
            raise ValueError("The scope must not contain duplicates")
        if children is None:
            children = list()

        self.id = 0
        self.scope = scope
        self.children = children

    @abc.abstractmethod
    def likelihood(self, x: np.ndarray) -> np.ndarray:
        """
        Compute the likelihood of the node given some input.

        :param x: The inputs.
        :return: The resulting likelihoods.
        """

    @abc.abstractmethod
    def log_likelihood(self, x: np.ndarray) -> np.ndarray:
        """
        Compute the logarithmic likelihood of the node given some input.

        :param x: The inputs.
        :return: The resulting log-likelihoods.
        """


class Sum(Node):
    def __init__(
        self,
        scope: Optional[List[int]] = None,
        children: Optional[List[Node]] = None,
        weights: Optional[Union[List[float], np.ndarray]] = None,
    ):
        """
        Initialize a SPN sum node given a list of children and their weights and a scope.

        :param scope: The scope. If None, the scope is initialized based on children scopes.
        :param children: A list of nodes. If None, children are initialized as an empty list.
        :param weights: The weights associated to each children node. It can be None.
        :raises ValueError: If both scope and children are None.
        :raises ValueError: If children nodes have different scopes.
        :raises ValueError: If the length of weights and children are different.
        :raises ValueError: If weights don't sum up to 1.
        """
        if children is None:
            if scope is None:
                raise ValueError("Cannot infer Sum node's scope without children")
        else:
            if scope is None:
                scope = children[0].scope
            s_scope = set(scope)
            if any(map(lambda c: set(c.scope) != s_scope, children[1:])):
                raise ValueError("Children of Sum node have different scopes")
            if weights is not None and len(weights) != len(children):
                raise ValueError("Weights and children length mismatch")

        if weights is not None:
            if isinstance(weights, list):
                weights = np.array(weights, dtype=np.float32)
            if not np.isclose(np.sum(weights), 1.0):
                raise ValueError("Weights don't sum up to 1")
        self.weights = weights

        super().__init__(scope, children)

    def em_init(self, random_state: np.random.RandomState):
        """
        Random initialize the node's parameters for Expectation-Maximization (EM).

        :param random_state: The random state.
        """
        weights = random_state.dirichlet(np.ones(len(self.children)))
        self.weights = weights.astype(np.float32)

    def em_step(self, stats: np.ndarray, step_size: float):
        """
        Compute a batch Expectation-Maximization (EM) step.

        :param stats: The sufficient statistics of each sample.
        :param step_size: The step size of update.
        """
        unnorm_weights = self.weights * np.sum(stats, axis=1) + np.finfo(np.float32).eps
        weights = unnorm_weights / np.sum(unnorm_weights)

        # Update the parameters
        self.weights = (1.0 - step_size) * self.weights + step_size * weights

    def likelihood(self, x: np.ndarray) -> np.ndarray:
        return np.expand_dims(np.dot(x, self.weights), axis=1)

    def log_likelihood(self, x: np.ndarray) -> np.ndarray:
        return logsumexp(x, b=self.weights, axis=1, keepdims=True)


class Product(Node):
    def __init__(
        self,
        scope: Optional[List[int]] = None,
        children: Optional[List[Node]] = None
    ):
        """
        Initialize a product node given a list of children and its scope.

        :param scope: The scope. If None, the scope is initialized based on children scopes.
        :param children: A list of nodes. If None, children are initialized as an empty list.
        :raises ValueError: If both scope and children are None.
        :raises ValueError: If children nodes don't have disjointed scopes.
        """
        if children is None:
            if scope is None:
                raise ValueError("Cannot infer Product node's scope without children")
        else:
            c_scope = list(sum([c.scope for c in children], []))
            s_scope = set(c_scope)
            if scope is None:
                if len(c_scope) != len(s_scope):
                    raise ValueError("Children of Product node don't have disjointed scopes")
                scope = c_scope
            elif set(scope) != s_scope:
                raise ValueError("Children of Product node don't have disjointed scopes")

        super().__init__(scope, children)

    def likelihood(self, x: np.ndarray) -> np.ndarray:
        return np.prod(x, axis=1, keepdims=True)

    def log_likelihood(self, x: np.append) -> np.ndarray:
        return np.sum(x, axis=1, keepdims=True)


def assign_ids(root: Node) -> Node:
    """
    Assign the ids to the nodes of a SPN.

    :param root: The root of the SPN.
    :return: The same SPN with each node having modified ids.
    :raises ValueError: If the SPN structure is not a DAG.
    """
    nodes = topological_order(root)
    if nodes is None:
        raise ValueError("SPN structure is not a directed acyclic graph (DAG)")

    next_id = 0
    for node in nodes:
        node.id = next_id
        next_id += 1
    return root


def bfs(root: Node) -> Iterator[Node]:
    """
    Compute the Breadth First Search (BFS) ordering for a SPN.

    :param root: The root of the SPN.
    :return: The BFS nodes iterator.
    """
    seen, queue = {root}, deque([root])
    while queue:
        node = queue.popleft()
        yield node
        for c in node.children:
            if c not in seen:
                seen.add(c)
                queue.append(c)


def dfs_post_order(root: Node) -> Iterator[Node]:
    """
    Compute Depth First Search (DFS) Post-Order ordering for a SPN.

    :param root: The root of the SPN.
    :return: The DFS Post-Order nodes iterator.
    """
    seen, stack = {root}, [root]
    while stack:
        node = stack[-1]
        if set(node.children).issubset(seen):
            stack.pop()
            yield node
            continue
        for c in node.children:
            if c not in seen:
                seen.add(c)
                stack.append(c)


def topological_order(root: Node) -> Optional[List[Node]]:
    """
    Compute the Topological Ordering for a SPN, using the Kahn's Algorithm.

    :param root: The root of the SPN.
    :return: A list of nodes that form a topological ordering.
             If the SPN graph is not acyclic, it returns None.
    """
    ordering = list()
    num_outgoings = defaultdict(int)
    num_outgoings[root] = 0

    # Initialize the number of outgoings edges for each node
    for node in bfs(root):
        for c in node.children:
            num_outgoings[c] += 1

    # Check the unusual case where the root node have outgoings edges, i.e. a trivial cycle has been found
    if num_outgoings[root] != 0:
        return None

    # Non-layered topological ordering implementation
    queue = deque([root])
    while queue:
        node = queue.popleft()
        ordering.append(node)
        for c in node.children:
            num_outgoings[c] -= 1
            if num_outgoings[c] == 0:
                queue.append(c)

    # Check if a cycle has been found
    if sum(num_outgoings.values()) != 0:
        return None
    return ordering


def topological_order_layered(root: Node) -> Optional[List[List[Node]]]:
    """
    Compute the Topological Ordering Layered for a SPN, using the Kahn's Algorithm.

    :param root: The root of the SPN.
    :return: A list of layers that form a topological ordering.
             If the SPN graph is not acyclic, it returns None.
    """
    ordering = list()
    num_outgoings = defaultdict(int)
    num_outgoings[root] = 0

    # Initialize the number of outgoings edges for each node
    for node in bfs(root):
        for c in node.children:
            num_outgoings[c] += 1

    # Check the unusual case where the root node have outgoings edges, i.e. a trivial cycle has been found
    if num_outgoings[root] != 0:
        return None

    # Layered topological ordering implementation
    ordering.append([root])
    while True:
        layer = list()
        for node in ordering[-1]:
            for c in node.children:
                num_outgoings[c] -= 1
                if num_outgoings[c] == 0:
                    layer.append(c)
        if not layer:
            break
        ordering.append(layer)

    # Check if a cycle has been found
    if sum(num_outgoings.values()) != 0:
        return None
    return ordering


#-------------------------------------------------------------------------- Leaf

class LeafType(Enum):
    """
    The type of the distribution leaf.
    It can be either discrete or continuous.
    """
    DISCRETE = 1
    CONTINUOUS = 2


class Leaf(Node):
    LEAF_TYPE = None

    def __init__(self, scope: Union[int, List[int]]):
        """
        Initialize a leaf node given its scope.

        :param scope: The scope of the leaf.
        :param kwargs: Additional arguments.
        """
        super().__init__([scope] if isinstance(scope, int) else scope)

    @abc.abstractmethod
    def em_init(self, random_state: np.random.RandomState):
        """
        Random initialize the leaf's parameters for Expectation-Maximization (EM).

        :param random_state: The random state.
        """

    @abc.abstractmethod
    def em_step(self, stats: np.ndarray, data: np.ndarray, step_size: float):
        """
        Compute a batch Expectation-Maximization (EM) step.

        :param stats: The sufficient statistics of each sample.
        :param data: The data regarding random variables of the leaf.
        :param step_size: The step size of update.
        """

    @abc.abstractmethod
    def fit(self, data: np.ndarray, domain: Union[list, tuple], **kwargs):
        """
        Fit the distribution parameters given the domain and some training data.

        :param data: The training data.
        :param domain: The domain of the distribution leaf.
        :param kwargs: Optional parameters.
        :raises ValueError: If a parameter is out of domain.
        """

    @abc.abstractmethod
    def likelihood(self, x: np.ndarray) -> np.ndarray:
        """
        Compute the likelihood of the distribution leaf given some input.

        :param x: The inputs.
        :return: The resulting likelihoods.
        """

    @abc.abstractmethod
    def log_likelihood(self, x: np.ndarray) -> np.ndarray:
        """
        Compute the logarithmic likelihood of the distribution leaf given some input.

        :param x: The inputs.
        :return: The resulting log-likelihoods.
        """

    @abc.abstractmethod
    def mpe(self, x: np.ndarray) -> np.ndarray:
        """
        Compute the maximum at posteriori values.

        :param x: The inputs.
        :return: The distribution's maximum at posteriori values.
        """

    @abc.abstractmethod
    def sample(self, x: np.ndarray) -> np.ndarray:
        """
        Sample from the leaf distribution.

        :param x: The samples with possible NaN values.
        :return: The completed samples.
        """

    @abc.abstractmethod
    def moment(self, k: int = 1) -> float:
        """
        Compute the moment of a given order.

        :param k: The order of the moment.
        :return: The moment of order k.
        """

    @abc.abstractmethod
    def params_count(self) -> int:
        """
        Get the number of parameters of the distribution leaf.

        :return: The number of parameters.
        """

    @abc.abstractmethod
    def params_dict(self) -> dict:
        """
        Get a dictionary representation of the distribution parameters.

        :return: A dictionary containing the distribution parameters.
        """


class Bernoulli(Leaf):
    LEAF_TYPE = LeafType.DISCRETE

    def __init__(self, scope: int, p: float = 0.5):
        """
        Initialize a Bernoulli leaf node given its scope.

        :param scope: The scope of the leaf.
        :param p: The Bernoulli probability.
        :raises ValueError: If a parameter is out of domain.
        """
        super().__init__(scope)
        if p < 0.0 or p > 1.0:
            raise ValueError("The Bernoulli probability must be in [0, 1]")

        self.p = p

    def fit(self, data: np.ndarray, domain: list, alpha: float = 0.1, **kwargs):
        """
        Fit the distribution parameters given the domain and some training data.

        :param data: The training data.
        :param domain: The domain of the distribution leaf.
        :param alpha: The Laplace smoothing factor.
        :param kwargs: Optional parameters.
        :raises ValueError: If a parameter is out of domain.
        """
        if domain != [0, 1]:
            raise ValueError("The domain must be binary for a Bernoulli distribution")
        if alpha < 0.0:
            raise ValueError("The Laplace smoothing factor must be non-negative")

        # Check the data dtype
        data = check_data_dtype(data, dtype=np.float32)

        # Estimate using Laplace smoothing
        self.p = (np.sum(data) + alpha) / (len(data) + 2 * alpha)

    def em_init(self, random_state: np.random.RandomState):
        self.p = random_state.rand()

    def em_step(self, stats: np.ndarray, data: np.ndarray, step_size: float):
        alpha = np.finfo(np.float16).eps  # Use a very small Laplace smoothing factor
        data = np.squeeze(data, axis=1)
        total_stats = np.sum(stats)
        p = (np.dot(stats, data) + alpha) / (total_stats + 2 * alpha)

        # Update the parameters
        self.p = (1.0 - step_size) * self.p + step_size * p

    def likelihood(self, x: np.ndarray) -> np.ndarray:
        ls = np.ones([len(x), 1], dtype=np.float32)
        mask = np.isnan(x)
        ls[~mask] = ss.bernoulli.pmf(x[~mask], self.p)
        return ls

    def log_likelihood(self, x: np.ndarray) -> np.ndarray:
        lls = np.zeros([len(x), 1], dtype=np.float32)
        mask = np.isnan(x)
        lls[~mask] = ss.bernoulli.logpmf(x[~mask], self.p)
        return lls

    def mpe(self, x: np.ndarray) -> np.ndarray:
        x = np.copy(x)
        mask = np.isnan(x)
        x[mask] = 0 if self.p < 0.5 else 1
        return x

    def sample(self, x: np.ndarray) -> np.ndarray:
        x = np.copy(x)
        mask = np.isnan(x)
        x[mask] = ss.bernoulli.rvs(self.p, size=np.count_nonzero(mask))
        return x

    def moment(self, k: int = 1) -> float:
        return ss.bernoulli.moment(k, self.p)

    def params_count(self):
        return 1

    def params_dict(self):
        return {'p': self.p}


class Categorical(Leaf):
    LEAF_TYPE = LeafType.DISCRETE

    def __init__(
        self,
        scope: int,
        categories: Optional[Union[List, np.ndarray]] = None,
        probabilities: Optional[Union[List, np.ndarray]] = None
    ):
        """
        Initialize a Categorical leaf node given its scope.

        :param scope: The scope of the leaf.
        :param categories: The possible categories.
        :param probabilities: The probabilities associated to each category.
        """
        super().__init__(scope)
        self.categories = None
        self.probabilities = None
        self.distribution = None

        if categories is not None and probabilities is not None:
            if len(categories) != len(probabilities):
                raise ValueError("Each category must be associated a probability")
            if not np.isclose(np.sum(probabilities), 1.0):
                raise ValueError("Probabilities parameter must sum up to 1")
            if isinstance(categories, list):
                categories = np.array(categories, np.int64)
            if isinstance(probabilities, list):
                probabilities = np.array(probabilities, np.float32)
            self.categories = np.array(categories, np.int64)
            self.probabilities = np.array(probabilities, np.float32)
            self.distribution = ss.rv_discrete(values=(self.categories, self.probabilities))
        elif categories is not None or probabilities is not None:
            raise ValueError("Partial defined parameters (categories, probabilities) are not handled")

    def fit(self, data: np.ndarray, domain: list, alpha: float = 0.1, **kwargs):
        """
        Fit the distribution parameters given the domain and some training data.

        :param data: The training data.
        :param domain: The domain of the distribution leaf.
        :param alpha: The Laplace smoothing factor.
        :param kwargs: Optional parameters.
        :raises ValueError: If a parameter is out of domain.
        """
        if not isinstance(domain, list):
            raise ValueError("The domain must be categorical for a Categorical distribution")
        if alpha < 0.0:
            raise ValueError("The Laplace smoothing factor must be non-negative")

        self.probabilities = np.empty(len(domain), np.float32)
        for i, d in enumerate(domain):
            self.probabilities[i] = (len(data[data == d]) + alpha) / (len(data) + len(domain) * alpha)
        self.categories = np.array(domain, np.int64)
        self.distribution = ss.rv_discrete(values=(self.categories, self.probabilities))

    def em_init(self, random_state: np.random.RandomState):
        """
        Random initialize the leaf's parameters for Expectation-Maximization (EM).

        :param random_state: The random state.
        :raises ValueError: If the categories are not initialized.
        """
        if self.categories is None:
            raise ValueError("Categorical leaf distribution is not initialized")

        # Initialize the categories probabilities using a dirichlet distribution
        self.probabilities = random_state.dirichlet(np.ones(len(self.categories)))
        self.distribution = ss.rv_discrete(values=(self.categories, self.probabilities))

    def em_step(self, stats: np.ndarray, data: np.ndarray, step_size: float):
        alpha = np.finfo(np.float16).eps  # Use a very small Laplace smoothing factor
        data = np.squeeze(data, axis=1)
        total_stats = np.sum(stats)

        # Compute the probabilities for each category
        probabilities = np.empty(len(self.categories), np.float32)
        for i, d in enumerate(self.categories):
            probabilities[i] = (np.sum(stats[data == d]) + alpha) / (total_stats + len(self.categories) * alpha)

        # Update the parameters
        self.probabilities = (1.0 - step_size) * self.probabilities + step_size * probabilities
        self.distribution = ss.rv_discrete(values=(self.categories, self.probabilities))

    def likelihood(self, x: np.ndarray) -> np.ndarray:
        ls = np.ones([len(x), 1], dtype=np.float32)
        mask = np.isnan(x)
        ls[~mask] = self.distribution.pmf(x[~mask].astype(np.int64, copy=False))
        return ls

    def log_likelihood(self, x: np.ndarray) -> np.ndarray:
        lls = np.zeros([len(x), 1], dtype=np.float32)
        mask = np.isnan(x)
        lls[~mask] = self.distribution.logpmf(x[~mask].astype(np.int64, copy=False))
        return lls

    def mpe(self, x: np.ndarray) -> np.ndarray:
        x = np.copy(x)
        mask = np.isnan(x)
        x[mask] = self.categories[self.probabilities.argmax()]
        return x

    def sample(self, x: np.ndarray) -> np.ndarray:
        x = np.copy(x)
        mask = np.isnan(x)
        x[mask] = self.distribution.rvs(size=np.count_nonzero(mask))
        return x

    def moment(self, k: int = 1) -> float:
        return self.distribution.moment(k)

    def params_count(self) -> int:
        return 2 * len(self.categories)

    def params_dict(self) -> dict:
        if self.distribution is None:
            return {'categories': None, 'probabilities': None}
        return {'categories': self.categories, 'probabilities': self.probabilities}


class Isotonic(Leaf):
    LEAF_TYPE = LeafType.CONTINUOUS

    def __init__(
        self,
        scope: int,
        densities: Optional[Union[List[float], np.ndarray]] = None,
        breaks: Optional[Union[List[float], np.ndarray]] = None
    ):
        """
        Initialize a histogram-Isotonic leaf node given its scope.

        :param scope: The scope of the leaf.
        :param densities: The densities. They must sum up to one.
        :param breaks: The breaks values, such that len(breaks) == len(densities) + 1.
        :raises ValueError: If a parameter is out of domain.
        """
        super().__init__(scope)
        self.densities = None
        self.breaks = None
        self.distribution = None

        if densities is not None and breaks is not None:
            if len(breaks) != len(densities) + 1:
                raise ValueError("Invalid histogram parameters shapes")
            if not np.isclose(np.sum(densities), 1.0):
                raise ValueError("Densities parameter must sum up to 1")
            if isinstance(densities, list):
                densities = np.array(densities, np.float32)
            if isinstance(breaks, list):
                breaks = np.array(breaks, np.float32)
            self.densities = densities
            self.breaks = breaks
            self.distribution = ss.rv_histogram(histogram=(densities, breaks))
        elif densities is not None or breaks is not None:
            raise ValueError("Partial defined parameters (densities, breaks) are not handled")

    def fit(self, data: np.ndarray, domain: tuple, alpha: float = 0.1, **kwargs):
        """
        Fit the distribution parameters given the domain and some training data.

        :param data: The training data.
        :param domain: The domain of the distribution leaf.
        :param alpha: The Laplace smoothing factor.
        :param kwargs: Optional parameters.
        :raises ValueError: If a parameter is out of domain.
        """
        if not isinstance(domain, tuple):
            raise ValueError("The domain must be continuous for an Isotonic distribution")
        if alpha < 0.0:
            raise ValueError("The Laplace smoothing factor must be non-negative")
        histogram, breaks = np.histogram(data, bins='fd')

        # Apply Laplace smoothing and obtain the densities
        densities = (histogram + alpha) / (len(data) + len(histogram) * alpha)
        densities = densities.astype(np.float32, copy=False)
        breaks = breaks.astype(np.float32, copy=False)

        # Build the distribution
        self.densities = densities
        self.breaks = breaks
        self.distribution = ss.rv_histogram(histogram=(densities, breaks))

    def em_init(self, random_state: np.random.RandomState):
        raise NotImplementedError("EM parameters initialization not yet implemented for Isotonic distributions")

    def em_step(self, stats: np.ndarray, data: np.ndarray, step_size: float):
        raise NotImplementedError("EM step not yet implemented for Isotonic distributions")

    def likelihood(self, x: np.ndarray) -> np.ndarray:
        ls = np.ones([len(x), 1], dtype=np.float32)
        mask = np.isnan(x)
        ood_mask = ~mask & ((x <= self.distribution.a) | (x >= self.distribution.b))
        ls[~mask] = self.distribution.pdf(x[~mask])
        ls[ood_mask] = np.finfo(np.float32).eps
        return ls

    def log_likelihood(self, x: np.ndarray) -> np.ndarray:
        lls = np.zeros([len(x), 1], dtype=np.float32)
        mask = np.isnan(x)
        ood_mask = ~mask & ((x <= self.distribution.a) | (x >= self.distribution.b))
        lls[~mask] = self.distribution.logpdf(x[~mask])
        lls[ood_mask] = np.log(np.finfo(np.float64).eps)
        return lls

    def mpe(self, x: np.ndarray) -> np.ndarray:
        x = np.copy(x)
        mask = np.isnan(x)
        idx = np.argmax(self.densities)
        x[mask] = (self.breaks[idx] + self.breaks[idx + 1]) / 2.0
        return x

    def sample(self, x: np.ndarray) -> np.ndarray:
        x = np.copy(x)
        mask = np.isnan(x)
        x[mask] = self.distribution.ppf(q=np.random.rand(np.count_nonzero(mask)))
        return x

    def moment(self, k: int = 1) -> np.ndarray:
        return self.distribution.moment(k)

    def params_count(self) -> int:
        return 2 * len(self.densities) + 1

    def params_dict(self) -> dict:
        if self.distribution is None:
            return {'densities': None, 'breaks': None}
        return {'densities': self.densities, 'breaks': self.breaks}


class Uniform(Leaf):
    LEAF_TYPE = LeafType.CONTINUOUS

    def __init__(self, scope: int, start: float = 0.0, width: float = 1.0):
        """
        Initialize an Uniform leaf node given its scope.

        :param scope: The scope of the leaf.
        :param start: The start of the uniform distribution.
        :param width: The width of the uniform distribution.
        """
        super().__init__(scope)
        self.start = start
        self.width = width

    def fit(self, data: np.ndarray, domain: tuple, **kwargs):
        if not isinstance(domain, tuple):
            raise ValueError("The domain must be continuous for an Uniform distribution")

        # Estimate the parameters of a uniform distribution
        self.start, self.width = ss.uniform.fit(data)

    def em_init(self, random_state: np.random.RandomState):
        raise NotImplementedError("EM parameters initialization not yet implemented for Uniform distributions")

    def em_step(self, stats: np.ndarray, data: np.ndarray, step_size: float):
        raise NotImplementedError("EM step not yet implemented for Uniform distributions")

    def likelihood(self, x: np.ndarray) -> np.ndarray:
        ls = np.ones([len(x), 1], dtype=np.float32)
        mask = np.isnan(x)
        ls[~mask] = ss.uniform.pdf(x[~mask], self.start, self.width)
        return ls

    def log_likelihood(self, x: np.ndarray) -> np.ndarray:
        lls = np.zeros([len(x), 1], dtype=np.float32)
        mask = np.isnan(x)
        lls[~mask] = ss.uniform.logpdf(x[~mask], self.start, self.width)
        return lls

    def mpe(self, x: np.ndarray) -> np.array:
        x = np.copy(x)
        mask = np.isnan(x)
        x[mask] = self.start
        return x

    def sample(self, x: np.ndarray) -> np.ndarray:
        x = np.copy(x)
        mask = np.isnan(x)
        x[mask] = ss.uniform.rvs(self.start, self.width, size=np.count_nonzero(mask))
        return x

    def moment(self, k: int = 1) -> float:
        return ss.uniform.moment(k, self.start, self.width)

    def params_count(self) -> int:
        return 2

    def params_dict(self) -> dict:
        return {
            'start': self.start,
            'width': self.width
        }


class Gaussian(Leaf):
    LEAF_TYPE = LeafType.CONTINUOUS

    def __init__(self, scope: int, mean: float = 0.0, stddev: float = 1.0):
        """
        Initialize a Gaussian leaf node given its scope.

        :param scope: The scope of the leaf.
        :param mean: The mean parameter.
        :param stddev: The standard deviation parameter.
        :raises ValueError: If a parameter is out of domain.
        """
        super().__init__(scope)
        if stddev <= 1e-5:
            raise ValueError("The standard deviation of a Gaussian must be greater than 1e-5")

        self.mean = mean
        self.stddev = stddev

    def fit(self, data: np.ndarray, domain: tuple, **kwargs):
        if not isinstance(domain, tuple):
            raise ValueError("The domain must be continuous for a Gaussian distribution")

        self.mean, self.stddev = ss.norm.fit(data)
        self.stddev = max(self.stddev, 1e-5)

    def em_init(self, random_state: np.random.RandomState):
        self.mean = 1e-1 * random_state.randn()
        self.stddev = 0.5 + 1e-1 * np.tanh(random_state.randn())

    def em_step(self, stats: np.ndarray, data: np.ndarray, step_size: float):
        data = np.squeeze(data, axis=1)
        total_stats = np.sum(stats) + np.finfo(np.float32).eps
        mean = np.sum(stats * data) / total_stats
        stddev = np.sqrt(np.sum(stats * (data - mean) ** 2.0) / total_stats)
        stddev = max(stddev, 1e-5)

        # Update the parameters
        self.mean = (1.0 - step_size) * self.mean + step_size * mean
        self.stddev = (1.0 - step_size) * self.stddev + step_size * stddev

    def likelihood(self, x: np.ndarray) -> np.ndarray:
        ls = np.ones([len(x), 1], dtype=np.float32)
        mask = np.isnan(x)
        ls[~mask] = ss.norm.pdf(x[~mask], self.mean, self.stddev)
        return ls

    def log_likelihood(self, x: np.ndarray) -> np.ndarray:
        lls = np.zeros([len(x), 1], dtype=np.float32)
        mask = np.isnan(x)
        lls[~mask] = ss.norm.logpdf(x[~mask], self.mean, self.stddev)
        return lls

    def mpe(self, x: np.ndarray) -> np.ndarray:
        x = np.copy(x)
        mask = np.isnan(x)
        x[mask] = self.mean
        return x

    def sample(self, x: np.ndarray) -> np.ndarray:
        x = np.copy(x)
        mask = np.isnan(x)
        x[mask] = ss.norm.rvs(self.mean, self.stddev, size=np.count_nonzero(mask))
        return x

    def moment(self, k: int = 1) -> float:
        return ss.norm.moment(k, self.mean, self.stddev)

    def params_count(self) -> int:
        return 2

    def params_dict(self) -> dict:
        return {
            'mean': self.mean,
            'stddev': self.stddev
        }




In [None]:
#@title tree


class TreeNode:
    """A simple class to model a node of a tree."""
    def __init__(self, node_id: int, parent: TreeNode = None):
        """
        Initialize a binary CLT.

        :param node_id: The ID of the node.
        :param parent: The parent node.
        """
        self.id = node_id
        self.__parent = None
        self.__children = []
        self.set_parent(parent)

    def get_id(self) -> int:
        """
        Get the ID of the node.

        :return: The ID of the node.
        """
        return self.id

    def get_parent(self) -> TreeNode:
        """
        Get the parent node.

        :return: The parent node, None if the node has no parent.
        """
        return self.__parent

    def get_children(self) -> List[TreeNode]:
        """
        Get the children list of the node.

        :return: The children list of the node.
        """
        return self.__children

    def set_parent(self, parent: TreeNode):
        """
        Set the parent node and update its children list.

        :param parent: The parent node.
        """
        if self.__parent is None and parent is not None:
            self.__parent = parent
            self.__parent.get_children().append(self)

    def is_leaf(self) -> bool:
        """
        Check whether the node is leaf.

        :return: True if the node is leaf, False otherwise.
        """
        return len(self.__children) == 0

    def get_n_nodes(self) -> int:
        """
        Get the number of the nodes of the tree rooted at self.

        :return: The number of nodes of the tree rooted at self.
        """
        n_nodes = 0
        queue = [self]
        while queue:
            current_node = queue.pop(0)
            queue.extend(current_node.get_children())
            n_nodes += 1
        return n_nodes

    def get_tree_scope(self) -> Tuple[list, list]:
        """
        Return the list of predecessors and the related scope of the tree rooted at self.
        Note that tree[root] must be -1, as it doesn't have a predecessor.

        :return tree: List of predecessors.
        :return scope: The related scope list.
        """
        tree = []
        scope = []
        queue = [self]
        while queue:
            current_node = queue.pop(0)
            queue.extend(current_node.get_children())
            scope.append(current_node.id)
            tree.append(current_node.get_parent().id if current_node.get_parent() is not None else -1)
        tree[scope.index(self.id)] = -1
        tree = [scope.index(t) if t != -1 else -1 for t in tree]
        return tree, scope


def build_tree_structure(tree: Union[List[int], np.ndarray], scope: Optional[List[int]] = None) -> TreeNode:
    """
    Build a Tree node recursive data structure given a tree structure encoded as a list of predecessors.
    Note that tree[root] must be -1, as it doesn't have a predecessor.
    Optionally, a scope can be used to specify the tree node ids.

    :param tree: The tree structure, as a sequence of predecessors.
    :param scope: An optional scope, as a list of ids.
    :return: The Tree node structure's root.
    :raises ValueError: If the tree structure is not compatible with the root node.
    :raises ValueError: If the scope contains duplicates.
    :raises ValueError: If the scope is incompatible with the tree structure.
    """
    if isinstance(tree, np.ndarray):
        tree = tree.tolist()
    if tree.count(-1) != 1:
        raise ValueError("Invalid tree structure")
    root_idx = tree.index(-1)

    if scope is None:
        root_id = root_idx
        nodes = [TreeNode(node_id) for node_id in range(len(tree))]
        for node_id, parent_id in enumerate(tree):
            if parent_id != -1:
                nodes[node_id].set_parent(nodes[parent_id])
    else:
        if len(set(scope)) != len(scope):
            raise ValueError("The scope must not contain duplicates")
        if len(scope) != len(tree):
            raise ValueError("Invalid scope's number of variables")

        root_id = scope[root_idx]
        nodes = {node_id: TreeNode(node_id) for node_id in scope}
        for node_idx, parent_idx in enumerate(tree):
            if parent_idx != -1:
                node_id = scope[node_idx]
                parent_id = scope[parent_idx]
                nodes[node_id].set_parent(nodes[parent_id])

    return nodes[root_id]


def compute_bfs_ordering(tree: Union[List[int], np.ndarray]) -> Union[List[int], np.ndarray]:
    """
    Compute the breadth-first-search variable ordering given a tree structure.
    Note that tree[root] must be -1, as it doesn't have a predecessor.

    :param tree: The tree structure, as a sequence of predecessors.
    :return: The BFS variable ordering as a Numpy array.
    """
    # Build the tree structure first
    root = build_tree_structure(tree)

    # Pre-Order exploration
    ordering = list()
    nodes_queue = deque([root])
    while nodes_queue:
        node = nodes_queue.popleft()
        ordering.append(node.get_id())
        if not node.is_leaf():
            nodes_queue.extend(node.get_children())

    if isinstance(tree, list):
        return ordering
    return np.array(ordering, dtype=tree.dtype)


def maximum_spanning_tree(root: int, adj_matrix: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute the maximum spanning tree of a graph starting from a given root node.

    :param root: The root node index.
    :param adj_matrix: The graph's adjacency matrix.
    :return: The breadth first traversal ordering and the maximum spanning tree.
             The maximum spanning tree is given as a list of predecessors.
    """
    # Compute the maximum spanning tree of an adjacency matrix
    # Note adding one to the adjacency matrix, because the graph must be fully connected
    mst = sp.csgraph.minimum_spanning_tree(-(adj_matrix + 1.0), overwrite=True)
    bfs, tree = sp.csgraph.breadth_first_order(
        mst, directed=False, i_start=root, return_predecessors=True
    )
    tree[root] = -1
    return bfs, tree



In [None]:
#@title partition


class Partition:

    def __init__(
        self,
        row_ids: list,
        col_ids: list,
        uncond_vars: list,
        parent_partition: Optional[Partition] = None,
        is_naive: Optional[bool] = False,
        is_conj: Optional[bool] = False,
    ):
        """
        Create a partition, i.e. an object modeling a data slice (and some of its properties)
        by keeping track of its indices (i.e. row_ids and col_ids).

        :param row_ids: The row indices of the modeled slice.
        :param col_ids: The column indices of the modeled slice.
        :param uncond_vars:  Ordered list of variables from which the conjunction variables will be extracted
         to horizontally split the current partition.
        :param parent_partition: The optional parent partition
        :param is_naive: If True and determinism is not required, a naive factorization will be learnt over the data
         slice modeled by the current partition; otherwise, if True and determinism is required, a disjunction
         will be learnt over the data slice modeled by the current partition.
        :param is_conj: True if the modeled slice is associated to a conjunction, i.e. every row in the slice is
         equal to the others.
        """
        self.row_ids = np.array(row_ids)
        self.col_ids = np.array(col_ids)
        self.uncond_vars = list(uncond_vars)

        self.parent_partition = parent_partition
        self.set_parent_partition(parent_partition)
        self.sub_partitions = []

        self.is_naive = is_naive
        self.is_conj = is_conj
        # discarded assignments, see build_leaf() in xpc.py
        self.disc_assignments = None

    def set_parent_partition(self, parent_partition: Partition):
        """
        Set the parent partition and update its sub_partitions attribute.

        :param parent_partition: The parent partition.
        """
        if parent_partition is not None:
            parent_partition.sub_partitions.append(self)

    def is_partitioned(self):
        """
        :return: True if the partition is partitioned, False otherwise.
        """
        return len(self.sub_partitions) != 0

    def is_horizontally_partitioned(self):
        """
        :return: True if the partition is horizontally partitioned, False otherwise.
        """
        ret = False
        if self.is_partitioned():
            ret = len(self.row_ids) > len(self.sub_partitions[0].row_ids)
        return ret

    def get_slice(self, data: np.ndarray) -> np.ndarray:
        """
        Slice the input data matrix according to self.

        :param data: The data to be sliced.
        :return: The data slice.
        """
        return data[self.row_ids][:, self.col_ids]

    def get_vertical_split(self) -> list[np.ndarray, np.ndarray]:
        """
        If possible, split vertically the current partition.
        """
        vertical_split = []
        cond_vars = [col_id for col_id in self.col_ids if col_id not in self.uncond_vars]
        if len(cond_vars) != 0 and len(cond_vars) != len(self.col_ids):
            vertical_split = [np.asarray(cond_vars), np.asarray(self.uncond_vars)]
        return vertical_split

    def get_conj_row_ids(
        self,
        data: np.ndarray,
        conj: list,
        min_part_inst: int,
    ) -> np.ndarray:
        """
        Return the row ids of the instances satisfying the given conjunction.
        The row ids must be found within the slice modeled by the self partition.

        :param data: The input data.
        :param conj: Conjunction modeled as a list of two lists: the first contains
                     the IDs of the variables, the second  the related assignment.
                     For example, [[8,3],[1,0]] models the conjunction X8=1 and X3=0.
        :param min_part_inst: the minimum number of instances allowed to return.
        :return: The row ids of the instances satisfying the given conjunction iff
                 the number of such instances is greater or equal than the minimum number
                 of instances allowed to return; otherwise, an empty array.
        """
        if len(self.row_ids) < min_part_inst:
            conj_row_ids = np.empty(0, dtype=np.int32)
        else:
            conj_row_ids = self.row_ids.copy()
            for i in range(len(conj[0])):
                conj_row_ids = conj_row_ids[np.where(data[np.array(conj_row_ids), conj[0][i]] == conj[1][i])[0]]
                if len(conj_row_ids) < min_part_inst:
                    conj_row_ids = np.empty(0, dtype=np.int32)
                    break
        return conj_row_ids

    def get_horizontal_split(
        self,
        data: np.ndarray,
        min_part_inst: int,
        conj_len: int,
        arity: int,
        sd: bool,
        random_state: np.random.RandomState
    ) -> Tuple[list, np.ndarray, list]:
        """
        If possible, split horizontally the current partition.

        :param data: The input data matrix.
        :param min_part_inst: The minimum number of instances allowed per partition.
        :param conj_len: The conjunction length.
        :param arity: The maximum number of subpartitions for an horizontal partitioned partition.
        :param sd: True if the generated tree will be used to model a SD PC, False otherwise.
        :param random_state: The random state.
        """
        if len(self.uncond_vars) < conj_len or len(self.row_ids) < 2 * min_part_inst:
            return [], np.array([]), []

        uncond_vars = self.uncond_vars.copy()
        if not sd:
            random_state.shuffle(uncond_vars)
        conj_vars = uncond_vars[:conj_len]

        # list of all possible assignments for a conjunction with length conj_len
        assignments = [list(assignment) for assignment in itertools.product([0, 1], repeat=len(conj_vars))]
        random_state.shuffle(assignments)

        discarded_row_ids = self.row_ids.copy()
        conj_row_ids_l = []
        for assignment in assignments:
            conj = [conj_vars, assignment]
            conj_row_ids = self.get_conj_row_ids(data, conj, min_part_inst)
            if len(conj_row_ids) == len(self.row_ids):
                return [], discarded_row_ids, conj_vars
            if len(conj_row_ids) != 0 and len(discarded_row_ids) - len(conj_row_ids) >= min_part_inst:
                discarded_row_ids = np.setdiff1d(discarded_row_ids, conj_row_ids)
                conj_row_ids_l.append(conj_row_ids)
                if len(conj_row_ids_l) == arity - 1 or len(discarded_row_ids) < 2 * min_part_inst:
                    break

        if conj_row_ids_l:
            return conj_row_ids_l, discarded_row_ids, conj_vars
        return [], np.array([]), []


def generate_random_partitioning(
    data: np.ndarray,
    min_part_inst: int,
    n_max_parts: int,
    conj_len: int,
    arity: int,
    sd: bool,
    uncond_vars: list,
    random_state: np.random.RandomState
):
    """
    Create a random partition tree.

    :param data: The input data matrix.
    :param min_part_inst: The minimum number of instances allowed per partition.
    :param n_max_parts: The maximum number of partitions in the tree.
    :param conj_len: The conjunction length.
    :param arity: The maximum number of subpartitions for an horizontal partitioned partition.
    :param sd: True if the generated tree will be used to model a SD PC, False otherwise.
    :param uncond_vars: Ordered list of variables from which the first *conj_len* ones
     are extracted as conjunction variables to partition the root partition.
    :param random_state: The random state.

    :return partition_root: The partition root of the tree.
    :return cl_parts_l: List containing the leaf partitions over which a CLTree will be learnt.
    :return conj_vars_l: List of lists. Every sublist contains the variables of a conjunction (e.g. [[3, 5]]).
     If a sublist occurs before another, then the former has been used first. There are no duplicates.
    :return n_partitions: The number of partitions in the generated tree.
    """
    partition_root = Partition(row_ids=np.arange(data.shape[0]),
                               col_ids=uncond_vars,
                               uncond_vars=uncond_vars,
                               parent_partition=None)
    n_partitions = 0
    conj_vars_l = []
    cl_parts_l = []
    leaves = [partition_root]
    while leaves and n_partitions + len(leaves) < n_max_parts:
        # randomly pop a leaf partition
        part = leaves.pop(random_state.randint(len(leaves)))

        conj_row_ids_l, discarded_row_ids, conj_vars = \
            part.get_horizontal_split(data, min_part_inst, conj_len, arity, sd, random_state)

        if len(discarded_row_ids):
            if conj_vars not in conj_vars_l:
                conj_vars_l.append(conj_vars)

            # this ensures a general definition of the list uncond_vars, preserving its order
            uncond_vars = [uv for uv in part.uncond_vars if uv not in conj_vars]

            part_buffer = [
                Partition(row_ids=discarded_row_ids,
                          col_ids=part.col_ids.copy(),
                          uncond_vars=uncond_vars.copy(),
                          parent_partition=part)]

            for conj_row_ids in conj_row_ids_l:
                part_buffer.append(
                    Partition(row_ids=conj_row_ids,
                              col_ids=part.col_ids.copy(),
                              uncond_vars=uncond_vars.copy(),
                              parent_partition=part))

            discarded_assignments = \
                {tuple(assignment) for assignment in itertools.product([0, 1], repeat=len(conj_vars))}

            for k in range(len(part_buffer)):
                part = part_buffer[k]
                vertical_split = part.get_vertical_split()
                if vertical_split:
                    n_partitions += 1
                    is_conj = False if not k else True
                    p = Partition(row_ids=part.row_ids.copy(),
                                  col_ids=vertical_split[0].copy(),
                                  uncond_vars=[],
                                  parent_partition=part,
                                  is_naive=True,
                                  is_conj=is_conj)
                    if is_conj:
                        discarded_assignments.remove(tuple(p.get_slice(data)[0]))

                    leaves.append(
                        Partition(row_ids=part.row_ids.copy(),
                                  col_ids=vertical_split[1].copy(),
                                  uncond_vars=vertical_split[1].copy(),
                                  parent_partition=part))
                else:
                    leaves.append(part)

            if part_buffer[0].sub_partitions:
                part_buffer[0].sub_partitions[0].disc_assignments = \
                    np.array(list(discarded_assignments))
        else:
            n_partitions += 1
            cl_parts_l.append(part)

    # in case the process ended because n_partitions + len(leaves) > n_max_parts
    n_partitions += len(leaves)
    cl_parts_l.extend(leaves)
    return partition_root, cl_parts_l, conj_vars_l, n_partitions

In [None]:
#@title random graph generation

class RegionGraph:
    def __init__(self, n_features: int, depth: int, random_state: Optional[RandomState] = None):
        """
        Initialize a region graph.

        A region graph is defined w.r.t. a set of indices of random variable in a SPN.
        A *region* R is defined as a non-empty subset of the indices,
        and represented as sorted tuples with unique entries.
        A *partition* P of a region R is defined as a collection of non-empty sets,
        which are non-overlapping, and whose
        union is R. R is also called *parent* region of P.
        Any region C such that C is in partition P is called *child region* of P.
        So, a region is represented as a sorted tuple of integers (unique elements)
        and a partition is represented as a sorted tuple of regions (non-overlapping, not-empty, at least 2).
        A *region graph* is an acyclic, directed, bi-partite graph over regions and partitions.
        So, any child of a region R is a partition of R, and any child of a partition is
        a child region of the partition. The root of the region graph
        is a sorted tuple composed of all the elements. The leaves of the region graph must also be regions.
        They are called input regions, or leaf regions.
        Given a region graph, we can easily construct a corresponding SPN:
        1) Associate I distributions to each input region.
        2) Associate K sum nodes to each other (non-input) region.
        3) For each partition P in the region graph,
        take all cross-products (as product nodes) of distributions/sum nodes associated with the child regions.
        Connect these products as children of all sum nodes in the parent region of P.
        In the end, this procedure will always deliver a complete and decomposable SPN.

        :param n_features: The number of features.
        :param depth: The maximum depth.
        :param random_state: The random state. It can be either None, a seed integer or a Numpy RandomState.
        :raises ValueError: If a parameter is out of domain.
        """
        if n_features <= 0:
            raise ValueError("The number of features must be positive")
        if depth <= 0:
            raise ValueError("The region graph depth must be positive")
        if depth > int(np.log2(n_features)):
            raise ValueError("Invalid region graph depth based on the number of features")

        self.items = tuple(range(n_features))
        self.depth = depth

        # Check the random state
        self.random_state = check_random_state(random_state)

    def random_layers(self) -> List[List[tuple]]:
        """
        Generate a list of layers randomly over a single repetition of features.

        :return: A list of layers, alternating between regions and partitions.
        """
        root = [self.items]
        layers = [root]

        for i in range(self.depth):
            regions = []
            partitions = []
            for r in layers[i * 2]:
                mid = len(r) // 2
                permutation = self.random_state.permutation(r).tolist()
                p0 = tuple(sorted(permutation[:mid]))
                p1 = tuple(sorted(permutation[mid:]))
                regions.append(p0)
                regions.append(p1)
                partitions.append((p0, p1))
            layers.append(partitions)
            layers.append(regions)

        return layers

    def make_layers(self, n_repetitions: int = 1) -> List[List[tuple]]:
        """
        Generate a random graph's layers over multiple repetitions of features.

        :param n_repetitions: The number of repetitions.
        :return: A list of layers, alternating between regions and partitions.
        :raises ValueError: If a parameter is out of domain.
        """
        if n_repetitions <= 0:
            raise ValueError("The number of repetitions must be positve")

        root = [self.items]
        graph_layers = [root] + [[]] * (self.depth * 2)

        for _ in range(n_repetitions):
            layers = self.random_layers()
            for h in range(1, len(layers)):
                graph_layers[h] = graph_layers[h] + layers[h]

        return graph_layers

In [None]:
#@title binary Chow-Liu Tree : CLT


# a random state type is either an integer seed value or a Numpy RandomState instance.
RandomState = Union[int, np.random.RandomState]


def check_random_state(random_state: Optional[RandomState] = None) -> np.random.RandomState:
    """
    Check a possible input random state and return it as a Numpy's RandomState object.

    :param random_state: The random state to check. If None a new Numpy RandomState will be returned.
                         If not None, it can be either a seed integer or a np.random.RandomState instance.
                         In the latter case, itself will be returned.
    :return: A Numpy's RandomState object.
    :raises ValueError: If the random state is not None or a seed integer or a Numpy RandomState object.
    """
    if random_state is None:
        return np.random.RandomState()
    if isinstance(random_state, int):
        return np.random.RandomState(random_state)
    if isinstance(random_state, np.random.RandomState):
        return random_state
    raise ValueError("The random state must be either None, a seed integer or a Numpy RandomState object")


class BinaryCLT(Leaf):
    LEAF_TYPE = LeafType.DISCRETE

    def __init__(
        self,
        scope: List[int],
        root: Optional[int] = None,
        tree: Optional[Union[List[int], np.ndarray]] = None,
        params: Optional[Union[List[List[List[float]]], np.ndarray]] = None
    ):
        """
        Initialize Binary Chow-Liu Tree (CLT) multi-variate leaf node.

        :param scope: The scope of the leaf.
        :param root: The root node of the CLT. If None it will be chosen randomly.
        :param tree: A sequence of variable ids predecessors (encoding the tree structure).
        :param params: The CLT conditional probability tables (CPTs), as a (N, 2, 2) Numpy array in logarithmic scale.
                       Note that params[i, l, k] = log P(X_i=k | Pa(X_i)=l).
        :raises ValueError: If the root variable is not in scope.
        :raises ValueError: If the tree structure is not compatible with the number of variables and root node.
        :raises ValueError: If the CPTs parameters are invalid.
        """
        super().__init__(scope)

        if tree is not None:
            if isinstance(tree, list):
                tree = np.array(tree, dtype=np.int32)

            # Check tree structure with respect to the scope
            if len(tree) != len(self.scope):
                raise ValueError("Invalid tree structure's number of variables")

            # Check root node with respect to the tree structure
            if root is None:
                root, = np.argwhere(tree == -1)
                if len(root) != 1:
                    raise ValueError("Invalid tree structure's root node")
                root = root.item()
            elif root not in self.scope:
                raise ValueError("The root variable must be in scope")
            else:
                root = self.scope.index(root)
            if tree[root] != -1:
                raise ValueError("Invalid tree structure's root node")

            # Compute BFS variable ordering
            bfs = compute_bfs_ordering(tree)
        else:
            bfs = None
            # Check root node with respect to the scope
            if root is not None:
                if root not in self.scope:
                    raise ValueError("The root variable must be in scope")
                root = self.scope.index(root)
        self.root = root
        self.tree = tree
        self.bfs = bfs

        # Initialize the parameters
        if isinstance(params, list):
            params = np.array(params, dtype=np.float32)
            if params.shape != (len(self.scope), 2, 2):
                raise ValueError("Invalid conditional probability table (CPT) shape")
            if not np.allclose(np.exp(params).sum(axis=2), 1.0):
                raise ValueError("Invalid conditional probability table (CPT) values")
        self.params = params

    @staticmethod
    def compute_clt_parameters(
        bfs: np.ndarray,
        tree: np.ndarray,
        priors: np.ndarray,
        joints: np.ndarray
    ) -> np.ndarray:
        """
        Compute the parameters of the CLTree given the tree structure and the priors and joints distributions.

        This function returns the conditional probability tables (CPTs) in a tensorized form.
        Note that params[i, l, k] = P(X_i=k | Pa(X_i)=l).
        A special case is made for the root distribution which is not conditioned.
        Note that params[root, :, k] = P(X_root=k).

        :param bfs: The bfs structure, i.e. a sequence of successors in a breadth-first traversal.
        :param tree: The tree structure, i.e. a sequence of predecessors in a tree structure.
        :param priors: The priors distributions.
        :param joints: The joints distributions.
        :return: The conditional probability tables (CPTs) in a tensorized form.
        """
        root_id = bfs[0]
        n_features = len(bfs)
        vs = np.arange(n_features)

        # Compute the conditional probabilities (by einsum operation)
        params = np.einsum('ikl,il->ilk', joints[vs, tree], np.reciprocal(priors[tree]))
        params[root_id] = priors[root_id]

        # Re-normalize the factors, because there can be FP32 approximation errors
        params /= np.sum(params, axis=2, keepdims=True)
        return params

    def em_init(self, random_state: np.random.RandomState):
        if self.tree is None:
            raise ValueError("The CLT's structure must be already initialized")

        probs = random_state.rand(len(self.scope), 2)
        probs[self.root, 0] = probs[self.root, 1]
        self.params[:, :, 1] = probs
        self.params[:, :, 0] = 1.0 - probs
        self.params = np.log(self.params)

    def em_step(self, stats: np.ndarray, data: np.ndarray, step_size: float):
        if self.tree is None:
            raise ValueError("The CLT's structure must be already initialized")

        alpha = np.finfo(np.float16).eps  # Use a very small Laplace smoothing factor
        total_stats = np.sum(stats)
        weighted_features = np.expand_dims(stats, axis=1) * data

        # Compute prior distributions
        priors_stats = np.sum(weighted_features, axis=0)
        priors = np.empty(shape=(len(self.scope), 2), dtype=np.float32)
        priors[:, 1] = (priors_stats + 2.0 * alpha) / (total_stats + 4.0 * alpha)
        priors[:, 0] = 1.0 - priors[:, 1]

        # Compute conditional sufficient statistics
        conditional_stats = np.empty(shape=(len(self.scope), 2), dtype=np.float32)
        conditional_stats[:, 1] = np.sum(weighted_features * data[:, self.tree], axis=0)
        conditional_stats[:, 0] = priors_stats - conditional_stats[:, 1]

        # Update the parameters
        params = np.empty_like(self.params)
        params[:, :, 1] = (conditional_stats + alpha) / (total_stats * priors[self.tree] + 4.0 * alpha)
        params[:, :, 0] = 1.0 - params[:, :, 1]
        params[self.root, 0] = params[self.root, 1] = priors[self.root]
        params = (1.0 - step_size) * np.exp(self.params) + step_size * params

        # Re-normalize the factors, because there can be FP32 approximation errors
        params /= np.sum(params, axis=2, keepdims=True)
        self.params = np.log(params)

    def fit(
        self,
        data: np.ndarray,
        domain: List[list],
        alpha: float = 0.1,
        random_state: Optional[RandomState] = None,
        **kwargs
    ):
        """
        Fit the distribution parameters (and structure if necessary) given the domain and some training data.

        :param data: The training data.
        :param domain: The domain of the distribution leaf.
        :param alpha: The Laplace smoothing factor.
        :param random_state: The random state. It can be either None, a seed integer or a Numpy RandomState.
        :param kwargs: Optional parameters.
        :raises ValueError: If the random state is not valid.
        :raises ValueError: If a parameter is out of domain.
        """
        _, n_features = data.shape
        if len(domain) != n_features:
            raise ValueError("Each data column should correspond to a random variable having a domain")
        if not all(d == [0, 1] for d in domain):
            raise ValueError("The domains must be binary for a Binary CLT distribution")
        if alpha < 0.0:
            raise ValueError("The Laplace smoothing factor must be non-negative")

        # Check the random state
        random_state = check_random_state(random_state)

        # Choose a root variable randomly, if not specified
        if self.root is None:
            self.root = random_state.choice(len(self.scope))

        # Estimate the priors and joints probabilities
        priors, joints = estimate_priors_joints(data, alpha=alpha)

        if self.tree is None:
            # Compute the mutual information
            mutual_info = compute_mutual_information(priors, joints)

            # Compute the CLT structure
            self.bfs, self.tree = maximum_spanning_tree(self.root, mutual_info)

        # Compute the CLT parameters (in log-space), using the joints and priors probabilities
        params = self.compute_clt_parameters(self.bfs, self.tree, priors, joints)
        self.params = np.log(params)

    def message_passing(
        self, x: np.ndarray,
        obs_mask: np.ndarray,
        return_lls: bool = True,
        reduce: str = 'mar'
    ) -> np.ndarray:
        """
        Compute the messages passed from the leaves to the root node.

        :param x: The input data.
        :param obs_mask: The mask of observed values.
        :param return_lls: Whether to compute and return the log-likelihoods.
        :param reduce: The method used to reduce the messages of missing values.
                       It can be either 'mar' (marginalize the message) or 'mpe' (maximum probable explanation).
        :return: The messages array if return_lls is False.
                 The log-likelihoods if return_lls is True.
        """
        n_samples, n_features = x.shape
        messages = np.zeros(shape=(n_features, n_samples, 2), dtype=np.float32)

        # Let's proceed bottom-up
        for j in reversed(self.bfs[1:]):
            mask = obs_mask[:, j]
            mis_mask = ~mask
            obs_values = x[mask, j].astype(np.int64)
            msg = np.expand_dims(messages[j], axis=1)

            # Compute the messages for observed data
            messages[self.tree[j], mask] += self.params[j, :, obs_values] + msg[mask, :, obs_values]

            # Compute the messages for unobserved data
            if np.any(mis_mask):
                parent_msg = self.params[j] + msg[mis_mask]
                if reduce == 'mar':
                    messages[self.tree[j], mis_mask] += logsumexp(parent_msg, axis=2)
                elif reduce == 'mpe':
                    messages[self.tree[j], mis_mask] += np.max(parent_msg, axis=2)
                else:
                    raise ValueError("Unknown reduce method called {}".format(reduce))

        if not return_lls:
            return messages

        lls = np.empty(n_samples, dtype=np.float32)
        mask = obs_mask[:, self.root]
        mis_mask = ~mask
        obs_values = x[mask, self.root].astype(np.int64)
        msg = messages[self.root]

        # Compute the messages for observed data at root node
        lls[mask] = self.params[self.root, 0, obs_values] + msg[mask, obs_values]

        # Compute the messages for unobserved data at root node
        if np.any(mis_mask):
            lls[mis_mask] = logsumexp(self.params[self.root, 0] + msg[mis_mask], axis=1)

        return lls

    def likelihood(self, x: np.ndarray) -> np.ndarray:
        return np.exp(self.log_likelihood(x))

    def log_likelihood(self, x: np.ndarray) -> np.ndarray:
        n_samples, n_features = x.shape

        # Build the mask of samples with missing values (used for marginalization)
        mis_mask = np.isnan(x)
        mar_mask = np.any(mis_mask, axis=1)

        if np.any(mar_mask):
            evi_mask = ~mar_mask
            obs_mask = ~mis_mask
            lls = np.empty(n_samples, dtype=np.float32)

            # Vectorized implementation of full-evidence inference
            vs = np.arange(n_features)
            z = x[evi_mask]
            z_cond = z[:, self.tree].astype(np.int64, copy=False)
            z_vals = z[:, vs].astype(np.int64, copy=False)
            lls[evi_mask] = np.sum(self.params[vs, z_cond, z_vals], axis=1)

            # Semi-vectorized implementation of marginal inference
            z = x[mar_mask]
            lls[mar_mask] = self.message_passing(z, obs_mask[mar_mask], return_lls=True, reduce='mar')
            return np.expand_dims(lls, axis=1)

        # Vectorized implementation (without masking) of full-evidence inference
        vs = np.arange(n_features)
        x_cond = x[:, self.tree].astype(np.int64, copy=False)
        x_vals = x[:, vs].astype(np.int64, copy=False)
        lls = np.sum(self.params[vs, x_cond, x_vals], axis=1, keepdims=True)
        return lls

    def mpe(self, x: np.ndarray) -> np.ndarray:
        x = np.copy(x)
        mis_mask = np.isnan(x)
        obs_mask = ~mis_mask

        # Semi-vectorized implementation of MPE inference
        messages = self.message_passing(x, obs_mask, return_lls=False, reduce='mpe')

        # Compute MPE at the root feature
        mask = mis_mask[:, self.root]
        msg = self.params[self.root, 0] + messages[self.root, mask]
        x[mask, self.root] = np.argmax(msg, axis=1)

        # Compute MPE at the other features, by using the accumulated messages
        for j in self.bfs[1:]:
            mask = mis_mask[:, j]
            obs_parent_values = x[mask, self.tree[j]].astype(np.int64)
            msg = self.params[j, obs_parent_values] + messages[j, mask]
            x[mask, j] = np.argmax(msg, axis=1)
        return x

    def sample(self, x: np.ndarray) -> np.ndarray:
        x = np.copy(x)
        mis_mask = np.isnan(x)
        obs_mask = ~mis_mask

        # Semi-vectorized implementation of conditional sampling
        messages = self.message_passing(x, obs_mask, return_lls=False, reduce='mar')

        # Sample the root feature
        mask = mis_mask[:, self.root]
        log_probs = self.params[self.root, 0, 1] + messages[self.root, mask, 1]
        x[mask, self.root] = ss.bernoulli.rvs(np.exp(log_probs))

        # Sample the other features, by using the accumulated messages
        for j in self.bfs[1:]:
            mask = mis_mask[:, j]
            obs_parent_values = x[mask, self.tree[j]].astype(np.int64)
            log_probs = self.params[j, obs_parent_values, 1] + messages[j, mask, obs_parent_values]
            x[mask, j] = ss.bernoulli.rvs(np.exp(log_probs))
        return x

    def moment(self, k: int = 1) -> float:
        raise NotImplementedError("Computation of moments on Binary CLTs not yet implemented")

    def params_count(self) -> int:
        return 1 + len(self.tree) + self.params.size

    def params_dict(self) -> dict:
        return {
            'root': None if self.root is None else self.scope[self.root],
            'tree': self.tree,
            'params': self.params
        }

    def to_pc(self) -> Node:
        """
        Convert a Chow-Liu Tree into a smooth, deterministic and structured-decomposable PC

        :return: A smooth, deterministic and structured-decomposable PC.
        """
        # Build the tree structure
        root = build_tree_structure(self.tree, scope=self.scope)

        # Build the factors dictionary
        factors = {self.scope[i]: np.exp(self.params[i]) for i in range(len(self.tree))}

        # Post-Order exploration
        neg_buffer, pos_buffer = [], []
        nodes_stack = [root]
        last_node_visited = None
        while nodes_stack:
            node = nodes_stack[-1]
            if node.is_leaf() or (last_node_visited in node.get_children()):
                leaves: List[Union[Bernoulli, Sum]] = [
                    Bernoulli(node.get_id(), p=0.0),
                    Bernoulli(node.get_id(), p=1.0)
                ]
                if not node.is_leaf():
                    neg_prod = Product(children=[leaves[0]] + neg_buffer[-len(node.get_children()):])
                    pos_prod = Product(children=[leaves[1]] + pos_buffer[-len(node.get_children()):])
                    del neg_buffer[-len(node.get_children()):]
                    del pos_buffer[-len(node.get_children()):]
                    sum_children = [neg_prod, pos_prod]
                else:
                    sum_children = leaves
                weights = factors[node.get_id()]
                neg_buffer.append(
                    Sum(children=sum_children, weights=weights[0])
                )
                pos_buffer.append(
                    Sum(children=sum_children, weights=weights[1])
                )
                last_node_visited = nodes_stack.pop()
            else:
                nodes_stack.extend(node.get_children())
        # Equivalently, pos = neg_buffer[0]
        pc = pos_buffer[0]
        return assign_ids(pc)

    def get_scopes(self):
        """
        Return a list containing the scope of every node in the PC equivalent to the
        current CLTree (see to_pc() method). Every scope occurs once in the list.

        :return: The list of scopes.
        """
        scopes = []
        scopes_stack = []

        # Post-Order exploration
        root = build_tree_structure(self.tree, scope=self.scope)
        nodes_stack = [root]
        last_node_visited = None
        while nodes_stack:
            node = nodes_stack[-1]
            if node.is_leaf() or (last_node_visited in node.get_children()):
                if node.is_leaf():
                    scopes_stack.append([node.get_id()])
                else:
                    scopes_temp = scopes_stack[-len(node.get_children()):]
                    del scopes_stack[-len(node.get_children()):]
                    scopes_temp.append([node.get_id()])
                    merged_scope = [var for scope in scopes_temp for var in scope]
                    scopes_stack.append(merged_scope)
                    scopes.append(merged_scope)
                last_node_visited = nodes_stack.pop()
            else:
                nodes_stack.extend(node.get_children())

        return scopes



In [None]:
#@title binary cutset network : CNet



class ORNode(Node):
    def __init__(
        self,
        scope: Optional[List[int]],
        children: Optional[List[Node]] = None,
        weights: Optional[Union[List[float], np.ndarray]] = None,
        or_id: Optional[int] = None
    ):
        """
        Initialize an OR node given weights, child instances and child nodes.

        :param scope: The scope of the OR node.
        :param children: The child nodes of the OR node.
        :param weights: The weights of the OR node.
        :param or_id: The id of the OR node.
        """
        if weights is not None:
            if isinstance(weights, list):
                weights = np.array(weights, dtype=np.float32)
            if not np.isclose(np.sum(weights), 1.0):
                raise ValueError("Weights don't sum up to 1")
        self.weights = weights
        self.or_id = or_id
        self.row_indices = None
        self.col_indices = None
        self.clt = None

        super().__init__(scope, children)

    def assign_indices(
        self,
        row_indices: Optional[List[int], np.ndarray],
        col_indices: Optional[List[int], np.ndarray]
    ):
        """
        Assign the corresponding indices of the OR node's partition in the original data set.

        :param row_indices: Row indices of the partition.
        :param col_indices: Column indices of the partition.
        :return:
        """
        self.row_indices = row_indices
        self.col_indices = col_indices

    def likelihood(self, x: np.ndarray) -> np.ndarray:
        pass

    def log_likelihood(self, x: np.ndarray) -> np.ndarray:
        pass


class BinaryCNet(ORNode):
    def __init__(
        self,
        scope: Optional[List[int]],
        children: Optional[List[Node]] = None,
        weights: Optional[Union[List[float], np.ndarray]] = None,
        or_id: Optional[int] = None
    ):
        """
        Initialize a binary cutset network (CNet).

        :param scope: The scope of the binary CNet.
        :param children: The child OR nodes of the binary CNet.
        :param weights: The weights of the current OR node.
        :param or_id: The id of the current OR node.
        """
        super().__init__(scope, children, weights, or_id)

    def fit(
        self,
        data: np.ndarray,
        alpha: float = 0.01,
        min_n_samples: int = 10,
        min_n_features: int = 1,
        min_mean_entropy: float = 0.01
    ):
        """
        Fit the structure and the MLE parameters given some training data and hyper-parameters.

        :param data: The training data.
        :param alpha: The Laplace smoothing factor.
        :param min_n_samples: The minimum number of samples to split.
        :param min_n_features: The minimum number of features to split.
        :param min_mean_entropy: The minimum mean entropy of RVs given the data to split.
        :return:
        """
        n_samples, n_features = data.shape
        self.scope = list(range(n_features))
        self.assign_indices(row_indices=np.arange(n_samples), col_indices=np.arange(n_features))
        root = BinaryCNet(scope=list(range(n_features)))
        root.assign_indices(row_indices=np.arange(n_samples), col_indices=np.arange(n_features))
        node_stack = [root]
        while node_stack:
            node = node_stack.pop(0)
            partition = data[node.row_indices][:, node.col_indices]
            n_samples, n_features = partition.shape
            if n_samples <= min_n_samples or n_features <= min_n_features:
                # stopped due to few samples or features
                node.fit_clt(data=partition, alpha=alpha)
                continue
            best_or_idx, mean_entropy, max_info_gain = self.__select_variable_entropy(partition, alpha=alpha)
            if mean_entropy < min_mean_entropy or max_info_gain <= 0:
                # stopped due to small entropy or negative information gain
                node.fit_clt(data=partition, alpha=alpha)
                continue
            left_row_indices = node.row_indices[partition[:, best_or_idx] == 0]
            right_row_indices = node.row_indices[partition[:, best_or_idx] == 1]
            child_col_indices = np.delete(node.col_indices, obj=best_or_idx)
            left_weight = (len(left_row_indices) + alpha) / (len(node.row_indices) + 2 * alpha)
            right_weight = 1 - left_weight
            new_scope = node.scope.copy()
            del new_scope[best_or_idx]
            left_child = BinaryCNet(scope=new_scope)
            left_child.assign_indices(row_indices=left_row_indices, col_indices=child_col_indices)
            right_child = BinaryCNet(scope=new_scope)
            right_child.assign_indices(row_indices=right_row_indices, col_indices=child_col_indices)
            node_stack.append(left_child)
            node_stack.append(right_child)
            node.children = [left_child, right_child]
            node.weights = [left_weight, right_weight]
            node.or_id = node.scope[best_or_idx]
        self.or_id = root.or_id
        self.children = root.children
        self.weights = root.weights

    def fit_clt(
        self,
        data: np.ndarray,
        alpha: float = 0.01
    ):
        """
        Fit a Binary CLT for the RVs in the scope of the current OR node.

        :param data: The data partition.
        :param alpha: The laplace smoothing factor.
        :return:
        """
        clt = BinaryCLT(scope=self.scope)
        clt.fit(data=data, domain=[[0, 1]] * len(self.scope), alpha=alpha)
        self.clt = clt

    @staticmethod
    def __select_variable_entropy(
        data: np.ndarray,
        alpha: float = 0.01
    ):
        """
        Select the best cut node based on the reduced entropy (information gain).

        :param data: The training data partition.
        :param alpha: The Laplace smoothing factor.
        :return: The index of the selected RV,
                 the mean entropy of the RVs in the partition,
                 the information gain of the selected RV.
        """
        n_samples, n_features = data.shape
        counts_features = np.sum(data, axis=0)

        prior_counts = compute_prior_counts(data)
        joint_counts = compute_joint_counts(data)
        priors = (prior_counts + 2 * alpha) / (n_samples + 4 * alpha)
        priors[:, 0] = 1.0 - priors[:, 1]
        mean_entropy = -(priors * np.log(priors)).sum() / n_features

        conditionals = np.empty((n_features, n_features, 2, 2), dtype=np.float32)
        # as we are computing the probabilities for all nodes after cutting on a node,
        # the laplace smoothing factor is essentially the same as computing general prior probabilities
        conditionals[:, :, 0, 0] = ((joint_counts[:, :, 0, 0] + 2 * alpha).T / (prior_counts[:, 0] + 4 * alpha)).T
        conditionals[:, :, 0, 1] = ((joint_counts[:, :, 0, 1] + 2 * alpha).T / (prior_counts[:, 0] + 4 * alpha)).T
        conditionals[:, :, 1, 0] = ((joint_counts[:, :, 1, 0] + 2 * alpha).T / (prior_counts[:, 1] + 4 * alpha)).T
        conditionals[:, :, 1, 1] = ((joint_counts[:, :, 1, 1] + 2 * alpha).T / (prior_counts[:, 1] + 4 * alpha)).T

        vs = np.repeat(np.arange(n_features)[None, :], n_features, axis=0)
        vs = vs[~np.eye(vs.shape[0], dtype=bool)].reshape(vs.shape[0], -1)
        parents = np.repeat(np.arange(n_features)[:, None], n_features - 1, axis=1)

        ratio_features = counts_features / n_samples
        entropies = ratio_features * \
                    np.mean(-np.sum(conditionals[parents, vs, 1, :] * np.log(conditionals[parents, vs, 1, :]), axis=-1),
                            axis=1) + \
                    (1 - ratio_features) * \
                    np.mean(-np.sum(conditionals[parents, vs, 0, :] * np.log(conditionals[parents, vs, 0, :]), axis=-1),
                            axis=1)
        info_gains = mean_entropy - entropies
        selected_idx = np.argmax(info_gains)
        return selected_idx, mean_entropy, info_gains[selected_idx]

    def __is_leaf(self):
        """
        Check if the current OR node is a leaf.

        :return: True if the OR node has fitted a binary CLT, otherwise False
        """
        return True if self.clt else False

    def likelihood(self, x: np.ndarray) -> np.ndarray:
        return np.exp(self.log_likelihood(x))

    def log_likelihood(self, x: np.ndarray) -> np.ndarray:
        n_samples, n_features = x.shape
        root = copy.copy(self)
        root.row_indices, root.col_indices = np.arange(n_samples), np.arange(n_features)
        node_stack = [root]
        log_likes = np.zeros(n_samples)
        while node_stack:
            node = node_stack.pop(0)
            partition = x[node.row_indices][:, node.col_indices]
            if node.__is_leaf():
                log_likes[node.row_indices] += node.clt.log_likelihood(partition).squeeze()
                continue
            node_idx = node.scope.index(node.or_id)
            left_child = copy.copy(node.children[0])
            right_child = copy.copy(node.children[1])
            left_child.row_indices = node.row_indices[partition[:, node_idx] == 0]
            right_child.row_indices = node.row_indices[partition[:, node_idx] == 1]
            log_likes[left_child.row_indices] += np.log(node.weights[0])
            log_likes[right_child.row_indices] += np.log(node.weights[1])
            left_child.col_indices = np.delete(node.col_indices, obj=node_idx)
            right_child.col_indices = np.delete(node.col_indices, obj=node_idx)
            node_stack.append(left_child)
            node_stack.append(right_child)
        return log_likes


In [None]:
## for saving, loading and plotting graphs: https://github.com/deeprob-org/deeprob-kit/blob/main/deeprob/spn/structure/io.py

# Sampling, evaluation and inference

In [None]:
#@title evaluation

##---------------------------------------- evaluation helpers for validity check


def collect_nodes(root: Node) -> List[Node]:
    """
    Get all the nodes in a SPN.

    :param root: The root of the SPN.
    :return: A list of nodes.
    """
    return filter_nodes_by_type(root)


def filter_nodes_by_type(
    root: Node,
    ntype: Union[Type[Node], Tuple[Type[Node], ...]] = Node
) -> List[Union[Node, Leaf, Sum, Product]]:
    """
    Get the nodes of some specified types in a SPN.

    :param root: The root of the SPN.
    :param ntype: The node type. Multiple node types can be specified as a tuple.
    :return: A list of nodes of some specific types.
    """
    return list(filter(lambda n: isinstance(n, ntype), bfs(root)))

def check_spn(
    root: Node,
    labeled: bool = True,
    smooth: bool = False,
    decomposable: bool = False,
    structured_decomposable: bool = False
):
    """
    Check a SPN have certain properties. Defaults to checking only 'labeled'.
    This function combines several checks over a SPN, hence reducing the computational effort
    used to retrieve the nodes from the SPN.

    :param root: The root node of the SPN.
    :param labeled: Whether to check if the SPN is correctly labeled.
    :param smooth: Whether to check if the SPN is smooth.
    :param decomposable: Whether to check if the SPN is decomposable.
    :param structured_decomposable: Whether to check if the SPN is structured decomposable.
    :raises ValueError: If the SPN doesn't have a certain property.
    """
    if not is_check_spn_enabled():  # Skip the checks entirely, if specified
        return

    # Collect the nodes starting from the root node (cache)
    nodes = collect_nodes(root)

    # Check the SPN nodes are correctly labeled
    if labeled:
        result = is_labeled(root, nodes=nodes)
        if result is not None:
            raise ValueError(f"SPN is not correctly labeled: {result}")

    # Check the SPN is smooth
    if smooth:
        result = is_smooth(root, nodes=nodes)
        if result is not None:
            raise ValueError(f"SPN is not smooth: {result}")

    # Check the SPN is decomposable
    if decomposable:
        result = is_decomposable(root, nodes=nodes)
        if result is not None:
            raise ValueError(f"SPN is not decomposable: {result}")

    # Check the SPN is structured decomposable
    if structured_decomposable:
        result = is_structured_decomposable(root, nodes=nodes)
        if result is not None:
            raise ValueError(f"SPN is not structured decomposable: {result}")


def is_labeled(root: Node, nodes: Optional[List[Node]] = None) -> Optional[str]:
    """
    Check if the SPN is labeled correctly.
    It checks that the initial id is zero and each id is consecutive.

    :param root: The root of the SPN.
    :param nodes: The list of nodes. If None, it will be retrieved starting from the root node.
    :return: None if the SPN is labeled correctly, a reason otherwise.
    """
    if nodes is None:
        nodes = collect_nodes(root)

    ids = set(map(lambda n: n.id, nodes))
    if None in ids:
        return "Some nodes have missing ids"
    if len(ids) != len(nodes):
        return "Some nodes have repeated ids"
    if min(ids) != 0:
        return "Node ids are not starting at 0"
    if max(ids) != len(ids) - 1:
        return "Node ids are not consecutive"
    return None


def is_smooth(root: Node, nodes: Optional[List[Node]] = None) -> Optional[str]:
    """
    Check if the SPN is smooth (or complete).
    It checks that each child of a sum node has the same scope.

    :param root: The root of the SPN.
    :param nodes: The list of nodes. If None, it will be retrieved starting from the root node.
    :return: None if the SPN is smooth, a reason otherwise.
    """
    if nodes is None:
        nodes = collect_nodes(root)
    sum_nodes: List[Sum] = list(filter(lambda n: isinstance(n, Sum), nodes))

    for node in sum_nodes:
        if len(node.children) == 0:
            return f"Sum node #{node.id} has no children"
        if len(node.children) != len(node.weights):
            return f"Weights and children length mismatch in node #{node.id}"
        if any(map(lambda c: set(c.scope) != set(node.scope), node.children)):
            return f"Children of Sum node #{node.id} have different scopes"
    return None


def is_decomposable(root: Node, nodes: Optional[List[Node]] = None) -> Optional[str]:
    """
    Check if the SPN is decomposable (or consistent).
    It checks that each child of a product node has disjointed scopes.

    :param root: The root of the SPN.
    :param nodes: The list of nodes. If None, it will be retrieved starting from the root node.
    :return: None if the SPN is decomposable, a reason otherwise.
    """
    if nodes is None:
        nodes = collect_nodes(root)
    product_nodes: List[Product] = list(filter(lambda n: isinstance(n, Product), nodes))

    for node in product_nodes:
        if len(node.children) == 0:
            return f"Product node #{node.id} has no children"
        s_scope = set(sum([c.scope for c in node.children], []))
        if set(node.scope) != s_scope:
            return f"Children of Product node #{node.id} don't have disjointed scopes"
    return None


def is_structured_decomposable(root: Node, nodes: Optional[List[Node]] = None) -> Optional[str]:
    """
    Check if the PC is structured decomposable.
    It checks that product nodes follow a vtree.
    Note that if a PC is structured decomposable then it's also decomposable.

    :param root: The root of the PC.
    :param nodes: The list of nodes. If None, it will be retrieved starting from the root node.
    :return: None if the PC is structured decomposable, a reason otherwise.
    """
    # Shortcut: a PC is structured decomposable if it is compatible with itself
    if nodes is None:
        nodes = collect_nodes(root)
    return are_compatible(root, root, nodes_a=nodes, nodes_b=nodes)


def are_compatible(
    root_a: Node,
    root_b: Node,
    nodes_a: Optional[List[Node]] = None,
    nodes_b: Optional[List[Node]] = None
) -> Optional[str]:
    """
    Check if two PCs are compatible.

    :param root_a: The root of the first PC.
    :param root_b: The root of the second PC.
    :param nodes_a: The list of nodes of the first PC. If None, it will be retrieved starting from the root node.
    :param nodes_b: The list of nodes of the second PC. If None, it will be retrieved starting from the root node.
    :return: None if the two PCs are compatible, a reason otherwise.
    """
    if nodes_a is None:
        nodes_a = collect_nodes(root_a)
    if nodes_b is None:
        nodes_b = collect_nodes(root_b)

    # Check smoothness and decomposability first
    res = is_smooth(root_a, nodes_a)
    if res is not None:
        return f'First PC: {res}'
    res = is_decomposable(root_a, nodes_a)
    if res is not None:
        return f'First PC: {res}'
    res = is_smooth(root_b, nodes_b)
    if res is not None:
        return f'Second PC: {res}'
    res = is_decomposable(root_b, nodes_b)
    if res is not None:
        return f'Second PC: {res}'

    # Get scopes as sets
    scopes_a = collect_scopes(nodes_a)
    scopes_b = collect_scopes(nodes_b)
    scopes_a = list(map(lambda s: set(s), scopes_a))
    scopes_b = list(map(lambda s: set(s), scopes_b))

    # Quadratic in the number of product nodes
    for s1 in scopes_a:
        for s2 in scopes_b:
            int_len = len(s1.intersection(s2))
            if int_len != 0 and int_len != min(len(s1), len(s2)):
                return f"Incompatibility found between scope {s1} and scope {s2}"
    return None


def collect_scopes(nodes: List[Node]) -> List[Tuple[int]]:
    """
    Collect the scopes of each node.

    :param nodes: The list of nodes.
    :return: A list of scopes.
    """
    scopes = list()
    for n in nodes:
        if isinstance(n, Product):
            scopes.append(tuple(sorted(n.scope)))
        elif isinstance(n, BinaryCLT):
            scopes.extend([tuple(sorted(scope)) for scope in n.get_scopes()])
        elif not isinstance(n, Sum) and not isinstance(n, Leaf):
            raise NotImplementedError(f"Case not considered for {type(n)} nodes")
    return scopes


##------------------------------------------------------------------- evaluetion

def parallel_layerwise_eval(
    layers: List[List[Node]],
    eval_func: Callable[[Node], None],
    reverse: bool = False,
    n_jobs: int = -1
):
    """
    Execute a function per node layerwise in parallel.

    :param layers: The layers, i.e., the layered topological ordering.
    :param eval_func: The evaluation function for a given node.
    :param reverse: Whether to reverse the layered topological ordering.
    :param n_jobs: The number of parallel jobs. It follows the joblib's convention.
    """
    if reverse:
        layers = reversed(layers)

    # Run parallel threads using joblib
    with joblib.parallel_backend('threading', n_jobs=n_jobs):
        with joblib.Parallel() as parallel:
            for layer in layers:
                parallel(joblib.delayed(eval_func)(node) for node in layer)


def eval_bottom_up(
    root: Node,
    x: np.ndarray,
    leaf_func: Callable[[Leaf, np.ndarray, Any], np.ndarray],
    node_func: Callable[[Node, np.ndarray, Any], np.ndarray],
    leaf_func_kwargs: Optional[dict] = None,
    node_func_kwargs: Optional[dict] = None,
    return_results: bool = False,
    n_jobs: int = 0
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
    """
    Evaluate the SPN bottom up given some inputs and leaves and nodes evaluation functions.

    :param root: The root of the SPN.
    :param x: The inputs.
    :param leaf_func: The function to compute at each leaf node.
    :param node_func: The function to compute at each inner node.
    :param leaf_func_kwargs: The optional parameters of the leaf evaluation function.
    :param node_func_kwargs: The optional parameters of the inner nodes evaluation function.
    :param return_results: A flag indicating if this function must return the log likelihoods of each node of the SPN.
    :param n_jobs: The number of parallel jobs. It follows the joblib's convention. Set to 0 to disable.
    :return: The outputs. Additionally, it returns the output of each node.
    :raises ValueError: If a parameter is out of domain.
    """
    if leaf_func_kwargs is None:
        leaf_func_kwargs = dict()
    if node_func_kwargs is None:
        node_func_kwargs = dict()

    # Check the SPN
    check_spn(root, labeled=True, smooth=True, decomposable=True)

    def eval_forward(n):
        if isinstance(n, Leaf):
            ls[n.id] = leaf_func(n, x[:, n.scope], **leaf_func_kwargs)
        else:
            children_ls = np.stack([ls[c.id] for c in n.children], axis=1)
            ls[n.id] = node_func(n, children_ls, **node_func_kwargs)

    if n_jobs == 0:
        # Compute the topological ordering
        ordering = topological_order(root)
        if ordering is None:
            raise ValueError("SPN structure is not a directed acyclic graph (DAG)")
        n_nodes, n_samples = len(ordering), len(x)
        ls = np.empty(shape=(n_nodes, n_samples), dtype=np.float32)
        for node in reversed(ordering):
            eval_forward(node)
    else:
        # Compute the layered topological ordering
        layers = topological_order_layered(root)
        if layers is None:
            raise ValueError("SPN structure is not a directed acyclic graph (DAG)")
        n_nodes, n_samples = sum(map(len, layers)), len(x)
        ls = np.empty(shape=(n_nodes, n_samples), dtype=np.float32)
        parallel_layerwise_eval(layers, eval_forward, reverse=True, n_jobs=n_jobs)

    if return_results:
        return ls[root.id], ls
    return ls[root.id]


def eval_top_down(
    root: Node,
    x: np.ndarray,
    lls: np.ndarray,
    leaf_func: Callable[[Leaf, np.ndarray, Any], np.ndarray],
    sum_func: Callable[[Sum, np.ndarray, Any], np.ndarray],
    leaf_func_kwargs: Optional[dict] = None,
    sum_func_kwargs: Optional[dict] = None,
    inplace: bool = False,
    n_jobs: int = 0
) -> np.ndarray:
    """
    Evaluate the SPN top down given some inputs, the likelihoods of each node and a leaves evaluation function.
    The leaves to evaluate are chosen by following the nodes given by the sum nodes evaluation function.

    :param root: The root of the SPN.
    :param x: The inputs with some NaN values.
    :param lls: The log-likelihoods of each node.
    :param leaf_func: The leaves evaluation function.
    :param sum_func: The sum nodes evaluation function.
    :param leaf_func_kwargs: The optional parameters of the leaf evaluation function.
    :param sum_func_kwargs: The optional parameters of the sum nodes evaluation function.
    :param inplace: Whether to make inplace assignments.
    :param n_jobs: The number of parallel jobs. It follows the joblib's convention. Set to 0 to disable.
    :return: The NaN-filled inputs.
    :raises ValueError: If a parameter is out of domain.
    """
    if leaf_func_kwargs is None:
        leaf_func_kwargs = dict()
    if sum_func_kwargs is None:
        sum_func_kwargs = dict()

    # Check the SPN
    check_spn(root, labeled=True, smooth=True, decomposable=True)

    # Copy the input array, if not inplace mode
    if not inplace:
        x = np.copy(x)

    def eval_backward(n):
        if isinstance(n, Leaf):
            mask = np.ix_(masks[n.id], n.scope)
            x[mask] = leaf_func(n, x[mask], **leaf_func_kwargs)
        elif isinstance(n, Product):
            for c in n.children:
                masks[c.id] |= masks[n.id]
        elif isinstance(n, Sum):
            children_lls = np.stack([lls[c.id] for c in n.children], axis=1)
            branch = sum_func(n, children_lls, **sum_func_kwargs)
            for i, c in enumerate(n.children):
                masks[c.id] |= masks[n.id] & (branch == i)
        else:
            raise NotImplementedError(f"Top down evaluation not implemented for node of type {n.__class__.__name__}")

    if n_jobs == 0:
        # Compute the topological ordering
        ordering = topological_order(root)
        if ordering is None:
            raise ValueError("SPN structure is not a directed acyclic graph (DAG)")
        n_nodes, n_samples = len(ordering), len(x)

        # Build the array consisting of top-down path masks
        masks = np.zeros(shape=(n_nodes, n_samples), dtype=np.bool_)
        masks[root.id] = True
        for node in ordering:
            eval_backward(node)
    else:
        # Compute the layered topological ordering
        layers = topological_order_layered(root)
        if layers is None:
            raise ValueError("SPN structure is not a directed acyclic graph (DAG)")
        n_nodes, n_samples = sum(map(len, layers)), len(x)

        # Build the array consisting of top-down path masks
        masks = np.zeros(shape=(n_nodes, n_samples), dtype=np.bool_)
        masks[root.id] = True
        parallel_layerwise_eval(layers, eval_backward, reverse=False, n_jobs=n_jobs)

    return x


##--------------------------------------------------------------Eval Gradient

def eval_backward(root: Node, lls: np.ndarray) -> np.ndarray:
    """
    Compute the log-gradients at each SPN node.

    :param root: The root of the SPN.
    :param lls: The log-likelihoods at each node.
    :return: The log-gradients w.r.t. the nodes.
    :raises ValueError: If a parameter is out of domain.
    """
    # Check the SPN
    check_spn(root, labeled=True, smooth=True, decomposable=True)

    nodes = topological_order(root)
    if nodes is None:
        raise ValueError("SPN structure is not a directed acyclic graph (DAG)")

    n_nodes, n_samples = lls.shape
    if n_nodes != len(nodes):
        raise ValueError("Incompatible log-likelihoods broadcasting at each node")

    # Initialize the log-gradients array and the cached log-gradients dictionary of lists
    grads = np.empty(shape=(n_nodes, n_samples), dtype=np.float32)
    cached_grads = defaultdict(list)

    # Initialize the identity log-gradient at root node
    grads[root.id] = 0.0

    for node in nodes:
        # Compute log-gradient at the underlying node by logsumexp
        # Note that at this point of topological ordering, the node have no incoming arcs
        # Hence, we can finally compute the log-gradients w.r.t. this node
        if node.id != root.id:
            grads[node.id] = logsumexp(cached_grads[node.id], axis=0)
            del cached_grads[node.id]  # Cached log-gradients no longer necessary

        if isinstance(node, Sum):
            for c, w in zip(node.children, node.weights):
                g = grads[node.id] + np.log(w)
                cached_grads[c.id].append(g)
        elif isinstance(node, Product):
            for c in node.children:
                g = grads[node.id] + lls[node.id] - lls[c.id]
                cached_grads[c.id].append(g)
        elif isinstance(node, Leaf):
            pass  # Leaves have no children
        else:
            raise NotImplementedError(
                "Gradient evaluation not implemented for node of type {}".format(node.__class__.__name__)
            )

    return grads

In [None]:
#@title inference

def likelihood(
    root: Node,
    x: np.ndarray,
    return_results: bool = False,
    n_jobs: int = 0
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
    """
    Compute the likelihoods of the SPN given some inputs.

    :param root: The root of the SPN.
    :param x: The inputs. They can be marginalized using NaNs.
    :param return_results: A flag indicating if this function must return the likelihoods of each node of the SPN.
    :param n_jobs: The number of parallel jobs. It follows the joblib's convention. Set to 0 to disable.
    :return: The likelihood values. Additionally, it returns the likelihood values of each node.
    """
    return eval_bottom_up(
        root, x,
        leaf_func=node_likelihood,
        node_func=node_likelihood,
        return_results=return_results,
        n_jobs=n_jobs
    )


def log_likelihood(
    root: Node,
    x: np.ndarray,
    return_results: bool = False,
    n_jobs: int = 0
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
    """
    Compute the logarithmic likelihoods of the SPN given some inputs.

    :param root: The root of the SPN.
    :param x: The inputs. They can be marginalized using NaNs.
    :param return_results: A flag indicating if this function must return the log likelihoods of each node of the SPN.
    :param n_jobs: The number of parallel jobs. It follows the joblib's convention. Set to 0 to disable.
    :return: The log likelihood values. Additionally, it returns the log likelihood values of each node.
    """
    return eval_bottom_up(
        root, x,
        leaf_func=node_log_likelihood,
        node_func=node_log_likelihood,
        return_results=return_results,
        n_jobs=n_jobs
    )


def mpe(root: Node, x: np.ndarray, inplace: bool = False, n_jobs: int = 0) -> np.ndarray:
    """
    Compute the Most Probable Explanation of a SPN given some inputs.

    :param root: The root of the SPN.
    :param x: The inputs. They can be marginalized using NaNs.
    :param inplace: Whether to make inplace assignments.
    :param n_jobs: The number of parallel jobs. It follows the joblib's convention. Set to 0 to disable.
    :return: The NaN-filled inputs.
    """
    _, lls = log_likelihood(root, x, return_results=True)
    with ContextState(check_spn=False):  # We've already checked the SPN in forward mode
        return eval_top_down(
            root, x, lls,
            leaf_func=leaf_mpe,
            sum_func=sum_mpe,
            inplace=inplace,
            n_jobs=n_jobs
        )


def node_likelihood(node: Node, x: np.ndarray) -> np.ndarray:
    """
    Compute the likelihood of a node given the list of likelihoods of its children.

    :param node: The internal node.
    :param x: The array of likelihoods of the children.
    :return: The likelihoods of the node given the inputs.
    """
    ls = node.likelihood(x)
    return np.squeeze(ls, axis=1)


def node_log_likelihood(node: Node, x: np.ndarray) -> np.ndarray:
    """
    Compute the log-likelihood of a node given the list of log-likelihoods of its children.

    :param node: The internal node.
    :param x: The array of log-likelihoods of the children.
    :return: The log-likelihoods of the node given the inputs.
    """
    lls = node.log_likelihood(x)
    return np.squeeze(np.maximum(lls, -1e31), axis=1)


def leaf_mpe(node: Leaf, x: np.ndarray) -> np.ndarray:
    """
    Compute the maximum likelihood estimate of a leaf node.

    :param node: The leaf node.
    :param x: The inputs with some NaN values.
    :return: The most proable explanation.
    """
    return node.mpe(x)


def sum_mpe(node: Sum, lls: np.ndarray) -> np.ndarray:
    """
    Choose the branch that maximize the posterior estimate likelihood.

    :param node: The sum node.
    :param lls: The log-likelihoods of the children nodes.
    :return: The branch that maximize the posterior estimate likelihood.
    """
    weighted_lls = lls + np.log(node.weights)
    return np.argmax(weighted_lls, axis=1)


##----------------------------------------- capture moment on node scale

def moment(root: Node, order: int = 1) -> np.ndarray:
    """
    Compute non-central moments of a given order of a smooth and decomposable SPN.

    :param root: The root of the SPN.
    :param order: The order of the moment. If scalar, it will be used for all the random variables.
    :return: The non-central moments with respect to each variable in the scope.
    :raises ValueError: If the order of the moment is negative.
    """
    scope = root.scope
    if order < 0:
        raise ValueError("The order of the moment must be non-negative")
    if order == 0:  # Completely skip computation for 0-order moments
        return np.ones(len(scope), dtype=np.float32)

    # Compute the moments w.r.t. each random variable by proceeding bottom-up
    moments = np.ones(shape=[len(scope), len(scope)], dtype=np.float32)
    return eval_bottom_up(
        root, moments,
        leaf_func=leaf_moment,
        node_func=node_likelihood,
        leaf_func_kwargs={'order': order}
    )


def leaf_moment(node: Leaf, x: np.ndarray, order: int) -> np.ndarray:
    """
    Compute the moment of a leaf node.

    :param node: The leaf node.
    :param x: The inputs of the leaf. Actually, it's used only to infer the output shape.
    :param order: The order of the moment.
    :return: The moment of the leaf node.
    """
    m = np.ones(len(x), dtype=np.float32)
    m[node.scope] = node.moment(k=order)
    return m


def expectation(root: Node) -> np.ndarray:
    """
    Compute the expectation values of a SPN w.r.t. each of the random variables.

    :param root: The root of the SPN.
    :return: The expectation w.r.t. each of the random variables.
    """
    return moment(root, order=1)


def variance(root: Node) -> np.ndarray:
    """
    Compute the variance values of a SPN w.r.t. each of the random variables.

    :param root: The root of the SPN.
    :return: The variance w.r.t. each of the random variables.
    """
    fst_moment = moment(root, order=1)
    snd_moment = moment(root, order=2)
    return snd_moment - fst_moment ** 2.0


def skewness(root: Node) -> np.ndarray:
    """
    Compute the skewness values of a SPN w.r.t. each of the random variables.

    :param root: The root of the SPN.
    :return: The skewness w.r.t. each of the random variables.
    """
    # This implementation is derived by expanding the third central moment
    # and obtaining a definition based on non-central moments
    fst_moment = moment(root, order=1)
    snd_moment = moment(root, order=2)
    thd_moment = moment(root, order=3)
    g1 = fst_moment ** 2.0
    g2 = snd_moment - g1
    g3 = 3.0 * snd_moment + 2.0 * g1
    return (thd_moment - fst_moment * g3) / (g2 ** 1.5)


def kurtosis(root: Node) -> np.ndarray:
    """
    Compute the kurtosis values of a SPN w.r.t. each of the random variables.
    This function returns the kurtosis based on Fisher's definition, i.e.
    3.0 is subtracted from the result to give 0.0 for a normal distribution.

    :param root: The root of the SPN.
    :return: The kurtosis w.r.t. each of the random variables.
    """
    # This implementation is derived from Moors' interpretation
    # (More @ https://en.wikipedia.org/wiki/Kurtosis#Moors'_interpretation)
    # by expanding Var[Z^2] + 1 and obtaining a definition based on non-central moments
    fst_moment = moment(root, order=1)
    snd_moment = moment(root, order=2)
    thd_moment = moment(root, order=3)
    fhd_moment = moment(root, order=4)
    g1 = fst_moment ** 2.0
    g2 = snd_moment - g1
    g3 = 4.0 * (g1 ** 2.0 + fst_moment * thd_moment)
    g4 = snd_moment * (8.0 * g1 - snd_moment)
    return -2.0 + (fhd_moment - g3 + g4) / (g2 ** 2.0)

In [None]:
#@title sampling


def sample(root: Node, x: np.ndarray, inplace: bool = False, n_jobs: int = 0) -> np.ndarray:
    """
    Sample some features from the distribution represented by the SPN.

    :param root: The root of the SPN.
    :param x: The inputs with possible NaN values to fill with sampled values.
    :param inplace: Whether to make inplace assignments.
    :param n_jobs: The number of parallel jobs. It follows the joblib's convention. Set to 0 to disable.
     Warning: disrupts seed determinism.
    :return: The inputs that are NaN-filled with samples from appropriate distributions.
    """
    # First evaluate the SPN bottom-up, then top-down
    _, lls = log_likelihood(root, x, return_results=True, n_jobs=n_jobs)
    with ContextState(check_spn=False):  # We've already checked the SPN in forward mode
        return eval_top_down(
            root, x, lls,
            leaf_func=leaf_sample,
            sum_func=sum_sample,
            inplace=inplace,
            n_jobs=n_jobs
        )


def leaf_sample(node: Leaf, x: np.ndarray) -> np.ndarray:
    """
    Sample some values from the distribution leaf.

    :param node: The distribution leaf node.
    :param x: The inputs with possible NaN values to fill with sampled values.
    :return: The completed samples.
    """
    return node.sample(x)


def sum_sample(node: Sum, lls: np.ndarray) -> np.ndarray:
    """
    Choose the sub-distribution from which sample.

    :param node: The sum node.
    :param lls: The log-likelihoods of the children nodes.
    :return: The index of the sub-distribution to follow.
    """
    n_samples, n_features = lls.shape
    gumbel = stats.gumbel_l.rvs(0.0, 1.0, size=(n_samples, n_features))
    weighted_lls = lls + np.log(node.weights) + gumbel
    return np.argmax(weighted_lls, axis=1)

In [None]:
#@title prune the spn


def prune(root: Node, copy: bool = True) -> Node:
    """
    Prune (or simplify) the given SPN to a minimal and equivalent SPN.

    :param root: The root of the SPN.
    :param copy: Whether to copy the SPN before pruning it.
    :return: A minimal and equivalent SPN.
    :raises ValueError: If the SPN structure is not a directed acyclic graph (DAG).
    :raises ValueError: If an unknown node type is found.
    """
    # Copy the SPN before proceeding, if specified
    if copy:
        root = deepcopy(root)

    # Check the SPN
    check_spn(root, labeled=True, smooth=True, decomposable=True)

    nodes = topological_order(root)
    if nodes is None:
        raise ValueError("SPN structure is not a directed acyclic graph (DAG)")

    # Build a dictionary that maps each id of a node to the corresponding node object
    nodes_map = dict(map(lambda n: (n.id, n), nodes))

    # Proceed by reversed topological order
    for node in reversed(nodes):
        # Skip leaves
        if isinstance(node, Leaf):
            continue

        # Retrieve the children nodes from the mapping
        children_nodes = list(map(lambda n: nodes_map[n.id], node.children))
        if len(children_nodes) == 1:
            nodes_map[node.id] = children_nodes[0]
        elif isinstance(node, Product):
            # Subsequent product nodes, concatenate the children of them
            children = list()
            for child in children_nodes:
                if not isinstance(child, Product):
                    children.append(child)
                    continue
                product_children = map(lambda n: nodes_map[n.id], child.children)
                children.extend(product_children)
            nodes_map[node.id].children = children
        elif isinstance(node, Sum):
            # Subsequent sum nodes, concatenate the children of them and adjust the weights accordingly
            # Important! This implementation take care also of directed acyclic graphs (DAGs)
            children_weights = defaultdict(float)
            for i, child in enumerate(children_nodes):
                if not isinstance(child, Sum):
                    children_weights[child] += node.weights[i]
                    continue
                sum_children = map(lambda n: nodes_map[n.id], child.children)
                for j, sum_child in enumerate(sum_children):
                    children_weights[sum_child] += node.weights[i] * child.weights[j]
            children, weights = zip(*children_weights.items())
            nodes_map[node.id].weights = np.array(weights, dtype=node.weights.dtype)
            nodes_map[node.id].children = children
        else:
            raise ValueError("Unknown node type called {}".format(node.__class__.__name__))

    return assign_ids(nodes_map[root.id])


def marginalize(root: Node, keep_scope: List[int], copy: bool = True) -> Node:
    """
    Marginalize some random variables of a SPN, obtaining the compilation of a marginal query.

    :param root: The root of the SPN to marginalize.
    :param keep_scope: The scope of the random variables to keep.
                       All the other random variables will be marginalized.
    :param copy: Whether to copy the SPN before marginalizing it.
    :return: A SPN in which an EVI query is equivalent to a MAR query under the given scope.
    :raises ValueError: If the scope of the random variables to keep is not valid.
    :raises ValueError: If the SPN structure is not a directed acyclic graph (DAG).
    :raises ValueError: If an unknown node type is found.
    :raises NotImplementedError: If non-BinaryCLT multivariate leaves are found.
    """
    if not keep_scope:
        raise ValueError("The scope of the random variables to keep must not be empty")
    keep_scope_s = set(keep_scope)
    if len(keep_scope) != len(keep_scope_s):
        raise ValueError("The scope of the random variables to keep must not contain duplicates")
    if not keep_scope_s.issubset(set(root.scope)):
        raise ValueError("The scope of the random variables to keep must be a subset of the scope of the SPN")

    # Copy the SPN before proceeding, if specified
    if copy:
        root = deepcopy(root)

    # Check the SPN
    check_spn(root, labeled=True, smooth=True, decomposable=True)

    nodes = topological_order(root)
    if nodes is None:
        raise ValueError("SPN structure is not a directed acyclic graph (DAG)")

    # Build a dictionary that maps each id of a node to the corresponding node object
    nodes_map = dict(map(lambda n: (n.id, n), nodes))

    # Proceed by reversed topological order
    for node in reversed(nodes):
        if isinstance(node, Leaf):
            # Marginalize leaves, set to None if the leaf is fully marginalized
            if isinstance(node, BinaryCLT):
                # Convert the binary Chow-Liu Tree to a SPN and marginalize that instead
                clt_scope = list(keep_scope_s.intersection(node.scope))
                if clt_scope:
                    with ContextState(check_spn=False):  # Disable checking the SPN obtained by CLT to PC conversion
                        nodes_map[node.id] = marginalize(node.to_pc(), clt_scope, copy=False)
                else:
                    nodes_map[node.id] = None
            elif len(node.scope) == 1:
                nodes_map[node.id] = node if node.scope[0] in keep_scope else None
            else:
                raise NotImplementedError(
                    "Structural marginalization for arbitrarily multivariate leaves not yet implemented"
                )
            continue

        # Retrieve the children nodes from the mapping
        children_nodes = list(filter(
            lambda n: n is not None, map(lambda n: nodes_map[n.id], node.children)
        ))

        if not children_nodes:
            nodes_map[node.id] = None
        elif len(children_nodes) == 1:
            nodes_map[node.id] = children_nodes[0]
        else:
            if isinstance(node, Product):
                nodes_map[node.id].scope = list(sum(map(lambda n: n.scope, children_nodes), []))
                nodes_map[node.id].children = children_nodes
            elif isinstance(node, Sum):
                nodes_map[node.id].scope = children_nodes[0].scope
                nodes_map[node.id].children = children_nodes
            else:
                raise ValueError("Unknown node type called {}".format(node.__class__.__name__))

    root = assign_ids(nodes_map[root.id])
    return prune(root, copy=False)

In [None]:
#@title spn stats calculator

def compute_statistics(root: Node) -> dict:
    """
    Compute some statistics of a SPN given its root.
    The computed statistics are the following:

    - n_nodes, the number of nodes
    - n_sum, the number of sum nodes
    - n_prod, the number of product nodes
    - n_leaves, the number of leaves
    - n_edges, the number of edges
    - n_params, the number of parameters
    - depth, the depth of the network

    :param root: The root of the SPN.
    :return: A dictionary containing the statistics.
    """
    stats = {
        'n_nodes': len(collect_nodes(root)),
        'n_sum': len(filter_nodes_by_type(root, Sum)),
        'n_prod': len(filter_nodes_by_type(root, Product)),
        'n_leaves': len(filter_nodes_by_type(root, Leaf)),
        'n_edges': compute_edges_count(root),
        'n_params': compute_parameters_count(root),
        'depth': compute_depth(root)
    }
    return stats


def compute_edges_count(root: Node) -> int:
    """
    Get the number of edges of a SPN given its root.

    :param root: The root of the SPN.
    :return: The number of edges.
    """
    return sum(len(n.children) for n in filter_nodes_by_type(root, (Sum, Product)))


def compute_parameters_count(root: Node) -> int:
    """
    Get the number of parameters of a SPN given its root.

    :param root:  The root of the SPN.
    :return: The number of parameters.
    """
    n_weights = sum(len(n.weights) for n in filter_nodes_by_type(root, Sum))
    n_leaf_params = sum(n.params_count() for n in filter_nodes_by_type(root, Leaf))
    return n_weights + n_leaf_params


def compute_depth(root: Node) -> int:
    """
    Get the depth of the SPN given its root.

    :param root: The root of the SPN.
    :return: The depth of the network.
    """
    depths = dict()
    for node in bfs(root):
        d = depths.setdefault(node, 0)
        for c in node.children:
            depths[c] = d + 1
    return max(depths.values())

# Splitting

In [None]:
#@title cluster and nonlinear independence test by randomized dependence coefficient (rdc)


#--------------------------------------------------------------- cluster methods

def gmm(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    random_state: np.random.RandomState,
    n: int = 2
) -> np.ndarray:
    """
    Execute GMM clustering on some data.

    :param data: The data.
    :param distributions: The data distributions.
    :param domains: The data domains.
    :param random_state: The random state.
    :param n: The number of clusters.
    :return: An array where each element is the cluster where the corresponding data belong.
    """
    # Convert the data using One Hot Encoding, in case of non-binary discrete features
    if any(len(d) > 2 for d in domains):
        data = mixed_ohe_data(data, domains)

    # Apply GMM
    with warnings.catch_warnings():
        warnings.simplefilter(action='ignore', category=ConvergenceWarning)  # Ignore convergence warnings
        return mixture.GaussianMixture(n, n_init=3, random_state=random_state).fit_predict(data)


def kmeans(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    random_state: np.random.RandomState,
    n: int = 2
) -> np.ndarray:
    """
    Execute K-Means clustering on some data.

    :param data: The data.
    :param distributions: The data distributions.
    :param domains: The data domains.
    :param random_state: The random state.
    :param n: The number of clusters.
    :return: An array where each element is the cluster where the corresponding data belong.
    """
    # Convert the data using One Hot Encoding, in case of non-binary discrete features
    if any(len(d) > 2 for d in domains):
        data = mixed_ohe_data(data, domains)

    # Apply K-Means
    with warnings.catch_warnings():
        warnings.simplefilter(action='ignore', category=ConvergenceWarning)  # Ignore convergence warnings
        return cluster.KMeans(n, n_init=5, random_state=random_state).fit_predict(data)


def kmeans_mb(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    random_state: np.random.RandomState,
    n: int = 2
) -> np.ndarray:
    """
    Execute MiniBatch K-Means clustering on some data.

    :param data: The data.
    :param distributions: The data distributions.
    :param domains: The data domains.
    :param random_state: The random state.
    :param n: The number of clusters.
    :return: An array where each element is the cluster where the corresponding data belong.
    """
    # Convert the data using One Hot Encoding, in case of non-binary discrete features
    if any(len(d) > 2 for d in domains):
        data = mixed_ohe_data(data, domains)

    # Apply K-Means MiniBatch
    with warnings.catch_warnings():
        warnings.simplefilter(action='ignore', category=ConvergenceWarning)  # Ignore convergence warnings
        warnings.simplefilter(action='ignore', category=UserWarning)  # Ignore user warnings
        return cluster.MiniBatchKMeans(n, n_init=5, random_state=random_state).fit_predict(data)


def dbscan(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    random_state: np.random.RandomState,
    n: int = 2
) -> np.ndarray:
    """
    Execute DBSCAN clustering on some data (only on discrete data).

    :param data: The data.
    :param distributions: The data distributions.
    :param domains: The data domains.
    :param random_state: The random state.
    :param n: The number of clusters.
    :return: An array where each element is the cluster where the corresponding data belong.
    :raises ValueError: If the leaf distributions are NOT discrete.
    """
    # Control if distribution are binary
    if not all(x.LEAF_TYPE == LeafType.DISCRETE for x in distributions):
        raise ValueError('DBScan clustering can be applied only on discrete attributes')

    # Convert the data using One Hot Encoding, in case of non-binary discrete features
    if any(len(d) > 2 for d in domains):
        data = mixed_ohe_data(data, domains)

    # Apply DBSCAN
    with warnings.catch_warnings():
        warnings.simplefilter(action='ignore', category=ConvergenceWarning)  # Ignore convergence warnings
        return cluster.DBSCAN(eps = 0.25, n_jobs=-1).fit_predict(data)


def wald(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    random_state: np.random.RandomState,
    n: int = 2
) -> np.ndarray:
    """
    Execute Ward (Hierarchical) clustering on some data (only discrete data).

    :param data: The data.
    :param distributions: The data distributions.
    :param domains: The data domains.
    :param random_state: The random state.
    :param n: The number of clusters.
    :return: An array where each element is the cluster where the corresponding data belong.
    :raises ValueError: If the leaf distributions are NOT discrete.
    """
    # Control if distribution are binary
    if not all(x.LEAF_TYPE == LeafType.DISCRETE for x in distributions):
        raise ValueError('DBScan clustering can be applied only on discrete attributes')

    # Convert the data using One Hot Encoding, in case of non-binary discrete features
    if any(len(d) > 2 for d in domains):
        data = mixed_ohe_data(data, domains)

    # Apply Wald
    with warnings.catch_warnings():
        warnings.simplefilter(action='ignore', category=ConvergenceWarning)  # Ignore convergence warnings
        return cluster.AgglomerativeClustering(n, linkage='ward').fit_predict(data)


#--------------------------------------------------------------------------- rdc

def rdc_cols(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    random_state: np.random.RandomState,
    d: float = 0.3,
    k: int = 20,
    s: float = 1.0 / 6.0,
    nl: Callable[[np.ndarray], np.ndarray] = np.sin
) -> np.ndarray:
    """
    Split the features using the RDC (Randomized Dependency Coefficient) method.

    :param data: The data.
    :param distributions: The data distributions.
    :param domains: The data domains.
    :param random_state: The random state.
    :param d: The threshold value that regulates the independence tests among the features.
    :param k: The size of the latent space.
    :param s: The standard deviation of the gaussian distribution.
    :param nl: The non linear function to use.
    :return: A features partitioning.
    """
    # Compute the RDC scores matrix
    rdc_matrix = rdc_scores(data, distributions, domains, random_state, k=k, s=s, nl=nl)

    # Compute the adjacency matrix
    adj_matrix = (rdc_matrix > d).astype(np.int32)

    # Compute the connected components of the adjacency matrix
    adj_matrix = sp.csr_matrix(adj_matrix)
    _, clusters = sp.csgraph.connected_components(adj_matrix, directed=False, return_labels=True)
    return clusters


def rdc_rows(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    random_state: np.random.RandomState,
    n: int = 2,
    k: int = 20,
    s: float = 1.0 / 6.0,
    nl: Callable[[np.ndarray], np.ndarray] = np.sin
) -> np.ndarray:
    """
    Split the samples using the RDC (Randomized Dependency Coefficient) method.

    :param data: The data.
    :param distributions: The data distributions.
    :param domains: The data domains.
    :param random_state: The random state.
    :param n: The number of clusters for KMeans.
    :param k: The size of the latent space.
    :param s: The standard deviation of the gaussian distribution.
    :param nl: The non linear function to use.
    :return: A samples partitioning.
    """
    # Transform the samples by RDC
    rdc_samples = np.concatenate(
        rdc_transform(data, distributions, domains, random_state, k, s, nl), axis=1
    )

    # Apply K-Means to the transformed samples
    with warnings.catch_warnings():
        warnings.simplefilter(action='ignore', category=ConvergenceWarning)  # Ignore convergence warnings for K-Means
        return cluster.KMeans(n, n_init=5, random_state=random_state).fit_predict(rdc_samples)


def rdc_scores(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    random_state: np.random.RandomState,
    k: int = 20,
    s: float = 1.0 / 6.0,
    nl: Callable[[np.ndarray], np.ndarray] = np.sin
) -> np.ndarray:
    """
    Compute the RDC (Randomized Dependency Coefficient) score for each pair of features.

    :param data: The data.
    :param distributions: The data distributions.
    :param domains: The data domains.
    :param random_state: The random state.
    :param k: The size of the latent space.
    :param s: The standard deviation of the gaussian distribution.
    :param nl: The non linear function to use.
    :return: The RDC score matrix.
    """
    # Apply RDC transformation to the features
    _, n_features = data.shape
    rdc_features = rdc_transform(data, distributions, domains, random_state, k, s, nl)
    pairwise_comparisons = list(combinations(range(n_features), 2))

    # Run Canonical Component Analysis (CCA) on RDC-transformed features
    rdc_matrix = np.empty(shape=(n_features, n_features), dtype=np.float32)
    with warnings.catch_warnings():
        warnings.simplefilter(action='ignore', category=ConvergenceWarning)  # Ignore convergence warnings for CCA
        for i, j in pairwise_comparisons:
            score = rdc_cca(i, j, rdc_features)
            rdc_matrix[i, j] = rdc_matrix[j, i] = score
    np.fill_diagonal(rdc_matrix, 1.0)
    return rdc_matrix


def rdc_cca(i: int, j: int, features: List[np.ndarray]) -> float:
    """
    Compute the RDC (Randomized Dependency Coefficient) using CCA (Canonical Correlation Analysis).

    :param i: The index of the first feature.
    :param j: The index of the second feature.
    :param features: The list of the features.
    :return: The RDC coefficient (the largest canonical correlation coefficient).
    """
    cca = cross_decomposition.CCA(n_components=1)
    x_cca, y_cca = cca.fit_transform(features[i], features[j])
    x_cca, y_cca = x_cca.squeeze(), y_cca.squeeze()
    return np.corrcoef(x_cca, y_cca)[0, 1]


def rdc_transform(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    random_state: np.random.RandomState,
    k: int = 20,
    s: float = 1.0 / 6.0,
    nl: Callable[[np.ndarray], np.ndarray] = np.sin
) -> List[np.ndarray]:
    """
    Execute the RDC (Randomized Dependency Coefficient) pipeline on some data.

    :param data: The data.
    :param distributions: The data distributions.
    :param domains: The data domains.
    :param random_state: The random state.
    :param k: The size of the latent space.
    :param s: The standard deviation of the gaussian distribution.
    :param nl: The non-linear function to use.
    :return: The transformed data.
    :raises ValueError: If an unknown distribution type is found.
    """
    features = []
    for i, dist in enumerate(distributions):
        if dist.LEAF_TYPE == LeafType.DISCRETE:
            feature_matrix = ohe_data(data[:, i], domains[i])
        elif dist.LEAF_TYPE == LeafType.CONTINUOUS:
            feature_matrix = np.expand_dims(data[:, i], axis=-1)
        else:
            raise ValueError("Unknown distribution type {}".format(dist.LEAF_TYPE))
        x = np.apply_along_axis(ecdf_data, 0, feature_matrix)
        features.append(x.astype(np.float32))

    samples = []
    for x in features:
        stddev = np.sqrt(s / x.shape[1])
        w = stddev * random_state.randn(x.shape[1], k).astype(np.float32)
        b = stddev * random_state.randn(k).astype(np.float32)
        y = nl(np.dot(x, w) + b)
        samples.append(y)
    return samples


In [None]:
#@title random selection and splitting of rows and columns

def random_rows(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    random_state: np.random.RandomState,
    a: float = 2.0,
    b: float = 2.0
) -> np.ndarray:
    """
    Choose a binary partition horizontally randomly.
    The proportion of the split is sampled from a beta distribution.

    :param data: The data.
    :param distributions: The data distributions (not used).
    :param domains: The data domains (not used).
    :param random_state: The random state.
    :param a: The alpha parameter of the beta distribution.
    :param b: The beta parameter of the beta distribution.
    :return: A binary partition.
    """
    n_samples, _ = data.shape
    p = random_state.beta(a, b)
    return random_state.binomial(1, p, size=n_samples)


def random_cols(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    random_state: np.random.RandomState,
    a: float = 2.0,
    b: float = 2.0
) -> np.ndarray:
    """
    Choose a binary partition vertically randomly.
    The proportion of the split is sampled from a beta distribution.

    :param data: The data.
    :param distributions: The data distributions (not used).
    :param domains: The data domains (not used).
    :param random_state: The random state.
    :param a: The alpha parameter of the beta distribution.
    :param b: The beta parameter of the beta distribution.
    :return: A binary partition.
    """
    _, n_features = data.shape
    p = random_state.beta(a, b)
    return random_state.binomial(1, p, size=n_features)



SplitRowsFunc = Callable[
    [np.ndarray,                # The data
     List[Type[Leaf]],          # The distributions
     List[Union[list, tuple]],  # The domains
     np.random.RandomState,     # The random state
     Any],                      # Other arguments
    np.ndarray                  # The rows ids
]


def split_rows_clusters(
    data: np.ndarray,
    clusters: np.ndarray
) -> Tuple[List[np.ndarray], List[float]]:
    """
    Split the data horizontally given the clusters.

    :param data: The data.
    :param clusters: The clusters.
    :return: (slices, weights) where slices is a list of partial data and
             weights is a list of proportions of the local data in respect to the original data.
    """
    slices = list()
    weights = list()
    n_samples = len(data)
    unique_clusters = np.unique(clusters)
    for c in unique_clusters:
        local_data = data[clusters == c, :]
        slices.append(local_data)
        weights.append(len(local_data) / n_samples)
    return slices, weights


def get_split_rows_method(split_rows: str) -> SplitRowsFunc:
    """
    Get the rows splitting method given a string.

    :param split_rows: The string of the method do get.
    :return: The corresponding rows splitting function.
    :raises ValueError: If the rows splitting method is unknown.
    """
    if split_rows == 'kmeans':
        return kmeans
    if split_rows == 'kmeans_mb':
        return kmeans_mb
    if split_rows == 'dbscan':
        return dbscan
    if split_rows == 'wald':
        return wald
    if split_rows == 'gmm':
        return gmm
    if split_rows == 'rdc':
        return rdc_rows
    if split_rows == 'random':
        return random_rows
    raise ValueError("Unknown split rows method called {}".format(split_rows))


#---------------------------------------------------

SplitColsFunc = Callable[
    [np.ndarray,                # The data
     List[Type[Leaf]],          # The distributions
     List[Union[list, tuple]],  # The domains
     np.random.RandomState,     # The random state
     Any],                      # Other arguments
    np.ndarray                  # The columns ids
]


def split_cols_clusters(
    data: np.ndarray,
    clusters: np.ndarray,
    scope: List[int]
) -> Tuple[List[np.ndarray], List[List[int]]]:
    """
    Split the data vertically given the clusters.

    :param data: The data.
    :param clusters: The clusters.
    :param scope: The original scope.
    :return: (slices, scopes) where slices is a list of partial data and
             scopes is a list of partial scopes.
    """
    slices = list()
    scopes = list()
    scope = np.asarray(scope)
    unique_clusters = np.unique(clusters)
    for c in unique_clusters:
        cols = (clusters == c)
        slices.append(data[:, cols])
        scopes.append(scope[cols].tolist())
    return slices, scopes


def get_split_cols_method(split_cols: str) -> SplitColsFunc:
    """
    Get the columns splitting method given a string.

    :param split_cols: The string of the method do get.
    :return: The corresponding columns splitting function.
    :raises ValueError: If the columns splitting method is unknown.
    """
    # if split_cols == 'gvs':
    #     return gvs_cols
    # if split_cols == 'rgvs':
    #     return rgvs_cols
    # if split_cols == 'wrgvs':
    #     return wrgvs_cols
    # if split_cols == 'ebvs':
    #     return entropy_cols
    # if split_cols == 'ebvs_ae':
    #     return entropy_adaptive_cols
    # if split_cols == 'gbvs':
    #     return gini_cols
    # if split_cols == 'gbvs_ag':
    #     return gini_adaptive_cols
    if split_cols == 'rdc':
        return rdc_cols
    if split_cols == 'random':
        return random_cols
    raise ValueError("Unknown split rows method called {}".format(split_cols))

In [None]:
## other methods of splitting: https://github.com/deeprob-org/deeprob-kit/tree/main/deeprob/spn/learning/splitting
## remember if using those --> in previous cell also un-uncomment corresponding if clauses

# Learning

In [None]:
#@title Expectation Maximization EM

def expectation_maximization(
    root: Node,
    data: np.ndarray,
    num_iter: int = 100,
    batch_perc: float = 0.1,
    step_size: float = 0.5,
    random_init: bool = True,
    random_state: Optional[RandomState] = None,
    verbose: bool = True
) -> Node:
    """
    Learn the parameters of a SPN by batch Expectation-Maximization (EM).
    See https://arxiv.org/abs/1604.07243 and https://arxiv.org/abs/2004.06231 for details.

    :param root: The spn structure.
    :param data: The data to use to learn the parameters.
    :param num_iter: The number of iterations.
    :param batch_perc: The percentage of data to use for each step.
    :param step_size: The step size for batch EM.
    :param random_init: Whether to random initialize the weights of the SPN.
    :param random_state: The random state. It can be either None, a seed integer or a Numpy RandomState.
    :param verbose: Whether to enable verbose learning.
    :return: The spn with learned parameters.
    :raises ValueError: If a parameter is out of domain.
    """
    if num_iter <= 0:
        raise ValueError("The number of iterations must be positive")
    if batch_perc <= 0.0 or batch_perc >= 1.0:
        raise ValueError("The batch percentage must be in (0, 1)")
    if step_size <= 0.0 or step_size >= 1.0:
        raise ValueError("The step size must be in (0, 1)")

    # Check the SPN
    check_spn(root, labeled=True, smooth=True, decomposable=True)

    # Compute the batch size
    n_samples = len(data)
    batch_size = int(batch_perc * n_samples)

    # Compute a list-based cache for accessing nodes
    cached_nodes = {
        'sum': filter_nodes_by_type(root, Sum),
        'leaf': filter_nodes_by_type(root, Leaf)
    }

    # Check the random state
    random_state = check_random_state(random_state)

    # Random initialize the parameters of the SPN, if specified
    if random_init:
        # Initialize the sum parameters
        for node in cached_nodes['sum']:
            node.em_init(random_state)

        # Initialize the leaf parameters
        for node in cached_nodes['leaf']:
            node.em_init(random_state)

    # Initialize the tqdm bar, if verbose is specified
    iterator = range(num_iter)
    if verbose:
        iterator = tqdm(
            iterator, leave=None, unit='batch',
            bar_format='{desc}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'
        )

    for _ in iterator:
        # Sample a batch of data randomly with uniform distribution
        batch_indices = random_state.choice(n_samples, size=batch_size, replace=False)
        batch_data = data[batch_indices]

        # Prevent checking the SPN at every forward inference step, we already did that!
        with ContextState(check_spn=False):
            # Forward step, obtaining the LLs at each node
            root_ll, lls = log_likelihood(root, batch_data, return_results=True)
        mean_ll = np.mean(root_ll)

        # Backward step, compute the log-gradients required to compute the sufficient statistics
        grads = eval_backward(root, lls)

        # Update the weights of each sum node
        for node in cached_nodes['sum']:
            children_ll = lls[list(map(lambda c: c.id, node.children))]
            stats = np.exp(children_ll - root_ll + grads[node.id])
            node.em_step(stats, step_size)

        # Update the parameters of each leaf node
        for node in cached_nodes['leaf']:
            stats = np.exp(lls[node.id] - root_ll + grads[node.id])
            node.em_step(stats, batch_data[:, node.scope], step_size)

        # Update the progress bar
        if verbose:
            iterator.set_description('Batch Avg. LL: {:.4f}'.format(mean_ll))

    return root

In [None]:
#@title SPN leaf methods

#: A signature for a learn SPN leaf function.
LearnLeafFunc = Callable[
    [np.ndarray,                # The data
     List[Type[Leaf]],          # The distributions
     List[Union[list, tuple]],  # The domains
     List[int],                 # The scope
     Any],                      # Other arguments
    Node                        # A SPN node
]


def get_learn_leaf_method(learn_leaf: str) -> LearnLeafFunc:
    """
    Get the learn leaf method.

    :param learn_leaf: The learn leaf method string to use.
    :return: A learn leaf function.
    :raises ValueError: If the leaf learning method is unknown.
    """
    if learn_leaf == 'mle':
        return learn_mle
    if learn_leaf == 'isotonic':
        return learn_isotonic
    if learn_leaf == 'binary-clt':
        return learn_binary_clt
    raise ValueError("Unknown learn leaf method called {}".format(learn_leaf))


def learn_mle(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    scope: List[int],
    alpha: float = 0.1,
    random_state: Optional[RandomState] = None
) -> Node:
    """
    Learn a leaf using Maximum Likelihood Estimate (MLE).
    If the data is multivariate, a naive factorized model is learned.

    :param data: The data, where each column correspond to a random variable.
    :param distributions: The distributions of the random variables.
    :param domains: The domains of the random variables.
    :param scope: The scope of the leaf.
    :param alpha: Laplace smoothing factor.
    :param random_state: The random state. It can be None.
    :return: A leaf distribution.
    :raises ValueError: If there are inconsistencies between the data, distributions and domains.
    """
    if len(scope) != len(distributions) or len(domains) != len(distributions):
        raise ValueError("Each data column should correspond to a random variable having a distribution and a domain")

    if len(scope) == 1:
        sc, dist, dom = scope[0], distributions[0], domains[0]
        leaf = dist(sc)
        leaf.fit(data, dom, alpha=alpha)
        return leaf

    return learn_naive_factorization(
        data, distributions, domains, scope, learn_mle,
        alpha=alpha, random_state=random_state
    )


def learn_isotonic(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    scope: List[int],
    alpha: float = 0.1,
    random_state: Optional[RandomState] = None
) -> Node:
    """
    Learn a leaf using Isotonic method.
    If the data is multivariate, a naive factorized model is learned.

    :param data: The data.
    :param distributions: The distribution of the random variables.
    :param domains: The domain of the random variables.
    :param scope: The scope of the leaf.
    :param alpha: Laplace smoothing factor.
    :param random_state: The random sate. It can be None.
    :return: A leaf distribution.
    :raises ValueError: If there are inconsistencies between the data, distributions and domains.
    """
    if len(scope) != len(distributions) or len(domains) != len(distributions):
        raise ValueError("Each data column should correspond to a random variable having a distribution and a domain")

    if len(scope) == 1:
        sc, dist, dom = scope[0], distributions[0], domains[0]
        leaf = Isotonic(sc) if dist.LEAF_TYPE == LeafType.CONTINUOUS else dist(sc)
        leaf.fit(data, dom, alpha=alpha)
        return leaf

    return learn_naive_factorization(
        data, distributions, domains, scope, learn_isotonic,
        alpha=alpha, random_state=random_state
    )


def learn_binary_clt(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    scope: List[int],
    to_pc: bool = False,
    alpha: float = 0.1,
    random_state: Optional[RandomState] = None
) -> Node:
    """
    Learn a leaf using a Binary Chow-Liu Tree (CLT).
    If the data is univariate, a Maximum Likelihood Estimate (MLE) leaf is returned.

    :param data: The data.
    :param distributions: The distributions of the random variables.
    :param domains: The domains of the random variables.
    :param scope: The scope of the leaf.
    :param to_pc: Whether to convert the CLT into an equivalent PC.
    :param alpha: Laplace smoothing factor.
    :param random_state: The random state. It can be None.
    :return: A leaf distribution.
    :raises ValueError: If there are inconsistencies between the data, distributions and domains.
    :raises ValueError: If the data doesn't follow a Bernoulli distribution.
    """
    if len(scope) != len(distributions) or len(domains) != len(distributions):
        raise ValueError("Each data column should correspond to a random variable having a distribution and a domain")
    if any(d != Bernoulli for d in distributions):
        raise ValueError("Binary Chow-Liu trees are only available for Bernoulli data")

    # If univariate, learn using MLE instead
    if len(scope) == 1:
        return learn_mle(
            data, distributions, domains, scope,
            alpha=alpha, random_state=random_state
        )

    # If multivariate, learn a binary CLTree
    leaf = BinaryCLT(scope)
    leaf.fit(data, domains, alpha=alpha, random_state=random_state)

    # Make the conversion to a probabilistic circuit, if specified
    if to_pc:
        return leaf.to_pc()
    return leaf


def learn_naive_factorization(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    scope: List[int],
    learn_leaf_func: LearnLeafFunc,
    **learn_leaf_kwargs
) -> Node:
    """
    Learn a leaf as a naive factorized model.

    :param data: The data.
    :param distributions: The distribution of the random variables.
    :param domains: The domain of the random variables.
    :param scope: The scope of the leaf.
    :param learn_leaf_func: The function to use to learn the sub-distributions parameters.
    :param learn_leaf_kwargs: Additional parameters for learn_leaf_func.
    :return: A naive factorized model.
    :raises ValueError: If there are inconsistencies between the data, distributions and domains.
    """
    if len(scope) != len(distributions) or len(domains) != len(distributions):
        raise ValueError("Each data column should correspond to a random variable having a distribution and a domain")

    node = Product(scope)
    for i, s in enumerate(scope):
        leaf = learn_leaf_func(data[:, [i]], [distributions[i]], [domains[i]], [s], **learn_leaf_kwargs)
        leaf.id = i + 1  # Set the leaves ids sequentially
        node.children.append(leaf)
    return node

In [None]:
#@title learnSPN method


class OperationKind(Enum):
    """
    Operation kind used by LearnSPN algorithm.
    """
    REM_FEATURES = 1
    CREATE_LEAF = 2
    SPLIT_NAIVE = 3
    SPLIT_ROWS = 4
    SPLIT_COLS = 5


class Task(NamedTuple):
    """
    Recursive task information used by LearnSPN algorithm.
    """
    parent: Node
    data: np.ndarray
    scope: List[int]
    no_cols_split: bool = False
    no_rows_split: bool = False
    is_first: bool = False


def learn_spn(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: List[Union[list, tuple]],
    learn_leaf: Union[str, LearnLeafFunc] = 'mle',
    split_rows: Union[str, SplitRowsFunc] = 'kmeans',
    split_cols: Union[str, SplitColsFunc] = 'rdc',
    learn_leaf_kwargs: dict = None,
    split_rows_kwargs: dict = None,
    split_cols_kwargs: dict = None,
    min_rows_slice: int = 256,
    min_cols_slice: int = 2,
    random_state: Optional[RandomState] = None,
    verbose: bool = True
) -> Node:
    """
    Learn the structure and parameters of a SPN given some training data and several hyperparameters.

    :param data: The training data.
    :param distributions: A list of distribution classes (one for each feature).
    :param domains: A list of domains (one for each feature). Each domain is either a list of values, for discrete
                    distributions, or a tuple (consisting of min value and max value), for continuous distributions.
    :param learn_leaf: The method to use to learn a distribution leaf node,
                       It can be either 'mle', 'isotonic', 'binary-clt' or a custom LearnLeafFunc.
    :param split_rows: The rows splitting method.
                       It can be either 'kmeans', 'gmm', 'rdc', 'random' or a custom SplitRowsFunc function.
    :param split_cols: The columns splitting method.
                       It can be either 'gvs', 'rgvs', 'wrgvs', 'ebvs', 'ebvs_ae', 'gbvs', 'gbvs_ag', 'rdc', 'random'
                       or a custom SplitColsFunc function.
    :param learn_leaf_kwargs: The parameters of the learn leaf method.
    :param split_rows_kwargs: The parameters of the rows splitting method.
    :param split_cols_kwargs: The parameters of the cols splitting method.
    :param min_rows_slice: The minimum number of samples required to split horizontally.
    :param min_cols_slice: The minimum number of features required to split vertically.
    :param random_state: The random state. It can be either None, a seed integer or a Numpy RandomState.
    :param verbose: Whether to enable verbose mode.
    :return: A learned valid SPN.
    :raises ValueError: If a parameter is out of scope.
    """
    if len(distributions) == 0:
        raise ValueError("The list of distribution classes must be non-empty")
    if len(domains) == 0:
        raise ValueError("The list of domains must be non-empty")
    if min_rows_slice <= 0:
        raise ValueError("The minimum number of samples required to split horizontally must be positive")
    if min_cols_slice <= 0:
        raise ValueError("The minimum number of samples required to split vertically must be positive")

    n_samples, n_features = data.shape
    if len(distributions) != n_features or len(domains) != n_features:
        raise ValueError("Each data column should correspond to a random variable having a distribution and a domain")

    # Setup the learn leaf, split rows and split cols functions
    learn_leaf_func = get_learn_leaf_method(learn_leaf) if isinstance(learn_leaf, str) else learn_leaf
    split_rows_func = get_split_rows_method(split_rows) if isinstance(split_rows, str) else split_rows
    split_cols_func = get_split_cols_method(split_cols) if isinstance(split_cols, str) else split_cols

    if learn_leaf_kwargs is None:
        learn_leaf_kwargs = dict()
    if split_rows_kwargs is None:
        split_rows_kwargs = dict()
    if split_cols_kwargs is None:
        split_cols_kwargs = dict()

    # Setup the initial scope as [0, # of features - 1]
    initial_scope = list(range(n_features))

    # Check the random state
    random_state = check_random_state(random_state)

    # Add the random state to learning leaf parameters
    learn_leaf_kwargs['random_state'] = random_state

    # Initialize the progress bar (with unspecified total), if verbose is enabled
    if verbose:
        tk = tqdm(
            total=np.inf, leave=None, unit='node',
            bar_format='{n_fmt}/{total_fmt} [{elapsed}, {rate_fmt}]'
        )

    tasks = deque()
    tmp_node = Product(initial_scope)
    tasks.append(Task(tmp_node, data, initial_scope, is_first=True))

    while tasks:
        # Get the next task
        task = tasks.popleft()

        # Select the operation to apply
        n_samples, n_features = task.data.shape
        # Get the indices of uninformative features
        zero_var_idx = np.isclose(np.var(task.data, axis=0), 0.0)
        # If all the features are uninformative, then split using Naive Bayes model
        if np.all(zero_var_idx):
            op = OperationKind.SPLIT_NAIVE
        # If only some of the features are uninformative, then remove them
        elif np.any(zero_var_idx):
            op = OperationKind.REM_FEATURES
        # Create a leaf node if the data split dimension is small or last rows splitting failed
        elif task.no_rows_split or n_features < min_cols_slice or n_samples < min_rows_slice:
            op = OperationKind.CREATE_LEAF
        # Use rows splitting if previous columns splitting failed or it is the first task
        elif task.no_cols_split or task.is_first:
            op = OperationKind.SPLIT_ROWS
        # Defaults to columns splitting
        else:
            op = OperationKind.SPLIT_COLS

        if op == OperationKind.REM_FEATURES:
            node = Product(task.scope)

            # Model the removed features using Naive Bayes
            rem_scope = [task.scope[i] for i, in np.argwhere(zero_var_idx)]
            dists, doms = [distributions[s] for s in rem_scope], [domains[s] for s in rem_scope]
            naive = learn_naive_factorization(
                task.data[:, zero_var_idx], dists, doms, rem_scope,
                learn_leaf_func=learn_leaf_func, **learn_leaf_kwargs
            )
            node.children.append(naive)

            # Add the tasks regarding non-removed features
            is_first = task.is_first and len(tasks) == 0
            oth_scope = [task.scope[i] for i, in np.argwhere(~zero_var_idx)]
            tasks.append(Task(node, task.data[:, ~zero_var_idx], oth_scope, is_first=is_first))
            task.parent.children.append(node)
        elif op == OperationKind.CREATE_LEAF:
            # Create a leaf node
            dists, doms = [distributions[s] for s in task.scope], [domains[s] for s in task.scope]
            leaf = learn_leaf_func(task.data, dists, doms, task.scope, **learn_leaf_kwargs)
            task.parent.children.append(leaf)
        elif op == OperationKind.SPLIT_NAIVE:
            # Split the data using a naive factorization
            dists, doms = [distributions[s] for s in task.scope], [domains[s] for s in task.scope]
            node = learn_naive_factorization(
                task.data, dists, doms, task.scope,
                learn_leaf_func=learn_leaf_func, **learn_leaf_kwargs
            )
            task.parent.children.append(node)
        elif op == OperationKind.SPLIT_ROWS:
            # Split the data by rows (sum node)
            dists, doms = [distributions[s] for s in task.scope], [domains[s] for s in task.scope]
            clusters = split_rows_func(task.data, dists, doms, random_state, **split_rows_kwargs)
            slices, weights = split_rows_clusters(task.data, clusters)

            # Check whether only one partitioning is returned
            if len(slices) == 1:
                tasks.append(Task(task.parent, task.data, task.scope, no_cols_split=False, no_rows_split=True))
                continue

            # Add sub-tasks and append Sum node
            node = Sum(task.scope, weights=weights)
            for local_data in slices:
                tasks.append(Task(node, local_data, task.scope))
            task.parent.children.append(node)
        elif op == OperationKind.SPLIT_COLS:
            # Split the data by columns (product node)
            dists, doms = [distributions[s] for s in task.scope], [domains[s] for s in task.scope]
            clusters = split_cols_func(task.data, dists, doms, random_state, **split_cols_kwargs)
            slices, scopes = split_cols_clusters(task.data, clusters, task.scope)

            # Check whether only one partitioning is returned
            if len(slices) == 1:
                tasks.append(Task(task.parent, task.data, task.scope, no_cols_split=True, no_rows_split=False))
                continue

            # Add sub-tasks and append Product node
            node = Product(task.scope)
            for i, local_data in enumerate(slices):
                tasks.append(Task(node, local_data, scopes[i]))
            task.parent.children.append(node)
        else:
            raise NotImplementedError("Operation of kind {} not implemented".format(op))

        if verbose:
            tk.update()
            tk.refresh()

    if verbose:
        tk.close()

    root = tmp_node.children[0]
    return assign_ids(root)

In [None]:
#@title learn binary cutset network CNet

def compute_or_bd_scores(
    data: np.ndarray,
    ess: float = 0.1
):
    """
    Compute the BDeu scores for the candidate OR nodes given the data.

    :param data: The binary data matrix.
    :param ess: The equivalent sample size (ESS).
    :return: The score array.
    """
    n_samples, n_features = data.shape
    prior_counts = compute_prior_counts(data=data)
    alpha_i = ess
    alpha_ik = ess / 2
    log_gamma_nodes = gammaln(alpha_i) - gammaln(n_samples + alpha_i) \
                      + np.sum(gammaln(prior_counts + alpha_ik) - gammaln(alpha_ik), axis=-1)
    return log_gamma_nodes


def compute_clt_bd_scores(
    data: np.ndarray,
    ess: float = 0.1
):
    """
    Compute the pairwise BDeu scores for constructing a CLT given the data.

    :param data: The binary data matrix.
    :param ess: The equivalent sample size (ESS).
    :return: The pairwise BDeu score matrix.
    """
    joint_counts = compute_joint_counts(data=data)
    alpha_ij = ess / 2
    alpha_ijk = ess / (2 * 2)
    parent_counts = np.sum(joint_counts, axis=-2)
    log_gamma_pairs = gammaln(alpha_ij) - gammaln(parent_counts + alpha_ij) \
                      + np.sum(gammaln(joint_counts + alpha_ijk) - gammaln(alpha_ijk), axis=-2)
    return np.sum(log_gamma_pairs, axis=-1)


def estimate_clt_params_bayesian(
    clt: BinaryCLT,
    data: np.ndarray,
    ess: float = 0.1
):
    """
    Compute the Bayesian posterior parameters for a CLT.

    :param clt: The CLT.
    :param data: The binary data matrix.
    :param ess: The equivalent sample size (ESS).
    :return: The CLT parameters in the log space.
    """
    n_samples, n_features = data.shape
    priors, joints = estimate_priors_joints(data, alpha=ess / 4)

    vs = np.arange(n_features)
    params = np.einsum('ikl,il->ilk', joints[vs, clt.tree], np.reciprocal(priors[clt.tree]))
    params[clt.root] = priors[clt.root]

    # Re-normalize the factors, because there can be FP32 approximation errors
    params /= np.sum(params, axis=2, keepdims=True)
    return np.log(params)


def eval_tree_score(
    tree: Optional[List[int], np.ndarray],
    clt_scores: np.ndarray,
    or_scores: np.ndarray
):
    """
    Evaluate the BDeu score for a tree structure.

    :param tree: The tree structure.
    :param clt_scores: The pairwise score matrix.
    :param or_scores: The OR score array.
    :return: The BDeu score of the tree structure.
    """
    root_idx = tree.argmin()
    parent_indices_no_root = np.delete(tree, obj=root_idx)
    child_indices_no_root = np.delete(np.arange(len(tree)), obj=root_idx)
    return np.sum(clt_scores[child_indices_no_root, parent_indices_no_root]) + or_scores[root_idx]


def select_cand_cuts(
    data: np.ndarray,
    ess: float = 0.1,
    n_cand_cuts: int = 10
):
    """
    Select the candidate cutting nodes.

    :param data: The binary data.
    :param ess: The equivalent sample size (ESS).
    :param n_cand_cuts: The number of candidate cutting nodes.
    :return: The indices of the selected nodes.
    """
    # Compute the counts
    n_samples, n_features = data.shape
    counts_features = data.sum(axis=0)

    prior_counts = compute_prior_counts(data)
    joint_counts = compute_joint_counts(data)
    smoothing_joint, smoothing_prior = ess / 2, ess
    if ess < 0.01:
        prior_counts = prior_counts.astype(np.float64)
        joint_counts = joint_counts.astype(np.float64)
    log_priors = np.log(prior_counts + smoothing_joint) - np.log(n_samples + smoothing_prior)
    mean_entropy = -(log_priors * np.exp(log_priors)).sum() / n_features

    conditionals = np.empty((n_features, n_features, 2, 2), dtype=prior_counts.dtype)
    conditionals[:, :, 0, 0] = ((joint_counts[:, :, 0, 0] + smoothing_joint).T /
                                (prior_counts[:, 0] + smoothing_prior)).T
    conditionals[:, :, 0, 1] = ((joint_counts[:, :, 0, 1] + smoothing_joint).T /
                                (prior_counts[:, 0] + smoothing_prior)).T
    conditionals[:, :, 1, 0] = ((joint_counts[:, :, 1, 0] + smoothing_joint).T /
                                (prior_counts[:, 1] + smoothing_prior)).T
    conditionals[:, :, 1, 1] = ((joint_counts[:, :, 1, 1] + smoothing_joint).T /
                                (prior_counts[:, 1] + smoothing_prior)).T

    vs = np.repeat(np.arange(n_features)[None, :], n_features, axis=0)
    vs = vs[~np.eye(vs.shape[0], dtype=bool)].reshape(vs.shape[0], -1)
    parents = np.repeat(np.arange(n_features)[:, None], n_features - 1, axis=1)

    ratio_features = counts_features / n_samples
    entropies = ratio_features * \
                np.mean(-np.sum(conditionals[parents, vs, 1, :] * np.log(conditionals[parents, vs, 1, :]), axis=-1),
                        axis=1) + \
                (1 - ratio_features) * \
                np.mean(-np.sum(conditionals[parents, vs, 0, :] * np.log(conditionals[parents, vs, 0, :]), axis=-1),
                        axis=1)

    info_gains = mean_entropy - entropies
    selected_idx = np.argmax(info_gains) if n_cand_cuts == 1 else np.argpartition(info_gains,
                                                                                  -n_cand_cuts)[-n_cand_cuts:]
    return selected_idx


def learn_cnet_bd(
    data: np.ndarray,
    ess: float = 0.1,
    n_cand_cuts: int = 10,
):
    """
    Learn a binary CNet using the Bayesian-Dirichlet equivalent uniform (BDeu) score.

    :param cnet: The binary CNet.
    :param data: The training data.
    :param ess: The equivalent sample size (ESS).
    :param n_cand_cuts: The number of candidate cutting nodes.
    :return: A binary CNet.
    """
    n_samples, n_features = data.shape
    root = BinaryCNet(scope=list(range(n_features)))
    root.assign_indices(row_indices=np.arange(n_samples), col_indices=np.arange(n_features))
    root.fit_clt(data=data)
    # use Bayesian posterior parameters.
    root.clt.params = estimate_clt_params_bayesian(clt=root.clt, data=data, ess=ess)
    or_score_matrix = compute_or_bd_scores(data=data, ess=ess)
    clt_score_matrix = compute_clt_bd_scores(data=data, ess=ess)
    clt_score = eval_tree_score(tree=root.clt.tree, clt_scores=clt_score_matrix, or_scores=or_score_matrix)

    node_stack = [[root, ess, clt_score]]
    while node_stack:
        node, node_ess, node_clt_score = node_stack.pop(0)
        if len(node.scope) == 1:
            continue

        partition = data[node.row_indices][:, node.col_indices]
        or_score_matrix = compute_or_bd_scores(data=partition, ess=node_ess)

        k = min(n_cand_cuts, len(node.scope))
        search_indices = select_cand_cuts(data=partition, ess=node_ess, n_cand_cuts=k)

        best_or_idx = -1
        best_cnet_score = -np.inf
        best_left_clt = None
        best_right_clt = None
        best_left_clt_score = -np.inf
        best_right_clt_score = -np.inf

        for i in search_indices:
            left_row_indices = node.row_indices[partition[:, i] == 0]
            right_row_indices = node.row_indices[partition[:, i] == 1]

            if len(left_row_indices) == 0 or len(right_row_indices) == 0:
                continue

            child_col_indices = np.delete(node.col_indices, obj=i)
            left_partition = data[left_row_indices][:, child_col_indices]
            right_partition = data[right_row_indices][:, child_col_indices]
            new_scope = node.scope.copy()
            del new_scope[i]

            left_clt = BinaryCLT(scope=new_scope)
            left_or_score_matrix = compute_or_bd_scores(data=left_partition, ess=node_ess / 2)
            left_clt_score_matrix = compute_clt_bd_scores(data=left_partition, ess=node_ess / 2)

            right_clt = BinaryCLT(scope=new_scope)
            right_or_score_matrix = compute_or_bd_scores(data=right_partition, ess=node_ess / 2)
            right_clt_score_matrix = compute_clt_bd_scores(data=right_partition, ess=node_ess / 2)

            left_clt.fit(data=left_partition, domain=[[0, 1]] * len(new_scope), alpha=0.01)
            right_clt.fit(data=right_partition, domain=[[0, 1]] * len(new_scope), alpha=0.01)

            left_clt.params = estimate_clt_params_bayesian(left_clt, data=left_partition, ess=node_ess / 2)
            right_clt.params = estimate_clt_params_bayesian(right_clt, data=right_partition, ess=node_ess / 2)

            left_clt_score = eval_tree_score(tree=left_clt.tree,
                                             clt_scores=left_clt_score_matrix,
                                             or_scores=left_or_score_matrix)
            right_clt_score = eval_tree_score(tree=right_clt.tree,
                                              clt_scores=right_clt_score_matrix,
                                              or_scores=right_or_score_matrix)
            cnet_score = left_clt_score + right_clt_score + or_score_matrix[i]

            if cnet_score > best_cnet_score:
                best_cnet_score = cnet_score
                best_or_idx = i
                best_left_clt = left_clt
                best_right_clt = right_clt
                best_left_clt_score = left_clt_score
                best_right_clt_score = right_clt_score

        if best_cnet_score > node_clt_score:
            node.or_id = node.scope[best_or_idx]
            node.clt = None
            left_row_indices = node.row_indices[partition[:, best_or_idx] == 0]
            right_row_indices = node.row_indices[partition[:, best_or_idx] == 1]
            child_col_indices = np.delete(node.col_indices, obj=best_or_idx)
            left_weight = (len(left_row_indices) + node_ess / 2) / (len(node.row_indices) + node_ess)
            right_weight = 1 - left_weight
            new_scope = node.scope.copy()
            del new_scope[best_or_idx]

            left_child = BinaryCNet(scope=new_scope)
            left_child.clt = best_left_clt
            left_child.assign_indices(row_indices=left_row_indices, col_indices=child_col_indices)
            right_child = BinaryCNet(scope=new_scope)
            right_child.clt = best_right_clt
            right_child.assign_indices(row_indices=right_row_indices, col_indices=child_col_indices)

            node_stack.append([left_child, node_ess / 2, best_left_clt_score])
            node_stack.append([right_child, node_ess / 2, best_right_clt_score])
            node.weights = [left_weight, right_weight]
            node.children = [left_child, right_child]
    return root


def learn_cnet_bic(
    data: np.ndarray,
    alpha: float = 0.01,
    n_cand_cuts: int = 10,
):
    """
    Learn a binary CNet using the Bayesian Information Criterion (BIC) score.

    :param data: The binary data.
    :param alpha: The Laplace smoothing factor.
    :param n_cand_cuts: The number of candidate cutting nodes.
    :return: A binary CNet.
    """
    n_samples, n_features = data.shape
    root = BinaryCNet(scope=list(range(n_features)))
    root.assign_indices(row_indices=np.arange(n_samples), col_indices=np.arange(n_features))
    root.fit_clt(data=data)
    clt_score = np.sum(root.clt.log_likelihood(data)) - 0.5 * np.log(n_samples) * (2 * n_features - 1)

    node_stack = [[root, clt_score]]
    while node_stack:
        node, node_clt_score = node_stack.pop(0)
        if len(node.scope) == 1:
            continue

        partition = data[node.row_indices][:, node.col_indices]

        k = min(n_cand_cuts, len(node.scope))
        search_indices = select_cand_cuts(partition, ess=4 * alpha, n_cand_cuts=k)

        best_or_idx = -1
        best_cnet_score = -np.inf
        best_left_clt = None
        best_right_clt = None
        best_left_clt_score = 0.0
        best_right_clt_score = 0.0
        for i in search_indices:
            left_row_indices = node.row_indices[partition[:, i] == 0]
            right_row_indices = node.row_indices[partition[:, i] == 1]

            if len(left_row_indices) == 0 or len(right_row_indices) == 0:
                continue

            child_col_indices = np.delete(node.col_indices, obj=i)
            left_partition = data[left_row_indices][:, child_col_indices]
            right_partition = data[right_row_indices][:, child_col_indices]
            new_scope = node.scope.copy()
            del new_scope[i]

            left_weight = (len(left_row_indices) + alpha) / (len(node.row_indices) + 2 * alpha)
            right_weight = 1 - left_weight

            left_clt = BinaryCLT(scope=new_scope)
            right_clt = BinaryCLT(scope=new_scope)

            left_clt.fit(data=left_partition, domain=[[0, 1]] * len(new_scope), alpha=alpha)
            right_clt.fit(data=right_partition, domain=[[0, 1]] * len(new_scope), alpha=alpha)

            left_clt_score = np.sum(left_clt.log_likelihood(left_partition)) \
                             - 0.5 * np.log(len(data)) * (2 * len(new_scope) - 1)
            right_clt_score = np.sum(right_clt.log_likelihood(right_partition)) \
                              - 0.5 * np.log(len(data)) * (2 * len(new_scope) - 1)
            or_score = len(left_partition) * np.log(left_weight) + len(right_partition) * np.log(right_weight) \
                       - 0.5 * np.log(len(data))
            cnet_score = left_clt_score + right_clt_score + or_score

            if cnet_score > best_cnet_score:
                best_cnet_score = cnet_score
                best_or_idx = i
                best_left_clt = left_clt
                best_right_clt = right_clt
                best_left_clt_score = left_clt_score
                best_right_clt_score = right_clt_score

        if best_cnet_score > node_clt_score:
            node.or_id = node.scope[best_or_idx]
            node.clt = None
            left_row_indices = node.row_indices[partition[:, best_or_idx] == 0]
            right_row_indices = node.row_indices[partition[:, best_or_idx] == 1]
            child_col_indices = np.delete(node.col_indices, obj=best_or_idx)

            left_weight = (len(left_row_indices) + alpha) / (len(node.row_indices) + 2 * alpha)
            right_weight = 1 - left_weight
            new_scope = node.scope.copy()
            del new_scope[best_or_idx]

            left_child = BinaryCNet(scope=new_scope)
            left_child.clt = best_left_clt
            left_child.assign_indices(row_indices=left_row_indices, col_indices=child_col_indices)
            right_child = BinaryCNet(scope=new_scope)
            right_child.clt = best_right_clt
            right_child.assign_indices(row_indices=right_row_indices, col_indices=child_col_indices)

            node_stack.append([left_child, best_left_clt_score])
            node_stack.append([right_child, best_right_clt_score])
            node.weights = [left_weight, right_weight]
            node.children = [left_child, right_child]
    return root


In [None]:
#@title learning eXtremely randomized Probabilistic Circuit (XPC) or its ensemble


# SD stands for Structured Decomposable
SD_LEVEL_0 = 0  # non-SD ensemble of non-SD PCs
SD_LEVEL_1 = 1  # non-SD ensemble OF SD PCs
SD_LEVEL_2 = 2  # SD ensemble
SD_LEVELS = [SD_LEVEL_0, SD_LEVEL_1, SD_LEVEL_2]

ROOT = -1


def build_disjunction(
    data: np.ndarray,
    scope: list,
    assignments: Optional[np.ndarray] = None,
    alpha: float = 0.01
) -> Node:
    """
    Build a disjunction (sum node) of conjunctions (product nodes).
    If assignments are given, every conjunction is associated to a specific assignment (the number of conjunctions
    is the same as the given assignments); otherwise, every conjunction will be associated to a specific
    assignment occurring in the input data (the number of conjunctions is the same as the unique assignments
    occurring in the data).

    :param data: The input data matrix.
    :param scope: The scope.
    :param assignments: The optional assignments.
    :param alpha: Laplace smoothing factor.
    """
    unq_data, counts = np.unique(data, axis=0, return_counts=True)
    assignments = unq_data if assignments is None else assignments
    assert unq_data.shape[0] <= assignments.shape[0]

    weights = np.zeros(assignments.shape[0])
    for i in range(assignments.shape[0]):
        index = np.where(np.all(assignments[i] == unq_data, axis=1))[0]
        if len(index):
            weights[i] = counts[index[0]]
    weights = (weights + alpha) / (weights + alpha).sum()

    prod_nodes = []
    for i in range(assignments.shape[0]):
        children = []
        for j in range(assignments.shape[1]):
            children.append(Bernoulli(scope=[scope[j]], p=assignments[i, j]))
        prod_nodes.append(Product(children=children))

    disjunction = Sum(children=prod_nodes, weights=weights) if len(prod_nodes) > 1 else prod_nodes[0]
    return assign_ids(disjunction)


def build_leaf(
    data: np.ndarray,
    part: Partition,
    use_clt: bool,
    trees_dict: dict,
    det: bool,
    alpha: float
) -> Node:
    """
    Build a multivariate leaf distribution for an XPC.

    :param data: The input data matrix.
    :param part: The partition associated to the leaf to build.
    :param use_clt: True if it is possible to use CLTrees as leaf nodes, False otherwise.
    :param trees_dict: A dictionary of trees (see the function build_trees_dict).
    :param det: True to force determinism, False otherwise.
    :param alpha: Laplace smoothing factor.
    """
    data_slice = part.get_slice(data)
    scope = part.col_ids.tolist()

    if part.is_conj:
        leaf = Product(children=[Bernoulli(scope=[scope[k]], p=float(data_slice[0][k])) for k in range(len(scope))])

    elif part.is_naive or not use_clt:
        if not det or part.disc_assignments.shape[0] == 2 ** part.disc_assignments.shape[1]:
            leaf = learn_mle(data_slice, [Bernoulli] * len(scope), [[0, 1]] * len(scope), scope, alpha)
        else:
            leaf = build_disjunction(data=data_slice, scope=scope, assignments=part.disc_assignments, alpha=alpha)

    else:
        if trees_dict is not None:
            tree, scope = trees_dict[len(scope)]
            data_slice = data[part.row_ids][:, scope]
        else:
            tree, scope = None, part.col_ids.tolist()
        leaf = BinaryCLT(tree=tree, scope=scope)
        leaf.fit(data_slice, domain=[[0, 1]] * len(scope), alpha=0.01)

    return leaf


def greedy_vars_ordering(
    data: np.ndarray,
    conj_len: int,
    alpha: float = 0.01
) -> list:
    """
    Return the ordering of the random variables according to the implemented heuristic.

    :param data: The input data matrix.
    :param conj_len: The conjunction length.
    :param alpha: Laplace smoothing factor.

    :return ordering: The ordering.
    """
    priors, joints = estimate_priors_joints(data, alpha)
    mut_info = compute_mutual_information(priors, joints)
    sums = np.sum(mut_info, axis=0)

    ordering = []
    free_vars = np.arange(data.shape[1]).tolist()
    while free_vars:
        peek_var = free_vars[np.argmax(sums[free_vars])]
        free_vars.remove(peek_var)
        ordering.append(peek_var)
        if len(free_vars) > conj_len - 1:
            idx = np.argpartition(-mut_info[peek_var][free_vars], conj_len - 1)[:conj_len - 1]
            vars_ = np.array(free_vars)[idx].tolist()
        else:
            vars_ = free_vars.copy()
        free_vars = list(set(free_vars) - set(vars_))
        ordering.extend(vars_)
    return ordering


def build_trees_dict(
    data: np.ndarray,
    cl_parts_l: list,
    conj_vars_l: list,
    alpha: float,
    random_state: np.random.RandomState
) -> dict:
    """
    Return a dictionary where:
     - a key refers to a scope length
     - a value is a list of two lists: the first is a list of predecessors, the second its scope.

    :param data: The input data matrix.
    :param cl_parts_l: List of lists. Every sublist is associated to a specific XPC and contains
     the leaf partitions over which a CLTree will be learnt.
    :param conj_vars_l: List of lists. Every sublist contains the variables of a conjunction (e.g. [[3, 5]]).
     If a sublist occurs before another, then the former has been used first. There are no duplicates.
    :param alpha: Laplace smoothing factor.
    :param random_state: The random state.

    :return tree_dict: The dictionary.
    """
    # Compute the mutual information for each slice associated to every partition in cl_parts_l
    # and add it to a cumulative matrix (cumulative_info).
    n_vars = data.shape[1]
    cumulative_info = np.zeros((n_vars, n_vars))
    for cl_parts in cl_parts_l:
        for part in cl_parts:
            priors, joints = estimate_priors_joints(part.get_slice(data), alpha)
            mi = compute_mutual_information(priors, joints)
            cumulative_info[part.col_ids[:, None], part.col_ids] += mi

    # Free_vars are the variables not involved in any conjunction and will appear at the bottom of the circuit
    free_vars = list(set(np.arange(n_vars)) - set([var for conj_vars in conj_vars_l for var in conj_vars]))

    # Create a tree for each scope in scopes
    scopes = conj_vars_l + [free_vars] if free_vars else conj_vars_l
    trees = []
    for scope in scopes:
        _, tree = maximum_spanning_tree(
            adj_matrix=cumulative_info[scope][:, scope],
            root=scope.index(random_state.choice(scope))
        )
        trees.append(list(tree))

    # Concatenate trees and create the dictionary
    # The root of every tree is added as child to the root node of the tree with the minimum higher length scope.
    tree_dict = dict()
    tree = trees[-1].copy()
    scope = scopes[-1].copy()
    for k in reversed(range(0, len(trees) - 1)):
        tree_dict[len(scope)] = [tree.copy(), scope.copy()]
        tree += [t + len(scope) if t != ROOT else t for t in trees[k]]
        tree[tree.index(ROOT)] = tree.index(ROOT, len(scope))
        scope += scopes[k]

    return tree_dict


def build_xpc(
    data: np.ndarray,
    part_root: Partition,
    trees_dict: dict,
    det: bool,
    use_clt: bool,
    alpha: float
) -> Node:
    """
    Build the XPC induced by the partitions tree in a bottom up way.
    The building process is based on the post-order traversal exploration of the partitions tree.

    :param data: The input data matrix.
    :param part_root: The root partition of the tree.
    :param trees_dict: None if no dependency tree has to be respected, a dictionary of trees otherwise.
    :param det: True to force determinism, False otherwise.
    :param use_clt: True to use CLTrees as leaf nodes, False otherwise.
    :param alpha: Laplace smoothing factor.

    :return: the XPC induced by the partition tree
    """
    partitions_stack = [part_root]
    pc_nodes_stack = []
    last_part_visited = None

    while partitions_stack:
        part = partitions_stack[-1]
        if not part.is_partitioned() or (last_part_visited in part.sub_partitions):
            if part.is_partitioned():
                pc_child_nodes = pc_nodes_stack[-len(part.sub_partitions):]
                pc_nodes_stack = pc_nodes_stack[:len(pc_nodes_stack) - len(part.sub_partitions)]
                if part.is_horizontally_partitioned():
                    # Create sum node
                    weights = [len(sub_part.row_ids) / len(part.row_ids) for sub_part in part.sub_partitions]
                    pc_nodes_stack.append(Sum(weights=weights, children=pc_child_nodes))
                else:
                    # Create product node
                    pc_child_nodes_ = []
                    for c in pc_child_nodes:
                        if isinstance(c, Product) or (isinstance(c, Sum) and len(c.children) == 1):
                            pc_child_nodes_.extend(c.children)
                        else:
                            pc_child_nodes_.append(c)
                    pc_prod_node = Product(children=pc_child_nodes_)
                    pc_nodes_stack.append(pc_prod_node)
            else:
                # Create leaf (it could be either a PC or a multivariate leaf)
                leaf = build_leaf(data, part, use_clt, trees_dict, det, alpha)
                pc_nodes_stack.append(leaf)
            last_part_visited = partitions_stack.pop()
        else:
            partitions_stack.extend(part.sub_partitions[::-1])

    xpc = pc_nodes_stack[0]
    assign_ids(xpc)
    return xpc


def learn_xpc(
    data: np.ndarray,
    det: bool,
    sd: bool,
    min_part_inst: int,
    conj_len: int,
    arity: int,
    n_max_parts: int = 200,
    use_clt: bool = True,
    use_greedy_ordering: Optional[bool] = False,
    alpha: int = 0.01,
    random_seed: int = 42
) -> Tuple[Node, dict]:
    """
    Learn an eXtremely randomized Probabilistic Circuit (XPC).

    :param data: The input data matrix.
    :param det: True to force determinism, False otherwise.
    :param sd: True to force structured decomposability, False otherwise.
    :param min_part_inst: The minimum number of instances allowed per partition.
    :param conj_len: The conjunction length.
    :param arity: The maximum number of children for a sum node.
    :param n_max_parts: The maximum number of partitions for the partitions tree.
    :param use_clt: True to use CLTrees as multivariate leaves, False otherwise.
    :param use_greedy_ordering: True to use a greedy ordering, False otherwise.
    :param alpha: Laplace smoothing factor.
    :param random_seed: Random State.
    """
    assert arity > 1 or arity <= 2 ** conj_len, 'Arity must be in the interval [2, 2 ** conj_len]'
    assert sd or not use_greedy_ordering, 'Using the greedy ordering makes sense only if sd = True.'

    random_state = np.random.RandomState(random_seed)
    if use_greedy_ordering:
        ordering = greedy_vars_ordering(data, conj_len)
    else:
        ordering = np.arange(data.shape[1]).tolist()
        random_state.shuffle(ordering)

    part_root, cl_parts_l, conj_vars_l, n_parts = \
        generate_random_partitioning(
            data=data,
            sd=sd,
            min_part_inst=min_part_inst,
            conj_len=conj_len,
            arity=arity,
            n_max_parts=n_max_parts,
            uncond_vars=ordering,
            random_state=random_state)
    assert n_parts > 1, 'No partitioning found.'

    trees_dict = None
    if sd and use_clt:
        trees_dict = build_trees_dict(data, [cl_parts_l], conj_vars_l, alpha, random_state)

    # creating useful dictionary
    utils = {'part_root': part_root, 'cl_parts_l': cl_parts_l, 'conj_vars_l': conj_vars_l,
             'n_parts': n_parts, 'trees_dict': trees_dict}
    xpc = build_xpc(data, part_root, trees_dict, det, use_clt, alpha)
    return xpc, utils


def learn_expc(
    data: np.ndarray,
    ensemble_dim: int,
    det: bool,
    sd_level: int,
    min_part_inst: int,
    conj_len: int,
    arity: int,
    n_max_parts: int = 200,
    use_clt: bool = True,
    alpha: int = 0.01,
    random_seed: int = 42
) -> Tuple[Node, list]:
    """
    Learn an Ensemble (i.e. a mixture) of eXtremely randomized Probabilistic Circuit (EXPC).

    :param data: The input data matrix.
    :param ensemble_dim: The number of circuits in the ensemble/mixture.
    :param det: True to force determinism, False otherwise.
    :param sd_level: 0 a non-SD ensemble of non-SD PCs, 1 for a non-SD ensemble of SD PCs and 2 for a SD ensemble.
    :param min_part_inst: The minimum number of instances allowed per partition.
    :param conj_len: The conjunction length.
    :param arity: The maximum number of children for a Sum node.
    :param n_max_parts: The maximum number of partitions for the partitions tree.
    :param use_clt: True to use CLTrees as multivariate leaves, False otherwise.
    :param alpha: Laplace smoothing factor.
    :param random_seed: A random seed.
    """
    assert sd_level in SD_LEVELS, 'Choose a value in {0, 1, 2}.'
    assert arity > 1 or arity <= 2 ** conj_len, 'Arity must be in the interval [2, 2 ** conj_len].'
    assert not (sd_level == SD_LEVEL_2 and conj_len == 1), 'No randomness in this setting. Change hyper parameters.'

    random_state = np.random.RandomState(random_seed)
    conj_vars_l_l = [None] * ensemble_dim
    cl_parts_l_l = [None] * ensemble_dim
    trees_dict_l = [None] * ensemble_dim
    part_root_l = [None] * ensemble_dim
    n_parts_l = [None] * ensemble_dim
    xpc_l = [None] * ensemble_dim

    sd = (sd_level in [SD_LEVEL_1, SD_LEVEL_2])
    if sd_level == SD_LEVEL_2:
        ordering = greedy_vars_ordering(data, conj_len)
    else:
        ordering = np.arange(data.shape[1]).tolist()

    for i in range(ensemble_dim):
        if sd_level != SD_LEVEL_2:
            np.random.shuffle(ordering)
        part_root_l[i], cl_parts_l_l[i], conj_vars_l_l[i], n_parts_l[i] = \
            generate_random_partitioning(
                data=data,
                sd=sd,
                min_part_inst=min_part_inst,
                conj_len=conj_len,
                arity=arity,
                n_max_parts=n_max_parts,
                uncond_vars=ordering,
                random_state=random_state)
    assert not all(n_parts == 1 for n_parts in n_parts_l), 'No Partitioning Found'

    if sd_level == SD_LEVEL_0 or not use_clt:
        # no tree structure to respect
        trees_dict = None
        for i in range(ensemble_dim):
            print('Learning XPC %s/%s' % (i + 1, ensemble_dim))
            xpc_l[i] = build_xpc(data, part_root_l[i], trees_dict, det, use_clt, alpha)
    elif sd_level == SD_LEVEL_1:
        for i in range(ensemble_dim):
            print('Learning XPC %s/%s' % (i + 1, ensemble_dim))
            # learn a tree for each XPC
            trees_dict_l[i] = build_trees_dict(data, [cl_parts_l_l[i]], conj_vars_l_l[i], alpha, random_state)
            xpc_l[i] = build_xpc(data, part_root_l[i], trees_dict_l[i], det, use_clt, alpha)
    elif sd_level == SD_LEVEL_2:
        # learn a tree structure for the whole ensemble
        print('Learning a dependency tree for the ensemble..')
        trees_dict = build_trees_dict(data, cl_parts_l_l, max(conj_vars_l_l, key=len), alpha, random_state)
        trees_dict_l = [trees_dict] * ensemble_dim
        for i in range(ensemble_dim):
            print('Building XPC %s/%s' % (i + 1, ensemble_dim))
            xpc_l[i] = build_xpc(data, part_root_l[i], trees_dict, det, use_clt, alpha)

    # creating useful list of dictionaries
    utils = [{'part_root': part_root_l[i], 'cl_parts_l': cl_parts_l_l[i], 'conj_vars_l': conj_vars_l_l[i],
              'n_parts': n_parts_l[i], 'trees_dict': trees_dict_l[i]} for i in range(ensemble_dim)]
    expc = Sum(weights=np.full(ensemble_dim, 1 / ensemble_dim), children=xpc_l)
    assign_ids(expc)
    return expc, utils

In [None]:
#@title wrapped as estimator and classifier


def learn_estimator(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: Optional[List[Union[list, tuple]]] = None,
    method: str = 'learnspn',
    **kwargs
) -> Node:
    """
    Learn a SPN density estimator given some training data, the features distributions and domains.

    :param data: The training data.
    :param distributions: A list of distribution classes (one for each feature).
    :param domains: A list of domains (one for each feature). Each domain is either a list of values, for discrete
                    distributions, or a tuple (consisting of min value and max value), for continuous distributions.
                    If None, domains are determined automatically.
    :param method: The method used for structure learning. It can be either 'learnspn', 'xpc' or 'ensemble-xpc'.
    :param kwargs: Additional parameters for structure learning.
    :return: A learned valid and optimized SPN.
    :raises ValueError: If the method used for structure learning is not known.
    :raises ValueError: If the method is 'xpc' or 'ensemble-xpc' but the variable domains are not binary.
    """
    if domains is None:
        domains = compute_data_domains(data, distributions)

    if method == 'learnspn':
        root = learn_spn(data, distributions, domains, **kwargs)
        return prune(root, copy=False)
    if method == 'xpc':
        if not all(d == [0, 1] for d in domains):
            raise ValueError("The domains must be binary for learning a XPC")
        root, _ = learn_xpc(data, **kwargs)
        return root
    if method == 'ensemble-xpc':
        if not all(d == [0, 1] for d in domains):
            raise ValueError("The domains must be binary for learning an Ensemble-XPC")
        root, _ = learn_expc(data, **kwargs)
        return root
    raise ValueError("Unknown SPN learning method called {}".format(method))


def learn_classifier(
    data: np.ndarray,
    distributions: List[Type[Leaf]],
    domains: Optional[List[Union[list, tuple]]] = None,
    class_idx: int = -1,
    verbose: bool = True,
    **kwargs
) -> Node:
    """
    Learn a SPN classifier given some training data, the features distributions and domains and
    the class index in the training data.

    :param data: The training data.
    :param distributions: A list of distribution classes (one for each feature).
    :param domains: A list of domains (one for each feature). Each domain is either a list of values, for discrete
                    distributions, or a tuple (consisting of min value and max value), for continuous distributions.
                    If None, domains are determined automatically.
    :param class_idx: The index of the class feature in the training data.
    :param verbose: Whether to enable verbose mode.
    :param kwargs: Other parameters for structure learning.
    :return: A learned valid and optimized SPN.
    """
    if domains is None:
        domains = compute_data_domains(data, distributions)

    n_samples, _ = data.shape
    classes = data[:, class_idx]

    # Initialize the tqdm wrapped unique classes array, if verbose is enabled
    unique_classes = np.unique(classes)
    if verbose:
        unique_classes = tqdm(unique_classes, bar_format='{l_bar}{bar:24}{r_bar}', unit='class')

    # Learn each sub-spn's structure individually
    weights = []
    children = []
    for c in unique_classes:
        local_data = data[classes == c]
        weight = len(local_data) / n_samples
        branch = learn_spn(local_data, distributions, domains, verbose=verbose, **kwargs)
        weights.append(weight)
        children.append(prune(branch, copy=False))

    root = Sum(children=children, weights=weights)
    return assign_ids(root)


def compute_data_domains(data: np.ndarray, distributions: List[Type[Leaf]]) -> List[Union[list, tuple]]:
    """
    Compute the domains based on the training data and the features distributions.

    :param data: The training data.
    :param distributions: A list of distribution classes.
    :return: A list of domains. Each domain is either a list of values, for discrete distributions, or
             a tuple (consisting of min value and max value), for continuous distributions.
    :raises ValueError: If an unknown distribution type is found.
    """
    domains = []
    for i, d in enumerate(distributions):
        col = data[:, i]
        if d.LEAF_TYPE == LeafType.DISCRETE:
            vals = np.unique(col).tolist()
            domains.append(vals)
        elif d.LEAF_TYPE == LeafType.CONTINUOUS:
            vmin = np.min(col).item()
            vmax = np.max(col).item()
            domains.append((vmin, vmax))
        else:
            raise ValueError("Unknown distribution type {}".format(d.LEAF_TYPE))
    return domains

# Models

In [None]:
#@title vanilla spn

class SPNEstimator(BaseEstimator, DensityMixin):
    def __init__(
        self,
        distributions: List[Type[Leaf]],
        domains: Optional[List[Union[list, tuple]]] = None,
        **kwargs
    ):
        """
        Scikit-learn density estimator model for Sum Product Networks (SPNs).

        :param distributions: A list of distribution classes (one for each feature).
        :param domains: A list of domains (one for each feature).
        :param kwargs: Additional arguments to pass to the SPN learner.
        """
        super().__init__()
        self.distributions = distributions
        self.domains = domains
        self.kwargs = kwargs
        self.spn_ = None
        self.n_features_ = 0

    def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None):
        """
        Fit the SPN density estimator.

        :param X: The training data.
        :param y: Ignored, only for scikit-learn API convention.
        :return: Itself.
        """
        self.spn_ = learn_estimator(X, self.distributions, self.domains, **self.kwargs)
        self.n_features_ = X.shape[1]
        return self

    def predict_log_proba(self, X: np.ndarray) -> np.ndarray:
        """
        Predict using the SPN density estimator, i.e. compute the log-likelihood.

        :param X: The inputs. They can be marginalized using NaNs.
        :return: The log-likelihood of the inputs.
        """
        return log_likelihood(self.spn_, X)

    def mpe(self, X: np.ndarray) -> np.ndarray:
        """
        Predict the un-observed variable by maximum at posterior estimation (MPE).

        :param X: The inputs having some NaN values.
        :return: The MPE assignment to un-observed variables.
        """
        return mpe(self.spn_, X, inplace=False)

    def sample(self, n: Optional[int] = None, X: Optional[np.ndarray] = None) -> np.ndarray:
        """
        Sample from the modeled distribution.

        :param n: The number of samples. It must be None if X is not None. If None, n=1 is assumed.
        :param X: Data used for conditional sampling. It can be None for full sampling.
        :return: The samples.
        :raise ValueError: If both parameters 'n' and 'X' are passed by.
        """
        if n is not None and X is not None:
            raise ValueError("Only one between 'n' and 'X' can be specified")

        if X is not None:
            # Conditional sampling
            return sample(self.spn_, X, inplace=False)
        else:
            # Full sampling
            n = 1 if n is None else n
            x = np.tile(np.nan, [n, self.n_features_])
            return sample(self.spn_, x, inplace=True)

    def score(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> dict:
        """
        Return the mean log-likelihood and two standard deviations on the given test data.

        :param X: The inputs. They can be marginalized using NaNs.
        :param y: Ignored. Specified only for scikit-learn API compatibility.
        :return: A dictionary consisting of two keys "mean_ll" and "stddev_ll",
                 representing respectively the mean log-likelihood and two standard deviations.
        """
        ll = self.predict_log_proba(X)
        mean_ll = np.mean(ll)
        stddev_ll = np.std(ll)
        return {
            'mean_ll': mean_ll,
            'stddev_ll': 2.0 * stddev_ll / np.sqrt(len(X))
        }


class SPNClassifier(BaseEstimator, ClassifierMixin):
    def __init__(
        self,
        distributions: List[Type[Leaf]],
        domains: Optional[List[Union[list, tuple]]] = None,
        **kwargs
    ):
        """
        Scikit-learn classifier model for Sum Product Networks (SPNs).

        :param distributions: A list of distribution classes (one for each feature).
        :param domains: A list of domains (one for each feature).
        :param kwargs: Additional arguments to pass to the SPN learner.
        """
        super().__init__()
        self.distributions = distributions
        self.domains = domains
        self.kwargs = kwargs
        self.spn_ = None
        self.n_features_ = 0
        self.n_classes_ = 0

    def fit(self, X: np.ndarray, y: np.ndarray):
        """
        Fit the SPN density estimator.

        :param X: The training data.
        :param y: The data labels.
        :return: Itself.
        """
        # Build the training data, consisting of labels
        y = np.expand_dims(y, axis=1)
        data = np.hstack([X, y])

        # Constructs the list of distributions
        n_classes = len(np.unique(y))
        if n_classes == 2:
            # Use bernoulli for binary classification
            distributions = self.distributions + [Bernoulli]
        else:
            # otherwise, use a categorical distribution
            distributions = self.distributions + [Categorical]

        self.spn_ = learn_classifier(data, distributions, self.domains, **self.kwargs)
        self.n_features_ = X.shape[1]
        self.n_classes_ = n_classes
        return self

    def predict(self, X: np.ndarray) -> np.ndarray:
        """
        Predict using the SPN classifier.

        :param X: The inputs. They can be marginalized using NaNs.
        :return: The predicted classes.
        """
        # Build the testing data, having X as features assignments and NaNs for labels
        data = np.hstack([X, np.full([len(X), 1], np.nan)])

        # Make classification using maximum probable explanation (MPE)
        mpe(self.spn_, data, inplace=True)

        # Return the classifications for each sample
        return data[:, -1]

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        """
        Predict using the SPN classifier, using probabilities.

        :param X: The inputs. They can be marginalized using NaNs.
        :return: The prediction probabilities for each class.
        """
        return np.exp(self.predict_log_proba(X))

    def predict_log_proba(self, X: np.ndarray) -> np.ndarray:
        """
        Predict using the SPN classifier, using log-probabilities.

        :param X: The inputs. They can be marginalized using NaNs.
        :return: The prediction log-probabilities for each class.
        """
        # Build the testing data, having X as features assignments and NaNs for labels
        data = np.hstack([X, np.tile(np.nan, [len(X), 1])])

        # Make probabilistic classification by computing the log-likelihoods at sub-class SPN
        _, lls = log_likelihood(self.spn_, data, return_results=True)

        # Collect the predicted class probabilities
        class_ids = [c.id for c in self.spn_.children]
        class_ll = np.log(self.spn_.weights) + lls[class_ids]
        return log_softmax(class_ll, axis=1)

    def sample(self, n: Optional[int] = None, y: Optional[np.ndarray] = None) -> np.ndarray:
        """
        Sample from the modeled conditional distribution.

        :param n: The number of samples. It must be None if y is not None. If None, n=1 is assumed.
        :param y: Labels used for conditional sampling. It can be None for un-conditional sampling.
        :return: The samples.
        """
        if n is not None and y is not None:
            raise ValueError("Only one between 'n' and 'y' can be specified")

        # Conditional sampling
        if y is not None:
            y = np.expand_dims(y, axis=1)
            x = np.hstack([np.tile(np.nan, [len(y), self.n_features_]), y])
            return sample(self.spn_, x, inplace=False)

        # Full sampling
        n = 1 if n is None else n
        x = np.tile(np.nan, [n, self.n_features_ + 1])
        return sample(self.spn_, x, inplace=True)

# Experiment

In [None]:
#@title naive spn

# Sample some binary data randomly
np.random.seed(42)
n_samples, n_features = 1000, 10
data = np.random.binomial(1, p=0.4, size=[n_samples, n_features])

# Set the features distributions and domains
distributions = [Bernoulli] * n_features
domains = [[0, 1]] * n_features  # Use lists to specify discrete domains

# Learn a naive factorized model from a subset of the data
scope = [5, 1, 7]
dists = [distributions[s] for s in scope]
doms = [domains[s] for s in scope]
naive = learn_naive_factorization(
    data[:, scope], dists, doms, scope,
    learn_leaf_func=learn_mle,  # Use MLE to learn the leaf distributions
    alpha=0.01  # Additional learn_mle parameters, for example the Laplace smoothing factor
)

# Compute the average likelihood
    ## ls is same as: scipy.stats.bernoulli.pmf(data[~mask], naive.children[child_id].p) than multiply scope columns (variables) on each other
    ## reason of using thi is that is "uses ones like zeros"
    ## look for it in Bernoulli class (P.S. pmf is same for discrete as pdf for continuous)

ls = likelihood(naive, data)
print("Average Likelihood: {:.4f}".format(np.mean(ls)))

# Print some statistics about the model's structure and parameters
print("SPN structure and parameters statistics:")
print(compute_statistics(naive))

len(data), data.shape, naive.children[0].p, (naive.children[0].p * naive.children[1].p * naive.children[2].p), set(ls)

In [None]:
#@title snp classifier (from latent) part 1

import numpy as np
import sklearn as sk
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from tqdm import tqdm


# Setup the MNIST datasets
n_classes = 10
n_features = (image_c, image_h, image_w) = (1, 28, 28)
n_dimensions = np.prod(n_features).item()
transform = transforms.ToTensor()
data_train = datasets.MNIST('datasets', train=True, transform=transform, download=True)
data_test = datasets.MNIST('datasets', train=False, transform=transform, download=True)

# Build the autoencoder for features extraction
latent_dim = 24  # Use 24 features as latent space
encoder = nn.Sequential(
    nn.Flatten(),
    nn.Linear(n_dimensions, 512), nn.ReLU(inplace=True),
    nn.Linear(512, 256), nn.ReLU(inplace=True),
    nn.Linear(256, latent_dim), nn.Tanhshrink(),
)
decoder = nn.Sequential(
    nn.Linear(latent_dim, 256), nn.ReLU(inplace=True),
    nn.Linear(256, 512), nn.ReLU(inplace=True),
    nn.Linear(512, 784), nn.Sigmoid(),
    nn.Unflatten(1, n_features)
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
autoencoder = nn.Sequential(encoder, decoder).to(device)

# Train the autoencoder, by minimizing the reconstruction binary cross-entropy
epochs = 25
batch_size = 100
lr = 1e-3
train_loader = data.DataLoader(data_train, batch_size=batch_size, shuffle=True)
optimizer = optim.Adam(autoencoder.parameters(), lr=lr)
criterion = nn.BCELoss()
tk_epochs = tqdm(range(epochs), bar_format='{l_bar}{bar:24}{r_bar}', unit='epoch')
for epoch in tk_epochs:
    train_loss = 0.0
    for (inputs, _) in train_loader:
        optimizer.zero_grad()
        inputs = inputs.to(device)
        outputs = autoencoder(inputs)
        loss = criterion(outputs, inputs)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * inputs.shape[0]
    train_loss /= len(train_loader)
    tk_epochs.set_description('Train Loss: {}'.format(round(train_loss, 4)))

# Compute the (train data) latent space features using the encoder
train_loader = data.DataLoader(data_train, batch_size=batch_size, shuffle=True)
x_train = np.empty([len(data_train), latent_dim], dtype=np.float32)
y_train = np.empty(len(data_train), dtype=np.int64)
with torch.no_grad():
    for i, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device)
        outputs = encoder(inputs).cpu()
        x_train[i * batch_size:i * batch_size + batch_size] = outputs.numpy()
        y_train[i * batch_size:i * batch_size + batch_size] = targets.numpy()

# Compute the (test data) latent space features using the encoder
test_loader = data.DataLoader(data_test, batch_size=batch_size, shuffle=False)
x_test = np.empty([len(data_test), latent_dim], dtype=np.float32)
y_test = np.empty(len(data_test), dtype=np.int64)
with torch.no_grad():
    for i, (inputs, targets) in enumerate(test_loader):
        inputs = inputs.to(device)
        outputs = encoder(inputs).cpu()
        x_test[i * batch_size:i * batch_size + batch_size] = outputs.numpy()
        y_test[i * batch_size:i * batch_size + batch_size] = targets.numpy()

# Preprocess the datasets using standardization
transform = DataStandardizer()
transform.fit(x_train)
x_train = transform.forward(x_train)
x_test = transform.forward(x_test)


In [None]:
#@title snp classifier (from latent) part 2

# Learn the SPN structure and parameters, as a classifier
# Note that we consider the train data as features + targets
distributions = [Gaussian] * latent_dim + [Categorical]
data_train = np.column_stack([x_train, y_train])
root = learn_classifier(
    data_train,
    distributions,
    learn_leaf='mle',     # Learn leaf distributions by MLE
    split_rows='kmeans',  # Use K-Means for splitting rows
    split_cols='rdc',     # Use RDC for splitting columns
    min_rows_slice=200,   # The minimum number of rows required to split furthermore
    split_rows_kwargs={'n': 2},   # Use n=2 number of clusters for K-Means
    split_cols_kwargs={'d': 0.3}  # Use d=0.3 as threshold for RDC independence test
)

# Print some statistics about the model's structure and parameters
print("SPN structure and parameters statistics:")
print(compute_statistics(root))


# Make some predictions on the test set classes
# This is done by running a Maximum Probable Explanation (MPE) query
nan_classes = np.full([len(x_test), 1], np.nan)
data_test = np.column_stack([x_test, nan_classes])
mpe(root, data_test, inplace=True)   ## for top down (from leaves to classes) leaf mask selection https://stackoverflow.com/questions/62505046/what-does-numpy-ix-function-do-and-what-is-the-output-used-for
y_pred = data_test[:, -1]

# Plot a classification report
print("Classification Report:")
print(sk.metrics.classification_report(y_test, y_pred))

# Sample some examples for each class
# This is done by conditional sampling w.r.t. the example classes
n_samples = 10
nan_features = np.full([n_samples * n_classes, latent_dim], np.nan) ## 100 X 24 nans
classes = np.tile(np.arange(n_classes), [1, n_samples]).T           ## 100 X 1 (np.repeat repeates individual elements np.tile repeates the array)
samples = np.column_stack([nan_features, classes])                  ## 100 X 25 (24 nans and 1 for clas 0-9. In 100 rows we have each class 10 times)
sample(root, samples, inplace=True)                                 ## It's first bottomup and than topdown evaluation (look into sampling part)
                                ## for single top leaf we just directly sample from scipy.stats.cauchy.rvs()
                                ## In Product node we just combine node with each its children by union (overlap is considered just once)
                                  #Note: node is above child below and we are going down)
                                ## for sampling for Sum of nodes we add randoms (scipy.stats.gumbel_l.rvs()) and add log(weights) to children nodes probabilities and tale argmax
                                  #than we just combine Sum node with just one of its children by union (overlap is considered just once)
                                  #that just one child is the one to which corresponds the argmax of Sum.
features = samples[:, :-1]

# Apply the inverse preprocessing transformation
# Then apply the features extractor's decoder and plot the examples on a grid
with torch.no_grad():
    images = transform.backward(features)
    inputs = torch.tensor(images, dtype=torch.float32, device=device)
    data_images = decoder(inputs).cpu()
    samples_filename = 'spn-latent-mnist-samples.png'
    print("Plotting generated samples to {} ...".format(samples_filename))
    torchvision.utils.save_image(data_images, samples_filename, nrow=n_samples, padding=0)