In [250]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [251]:
import jax
import jax.numpy as jnp
from jax import random

import optax
from numpyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO

from numpyro.contrib.funsor import config_enumerate

In [252]:
from src.models.M2VAE import M2VAE

from src.models.encoder_decoder import MNISTEncoder, MNISTDecoder
from src.data_loading.constants import MNIST_IMG_SHAPE

from src.data_loading.loaders import get_data_loaders

In [253]:
m2_vae = M2VAE(MNISTEncoder, MNISTDecoder, 10, 30, MNIST_IMG_SHAPE)

In [254]:
optimizer = optax.adam(1e-3)

In [255]:
svi_supervised = SVI(m2_vae.model_supervised, 
            m2_vae.guide_supervised, 
            optim=optimizer, 
            loss=Trace_ELBO()
)

svi_unsupervised = SVI(config_enumerate(m2_vae.model_unsupervised), 
            config_enumerate(m2_vae.guide_unsupervised), 
            optim=optimizer, 
            loss=TraceEnum_ELBO()
)

In [256]:
state = svi_supervised.init(random.PRNGKey(0), xs=jnp.ones((1,)+MNIST_IMG_SHAPE), ys=jnp.ones((1), dtype=jnp.int32))
state_unsup = svi_unsupervised.init(random.PRNGKey(0), xs=jnp.ones((1,)+MNIST_IMG_SHAPE))

In [257]:
seed = 42

# Data loading
img_shape, loader_dict = get_data_loaders(dataset_name="MNIST", 
                                          p_test=0.2, 
                                          p_val=0.1, 
                                          p_supervised=0.2, 
                                          batch_size=2, 
                                          num_workers=0, 
                                          seed=seed)

Successfully loaded MNIST dataset.
Total num samples 60000
Num test samples: 12000
Num validation samples: 4800
Num supervised samples: 8640
Num unsupervised samples: 34560


In [258]:
tot_loss_supervised = 0.0
tot_loss_unsupervised = 0.0

c = 0
for batch in loader_dict["supervised"]:
    if True:
        x, y = batch
        state, loss = svi_supervised.update(state, xs=x, ys=y)
        tot_loss_supervised += loss
    else:
        x = batch
        state, loss = svi_supervised.update(state, xs=x)
        tot_loss_unsupervised += loss
    if c > 5:
        break
    c+= 1
print(tot_loss_supervised, tot_loss_unsupervised)

7638.25 0.0


In [259]:
tot_loss_supervised = 0.0
tot_loss_unsupervised = 0.0

c = 0
for is_supervised, batch in loader_dict["semi_supervised"]:
    if is_supervised:
        x, y = batch
        state, loss = svi_supervised.update(state, xs=x, ys=y)
        tot_loss_supervised += loss
    else:
        x = batch
        state, loss = svi_unsupervised.update(state, xs=x)
        tot_loss_unsupervised += loss
    if c > 5:
        break
    c+= 1
print(tot_loss_supervised, tot_loss_unsupervised)

TypeError: Cannot concatenate arrays with different numbers of dimensions: got (2, 30), (10, 1, 10).