In [1]:
import torch
import numpy as np
import torch.utils
import matplotlib.pyplot as plt
import torch.optim as optim

import utils.visualize
from trainers import UnsupervisedTrainer
import losses
import vae_models
from datasets import get_dataset
from utils.io import find_optimal_num_workers

## 2. Configuration

In [3]:
# --- General Hyperparameters ---
model_name = 'toroidal_vae_burgess'  # Name of the model architecture file (e.g., 'vae_burgess')
latent_factor_num = 10
learning_rate = 1e-4
batch_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
rec_dist = 'bernoulli'  # Reconstruction distribution (e.g., 'bernoulli', 'gaussian')

# train_step_unit = 'iteration'  # Unit for training steps ('epoch' or 'iteration')
# num_train_steps = int(3e5)  # Number of training steps 

train_step_unit = 'epoch'  # Unit for training steps ('epoch' or 'iteration')
num_train_steps = 5  # Number of training steps (epochs)

In [4]:
# --- Loss Specific Hyperparameters ---
# Beta VAE
loss_name = 'beta_toroidal_vae' 
loss_kwargs = {
    'beta': 16,
    'rec_dist': rec_dist
               }

In [None]:
# # AnnealedVAE
# loss_name = 'annealedvae'  
# loss_kwargs = {
#     'C_init': 0.0,
#     'C_fin_3dshapes': 25.0,
#     'C_fin_dsprites': 15.0,
#     'gamma_annealed': 100.0,
#     'anneal_steps': 10000,
#     'rec_dist': rec_dist,
# }

In [None]:
# # BetaTCVAE
# loss_name = 'betatcvae'
# loss_kwargs = {
#     'alpha_tc': 1.0,
#     'beta_tc_3dshapes': 6.0,
#     'beta_tc_dsprites': 4.0,
#     'gamma_tc': 1.0,
#     'is_mss': True,
#     'rec_dist': rec_dist
# }

In [None]:
# # Factor VAE

# loss_name = 'factorvae'
# loss_kwargs = {
#     'device': device,
#     'gamma': 6.4,
#     'discr_lr': 5e-5,
#     'discr_betas': (0.5, 0.9),
#     'rec_dist': rec_dist,
# }

## 3. Load Datasets

In [5]:
# Load 3D Shapes
Shapes3D = get_dataset("shapes3d")
shapes3d_dataset = Shapes3D(selected_factors='all', not_selected_factors_index_value=None)
# num_workers_3dshapes = find_optimal_num_workers(shapes3d_dataset, batch_size=batch_size, num_batches_to_test='all')
num_workers_3dshapes = 4

shapes3d_dataloader = torch.utils.data.DataLoader(shapes3d_dataset, batch_size=batch_size, num_workers=num_workers_3dshapes, shuffle=True, pin_memory=True)
print(f"Loaded 3D Shapes dataset with {len(shapes3d_dataset)} samples.")

# Load dSprites
Dsprites = get_dataset('dsprites')

dsprites_dataset = Dsprites(selected_factors='all', not_selected_factors_index_value=None)
# num_workers_dsprites = find_optimal_num_workers(dsprites_dataset, batch_size=batch_size, num_batches_to_test='all')
num_workers_dsprites = 7

dsprites_dataloader = torch.utils.data.DataLoader(dsprites_dataset, batch_size=batch_size, num_workers=num_workers_dsprites, shuffle=True, pin_memory=True)
print(f"Loaded dSprites dataset with {len(dsprites_dataset)} samples.")

Loaded 3D Shapes dataset with 480000 samples.
Loaded dSprites dataset with 737280 samples.


## 4. Setup Model, Loss, and Optimizer

In [8]:
def setup_components(dataset, loss_kwargs):
    """Instantiates model, loss function, and optimizer based on config."""
    img_size = dataset[0][0].shape
    n_data = len(dataset)
    

    # Instantiate Model
    model = vae_models.select(name=model_name, img_size=img_size, latent_factor_num=latent_factor_num)

    if loss_name == 'betatcvae':
        loss_kwargs['n_data'] = n_data
    
    loss_fn = losses.select(loss_name, **loss_kwargs)

    # Instantiate Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    print(f"--- Setup for {dataset.__class__.__name__} --- ")
    print(f"Model: {model.__class__.__name__}")
    print(f"Loss: {loss_fn.__class__.__name__} (rec_dist={rec_dist}), kwargs={loss_kwargs}")
    print(f"Optimizer: {optimizer.__class__.__name__}")
    print(f"---------------------------")

    return model, loss_fn, optimizer

## 5. Train and Visualize

## 5.1 - 3D Shapes

In [9]:
shapes3d_loss_kwargs =  {
     'beta': 16,
     'rec_dist': rec_dist
               }

In [10]:
print("\n===== Training on 3D Shapes =====")
model_3dshapes, loss_fn_3dshapes, optimizer_3dshapes = setup_components(shapes3d_dataset, shapes3d_loss_kwargs)

trainer_3dshapes = UnsupervisedTrainer(model=model_3dshapes,
                                      loss_fn=loss_fn_3dshapes,
                                      scheduler=None,
                                      optimizer=optimizer_3dshapes,
                                      device=device,
                                      train_step_unit=train_step_unit,
                                      )

trainer_3dshapes.train(shapes3d_dataloader, num_train_steps)


===== Training on 3D Shapes =====


ImportError: cannot import name 'kl_toroidal_loss' from 'losses.s_vae.kl_div' (/notebooks/losses/s_vae/kl_div.py)