# Explore your model with Wandb

In this tutorial we detail a simple example on how to monitor your training / and validation using WandB. 

First, if you don't have wandb installed yet, follow the installation instructions:

## Install wandb
1. Install wandb ```$ pip install wandb```
2. Create a wandb account [online](https://wandb.ai/)
3. Once you are logged in, go to this [page](https://wandb.ai/authorize) and copy the API key. 
4. In your terminal, enter ```$ wandb login``` and then copy your API key when prompted.

## Define your dataset, model 

First define the dataset and model you want to use.

In [1]:
from multivae.data.datasets.mnist_labels import MnistLabels

# Import the dataset
DATA_PATH = "./data"  # Set the path where to download the data
dataset = MnistLabels(DATA_PATH, "test", download=True)  # Set download to True

In [2]:
# Import the model of your choice
from multivae.models import MVTCAE, MVTCAEConfig

In [3]:
# Define the model configuration

model_config = MVTCAEConfig(
    n_modalities=2,
    latent_dim=20,
    input_dims={"images": (1, 28, 28), "labels": (1, 10)},
    decoders_dist={
        "images": "normal",
        "labels": "categorical",
    },  # Distributions to use for the decoders. It defines the reconstruction loss.
    alpha=2.0 / 3.0,  # hyperparameters specific to this model
    beta=2.5,
    uses_likelihood_rescaling=True,  # rescale the reconstruction loss for better results
    rescale_factors=dict(images=1, labels=50),
)

In [4]:
# Initialize the model

# If no encoders/ decoders architectures are specified, default MLPs are used
model = MVTCAE(model_config=model_config)

## Create a wandb callback and pass it to your trainer

In [5]:
from multivae.trainers import BaseTrainer, BaseTrainerConfig
from multivae.trainers.base.callbacks import WandbCallback

# Define the training configuration
trainer_config = BaseTrainerConfig(
    num_epochs=30,
    learning_rate=1e-2,
    optimizer_cls="Adam",
    output_dir="dummy_output_dir",
    steps_predict=5,  # !! set this argument to log images of generation to Wandb every 5 epochs !!
)

# !Define your wandb callback!
wandb_cb = WandbCallback()
# Pass the training config and model config
wandb_cb.setup(
    training_config=trainer_config,
    model_config=model_config,
    project_name="wandb_notebook",
)

# Define the trainer
trainer = BaseTrainer(
    model=model,
    training_config=trainer_config,
    train_dataset=dataset,
    callbacks=[wandb_cb],  ## !!! Pass the callback to the trainer !!!
)

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33masenellart[0m ([33mmultimodal_vaes[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


! No eval dataset provided ! -> keeping best model on train.

Model passed sanity check !
Ready for training.

Setting the optimizer with learning rate 0.01
Created dummy_output_dir/MVTCAE_training_2025-03-14_16-05-17. 
Training config, checkpoints and final model will be saved here.



In [6]:
# Now we train:

trainer.train()

Training params:
 - max_epochs: 30
 - per_device_train_batch_size: 64
 - per_device_eval_batch_size: 64
 - checkpoint saving every: None
Optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.01
    maximize: False
    weight_decay: 0
)
Scheduler: None

Successfully launched training !

Training of epoch 1/30:  96%|█████████▌| 151/157 [00:01<00:00, 139.66batch/s]New best model on train saved!
  fig.tight_layout()
Training of epoch 1/30: 100%|██████████| 157/157 [00:03<00:00, 42.53batch/s] 
--------------------------------------------------------------------------
Train loss: 564.5122
--------------------------------------------------------------------------
Training of epoch 2/30:  97%|█████████▋| 152/157 [00:01<00:00, 133.95batch/s]New best model on train saved!
Training of epoch 2/30: 100%|██████████| 157/157 [00:01<00:00, 130.94batch/s]
-----------------

0,1
train/epoch_loss,█▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁
train/global_step,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▇▇▇▇▇█████
train/images,█▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/joint_divergence,▆█▇▅▅▄▃▄▄▃▃▃▃▃▄▃▂▃▃▃▂▃▂▂▃▃▁▃▃▃
train/kld_images,▆▅▄▃▃▂▃▂▂▂▁▂▃▂▂▂▃▂▄▁▂▂▂▂█▇▄▅▄▅
train/kld_labels,█▇▆▅▅▅▄▄▄▄▄▄▃▅▄▃▄▅▂▆▄▅▃▆▁▃▃▄▄▄
train/labels,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train/epoch_loss,557.54183
train/global_step,30.0
train/images,47707.08203
train/joint_divergence,276.43442
train/kld_images,260.86557
train/kld_labels,116.92857
train/labels,4743.45508


## Compute metrics and log to WandB

When computing metrics afterwards, you can log the results to the same wandb path. 
If you reload your model in a different script and don't know where to find the wandb path, check out this [link](https://multivae.readthedocs.io/en/latest/metrics/info_wandb.html). 

In [8]:
from multivae.models.auto_model import AutoModel

# reload the best model
best_model = AutoModel.load_from_folder(
    f"{trainer.training_dir}/final_model"
)  # Copy the path to final model.

In [9]:
from multivae.metrics.likelihoods import (
    LikelihoodsEvaluator,
    LikelihoodsEvaluatorConfig,
)

# here we get the path from the wandb_cb object that we created earlier
wandb_path = wandb_cb.run.path

ll_config = LikelihoodsEvaluatorConfig(
    batch_size=128,
    num_samples=100,
    wandb_path=wandb_path,  # ! pass the wandb_path here !
)

ll = LikelihoodsEvaluator(best_model, dataset, eval_config=ll_config)

ll.eval()  # might take some time
ll.finish()  # to finish the wandb run

100%|██████████| 79/79 [00:10<00:00,  7.53it/s]
Mean Joint likelihood : tensor(752.7980)


ModelOutput([('joint_likelihood', tensor(752.7980))])