In [1]:
import os

import torch
from pytorch_lightning import LightningDataModule, LightningModule, Trainer, callbacks
from torch import Tensor, nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64

Runner
 VAE_1
    stage 
    input -> internal_latent -> comm -> internal_latent -> input
    
 VAE_2
    stage 
    input -> internal_latent -> comm -> internal_latent -> input

Goal: have other creat

In [2]:
from pytorch_vae.experiment import VAEXperiment
from pytorch_vae.models import VanillaVAE, BaseVAE
from typing import List
import math

class FlatVAE(BaseVAE):
    def __init__(self,
        in_shape: int = [1, 32, 32],
        latent_dim: int = 64,
        hidden_dims: List = [256],
        **kwargs) -> None:
        super(FlatVAE, self).__init__()

        self.latent_dim = latent_dim
        self.in_shape = in_shape
        self.in_dim = math.prod(in_shape)
        modules = [nn.Flatten()]
        
        # Build Encoder
        prev_dim = self.in_dim
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(prev_dim, h_dim),
                    nn.BatchNorm1d(h_dim),
                    nn.LeakyReLU()
                )
            )
            prev_dim = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)


        # Build Decoder
        modules = []
        hidden_dims.reverse()
        
        prev_dim = self.latent_dim
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(prev_dim, h_dim),
                    nn.BatchNorm1d(h_dim),
                    nn.LeakyReLU()
                )
            )
            prev_dim = h_dim

        modules.append(nn.Linear(prev_dim, self.in_dim))

        self.decoder = nn.Sequential(*modules)


    def encode(self, input: Tensor) -> List[Tensor]:
        result = self.encoder(input)
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        result = self.decoder(z).view([-1] + self.in_shape)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        recons = self.decode(z)
        return  [recons, input, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset

        recons_loss = F.mse_loss(recons, input)

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + kld_weight * kld_loss

        return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        
        z = torch.zeros(num_samples,
                        self.latent_dim)
        
        start = torch.randn(self.latent_dim)
        end = -start
        for sample in range(num_samples):
            z[sample, :] = torch.lerp(start, end, sample / num_samples)

        z = z.to(current_device)

        samples = self.decode(z).view([-1] + self.in_shape)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        return self.forward(x)[0]

In [3]:
import random
class MultiVAE(BaseVAE):
    def __init__(self,
        in_shape: int = [1, 32, 32],
        latent_dim: int = 64,
        hidden_dims: List = [256],
        n_agents: int = 2,
        **kwargs) -> None:
        super(MultiVAE, self).__init__()
        self.latent_dim = latent_dim
        self.in_shape = in_shape
        
        
        self.agents = nn.ModuleList([FlatVAE(in_shape, latent_dim, hidden_dims) for _ in range(n_agents)])

    def encode(self, input: Tensor) -> List[Tensor]:
        mu, log_var = random.choice(self.agents).encode(input)
        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        result = random.choice(self.agents).decode(z)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        recons = self.decode(z)
        return  [recons, input, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss = F.mse_loss(recons, input)

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + kld_weight * kld_loss

        return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:

        z = torch.zeros(num_samples,
                        self.latent_dim)
        
        start = torch.randn(self.latent_dim)
        end = -start
        for sample in range(num_samples):
            z[sample, :] = torch.lerp(start, end, sample / num_samples)

        z = z.to(current_device)

        samples = self.decode(z).view([-1] + self.in_shape)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        return self.forward(x)[0]

In [4]:
class MnistDM(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.data_dir = "."
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Pad(2),  # Get to 32 x 32
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE, num_workers=8)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE, num_workers=8)

In [5]:
from pathlib import Path

from pytorch_lightning.loggers import TensorBoardLogger

tb_logger = TensorBoardLogger(save_dir="logs", name="test1")
Path(f"{tb_logger.log_dir}/Samples").mkdir(exist_ok=True, parents=True)
Path(f"{tb_logger.log_dir}/Reconstructions").mkdir(exist_ok=True, parents=True)

Missing logger folder: logs/test1


In [6]:
dset = MnistDM()
dset.setup()
vae = MultiVAE(in_dim=1024, latent_dim=16, hidden_dims=[256], n_agents=10)
x, y = next(iter(dset.train_dataloader()))
for item in vae(x):
    print(item.shape)

torch.Size([64, 1, 32, 32])
torch.Size([64, 1, 32, 32])
torch.Size([64, 16])
torch.Size([64, 16])


In [7]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir logs/

Reusing TensorBoard on port 6006 (pid 28225), started 1 day, 15:56:36 ago. (Use '!kill 28225' to kill it.)

In [8]:
params = {
    "LR": 0.001,
    "weight_decay": 0.0,
    "scheduler_gamma": 0.95,
    "kld_weight": 0.0001,
    "manual_seed": 42,
}
experiment = VAEXperiment(vae, params)
trainer = Trainer(
    logger=tb_logger,
    gpus=AVAIL_GPUS,
    max_epochs=60,
    progress_bar_refresh_rate=20,
)
trainer.fit(experiment, dset)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type     | Params
-----------------------------------
0 | model | MultiVAE | 5.4 M 
-----------------------------------
5.4 M     Trainable params
0         Non-trainable params
5.4 M     Total params
21.567    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1