# Sum of Squares Circuits

## Complex Squared Circuits

TODO: write

In [25]:
from cirkit.templates import circuit_templates
from cirkit.symbolic.circuit import Circuit

def build_symbolic_complex_circuit(region_graph: str) -> Circuit:
    return circuit_templates.image_data(
        (1, 28, 28),                 # The shape of MNIST image, i.e., (num_channels, image_height, image_width)
        region_graph=region_graph,
        # ----------- Input layers hyperparameters ----------- #
        input_layer='embedding',     # Use Embedding maps for the pixel values (0-255) as input layers
        num_input_units=32,          # Each input layer consists of 64 input units that output Embedding entries
        input_params={               # Set how to parameterize the input layers parameters
            # In this case we parameterize the 'weight' parameter of Embedding layers,
            # by choosing them to be complex-valued whose real and imaginary part are sampled uniformly in [0, 1)
            'weight': circuit_templates.Parameterization(dtype='complex', initialization='uniform'),
        },
        # -------- Sum-product layers hyperparameters -------- #
        sum_product_layer='cp-t',    # Use CP-T sum-product layers, i.e., alternate hadamard product layers and dense layers
        num_sum_units=32,            # Each dense sum layer consists of 64 sum units
        # Set how to parameterize the sum layers parameters
        # We paramterize them to be complex-valued whose real and imaginary part are sampled uniformly in [0, 1)
        sum_weight_param = circuit_templates.Parameterization(dtype='complex', initialization='uniform')
    )

In [14]:
symbolic_circuit = build_symbolic_complex_circuit('quad-tree-4')

TODO: write

In [15]:
import cirkit.symbolic.functional as SF

symbolic_circuit_partition_func = SF.integrate(
    SF.multiply(symbolic_circuit, SF.conjugate(symbolic_circuit))
)

TODO: write

In [16]:
from cirkit.pipeline import PipelineContext, compile

# Instantiate the pipeline context
ctx = PipelineContext(backend='torch', semiring='complex-lse-sum', fold=True, optimize=True)

with ctx:
    circuit = compile(symbolic_circuit)
    circuit_partition_func = compile(symbolic_circuit_partition_func)

TODO: write

In [None]:
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

# Load the MNIST data set and data loaders
transform = transforms.Compose([
    transforms.ToTensor(),
    # Flatten the images and set pixel values in the [0-255] range
    transforms.Lambda(lambda x: (255 * x.view(-1)).long())
])
data_train = datasets.MNIST('datasets', train=True, download=True, transform=transform)
data_test = datasets.MNIST('datasets', train=False, download=True, transform=transform)

# Instantiate the training and testing data loaders
train_dataloader = DataLoader(data_train, shuffle=True, batch_size=256)
test_dataloader = DataLoader(data_test, shuffle=False, batch_size=256)

# Initialize a torch optimizer of your choice,
#  e.g., Adam, by passing the parameters of the circuit
optimizer = optim.Adam(circuit.parameters(), lr=0.01)

TODO: write

In [None]:
num_epochs = 5
step_idx = 0
running_loss = 0.0

# Move the circuit to chosen device
circuit = circuit.to(device)

for epoch_idx in range(num_epochs):
    for i, (batch, _) in enumerate(train_dataloader):
        # The circuit expects an input of shape (batch_dim, num_channels, num_variables),
        # so we unsqueeze a dimension for the channel.
        batch = batch.to(device).unsqueeze(dim=1)

        # Compute the logarithm of the squared scores of the batch, by evaluating the circuit
        log_scores = circuit(batch)
        log_squared_scores = 2.0 * log_scores.real
        
        # Compute the log-partition function
        log_partition_func = circuit_partition_func().real

        # Compute the log-likelihoods
        log_likelihoods = log_squared_scores - log_partition_func

        # We take the negated average log-likelihood as loss
        loss = -torch.mean(log_likelihoods)
        loss.backward()
        # Update the parameters of the circuits, as any other model in PyTorch
        optimizer.step()
        optimizer.zero_grad()
        running_loss += loss.detach() * len(batch)
        step_idx += 1
        if step_idx % 100 == 0:
            print(f"Step {step_idx}: Average NLL: {running_loss / (100 * len(batch)):.3f}")
            running_loss = 0.0

TODO: write

In [None]:
with torch.no_grad():
    test_lls = 0.0

    for batch, _ in test_dataloader:
        # The circuit expects an input of shape (batch_dim, num_channels, num_variables),
        # so we unsqueeze a dimension for the channel.
        batch = batch.to(device).unsqueeze(dim=1)

        # Compute the log-likelihoods of the batch
        log_likelihoods = circuit(batch)

        # Accumulate the log-likelihoods
        test_lls += log_likelihoods.sum().item()

    # Compute average test log-likelihood and bits per dimension
    average_ll = test_lls / len(data_test)
    bpd = -average_ll / (28 * 28 * np.log(2.0))
    print(f"Average test LL: {average_ll:.3f}")
    print(f"Bits per dimension: {bpd:.3f}")

TODO: write