# üëæ PixelCNN using PyTorch distributions

Âú®Ëøô‰∏™ Notebook ‰∏≠ÔºåÊàë‰ª¨Â∞Ü‰ΩøÁî® **PyTorch**ÔºåÂü∫‰∫é **Ê∑∑Âêà Logistic ÂàÜÂ∏ÉÔºàMixture of LogisticsÔºâ**Âú® **Fashion-MNIST** Êï∞ÊçÆÈõÜ‰∏äËÆ≠ÁªÉ‰∏Ä‰∏™ PixelCNN„ÄÇ

In [None]:
# %% 
%load_ext autoreload
%autoreload 2

In [None]:
# %%
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

import numpy as np
import matplotlib.pyplot as plt

## 0. Parameters <a name="parameters"></a>

In [None]:
IMAGE_SIZE = 32
N_COMPONENTS = 5      # Mixture of Logistics components
EPOCHS = 10
BATCH_SIZE = 128
LR = 1e-3

## 1. Prepare the data <a name="prepare"></a>

In [None]:
# %%
# Load Fashion-MNIST
train_dataset = datasets.FashionMNIST(
    root="./data",
    train=True,
    download=True,
    transform=transforms.ToTensor(),
)

x_train = train_dataset.data.numpy()

In [None]:
def preprocess(imgs):
    imgs = np.expand_dims(imgs, 1)  # (N,1,H,W)
    imgs = torch.tensor(imgs, dtype=torch.float32) / 255.0
    imgs = F.interpolate(imgs, size=(IMAGE_SIZE, IMAGE_SIZE), mode="nearest")
    return imgs

input_data = preprocess(x_train)

In [None]:
# %%
# Display samples
def display(images, n=10):
    images = images[:n]
    fig, axes = plt.subplots(1, n, figsize=(n * 1.5, 1.5))
    for i, ax in enumerate(axes):
        ax.imshow(images[i, 0], cmap="gray")
        ax.axis("off")
    plt.show()

display(input_data)

## 2. Build the PixelCNN <a name="build"></a>

In [None]:
# %%
class MaskedConv2d(nn.Conv2d):
    def __init__(self, mask_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert mask_type in ["A", "B"]
        self.register_buffer("mask", torch.ones_like(self.weight))
        _, _, h, w = self.weight.shape

        self.mask[:, :, h // 2, w // 2 + (mask_type == "B"):] = 0
        self.mask[:, :, h // 2 + 1:, :] = 0

    def forward(self, x):
        self.weight.data *= self.mask
        return super().forward(x)

In [None]:
# %%
class PixelCNN(nn.Module):
    def __init__(self, n_components):
        super().__init__()
        self.conv1 = MaskedConv2d("A", 1, 64, kernel_size=7, padding=3)
        self.conv2 = MaskedConv2d("B", 64, 64, kernel_size=3, padding=1)
        self.conv3 = MaskedConv2d("B", 64, 64, kernel_size=3, padding=1)

        # ËæìÂá∫ÔºöÊØè‰∏™ÂÉèÁ¥†ÁöÑ mixture logits / means / scales
        self.out = nn.Conv2d(
            64,
            n_components * 3,
            kernel_size=1,
        )

        self.n_components = n_components

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        return self.out(x)

In [None]:
# %%
def mixture_logistic_loss(x, params, n_components):
    B, _, H, W = x.shape
    params = params.view(B, n_components, 3, H, W)

    logits = params[:, :, 0]
    means = params[:, :, 1]
    log_scales = torch.clamp(params[:, :, 2], min=-7)

    x = x.unsqueeze(1)

    centered = x - means
    inv_std = torch.exp(-log_scales)
    log_probs = centered * inv_std
    log_probs = -log_probs - log_scales - 2 * F.softplus(-log_probs)

    log_probs = log_probs + F.log_softmax(logits, dim=1)
    return -torch.mean(torch.logsumexp(log_probs, dim=1))

## 3. Train the PixelCNN <a name="train"></a>

In [None]:
# %%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PixelCNN(N_COMPONENTS).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

dataset = TensorDataset(input_data)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
# %%
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0

    for (x,) in loader:
        x = x.to(device)

        optimizer.zero_grad()
        params = model(x)
        loss = mixture_logistic_loss(x, params, N_COMPONENTS)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch:02d} | NLL: {total_loss / len(loader):.4f}")

## 4. Generate images <a name="generate"></a>

In [None]:
# %%
@torch.no_grad()
def sample(model, n_samples):
    model.eval()
    images = torch.zeros(
        n_samples, 1, IMAGE_SIZE, IMAGE_SIZE, device=device
    )

    for i in range(IMAGE_SIZE):
        for j in range(IMAGE_SIZE):
            params = model(images)
            params = params.view(
                n_samples, N_COMPONENTS, 3, IMAGE_SIZE, IMAGE_SIZE
            )

            logits = params[:, :, 0, i, j]
            means = params[:, :, 1, i, j]

            comp = torch.multinomial(
                F.softmax(logits, dim=1), 1
            ).squeeze()

            images[:, 0, i, j] = means[
                torch.arange(n_samples), comp
            ]

    return images.cpu()

In [None]:
# %%
generated_images = sample(model, n_samples=2)
display(generated_images, n=2)