In [1]:
import torch
from datasets import load_mnist, get_observation_pixels
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim

BATCH_SIZE = 100
train_loader, test_loader, val_loader = load_mnist(BATCH_SIZE)

In [3]:
import pytorch_lightning as pl

from models import VAE

class VAETrainer(pl.LightningModule):
    def __init__(self, hidden_dims = [128, 256], latent_dim = 2):
        super(VAETrainer, self).__init__()
        
        self.model = VAE(hidden_dims, latent_dim)
        
    def forward(self, x, x_cond, y):
        return self.model(x)
    
    def step(self, batch, batch_idx, mode = 'train'):
        x, x_cond, y = batch
        x_hat, mu, log_var, z = self(x, x_cond, y)
        loss = self.model.loss(x, x_hat, mu, log_var)
        self.log_dict({f"{mode}_{key}": val.item() for key, val in loss.items()}, sync_dist=True, prog_bar=True)
        return loss['loss']
    
    def decode(self, z):
        return self.model.decode(z)
        
    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, 'train')
    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, 'val')
    def test_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, 'test')
    
    def configure_optimizers(self):
        return optim.Adam(self.model.parameters(), lr=1e-3)

In [3]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='checkpoints/',
    filename='base-vae-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

logger = TensorBoardLogger('./', version="vae_softadapt_50(W_N)_")

model = VAETrainer()
trainer = pl.Trainer(accelerator='gpu', devices=[5], max_epochs=50, enable_progress_bar=True, callbacks=[checkpoint_callback], logger=logger)
trainer.fit(model, train_loader, val_loader)
#save
trainer.save_checkpoint('checkpoints/vae_50.ckpt')

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
You are using a CUDA device ('A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name  | Type | Params
-------------------------------
0 | model | VAE  | 899 K 
-------------------------------
899 K     Trainable params
1         Non-trainable params
899 K     Total params
3.596     Total estimated model params size (MB)


Epoch 0:   5%|▌         | 32/600 [00:01<00:18, 29.93it/s, v_num=_N)_, train_recon_loss=5.37e+3, train_kl_loss=279.0, train_loss=2.73e+3, train_loss(no_weights)=5.65e+3]

  exp_fi = np.exp(self.beta * (fi - max_si))


Epoch 49: 100%|██████████| 600/600 [00:07<00:00, 79.86it/s, v_num=_N)_, train_recon_loss=2.74e+3, train_kl_loss=541.0, train_loss=1.58e+3, train_loss(no_weights)=3.28e+3, val_recon_loss=2.92e+3, val_kl_loss=537.0, val_loss=1.74e+3, val_loss(no_weights)=3.46e+3]

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|██████████| 600/600 [00:07<00:00, 79.84it/s, v_num=_N)_, train_recon_loss=2.74e+3, train_kl_loss=541.0, train_loss=1.58e+3, train_loss(no_weights)=3.28e+3, val_recon_loss=2.92e+3, val_kl_loss=537.0, val_loss=1.74e+3, val_loss(no_weights)=3.46e+3]


In [2]:
model = VAETrainer.load_from_checkpoint('checkpoints/vae_50.ckpt', map_location='cpu')
model.eval()

from plotting import plot_samples_with_reconstruction
from plotting import plot_latent_images
# plot some samples and their reconstructions

plot_samples_with_reconstruction(model, next(iter(test_loader)))
plot_latent_images(model)

NameError: name 'VAETrainer' is not defined