# Training a MHVAE model on PolyMNIST

In [1]:
from multivae.data.datasets import MMNISTDataset

train_set = MMNISTDataset('/Users/agathe/dev/data')

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/agathe/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


First we need to define the architectures we are going to use. 

Here we use 2 levels of latent variables. We use the same architectures for all modalities. 

latent_dim = 32

|block | input_dim | output_dim |
|-----|------------|------------|
|encoder|(3,28,28)| (32,14,14)  |
|bottom-up_1 | (32,14,14)|(64,7,7)|
|bottom-up_2 | (64,7,7)|latent_dim|
|top-down_2 |latent_dim|(64,7,7)|
|top-down_1 |(64,7,7)|(32,14,14)|
|decoder|(32,14,14)|(3,28,28)|
|prior_block_2|(64,7,7)|(64,7,7)|
|prior_block_1|(32,14,14)|(32,14,14)|
|posterior_block_2|(2*64,7,7)|(64,7,7)|
|posterior_block_1|(2*32,14,14)|(32,14,14)|


In [2]:
from multivae.models.mhvae import MHVAEConfig, MHVAE

model_config = MHVAEConfig(
    n_modalities=5,
    latent_dim=64,
    input_dims={f'm{i}':(3,28,28) for i in range(5)},
    n_latent=3,
    beta=1
)

In [3]:
from multivae.models.base import BaseEncoder, ModelOutput, BaseDecoder
from torch import nn

# Defining encoder and bottom-up blocks

class my_input_encoder(BaseEncoder):
    
    def __init__(self):
        super().__init__()
        
        self.conv0 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=True)
        self.act_1 = nn.SiLU()
        
    def forward(self, x):
       
        x = self.conv0(x)
        x = self.act_1(x)
        
        return ModelOutput(embedding = x)


bu_1 = nn.Sequential(nn.Conv2d(
            in_channels=32,
            out_channels=64,
            kernel_size=3,
            stride=2,
            padding=1,
            bias=True
        ) , nn.SiLU())

        
class bu_2(BaseEncoder):
    
    def __init__(self, inchannels,outchannels,latent_dim):
        super().__init__()

        self.network = nn.Sequential( nn.Conv2d(
            in_channels=inchannels,
            out_channels=outchannels,
            kernel_size=3,
            stride=2,
            padding=1,
            bias=True
        ) ,
        nn.SiLU(),
        nn.Flatten(),
        nn.Linear(2048, 512),  
        nn.ReLU())
        
        self.mu = nn.Linear(512, latent_dim)
        self.log_var = nn.Linear(512, latent_dim)
        
    def forward(self, x):
        h = self.network(x)
        return ModelOutput(
            embedding = self.mu(h),
            log_covariance = self.log_var(h)
        )
        
# Defininin top-down blocks and decoder
        
class td_2(nn.Module):
    
    def __init__(self, latent_dim):
        super().__init__()

        
        self.linear = nn.Sequential(
            nn.Linear(latent_dim,2048),nn.ReLU()
        )
        self.convs = nn.Sequential(nn.ConvTranspose2d(128,64,kernel_size=3,stride=2,padding=1,bias=True),
                                   nn.SiLU())
    def forward(self,x):
        h=self.linear(x)
        h = h.view(h.shape[0],128,4,4)
        return self.convs(h)
    
td_1 = nn.Sequential(
    nn.ConvTranspose2d(64,32,kernel_size=3,stride=2,padding=1, output_padding=1,bias=True),
                                   nn.SiLU()
)

class my_input_decoder(BaseDecoder):
    
    def __init__(self):
        super().__init__()

        self.network = nn.Sequential(
            nn.ConvTranspose2d(32,3,3,2,1, output_padding=1),nn.Sigmoid()
        )
    
    def forward(self,x):
        return ModelOutput(
            reconstruction = self.network(x)
        )
        
# Defining prior blocks and posterior blocks

class prior_block(BaseEncoder):
    
    def __init__(self, n_channels):
        super().__init__()

        self.mu = nn.utils.weight_norm(nn.Conv2d(n_channels,n_channels,1,1,0))
        self.logvar = nn.utils.weight_norm(nn.Conv2d(n_channels,n_channels,1,1,0))
    def forward(self, x):
        return ModelOutput(embedding = self.mu(x), log_covariance = self.logvar(x))

class posterior_block(BaseEncoder):
    
    def __init__(self, n_channels_before_concat):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(2*n_channels_before_concat,n_channels_before_concat,3,1,1, bias=True), 
            nn.SiLU()
        )
        
        self.mu = nn.utils.weight_norm(nn.Conv2d(n_channels_before_concat,n_channels_before_concat,1,1,0))
        self.logvar = nn.utils.weight_norm(nn.Conv2d(n_channels_before_concat,n_channels_before_concat,1,1,0))
        
    def forward(self, x):
        h = self.network(x)
        return ModelOutput(embedding = self.mu(h), log_covariance = self.logvar(h))
    



In [4]:
model = MHVAE(
    model_config=model_config,
    encoders = {f'm{i}' : my_input_encoder() for i in range(5)},
    decoders = {f'm{i}':my_input_decoder() for i in range(5)},
    bottom_up_blocks={f'm{i}' : [bu_1,bu_2(64,128,model_config.latent_dim)] for i in range(5)},
    top_down_blocks=[td_1,td_2(model_config.latent_dim)],
    prior_blocks=[prior_block(32), prior_block(64)],
    posterior_blocks=[posterior_block(32),posterior_block(64)]
)



In [5]:
from torch.utils.data import DataLoader

dl = DataLoader(train_set,10)

sample =next(iter(dl))

model(sample)

ModelOutput([('loss', tensor(139815.3594, grad_fn=<MeanBackward0>)),
             ('loss_sum', tensor(139815.3594, grad_fn=<MeanBackward0>))])

In [6]:
embedding = model.encode(sample)

In [7]:
model.decode(embedding)


ModelOutput([('m0',
              tensor([[[[0.5244, 0.5870, 0.5265,  ..., 0.4207, 0.4948, 0.5235],
                        [0.5522, 0.3758, 0.6480,  ..., 0.6652, 0.4468, 0.4567],
                        [0.4979, 0.4653, 0.4306,  ..., 0.5740, 0.4931, 0.5435],
                        ...,
                        [0.5095, 0.6717, 0.5708,  ..., 0.4583, 0.5793, 0.4649],
                        [0.5892, 0.3328, 0.5288,  ..., 0.5557, 0.6067, 0.4503],
                        [0.4664, 0.4112, 0.5552,  ..., 0.6430, 0.5149, 0.5641]],
              
                       [[0.5373, 0.4324, 0.5524,  ..., 0.4831, 0.2908, 0.5536],
                        [0.4036, 0.6111, 0.6207,  ..., 0.6712, 0.4758, 0.5188],
                        [0.4826, 0.3381, 0.3900,  ..., 0.4418, 0.5407, 0.5870],
                        ...,
                        [0.6473, 0.5735, 0.3285,  ..., 0.4246, 0.3577, 0.3698],
                        [0.5637, 0.5633, 0.4532,  ..., 0.4145, 0.5569, 0.3902],
                        [0