In [None]:
# !pip install pytorch-lightning
# !pip install lightning-bolts
!brew install pytorch
!brew install pytorchvision

pytorch 2.0.0 is already installed but outdated (so it will be upgraded).
[32m==>[0m [1mFetching dependencies for pytorch: [32mlibuv[39m, [32mmpfr[39m, [32mnumpy[39m, [32mprotobuf@21[39m, [32mpybind11[39m, [32mpython-typing-extensions[39m, [32mpyyaml[39m and [32mlibomp[39m[0m
[32m==>[0m [1mFetching [32mlibuv[39m[0m
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/libuv/manifests/1.46.0[0m
######################################################################### 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/libuv/blobs/sha256:8c3beb4d11ed[0m
######################################################################### 100.0%
[32m==>[0m [1mFetching [32mmpfr[39m[0m
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/mpfr/manifests/4.2.0-p12[0m
######################################################################### 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/mpfr/blobs/sha256:43db5951067

In [1]:
import torch
import pytorch_lightning as pl
from torch import nn
from torch.nn import functional as F
from pl_bolts.models.autoencoders.components import (
    resnet18_decoder,
    resnet18_encoder,
)

class VAE(pl.LightningModule):
    def __init__(self, enc_out_dim=512, latent_dim=256, input_height=178):
        super().__init__()

        self.save_hyperparameters()

        # encoder, decoder
        self.encoder = resnet18_encoder(False, False)
        self.decoder = resnet18_decoder(
            latent_dim=latent_dim,
            input_height=input_height,
            first_conv=False,
            maxpool1=False
        )

        # distribution parameters
        self.fc_mu = nn.Linear(enc_out_dim, latent_dim)
        self.fc_var = nn.Linear(enc_out_dim, latent_dim)

        # for the gaussian likelihood
        self.eps = nn.Torch.Distributions.Normal(0, 1)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

    def gaussian_likelihood(self, mean, logscale, sample):
        scale = torch.exp(logscale)
        dist = torch.distributions.Normal(mean, scale)
        log_pxz = dist.log_prob(sample)
        return log_pxz.sum(dim=(1, 2, 3))

    def kl_divergence(self, z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)

        # 2. get the probabilities from the equation
        log_qzx = q.log_prob(z)
        log_pz = p.log_prob(z)

        # kl
        kl = (log_qzx - log_pz)
        kl = kl.sum(-1)
        return kl

    def training_step(self, batch, batch_idx):
        x = batch

        #encode 
        x_encoded = self.encoder(x)
        mu, sigma = self.fc_mu(x_encoded), torch.exp(self.fc_var(x_encoded))

        # get (sample) z from dist normal (q)
        q = self.eps.sample(mu, sigma)
        z = mu + sigma* q

        # decode
        x_hat = self.decoder(z)

        #loss
        recon_loss = torch.nn.MSELoss(x,x_hat)
        kl = self.kl_divergence(z, mu, sigma)
        elbo = (kl - recon_loss)
        elbo = elbo.mean()

        self.log_dict({
            'elbo': elbo,
            'kl': kl.mean(),
            'recon_loss': recon_loss.mean(), 
            'kl': kl.mean(),
        })

        return elbo

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'pytorch_lightning'

In [30]:
import torchvisio

import os
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn.functional as F


path = './img_align_celeba/'
class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_labels = os.listdir(img_dir)

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        full_path = str(idx).rjust(6, '0') + '.jpg'
        img_path = os.path.join(self.img_dir, full_path)
        image = read_image(img_path)
        image = image.unsqueeze(0)
        image = F.interpolate(image, size=(3, 32, 32), mode='bilinear', align_corners=False)
        image = image.squeeze(0)
        image = image.float()
        if self.transform:
            image = self.transform(image)
        return image



path = './img_align_celeba/'
dataset = CustomImageDataset(path)
celeba = DataLoader(dataset, batch_size=32, shuffle=True)

ModuleNotFoundError: No module named 'torchvisio'

In [27]:
pl.seed_everything(1234)

vae = VAE()
trainer = pl.Trainer(max_epochs=7, accelerator='mps')
trainer.fit(vae, celeba)

Global seed set to 1234
  self.encoder = resnet18_encoder(False, False)
  return ResNetEncoder(EncoderBlock, [2, 2, 2, 2], first_conv, maxpool1)
  layers.append(block(self.inplanes, planes, stride, downsample))
  self.conv1 = conv3x3(inplanes, planes, stride)
  conv1x1(self.inplanes, planes * block.expansion, stride),
  self.decoder = resnet18_decoder(
  return ResNetDecoder(DecoderBlock, [2, 2, 2, 2], latent_dim, input_height, first_conv, maxpool1)
  resize_conv1x1(self.inplanes, planes * block.expansion, scale),
  return nn.Sequential(Interpolate(scale_factor=scale), conv1x1(in_planes, out_planes))
  layers.append(block(self.inplanes, planes, scale, upsample))
  self.conv1 = resize_conv3x3(inplanes, inplanes)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type          | Params
------------------------------------------
0 | encoder | ResNetEncoder | 11.2 M
1 | decode

Training: 0it [00:00, ?it/s]

ValueError: Input and output must have the same number of spatial dimensions, but got input with spatial dimensions of [218, 178] and output size of (3, 32, 32). Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size in (o1, o2, ...,oK) format.

In [None]:
from matplotlib.pyplot import imshow, figure
import numpy as np
from torchvision.utils import make_grid
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
figure(figsize=(8, 3), dpi=300)

# Z COMES FROM NORMAL(0, 1)
num_preds = 16
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
z = p.rsample((num_preds,))

# SAMPLE IMAGES
with torch.no_grad():
    pred = vae.decoder(z.to(vae.device)).cpu()

# UNDO DATA NORMALIZATION
normalize = cifar10_normalization()
mean, std = np.array(normalize.mean), np.array(normalize.std)
img = make_grid(pred).permute(1, 2, 0).numpy() * std + mean

# PLOT IMAGES
imshow(img)