# Learning a Circuit

TODO: explain what are we going to do (high level)

In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt

## Load MNIST Dataset

TODO: stress we can use any library to load data sets, and everything will work

Load the training and test splits of MNIST, and preprocess them by flattening the tensor images.

In [None]:
from torchvision import transforms, datasets
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (255 * x.view(-1)).long())
])
data_train = datasets.MNIST('datasets', train=True, download=True, transform=transform)
data_test = datasets.MNIST('datasets', train=False, download=True, transform=transform)
num_variables = data_train[0][0].shape[0]
height, width = 28, 28
print(f"Number of variables: {num_variables}")

## Instantiating a Circuit structure Template: the Region Graph

TODO: explain what a region graph is + add citation

Initialize a _Quad Tree_ region graph.

In [None]:
from cirkit.templates.region_graph import QuadTree
region_graph = QuadTree(shape=(height, width), num_patch_splits=4)

## Constructing the Symbolic Circuit Representation

TODO: refactor this, explain symbolic circuits as an intermediate representation before compilation (link section)

From the region graph definition above, we now construct the symbolic circuit representation. Note that this circuit representation is _not_ executable, i.e., you cannot do learn it or do inference with it. It will be compiled later, by choosing a backend such as torch.

To do so, we first define the factories that will be used to construct symbolic layers. Note that we choose the parameterization at the symbolic level. That is, we guarantee non-negative parameters by passing them through an exponential function. Moreover, we can choose how to parameterize the categorical distributions used to model the distribution of pixel values in the 0-255 range. In this case, we use a log softmax function. We choose to initialize the weights of the circuit by sampling from a normal distribution.

In [None]:
from cirkit.utils.scope import Scope
from cirkit.symbolic.parameters import SoftmaxParameter, ExpParameter, TensorParameter, Parameter
from cirkit.symbolic.layers import CategoricalLayer, DenseLayer, KroneckerLayer, MixingLayer
from cirkit.symbolic.initializers import NormalInitializer

In [None]:
# TODO: remove sum/product/mixing factories and use the higher level APIs


# TODO: Remove categorical factory, Add string option in the lib
def categorical_layer_factory(
    scope: Scope,
    num_units: int,
    num_channels: int
) -> CategoricalLayer:
    return CategoricalLayer(
        scope, num_units, num_channels, num_categories=256,
        probs_factory=lambda shape: Parameter.from_unary(
            SoftmaxParameter(shape),
            TensorParameter(*shape, initializer=NormalInitializer(0.0, 1e-2))
        )
    )


def kronecker_layer_factory(
    scope: Scope, num_input_units: int, arity: int
) -> KroneckerLayer:
    return KroneckerLayer(scope, num_input_units, arity)


def dense_layer_factory(
    scope: Scope,
    num_input_units: int,
    num_output_units: int
) -> DenseLayer:
    return DenseLayer(
        scope, num_input_units, num_output_units,
        weight_factory=lambda shape: Parameter.from_unary(
            ExpParameter(shape),
            TensorParameter(*shape, initializer=NormalInitializer(0.0, 1e-1))
        )
    )


def mixing_layer_factory(
    scope: Scope, num_units: int, arity: int
) -> MixingLayer:
    return MixingLayer(
        scope, num_units, arity,
        weight_factory=lambda shape: Parameter.from_unary(
            ExpParameter(shape),
            TensorParameter(*shape, initializer=NormalInitializer(0.0, 1e-1))
        )
    )

Then, we call a function to construct the symbolic circuit from region graph, by specifying the number of units and the factories to build layers.

In [None]:
from cirkit.symbolic.circuit import Circuit

In [None]:
# TODO: implement this

symbolic_circuit = templates.image_data(
    image_shape=(1, 28, 28),
    region_graph='quad-tree',
    num_input_units=8
    num_sum_units=8,
    input_layer='categorical',   # TODO: implement this, rather than using the factory
    sum_product='cp',
    sum_weight_param='softmax'   # TODO: implement this instead of dense_weight_factory
)

TODO: discuss structural properties

We can retrieve some information about the circuit and its structural properties as follows.

In [None]:
print(f'Smooth: {symbolic_circuit.is_smooth}')
print(f'Decomposable: {symbolic_circuit.is_decomposable}')
print(f'Number of variables: {symbolic_circuit.num_variables}')
print(f'Number of channels per variable: {symbolic_circuit.num_channels}')

## Compiling the Symbolic Circuit

TODO: explain compilation procedure, we choose the torch backend

We are ready to compile the symbolic circuit constructed above into another one that we can learn and/or do inference. To do so, we have to choose a compilation backend. In this case, we choose torch as a backend.

In [None]:
#TODO: set seed also in random/numpy

import torch
device = torch.device('cuda')  # The device to use
torch.manual_seed(42)
torch.cuda.manual_seed(42)

We first need to instantiate a circuit pipeline context and specify the backend to be used, as well as optional compilation flags, e.g., whether to fold the circuit or which inference semiring to use. Finally, we use the pipeline context to compile the symbolic circuit.

In [None]:
from cirkit.pipeline import compile

In [None]:
%%time
circuit = compile(symbolic_circuit)

Note that the compilation step, comprising the folding optimization, required just a few seconds for a circuit with ~5000 layers and ~400M learnable parameters.

In [None]:
print(len(list(symbolic_circuit.layers)))
print(sum(p.numel() for p in circuit.parameters() if p.requires_grad))

## Training and Testing

TODO: refactor this comment, stress the user can choose any optimizer

We are now ready to learn the parameters and do inference First, we wrap our data into PyTorch data loaders by specifying the batch size. Then, we initialize any PyTorch optimizer, e.g. SGD with momentum in this case.

In [None]:
from torch import optim
from torch.utils.data import DataLoader
train_dataloader = DataLoader(data_train, shuffle=True, batch_size=256, drop_last=True, num_workers=4)
test_dataloader = DataLoader(data_test, shuffle=False, batch_size=256, num_workers=4)
optimizer = optim.Adam(circuit.parameters(), lr=0.01)

In [None]:
# Move circuit to device
circuit = circuit.to(device)

In [None]:
# TODO: rewrite this loop, we do not need the partition function anymore here

start_time = time.perf_counter()
num_epochs = 3
step_idx = 0
running_loss = 0.0
for epoch_idx in range(num_epochs):
    for i, (batch, _) in enumerate(train_dataloader):
        batch = batch.to(device).unsqueeze(dim=1)   # Add a channel dimension
        log_output = circuit(batch)                 # Compute the log output of the circuit
        log_pf = pf_circuit()                       # Compute the log partition function of the circuit
        lls = log_output - log_pf                   # Compute the log-likelihood
        loss = -torch.mean(lls)   # The loss is the negative average log-likelihood
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        running_loss += loss.detach() * len(batch)
        step_idx += 1
        if step_idx % 100 == 0:
            print(f"Step {step_idx}: Average NLL: {running_loss / (100 * len(batch)):.3f}")
            running_loss = 0.0
end_time = time.perf_counter()
print(f"Training time: {end_time - start_time:.1f} seconds")

We then evaluate our model on test data by computing the average log-likelihood and bits per dimension.

In [None]:
circuit.eval()
pf_circuit.eval()

with torch.no_grad():
    test_lls = 0.0
    log_pf = pf_circuit()  # Compute the log partition function of the circuit (just once as we are evaluating)
    for batch, _ in test_dataloader:
        batch = batch.to(device).unsqueeze(dim=1)   # Add a channel dimension
        log_output = circuit(batch)                 # Compute the log output of the circuit
        lls = log_output - log_pf                   # Compute the log-likelihood
        test_lls += lls.sum().item()
    average_ll = test_lls / len(data_test)
    bpd = -average_ll / (num_variables * np.log(2.0))
    print(f"Average test LL: {average_ll:.3f}")
    print(f"Bits per dimension: {bpd:.3f}")

TODO: show people we can do marginals, use integrate in cirkit.pipeline