# Building and Learning Sum of Squares Circuits

## Goal

By the end of this notebook, you will know how to compose symbolic **circuit operators** as to build and learn a Probabilistic Circuit (PC). In particular, you will know how to build and learn Sum-of-Squares (SOS) circuits for distribution estimation tasks, as introduced in the paper [Sum of Squares Circuits](https://arxiv.org/abs/2408.11778). We start by introducing complex squared circuits.

## Complex Squared Circuits

PCs are typically learned by assuming their parameters to be non-negative, i.e., they are _monotonic_. For example, the PC learned in the notebook [learning-a-circuit.ipynb](learning-a-circuit.ipynb) is monotonic, as it consists of input layers encoding Categorical likelihoods and the parameters are obtained by applying a softmax activation. To build a more expressive distribution estimator, one can instead use a circuit with complex parameters, i.e., a complex circuit.

Similarly to the [learning-a-circuit.ipynb](learning-a-circuit.ipynb) notebook, we aim at building a circuit that estimates the probability distribution of MNIST images. For this reason, we will construct a complex circuit using the ```cirkit_templates.image_data``` utility, as shown in the following function.

In [1]:
from cirkit.templates import circuit_templates
from cirkit.symbolic.circuit import Circuit

def build_symbolic_complex_circuit(region_graph: str) -> Circuit:
    return circuit_templates.image_data(
        (1, 28, 28),                 # The shape of MNIST image, i.e., (num_channels, image_height, image_width)
        region_graph=region_graph,
        # ----------- Input layers hyperparameters ----------- #
        input_layer='embedding',     # Use Embedding maps for the pixel values (0-255) as input layers
        num_input_units=32,          # Each input layer consists of 32 input units that output Embedding entries
        input_params={               # Set how to parameterize the input layers parameters
            # In this case we parameterize the 'weight' parameter of Embedding layers,
            # by choosing them to be complex-valued whose real and imaginary part are sampled uniformly in [0, 1)
            'weight': circuit_templates.Parameterization(dtype='complex', initialization='uniform'),
        },
        # -------- Sum-product layers hyperparameters -------- #
        sum_product_layer='cp-t',    # Use CP-T sum-product layers, i.e., alternate hadamard product layers and dense layers
        num_sum_units=32,            # Each dense sum layer consists of 32 sum units
        # Set how to parameterize the sum layers parameters
        # We paramterize them to be complex-valued whose real and imaginary part are sampled uniformly in [0, 1)
        sum_weight_param=circuit_templates.Parameterization(dtype='complex', initialization='uniform')
    )

In the above, we choose input layers encoding complex embeddings, i.e., each input unit maps a pixel value in $\{0,1,\ldots,255\}$ to the corresponding entry of an embedding in $\mathbb{C}^{256}$. In addition, we make use of CP-T as sum-product layers, where sum layers are parameterized with complex weights. For more details about this and other layers, see the [region-graph-and-parameterisations.ipynb](region-graph-and-parameterisations.ipynb) notebook.

To encode a probability distribution, we must at least encode a non-negative real function. To do so with a complex circuit, we take the modulus square of its output. Formally, let $c$ be a complex circuit over variables $\mathbf{X}$, i.e., $c(\mathbf{X})\in\mathbb{C}$, we can encode a probability distribution $p(\mathbf{X})$ such that $p(\mathbf{X})=Z^{-1} |c(\mathbf{X})|^2 = Z^{-1} c(\mathbf{X}) c(\mathbf{X})^*$, where $(\ \cdot\ )^*$ denotes the complex conjugation operation and $Z = \int_{\mathrm{dom}(\mathbf{X})} |c(\mathbf{x})|^2 \mathrm{d}\mathbf{x}$ denotes the partition function. Equivalently, we can write $p(\mathbf{X})$ as proportional to the **sum of two squares**, i.e., $p(\mathbf{X}) \propto \Re(c(\mathbf{X}))^2 + \Im(c(\mathbf{X}))^2$, where $\Re,\Im$ denote real and imaginary part, respectively, thus guaranteeing it is non-negative.

## Composing Circuit Operators

To enable the exact and efficient computation of probabilities, we need to renormalize $p$, i.e., compute the renormalization constant $Z$. To do so, we can use the **circuit operators** in the ```cirkit.symbolic.functional``` module as to automatically construct the circuit computing $Z$. All we need is to _compose the operators_ as to encode the formula $Z = \int_{\mathrm{dom}(\mathbf{X})} |c(\mathbf{x})|^2 \mathrm{d}\mathbf{x}$ as yet another circuit.

More specifically, each of the operators we will use has **pre-conditions** on the structural properties of the input circuits, and **post-conditions** on the properties and semantics of the output circuit:
- ```c' = multiply(c1, c2)```:
  - Pre-condition: ```c1``` and ```c2``` are _compatible_, i.e., they share the same partitionings of variables at the products.
  - Post-condtion: ```c'``` is _smooth_ and _decomposable_ and encodes the product of ```c1``` and ```c2```.
- ```c' = conjugate(c)```:
  - Pre-condition: ```c``` is a circuit with possibly complex parameters.
  - Post-condition: ```c'``` is a circuit of the same structure of ```c``` and computing the complex conjugation of ```c```.
- ```c' = integrate(c)```:
  - Pre-condition: ```c``` is a _smooth_ and _decomposable_ circuit.
  - Post-condition: ```c'``` is a circuit exactly encoding the integral of ```c``` over the whole variables domain.

In order to satisfy these pre-conditions, we construct a complex circuit from a region graph that is structured-decomposable. This will yield a circuit that is compatible with itself, and therefore we can apply the ```multiply``` operator as to square it. Then, the circuit resulting from the multiply operator is smooth and decomposable and therefore it satisfies the pre-conditions of the ```integrate``` operator.

We build the symbolic circuit below and show its structural properties.

In [2]:
# Build a symbolic complex circuit by overparameterizing a Quad-Tree (4) region graph, which is structured-decomposable
symbolic_circuit = build_symbolic_complex_circuit('quad-tree-4')

# Print which structural properties the circuit satisfies
print(f'Structural properties:')
print(f'  - Smoothness: {symbolic_circuit.is_smooth}')
print(f'  - Decomposability: {symbolic_circuit.is_decomposable}')
print(f'  - Structured-decomposability: {symbolic_circuit.is_structured_decomposable}')

Structural properties:
  - Smoothness: True
  - Decomposability: True
  - Structured-decomposability: True


Next, we compose the circuit operators mentioned above as to construct the circuit computing $Z$.

In [3]:
import cirkit.symbolic.functional as SF

# Construct the circuit computing |c(X)|^2 = c(X) c(X)^*
symbolic_squared_circuit = SF.multiply(symbolic_circuit, SF.conjugate(symbolic_circuit))

# Construct the circuit computing Z, i.e., the integral of |c(X)|^2 over the complete domain of X
symbolic_circuit_partition_func = SF.integrate(symbolic_squared_circuit)

### Compiling and Learning Complex Squared Circuits

Since we want to estimate the distribution of MNIST images, here we learn complex squared circuits by maximizing the log-likelihoods of observed images. Formally, given a complex circuit $c$, we can write the log-likelihood of a data point $\mathbf{x}$ modeled by the complex squared circuit as $\log p(\mathbf{x}) = -\log Z + 2 \log |c(\mathbf{x})|$. Therefore, we need to compile two circuits for this purpose: (1) the circuit $c$, and (2) the circuit computing $Z$.

We choose PyTorch as the compilation backend, and set random seeds and the device below.

In [4]:
import random
import numpy as np
import torch

# Set some seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

# Set the torch device to use
device = torch.device('cuda')

To compile the circuits, we instantiate a ```PipelineContext``` object and refer the reader to the [compilation-options.ipynb](compilation-options.ipynb) notebook for a tutorial on compiling circuits and on the meaning of the different flags. Here, one important flag is the evaluation semiring. That is, to ensure numerical stability, we evaluate circuits by computing sum and products as they were operations of a semiring where the addition is the LogSumExp and the multiplication is the addition. More specifically, since our complex circuit can have negative real or complex parameter, we choose a generalization of the mentioned semiring over the complex plane.

In [5]:
from cirkit.pipeline import PipelineContext, compile

# Instantiate the pipeline context
ctx = PipelineContext(
    backend='torch',  # Choose PyTorch as compilation backend
    # ---- Use the evaluation semiring (C, +, x), where + is the numerically stable LogSumExp and x is the sum ---- #
    semiring='complex-lse-sum',
    # ------------------------------------------------------------------------------------------------------------- #
    fold=True,     # Fold the circuit to better exploit GPU parallelism
    optimize=True  # Optimize the layers of the circuit
)

with ctx:  # Compile the circuits computing log |c(X)| and log |Z|
    circuit = compile(symbolic_circuit)
    circuit_partition_func = compile(symbolic_circuit_partition_func)

In the above code, since we have chosen the ```complex-lse-sum``` semiring, then ```circuit``` is the complex circuit computing $\log |c(\mathbf{x})|$, while ```circuit_partition_func``` is the circuit computing $\log Z$, and both are PyTorch modules.

Next, we load the MNIST dataset using ```torchvision```.

In [6]:
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

# Load the MNIST data set and data loaders
transform = transforms.Compose([
    transforms.ToTensor(),
    # Flatten the images and set pixel values in the [0-255] range
    transforms.Lambda(lambda x: (255 * x.view(-1)).long())
])
data_train = datasets.MNIST('datasets', train=True, download=True, transform=transform)
data_test = datasets.MNIST('datasets', train=False, download=True, transform=transform)

# Instantiate the training and testing data loaders
train_dataloader = DataLoader(data_train, shuffle=True, batch_size=256)
test_dataloader = DataLoader(data_test, shuffle=False, batch_size=256)

# Initialize a torch optimizer of your choice,
#  e.g., Adam, by passing the parameters of the circuit
optimizer = optim.Adam(circuit.parameters(), lr=0.01)

In the following training loop, we move the circuit parameters to the chosen device, and then learn the parameters of the complex squared circuit by minimizing the negative log-likelihood computed on MNIST images.

In [7]:
num_epochs = 15
step_idx = 0
running_loss = 0.0

# Move the circuit to chosen device
circuit = circuit.to(device)

for epoch_idx in range(num_epochs):
    for i, (batch, _) in enumerate(train_dataloader):
        # The circuit expects an input of shape (batch_dim, num_channels, num_variables),
        # so we unsqueeze a dimension for the channel.
        batch = batch.to(device).unsqueeze(dim=1)

        # -------- Computation of the negative log-likelihoods loss -------- #
        # Compute the logarithm of the squared scores of the batch, by evaluating the circuit
        log_scores = circuit(batch)                 # log |c(x)|
        log_squared_scores = 2.0 * log_scores.real  # 2 * log |c(x)|, i.e., equivalent to log |c(x)|^2
        # Compute the log-partition function
        log_partition_func = circuit_partition_func().real  # log Z
        # Compute the log-likelihoods, log p(x) = 2 * log |c(X)| - log Z
        log_likelihoods = log_squared_scores - log_partition_func
        # We take the negated average log-likelihood as loss
        loss = -torch.mean(log_likelihoods)
        # ------------------------------------------------------------------ #

        # Update the parameters of the circuits, as any other model in PyTorch
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_loss += loss.detach() * len(batch)
        step_idx += 1
        if step_idx % 300 == 0:
            print(f"Step {step_idx}: Average NLL: {running_loss / (300 * len(batch)):.3f}")
            running_loss = 0.0

Step 300: Average NLL: 1277.903
Step 600: Average NLL: 759.341
Step 900: Average NLL: 711.138
Step 1200: Average NLL: 683.216
Step 1500: Average NLL: 668.198
Step 1800: Average NLL: 659.092
Step 2100: Average NLL: 653.012
Step 2400: Average NLL: 644.388
Step 2700: Average NLL: 641.965
Step 3000: Average NLL: 640.378
Step 3300: Average NLL: 636.618


Next, we evaluate the model on the test MNIST images, and show the bits-per-dimension metric.

In [8]:
with torch.no_grad():
    # -------- Compute the log-partition function -------- #
    # Note that we need to do it just one, since we are not updating the parameters here
    log_partition_func = circuit_partition_func().real
    # ---------------------------------------------------- #

    test_lls = 0.0
    for batch, _ in test_dataloader:
        batch = batch.to(device).unsqueeze(dim=1)

        # -------- Compute the log-likelihoods of hte unseen samples -------- #
        # Compute the logarithm of the squared scores of the batch, by evaluating the circuit
        log_scores = circuit(batch)
        log_squared_scores = 2.0 * log_scores.real
        # Compute the log-likelihoods
        log_likelihoods = log_squared_scores - log_partition_func
        # ------------------------------------------------------------------- #

        test_lls += log_likelihoods.sum().item()

    # Compute average test log-likelihood and bits per dimension
    average_ll = test_lls / len(data_test)
    bpd = -average_ll / (28 * 28 * np.log(2.0))
    print(f"Average test LL: {average_ll:.3f}")
    print(f"Bits per dimension: {bpd:.3f}")

Average test LL: -680.237
Bits per dimension: 1.252


## Learning a Sum of Exponentially many Squared Circuits

As we also observed above, the complex squared circuit we have built encodes a probability distribution that is the sum of two squares (real and imaginary part of the complex circuit). In the paper [Sum of Squares Circuits](https://arxiv.org/abs/2408.11778), a sum of exponentially many squared circuits is modelled, which can be more expressive than both a single squared circuit with real parameters and a structured-decomposable circuit with positive parameters only. In this section, we construct and learn such a model using cirkit.

Given a complex circuit $c_2$ like the one we have previously built, we construct a monotonic PC $c_1$ that has the same structure of $c_2$. By sharing the same structure, we can model the distribution $p(\mathbf{X})$ as proportional to the product between $c_1$ and the modulus squaring of $c_2$, i.e., $p(\mathbf{X}) = \frac{1}{Z} c_1(\mathbf{X}) |c_2(\mathbf{X})|^2$, where $Z = \int_{\mathrm{dom}(\mathbf{X})} c_1(\mathbf{x}) |c_2(\mathbf{x})|^2 \mathrm{d}\mathbf{x}$ is the partition function. Since $c_1$ implicitly encodes a mixture model of an exponential number of components w.r.t. its circuit depth, the product of $c_1(\mathbf{X})$ and $|c_2(\mathbf{X})|^2$ results in a mixture model of an exponentialy number of squared circuits that share parameters. See the Appendix C.3 of the paper [Sum of Squares Circuits](https://arxiv.org/abs/2408.11778) for more details.

Here, we start by constructing the monotonic and the complex circuits.

In [9]:
def build_symbolic_monotonic_circuit(region_graph: str) -> Circuit:
    return circuit_templates.image_data(
        (1, 28, 28),                 # The shape of MNIST image, i.e., (num_channels, image_height, image_width)
        region_graph=region_graph,
        # ----------- Input layers hyperparameters ----------- #
        input_layer='embedding',     # Use Embedding maps for the pixel values (0-255) as input layers
        num_input_units=4,           # Each input layer consists of 4 input units that output Embedding entries
        input_params={               # Set how to parameterize the input layers parameters
            # In this case we parameterize the 'weight' parameter of Embedding layers,
            # by choosing them to be paramerized with a softmax, and initialized by sampling from a standard normal distribution
            'weight': circuit_templates.Parameterization(activation='softmax', initialization='normal'),
        },
        # -------- Sum-product layers hyperparameters -------- #
        sum_product_layer='cp-t',    # Use CP-T sum-product layers, i.e., alternate hadamard product layers and dense layers
        num_sum_units=4,             # Each dense sum layer consists of 4 sum units
        # Set how to parameterize the sum layers parameters
        # We paramterize them with a softmax, and initialize them by sampling from a standard normal distribution
        sum_weight_param=circuit_templates.Parameterization(activation='softmax', initialization='normal')
    )

The function above is very similar to the function we used to construct a symbolic complex circuit. The only difference is that we parameterize the embeddings and the sum weights using a softmax activation function, thus guaranteeig the non-negativity of the parameters, and therefore the outputs of $c_1$.

By using the same region graph, we construct the symbolic complex and monotonics circuits.

In [10]:
# Build a symbolic monotonic circuit, i..e., c_1, by overparameterizing a Quad-Tree (4) region graph
symbolic_monotonic_circuit = build_symbolic_monotonic_circuit('quad-tree-4')

# Build a symbolic complex circuit, i.e., c_2, by overparameterizing the same region graph
symbolic_complex_circuit = build_symbolic_complex_circuit('quad-tree-4')

Since we used the same region graph, the two circuits will be _compatible_, thus satisfying the pre-conditions of the ```multiply``` operator.

In the following code snipped, we construct the symbolic circuit computing the partition function  of our model, i.e., $Z = \int_{\mathrm{dom}(\mathbf{X})} c_1(\mathbf{x}) |c_2(\mathbf{x})|^2 \mathrm{d}\mathbf{x}$.

In [11]:
# Construct the circuit computing c_1(X) |c_2(X)|^2 = c_1(X) c_2(X) c_2(X)^*
symbolic_expsos_circuit = SF.multiply(
    symbolic_monotonic_circuit,
    SF.multiply(symbolic_complex_circuit, SF.conjugate(symbolic_complex_circuit))
)

# Construct the circuit computing Z, i.e., the integral of c_1(X) |c_2(X)|^2 over the complete domain of X
symbolic_circuit_partition_func = SF.integrate(symbolic_expsos_circuit)

As done in the previous section, we compile the circuits we need during learning. We observe we can decompose the log-likelihood $\log p(\mathbf{x})$ of a data point $\mathbf{x}$ as
$$\log p(\mathbf{x}) = -\log Z + \log c_1(\mathbf{x}) + 2 \log |c_2(\mathbf{x})|,$$
thus requiring us to compile three circuits: (1) the monotonic circuit $c_1$, (2) the complex circuit $c_2$, and (3) the circuit computing $Z$.

In [12]:
# Free-up some memory
del circuit, circuit_partition_func, ctx

# Instantiate the pipeline context
ctx = PipelineContext(
    backend='torch',  # Choose PyTorch as compilation backend
    semiring='complex-lse-sum',
    fold=True,     # Fold the circuit to better exploit GPU parallelism
    optimize=True  # Optimize the layers of the circuit
)

with ctx:  # Compile the circuits computing log c_1(X), log |c_2(X)|, and log |Z|
    monotonic_circuit = compile(symbolic_monotonic_circuit)
    complex_circuit = compile(symbolic_complex_circuit)
    circuit_partition_func = compile(symbolic_circuit_partition_func)

In the following, we use Adam as optimizer in PyTorch and optimize the learnable parameters of both the monotonic and complex circuits.

In [13]:
import itertools

# Initialize a torch optimizer of your choice,
#  e.g., Adam, by passing the parameters of the circuits
optimizer = optim.Adam(itertools.chain(monotonic_circuit.parameters(), complex_circuit.parameters()), lr=0.01)

As done for the complex squared circuit above, we optimize the parameters by minimizing the negative log-likelihood computed on MNIST images.

In [14]:
num_epochs = 15
step_idx = 0
running_loss = 0.0

# Move the circuits to chosen device
monotonic_circuit = monotonic_circuit.to(device)
complex_circuit = complex_circuit.to(device)

for epoch_idx in range(num_epochs):
    for i, (batch, _) in enumerate(train_dataloader):
        # The circuit expects an input of shape (batch_dim, num_channels, num_variables),
        # so we unsqueeze a dimension for the channel.
        batch = batch.to(device).unsqueeze(dim=1)

        # -------- Computation of the negative log-likelihoods loss -------- #
        # Compute the logarithm of the scores of the batch, by evaluating the circuits
        log_monotonic_scores = monotonic_circuit(batch).real    # log c_+(x)
        log_squared_scores = 2.0 * complex_circuit(batch).real  # 2 * log |c(x)|
        # Compute the log-partition function
        log_partition_func = circuit_partition_func().real  # log Z
        # Compute the log-likelihoods, log p(x) = log c_+(x) + 2 * log |c(X)| - log Z
        log_likelihoods = log_monotonic_scores + log_squared_scores - log_partition_func
        # We take the negated average log-likelihood as loss
        loss = -torch.mean(log_likelihoods)
        # ------------------------------------------------------------------ #

        # Update the parameters of the circuits, as any other model in PyTorch
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_loss += loss.detach() * len(batch)
        step_idx += 1
        if step_idx % 300 == 0:
            print(f"Step {step_idx}: Average NLL: {running_loss / (300 * len(batch)):.3f}")
            running_loss = 0.0

Step 300: Average NLL: 1192.256
Step 600: Average NLL: 731.775
Step 900: Average NLL: 689.649
Step 1200: Average NLL: 665.281
Step 1500: Average NLL: 651.504
Step 1800: Average NLL: 642.295
Step 2100: Average NLL: 636.493
Step 2400: Average NLL: 628.150
Step 2700: Average NLL: 626.221
Step 3000: Average NLL: 623.579
Step 3300: Average NLL: 620.558


Then, we test the model by computing the bits-per-dimension on unseen MNIST images.

In [15]:
with torch.no_grad():
    # -------- Compute the log-partition function -------- #
    # Note that we need to do it just one, since we are not updating the parameters here
    log_partition_func = circuit_partition_func().real
    # ---------------------------------------------------- #

    test_lls = 0.0
    for batch, _ in test_dataloader:
        batch = batch.to(device).unsqueeze(dim=1)

        # -------- Compute the log-likelihoods of hte unseen samples -------- #
        # Compute the logarithm of the cores of the batch, by evaluating the circuits
        log_monotonic_scores = monotonic_circuit(batch).real    # log c_+(x)
        log_squared_scores = 2.0 * complex_circuit(batch).real  # 2 * log |c(x)|
        # Compute the log-likelihoods
        log_likelihoods = log_monotonic_scores + log_squared_scores - log_partition_func
        # ------------------------------------------------------------------- #

        test_lls += log_likelihoods.sum().item()

    # Compute average test log-likelihood and bits per dimension
    average_ll = test_lls / len(data_test)
    bpd = -average_ll / (28 * 28 * np.log(2.0))
    print(f"Average test LL: {average_ll:.3f}")
    print(f"Bits per dimension: {bpd:.3f}")

Average test LL: -667.903
Bits per dimension: 1.229


Note that we achieved a reduction in terms of bits-per-dimension when compared to the complex squared circuit alone. In particular, this reduction has been achieved with a small increase in the total number of parameters given by the monotonic circuit.

In [16]:
num_monotonic_params = sum(t.numel() for t in monotonic_circuit.parameters() if t.requires_grad)
num_complex_params = sum(2 * t.numel() for t in complex_circuit.parameters() if t.requires_grad)
print(f"Monotonic circuit - Num. of parameters: {num_monotonic_params}")
print(f"Complex circuit - Num. of parameters: {num_complex_params}")

Monotonic circuit - Num. of parameters: 807044
Complex circuit - Num. of parameters: 13385792
