In [1]:
import torch
from datasets import load_mnist, get_observation_pixels
import numpy as np
import matplotlib.pyplot as plt

BATCH_SIZE = 128
train_loader, test_loader, val_loader = load_mnist(BATCH_SIZE)

In [2]:
import pytorch_lightning as pl
from models import LabelConditionalVAE

class Label_SCVAE(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.LabelConditionalVAE = LabelConditionalVAE()
    
    def forward(self, x, y):
        return self.LabelConditionalVAE(x, y)
    
    def training_step(self, batch, batch_idx):
        x, x_cond, y = batch
        output, z_mean, z_log_var, z = self.LabelConditionalVAE(x, y)
        recon_loss, kl_loss, loss = self.LabelConditionalVAE.loss(x, output, z_mean, z_log_var)
        self.log('train_loss', loss)
        self.log('train_recon_loss', recon_loss)
        self.log('train_kl_loss', kl_loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, x_cond, y = batch
        output, z_mean, z_log_var, z = self.LabelConditionalVAE(x, y)
        recon_loss, kl_loss, loss = self.LabelConditionalVAE.loss(x, output, z_mean, z_log_var)
        self.log('val_loss', loss)
        self.log('val_recon_loss', recon_loss)
        self.log('val_kl_loss', kl_loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, x_cond, y = batch
        output, z_mean, z_log_var, z = self.LabelConditionalVAE(x, y)
        recon_loss, kl_loss, loss = self.LabelConditionalVAE.loss(x, output, z_mean, z_log_var)
        self.log('test_loss', loss)
        self.log('test_recon_loss', recon_loss)
        self.log('test_kl_loss', kl_loss)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)