### Codage d'un VAE sur le dataset MNIST avec Torch

In [None]:
!pip install pytorch_lightning

In [None]:
!pip install pydantic

In [3]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import random
from torchvision.datasets import MNIST, FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from typing import Optional

In [4]:
class VAE(pl.LightningDataModule):

  def __init__(self, hidden_size: int, alpha: int, lr: float, batch_size: int,
              dataset: Optional[str] = None,
              save_images: Optional[bool] = None,
              save_path: Optional[str] = None, **kwargs):
    super().__init__()
    if save_images:
        self.save_path = f'{save_path}/{kwargs["model_type"]}_images/'
    self.save_hyperparameters()
    self.hidden_size = hidden_size
    self.alpha = alpha
    self.batch_size = batch_size
    self.learning_rate = lr
    self.dataset = dataset
    self.encoder = nn.Sequential(
            Flatten(), 
            nn.Linear(784, 392), nn.BatchNorm1d(392), nn.LeakyReLU(0.1),
            nn.Linear(392, 196), nn.BatchNorm1d(196), nn.LeakyReLU(0.1),
            nn.Linear(196, 128), nn.BatchNorm1d(128), nn.LeakyReLU(0.1),
            nn.Linear(128, hidden_size))
    self.hidden_mu = nn.Linear(hidden_size, hidden_size)
    self.hidden_log_var = nn.Linear(hidden_size, hidden_size)

    self.decoder = nn.Sequential(
            nn.Linear(hidden_size, 128), nn.BatchNorm1d(128), nn.LeakyReLU(0.1),
            nn.Linear(128, 196), nn.BatchNorm1d(196), nn.LeakyReLU(0.1),
            nn.Linear(196, 392), nn.BatchNorm1d(392), nn.LeakyReLU(0.1),
            nn.Linear(392, 784),
            Stack(1, 28, 28),
            nn.Tanh())
    
    self.data_tranformed = transforms.Compose([transforms.ToTensor, transforms.Lambda(lambda x: 2*x - 1.0)])


    def encode(self,x):
      hidden = self.encoder
      mu = self.hidden_mu(hidden)
      log_var = self.hidden_mu(hidden)
      return mu, logvar
    

    def decode(self, x):
      return self.decoder(x)


    def reparametrize(self, mu, log_var):
      sigma = torch.exp(0.5*log_var)
      z = torch.randn_like(sigma)
      return mu + sigma*z


    def training_step(self, batch, batch_idx):
      x, _ = batch
      mu, log_var, x_out = self.forward(x)
      kl_loss = (-0.5*(1+log_var - mu**2 -
                        torch.exp(log_var)).sum(dim=1)).mean(dim=0)
      recon_loss_criterion = nn.MSELoss()
      recon_loss = recon_loss_criterion(x, x_out)

      loss = recon_loss*self.alpha + kl_loss

      self.log('train_loss', loss, on_step=False,
                on_epoch=True, prog_bar=True)
      return loss
    

    def validation_step(self, batch, batch_idx):
      x, _ = batch
      mu, log_var, x_out = self.forward(x)

      kl_loss = (-0.5*(1+log_var - mu**2 -
                        torch.exp(log_var)).sum(dim=1)).mean(dim=0)
      recon_loss_criterion = nn.MSELoss()
      recon_loss = recon_loss_criterion(x, x_out)

      loss = recon_loss*self.alpha + kl_loss
      self.log('val_kl_loss', kl_loss, on_step=False, on_epoch=True)
      self.log('val_recon_loss', recon_loss, on_step=False, on_epoch=True)
      self.log('val_loss', loss, on_step=False, on_epoch=True)
      return x_out, loss

    
    def validation_epoch_end(self, outputs):
      if not self.save_images:
          return
      if not os.path.exists(self.save_path):
          os.makedirs(self.save_path)
      choice = random.choice(outputs)
      output_sample = choice[0]
      output_sample = output_sample.reshape(-1, 1, 28, 28)
      # output_sample = self.scale_image(output_sample)
      save_image(
          output_sample,
          f"{self.save_path}/epoch_{self.current_epoch+1}.png",
          # value_range=(-1, 1)
      )

    
    def configure_optimizers(self):
      optimizer = Adam(self.parameters(), lr=(self.lr or self.learning_rate))
      lr_scheduler = ReduceLROnPlateau(optimizer,)
      return {
          "optimizer": optimizer, "lr_scheduler": lr_scheduler,
          "monitor": "val_loss"
      }


    def forward(self, x):
      mu, log_var = self.encode(x)
      hidden = self.reparametrize(mu, log_var)
      output = self.decoder(hidden)
      return mu, log_var, output

    # Functions for dataloading
    def train_dataloader(self):
      if self.dataset == "mnist":
          train_set = MNIST('data/', download=True,
                            train=True, transform=self.data_transform)
      return DataLoader(train_set, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
      if self.dataset == "mnist":
          val_set = MNIST('data/', download=True, train=False,
                          transform=self.data_transform)
      elif self.dataset == "fashion-mnist":
          val_set = FashionMNIST(
              'data/', download=True, train=False,
              transform=self.data_transform)
      return DataLoader(val_set, batch_size=64)

    def scale_image(self, img):
      out = (img + 1) / 2
      return out

    def interpolate(self, x1, x2):
      assert x1.shape == x2.shape, "Inputs must be of the same shape"
      if x1.dim() == 3:
          x1 = x1.unsqueeze(0)
      if x2.dim() == 3:
          x2 = x2.unsqueeze(0)
      if self.training:
          raise Exception(
              "This function should not be called when model is still "
              "in training mode. Use model.eval() before calling the "
              "function")
      mu1, lv1 = self.encode(x1)
      mu2, lv2 = self.encode(x2)
      z1 = self.reparametrize(mu1, lv1)
      z2 = self.reparametrize(mu2, lv2)
      weights = torch.arange(0.1, 0.9, 0.1)
      intermediate = [self.decode(z1)]
      for wt in weights:
          inter = (1.-wt)*z1 + wt*z2
          intermediate.append(self.decode(inter))
      intermediate.append(self.decode(z2))
      out = torch.stack(intermediate, dim=0).squeeze(1)
      return out, (mu1, lv1), (mu2, lv2)


class Stack(nn.Module):
  def __init__(self, channels, height, width):
    super(Stack, self).__init__()
    self.channels = channels
    self.height = height
    self.width = width

  def forward(self, x):
    return x.view(x.size(0), self.channels, self.height, self.width)


class Flatten(nn.Module):
  def forward(self,x):
    return x.view(x.size(0), -1)



In [5]:
__all__ = [
    'VAE', 'Flatten', 'Stack'
    'Conv_VAE',
]
vae_models = {
    #"conv-vae": Conv_VAE,
    "vae": VAE
}

In [6]:
import yaml

data = {'training_params': {'max_epochs': 30,
                            'auto_lr_find': False,
                            'gpus': 1}, 
        'logger_params': {'name': "conv-vae", 
                          'save_dir': "logs/"},
        'model_params': {'model_type': 'vae', # vae or conv-vae 
                        'lr': 0.005,
                        'batch_size': 144,
                        'hidden_size': 128,
                        'alpha': 1024,
                        'dataset': "mnist",
                        'save_images': True,
                        'save_path': "log_images/",
                        'channels': 1,
                        'height': 28,
                        'width': 28}
        }

def write_yaml(data):
    """ A function to write YAML file"""
    with open('config.yml', 'w') as f:
        yaml.dump(data, f)

write_yaml(data)

In [7]:
from pydantic import BaseModel
from typing import Optional, Union
import yaml


class TrainConfig(BaseModel):
    max_epochs: int
    auto_lr_find: Union[bool, int]
    gpus: int


class VAEConfig(BaseModel):
    model_type: str
    hidden_size: int
    alpha: int
    dataset: str
    batch_size: Optional[int] = 64
    save_images: Optional[bool] = False
    lr: Optional[float] = None
    save_path: Optional[str] = None


class ConvVAEConfig(VAEConfig):
    channels: int
    height: int
    width: int


class LoggerConfig(BaseModel):
    name: str
    save_dir: str


class Config(BaseModel):
    model_config: Union[VAEConfig, ConvVAEConfig]
    train_config: TrainConfig
    model_type: str
    log_config: LoggerConfig


def load_config(path="config.yml"):
    config = yaml.load(open(path), yaml.SafeLoader)
    model_type = config['model_params']['model_type']
    if model_type == "vae":
        model_config = VAEConfig(**config["model_params"])
    else:
        raise NotImplementedError(f"Model {model_type} is not implemented")
    train_config = TrainConfig(**config["training_params"])
    log_config = LoggerConfig(**config["logger_params"])
    config = Config(model_config=model_config, train_config=train_config,
                    model_type=model_type, log_config=log_config)
    return config


config = load_config()

In [8]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import os


def make_model(config):
    model_type = config.model_type
    model_config = config.model_config
    if model_type not in vae_models:
        raise NotImplementedError("Model Architecture not implemented")
    else:
        return vae_models[model_type](**model_config.dict())

In [9]:
  model = make_model(config)
  train_config = config.train_config
  logger = TensorBoardLogger(**config.log_config.dict())
  trainer = Trainer(**train_config.dict(), logger=logger,
                    callbacks=LearningRateMonitor())
  if train_config.auto_lr_find:
      lr_finder = trainer.tuner.lr_find(model)
      new_lr = lr_finder.suggestion()
      print("Learning Rate Chosen:", new_lr)
      model.lr = new_lr
      trainer.fit(model)
  else:
      trainer.fit(model)
  if not os.path.isdir("./saved_models"):
      os.mkdir("./saved_models")
  trainer.save_checkpoint(
      f"saved_models/{config.model_type}_alpha_{config.model_config.alpha}_dim_{config.model_config.hidden_size}.ckpt")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


MisconfigurationException: ignored