Notebook for checking individual layers and the complete model to compare inputs and outputs.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np

from normalizing_flows.src.realnvp.dataset import CelebADataset
from normalizing_flows.src.realnvp.model import layers, blocks, realnvp_flow


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Test individual layers

In [None]:
cb_bijection = layers.CheckerboardBijection2D(in_channels=12, hidden_channels=128)

inp = torch.rand((8, 12, 32, 32))
out, ldj = cb_bijection(inp)
out = cb_bijection.inverse(out)
print(out.shape, ldj.shape)
torch.allclose(inp, out, atol=1e-6)

In [None]:
cw_bijection = layers.ChannelwiseBijection2D(in_channels=12, hidden_channels=128)

inp = torch.rand((8, 12, 32, 32))
out, ldj = cw_bijection(inp)
out = cw_bijection.inverse(out)
print(out.shape, ldj.shape)
torch.allclose(inp, out, atol=1e-6)

In [None]:
block_bijection = blocks.BlockBijection2D(in_channels=12, hidden_channels=128)

inp = torch.rand((8, 12, 32, 32))
out, ldj = block_bijection(inp)
out = block_bijection.inverse(out)
print(out.shape, ldj.shape)
torch.allclose(inp, out, atol=1e-3)

## Create Model

In [None]:
# Define image transformations
transform = transforms.Compose([
    transforms.CenterCrop(size=128),
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

# Load CelebA dataset
ds_train = CelebADataset(
    root='../data',
    split='train',
    download=True,  # if you have trouble downloading the images, download them manually and move the zip file to ../data/celeba/
    transform=transform
)

# Create dataloader
dl_train = DataLoader(
    ds_train,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    collate_fn=ds_train.collate_fn_skip_errors
)

# Print dataset information
print(f"Train dataset size: {len(ds_train)}")
print(f"Number of batches: {len(dl_train)}")
sample_batch = next(iter(dl_train))[0][:32]
print(f"Batch shape: {sample_batch.shape}")


def show_images(images, nrow=8):
    """Display a grid of images."""
    images = images.cpu()
    # images = images * 0.5 + 0.5  # Denormalize
    grid = torchvision.utils.make_grid(images, nrow=nrow)
    plt.figure(figsize=(15, 15))
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis('off')
    plt.show()


show_images(sample_batch)

In [None]:
model = realnvp_flow.RealNVP(
    in_channels=3,  # RGB images
    size=32,
    hidden_channels=128,
    n_residual_blocks=2,
    final_size=4
)

In [None]:
out, ldj = model(sample_batch)
out = model.inverse(out)
print(out.shape, ldj.shape)
torch.allclose(sample_batch, out, atol=1e-1)

In [None]:
show_images(out)