# VAE

In [1]:
import os

from pythae.pipelines import TrainingPipeline
from pythae.trainers import BaseTrainerConfig
from pythae.trainers.training_callbacks import WandbCallback
from pythae.models import AutoModel, BetaVAE, BetaVAEConfig

from cnn_framework.dummy_vae_model.data_set import DummyVAEDataSet
from cnn_framework.dummy_vae_model.model_params import DummyVAEModelParams

from cnn_framework.dummy_vae_model.decoder import CustomDecoder
from cnn_framework.dummy_vae_model.encoder import CustomEncoder

from cnn_framework.utils.data_managers.default_data_manager import DefaultDataManager
from cnn_framework.utils.data_loader_generators.data_loader_generator import DataLoaderGenerator
from cnn_framework.utils.metrics.mean_squared_error_metric import MeanSquaredErrorMetric
from cnn_framework.utils.model_managers.model_manager import ModelManager
from cnn_framework.utils.model_managers.vae_model_manager import VAEModelManager
from cnn_framework.utils.create_dummy_data_set import generate_data_set

In [2]:
params = DummyVAEModelParams()
params.update()

# Create data set if needed
if not os.path.exists(params.data_dir):
    generate_data_set(params.data_dir)
    print(f"\nData set created in {params.data_dir}")

print(f"\nModel will be saved in {params.models_folder}")
print(f"Predictions will be saved in {params.output_dir}")
print(f"Tensorboard logs will be saved in {params.tensorboard_folder_path}")

Model time id: 20230908-155007-local
epochs 30 | batch 2 | lr 0.0001 | weight decay 0.05 | dropout 0.0 | latent dim 16 | beta 1 | gamma 0 | delta 0 | depth 5 | kld loss standard | encoder name timm-efficientnet-b0

Model will be saved in C:\Users\thoma\models/local/dummy_vae/20230908-155007-local
Predictions will be saved in C:\Users\thoma\predictions/local/dummy_vae/20230908-155007-local
Tensorboard logs will be saved in C:\Users\thoma\tensorboard/local/20230908-155007-local_dummy_vae


In [3]:
loader_generator = DataLoaderGenerator(params, DummyVAEDataSet, DefaultDataManager)
train_dl, val_dl, test_dl = loader_generator.generate_data_loader()

### Data source ###
train data is loaded from C:\Users\thoma\data\dummy - 80% elements
val data is loaded from C:\Users\thoma\data\dummy - 10% elements
test data is loaded from C:\Users\thoma\data\dummy - 10% elements
###################
train has 160 images.
val has 20 images.
test has 20 images.
###################


In [4]:
# Create folder to save model
os.makedirs(params.models_folder, exist_ok=True)

my_training_config = BaseTrainerConfig(
    output_dir=params.models_folder,
    num_epochs=params.num_epochs,
    learning_rate=params.learning_rate,
    per_device_train_batch_size=params.batch_size,
    per_device_eval_batch_size=params.batch_size,
    train_dataloader_num_workers=params.num_workers,
    eval_dataloader_num_workers=params.num_workers,
    steps_saving=None,
    optimizer_cls="AdamW",
    optimizer_params={
        "weight_decay": params.weight_decay,
        "betas": (params.beta1, params.beta2),
    },
    scheduler_cls="ReduceLROnPlateau",
    scheduler_params={"patience": 5, "factor": 0.5},
)

# Set up the model configuration
my_vae_config = BetaVAEConfig(
    reconstruction_loss=params.reconstruction_loss,
    input_dim=(
        len(params.c_indexes) * len(params.z_indexes),
        params.input_dimensions.height,
        params.input_dimensions.width,
    ),
    latent_dim=params.latent_dim,
    beta=params.beta,
    uses_default_decoder=False,
    uses_default_encoder=False,
)

# Build the model
if params.model_pretrained_path:
    vae_model = AutoModel.load_from_folder(params.model_pretrained_path)
    # Update modifiable parameters
    vae_model.model_config = my_vae_config
    vae_model.beta = my_vae_config.beta
else:
    encoder = CustomEncoder(params, my_vae_config)
    print(f"Number of parameters in encoder: {sum(p.numel() for p in encoder.parameters())}")
    decoder = CustomDecoder(params, my_vae_config)
    print(f"Number of parameters in decoder: {sum(p.numel() for p in decoder.parameters())}")
    vae_model = BetaVAE(encoder=encoder, decoder=decoder, model_config=my_vae_config)

# Build the Pipeline
pipeline = TrainingPipeline(training_config=my_training_config, model=vae_model)

# Compute mean_std for future normalization
model_manager = ModelManager(vae_model, params, None)
model_manager.compute_and_save_mean_std(train_dl, val_dl)

train_dl.dataset.initialize_transforms()
val_dl.dataset.initialize_transforms()

# Create you callback
callbacks = []  # the TrainingPipeline expects a list of callbacks
wandb_cb = WandbCallback()  # Build the callback
# SetUp the callback
wandb_cb.setup(
    training_config=my_training_config,  # training config
    model_config=my_vae_config,  # model config
    project_name=params.wandb_project,  # specify your wandb project
    entity_name=params.wandb_entity,  # specify your wandb entity
    run_name=params.format_now,  # name of the run
)
callbacks.append(wandb_cb)  # Add it to the callbacks list

Number of parameters in encoder: 4171420
Number of parameters in decoder: 2498675
Current commit hash: 589eca1e521ba7aa867d3d42d738cdbe2fda362f


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mthomas-bonte[0m ([33mcbio-bis[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
# Launch the Pipeline
pipeline(
    train_data=train_dl.dataset,  # must be torch.Tensor, np.array or torch datasets
    eval_data=val_dl.dataset,  # must be torch.Tensor, np.array or torch datasets
    callbacks=callbacks,
)

Checking train dataset...
Checking eval dataset...
Using Base Trainer

Model passed sanity check !
Ready for training.

Created C:\Users\thoma\models/local/dummy_vae/20230908-155007-local\BetaVAE_training_2023-09-08_15-50-33. 
Training config, checkpoints and final model will be saved here.

Training params:
 - max_epochs: 30
 - per_device_train_batch_size: 2
 - per_device_eval_batch_size: 2
 - checkpoint saving every: None
Optimizer: AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.91, 0.995)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 0.05
)
Scheduler: <torch.optim.lr_scheduler.ReduceLROnPlateau object at 0x000001B5164D73D0>

Successfully launched training !



Training of epoch 1/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 1/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 39875.8289
Eval loss: 33147.3303
--------------------------------------------------------------------------


Training of epoch 2/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 2/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 27631.5861
Eval loss: 35421.3955
--------------------------------------------------------------------------


Training of epoch 3/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 3/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 23508.8734
Eval loss: 34472.6119
--------------------------------------------------------------------------


Training of epoch 4/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 4/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 21705.2487
Eval loss: 29933.3256
--------------------------------------------------------------------------


Training of epoch 5/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 5/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 21443.8914
Eval loss: 29437.7758
--------------------------------------------------------------------------


Training of epoch 6/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 6/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 20599.7188
Eval loss: 23193.0669
--------------------------------------------------------------------------


Training of epoch 7/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 7/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 19253.1897
Eval loss: 22259.6152
--------------------------------------------------------------------------


Training of epoch 8/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 8/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 20406.5709
Eval loss: 23432.1786
--------------------------------------------------------------------------


Training of epoch 9/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 9/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 19568.0903
Eval loss: 21097.9493
--------------------------------------------------------------------------


Training of epoch 10/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 10/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 19367.0471
Eval loss: 22287.7409
--------------------------------------------------------------------------


Training of epoch 11/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 11/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 18086.2306
Eval loss: 17234.4842
--------------------------------------------------------------------------


Training of epoch 12/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 12/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 18310.3057
Eval loss: 19190.3052
--------------------------------------------------------------------------


Training of epoch 13/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 13/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 18294.8927
Eval loss: 17747.0697
--------------------------------------------------------------------------


Training of epoch 14/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 14/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 17621.7826
Eval loss: 19643.9727
--------------------------------------------------------------------------


Training of epoch 15/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 15/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 18089.1695
Eval loss: 21676.2707
--------------------------------------------------------------------------


Training of epoch 16/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 16/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 16826.1758
Eval loss: 18468.9005
--------------------------------------------------------------------------


Training of epoch 17/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 17/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 18016.539
Eval loss: 19676.1385
--------------------------------------------------------------------------


Training of epoch 18/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 18/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 16319.8229
Eval loss: 18066.4786
--------------------------------------------------------------------------


Training of epoch 19/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 19/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 15745.186
Eval loss: 18086.8271
--------------------------------------------------------------------------


Training of epoch 20/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 20/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 16359.6618
Eval loss: 16573.503
--------------------------------------------------------------------------


Training of epoch 21/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 21/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 15718.5758
Eval loss: 17498.7309
--------------------------------------------------------------------------


Training of epoch 22/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 22/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 15340.4717
Eval loss: 19182.915
--------------------------------------------------------------------------


Training of epoch 23/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 23/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 15301.5871
Eval loss: 18563.6477
--------------------------------------------------------------------------


Training of epoch 24/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 24/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 15485.1641
Eval loss: 18199.0125
--------------------------------------------------------------------------


Training of epoch 25/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 25/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 15789.922
Eval loss: 16344.8273
--------------------------------------------------------------------------


Training of epoch 26/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 26/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 16144.8122
Eval loss: 16302.5508
--------------------------------------------------------------------------


Training of epoch 27/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 27/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 14940.17
Eval loss: 16850.3422
--------------------------------------------------------------------------


Training of epoch 28/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 28/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 15189.2762
Eval loss: 15290.8796
--------------------------------------------------------------------------


Training of epoch 29/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 29/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 14688.3155
Eval loss: 15747.205
--------------------------------------------------------------------------


Training of epoch 30/30:   0%|          | 0/80 [00:00<?, ?batch/s]

Eval of epoch 30/30:   0%|          | 0/10 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 14475.4916
Eval loss: 17511.2616
--------------------------------------------------------------------------
Training ended!
Saved final model in C:\Users\thoma\models/local/dummy_vae/20230908-155007-local\BetaVAE_training_2023-09-08_15-50-33\final_model


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/epoch_loss,▇██▆▆▄▃▄▃▃▂▂▂▃▃▂▃▂▂▁▂▂▂▂▁▁▂▁▁▂
train/epoch_loss,█▅▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁
train/global_step,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███

0,1
eval/epoch_loss,17511.26162
train/epoch_loss,14475.49155
train/global_step,30.0


In [6]:
# Test and save images
manager = VAEModelManager(vae_model, params, MeanSquaredErrorMetric)
manager.predict(test_dl)

Current commit hash: 589eca1e521ba7aa867d3d42d738cdbe2fda362f
Model evaluation in progress: 100.0% | Batch #9
Average MeanSquaredError: -0.37
