In [1]:
from pathlib import Path
import pickle

import jax
from jax import jit, device_put
import jax.numpy as jnp
from jax import random
from tqdm import tqdm
import optax
from numpyro.infer import SVI, Trace_ELBO
import numpy as np
import matplotlib.pyplot as plt

from src.models.CCVAE import CCVAE
from src.models.encoder_decoder import MNISTEncoder, MNISTDecoder, CIFAR10Encoder, CIFAR10Decoder
from src.data_loading.loaders import get_data_loaders
from src.losses import CCVAE_ELBO


# Set up random seed
seed = 42

# DATASET
dataset_name = "CELEBA" # use "CIFAR10"*
#dataset_name = 'CIFAR10'

encoder_class = MNISTEncoder if dataset_name=="MNIST" else CIFAR10Encoder
decoder_class = MNISTDecoder if dataset_name=="MNIST" else CIFAR10Decoder
distribution = "bernoulli" if dataset_name=="MNIST" else "laplace"

# Data loading

img_shape, loader_dict, size_dict = get_data_loaders(dataset_name=dataset_name, 
                                          p_test=0.2, 
                                          p_val=0.2, 
                                          p_supervised=0.9, 
                                          batch_size=128, 
                                          num_workers=6, 
                                          seed=seed)

scale_factor = 0.1 * size_dict["supervised"] # IMPORTANT, maybe run a grid search (0.3 on cifar)

# Set up model
ccvae = CCVAE(encoder_class, 
               decoder_class, 
               10, 
               50, 
               img_shape, 
               scale_factor=scale_factor, 
               distribution=distribution,
               multiclass=False
)
print("Model set up!")

# Set up two learning rate schedules, one for the VAE and one for the classifier
lr_schedule_vae = optax.piecewise_constant_schedule(
    init_value=2e-3,
    boundaries_and_scales={
        20 * len(loader_dict["semi_supervised"]): 0.7,
    }
)
lr_schedule_classifier = optax.piecewise_constant_schedule(
    init_value=1e-3,
    boundaries_and_scales={
        20 * len(loader_dict["semi_supervised"]): 0.7,
    }
)
# Set up two optimizers, one for the VAE and one for the classifier

optimizer_vae = optax.adam(lr_schedule_vae)
optimizer_classifier = optax.adamw(lr_schedule_classifier, weight_decay=3e-4)

  from .autonotebook import tqdm as notebook_tqdm


Successfully loaded MNIST dataset.
Total num samples 60000
Num test samples: 12000
Num validation samples: 9600
Num supervised samples: 34560
Num unsupervised samples: 3840
Model set up!


In [2]:
import flax
def flattened_traversal(fn):
  """Returns function that is called with `(path, param)` instead of pytree."""
  def mask(tree):
    flat = flax.traverse_util.flatten_dict(tree)
    return flax.traverse_util.unflatten_dict(
        {k: fn(k, v) for k, v in flat.items()})
  return mask

In [3]:
vae_classifier_fn = lambda path, param: 'low' if "Classifier" in path else 'high'

vae_classifier_fn = flattened_traversal(vae_classifier_fn)

# Set up optimizer

optimizer = optax.multi_transform({'low': optimizer_classifier, 'high': optimizer_vae}, vae_classifier_fn)

In [4]:

# Set up SVI
svi_supervised = SVI(ccvae.model_supervised, 
            ccvae.guide_supervised, 
            optim=optimizer, 
            loss=CCVAE_ELBO()
)

svi_unsupervised = SVI(ccvae.model_unsupervised, 
            ccvae.guide_unsupervised,
            optim=optimizer, 
            loss=Trace_ELBO()
)

state = svi_supervised.init(
    random.PRNGKey(seed), 
    xs=jnp.ones((1,)+img_shape), 
    ys=jnp.ones((1), dtype=jnp.int32)
)
svi_unsupervised.init(
    random.PRNGKey(seed), 
    xs=jnp.ones((1,)+img_shape)
)
print("SVI set up!")


# Train functions
@jit
def train_step_supervised(state, batch):
    x, y = batch
    state, loss_supervised = svi_supervised.update(state, xs=x, ys=y)
    
    return state, loss_supervised

@jit
def train_step_unsupervised(state, batch):
    x = batch
    state, loss_unsupervised = svi_unsupervised.update(state, xs=x)
    
    return state, loss_unsupervised

# Training
semi_supervised_loader = loader_dict["semi_supervised"]
validation_loader = loader_dict["validation"]
test_loader = loader_dict["test"]

print("Start training.")
loss_rec_supervised = []
loss_rec_unsupervised = []
validation_accuracy_rec = []


SVI set up!
Start training.


  svi_unsupervised.init(


In [5]:
num_epochs = 100
for epoch in tqdm(range(1, num_epochs + 1)):
    running_loss = 0.0

    loss_rec_step_supervised = []
    loss_rec_step_unsupervised = []

    # Trainning
    for is_supervised, batch in semi_supervised_loader: 
        batch = device_put(batch)

        if is_supervised:
            state, loss_supervised = train_step_supervised(state, batch)
            loss_rec_step_supervised.append(loss_supervised)
        else:
            state, loss_unsupervised = train_step_unsupervised(state, batch)
            loss_rec_step_unsupervised.append(loss_unsupervised)
    
    loss_epoch_supervised = np.mean(loss_rec_step_supervised)
    loss_epoch_unsupervised = np.mean(loss_rec_step_unsupervised)

    loss_rec_supervised.append(loss_epoch_supervised)
    loss_rec_unsupervised.append(loss_epoch_unsupervised)
    
    validation_accuracy = 0.0

    for batch in validation_loader:
        batch = device_put(batch)
        x, y = batch
        ypred = ccvae.classify(state[0][1][0], x)
        validation_accuracy += jnp.mean(y == ypred)
    
    validation_accuracy /= len(validation_loader)
    validation_accuracy_rec.append(validation_accuracy)
    
    print("\nEpoch:", 
          epoch, 
          "loss sup:", 
          loss_epoch_supervised, 
          "loss unsup:", 
          loss_epoch_unsupervised, 
          "val acc:", 
          validation_accuracy
    )

print("Training finished!")

  1%|          | 1/100 [00:09<16:12,  9.82s/it]


Epoch: 1 loss sup: 34727.152 loss unsup: 33238.176 val acc: 0.13854165


  2%|▏         | 2/100 [00:12<09:28,  5.80s/it]


Epoch: 2 loss sup: 25699.863 loss unsup: 25543.617 val acc: 0.19864583


  3%|▎         | 3/100 [00:15<07:11,  4.44s/it]


Epoch: 3 loss sup: 23984.86 loss unsup: 23946.72 val acc: 0.21135415


  4%|▍         | 4/100 [00:18<06:02,  3.78s/it]


Epoch: 4 loss sup: 23022.28 loss unsup: 22859.684 val acc: 0.20947915


  5%|▌         | 5/100 [00:21<05:28,  3.45s/it]


Epoch: 5 loss sup: 21797.152 loss unsup: 21809.682 val acc: 0.28333333


  6%|▌         | 6/100 [00:24<05:04,  3.24s/it]


Epoch: 6 loss sup: 20970.73 loss unsup: 20873.309 val acc: 0.28395832


  7%|▋         | 7/100 [00:26<04:46,  3.08s/it]


Epoch: 7 loss sup: 20386.822 loss unsup: 20280.588 val acc: 0.20677082


  8%|▊         | 8/100 [00:29<04:34,  2.98s/it]


Epoch: 8 loss sup: 19823.402 loss unsup: 19777.717 val acc: 0.18968749


  9%|▉         | 9/100 [00:32<04:26,  2.92s/it]


Epoch: 9 loss sup: 19540.625 loss unsup: 19439.469 val acc: 0.21916665


 10%|█         | 10/100 [00:35<04:22,  2.92s/it]


Epoch: 10 loss sup: 19133.135 loss unsup: 19029.094 val acc: 0.23718749


 11%|█         | 11/100 [00:38<04:18,  2.90s/it]


Epoch: 11 loss sup: 18771.557 loss unsup: 18643.617 val acc: 0.25062498


 12%|█▏        | 12/100 [00:41<04:14,  2.90s/it]


Epoch: 12 loss sup: 18437.111 loss unsup: 18298.188 val acc: 0.25020832


 13%|█▎        | 13/100 [00:43<04:07,  2.85s/it]


Epoch: 13 loss sup: 18121.967 loss unsup: 18014.018 val acc: 0.2153125


 14%|█▍        | 14/100 [00:46<04:03,  2.83s/it]


Epoch: 14 loss sup: 17768.744 loss unsup: 17731.725 val acc: 0.23385416


 15%|█▌        | 15/100 [00:49<03:59,  2.81s/it]


Epoch: 15 loss sup: 17571.018 loss unsup: 17458.055 val acc: 0.22781248


 16%|█▌        | 16/100 [00:52<03:59,  2.86s/it]


Epoch: 16 loss sup: 17414.635 loss unsup: 17287.768 val acc: 0.20208332


 17%|█▋        | 17/100 [00:55<03:57,  2.86s/it]


Epoch: 17 loss sup: 17245.963 loss unsup: 17137.467 val acc: 0.20010416


 18%|█▊        | 18/100 [00:58<03:54,  2.86s/it]


Epoch: 18 loss sup: 17058.24 loss unsup: 17009.574 val acc: 0.23416665


 19%|█▉        | 19/100 [01:01<03:55,  2.90s/it]


Epoch: 19 loss sup: 16953.396 loss unsup: 16875.633 val acc: 0.22864582


 20%|██        | 20/100 [01:04<03:55,  2.95s/it]


Epoch: 20 loss sup: 16923.492 loss unsup: 16776.94 val acc: 0.19749999


 21%|██        | 21/100 [01:06<03:50,  2.91s/it]


Epoch: 21 loss sup: 16786.924 loss unsup: 16629.652 val acc: 0.27072915


 22%|██▏       | 22/100 [01:09<03:45,  2.89s/it]


Epoch: 22 loss sup: 16693.17 loss unsup: 16565.732 val acc: 0.26135415


 23%|██▎       | 23/100 [01:12<03:40,  2.86s/it]


Epoch: 23 loss sup: 16650.73 loss unsup: 16492.709 val acc: 0.15145832


 24%|██▍       | 24/100 [01:15<03:37,  2.86s/it]


Epoch: 24 loss sup: 16604.213 loss unsup: 16435.045 val acc: 0.23166665


 25%|██▌       | 25/100 [01:18<03:34,  2.86s/it]


Epoch: 25 loss sup: 16530.596 loss unsup: 16403.584 val acc: 0.19885416


 26%|██▌       | 26/100 [01:21<03:30,  2.84s/it]


Epoch: 26 loss sup: 16460.688 loss unsup: 16356.245 val acc: 0.16031249


 27%|██▋       | 27/100 [01:23<03:26,  2.83s/it]


Epoch: 27 loss sup: 16442.41 loss unsup: 16313.684 val acc: 0.13260417


 28%|██▊       | 28/100 [01:26<03:23,  2.83s/it]


Epoch: 28 loss sup: 16441.889 loss unsup: 16277.297 val acc: 0.14281249


 29%|██▉       | 29/100 [01:29<03:21,  2.84s/it]


Epoch: 29 loss sup: 16395.377 loss unsup: 16234.754 val acc: 0.16833332


 30%|███       | 30/100 [01:32<03:17,  2.83s/it]


Epoch: 30 loss sup: 16328.849 loss unsup: 16179.568 val acc: 0.19010416


 31%|███       | 31/100 [01:35<03:14,  2.82s/it]


Epoch: 31 loss sup: 16283.828 loss unsup: 16152.885 val acc: 0.24770832


 32%|███▏      | 32/100 [01:38<03:11,  2.82s/it]


Epoch: 32 loss sup: 16244.9 loss unsup: 16101.917 val acc: 0.21822916


 33%|███▎      | 33/100 [01:40<03:11,  2.85s/it]


Epoch: 33 loss sup: 16181.922 loss unsup: 16048.37 val acc: 0.19281249


 34%|███▍      | 34/100 [01:43<03:06,  2.83s/it]


Epoch: 34 loss sup: 16125.422 loss unsup: 16025.371 val acc: 0.1559375


 35%|███▌      | 35/100 [01:46<03:04,  2.85s/it]


Epoch: 35 loss sup: 16137.444 loss unsup: 16012.215 val acc: 0.14156249


 36%|███▌      | 36/100 [01:49<03:02,  2.84s/it]


Epoch: 36 loss sup: 16098.013 loss unsup: 15974.236 val acc: 0.15291665


 37%|███▋      | 37/100 [01:52<02:58,  2.83s/it]


Epoch: 37 loss sup: 16074.535 loss unsup: 15943.369 val acc: 0.14354166


 38%|███▊      | 38/100 [01:55<02:55,  2.83s/it]


Epoch: 38 loss sup: 16032.004 loss unsup: 15905.512 val acc: 0.20854166


 39%|███▉      | 39/100 [01:57<02:53,  2.84s/it]


Epoch: 39 loss sup: 16049.132 loss unsup: 15865.081 val acc: 0.14447916


 40%|████      | 40/100 [02:00<02:50,  2.85s/it]


Epoch: 40 loss sup: 15959.389 loss unsup: 15845.765 val acc: 0.16572917


 41%|████      | 41/100 [02:03<02:47,  2.84s/it]


Epoch: 41 loss sup: 15950.628 loss unsup: 15821.748 val acc: 0.13812499


 42%|████▏     | 42/100 [02:06<02:44,  2.84s/it]


Epoch: 42 loss sup: 15977.748 loss unsup: 15797.662 val acc: 0.190625


 43%|████▎     | 43/100 [02:09<02:42,  2.84s/it]


Epoch: 43 loss sup: 15898.25 loss unsup: 15791.412 val acc: 0.18333332


 44%|████▍     | 44/100 [02:12<02:40,  2.87s/it]


Epoch: 44 loss sup: 15904.035 loss unsup: 15745.317 val acc: 0.27135417


 45%|████▌     | 45/100 [02:15<02:36,  2.85s/it]


Epoch: 45 loss sup: 15799.967 loss unsup: 15718.785 val acc: 0.18416665


 46%|████▌     | 46/100 [02:18<02:35,  2.89s/it]


Epoch: 46 loss sup: 15851.539 loss unsup: 15701.318 val acc: 0.1509375


 47%|████▋     | 47/100 [02:20<02:31,  2.86s/it]


Epoch: 47 loss sup: 15823.06 loss unsup: 15694.226 val acc: 0.14458333


 48%|████▊     | 48/100 [02:23<02:28,  2.85s/it]


Epoch: 48 loss sup: 15832.645 loss unsup: 15681.464 val acc: 0.20239583


 49%|████▉     | 49/100 [02:26<02:24,  2.84s/it]


Epoch: 49 loss sup: 15820.715 loss unsup: 15662.493 val acc: 0.19218749


 50%|█████     | 50/100 [02:29<02:21,  2.83s/it]


Epoch: 50 loss sup: 15762.604 loss unsup: 15649.024 val acc: 0.20968749


 51%|█████     | 51/100 [02:32<02:17,  2.80s/it]


Epoch: 51 loss sup: 15743.752 loss unsup: 15627.614 val acc: 0.16135415


 52%|█████▏    | 52/100 [02:34<02:15,  2.82s/it]


Epoch: 52 loss sup: 15771.826 loss unsup: 15637.632 val acc: 0.20791666


 53%|█████▎    | 53/100 [02:37<02:12,  2.83s/it]


Epoch: 53 loss sup: 15750.663 loss unsup: 15617.801 val acc: 0.17062499


 54%|█████▍    | 54/100 [02:40<02:10,  2.84s/it]


Epoch: 54 loss sup: 15739.235 loss unsup: 15587.946 val acc: 0.144375


 55%|█████▌    | 55/100 [02:43<02:08,  2.86s/it]


Epoch: 55 loss sup: 15742.017 loss unsup: 15576.811 val acc: 0.14656249


 56%|█████▌    | 56/100 [02:46<02:07,  2.90s/it]


Epoch: 56 loss sup: 15738.872 loss unsup: 15584.099 val acc: 0.13833332


 57%|█████▋    | 57/100 [02:49<02:04,  2.90s/it]


Epoch: 57 loss sup: 15777.893 loss unsup: 15578.849 val acc: 0.12635416


 58%|█████▊    | 58/100 [02:52<02:00,  2.86s/it]


Epoch: 58 loss sup: 15699.67 loss unsup: 15570.157 val acc: 0.13010415


 59%|█████▉    | 59/100 [02:55<01:57,  2.87s/it]


Epoch: 59 loss sup: 15726.546 loss unsup: 15536.75 val acc: 0.13114582


 60%|██████    | 60/100 [02:57<01:53,  2.84s/it]


Epoch: 60 loss sup: 15715.226 loss unsup: 15535.338 val acc: 0.12989582


 61%|██████    | 61/100 [03:00<01:50,  2.82s/it]


Epoch: 61 loss sup: 15675.548 loss unsup: 15526.353 val acc: 0.1334375


 62%|██████▏   | 62/100 [03:03<01:49,  2.88s/it]


Epoch: 62 loss sup: 15671.304 loss unsup: 15532.184 val acc: 0.125625


 63%|██████▎   | 63/100 [03:06<01:45,  2.84s/it]


Epoch: 63 loss sup: 15673.909 loss unsup: 15514.686 val acc: 0.13083333


 64%|██████▍   | 64/100 [03:09<01:42,  2.84s/it]


Epoch: 64 loss sup: 15679.866 loss unsup: 15506.186 val acc: 0.13635416


 65%|██████▌   | 65/100 [03:12<01:39,  2.83s/it]


Epoch: 65 loss sup: 15650.058 loss unsup: 15512.869 val acc: 0.1359375


 66%|██████▌   | 66/100 [03:14<01:35,  2.81s/it]


Epoch: 66 loss sup: 15638.361 loss unsup: 15486.181 val acc: 0.143125


 67%|██████▋   | 67/100 [03:17<01:32,  2.81s/it]


Epoch: 67 loss sup: 15673.487 loss unsup: 15470.75 val acc: 0.13645832


 68%|██████▊   | 68/100 [03:20<01:29,  2.81s/it]


Epoch: 68 loss sup: 15661.926 loss unsup: 15471.215 val acc: 0.13114582


 69%|██████▉   | 69/100 [03:23<01:27,  2.83s/it]


Epoch: 69 loss sup: 15659.2705 loss unsup: 15467.585 val acc: 0.13479166


 70%|███████   | 70/100 [03:26<01:26,  2.88s/it]


Epoch: 70 loss sup: 15651.004 loss unsup: 15466.888 val acc: 0.13489583


 71%|███████   | 71/100 [03:29<01:22,  2.85s/it]


Epoch: 71 loss sup: 15608.1 loss unsup: 15447.159 val acc: 0.13385417


 72%|███████▏  | 72/100 [03:31<01:19,  2.86s/it]


Epoch: 72 loss sup: 15573.293 loss unsup: 15450.476 val acc: 0.13937499


 73%|███████▎  | 73/100 [03:34<01:17,  2.87s/it]


Epoch: 73 loss sup: 15619.991 loss unsup: 15427.679 val acc: 0.14229167


 74%|███████▍  | 74/100 [03:37<01:15,  2.89s/it]


Epoch: 74 loss sup: 15632.444 loss unsup: 15423.746 val acc: 0.13874999


 75%|███████▌  | 75/100 [03:40<01:13,  2.93s/it]


Epoch: 75 loss sup: 15631.812 loss unsup: 15416.0 val acc: 0.14010416


 76%|███████▌  | 76/100 [03:43<01:09,  2.90s/it]


Epoch: 76 loss sup: 15618.326 loss unsup: 15430.234 val acc: 0.13510416


 77%|███████▋  | 77/100 [03:46<01:06,  2.88s/it]


Epoch: 77 loss sup: 15594.728 loss unsup: 15408.135 val acc: 0.13947916


 78%|███████▊  | 78/100 [03:49<01:02,  2.84s/it]


Epoch: 78 loss sup: 15604.174 loss unsup: 15410.968 val acc: 0.14093749


 79%|███████▉  | 79/100 [03:51<00:59,  2.83s/it]


Epoch: 79 loss sup: 15570.597 loss unsup: 15403.946 val acc: 0.13895832


 80%|████████  | 80/100 [03:54<00:57,  2.87s/it]


Epoch: 80 loss sup: 15569.855 loss unsup: 15388.723 val acc: 0.14854166


 81%|████████  | 81/100 [03:57<00:54,  2.89s/it]


Epoch: 81 loss sup: 15597.355 loss unsup: 15401.138 val acc: 0.15270832


 82%|████████▏ | 82/100 [04:00<00:51,  2.87s/it]


Epoch: 82 loss sup: 15552.758 loss unsup: 15384.892 val acc: 0.1490625


 83%|████████▎ | 83/100 [04:03<00:48,  2.86s/it]


Epoch: 83 loss sup: 15544.153 loss unsup: 15377.168 val acc: 0.148125


 84%|████████▍ | 84/100 [04:06<00:45,  2.85s/it]


Epoch: 84 loss sup: 15559.409 loss unsup: 15356.798 val acc: 0.14625


 85%|████████▌ | 85/100 [04:09<00:42,  2.83s/it]


Epoch: 85 loss sup: 15593.237 loss unsup: 15337.984 val acc: 0.14697915


 86%|████████▌ | 86/100 [04:12<00:39,  2.83s/it]


Epoch: 86 loss sup: 15573.074 loss unsup: 15367.243 val acc: 0.14958332


 87%|████████▋ | 87/100 [04:14<00:37,  2.87s/it]


Epoch: 87 loss sup: 15546.733 loss unsup: 15344.735 val acc: 0.15395832


 88%|████████▊ | 88/100 [04:17<00:34,  2.84s/it]


Epoch: 88 loss sup: 15556.6 loss unsup: 15329.002 val acc: 0.15958333


 89%|████████▉ | 89/100 [04:20<00:31,  2.82s/it]


Epoch: 89 loss sup: 15540.609 loss unsup: 15317.309 val acc: 0.16520832


 90%|█████████ | 90/100 [04:23<00:28,  2.84s/it]


Epoch: 90 loss sup: 15574.814 loss unsup: 15302.134 val acc: 0.1540625


 91%|█████████ | 91/100 [04:26<00:25,  2.83s/it]


Epoch: 91 loss sup: 15550.541 loss unsup: 15295.022 val acc: 0.15562499


 92%|█████████▏| 92/100 [04:29<00:22,  2.82s/it]


Epoch: 92 loss sup: 15521.339 loss unsup: 15315.699 val acc: 0.15760416


 93%|█████████▎| 93/100 [04:31<00:19,  2.81s/it]


Epoch: 93 loss sup: 15532.491 loss unsup: 15299.627 val acc: 0.16104166


 94%|█████████▍| 94/100 [04:34<00:16,  2.81s/it]


Epoch: 94 loss sup: 15548.13 loss unsup: 15292.337 val acc: 0.16385415


 95%|█████████▌| 95/100 [04:37<00:14,  2.81s/it]


Epoch: 95 loss sup: 15543.655 loss unsup: 15276.782 val acc: 0.16208333


 96%|█████████▌| 96/100 [04:40<00:11,  2.82s/it]


Epoch: 96 loss sup: 15527.742 loss unsup: 15276.851 val acc: 0.15885416


 97%|█████████▋| 97/100 [04:43<00:08,  2.81s/it]


Epoch: 97 loss sup: 15525.865 loss unsup: 15265.209 val acc: 0.15770833


 98%|█████████▊| 98/100 [04:45<00:05,  2.80s/it]


Epoch: 98 loss sup: 15522.7705 loss unsup: 15265.436 val acc: 0.15906249


 99%|█████████▉| 99/100 [04:48<00:02,  2.84s/it]


Epoch: 99 loss sup: 15494.393 loss unsup: 15245.385 val acc: 0.16645832


100%|██████████| 100/100 [04:51<00:00,  2.92s/it]


Epoch: 100 loss sup: 15503.29 loss unsup: 15239.277 val acc: 0.16854165
Training finished!



