# Conditional Circuits

Let's assume we want to parameterize a circuit by means of a neural network, i.e., build and learn a _conditional circuit_. We can do so in cirkit in three steps:
1. we instantiate the symbolic circuit we want to parameterize;
2. we call a functional that takes the symbolic circuit and returns another one that contains the additional information for the parameterization we want;
3. we compile the symbolic circuit by firstly registering the parameterization to the compiler.

We start by instantiating a symbolic circuit on MNSIT images.

In [1]:
from cirkit.templates import data_modalities, utils

symbolic_circuit = data_modalities.image_data(
    (1, 28, 28),                 # The shape of MNIST image, i.e., (num_channels, image_height, image_width)
    region_graph='quad-tree-4',  # Select the structure of the circuit to follow the QuadTree-4 region graph
    input_layer='categorical',   # Use Categorical distributions for the pixel values (0-255) as input layers
    num_input_units=64,          # Each input layer consists of 64 Categorical input units
    sum_product_layer='cp',      # Use CP sum-product layers, i.e., alternate dense layers with Hadamard product layers
    num_sum_units=64,            # Each dense sum layer consists of 64 sum units
    # ----- Do not use any parameterization for the sum weights and the Categorical logits ----- #
    sum_weight_param=utils.Parameterization(activation='none'),
    input_params={'probs': utils.Parameterization(activation='none')}
    # ------------------------------------------------------------------------------------------ #
)

Note that we did not specify any parameterization for the sum layer parameters and the logits of the Categorical input layers.

Then, we call the functional ```cirkit.symbolic.functional.model_parameterize``` to obtain another symbolic circuit that stores the additional information on how we want to parameterize it.

In [2]:
import cirkit.symbolic.functional as SF
from cirkit.symbolic.layers import CategoricalLayer, SumLayer

# Assume there exists a model called "my-neural-network" we will specify at compile-time later
symbolic_conditional_circuit, model_specs = SF.model_parameterize(
    symbolic_circuit,                          # The symbolic circut we want to parameterize
    model_id="my-neural-network",              # The identifier of the neural network we will use
    filter_layers=[CategoricalLayer, SumLayer] # The layer filters. We want to parameterize both the Categorical and the sum layer
)

The ```model_parameterize``` functional also returns the shapes of the tensors to parameterize.

In [3]:
# The parameterize function returned the shape specfication of the tensors we will need to return
model_specs

{'g0.CategoricalLayer.probs': (784, 64, 1, 256),
 'g1.SumLayer.weight': (784, 64, 64),
 'g2.SumLayer.weight': (196, 64, 64),
 'g3.SumLayer.weight': (49, 64, 64),
 'g4.SumLayer.weight': (15, 64, 64),
 'g5.SumLayer.weight': (4, 64, 64),
 'g6.SumLayer.weight': (1, 1, 64)}

Before compiling our conditional circuit, we define a neural network in PyTorch.

In [4]:
import torch
from torch import Tensor, nn

class MyNeuralNetwork(nn.Module):
    def __init__(self, output_specs: dict[str, tuple[int, ...]]):
        super().__init__()
        self._output_specs = output_specs
        # TODO: define some clever neural network ...
        ...

    def forward(self, z: Tensor) -> dict[str, Tensor]:
        # Evaluate my neural network on some input y
        # Return a dictionary mapping parameter names to actual parameter tensors
        # In this example, we just sample them randomly and use a softmax activation
        mean, stddev = torch.mean(z.flatten()), torch.std(z.flatten())
        return {
            name: torch.softmax(mean + stddev * torch.randn(*shape), dim=-1)
            for name, shape in self._output_specs.items()
        }

In [5]:
# Allocate our neural network
my_neural_net = MyNeuralNetwork(model_specs)

Next, we tell the compiler about our model, by specifying the same model name we gave to the ```model_parameterize``` functional.

In [6]:
from cirkit.pipeline import PipelineContext

# Initialize an pipeline compilation context
# Let's try _without_ folding first
ctx = PipelineContext(semiring="lse-sum", backend='torch', fold=False, optimize=False)

# Register our neural network as an external model
ctx.add_external_model("my-neural-network", my_neural_net)

# Finally, we compile the conditional circuit
circuit = ctx.compile(symbolic_conditional_circuit)

In [7]:
circuit

TorchCircuit(
  (0): TorchCategoricalLayer(
    folds: 1  channels: 1  variables: 1  output-units: 64
    input-shape: (1, 1, -1, 1)
    output-shape: (1, -1, 64)
    (probs): TorchParameter(
      shape: (1, 64, 1, 256)
      (0): TorchModelParameter(output-shape: (1, 64, 1, 256))
    )
  )
  (1): TorchCategoricalLayer(
    folds: 1  channels: 1  variables: 1  output-units: 64
    input-shape: (1, 1, -1, 1)
    output-shape: (1, -1, 64)
    (probs): TorchParameter(
      shape: (1, 64, 1, 256)
      (0): TorchModelParameter(output-shape: (1, 64, 1, 256))
    )
  )
  (2): TorchCategoricalLayer(
    folds: 1  channels: 1  variables: 1  output-units: 64
    input-shape: (1, 1, -1, 1)
    output-shape: (1, -1, 64)
    (probs): TorchParameter(
      shape: (1, 64, 1, 256)
      (0): TorchModelParameter(output-shape: (1, 64, 1, 256))
    )
  )
  (3): TorchCategoricalLayer(
    folds: 1  channels: 1  variables: 1  output-units: 64
    input-shape: (1, 1, -1, 1)
    output-shape: (1, -1, 64)


Now, to evaluate the conditional circuit we specify additional arguments when calling it on some tensor, as shown below.

In [8]:
batch_size = 8
x = torch.randint(256, size=(batch_size, 1, 784))  # The circuit input
# TODO: make the parameters batch-dependant
z = torch.randn(size=(1, 127))  # Some dummy input to the neural net

# Evaluate the circuit on some input
# Note that we also pass some input to the external model
circuit(x, ext_model_kwargs={'my-neural-network': {'z': z}})

tensor([[[-4365.5845]],

        [[-4348.9673]],

        [[-4357.1748]],

        [[-4349.1304]],

        [[-4358.7705]],

        [[-4356.5312]],

        [[-4355.4453]],

        [[-4355.8472]]])

The above parameterization is robust to change in compilation flages, e.g., now enabling folding and layer optimizations.

In [9]:
from cirkit.pipeline import PipelineContext

# Initialize an pipeline compilation context
# Let's enable folding and layer optimizations
ctx = PipelineContext(semiring="lse-sum", backend='torch', fold=True, optimize=True)

# Register our neural network as an external model
ctx.add_external_model("my-neural-network", my_neural_net)

# Finally, we compile the conditional circuit
circuit = ctx.compile(symbolic_conditional_circuit)

In [10]:
circuit

TorchCircuit(
  (0): TorchCategoricalLayer(
    folds: 784  channels: 1  variables: 1  output-units: 64
    input-shape: (784, 1, -1, 1)
    output-shape: (784, -1, 64)
    (probs): TorchParameter(
      shape: (784, 64, 1, 256)
      (0): TorchModelParameter(output-shape: (784, 64, 1, 256))
    )
  )
  (1): TorchSumLayer(
    folds: 784  arity: 1  input-units: 64  output-units: 64
    input-shape: (784, 1, -1, 64)
    output-shape: (784, -1, 64)
    (weight): TorchParameter(
      shape: (784, 64, 64)
      (0): TorchModelParameter(output-shape: (784, 64, 64))
    )
  )
  (2): TorchCPTLayer(
    folds: 196  arity: 4  input-units: 64  output-units: 64
    input-shape: (196, 4, -1, 64)
    output-shape: (196, -1, 64)
    (weight): TorchParameter(
      shape: (196, 64, 64)
      (0): TorchModelParameter(output-shape: (196, 64, 64))
    )
  )
  (3): TorchCPTLayer(
    folds: 49  arity: 4  input-units: 64  output-units: 64
    input-shape: (49, 4, -1, 64)
    output-shape: (49, -1, 64)
  

In [11]:
batch_size = 8
x = torch.randint(256, size=(batch_size, 1, 784))  # The circuit input
# TODO: make the parameters batch-dependant
z = torch.randn(size=(1, 127))  # Some dummy input to the neural net

# Evaluate the circuit on some input
# Note that we also pass some input to the external model
circuit(x, ext_model_kwargs={'my-neural-network': {'z': z}})

tensor([[[-4348.8794]],

        [[-4358.9419]],

        [[-4354.4893]],

        [[-4353.4937]],

        [[-4356.8423]],

        [[-4356.1177]],

        [[-4353.5776]],

        [[-4361.5161]]])