# Group Theory VAE (BetaVAE Base) Testing Framework

In [1]:
from utils.reproducibility import set_deterministic_run, get_deterministic_dataloader

seed = 42
set_deterministic_run(seed=seed)

## 1. Imports

In [2]:
import torch
import numpy as np
import torch.utils
import matplotlib.pyplot as plt
import torch.optim as optim

import utils
from trainers import UnsupervisedTrainer
import losses
import vae_models
from datasets import get_dataset
from utils.io import find_optimal_num_workers
from metrics.utils import MetricAggregator

## 2. Configuration

In [3]:
# --- General Hyperparameters ---
model_name = 'vae_locatello'
latent_dim = 10
learning_rate = 1e-4
batch_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
rec_dist = 'bernoulli'

train_step_unit = 'epoch'
num_train_steps = 5

# train_step_unit = 'iteration'
# num_train_steps = int(9e3)

In [4]:
# --- Loss Specific Hyperparameters ---
loss_name = 'group_theory'
base_loss_name = 'betavae'

group_loss_kwargs = {
    'base_loss_name': base_loss_name,
    'base_loss_kwargs': {
        'rec_dist': rec_dist,
        'latent_dim': latent_dim,
        'beta': 16, # Default beta
    },
    'rec_dist': rec_dist,
    'device': device,
    'commutative_weight': 1.0,
    'meaningful_weight': 1.0,
    'commutative_component_order': 2,
    'meaningful_component_order': 1,
    'meaningful_transformation_order': 1,
    'meaningful_critic_gradient_penalty_weight': 10,
    'meaningful_critic_lr': 1e-4,
    'meaningful_n_critic': 5,
    'deterministic_rep': False,
    'log_kl_components': True
}

## 3. Load Datasets

In [5]:
# Load 3D Shapes
Shapes3D = get_dataset("shapes3d")
shapes3d_dataset = Shapes3D(selected_factors='all', not_selected_factors_index_value=None)
num_workers_3dshapes = 4

shapes3d_dataloader = get_deterministic_dataloader(dataset=shapes3d_dataset, 
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers_3dshapes,
                                                   seed=seed,
                                                   pin_memory=True)

print(f"Loaded 3D Shapes dataset with {len(shapes3d_dataset)} samples.")

Loaded 3D Shapes dataset with 480000 samples.


In [6]:
# Load dSprites
Dsprites = get_dataset('dsprites')
dsprites_dataset = Dsprites(selected_factors='all', not_selected_factors_index_value=None)
num_workers_dsprites = 7

dsprites_dataloader = get_deterministic_dataloader(dataset=dsprites_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers_dsprites,
                                                   seed=seed,
                                                   pin_memory=True)

print(f"Loaded dSprites dataset with {len(dsprites_dataset)} samples.")

Loaded dSprites dataset with 737280 samples.


## 4. Setup Model, Loss, and Optimizer

In [7]:
def setup_components(dataset, specific_loss_kwargs):
    """Instantiates model, loss function, and optimizer based on config."""
    img_size = dataset[0][0].shape
    n_data = len(dataset)
    
    model = vae_models.select(name=model_name, img_size=img_size, latent_dim=latent_dim)
    
    final_loss_kwargs = group_loss_kwargs.copy()
    final_loss_kwargs.update(specific_loss_kwargs)
    
    if final_loss_kwargs.get('base_loss') == 'betatcvae':
         final_loss_kwargs['n_data'] = n_data
         
    loss_fn = losses.select(loss_name, **final_loss_kwargs)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    print(f"--- Setup for {dataset.__class__.__name__} --- ")
    print(f"Model: {model.__class__.__name__}" )
    print(f"Loss: {loss_fn.__class__.__name__} (Base: {base_loss_name}, rec_dist={rec_dist}), kwargs={final_loss_kwargs}")
    print(f"Optimizer: {optimizer.__class__.__name__}")
    print(f"---------------------------")

    return model, loss_fn, optimizer

## 5. Train and Visualize

### 5.1 - 3D Shapes

In [8]:
shapes3d_specific_loss_kwargs = group_loss_kwargs.copy()

In [9]:
print("\n===== Training on 3D Shapes =====")
model_3dshapes, loss_fn_3dshapes, optimizer_3dshapes = setup_components(shapes3d_dataset, shapes3d_specific_loss_kwargs)

trainer_3dshapes = UnsupervisedTrainer(model=model_3dshapes,
                               loss_fn=loss_fn_3dshapes,
                               optimizer=optimizer_3dshapes,
                               lr_scheduler=None,
                               device=device,
                               train_step_unit=train_step_unit,
                               return_log_loss=True,
                               log_loss_interval_type='epoch'
                               )

train_logs_3dshapes = trainer_3dshapes.train(shapes3d_dataloader, max_steps=num_train_steps)
print("Training Logs (3D Shapes):", train_logs_3dshapes)


===== Training on 3D Shapes =====
--- Setup for Shapes3D --- 
Model: Model
Loss: Loss (Base: betavae, rec_dist=bernoulli), kwargs={'base_loss_name': 'betavae', 'base_loss_kwargs': {'rec_dist': 'bernoulli', 'latent_dim': 10, 'beta': 16}, 'rec_dist': 'bernoulli', 'device': device(type='cuda'), 'commutative_weight': 1.0, 'meaningful_weight': 1.0, 'commutative_component_order': 2, 'meaningful_component_order': 1, 'meaningful_transformation_order': 1, 'meaningful_critic_gradient_penalty_weight': 10, 'meaningful_critic_lr': 0.0001, 'meaningful_n_critic': 5, 'deterministic_rep': False, 'log_kl_components': True}
Optimizer: Adam
---------------------------


Epoch 1:   0%|          | 0/7500 [00:00<?, ?it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])


Epoch 1:   0%|          | 1/7500 [00:01<4:06:27,  1.97s/it]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 2/7500 [00:02<1:51:21,  1.12it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 4/7500 [00:02<47:56,  2.61it/s]  

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 5/7500 [00:02<36:38,  3.41it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 7/7500 [00:02<25:37,  4.87it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 8/7500 [00:02<23:22,  5.34it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 10/7500 [00:03<20:52,  5.98it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 11/7500 [00:03<19:45,  6.32it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 13/7500 [00:03<18:35,  6.71it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 14/7500 [00:03<18:25,  6.77it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 16/7500 [00:04<18:43,  6.66it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 17/7500 [00:04<19:15,  6.47it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 18/7500 [00:04<19:18,  6.46it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 20/7500 [00:04<19:09,  6.51it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 21/7500 [00:04<19:29,  6.40it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 22/7500 [00:05<19:14,  6.48it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 23/7500 [00:05<30:32,  4.08it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 24/7500 [00:06<40:59,  3.04it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 26/7500 [00:06<40:27,  3.08it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 28/7500 [00:07<28:47,  4.33it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 29/7500 [00:07<25:16,  4.93it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 30/7500 [00:07<22:43,  5.48it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 31/7500 [00:07<24:03,  5.17it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 32/7500 [00:07<29:40,  4.19it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 34/7500 [00:08<27:24,  4.54it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 36/7500 [00:08<22:01,  5.65it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   0%|          | 37/7500 [00:08<20:28,  6.08it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 39/7500 [00:08<18:38,  6.67it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 40/7500 [00:09<18:08,  6.85it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 42/7500 [00:09<18:15,  6.81it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 43/7500 [00:09<17:34,  7.07it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 45/7500 [00:09<17:19,  7.17it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 46/7500 [00:09<17:11,  7.23it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 48/7500 [00:10<16:58,  7.32it/s]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 50/7500 [00:10<17:11,  7.22it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 51/7500 [00:10<17:15,  7.19it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 52/7500 [00:10<17:19,  7.17it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 54/7500 [00:11<17:32,  7.08it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 56/7500 [00:11<17:15,  7.19it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 57/7500 [00:11<17:13,  7.20it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 59/7500 [00:11<17:36,  7.04it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 60/7500 [00:11<17:42,  7.00it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 61/7500 [00:12<18:40,  6.64it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 63/7500 [00:12<18:20,  6.76it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 64/7500 [00:12<23:35,  5.25it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 65/7500 [00:13<30:48,  4.02it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 66/7500 [00:13<28:21,  4.37it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 68/7500 [00:13<22:13,  5.57it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 69/7500 [00:13<20:41,  5.99it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 70/7500 [00:13<27:36,  4.48it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 71/7500 [00:14<28:23,  4.36it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 73/7500 [00:14<22:02,  5.62it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 74/7500 [00:14<20:04,  6.16it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 76/7500 [00:14<17:57,  6.89it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 78/7500 [00:15<17:00,  7.27it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 79/7500 [00:15<16:43,  7.40it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 81/7500 [00:15<16:16,  7.60it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 83/7500 [00:15<16:35,  7.45it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 84/7500 [00:15<16:37,  7.44it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 86/7500 [00:16<16:32,  7.47it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 87/7500 [00:16<16:26,  7.52it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 89/7500 [00:16<17:09,  7.20it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 90/7500 [00:16<18:44,  6.59it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 92/7500 [00:17<17:33,  7.03it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|          | 93/7500 [00:17<17:22,  7.10it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|▏         | 95/7500 [00:17<16:56,  7.28it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|▏         | 96/7500 [00:17<16:38,  7.42it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|▏         | 98/7500 [00:17<16:56,  7.28it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|▏         | 99/7500 [00:17<17:27,  7.07it/s, critic_loss=-47.8, g_commutative_loss=17.3, generator_loss=-17.9, kl_loss=1.08, loss=8.58e+3, rec_loss=8.56e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|▏         | 101/7500 [00:18<17:31,  7.04it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|▏         | 102/7500 [00:18<17:18,  7.13it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|▏         | 103/7500 [00:18<17:12,  7.16it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|▏         | 104/7500 [00:18<21:59,  5.61it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|▏         | 105/7500 [00:19<27:30,  4.48it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|▏         | 106/7500 [00:19<25:00,  4.93it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|▏         | 108/7500 [00:19<20:55,  5.89it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|▏         | 109/7500 [00:19<19:55,  6.18it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|▏         | 111/7500 [00:20<19:05,  6.45it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   1%|▏         | 112/7500 [00:20<18:38,  6.60it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 114/7500 [00:20<17:20,  7.10it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 115/7500 [00:20<17:07,  7.19it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 117/7500 [00:20<16:53,  7.29it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 119/7500 [00:21<16:55,  7.27it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 120/7500 [00:21<16:47,  7.32it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 122/7500 [00:21<16:53,  7.28it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 123/7500 [00:21<16:54,  7.27it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 125/7500 [00:21<16:46,  7.33it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 126/7500 [00:22<16:49,  7.31it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 128/7500 [00:22<17:45,  6.92it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 129/7500 [00:22<17:36,  6.98it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 130/7500 [00:22<17:40,  6.95it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 132/7500 [00:22<17:22,  7.07it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 134/7500 [00:23<17:19,  7.09it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 135/7500 [00:23<17:19,  7.08it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 136/7500 [00:23<17:23,  7.06it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 137/7500 [00:23<17:23,  7.05it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 138/7500 [00:24<26:52,  4.56it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 139/7500 [00:24<35:31,  3.45it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 141/7500 [00:24<26:20,  4.66it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 142/7500 [00:24<23:45,  5.16it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 143/7500 [00:25<22:24,  5.47it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 144/7500 [00:25<27:42,  4.43it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 146/7500 [00:25<27:01,  4.54it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 147/7500 [00:26<23:49,  5.14it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 149/7500 [00:26<20:30,  5.97it/s, critic_loss=-79.6, g_commutative_loss=10.6, generator_loss=-29.6, kl_loss=0.239, loss=8.43e+3, rec_loss=8.45e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 151/7500 [00:26<18:36,  6.58it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 152/7500 [00:26<18:00,  6.80it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 154/7500 [00:26<17:15,  7.09it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 155/7500 [00:27<16:51,  7.26it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 157/7500 [00:27<16:19,  7.50it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 159/7500 [00:27<16:07,  7.59it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 160/7500 [00:27<16:23,  7.47it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 162/7500 [00:28<16:19,  7.49it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 163/7500 [00:28<16:11,  7.55it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 165/7500 [00:28<16:33,  7.38it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 166/7500 [00:28<16:36,  7.36it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 168/7500 [00:28<16:44,  7.30it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 169/7500 [00:28<16:37,  7.35it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 171/7500 [00:29<16:30,  7.40it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 173/7500 [00:29<16:16,  7.51it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 174/7500 [00:29<16:15,  7.51it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 176/7500 [00:29<16:28,  7.41it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 177/7500 [00:30<16:38,  7.33it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 179/7500 [00:30<17:19,  7.04it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 180/7500 [00:30<17:23,  7.01it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 181/7500 [00:30<18:45,  6.50it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 182/7500 [00:31<26:17,  4.64it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 183/7500 [00:31<25:49,  4.72it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 185/7500 [00:31<20:33,  5.93it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   2%|▏         | 186/7500 [00:31<18:58,  6.43it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   3%|▎         | 188/7500 [00:31<17:37,  6.92it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   3%|▎         | 189/7500 [00:32<17:15,  7.06it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   3%|▎         | 191/7500 [00:32<16:49,  7.24it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   3%|▎         | 193/7500 [00:32<16:30,  7.38it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   3%|▎         | 194/7500 [00:32<16:25,  7.41it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   3%|▎         | 196/7500 [00:32<16:10,  7.52it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   3%|▎         | 197/7500 [00:33<16:17,  7.47it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   3%|▎         | 199/7500 [00:33<16:23,  7.43it/s, critic_loss=-76.3, g_commutative_loss=12.7, generator_loss=-18.1, kl_loss=0.881, loss=8.17e+3, rec_loss=8.16e+3]

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

Epoch 1:   3%|▎         | 201/7500 [00:33<16:25,  7.41it/s, critic_loss=-69.7, g_commutative_loss=7.55, generator_loss=-1.21, kl_loss=2.27, loss=7.8e+3, rec_loss=7.75e+3]  

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([6

                                                                                                                                                                          

variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])
variance_components.shape torch.Size([64, 10])
selected_indices.shape torch.Size([64, 1])
selected_variances.shape torch.Size([64, 1])
random_samples.shape torch.Size([64, 1])
transformation_parameters.shape torch.Size([64, 10])
transformation_values.shape torch.Size([64, 1])


KeyboardInterrupt: 

In [None]:
print("\n===== Visualizing 3D Shapes Results =====")
visualizer_3dshapes = utils.visualize.Visualizer(vae_model=model_3dshapes, dataset=shapes3d_dataset)

print("Plotting random reconstructions...")
visualizer_3dshapes.plot_random_reconstructions(10, mode='mean')
plt.show()

print("Plotting reconstructions from specific indices...")
indices_3dshapes = [5000, 6000, 7000, 100, 1000, 1024]
visualizer_3dshapes.plot_reconstructions_sub_dataset(indices_3dshapes, mode='mean')
plt.show()

print("Plotting latent traversals...")
visualizer_3dshapes.plot_all_latent_traversals(num_samples=20)
plt.show()

### 5.1.1 Metric Evaluation (3D Shapes)

In [None]:
metrics_to_compute = [
    {'name': 'dci_d', 'args':{'num_train':5000, 'num_test':1000}},
    {'name': 'mig', 'args':{}}
]

In [None]:
metric_aggregator_3dshapes = MetricAggregator(metrics=metrics_to_compute)

In [None]:
print("\n===== Computing Metrics for 3D Shapes =====")
metrics_results_3dshapes = metric_aggregator_3dshapes.compute(model=model_3dshapes, 
                                                              data_loader=shapes3d_dataloader, 
                                                              device=device)
print("3D Shapes Metrics:", metrics_results_3dshapes)

### 5.2 - dSprites

In [None]:
dsprites_specific_loss_kwargs = {
    # 'beta': 4 # Uses default from group_loss_kwargs
}

In [None]:
print("\n===== Training on dSprites =====")
model_dsprites, loss_fn_dsprites, optimizer_dsprites = setup_components(dsprites_dataset, dsprites_specific_loss_kwargs)

trainer_dsprites = UnsupervisedTrainer(model=model_dsprites,
                               loss_fn=loss_fn_dsprites,
                               optimizer=optimizer_dsprites,
                               lr_scheduler=None,
                               device=device,
                               train_step_unit=train_step_unit,
                               return_log_loss=True,
                               log_loss_interval_type='epoch'
                               )

train_logs_dsprites = trainer_dsprites.train(dsprites_dataloader, max_steps=num_train_steps)
print("Training Logs (dSprites):", train_logs_dsprites)

In [None]:
print("\n===== Visualizing dSprites Results =====")
visualizer_dsprites = utils.visualize.Visualizer(vae_model=model_dsprites, dataset=dsprites_dataset)

print("Plotting random reconstructions...")
visualizer_dsprites.plot_random_reconstructions(10, mode='mean')
plt.show()

print("Plotting reconstructions from specific indices...")
indices_dsprites = [0, 100000, 200000, 300000, 40000, 50000]
visualizer_dsprites.plot_reconstructions_sub_dataset(indices_dsprites, mode='mean')
plt.show()

print("Plotting latent traversals...")
visualizer_dsprites.plot_all_latent_traversals(num_samples=20)
plt.show()

### 5.2.1 Metric Evaluation (dSprites)

In [None]:
metric_aggregator_dsprites = MetricAggregator(metrics=metrics_to_compute)