In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle

import numpyro
import numpyro.handlers as handlers
import jax.numpy as jnp
import jax

from jax import random
import optax
from numpyro.infer import Predictive, SVI, Trace_ELBO

from src.models.CCVAE import CCVAE
from src.models.encoder_decoder import MNISTEncoder, MNISTDecoder, CIFAR10Encoder, CIFAR10Decoder
from src.data_loading.loaders import get_data_loaders
from src.losses import CCVAE_ELBO

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Set up random seed
seed = 42

# DATASET
dataset_name = "MNIST" # use "CIFAR10"

encoder_class = MNISTEncoder if dataset_name=="MNIST" else CIFAR10Encoder
decoder_class = MNISTDecoder if dataset_name=="MNIST" else CIFAR10Decoder
distribution = "bernoulli" if dataset_name=="MNIST" else "laplace"

In [4]:
img_shape, loader_dict, size_dict = get_data_loaders(dataset_name=dataset_name, 
                                          p_test=0.2, 
                                          p_val=0.2, 
                                          p_supervised=0.05, 
                                          batch_size=10, 
                                          num_workers=0, 
                                          seed=seed)

scale_factor = 1.

Successfully loaded MNIST dataset.
Total num samples 60000
Num test samples: 12000
Num validation samples: 9600
Num supervised samples: 1920
Num unsupervised samples: 36480


In [5]:
ccvae = CCVAE(encoder_class, 
               decoder_class, 
               10, 
               50, 
               img_shape, 
               scale_factor=scale_factor, 
               distribution=distribution,
               multiclass=False
)

In [6]:
optimizer = optax.adam(1e-3)

In [7]:
svi_supervised = SVI(ccvae.model_supervised, 
            ccvae.guide_supervised, 
            optim=optimizer, 
            loss=CCVAE_ELBO()
)

In [8]:
state = svi_supervised.init(
    random.PRNGKey(seed), 
    xs=jnp.ones((1,)+img_shape), 
    ys=jnp.ones((1), dtype=jnp.int32)
)

(1,)


In [9]:
loader_supervised = loader_dict["supervised"]

In [10]:
for batch in loader_supervised:
    x, y = batch
    state, loss = svi_supervised.update(state, xs=x, ys=y)
    break

(10,)


In [11]:
print(loss)

4671.792


In [12]:
svi_unsupervised = SVI(ccvae.model_unsupervised, 
            ccvae.guide_unsupervised, 
            optim=optimizer, 
            loss=Trace_ELBO()
)

In [13]:
state = svi_unsupervised.init(
    random.PRNGKey(seed), 
    xs=jnp.ones((1,)+img_shape)
)

  state = svi_unsupervised.init(


In [14]:
loader_unsupervised = loader_dict["unsupervised"]

In [15]:
for batch in loader_unsupervised:
    x = batch
    state, loss = svi_unsupervised.update(state, xs=x)
    break

In [16]:
print(loss)

5655.629
