In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp
from jax import random

import optax
from numpyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO

In [None]:
from src.models.M2VAE import M2VAE

from src.models.encoder_decoder import MNISTEncoder, MNISTDecoder
from src.data_loading.constants import MNIST_IMG_SHAPE

from src.data_loading.loaders import get_data_loaders

# Set Random Seed

In [None]:
seed = 42

# Load Data

In [None]:
img_shape, loader_dict, size_dict = get_data_loaders(dataset_name="MNIST", 
                                          p_test=0.2, 
                                          p_val=0.2, 
                                          p_supervised=0.1, 
                                          batch_size=2, 
                                          num_workers=0, 
                                          seed=seed)

scale_factor = 50 / size_dict["supervised"]

# Set up Model

In [None]:
m2_vae = M2VAE(MNISTEncoder, MNISTDecoder, 10, 30, MNIST_IMG_SHAPE, scale_factor=scale_factor)

# Set up Optimizer

In [None]:
optimizer = optax.adam(1e-3)

# Set up Statistic Variational Inference for all cases (Supervised, Unsupervised, and Classify)

In [None]:
svi_supervised = SVI(m2_vae.model_supervised, 
            m2_vae.guide_supervised, 
            optim=optimizer, 
            loss=Trace_ELBO()
)

svi_unsupervised = SVI(m2_vae.model_unsupervised, 
            m2_vae.guide_unsupervised, #config_enumerate(m2_vae.guide_unsupervised), 
            optim=optimizer, 
            loss=Trace_ELBO() # TraceEnum_ELBO(max_plate_nesting=1) Would be better, ...
)

svi_classify = SVI(m2_vae.model_classify,
                   m2_vae.guide_classify,
                   optim=optimizer,
                   loss=Trace_ELBO()
)

In [None]:
state = svi_supervised.init(random.PRNGKey(0), xs=jnp.ones((1,)+MNIST_IMG_SHAPE), ys=jnp.ones((1), dtype=jnp.int32))
svi_unsupervised.init(random.PRNGKey(0), xs=jnp.ones((1,)+MNIST_IMG_SHAPE))
svi_classify.init(random.PRNGKey(0), xs=jnp.ones((1,)+MNIST_IMG_SHAPE), ys=jnp.ones((1), dtype=jnp.int32))

print("SVI set up complete!")

In [None]:
tot_loss_supervised = 0.0
tot_loss_unsupervised = 0.0
tot_loss_classify = 0.0

for is_supervised, batch in loader_dict["semi_supervised"]:
    if is_supervised:
        x, y = batch
        state, loss = svi_supervised.update(state, xs=x, ys=y)
        tot_loss_supervised += loss
        state, loss = svi_classify.update(state, xs=x, ys=y)
        tot_loss_classify += loss
    else:
        x = batch
        state, loss = svi_unsupervised.update(state, xs=x)
        tot_loss_unsupervised += loss

print(tot_loss_supervised, tot_loss_unsupervised, tot_loss_classify)