In [1]:
import os
import math
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

from torchsummary import summary
from fastprogress import progress_bar

In [2]:
# define image transformation
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((128, 128)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# download dataset
train_dataset = torchvision.datasets.SUN397(root='dataset', transform=transform, download=True)

# # select classes
classes_idx = [idx for idx, label in enumerate(train_dataset.classes) if label in
    ['bedroom', 'beach', 'skyscraper', 'lighthouse', 'windmill', 'mountain', 'castle', 'rice_paddy', 'forest_path', 'bridge']]

# filter dataset by classes
train_dataset = torch.utils.data.Subset(train_dataset, [idx for idx, label in enumerate(train_dataset._labels) if label in classes_idx])

print('Dataset size:', len(train_dataset))

In [3]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)

# extract a batch of images
batch_images, batch_labels = next(iter(train_loader))
# display a grid of images
plt.figure(figsize=(16, 12))
plt.imshow(np.transpose(torchvision.utils.make_grid(batch_images[:8], padding=0).numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()

#### Residual Block

In [4]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.group_norm_1 = torch.nn.GroupNorm(32, in_channels)
        self.conv_1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

        self.group_norm_2 = torch.nn.GroupNorm(32, out_channels)
        self.conv_2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        if in_channels == out_channels:
            self.residual_layer = torch.nn.Identity()
        else:
            self.residual_layer = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
    
    def forward(self, x):
        residue = x

        x = self.group_norm_1(x)
        x = torch.nn.functional.silu(x)
        x = self.conv_1(x)

        x = self.group_norm_2(x)
        x = torch.nn.functional.silu(x)
        x = self.conv_2(x)

        return x + self.residual_layer(residue)

#### Attention Block

In [5]:
class SelfAttention(torch.nn.Module):
    def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
        super().__init__()
        self.in_proj = torch.nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
        self.out_proj = torch.nn.Linear(d_embed, d_embed, bias=out_proj_bias)
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads

    def forward(self, x, causal_mask=False):
        input_shape = x.shape
        batch_size, sequence_length, _ = input_shape
        interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)

        q, k, v = self.in_proj(x).chunk(3, dim=-1)

        q = q.view(interim_shape).transpose(1, 2)
        k = k.view(interim_shape).transpose(1, 2)
        v = v.view(interim_shape).transpose(1, 2)

        weight = q @ k.transpose(-1, -2)
        if causal_mask:
            mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
            weight.masked_fill_(mask, -torch.inf)
        weight /= math.sqrt(self.d_head)
        weight = torch.nn.functional.softmax(weight, dim=-1)

        output = weight @ v
        output = output.transpose(1, 2)
        output = output.reshape(input_shape)
        output = self.out_proj(output)
        return output

In [6]:
class AttentionBlock(torch.nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.group_norm = torch.nn.GroupNorm(32, channels)
        self.attention = SelfAttention(1, channels)
    
    def forward(self, x):
        residue = x
        x = self.group_norm(x)

        n, c, h, w = x.shape
        x = x.view((n, c, h * w))
        x = x.transpose(-1, -2)
        x = self.attention(x)
        x = x.transpose(-1, -2)
        x = x.view((n, c, h, w))

        x += residue
        return x

#### Encoder Block

In [7]:
class Encoder(torch.nn.Sequential):
    def __init__(self):
        super().__init__(
            torch.nn.Conv2d(3, 64, kernel_size=3, padding=1),
            ResidualBlock(64, 64),
            torch.nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=0),
            ResidualBlock(64, 128),
            torch.nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
            ResidualBlock(128, 256),
            torch.nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
            ResidualBlock(256, 256),
            AttentionBlock(256),
            ResidualBlock(256, 256),
            torch.nn.GroupNorm(32, 256),
            torch.nn.SiLU(),
            torch.nn.Conv2d(256, 8, kernel_size=3, padding=1),
            torch.nn.Conv2d(8, 8, kernel_size=1, padding=0),
        )

    def forward(self, x):
        for module in self:
            if getattr(module, 'stride', None) == (2, 2):
                x = torch.nn.functional.pad(x, (0, 1, 0, 1))
            x = module(x)

        mean, log_variance = torch.chunk(x, 2, dim=1)
        log_variance = torch.clamp(log_variance, -30, 20)
        variance = log_variance.exp()
        stdev = variance.sqrt()
        noise = torch.randn(*mean.size()).float().cuda()
        x = mean + stdev * noise

        x *= 0.18215
        return mean, log_variance, x

In [8]:
encoder = Encoder().float().cuda()
z_mean, z_log_var, z = encoder(torch.rand((8, 3, 128, 128)).float().cuda())

print('Latent space shape', z_mean.shape)

Latent space shape torch.Size([8, 4, 16, 16])


In [9]:
summary(encoder, input_size=(3, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           1,792
         GroupNorm-2         [-1, 64, 128, 128]             128
            Conv2d-3         [-1, 64, 128, 128]          36,928
         GroupNorm-4         [-1, 64, 128, 128]             128
            Conv2d-5         [-1, 64, 128, 128]          36,928
          Identity-6         [-1, 64, 128, 128]               0
     ResidualBlock-7         [-1, 64, 128, 128]               0
            Conv2d-8           [-1, 64, 64, 64]          36,928
         GroupNorm-9           [-1, 64, 64, 64]             128
           Conv2d-10          [-1, 128, 64, 64]          73,856
        GroupNorm-11          [-1, 128, 64, 64]             256
           Conv2d-12          [-1, 128, 64, 64]         147,584
           Conv2d-13          [-1, 128, 64, 64]           8,320
    ResidualBlock-14          [-1, 128,

#### Decoder Block

In [10]:
class Decoder(torch.nn.Sequential):
    def __init__(self):
        super().__init__(
            torch.nn.Conv2d(4, 4, kernel_size=1, padding=0),
            torch.nn.Conv2d(4, 256, kernel_size=3, padding=1),
            ResidualBlock(256, 256),
            AttentionBlock(256),
            ResidualBlock(256, 256),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),
            ResidualBlock(256, 256),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),
            ResidualBlock(256, 128),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(128, 128, kernel_size=3, padding=1),
            ResidualBlock(128, 64),
            torch.nn.GroupNorm(32, 64),
            torch.nn.SiLU(),
            torch.nn.Conv2d(64, 3, kernel_size=3, padding=1),
        )

    def forward(self, x):
        x /= 0.18215
        for module in self:
            x = module(x)
        return x

In [11]:
decoder = Decoder().float().cuda()
reconstruction = decoder(torch.rand((8, 4, 16, 16)).float().cuda())

print('Reconstruction shape', reconstruction.shape)

Reconstruction shape torch.Size([8, 3, 128, 128])


In [12]:
summary(decoder, input_size=(4, 16, 16))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 4, 16, 16]              20
            Conv2d-2          [-1, 256, 16, 16]           9,472
         GroupNorm-3          [-1, 256, 16, 16]             512
            Conv2d-4          [-1, 256, 16, 16]         590,080
         GroupNorm-5          [-1, 256, 16, 16]             512
            Conv2d-6          [-1, 256, 16, 16]         590,080
          Identity-7          [-1, 256, 16, 16]               0
     ResidualBlock-8          [-1, 256, 16, 16]               0
         GroupNorm-9          [-1, 256, 16, 16]             512
           Linear-10             [-1, 256, 768]         197,376
           Linear-11             [-1, 256, 256]          65,792
    SelfAttention-12             [-1, 256, 256]               0
   AttentionBlock-13          [-1, 256, 16, 16]               0
        GroupNorm-14          [-1, 256,

#### Train Variational Auto-Encoder

In [13]:
EPOCHS = 100
LEARNING_RATE = 5e-3

def loss_fn(recon_x, x, mean, log_var):
    mse = torch.nn.functional.mse_loss(recon_x, x, size_average=False)
    kld = -0.5 * torch.mean(1 + log_var - torch.pow(mean, 2) - torch.exp(log_var))

    return mse + kld, mse, kld

def save_checkpoint(checkpoint_path="diffusion_vae"):
    torch.save(encoder.state_dict(), os.path.join(checkpoint_path, "encoder_ckpt.pt"))
    torch.save(decoder.state_dict(), os.path.join(checkpoint_path, "decoder_ckpt.pt"))

scaler = torch.cuda.amp.GradScaler()
optimizer = torch.optim.AdamW(list(encoder.parameters()) + list(decoder.parameters()), lr=LEARNING_RATE, eps=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [7]:
for epoch in range(EPOCHS):
    p_bar = progress_bar(train_loader, leave=False)
    avg_loss = 0.
    avg_kl_loss = 0.
    avg_reconstruction_loss = 0.
    for image, label in p_bar:
        with torch.autocast("cuda") and torch.enable_grad():
            # forward pass
            z_mean, z_log_var, z = encoder(image)
            reconstruction = decoder(z)
            # calculate loss
            loss, mse, kld = loss_fn(reconstruction, image, z_mean, z_log_var)
        # backward pass
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        # logger
        avg_loss += loss
        avg_kl_loss += kld
        avg_reconstruction_loss += mse
        p_bar.comment = f"total_loss: {loss.item():.2e}, reconstruction_loss: {mse.item():.2e}, kl_loss: {kld.item():.2e}"
    # log average loss
    print(f"Epoch {epoch + 1}/{EPOCHS}: total_loss: {avg_loss.mean().item():.2e}, reconstruction_loss: {avg_reconstruction_loss.mean().item():.2e}, kl_loss: {avg_kl_loss.mean().item():.2e}")
    # save checkpoint
    save_checkpoint("diffusion_vae")

In [None]:
save_checkpoint("diffusion_vae")

#### Test Model

In [None]:
encoder.load_state_dict(torch.load(os.path.join("diffusion_vae", "encoder_ckpt.pt")))
decoder.load_state_dict(torch.load(os.path.join("diffusion_vae", "decoder_ckpt.pt")))

In [None]:
mean, z_log_var, z = encoder(batch_images)
reconstructed = decoder(z)

# display a reconstructed images
plt.figure(figsize=(16, 12))
plt.imshow(np.transpose(torchvision.utils.make_grid(reconstructed[:8], padding=0).numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()