# Conditional Circuits

Let's assume we want to parameterize a circuit by means of a neural network, i.e., build and learn a _conditional circuit_. We can do so in cirkit in three steps:
1. we instantiate the symbolic circuit we want to parameterize;
2. we call a functional that takes the symbolic circuit and returns another one that contains the additional information for the parameterization we want;
3. we compile the symbolic circuit by firstly registering the parameterization to the compiler.

We start by instantiating a symbolic circuit on MNSIT images.

In [2]:
from cirkit.templates import data_modalities, utils

symbolic_circuit = data_modalities.image_data(
    (1, 28, 28),                 # The shape of MNIST image, i.e., (num_channels, image_height, image_width)
    region_graph='quad-tree-4',  # Select the structure of the circuit to follow the QuadTree-4 region graph
    input_layer='categorical',   # Use Categorical distributions for the pixel values (0-255) as input layers
    num_input_units=64,          # Each input layer consists of 64 Categorical input units
    sum_product_layer='cp',      # Use CP sum-product layers, i.e., alternate dense layers with Hadamard product layers
    num_sum_units=64,            # Each dense sum layer consists of 64 sum units
)

Note that we did not specify any parameterization for the sum layer parameters and the logits of the Categorical input layers.

Then, we call the functional ```cirkit.symbolic.functional.condition_circuit``` to obtain another symbolic circuit that stores the additional information on how we want to parameterize it.

In [3]:
import cirkit.symbolic.functional as SF
from cirkit.symbolic.layers import CategoricalLayer, SumLayer

parametrization_map = {
    "sum-layers": list(symbolic_circuit.sum_layers)
}

symbolic_conditional_circuit, gf_specs = SF.condition_circuit(
    symbolic_circuit,                          # The symbolic circut we want to parameterize
    gate_functions=parametrization_map
)

The ```condition_circuit``` functional also returns the shapes of the tensors to parameterize.

In [4]:
# The parameterize function returned the shape specfication of the tensors we will need to return
gf_specs

{'sum-layers.weight.0': (1048, 64, 64), 'sum-layers.weight.1': (1, 1, 64)}

Before compiling our conditional circuit, we define the gating function. As long as the gating function outputs a tensor compatible with the shape specified by the gating functions specifications, they can be any arbitrary function.

Note that the function is only responsible for providing valid parameters. Additional operations that are declared by the symbolic circuit are mantained.

In [5]:
print("Parameter graph for a sum layer in the symbolic circuit:")
for n in next(symbolic_circuit.sum_layers).params["weight"].nodes:
    print("\t", n)

Parameter graph for a sum layer in the symbolic circuit:
	 TensorParameter(shape=(64, 64), learnable=True, dtype=2, initializer=NormalInitializer(mean=0.0, stddev=1.0))
	 <cirkit.symbolic.parameters.SoftmaxParameter object at 0x703441fad4c0>


In the original symbolic circuit, we have one tensor parameterizing the sum layer weights and a softmax activation that ensures they form a convex combination.

In [6]:
print("Parameter graph for a sum layer in the conditioned circuit:")
for n in next(symbolic_conditional_circuit.sum_layers).params["weight"].nodes:
    print("\t", n)

Parameter graph for a sum layer in the conditioned circuit:
	 GateFunctionParameter(shape=(64, 64), name='sum-layers.weight.0', index=0)
	 <cirkit.symbolic.parameters.SoftmaxParameter object at 0x703441fad4c0>


The same happens in the conditional circuit, but with the crucial difference that the tensor parameter has been replaced by a gate function.

For a simple test, let's parametrize the sum layers of the circuit by randomly sampling their weights. To do so, we define gating functions that take as input an external tensor, say `z`, and outputs tensors with shapes compatible with the specifications.

In [7]:
import torch
from functools import partial

def random_sum_weights(shape, z: torch.Tensor):
    # compute the mean and standard deviation of all the elements in the batch
    mean, stddev = torch.mean(z, dim=-1), torch.std(z, dim=-1)
    # compute weights by randomly sampling
    samples = torch.randn(*shape)
    weight = mean.view(-1, 1, 1, 1) + stddev.view(-1, 1, 1, 1) * samples
    return weight

# test that the function outputs proper weights
weights = random_sum_weights(
    (3, *gf_specs["sum-layers.weight.0"]), 
    torch.randn(3, 256)
)

print("Weights shape:", weights.shape)

Weights shape: torch.Size([3, 1048, 64, 64])


We can now register the gating functions on the compiler, which will take care of compiling the conditional circuit, keep track of which function to call and execute them efficiently.

In [8]:
from cirkit.pipeline import PipelineContext

# Initialize a pipeline compilation context
# Let's try _without_ folding first
ctx = PipelineContext(semiring="lse-sum", backend='torch', fold=False, optimize=False)

# Register our neural network as an external model
ctx.add_gate_function("sum-layers.weight.0", partial(random_sum_weights, gf_specs["sum-layers.weight.0"]))
ctx.add_gate_function("sum-layers.weight.1", partial(random_sum_weights, gf_specs["sum-layers.weight.1"]))

# Finally, we compile the conditional circuit
circuit = ctx.compile(symbolic_conditional_circuit)

And evaluate the conditional circuit by specifying the argument for each gating function.

In [9]:
x = torch.randint(256, size=(10, 784))  # The circuit input
z = torch.randn(size=(10, 127))  # Some dummy input to the neural net

# Evaluate the circuit on some input
# Note that we also pass some input to the external model
circuit(x, gate_function_kwargs={'sum-layers.weight.0': {'z': z}, 'sum-layers.weight.1': {'z': z}})

tensor([[[-4357.3911]],

        [[-4362.3804]],

        [[-4361.9121]],

        [[-4357.7598]],

        [[-4358.0020]],

        [[-4352.8735]],

        [[-4358.7144]],

        [[-4353.7617]],

        [[-4352.5952]],

        [[-4359.7930]]], grad_fn=<TransposeBackward0>)

The above parameterization is robust to change in compilation flages, e.g., we can enable folding and layer optimizations.

In [10]:
# folding and optimization is enabled
ctx = PipelineContext(semiring="lse-sum", backend='torch', fold=True, optimize=True)

ctx.add_gate_function("sum-layers.weight.0", partial(random_sum_weights, gf_specs["sum-layers.weight.0"]))
ctx.add_gate_function("sum-layers.weight.1", partial(random_sum_weights, gf_specs["sum-layers.weight.1"]))

circuit = ctx.compile(symbolic_conditional_circuit)

circuit(x, gate_function_kwargs={'sum-layers.weight.0': {'z': z}, 'sum-layers.weight.1': {'z': z}})

tensor([[[-4354.7549]],

        [[-4353.4976]],

        [[-4355.6870]],

        [[-4361.1377]],

        [[-4352.8804]],

        [[-4354.7881]],

        [[-4356.4692]],

        [[-4356.9868]],

        [[-4356.1367]],

        [[-4359.3481]]], grad_fn=<TransposeBackward0>)

The conditional parametrization is batch-dependant: for each batch we independently parametrize the model. 

We can see how this influences the output by producing illegal weight parameters on purpose. The output of the circuit will not be a valid probability distribution anymore. To do so, we set all zero weigths for the first element in the batch. Intuitively, we should see the circuit producing a *stange* likelihood on the first batch and working regularly on the other.

To do so, however, we have to disable the parameter activation from the symbolic circuit, otherwise the softmax activation would act as a guard, producing a valid probability.

In [11]:
from cirkit.templates.utils import Parameterization

symbolic_circuit = data_modalities.image_data(
    (1, 28, 28),
    region_graph='quad-tree-4',
    input_layer='categorical',
    num_input_units=64,
    sum_product_layer='cp',
    num_sum_units=64,
    sum_weight_param=Parameterization(activation="none"), # disable the softmax activation
)

symbolic_conditional_circuit, gf_specs = SF.condition_circuit(
    symbolic_circuit,                          # The symbolic circut we want to parameterize
    gate_functions={
        "sum-layers": list(symbolic_circuit.sum_layers)
    }
)

print("Parameter graph for a sum layer in the conditioned circuit:")
for n in next(symbolic_conditional_circuit.sum_layers).params["weight"].nodes:
    print("\t", n)

Parameter graph for a sum layer in the conditioned circuit:
	 GateFunctionParameter(shape=(64, 64), name='sum-layers.weight.0', index=0)


We can see that now the sum weights are parameterized by the gating function on its own, which is single-handedly responsible for the normalization step.

In [12]:
def random_sum_weights_zero_first_sample(shape, z: torch.Tensor):
    # compute the mean and standard deviation of all the elements in the batch
    mean, stddev = torch.mean(z, dim=-1), torch.std(z, dim=-1)
    # compute weights by randomly sampling
    samples = torch.randn(*shape)
    weight = mean.view(-1, 1, 1, 1) + stddev.view(-1, 1, 1, 1) * samples
    
    # manually apply the softmax activation
    weight = torch.softmax(weight, dim=-1)

    # set first element in batch to 0
    weight[0] = 0
    return weight

# register the new gate function and compile the circuit
ctx = PipelineContext(semiring="lse-sum", backend='torch', fold=True, optimize=True)
ctx.add_gate_function("sum-layers.weight.0", partial(random_sum_weights_zero_first_sample, gf_specs["sum-layers.weight.0"]))
ctx.add_gate_function("sum-layers.weight.1", partial(random_sum_weights_zero_first_sample, gf_specs["sum-layers.weight.1"]))
circuit = ctx.compile(symbolic_conditional_circuit)

# run the circuit on the same dummy inputs
circuit(x, gate_function_kwargs={'sum-layers.weight.0': {'z': z}, 'sum-layers.weight.1': {'z': z}})

tensor([[[      -inf]],

        [[-4364.7510]],

        [[-4356.0972]],

        [[-4359.3677]],

        [[-4361.9619]],

        [[-4366.1802]],

        [[-4354.0938]],

        [[-4353.7583]],

        [[-4350.3306]],

        [[-4357.8169]]], grad_fn=<TransposeBackward0>)

And indeed, the first batch evaluates to a negative log likelihood equal to $-\infty$.