ProbabilisticCircuits.jl offers various parameter and structure learning algorithms for PCs. In this example, we will demonstrate how to generate a particular PC model termed *Hidden Chow-Liu Tree (HCLT)* and use it to train a density (generative) model with **SoTA** likelihood on MNIST in around **5 minutes**. In comparison, as shown in the table below, VAE and Flow models have worse likelihoods on MNIST and EMNIST. Additionally, they typically need several hours to train, which is over **10x slower than PCs**.

| Dataset | PC (ours) | IDF | Hierarchical VAE | PixelVAE |
| :--- | :---: | :---: | :---: | :---: |
| MNIST | **1.20** | 2.90 | 1.27 | 1.39 |
| FashionMNIST | 3.34 | 3.47 | **3.28** | 3.66 |
| EMNIST (Letter split) | **1.80** | 1.95 | 1.84 | 2.26 |
| EMNIST (ByClass split) | **1.85** | 1.98 | 1.87 | 2.23 |

\* Note: all reported numbers are bits-per-dimension (bpd). The results are extracted from [1].

[1] Anji Liu, Stephan Mandt and Guy Van den Broeck. Lossless Compression with Probabilistic Circuits, In International Conference on Learning Representations (ICLR), 2022.

We start by importing ProbabilisticCircuits.jl and other required packages:

In [1]:
using ProbabilisticCircuits
using MLDatasets
using CUDA

We first load the MNIST dataset from MLDatasets.jl and move them to GPU:

In [2]:
mnist_train_cpu = collect(transpose(reshape(MNIST.traintensor(UInt8), 28*28, :)))
mnist_test_cpu = collect(transpose(reshape(MNIST.testtensor(UInt8), 28*28, :)))
mnist_train_gpu = cu(mnist_train_cpu)
mnist_test_gpu = cu(mnist_test_cpu)
println("Dataset summary:\n - Number of training examples: $(size(mnist_train_cpu, 1))\n - Number of test examples: $(size(mnist_test_cpu, 1))\n - Number of features: $(size(mnist_train_cpu, 2))")

Dataset summary:
 - Number of training examples: 60000
 - Number of test examples: 10000
 - Number of features: 784


We move on to generate the HCLT structure. `hclt` constructs a smooth and structured-decomposable PC whose structure depends on the input samples. Specifically, it computes the pairwise mutual information (MI) between the MNIST features (i.e., pixels), and use the pairwise MI matrix to determine the PC structure, such that highly correlated features are placed "closer" in the PC to facilitate learning. In the following, `bits` is the number of bits to truncate to speedup the pairwise MI computation, and `latents` specifies the size of the generated HCLT.

In [3]:
bits = 4
latents = 32
println("Generating HCLT structure with $latents latents... ");
trunc_train = cu(mnist_train_cpu .÷ 2^bits)
@time pc = hclt(trunc_train, latents; num_cats = 256, pseudocount = 0.1, input_type = Categorical)
init_parameters(pc; perturbation = 0.4)
println("Number of free parameters: $(num_parameters(pc))")

Generating HCLT structure with 32 latents... 
 24.203915 seconds (82.55 M allocations: 6.716 GiB, 5.69% gc time, 50.04% compilation time)
Number of free parameters: 6980767


To facilitate efficient parameter learning on GPUs, we first convert `pc` into an equivalent GPU-friendly low-level representation termed bits-circuit:

In [4]:
print("Moving pc to GPU... ")
CUDA.@time bpc = CuBitsProbCircuit(pc);

Moving pc to GPU...   1.974239 seconds (17.74 M CPU allocations: 1.215 GiB, 7.59% gc time) (7 GPU allocations: 76.784 MiB, 0.00% memmgmt time)


We are now ready to train the parameters of the PC. This is done by calling the high-level API `mini_batch_em`:

In [5]:
num_epochs1       = 100
num_epochs2       = 250
num_epochs3       = 1
batch_size        = 512
pseudocount       = 0.1
param_inertia1    = 0.1 
param_inertia2    = 0.9
param_inertia3    = 0.95

@time begin
    mini_batch_em(bpc, mnist_train_gpu, num_epochs1; batch_size, pseudocount, 
                  param_inertia = param_inertia1, param_inertia_end = param_inertia2);
    mini_batch_em(bpc, mnist_train_gpu, num_epochs2; batch_size, pseudocount, 
                  param_inertia = param_inertia2, param_inertia_end = param_inertia3);
    full_batch_em(bpc, mnist_train_gpu, num_epochs3; batch_size, pseudocount);
end;

Mini-batch EM epoch 1; train LL -913.08734
Mini-batch EM epoch 2; train LL -847.4659
Mini-batch EM epoch 3; train LL -841.1994
Mini-batch EM epoch 4; train LL -836.8761
Mini-batch EM epoch 5; train LL -833.25433
Mini-batch EM epoch 6; train LL -830.4178
Mini-batch EM epoch 7; train LL -827.6333
Mini-batch EM epoch 8; train LL -825.24286
Mini-batch EM epoch 9; train LL -822.84735
Mini-batch EM epoch 10; train LL -820.6389
Mini-batch EM epoch 11; train LL -818.42413
Mini-batch EM epoch 12; train LL -816.32837
Mini-batch EM epoch 13; train LL -814.23914
Mini-batch EM epoch 14; train LL -812.36127
Mini-batch EM epoch 15; train LL -810.39795
Mini-batch EM epoch 16; train LL -808.6348
Mini-batch EM epoch 17; train LL -806.7153
Mini-batch EM epoch 18; train LL -804.8232
Mini-batch EM epoch 19; train LL -803.06384
Mini-batch EM epoch 20; train LL -801.2593
Mini-batch EM epoch 21; train LL -799.64825
Mini-batch EM epoch 22; train LL -797.8991
Mini-batch EM epoch 23; train LL -796.13086
Mini-bat

Mini-batch EM epoch 138; train LL -664.47217
Mini-batch EM epoch 139; train LL -664.4386
Mini-batch EM epoch 140; train LL -664.33105
Mini-batch EM epoch 141; train LL -664.3288
Mini-batch EM epoch 142; train LL -664.2715
Mini-batch EM epoch 143; train LL -664.1785
Mini-batch EM epoch 144; train LL -664.0954
Mini-batch EM epoch 145; train LL -664.043
Mini-batch EM epoch 146; train LL -664.0289
Mini-batch EM epoch 147; train LL -664.0166
Mini-batch EM epoch 148; train LL -663.8913
Mini-batch EM epoch 149; train LL -663.7863
Mini-batch EM epoch 150; train LL -663.75977
Mini-batch EM epoch 151; train LL -663.7384
Mini-batch EM epoch 152; train LL -663.62286
Mini-batch EM epoch 153; train LL -663.56195
Mini-batch EM epoch 154; train LL -663.57666
Mini-batch EM epoch 155; train LL -663.42114
Mini-batch EM epoch 156; train LL -663.4205
Mini-batch EM epoch 157; train LL -663.3667
Mini-batch EM epoch 158; train LL -663.3468
Mini-batch EM epoch 159; train LL -663.29236
Mini-batch EM epoch 160; 

Now we evaluate the trained PC:

In [6]:
train_ll = loglikelihood(bpc, mnist_train_gpu; batch_size)
test_ll = loglikelihood(bpc, mnist_test_gpu; batch_size)
train_bpd = -train_ll / log(2.0) / 28^2
test_bpd = -test_ll / log(2.0) / 28^2
println("Train_ll: $(train_ll)\nTrain_bpd: $(train_bpd)\nTest LL: $(test_ll)\nTest_bpd: $(test_bpd)")

Train_ll: -632.69055
Train_bpd: 1.1642595936712987
Test LL: -652.6873
Test_bpd: 1.2010570858863057


Finally, we copy back the learned parameters from the bit circuit `bpc` to the original PC `pc`:

In [7]:
update_parameters(bpc)