In [6]:
import torch
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import TensorBoardLogger
import pytorch_lightning as pl
import torchvision.utils as vutils

In [2]:
transform = transforms.Compose([transforms.ToTensor()])
train_data = MNIST("./data", train=True, transform = transform)
test_data = MNIST("./data", train=False, transform = transform)
trainloader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=15)
testloader = DataLoader(test_data, batch_size=64)



In [None]:
type(next(iter(trainloader))[0])

: 

In [12]:
class VAE(L.LightningModule):
    def __init__(self):
        super().__init__()

        # encoder
        self.fc1 = nn.Linear(784, 256)
        self.fcmu = nn.Linear(256, 20)
        self.fc2 = nn.Linear(784,256)
        self.fcvar = nn.Linear(256,20)
        
        # decoder
        self.fc3 = nn.Linear(20,256)
        self.fc4 = nn.Linear(256,784)

        self.activation = nn.LeakyReLU()
        
    def forward(self,x):
        x = x.view(x.shape[0], -1)
        mu, log_var = self.encode(x)
        z = self.reparameterization(mu, torch.exp(.5*log_var))
        x_hat = self.decode(z)
        x_hat
        return x_hat

    def encode(self, x):
        x1 = self.activation(self.fc1(x))
        mu = self.fcmu(x1)

        x2 = self.activation(self.fc2(x))
        log_var = self.fcvar(x2)

        return mu, log_var

    def decode(self,z):
        x = self.activation(self.fc3(z))

        # apply sigmoid so all values will be between (0,1)
        # this allows for bernoulli distribution to be applied
        x = torch.sigmoid(self.fc4(x))
        return x

    def reparameterization(self,mu, std):
        epsilon = torch.randn_like(std)
        return mu + std*epsilon

    def training_step(self,batch,batch_idx):
        x, _ = batch
        x = x.view(x.shape[0], -1)
        mu, log_var = self.encode(x)
        z = self.reparameterization(mu, torch.exp(.5*log_var))
        x_hat = self.decode(z)

        reconstruction_loss = nn.functional.binary_cross_entropy(x_hat,x, reduction="sum")
        kl_div = -0.5 * torch.sum(1+log_var - mu.pow(2)-torch.exp(log_var))
        loss = reconstruction_loss + kl_div
        self.log("training loss",loss, on_step=False, on_epoch=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())


In [14]:
trainer = L.Trainer(accelerator="auto",max_epochs=30,default_root_dir="checkpoints")
model = VAE()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(model, trainloader)

You are using a CUDA device ('NVIDIA L4') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type      | Params | Mode 
-------------------------------------------------
0 | fc1        | Linear    | 200 K  | train
1 | fcmu       | Linear    | 5.1 K  | train
2 | fc2        | Linear    | 200 K  | train
3 | fcvar      | Linear    | 5.1 K  | train
4 | fc3        | Linear    | 5.4 K  | train
5 | fc4        | Linear    | 201 K  | train
6 | activation | LeakyReLU | 0      | train
-------------------------------------------------
619 K     Trainable params
0         Non-trainable params
619 K     Total params
2.476     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


In [17]:
logger = TensorBoardLogger("tb_logs", name="vae_compare")

# Load the trained model
model = VAE.load_from_checkpoint("checkpoints/epoch=29-step=28140.ckpt")  # Replace with your checkpoint path
model.eval()
for batch in testloader:
    x, _ = batch
    x_hat = model(x)

    # Reshape to image format
    x_hat = x_hat.view(-1, 1, 28, 28)

    # Log the images
    grid_original = vutils.make_grid(x, normalize=True, scale_each=True)
    grid_reconstructed = vutils.make_grid(x_hat, normalize=True, scale_each=True)
    logger.experiment.add_image('Original Images', grid_original, 0)
    logger.experiment.add_image('Reconstructed Images', grid_reconstructed, 0)

    break  # Only log the first batch for demonstration purposes