In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import Subset
from torch import Tensor
from typing import Tuple

In [2]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root='../data',
    train=True,
    # download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root='../data',
    train=False,
    # download=True,
    transform=ToTensor(),
)

In [3]:
target_label = 0
training_data_incides = torch.where(training_data.targets == target_label)[0]
test_data_incides = torch.where(test_data.targets == target_label)[0]

training_data = Subset(training_data, training_data_incides)
test_data = Subset(test_data, test_data_incides)

In [4]:
class DataLoaderWrapper:
    def __init__(self, dl: DataLoader, desiried_labels: torch.Tensor):
        self.dl = dl
        self.desiried_labels = torch.tensor(desiried_labels)

    def __len__(self):
        return len(self.dl)
    
    def __iter__(self):
        for X, y in self.dl:
            mask = torch.isin(y, self.desiried_labels)
            yield X[mask], y[mask]

In [5]:
batch_size = 32

# Create data loaders.
train_dataloader = DataLoaderWrapper(DataLoader(training_data, batch_size=batch_size), [0])
test_dataloader = DataLoaderWrapper(DataLoader(test_data, batch_size=batch_size), [0])

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([32, 1, 28, 28])
Shape of y: torch.Size([32]) torch.int64


$$
z_B = \exp\left(-s(z_A)\right) \odot \left(x_B - b(z_A)\right)  
$$

$$
J = 
\begin{bmatrix}
I_d & 0 \\
\frac{\partial z_B}{\partial x_A} & \mathrm{diag}\big(\exp(-s)\big)
\end{bmatrix}
$$

$$
x_B = \exp\big(s(z_A, w)\big) \odot z_B + b(z_A, w)
$$

In [None]:
class CouplingLayer(nn.Module):
    def __init__(
        self,
        split_at: int,
        scale_net: nn.Module, # s
        shift_net: nn.Module, # b
        alternate_parts: bool = False
    ) -> None:
        super().__init__()
        self.split_at = split_at
        self.scale_net = scale_net
        self.shift_net = shift_net
        self.alternate_parts = alternate_parts
    

    def _split(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        if self.alternate_parts:
            return x[:, self.split_at:], x[:, :self.split_at]
        else:
            return x[:, :self.split_at], x[:, self.split_at:]


    def _merge(self, xA: Tensor, xB: Tensor) -> Tensor:
        if self.alternate_parts:
            return torch.cat((xB, xA), dim=1)
        else:
            return torch.cat((xA, xB), dim=1)


    def _get_scale_and_shift(self, zA: Tensor) -> Tuple[Tensor, Tensor]:
        log_scale = self.scale_net(zA)
        log_scale = torch.clamp(log_scale, min=self.log_scale_min, max=self.log_scale_max)
        shift = self.shift_net(zA)
        return log_scale, shift


    def forward(self, x: Tensor, log_det_total: Tensor) -> Tuple[Tensor, Tensor]:
        xA, xB = self._split(x)
        zA = xA
        log_scale, shift = self._get_scale_and_shift(zA)

        zB = torch.exp(-log_scale) * (xB - shift)
        z = self._merge(zA, zB)

        log_det_current = -torch.sum(log_scale, dim=1)
        log_det_total = log_det_total + log_det_current
        return z, log_det_total


    def inverse_forward(self, z: Tensor) -> Tensor:
        zA, zB = self._split(z)
        xA = zA
        log_scale, shift = self._get_scale_and_shift(zA)

        xB = torch.exp(log_scale) * (zB + shift)
        x = self._merge(xA, xB)
        return x

