# Training a MHVAE model on PolyMNIST

In this notebook, we demonstrate how to define and train a MHVAE model on PolyMNIST. 

We refer to the following diagram for setting architectures: 

![image](../../docs/source/models/multimodal_vaes/mhvae_architectures.png)

In [1]:
DATA_PATH = "/home/asenella/data"
# First import the dataset class
from multivae.data.datasets import MMNISTDataset
from torch.utils.data import random_split

train_set = MMNISTDataset(DATA_PATH)
train_set, eval_set = random_split(train_set, [0.8, 0.2])

First we need to define the architectures we are going to use. 

Here we use 2 levels of latent variables. We use the same architectures for all modalities. 
Each bottom-up blocks and top-down blocks are convolutional networks. 
We keep track of the input dims and output dims of each module below to make sure they match. 

latent_dim = 32

|block | input_dim | output_dim |
|-----|------------|------------|
|encoder|(3,28,28)| (32,14,14)  |
|bottom-up_1 | (32,14,14)|(64,7,7)|
|bottom-up_2 | (64,7,7)|latent_dim|
|top-down_2 |latent_dim|(64,7,7)|
|top-down_1 |(64,7,7)|(32,14,14)|
|decoder|(32,14,14)|(3,28,28)|
|prior_block_2|(64,7,7)|(64,7,7)|
|prior_block_1|(32,14,14)|(32,14,14)|
|posterior_block_2|(2*64,7,7)|(64,7,7)|
|posterior_block_1|(2*32,14,14)|(32,14,14)|


In [2]:
from multivae.models.mhvae import MHVAEConfig, MHVAE

# Define the model configuration
model_config = MHVAEConfig(
    n_modalities=5,
    latent_dim=64,
    input_dims={f"m{i}": (3, 28, 28) for i in range(5)},
    n_latent=3,
    beta=1,
)

In [3]:
from multivae.models.base import BaseEncoder, ModelOutput, BaseDecoder
from torch import nn

# Defining encoder and bottom-up blocks


class my_input_encoder(BaseEncoder):
    def __init__(self):
        super().__init__()

        self.conv0 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=True)
        self.act_1 = nn.SiLU()

    def forward(self, x):
        x = self.conv0(x)
        x = self.act_1(x)

        return ModelOutput(embedding=x)


bu_1 = nn.Sequential(
    nn.Conv2d(
        in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1, bias=True
    ),
    nn.SiLU(),
)


class bu_2(BaseEncoder):
    def __init__(self, inchannels, outchannels, latent_dim):
        super().__init__()

        self.network = nn.Sequential(
            nn.Conv2d(
                in_channels=inchannels,
                out_channels=outchannels,
                kernel_size=3,
                stride=2,
                padding=1,
                bias=True,
            ),
            nn.SiLU(),
            nn.Flatten(),
            nn.Linear(2048, 512),
            nn.ReLU(),
        )

        self.mu = nn.Linear(512, latent_dim)
        self.log_var = nn.Linear(512, latent_dim)

    def forward(self, x):
        h = self.network(x)
        return ModelOutput(embedding=self.mu(h), log_covariance=self.log_var(h))


# Defininin top-down blocks and decoder


class td_2(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()

        self.linear = nn.Sequential(nn.Linear(latent_dim, 2048), nn.ReLU())
        self.convs = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, bias=True),
            nn.SiLU(),
        )

    def forward(self, x):
        h = self.linear(x)
        h = h.view(h.shape[0], 128, 4, 4)
        return self.convs(h)


td_1 = nn.Sequential(
    nn.ConvTranspose2d(
        64, 32, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True
    ),
    nn.SiLU(),
)


class my_input_decoder(BaseDecoder):
    def __init__(self):
        super().__init__()

        self.network = nn.Sequential(
            nn.ConvTranspose2d(32, 32, 3, 2, 1, output_padding=1),
            nn.SiLU(),
            nn.ConvTranspose2d(32, 3, 3, 1, 1, output_padding=0),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return ModelOutput(reconstruction=self.network(x))


# Defining prior blocks and posterior blocks


class prior_block(BaseEncoder):
    def __init__(self, n_channels):
        super().__init__()

        self.mu = nn.utils.parametrizations.weight_norm(
            nn.Conv2d(n_channels, n_channels, 1, 1, 0)
        )
        self.logvar = nn.utils.parametrizations.weight_norm(
            nn.Conv2d(n_channels, n_channels, 1, 1, 0)
        )

    def forward(self, x):
        return ModelOutput(embedding=self.mu(x), log_covariance=self.logvar(x))


class posterior_block(BaseEncoder):
    def __init__(self, n_channels_before_concat):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(
                2 * n_channels_before_concat,
                n_channels_before_concat,
                3,
                1,
                1,
                bias=True,
            ),
            nn.SiLU(),
        )

        self.mu = nn.utils.parametrizations.weight_norm(
            nn.Conv2d(n_channels_before_concat, n_channels_before_concat, 1, 1, 0)
        )
        self.logvar = nn.utils.parametrizations.weight_norm(
            nn.Conv2d(n_channels_before_concat, n_channels_before_concat, 1, 1, 0)
        )

    def forward(self, x):
        h = self.network(x)
        return ModelOutput(embedding=self.mu(h), log_covariance=self.logvar(h))

In [4]:
model = MHVAE(
    model_config=model_config,
    encoders={f"m{i}": my_input_encoder() for i in range(5)},
    decoders={f"m{i}": my_input_decoder() for i in range(5)},
    bottom_up_blocks={
        f"m{i}": [bu_1, bu_2(64, 128, model_config.latent_dim)] for i in range(5)
    },
    top_down_blocks=[td_1, td_2(model_config.latent_dim)],
    prior_blocks=[prior_block(32), prior_block(64)],
    posterior_blocks=[posterior_block(32), posterior_block(64)],
)

Shared weights for the posterior blocks


In [6]:
# Check that the data compiles by running a forward pass
from torch.utils.data import DataLoader

dl = DataLoader(train_set, 10)
sample = next(iter(dl))
model(sample)

ModelOutput([('loss', tensor(139578.3594, grad_fn=<MeanBackward0>)),
             ('loss_sum', tensor(139578.3594, grad_fn=<MeanBackward0>)),
             ('metrics',
              {'kl_3': tensor(306.9045, grad_fn=<SumBackward0>),
               'kl_2': tensor(15249.9697, grad_fn=<SumBackward0>),
               'kl_1': tensor(30313.1875, grad_fn=<SumBackward0>)})])

## Training

In [None]:
from multivae.trainers import BaseTrainer, BaseTrainerConfig


trainer_config = BaseTrainerConfig(
    output_dir="~/experiments/mhvae_test",
    per_device_eval_batch_size=64,
    per_device_train_batch_size=64,
    num_epochs=20,
    steps_predict=1,
    learning_rate=1e-3,
)

## Set up wandb callback
# wandb_cb = WandbCallback()
# wandb_cb.setup(trainer_config,model_config,'mhvae_test')


trainer = BaseTrainer(
    training_config=trainer_config,
    model=model,
    train_dataset=train_set,
    eval_dataset=eval_set,
    #   callbacks=[wandb_cb]
)

trainer.train()