In [8]:
! pip install pytorch-lightning-bolts==0.3.0

Defaulting to user installation because normal site-packages is not writeable
Collecting pytorch-lightning-bolts==0.3.0
  Obtaining dependency information for pytorch-lightning-bolts==0.3.0 from https://files.pythonhosted.org/packages/6d/c0/f600b26020bc74e24d3e12163a3a4730b269dc7095643750f4665707c526/pytorch_lightning_bolts-0.3.0-py3-none-any.whl.metadata
  Downloading pytorch_lightning_bolts-0.3.0-py3-none-any.whl.metadata (12 kB)
Downloading pytorch_lightning_bolts-0.3.0-py3-none-any.whl (247 kB)
   ---------------------------------------- 0.0/247.3 kB ? eta -:--:--
   ---------------------------------------- 0.0/247.3 kB ? eta -:--:--
   - -------------------------------------- 10.2/247.3 kB ? eta -:--:--
   - -------------------------------------- 10.2/247.3 kB ? eta -:--:--
   - -------------------------------------- 10.2/247.3 kB ? eta -:--:--
   --------- ----------------------------- 61.4/247.3 kB 365.7 kB/s eta 0:00:01
   ----------------- -------------------- 112.6/247.3 kB 5

In [9]:
import pytorch_lightning as pl
from torch import nn
from torch.nn import functional as F
from pl_bolts.models.autoencoders.components import (
    resnet18_decoder,
    resnet18_encoder,
)

In [10]:
class VAE(pl.LightningModule):
    def __init__(self, enc_out_dim=512, latent_dim=256, input_height=32):
        super().__init__()

        self.save_hyperparameters()

        # encoder, decoder
        self.encoder = resnet18_encoder(False, False)
        self.decoder = resnet18_decoder(
            latent_dim=latent_dim, 
            input_height=input_height, 
            first_conv=False, 
            maxpool1=False
        )

        # distribution parameters
        self.fc_mu = nn.Linear(enc_out_dim, latent_dim)
        self.fc_var = nn.Linear(enc_out_dim, latent_dim)

        # for the gaussian likelihood
        self.log_scale = nn.Parameter(torch.Tensor([0.0]))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

    def gaussian_likelihood(self, mean, logscale, sample):
        scale = torch.exp(logscale)
        dist = torch.distributions.Normal(mean, scale)
        log_pxz = dist.log_prob(sample)
        return log_pxz.sum(dim=(1, 2, 3))

    def kl_divergence(self, z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)

        # 2. get the probabilities from the equation
        log_qzx = q.log_prob(z)
        log_pz = p.log_prob(z)

        # kl
        kl = (log_qzx - log_pz)
        kl = kl.sum(-1)
        return kl

    def training_step(self, batch, batch_idx):
        x, _ = batch

        # encode x to get the mu and variance parameters
        x_encoded = self.encoder(x)
        mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)

        # sample z from q
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()

        # decoded 
        x_hat = vae.decoder(z)

        # reconstruction loss
        recon_loss = self.gaussian_likelihood(x_hat, self.log_scale, x)

        # kl
        kl = self.kl_divergence(z, mu, std)

        # elbo
        elbo = (kl - recon_loss)
        elbo = elbo.mean()

        self.log_dict({
            'elbo': elbo,
            'kl': kl.mean(),
            'recon_loss': recon_loss.mean(), 
            'reconstruction': recon_loss.mean(),
            'kl': kl.mean(),
        })

        return elbo

In [12]:
pl.__version__

'2.2.2'

In [11]:
from pl_bolts.datamodules import CIFAR10DataModule

datamodule = CIFAR10DataModule('.')

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "d:\Software\Anaconda3\Lib\site-packages\IPython\core\interactiveshell.py", line 3526, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\Trial\AppData\Local\Temp\ipykernel_20104\1802927316.py", line 1, in <module>
    from pl_bolts.datamodules import CIFAR10DataModule
  File "C:\Users\Trial\AppData\Roaming\Python\Python311\site-packages\pl_bolts\__init__.py", line 47, in <module>
    from pl_bolts import callbacks, datamodules, datasets, losses, metrics, models, optimizers, transforms, utils
  File "C:\Users\Trial\AppData\Roaming\Python\Python311\site-packages\pl_bolts\callbacks\__init__.py", line 5, in <module>
    from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor  # noqa: F401
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Trial\AppData\Roaming\Python\Python311\site-packages\pl_bolts\callbacks\data_monitor.py", line 6, in

In [None]:
from pl_bolts.datamodules import CIFAR10DataModule

# train_transforms = torchvision.transforms.Compose(
#     [
#         torchvision.transforms.RandomCrop(32, padding=4),
#         torchvision.transforms.RandomHorizontalFlip(),
#         torchvision.transforms.ToTensor(),
#         cifar10_normalization(),
#     ]
# )

# test_transforms = torchvision.transforms.Compose(
#     [
#         torchvision.transforms.ToTensor(),
#         cifar10_normalization(),
#     ]
# )

# cifar10_dm = CIFAR10DataModule(
#     data_dir='.',
#     batch_size=16,
#     num_workers=2,
#     train_transforms=train_transforms,
#     test_transforms=test_transforms,
#     val_transforms=test_transforms,
# )

# datamodule = CIFAR10DataModule('.')

# traindl = DataLoader(datamodule)

# trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
#                                         download=True)
# sampler = torch.utils.data.distributed.DistributedSampler(
#     trainset,
#     num_replicas=xm.xrt_world_size(),
#     shuffle=True)

# trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,
#                                           shuffle=True)
