In [2]:
import torch
from datasets import load_mnist
import numpy as np
import matplotlib.pyplot as plt

BATCH_SIZE = 128
train_loader, test_loader, val_loader = load_mnist(BATCH_SIZE)

In [3]:
import pytorch_lightning as pl
from models import MultiDecoderConditionalVAE

class MDSCVAE(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.MultiDecoderConditionalVAE = MultiDecoderConditionalVAE()

    def forward(self, x, x_cond):
        return self.MultiDecoderConditionalVAE(x, x_cond)
    
    def training_step(self, batch, batch_idx):
        x, x_cond, y = batch
        x_hat, x_hat_2, z_mu, z_logvar, z = self(x, x_cond)
        recon_loss_conditioned, recon_loss, kl_loss, loss = self.MultiDecoderConditionalVAE.loss(x, x_hat, x_hat_2, z_mu, z_logvar)
        self.log('train_loss', loss)
        self.log('train_kl_loss', kl_loss)
        self.log('train_recon_loss', recon_loss)
        self.log('train_recon_loss_conditioned', recon_loss_conditioned)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, x_cond, y = batch
        x_hat, x_hat_2, z_mu, z_logvar, z = self(x, x_cond)
        recon_loss_conditioned, recon_loss, kl_loss, loss = self.MultiDecoderConditionalVAE.loss(x, x_hat, x_hat_2, z_mu, z_logvar)
        self.log('val_loss', loss)
        self.log('val_kl_loss', kl_loss)
        self.log('val_recon_loss', recon_loss)
        self.log('val_recon_loss_conditioned', recon_loss_conditioned)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, x_cond, y = batch
        x_hat, x_hat_2, z_mu, z_logvar, z = self(x, x_cond)
        recon_loss_conditioned, recon_loss, kl_loss, loss = self.MultiDecoderConditionalVAE.loss(x, x_hat, x_hat_2, z_mu, z_logvar)
        self.log('test_loss', loss)
        self.log('test_kl_loss', kl_loss)
        self.log('test_recon_loss', recon_loss)
        self.log('test_recon_loss_conditioned', recon_loss_conditioned)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    

In [4]:
from pytorch_lightning.callbacks import RichProgressBar
model = MDSCVAE()
trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=40, enable_progress_bar=True, callbacks=[RichProgressBar()])
trainer.fit(model, train_loader, val_loader)
#save
trainer.save_checkpoint('checkpoints/mdscvae_40.ckpt')

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

`Trainer.fit` stopped: `max_epochs=40` reached.
