# Group Theory VAE (BetaVAE Base) Testing Framework

In [3]:
from utils.reproducibility import set_deterministic_run, get_deterministic_dataloader

seed = 42
set_deterministic_run(seed=seed)

## 1. Imports

In [12]:
import torch
import numpy as np
import torch.utils
import matplotlib.pyplot as plt
import torch.optim as optim

import utils
from trainers import UnsupervisedTrainer
import losses
import vae_models
from datasets import get_dataset
from utils.io import find_optimal_num_workers
from metrics.utils import MetricAggregator

## 2. Configuration

In [5]:
# --- General Hyperparameters ---
model_name = 'vae_locatello'
latent_dim = 10
learning_rate = 1e-4
batch_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
rec_dist = 'bernoulli'

train_step_unit = 'epoch'
num_train_steps = 5

# train_step_unit = 'iteration'
# num_train_steps = int(9e3)

In [6]:
# --- Loss Specific Hyperparameters ---
loss_name = 'group_theory'
base_loss_name = 'betavae'

group_loss_kwargs = {
    'base_loss': base_loss_name,
    'rec_dist': rec_dist,
    'device': device,
    'commutative_weight': 1.0,
    'meaningful_weight': 1.0,
    'commutative_component_order': 2,
    'meaningful_component_order': 1,
    'meaningful_transformation_order': 1,
    'meaningful_critic_gradient_penalty_weight': 10,
    'meaningful_critic_lr': 1e-4,
    'meaningful_n_critic': 5,
    'deterministic_rep': False,
    # Base BetaVAE specific kwargs
    'beta': 4, # Default beta
    'log_kl_components': True
}

## 3. Load Datasets

In [7]:
# Load 3D Shapes
Shapes3D = get_dataset("shapes3d")
shapes3d_dataset = Shapes3D(selected_factors='all', not_selected_factors_index_value=None)
num_workers_3dshapes = 4

shapes3d_dataloader = get_deterministic_dataloader(dataset=shapes3d_dataset, 
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers_3dshapes,
                                                   seed=seed,
                                                   pin_memory=True)

print(f"Loaded 3D Shapes dataset with {len(shapes3d_dataset)} samples.")

Loaded 3D Shapes dataset with 480000 samples.


In [8]:
# Load dSprites
Dsprites = get_dataset('dsprites')
dsprites_dataset = Dsprites(selected_factors='all', not_selected_factors_index_value=None)
num_workers_dsprites = 7

dsprites_dataloader = get_deterministic_dataloader(dataset=dsprites_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers_dsprites,
                                                   seed=seed,
                                                   pin_memory=True)

print(f"Loaded dSprites dataset with {len(dsprites_dataset)} samples.")

Loaded dSprites dataset with 737280 samples.


## 4. Setup Model, Loss, and Optimizer

In [9]:
def setup_components(dataset, specific_loss_kwargs):
    """Instantiates model, loss function, and optimizer based on config."""
    img_size = dataset[0][0].shape
    n_data = len(dataset)
    
    model = vae_models.select(name=model_name, img_size=img_size, latent_dim=latent_dim)
    
    final_loss_kwargs = group_loss_kwargs.copy()
    final_loss_kwargs.update(specific_loss_kwargs)
    
    if final_loss_kwargs.get('base_loss') == 'betatcvae':
         final_loss_kwargs['n_data'] = n_data
         
    loss_fn = losses.select(loss_name, **final_loss_kwargs)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    print(f"--- Setup for {dataset.__class__.__name__} --- ")
    print(f"Model: {model.__class__.__name__}" )
    print(f"Loss: {loss_fn.__class__.__name__} (Base: {base_loss_name}, rec_dist={rec_dist}), kwargs={final_loss_kwargs}")
    print(f"Optimizer: {optimizer.__class__.__name__}")
    print(f"---------------------------")

    return model, loss_fn, optimizer

## 5. Train and Visualize

### 5.1 - 3D Shapes

In [10]:
shapes3d_specific_loss_kwargs = {
    'beta': 16
}

In [None]:
print("\n===== Training on 3D Shapes =====")
model_3dshapes, loss_fn_3dshapes, optimizer_3dshapes = setup_components(shapes3d_dataset, shapes3d_specific_loss_kwargs)

trainer_3dshapes = UnsupervisedTrainer(model=model_3dshapes,
                               loss_fn=loss_fn_3dshapes,
                               optimizer=optimizer_3dshapes,
                               lr_scheduler=None,
                               device=device,
                               train_step_unit=train_step_unit,
                               return_log_loss=True,
                               log_loss_interval_type='epoch'
                               )

train_logs_3dshapes = trainer_3dshapes.train(shapes3d_dataloader, max_steps=num_train_steps)
print("Training Logs (3D Shapes):", train_logs_3dshapes)

In [None]:
print("\n===== Visualizing 3D Shapes Results =====")
visualizer_3dshapes = utils.visualize.Visualizer(vae_model=model_3dshapes, dataset=shapes3d_dataset)

print("Plotting random reconstructions...")
visualizer_3dshapes.plot_random_reconstructions(10, mode='mean')
plt.show()

print("Plotting reconstructions from specific indices...")
indices_3dshapes = [5000, 6000, 7000, 100, 1000, 1024]
visualizer_3dshapes.plot_reconstructions_sub_dataset(indices_3dshapes, mode='mean')
plt.show()

print("Plotting latent traversals...")
visualizer_3dshapes.plot_all_latent_traversals(num_samples=20)
plt.show()

### 5.1.1 Metric Evaluation (3D Shapes)

In [None]:
metrics_to_compute = [
    {'name': 'dci_d', 'args':{'num_train':5000, 'num_test':1000}},
    {'name': 'mig', 'args':{}}
]

In [None]:
metric_aggregator_3dshapes = MetricAggregator(metrics=metrics_to_compute)

In [None]:
print("\n===== Computing Metrics for 3D Shapes =====")
metrics_results_3dshapes = metric_aggregator_3dshapes.compute(model=model_3dshapes, 
                                                              data_loader=shapes3d_dataloader, 
                                                              device=device)
print("3D Shapes Metrics:", metrics_results_3dshapes)

### 5.2 - dSprites

In [None]:
dsprites_specific_loss_kwargs = {
    # 'beta': 4 # Uses default from group_loss_kwargs
}

In [None]:
print("\n===== Training on dSprites =====")
model_dsprites, loss_fn_dsprites, optimizer_dsprites = setup_components(dsprites_dataset, dsprites_specific_loss_kwargs)

trainer_dsprites = UnsupervisedTrainer(model=model_dsprites,
                               loss_fn=loss_fn_dsprites,
                               optimizer=optimizer_dsprites,
                               lr_scheduler=None,
                               device=device,
                               train_step_unit=train_step_unit,
                               return_log_loss=True,
                               log_loss_interval_type='epoch'
                               )

train_logs_dsprites = trainer_dsprites.train(dsprites_dataloader, max_steps=num_train_steps)
print("Training Logs (dSprites):", train_logs_dsprites)

In [None]:
print("\n===== Visualizing dSprites Results =====")
visualizer_dsprites = utils.visualize.Visualizer(vae_model=model_dsprites, dataset=dsprites_dataset)

print("Plotting random reconstructions...")
visualizer_dsprites.plot_random_reconstructions(10, mode='mean')
plt.show()

print("Plotting reconstructions from specific indices...")
indices_dsprites = [0, 100000, 200000, 300000, 40000, 50000]
visualizer_dsprites.plot_reconstructions_sub_dataset(indices_dsprites, mode='mean')
plt.show()

print("Plotting latent traversals...")
visualizer_dsprites.plot_all_latent_traversals(num_samples=20)
plt.show()

### 5.2.1 Metric Evaluation (dSprites)

In [None]:
metric_aggregator_dsprites = MetricAggregator(metrics=metrics_to_compute)