In [1]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import Subset
from torch import Tensor
from typing import Tuple, Callable
from itertools import chain
import copy
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch.cuda.amp import GradScaler

In [2]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

scaler = GradScaler(device)

Using cuda device


  scaler = GradScaler(device)


In [3]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root='../data',
    train=True,
    # download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root='../data',
    train=False,
    # download=True,
    transform=ToTensor(),
)

In [4]:
target_label = 0
training_data_incides = torch.where(training_data.targets == target_label)[0]
test_data_incides = torch.where(test_data.targets == target_label)[0]

training_data = Subset(training_data, training_data_incides)
test_data = Subset(test_data, test_data_incides)

In [5]:
# Define validation split fraction
val_fraction = 0.1
dataset_size = len(training_data)
val_size = int(val_fraction * dataset_size)
train_size = dataset_size - val_size

# Create a reproducible shuffled list of indices
generator = torch.Generator().manual_seed(42)
indices = torch.randperm(dataset_size, generator=generator).tolist()

# Split indices for train and validation
train_indices = indices[:train_size]
val_indices = indices[train_size:]

# Wrap Subsets for train and validation datasets
train_data = Subset(training_data, train_indices)
val_data = Subset(training_data, val_indices)

In [6]:
def preprocess_data(x):
    """Preprocess images for normalizing flow training"""
    # Add uniform noise for dequantization (important for discrete data)
    x = x + torch.rand_like(x) / 256.0
    # Logit transform to map [0,1] to real numbers
    x = torch.clamp(x, 1e-6, 1 - 1e-6)  # Avoid log(0)
    x = torch.log(x) - torch.log(1 - x)  # logit transform
    return x

In [7]:
class DataLoaderWrapper:
    def __init__(self, dl, func):
        self.dl = dl
        self.func = func

    def __len__(self):
        return len(self.dl)

    def __iter__(self):
        for b in self.dl:
            yield (self.func(*b))


def to_device(x: Tensor, y: Tensor) -> Tensor:
    return x.to(device), y.to(device)

In [8]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoaderWrapper(DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True), to_device)
val_dataloader = DataLoaderWrapper(DataLoader(training_data, batch_size=batch_size), to_device)
test_dataloader = DataLoaderWrapper(DataLoader(test_data, batch_size=batch_size), to_device)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


$$
z_B = \exp\left(-s(z_A)\right) \odot \left(x_B - b(z_A)\right)  
$$

$$
J = 
\begin{bmatrix}
I_d & 0 \\
\frac{\partial z_B}{\partial x_A} & \mathrm{diag}\big(\exp(-s)\big)
\end{bmatrix}
$$

$$
x_B = \exp\big(s(z_A, w)\big) \odot z_B + b(z_A, w)
$$

In [9]:
class CouplingLayer(nn.Module):
    def __init__(
        self,
        split_at: int,
        scale_net: nn.Module, # s
        shift_net: nn.Module, # b
        alternate_parts: bool = False
    ) -> None:
        super().__init__()
        self.split_at = split_at
        self.scale_net = scale_net
        self.shift_net = shift_net
        self.alternate_parts = alternate_parts
    

    def _split(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        if self.alternate_parts:
            return x[:, self.split_at:], x[:, :self.split_at]
        else:
            return x[:, :self.split_at], x[:, self.split_at:]


    def _merge(self, xA: Tensor, xB: Tensor) -> Tensor:
        if self.alternate_parts:
            return torch.cat((xB, xA), dim=1)
        else:
            return torch.cat((xA, xB), dim=1)


    def _get_scale_and_shift(self, zA: Tensor) -> Tuple[Tensor, Tensor]:
        log_scale = self.scale_net(zA)
        log_scale = torch.clamp(log_scale, min=-5, max=3)
        shift = self.shift_net(zA)
        return log_scale, shift


    def forward(self, x: Tensor, log_det_total: Tensor) -> Tuple[Tensor, Tensor]:
        xA, xB = self._split(x)
        zA = xA
        log_scale, shift = self._get_scale_and_shift(zA)

        zB = torch.exp(-log_scale) * (xB - shift)
        z = self._merge(zA, zB)

        log_det_current = -torch.sum(log_scale, dim=1)
        log_det_total = log_det_total + log_det_current
        return z, log_det_total


    def inverse(self, z: Tensor) -> Tensor:
        zA, zB = self._split(z)
        xA = zA
        log_scale, shift = self._get_scale_and_shift(zA)

        xB = torch.exp(log_scale) * (zB + shift)
        x = self._merge(xA, xB)
        return x

In [None]:
class MLP(nn.Module):
    def __init__(self,
        layers_dims: list,
        activation_layer: nn.Module = nn.ReLU(),
        bias: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__()
        assert len(layers_dims) > 2
        layers = []
        
        for in_features, out_features in zip(layers_dims[:-2], layers_dims[1:-1]):
            layers.append(nn.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype))
            layers.append(copy.deepcopy(activation_layer))
        
        layers.append(nn.Linear(layers_dims[-2], layers_dims[-1], bias=bias, device=device, dtype=dtype))

        # Init last layer to small values for stability
        nn.init.zeros_(layers[-1].weight)
        if bias:
            nn.init.zeros_(layers[-1].bias)
            
        self.model = nn.Sequential(*layers)
        

    def forward(self, x) -> Tensor:
        return self.model(x)
        