In [1]:
import torch
import numpy as np
import torch.utils
import matplotlib.pyplot as plt
import torch.optim as optim

import utils.visualize
from trainers import UnsupervisedTrainer
import losses
import vae_models
from datasets import get_dataset
from utils.io import find_optimal_num_workers

## 2. Configuration

In [2]:
# --- General Hyperparameters ---
model_name = 'toroidal_vae_burgess'  # Name of the model architecture file (e.g., 'vae_burgess')
latent_factor_num = 10
learning_rate = 1e-4
batch_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
rec_dist = 'bernoulli'  # Reconstruction distribution (e.g., 'bernoulli', 'gaussian')

# train_step_unit = 'iteration'  # Unit for training steps ('epoch' or 'iteration')
# num_train_steps = int(3e5)  # Number of training steps 

train_step_unit = 'epoch'  # Unit for training steps ('epoch' or 'iteration')
num_train_steps = 5  # Number of training steps (epochs)

In [3]:
# --- Loss Specific Hyperparameters ---
# Beta VAE
loss_name = 'beta_toroidal_vae' 
loss_kwargs = {
    'beta': 16,
    'rec_dist': rec_dist
               }

In [4]:
# # AnnealedVAE
# loss_name = 'annealedvae'  
# loss_kwargs = {
#     'C_init': 0.0,
#     'C_fin_3dshapes': 25.0,
#     'C_fin_dsprites': 15.0,
#     'gamma_annealed': 100.0,
#     'anneal_steps': 10000,
#     'rec_dist': rec_dist,
# }

In [5]:
# # BetaTCVAE
# loss_name = 'betatcvae'
# loss_kwargs = {
#     'alpha_tc': 1.0,
#     'beta_tc_3dshapes': 6.0,
#     'beta_tc_dsprites': 4.0,
#     'gamma_tc': 1.0,
#     'is_mss': True,
#     'rec_dist': rec_dist
# }

In [6]:
# # Factor VAE

# loss_name = 'factorvae'
# loss_kwargs = {
#     'device': device,
#     'gamma': 6.4,
#     'discr_lr': 5e-5,
#     'discr_betas': (0.5, 0.9),
#     'rec_dist': rec_dist,
# }

## 3. Load Datasets

In [7]:
# Load 3D Shapes
Shapes3D = get_dataset("shapes3d")
shapes3d_dataset = Shapes3D(selected_factors='all', not_selected_factors_index_value=None)
# num_workers_3dshapes = find_optimal_num_workers(shapes3d_dataset, batch_size=batch_size, num_batches_to_test='all')
num_workers_3dshapes = 4

shapes3d_dataloader = torch.utils.data.DataLoader(shapes3d_dataset, batch_size=batch_size, num_workers=num_workers_3dshapes, shuffle=True, pin_memory=True)
print(f"Loaded 3D Shapes dataset with {len(shapes3d_dataset)} samples.")

# Load dSprites
Dsprites = get_dataset('dsprites')

dsprites_dataset = Dsprites(selected_factors='all', not_selected_factors_index_value=None)
# num_workers_dsprites = find_optimal_num_workers(dsprites_dataset, batch_size=batch_size, num_batches_to_test='all')
num_workers_dsprites = 7

dsprites_dataloader = torch.utils.data.DataLoader(dsprites_dataset, batch_size=batch_size, num_workers=num_workers_dsprites, shuffle=True, pin_memory=True)
print(f"Loaded dSprites dataset with {len(dsprites_dataset)} samples.")

Loaded 3D Shapes dataset with 480000 samples.
Loaded dSprites dataset with 737280 samples.


## 4. Setup Model, Loss, and Optimizer

In [8]:
def setup_components(dataset, loss_kwargs):
    """Instantiates model, loss function, and optimizer based on config."""
    img_size = dataset[0][0].shape
    n_data = len(dataset)
    

    # Instantiate Model
    model = vae_models.select(name=model_name, img_size=img_size, latent_factor_num=latent_factor_num)

    if loss_name == 'betatcvae':
        loss_kwargs['n_data'] = n_data
    
    loss_fn = losses.select(loss_name, **loss_kwargs)

    # Instantiate Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    print(f"--- Setup for {dataset.__class__.__name__} --- ")
    print(f"Model: {model.__class__.__name__}")
    print(f"Loss: {loss_fn.__class__.__name__} (rec_dist={rec_dist}), kwargs={loss_kwargs}")
    print(f"Optimizer: {optimizer.__class__.__name__}")
    print(f"---------------------------")

    return model, loss_fn, optimizer

## 5. Train and Visualize

## 5.1 - 3D Shapes

In [9]:
shapes3d_loss_kwargs =  {
     'beta': 16,
     'rec_dist': rec_dist
               }

In [10]:
print("\n===== Training on 3D Shapes =====")
model_3dshapes, loss_fn_3dshapes, optimizer_3dshapes = setup_components(shapes3d_dataset, shapes3d_loss_kwargs)

trainer_3dshapes = UnsupervisedTrainer(model=model_3dshapes,
                                      loss_fn=loss_fn_3dshapes,
                                      scheduler=None,
                                      optimizer=optimizer_3dshapes,
                                      device=device,
                                      train_step_unit=train_step_unit,
                                      )

trainer_3dshapes.train(shapes3d_dataloader, num_train_steps)


===== Training on 3D Shapes =====
--- Setup for Shapes3D --- 
Model: Model
Loss: BetaToroidalVAELoss (rec_dist=bernoulli), kwargs={'beta': 16, 'rec_dist': 'bernoulli'}
Optimizer: Adam
---------------------------


Epoch 1:   0%|          | 0/7500 [00:00<?, ?it/s]

samples_qzx shape: torch.Size([64, 20])


Epoch 1:   0%|          | 4/7500 [00:03<1:22:48,  1.51it/s]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   0%|          | 10/7500 [00:03<25:29,  4.90it/s] 

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   0%|          | 16/7500 [00:03<13:07,  9.51it/s]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   0%|          | 22/7500 [00:04<08:37, 14.44it/s]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   0%|          | 28/7500 [00:04<06:39, 18.69it/s]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   0%|          | 34/7500 [00:04<05:50, 21.31it/s]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   0%|          | 37/7500 [00:04<05:31, 22.55it/s]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   1%|          | 43/7500 [00:05<05:20, 23.25it/s]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   1%|          | 49/7500 [00:05<05:05, 24.39it/s, kl_loss=0.332, loss=8.58e+3, rec_loss=8.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   1%|          | 55/7500 [00:05<04:59, 24.87it/s, kl_loss=0.332, loss=8.58e+3, rec_loss=8.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   1%|          | 61/7500 [00:05<04:58, 24.91it/s, kl_loss=0.332, loss=8.58e+3, rec_loss=8.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   1%|          | 67/7500 [00:05<04:54, 25.23it/s, kl_loss=0.332, loss=8.58e+3, rec_loss=8.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   1%|          | 73/7500 [00:06<04:57, 24.96it/s, kl_loss=0.332, loss=8.58e+3, rec_loss=8.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   1%|          | 79/7500 [00:06<04:57, 24.91it/s, kl_loss=0.332, loss=8.58e+3, rec_loss=8.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   1%|          | 82/7500 [00:06<05:01, 24.56it/s, kl_loss=0.332, loss=8.58e+3, rec_loss=8.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   1%|          | 88/7500 [00:06<05:15, 23.50it/s, kl_loss=0.332, loss=8.58e+3, rec_loss=8.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   1%|▏         | 94/7500 [00:07<05:15, 23.51it/s, kl_loss=0.332, loss=8.58e+3, rec_loss=8.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   1%|▏         | 97/7500 [00:07<05:22, 22.98it/s, kl_loss=0.332, loss=8.58e+3, rec_loss=8.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   1%|▏         | 103/7500 [00:07<05:27, 22.61it/s, kl_loss=0.000906, loss=8.18e+3, rec_loss=8.18e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   1%|▏         | 109/7500 [00:07<05:22, 22.89it/s, kl_loss=0.000906, loss=8.18e+3, rec_loss=8.18e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   2%|▏         | 115/7500 [00:08<05:11, 23.73it/s, kl_loss=0.000906, loss=8.18e+3, rec_loss=8.18e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   2%|▏         | 118/7500 [00:08<05:05, 24.16it/s, kl_loss=0.000906, loss=8.18e+3, rec_loss=8.18e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   2%|▏         | 124/7500 [00:08<04:57, 24.78it/s, kl_loss=0.000906, loss=8.18e+3, rec_loss=8.18e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   2%|▏         | 130/7500 [00:08<05:00, 24.53it/s, kl_loss=0.000906, loss=8.18e+3, rec_loss=8.18e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   2%|▏         | 136/7500 [00:08<05:02, 24.34it/s, kl_loss=0.000906, loss=8.18e+3, rec_loss=8.18e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   2%|▏         | 142/7500 [00:09<04:57, 24.75it/s, kl_loss=0.000906, loss=8.18e+3, rec_loss=8.18e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   2%|▏         | 148/7500 [00:09<05:00, 24.50it/s, kl_loss=0.000906, loss=8.18e+3, rec_loss=8.18e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   2%|▏         | 151/7500 [00:09<04:55, 24.89it/s, kl_loss=0.00605, loss=7.75e+3, rec_loss=7.75e+3] 

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   2%|▏         | 157/7500 [00:09<05:04, 24.14it/s, kl_loss=0.00605, loss=7.75e+3, rec_loss=7.75e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   2%|▏         | 163/7500 [00:09<05:04, 24.06it/s, kl_loss=0.00605, loss=7.75e+3, rec_loss=7.75e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   2%|▏         | 169/7500 [00:10<04:57, 24.60it/s, kl_loss=0.00605, loss=7.75e+3, rec_loss=7.75e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   2%|▏         | 175/7500 [00:10<05:07, 23.82it/s, kl_loss=0.00605, loss=7.75e+3, rec_loss=7.75e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   2%|▏         | 178/7500 [00:10<05:08, 23.73it/s, kl_loss=0.00605, loss=7.75e+3, rec_loss=7.75e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   2%|▏         | 184/7500 [00:10<05:05, 23.97it/s, kl_loss=0.00605, loss=7.75e+3, rec_loss=7.75e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   3%|▎         | 190/7500 [00:11<05:06, 23.86it/s, kl_loss=0.00605, loss=7.75e+3, rec_loss=7.75e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   3%|▎         | 193/7500 [00:11<05:03, 24.08it/s, kl_loss=0.00605, loss=7.75e+3, rec_loss=7.75e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   3%|▎         | 199/7500 [00:11<05:02, 24.15it/s, kl_loss=0.00492, loss=7.62e+3, rec_loss=7.62e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   3%|▎         | 205/7500 [00:11<04:50, 25.07it/s, kl_loss=0.00492, loss=7.62e+3, rec_loss=7.62e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   3%|▎         | 211/7500 [00:11<04:38, 26.21it/s, kl_loss=0.00492, loss=7.62e+3, rec_loss=7.62e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   3%|▎         | 217/7500 [00:12<04:40, 25.96it/s, kl_loss=0.00492, loss=7.62e+3, rec_loss=7.62e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   3%|▎         | 223/7500 [00:12<04:42, 25.80it/s, kl_loss=0.00492, loss=7.62e+3, rec_loss=7.62e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   3%|▎         | 229/7500 [00:12<04:53, 24.77it/s, kl_loss=0.00492, loss=7.62e+3, rec_loss=7.62e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   3%|▎         | 232/7500 [00:12<04:57, 24.43it/s, kl_loss=0.00492, loss=7.62e+3, rec_loss=7.62e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   3%|▎         | 238/7500 [00:13<05:03, 23.93it/s, kl_loss=0.00492, loss=7.62e+3, rec_loss=7.62e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   3%|▎         | 241/7500 [00:13<05:06, 23.68it/s, kl_loss=0.00492, loss=7.62e+3, rec_loss=7.62e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   3%|▎         | 247/7500 [00:13<05:01, 24.07it/s, kl_loss=0.00492, loss=7.62e+3, rec_loss=7.62e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   3%|▎         | 253/7500 [00:13<04:54, 24.60it/s, kl_loss=0.00166, loss=7.57e+3, rec_loss=7.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   3%|▎         | 259/7500 [00:13<05:05, 23.72it/s, kl_loss=0.00166, loss=7.57e+3, rec_loss=7.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   4%|▎         | 265/7500 [00:14<04:59, 24.17it/s, kl_loss=0.00166, loss=7.57e+3, rec_loss=7.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   4%|▎         | 271/7500 [00:14<04:53, 24.66it/s, kl_loss=0.00166, loss=7.57e+3, rec_loss=7.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   4%|▎         | 277/7500 [00:14<04:53, 24.59it/s, kl_loss=0.00166, loss=7.57e+3, rec_loss=7.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   4%|▎         | 280/7500 [00:14<05:01, 23.95it/s, kl_loss=0.00166, loss=7.57e+3, rec_loss=7.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   4%|▍         | 286/7500 [00:15<05:06, 23.56it/s, kl_loss=0.00166, loss=7.57e+3, rec_loss=7.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   4%|▍         | 292/7500 [00:15<05:06, 23.49it/s, kl_loss=0.00166, loss=7.57e+3, rec_loss=7.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   4%|▍         | 298/7500 [00:15<04:59, 24.03it/s, kl_loss=0.00166, loss=7.57e+3, rec_loss=7.57e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   4%|▍         | 301/7500 [00:15<04:56, 24.25it/s, kl_loss=0.000899, loss=7.55e+3, rec_loss=7.55e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   4%|▍         | 307/7500 [00:15<04:51, 24.71it/s, kl_loss=0.000899, loss=7.55e+3, rec_loss=7.55e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   4%|▍         | 313/7500 [00:16<04:50, 24.77it/s, kl_loss=0.000899, loss=7.55e+3, rec_loss=7.55e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   4%|▍         | 319/7500 [00:16<04:43, 25.30it/s, kl_loss=0.000899, loss=7.55e+3, rec_loss=7.55e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


Epoch 1:   4%|▍         | 322/7500 [00:16<04:51, 24.62it/s, kl_loss=0.000899, loss=7.55e+3, rec_loss=7.55e+3]

samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])
samples_qzx shape: torch.Size([64, 20])


                                                                                                             

KeyboardInterrupt: 