In [None]:
import random
import numpy as np
import torch

# Set some seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

# Set the torch device to use
device = torch.device('cuda:0')

## Semantic Probabilistic Layer using `cirkit`

In the [Semantic Loss](../semantic-loss) notebook we saw how it is possible to implement neuro-symbolic methods using `cirkit`. In that case, the circuit was used as a regularization to a standard loss, *encouraging* consistent predictions according to some logical constraint $\phi$.

In this notebook we implement the **semantic probabilistic layer** (SPL) [(Ahmed et al, 2022)](https://arxiv.org/abs/2206.00426), which takes a different approach by exploiting the combination between a PC and a logical circuit.

Recall that given a logical constraint $\phi$, we can compile it to a circuit and implement it using `cirkit`. The very same circuit, however, is already a probabilistic circuit: it just defines a uniform unnormalized probability distribution over those assignments that are models of $\phi$ and enforces a null probability to every other assignment. This is due to the unitary weights over the sum units.

By changing the parameters of PC, we are able to *condition* the probability distribution of the circuit based on some external observation [(Shao et al, 2022)](https://www.sciencedirect.com/science/article/pii/S0888613X21001766). By doing so, however, the support of the circuit will be the same: only models of $\phi$ can have a non-null probability.

SPL leverage conditional circuits by training the PC on the target labels of a classification task, for instance MNIST images, and parameterizes the PC using a neural network that takes as input the MNIST image and output a parameter configuration for the PC.
During inference, it is possible to obtain the prediction for an image by conditionally parameterizing the PC and computing its MAP state.

In this notebook, we will show how to implement SPL by using a simple obvious constraint over MNIST images: the fact that each image only belongs to one class.

We first construct the logical constraint $\phi$

In [19]:
from operator import and_, or_
from functools import reduce
from itertools import combinations

from pysdd.sdd import Vtree, SddManager

N = 10

# define the SDD literals
vtree = Vtree(N, list(range(1, N + 1)), "balanced")
manager = SddManager.from_vtree(vtree)
alpha = reduce(or_, manager.vars) & reduce(and_, (~a | ~b for a, b in combinations(manager.vars, 2)))

And compile it into a circuit using `cirkit`

In [20]:
from tempfile import NamedTemporaryFile
from cirkit.templates.logic import SDD
from cirkit.symbolic.layers import EmbeddingLayer
from cirkit.symbolic.parameters import Parameter, ConstantParameter
from cirkit.symbolic.io import plot_circuit

from IPython.display import Image

with NamedTemporaryFile() as f:
    # export the SDD to a file
    alpha.save(f.name.encode())
    f.flush()

    # parse the SDD representation using the SDD class
    alpha_sdd = SDD.from_file(f.name)

    # we make sure that the sum layers include a softmax activation in their
    # parameter graph
    alpha_symbolic = alpha_sdd.build_circuit(sum_weight_activation="softmax")

In this case, for the sake of simplicity, we will only parameterize the sum units of the resulting PC, as this allows to train a single neural network to predict the parameters of the PC. Nonetheless, different combination of parameters (including input units) can be externally parameterized.

In [21]:
from cirkit.templates.logic import LiteralNode, NegatedLiteralNode
from cirkit.symbolic.functional import condition_circuit

# we enforce an order on the literals to match the order of the predictions
parametrization_map = {
    "sum": list(set(alpha_symbolic.sum_layers))
}

# conditionally parametrize alpha by using the external gate function
conditional_alpha_symbolic, gf_specs = condition_circuit(alpha_symbolic, gate_functions=parametrization_map)

for gf_k, gf_shape in gf_specs.items():
    print(f"Parameters {gf_k} need shape {gf_shape}")

Parameters sum.weight.0 need shape (9, 1, 2)


We can now compile the circuit into a torch computational graph. Note that in order to obtain a convex combination of input units in sums, we apply a softmax activation to each sum units' parameters.

In [22]:
from functools import partial
from cirkit.pipeline import PipelineContext

ctx = PipelineContext(backend="torch", semiring="lse-sum", fold=True)
ctx.add_gate_function("sum.weight.0", lambda x: x.view(-1, *gf_specs["sum.weight.0"]))
circuit = ctx.compile(conditional_alpha_symbolic)

We can now train the circuit on the MNIST dataset.

In [23]:
from torch import optim
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import v2

from operator import mul

from cirkit.backend.torch.queries import MAPQuery

# Load the MNIST data set and data loaders
transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
data_train = datasets.MNIST('datasets', train=True, download=True, transform=transform)
data_test = datasets.MNIST('datasets', train=False, download=True, transform=transform)

# Instantiate the training and testing data loaders
train_dataloader = DataLoader(data_train, shuffle=True, batch_size=256)
test_dataloader = DataLoader(data_test, shuffle=False, batch_size=256)

We will first train a baseline MLP model with hidden one layer and 256 hidden units for 10 epochs.

In [None]:
num_epochs = 10
running_loss = 0
running_samples = 0

mlp = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
).to(device)


# Initialize a torch optimizer of your choice
optimizer = optim.Adam(mlp.parameters(), lr=0.01)

for epoch_idx in range(num_epochs):
    for i, (x, y) in enumerate(train_dataloader):
        # compute the predictions using the MLP
        preds = mlp(x.to(device))

        loss = torch.nn.functional.cross_entropy(preds, y.to(device))

        # Update the parameters of the circuits, as any other model in PyTorch
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_loss += loss * len(x)
        running_samples += len(x)
    
    average_nl = running_loss / running_samples
        
    # compute the accuracy and logical correctness on the testing set
    correct_predictions = 0.0
    logically_correct_predictions = 0.0
    for x, y in test_dataloader:
        with torch.no_grad():
            pred = mlp(x.to(device))
            correct_predictions += (pred.argmax(dim=-1) == y.to(device)).sum()
    
    accuracy = correct_predictions / len(data_test)
    
    print(f"Epoch: {epoch_idx} | Loss: {average_nl} | Accuracy: {accuracy:.3f}")
    running_loss = 0.0
    running_samples = 0

Epoch: 0 | Loss: 0.22972959280014038 | Accuracy: 0.964
Epoch: 1 | Loss: 0.09498841315507889 | Accuracy: 0.976
Epoch: 2 | Loss: 0.07156120985746384 | Accuracy: 0.973
Epoch: 3 | Loss: 0.05734336003661156 | Accuracy: 0.975
Epoch: 4 | Loss: 0.048394735902547836 | Accuracy: 0.976
Epoch: 5 | Loss: 0.04697965085506439 | Accuracy: 0.974
Epoch: 6 | Loss: 0.037493132054805756 | Accuracy: 0.977
Epoch: 7 | Loss: 0.0397108793258667 | Accuracy: 0.975
Epoch: 8 | Loss: 0.03173552453517914 | Accuracy: 0.974
Epoch: 9 | Loss: 0.02695915289223194 | Accuracy: 0.973


Let's train the SPL model now. We will use a similar MLP model to predict the parameters of each sum: one hidden layers with 256 units.

In [24]:
mlp = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 256),
    nn.ReLU(),
    nn.Linear(256, reduce(mul, gf_specs["sum.weight.0"]))
).to(device)


# Initialize a torch optimizer of your choice
optimizer = optim.Adam(mlp.parameters(), lr=0.01)

# move circuit to device
circuit = circuit.to(device)

# prepare the map query to extract predictions
map_query = MAPQuery(circuit)

In [25]:
num_epochs = 10
running_loss = 0
running_samples = 0

for epoch_idx in range(num_epochs):
    for i, (x, y) in enumerate(train_dataloader):
        # compute the predictions using the MLP
        likelihoods = circuit(
            torch.nn.functional.one_hot(y.to(device), num_classes=10),
            gate_function_kwargs={
                "sum.weight.0": {"x": mlp(x.to(device))},
         })

        loss = -likelihoods.mean()

        # Update the parameters of the circuits, as any other model in PyTorch
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_loss += loss * len(x)
        running_samples += len(x)
    
    average_nl = running_loss / running_samples
        
    # compute the accuracy and logical correctness on the testing set
    correct_predictions = 0.0
    logically_correct_predictions = 0.0
    for x, y in test_dataloader:
        with torch.no_grad():
            _, pred = map_query(gate_function_kwargs={
                "sum.weight.0": {"x": mlp(x.to(device))},
            })
            correct_predictions += (pred.argmax(dim=-1) == y.to(device)).sum()
    
    accuracy = correct_predictions / len(data_test)
    
    print(f"Epoch: {epoch_idx} | NL: {average_nl} | Accuracy: {accuracy:.3f}")
    running_loss = 0.0
    running_samples = 0

Epoch: 0 | NL: 0.27490881085395813 | Accuracy: 0.959
Epoch: 1 | NL: 0.12324030697345734 | Accuracy: 0.960
Epoch: 2 | NL: 0.08998344838619232 | Accuracy: 0.964
Epoch: 3 | NL: 0.07511863857507706 | Accuracy: 0.963
Epoch: 4 | NL: 0.06081372871994972 | Accuracy: 0.967
Epoch: 5 | NL: 0.058237019926309586 | Accuracy: 0.972
Epoch: 6 | NL: 0.051081474870443344 | Accuracy: 0.966
Epoch: 7 | NL: 0.050264522433280945 | Accuracy: 0.963
Epoch: 8 | NL: 0.04702281579375267 | Accuracy: 0.968
Epoch: 9 | NL: 0.04019814357161522 | Accuracy: 0.967
