In [1]:
import torch
from datasets import load_mnist, get_observation_pixels
import numpy as np
import matplotlib.pyplot as plt

BATCH_SIZE = 64
train_loader, test_loader, val_loader = load_mnist(BATCH_SIZE)

In [2]:
import pytorch_lightning as pl
from models import LabelConditionalVAE

class LabelSCVAE(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.LabelConditionalVAE = LabelConditionalVAE()
    
    def forward(self, x, y):
        return self.LabelConditionalVAE(x, y)
    
    def training_step(self, batch, batch_idx):
        x, x_cond, y = batch
        output, z_mean, z_log_var, z = self.LabelConditionalVAE(x, y)
        recon_loss, kl_loss, loss = self.LabelConditionalVAE.loss(x, output, z_mean, z_log_var)
        self.log('train_loss', loss)
        self.log('train_recon_loss', recon_loss)
        self.log('train_kl_loss', kl_loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, x_cond, y = batch
        output, z_mean, z_log_var, z = self.LabelConditionalVAE(x, y)
        recon_loss, kl_loss, loss = self.LabelConditionalVAE.loss(x, output, z_mean, z_log_var)
        self.log('val_loss', loss)
        self.log('val_recon_loss', recon_loss)
        self.log('val_kl_loss', kl_loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, x_cond, y = batch
        output, z_mean, z_log_var, z = self.LabelConditionalVAE(x, y)
        recon_loss, kl_loss, loss = self.LabelConditionalVAE.loss(x, output, z_mean, z_log_var)
        self.log('test_loss', loss)
        self.log('test_recon_loss', recon_loss)
        self.log('test_kl_loss', kl_loss)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [3]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='checkpoints/',
    filename='label_scvae_{epoch:02d}_latent_40',
    save_top_k=3,
    mode='min',
)

model = LabelSCVAE()


trainer = pl.Trainer(accelerator='gpu', devices=[5], max_epochs=40, enable_progress_bar=True, callbacks=[checkpoint_callback])
trainer.fit(model, train_loader, val_loader)
#save
trainer.save_checkpoint('checkpoints/label_scvae_40_latent_40.ckpt')

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
You are using a CUDA device ('A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name                | Type                | Params
------------------------------------------------------------
0 | LabelConditionalVAE | LabelConditionalVAE | 2.4 M 
------------------------------------------------------------
2.4 M     Trainable params
0         Non-trainable params
2.4 M     Total params
9.660     

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=40` reached.
