In [4]:
import torch
from datasets import load_mnist, get_observation_pixels
import numpy as np
import matplotlib.pyplot as plt

BATCH_SIZE = 128
train_loader, test_loader, val_loader = load_mnist(BATCH_SIZE)

In [9]:
import pytorch_lightning as pl
from models import VAE

class VAEModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.ConditionalVAE = VAE()
    
    def forward(self, x, x_cond):
        return self.ConditionalVAE(x, x_cond)
    
    def training_step(self, batch, batch_idx):
        x, x_cond, y = batch
        output, z_mean, z_log_var, z = self.ConditionalVAE(x, x_cond)
        recon_loss, kl_loss, loss = self.ConditionalVAE.loss(x, output, z_mean, z_log_var)
        self.log('train_loss', loss)
        self.log('train_recon_loss', recon_loss)
        self.log('train_kl_loss', kl_loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, x_cond, y = batch
        output, z_mean, z_log_var, z = self.ConditionalVAE(x, x_cond)
        recon_loss, kl_loss, loss = self.ConditionalVAE.loss(x, output, z_mean, z_log_var)
        self.log('val_loss', loss)
        self.log('val_recon_loss', recon_loss)
        self.log('val_kl_loss', kl_loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, x_cond, y = batch
        output, z_mean, z_log_var, z = self.ConditionalVAE(x, x_cond)
        recon_loss, kl_loss, loss = self.ConditionalVAE.loss(x, output, z_mean, z_log_var)
        self.log('test_loss', loss)
        self.log('test_recon_loss', recon_loss)
        self.log('test_kl_loss', kl_loss)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [10]:
from pytorch_lightning.callbacks import RichProgressBar
scvae = VAEModel()
trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=40, enable_progress_bar=True, callbacks=[RichProgressBar()])
trainer.fit(scvae, train_loader, val_loader)
#save
trainer.save_checkpoint('scvae_40.ckpt')

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

TypeError: VAE.forward() takes 2 positional arguments but 3 were given