# IAF
- Inverse Autoregressive Flow
- 내재 변수의 사후확률 분포 $p_\phi(z|x)$ 를 자기회귀적으로 생성하여 고차원 내재공간을 효율적으로 학습하는 VAE 파생형 모델이다.
- 이제부터 본문은 편의상 영어로, 수식은 LaTeX 문법으로 작성할 것이다.

## Citation
- [Improving Variational Inference with Inverse Autoregressive Flow (by Diederik P. Kingma, Ilya Sutskever, et. al)](https://arxiv.org/abs/1606.04934)

# Training
## Hyperparameters

In [4]:
# Dataset options
N_BATCH = 32
N_THREAD = 2

# Learning options
N_EPOCH = 40
LEARNING_RATE = 0.0005
R_LOSS_FACTOR = 1000

# Model configuration
Z_DIM = 32  # latent space dimension
H_DIM = 128  # residual block dimension
CONV_OPTIONS = {
    'n_layers': 4,
    'in_channels': [1, 32, 64, 64],
    'out_channels': [32, 64, 64, 64],
    'kernel_sizes': [3, 3, 3, 3],
    'strides': [1, 2, 2, 1],
    'padding': [1, 1, 1, 1],
    'feature_shape': (64, 7, 7),
}

In [5]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
import random
import numpy as np

In [6]:
# For reproducible experiments
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
# torch.use_deterministic_algorithms(True)

# Get a working device
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


## Load Fashion MNIST

In [7]:
train_data = datasets.FashionMNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor(),
)

train_data, val_data = random_split(train_data, [0.9, 0.1])

train_data_loader = DataLoader(
    train_data,
    batch_size=N_BATCH,
    shuffle=True,
    num_workers=N_THREAD
)
val_data_loader = DataLoader(
    val_data,
    batch_size=N_BATCH,
    shuffle=True,
    num_workers=N_THREAD
)
test_data_loader = DataLoader(
    test_data,
    batch_size=N_BATCH,
    shuffle=True,
    num_workers=N_THREAD
)

print()
print("train_data:", len(train_data_loader.dataset))
print("val_data:", len(val_data_loader.dataset))
print("test_data:", len(test_data_loader.dataset))


train_data: 54000
val_data: 6000
test_data: 10000


## Helper class
- **Model**: model abstraction
- **Tracker**: save/load model and early stopping
- **Trainer**: handles the model training, validation, and testing

In [8]:
import time
import os
from tqdm import tqdm
from abc import ABC, abstractmethod


class Model(nn.Module, ABC):
    def __init__(self, name):
        super().__init__()
        self.name = name

    @abstractmethod
    def forward(self, X) -> tuple[torch.Tensor, torch.Tensor]:
        """Return prediction and loss tensor"""
        pass

    @abstractmethod
    def loss_fn(self, *args) -> torch.Tensor:
        """Return loss tensor"""
        pass

    def predict(self, X) -> torch.Tensor:
        with torch.no_grad():
            pred, _ = self.forward(X)
        return pred


class Tracker:
    def __init__(self, model: Model, save_path="./checkpoints", patience=3, delta=0.001, verbose=True):
        self._early_stop = False
        self._model = model
        self._patience = patience
        self._verbose = verbose
        self._counter = 0
        self._best_score = -np.inf
        self._delta = delta
        self._save_path = save_path
        self._val_loss_min = np.inf
        os.makedirs(save_path, exist_ok=True)

    def early_stop(self, val_loss, epoch):
        score = -val_loss
        if score < self._best_score + self._delta:
            self._counter += 1
            print(f'\nEarly Stopping counter: {self._counter} out of {self._patience}')
            if self._counter >= self._patience:
                self._early_stop = True
        else:
            self._best_score = score
            if self._verbose:
                print(f'\nSaving model to checkpoint... (Validate loss: {self._val_loss_min:.5f} --> {val_loss:.5f})')
            self.save_checkpoint(epoch=epoch)
            self._val_loss_min = val_loss
            self._counter = 0
        return self._early_stop

    def save_checkpoint(self, epoch=None):
        model_name = self._model.name
        if epoch is None:
            filename = f"{model_name}.pt"
        else:
            filename = f"{model_name}.{epoch}.pt"
        fp = os.path.join(self._save_path, filename)
        torch.save(self._model.state_dict(), fp)

    def load_checkpoint(self, ckpt_name=None) -> Model:
        if ckpt_name is None:
            ckpt_name = f"{self._model.name}.pt"
        ckpt = torch.load(os.path.join(self._save_path, ckpt_name))
        self._model.load_state_dict(ckpt)
        return self._model


class Trainer:
    def __init__(self, model: Model, tracker: Tracker, train_data_loader, val_data_loader, test_data_loader,
                 device):
        self.model = model
        self.device = device
        self.train_data_loader = train_data_loader
        self.val_data_loader = val_data_loader
        self.test_data_loader = test_data_loader
        self.tracker = tracker
        self.optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    def fit(self):
        """Train the model."""
        start = time.time()
        for epoch in range(1, N_EPOCH + 1):
            print(f"Epoch {epoch}/{N_EPOCH}\n-------------------------------")
            train_loss = self._train(epoch)
            val_loss = self._validate()
            early_stop = self.tracker.early_stop(val_loss, epoch)
            print(f"Train loss: {train_loss:.5f}")
            print(f"Validate loss: {val_loss:.5f}")
            print("===============================\n")
            if early_stop:
                print("Early stopping now...")
                break
        self.tracker.save_checkpoint()
        print(f"Training complete! Elapsed time: {time.time() - start:.1f}s")

    def test(self):
        """Evaluate the model's performance."""
        start = time.time()
        loss = self._validate(test=True)
        print(f"\nTest loss: {loss:.5f} \n")
        print(f"Test complete! Elapsed time: {time.time() - start:.1f}s")

    def _train(self, epoch) -> float:
        """Returns average training loss"""
        dataloader = self.train_data_loader
        size = len(dataloader.dataset)
        progress = tqdm(enumerate(dataloader), total=int(size / dataloader.batch_size))
        loss_sum = 0
        self.model.train()
        for batch, (X, _) in progress:
            X = X.to(self.device)
            self.optimizer.zero_grad()
            pred, loss = self.model(X)
            loss = loss.mean()
            loss.backward()
            self.optimizer.step()
            loss = loss.item()
            loss_sum += loss
            if batch % 100 == 0:
                progress.set_description(f"[Train loss: {loss:.5f}]")
        avg_loss = loss_sum / len(dataloader)
        return avg_loss

    def _validate(self, test=False) -> float:
        """Returns average validate/test loss"""
        dataloader = self.train_data_loader
        size = len(dataloader.dataset)
        progress = tqdm(enumerate(dataloader), total=int(size / dataloader.batch_size))
        self.model.eval()
        loss_sum = 0
        with torch.no_grad():
            for batch, (X, _) in progress:
                X = X.to(self.device)
                pred, loss = self.model(X)
                loss = loss.mean().item()
                loss_sum += loss
                if batch % 100 == 0:
                    if test:
                        progress.set_description(f"[Validate loss: {loss:.5f}]")
                    else:
                        progress.set_description(f"[Test loss: {loss:.5f}]")
        avg_loss = loss_sum / len(dataloader)
        return avg_loss


## Model

TODO

In [11]:
class Encoder(nn.Module):
    def __init__(self, conv_options, z_dim: int, h_dim: int, batch_norm=True, dropout=0.1):
        super().__init__()
        self.z_dim = z_dim
        self.h_dim = h_dim
        self.activation = nn.LeakyReLU()
        self.drop_out = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
        self.conv_layer = nn.ModuleList()
        self.batch_norm = nn.ModuleList()
        self.flatten = nn.Flatten()
        self.mu = nn.LazyLinear(self.z_dim)
        self.log_var = nn.LazyLinear(self.z_dim)
        # self.hidden = nn.LazyLinear(self.h_dim)

        for i in range(conv_options['n_layers']):
            layer = nn.LazyConv2d(
                out_channels=conv_options['out_channels'][i],
                kernel_size=conv_options['kernel_sizes'][i],
                stride=conv_options['strides'][i],
                padding=conv_options['padding'][i]
            )
            self.conv_layer.append(layer)
            if batch_norm:
                C = conv_options['out_channels'][i]
                self.batch_norm.append(nn.BatchNorm2d(num_features=C))
            else:
                self.batch_norm.append(nn.Identity())

    def forward(self, x):
        for i, conv2d in enumerate(self.conv_layer):
            x = conv2d(x)
            x = self.batch_norm[i](x)
            x = self.activation(x)
            x = self.drop_out(x)

        x = self.flatten(x)
        mu = self.mu(x)
        log_var = self.log_var(x)
        # hidden = self.hidden(x)
        return mu, log_var


class GaussianSampler(nn.Module):
    def __init__(self, z_dim: int):
        super().__init__()
        self.z_dim = z_dim

    def forward(self, mu, log_var):
        epsilon = torch.randn_like(mu)
        z = mu + torch.exp(0.5 * log_var) * epsilon
        return z


class Decoder(nn.Module):
    def __init__(self, conv_options, z_dim: int, batch_norm=True, dropout=0.25):
        super().__init__()
        self.relu = nn.LeakyReLU()
        self.drop_out = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
        self.conv_layer = nn.ModuleList()
        self.batch_norm = nn.ModuleList()
        self.sigmoid = nn.Sigmoid()
        self.to_feature_shape = nn.Linear(z_dim, np.prod(conv_options['feature_shape']))
        self.unflatten = nn.Unflatten(1, conv_options['feature_shape'])

        in_channels = conv_options['in_channels'][::-1]
        kernel_sizes = conv_options['kernel_sizes'][::-1]
        strides = conv_options['strides'][::-1]
        padding = conv_options['padding'][::-1]

        for i in range(conv_options['n_layers']):
            layer = nn.LazyConvTranspose2d(
                out_channels=in_channels[i],
                kernel_size=kernel_sizes[i],
                stride=strides[i],
                padding=padding[i],
                # fix ambiguity of conv2d output with stride > 1
                output_padding=1 if strides[i] > 1 else 0
            )
            self.conv_layer.append(layer)
            if batch_norm:
                # Don't apply batch_norm on the last layer
                if i + 1 < conv_options['n_layers']:
                    C = in_channels[i]
                    self.batch_norm.append(nn.BatchNorm2d(num_features=C))
                    continue
            self.batch_norm.append(nn.Identity())

    def forward(self, z):
        x = self.to_feature_shape(z)
        x = self.unflatten(x)

        for i, convT2d in enumerate(self.conv_layer):
            x = convT2d(x)
            if i + 1 < len(self.conv_layer):
                x = self.batch_norm[i](x)
                x = self.relu(x)
                x = self.drop_out(x)
            else:
                x = self.sigmoid(x)
        return x


class ResidualBlock(nn.Module):
    def __init__(self, out_channels, kernel_size=3, stride=1):
        """
        Never resize the `out_channels` unless `stride` is greater than 1.
        """
        super().__init__()
        if stride > 1:
            self.shortcut = nn.LazyConv2d(out_channels, kernel_size=1, stride=stride)
        else:
            self.shortcut = nn.Identity()
        self.conv_1 = nn.LazyConv2d(out_channels, kernel_size, stride=stride, padding=1)
        self.conv_2 = nn.LazyConv2d(out_channels, kernel_size, stride=1, padding=1)
        self.relu = nn.LeakyReLU()
        self.bn_1 = nn.LazyBatchNorm2d()
        self.bn_2 = nn.LazyBatchNorm2d()

    def forward(self, x):
        res = x                 # residual
        res = self.conv_1(res)
        res = self.bn_1(res)
        res = self.relu(res)
        res = self.conv_2(res)
        res = self.bn_2(res)
        x = self.shortcut(x)    # skip connection
        y = res + x             # H(x) = F(x) + x
        y = self.relu(y)
        return y


class InverseAutoregressiveFlow(Model):
    def __init__(self, conv_options, z_dim, h_dim, r_loss_factor, batch_norm=True, dropout=0.2):
        super().__init__(name="IAF")
        self.encoder = Encoder(conv_options, z_dim, h_dim, batch_norm, dropout)
        self.sampler = GaussianSampler(z_dim)
        self.decoder = Decoder(conv_options, z_dim, batch_norm, dropout)
        self.r_loss_factor = r_loss_factor

    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.sampler(mu, log_var)
        x_hat = self.decoder(z)
        loss = self.loss_fn(x, x_hat, mu, log_var)
        return x_hat, loss

    def loss_fn(self, x, x_hat, mu, log_var):
        reconstruction_loss = torch.mean(torch.square(x_hat - x), dim=(1, 2, 3))
        kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=tuple(range(1, mu.ndim)))
        return self.r_loss_factor * reconstruction_loss + kl_divergence


model = InverseAutoregressiveFlow(
    conv_options=CONV_OPTIONS,
    z_dim=Z_DIM,
    h_dim=H_DIM,
    r_loss_factor=R_LOSS_FACTOR
)
model = model.to(device)

print(model)

InverseAutoregressiveFlow(
  (encoder): Encoder(
    (activation): LeakyReLU(negative_slope=0.01)
    (drop_out): Dropout2d(p=0.2, inplace=False)
    (conv_layer): ModuleList(
      (0): LazyConv2d(0, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1-2): 2 x LazyConv2d(0, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (3): LazyConv2d(0, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (batch_norm): ModuleList(
      (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1-3): 3 x BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (mu): LazyLinear(in_features=0, out_features=32, bias=True)
    (log_var): LazyLinear(in_features=0, out_features=32, bias=True)
  )
  (sampler): VariationalSampler()
  (decoder): Decoder(
    (activation): LeakyReLU(negative_slope=0.01)
    (drop_out): Dropout2d(p=0.2, inplace=False)
    (conv_

In [12]:
tracker = Tracker(model=model)
trainer = Trainer(model, tracker, train_data_loader, val_data_loader, test_data_loader, device=device)
trainer.fit()

Epoch 1/200
-------------------------------


[Train loss: 39.52104]: : 1688it [00:12, 137.08it/s]                        
[Test loss: 34.12180]: : 1688it [00:04, 342.95it/s]                        


Saving model to checkpoint... (Validate loss: inf --> 36.24759)
Train loss: 47.27183
Validate loss: 36.24759

Epoch 2/200
-------------------------------



[Train loss: 35.76083]: : 1688it [00:12, 136.19it/s]                        
[Test loss: 32.68049]: : 1688it [00:06, 271.09it/s]                        



Saving model to checkpoint... (Validate loss: 36.24759 --> 34.09967)
Train loss: 37.22656
Validate loss: 34.09967

Epoch 3/200
-------------------------------


[Train loss: 34.09849]: : 1688it [00:12, 139.23it/s]                        
[Test loss: 31.89107]: : 1688it [00:06, 278.28it/s]                        


Saving model to checkpoint... (Validate loss: 34.09967 --> 32.65034)
Train loss: 35.53799
Validate loss: 32.65034

Epoch 4/200
-------------------------------



[Train loss: 36.04593]: : 1688it [00:14, 119.20it/s]                        
[Test loss: 31.45315]: : 1688it [00:06, 270.38it/s]                        


Saving model to checkpoint... (Validate loss: 32.65034 --> 32.22507)
Train loss: 34.60477
Validate loss: 32.22507

Epoch 5/200
-------------------------------



[Train loss: 32.91176]: : 1688it [00:11, 153.41it/s]                        
[Test loss: 33.68723]: : 1688it [00:04, 340.57it/s]                        


Early Stopping counter: 1 out of 3
Train loss: 33.99758
Validate loss: 32.22496

Epoch 6/200
-------------------------------



[Train loss: 31.11660]: : 1688it [00:11, 147.87it/s]                        
[Test loss: 31.33584]: : 1688it [00:04, 337.61it/s]                        



Saving model to checkpoint... (Validate loss: 32.22507 --> 31.09518)
Train loss: 33.49625
Validate loss: 31.09518

Epoch 7/200
-------------------------------


[Train loss: 34.93422]: : 1688it [00:11, 147.91it/s]                        
[Test loss: 28.69375]: : 1688it [00:05, 307.97it/s]                        


Saving model to checkpoint... (Validate loss: 31.09518 --> 30.96019)
Train loss: 33.16609
Validate loss: 30.96019

Epoch 8/200
-------------------------------



[Train loss: 33.24113]: : 1688it [00:11, 142.93it/s]                        
[Test loss: 29.06083]: : 1688it [00:04, 341.89it/s]                        


Early Stopping counter: 1 out of 3
Train loss: 32.89965
Validate loss: 31.55124

Epoch 9/200
-------------------------------



[Train loss: 30.59657]: : 1688it [00:11, 148.26it/s]                        
[Test loss: 31.77951]: : 1688it [00:05, 317.73it/s]                        


Saving model to checkpoint... (Validate loss: 30.96019 --> 30.79113)
Train loss: 32.61495
Validate loss: 30.79113

Epoch 10/200
-------------------------------



[Train loss: 32.24335]: : 1688it [00:11, 144.11it/s]                        
[Test loss: 27.38361]: : 1688it [00:04, 350.99it/s]                        


Saving model to checkpoint... (Validate loss: 30.79113 --> 30.59121)
Train loss: 32.45739
Validate loss: 30.59121

Epoch 11/200
-------------------------------



[Train loss: 32.21578]: : 1688it [00:11, 149.22it/s]                        
[Test loss: 30.98756]: : 1688it [00:05, 336.85it/s]                        


Saving model to checkpoint... (Validate loss: 30.59121 --> 30.26175)
Train loss: 32.27373
Validate loss: 30.26175

Epoch 12/200
-------------------------------



[Train loss: 32.50203]: : 1688it [00:11, 147.23it/s]                        
[Test loss: 25.89852]: : 1688it [00:04, 346.60it/s]                        


Saving model to checkpoint... (Validate loss: 30.26175 --> 30.03311)
Train loss: 32.10883
Validate loss: 30.03311

Epoch 13/200
-------------------------------



[Train loss: 32.66205]: : 1688it [00:10, 155.21it/s]                        
[Test loss: 30.74993]: : 1688it [00:04, 354.44it/s]                        


Early Stopping counter: 1 out of 3
Train loss: 32.01613
Validate loss: 30.25191

Epoch 14/200
-------------------------------



[Train loss: 37.60696]: : 1688it [00:11, 143.91it/s]                        
[Test loss: 30.28551]: : 1688it [00:05, 335.80it/s]                        


Saving model to checkpoint... (Validate loss: 30.03311 --> 29.98724)
Train loss: 31.88461
Validate loss: 29.98724

Epoch 15/200
-------------------------------



[Train loss: 30.82298]: : 1688it [00:11, 142.98it/s]                        
[Test loss: 31.36378]: : 1688it [00:04, 360.22it/s]                        



Saving model to checkpoint... (Validate loss: 29.98724 --> 29.76610)
Train loss: 31.78623
Validate loss: 29.76610

Epoch 16/200
-------------------------------


[Train loss: 31.52369]: : 1688it [00:11, 144.46it/s]                        
[Test loss: 28.15342]: : 1688it [00:04, 350.06it/s]                        



Saving model to checkpoint... (Validate loss: 29.76610 --> 29.51611)
Train loss: 31.71532
Validate loss: 29.51611

Epoch 17/200
-------------------------------


[Train loss: 32.79802]: : 1688it [00:11, 151.50it/s]                        
[Test loss: 29.42417]: : 1688it [00:05, 319.97it/s]                        


Early Stopping counter: 1 out of 3
Train loss: 31.61036
Validate loss: 29.66283

Epoch 18/200
-------------------------------



[Train loss: 29.94351]: : 1688it [00:16, 102.73it/s]                        
[Test loss: 28.19838]: : 1688it [00:09, 175.85it/s]                        


Early Stopping counter: 2 out of 3
Train loss: 31.51767
Validate loss: 29.85725

Epoch 19/200
-------------------------------



[Train loss: 31.46125]: : 1688it [00:18, 91.20it/s]                         
[Test loss: 29.86788]: : 1688it [00:09, 186.51it/s]                        


Saving model to checkpoint... (Validate loss: 29.51611 --> 29.46683)
Train loss: 31.45978
Validate loss: 29.46683

Epoch 20/200
-------------------------------



[Train loss: 29.23955]: : 1688it [00:18, 89.44it/s]                         
[Test loss: 26.99115]: : 1688it [00:08, 208.26it/s]                        


Early Stopping counter: 1 out of 3
Train loss: 31.37956
Validate loss: 29.63823

Epoch 21/200
-------------------------------



[Train loss: 34.49886]: : 1688it [00:18, 89.74it/s]                         
[Test loss: 29.63166]: : 1688it [00:09, 183.33it/s]                        


Saving model to checkpoint... (Validate loss: 29.46683 --> 29.41381)
Train loss: 31.31680
Validate loss: 29.41381

Epoch 22/200
-------------------------------



[Train loss: 28.39809]: : 1688it [00:18, 90.51it/s]                         
[Test loss: 33.11047]: : 1688it [00:08, 192.22it/s]                        


Early Stopping counter: 1 out of 3
Train loss: 31.21849
Validate loss: 29.54057

Epoch 23/200
-------------------------------



[Train loss: 32.41564]: : 1688it [00:18, 92.39it/s]                         
[Test loss: 32.28059]: : 1688it [00:08, 202.71it/s]                        


Early Stopping counter: 2 out of 3
Train loss: 31.19972
Validate loss: 30.02030

Epoch 24/200
-------------------------------



[Train loss: 25.75775]: : 1688it [00:18, 90.18it/s]                         
[Test loss: 29.23729]: : 1688it [00:08, 202.58it/s]                        


Early Stopping counter: 3 out of 3
Train loss: 31.14254
Validate loss: 29.49472

Early stopping now...
Training complete! Elapsed time: 479.6s





In [13]:
trainer.test()

[Validate loss: 29.32512]: : 1688it [00:05, 331.55it/s]                        


Test loss: 29.45660 

Test complete! Elapsed time: 5.1s





## Load model from disk

In [14]:
# Load from disk
model = tracker.load_checkpoint()
model.eval()

# Relocate tensors to cpu
device = 'cpu'
trainer.device = device
model = model.to(device)