# IAF
- Inverse Autoregressive Flow
- 내재 변수의 사후확률 분포 $p_\phi(z|x)$ 를 자기회귀적으로 생성하여 고차원 내재공간을 효율적으로 학습하는 VAE 파생형 모델이다.
- 이제부터 본문은 편의상 영어로, 수식은 LaTeX 문법으로 작성할 것이다.

## Citation
- [Improving Variational Inference with Inverse Autoregressive Flow (by Diederik P. Kingma, Ilya Sutskever, et. al)](https://arxiv.org/abs/1606.04934)

# Training
## Hyperparameters

In [1]:
Z_DIM = 8
N_BATCH = 32
N_EPOCH = 200
LEARNING_RATE = 0.0004
R_LOSS_FACTOR = 1000
N_THREAD = 2

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
import random
import numpy as np

In [3]:
# For reproducible experiments
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
# torch.use_deterministic_algorithms(True)

# Get a working device
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


## Load Fashion MNIST

In [4]:
train_data = datasets.FashionMNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor(),
)

train_data, val_data = random_split(train_data, [0.9, 0.1])

train_data_loader = DataLoader(
    train_data,
    batch_size=N_BATCH,
    shuffle=True,
    num_workers=N_THREAD
)
val_data_loader = DataLoader(
    val_data,
    batch_size=N_BATCH,
    shuffle=True,
    num_workers=N_THREAD
)
test_data_loader = DataLoader(
    test_data,
    batch_size=N_BATCH,
    shuffle=True,
    num_workers=N_THREAD
)

print()
print("train_data:", len(train_data_loader.dataset))
print("val_data:", len(val_data_loader.dataset))
print("test_data:", len(test_data_loader.dataset))

100%|██████████| 26.4M/26.4M [00:36<00:00, 730kB/s] 
100%|██████████| 29.5k/29.5k [00:00<00:00, 92.8kB/s]
100%|██████████| 4.42M/4.42M [00:10<00:00, 416kB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 5.15MB/s]


train_data: 54000
val_data: 6000
test_data: 10000





## Helper class
- **Model**: model abstraction
- **Tracker**: save/load model and early stopping
- **Trainer**: handles the model training, validation, and testing

In [None]:
import time
import os
from tqdm import tqdm
from abc import ABC, abstractmethod


class Model(nn.Module, ABC):
    def __init__(self, name):
        super().__init__()
        self.name = name

    @abstractmethod
    def forward(self, X) -> tuple[torch.Tensor, torch.Tensor]:
        """Return prediction and loss tensor"""
        pass

    @abstractmethod
    def loss_fn(self, *args) -> torch.Tensor:
        """Return loss tensor"""
        pass

    def predict(self, X) -> torch.Tensor:
        with torch.no_grad():
            pred, _ = self.forward(X)
        return pred


class Tracker:
    def __init__(self, model: Model, save_path="./checkpoints", patience=3, delta=0.001, verbose=True):
        self._early_stop = False
        self._model = model
        self._patience = patience
        self._verbose = verbose
        self._counter = 0
        self._best_score = -np.inf
        self._delta = delta
        self._save_path = save_path
        self._val_loss_min = np.inf
        os.makedirs(save_path, exist_ok=True)

    def early_stop(self, val_loss, epoch):
        score = -val_loss
        if score < self._best_score + self._delta:
            self._counter += 1
            print(f'\nEarly Stopping counter: {self._counter} out of {self._patience}')
            if self._counter >= self._patience:
                self._early_stop = True
        else:
            self._best_score = score
            if self._verbose:
                print(f'\nSaving model to checkpoint... (Validate loss: {self._val_loss_min:.5f} --> {val_loss:.5f})')
            self.save_checkpoint(epoch=epoch)
            self._val_loss_min = val_loss
            self._counter = 0
        return self._early_stop

    def save_checkpoint(self, epoch=None):
        model_name = self._model.name
        if epoch is None:
            filename = f"{model_name}.pt"
        else:
            filename = f"{model_name}.{epoch}.pt"
        fp = os.path.join(self._save_path, filename)
        torch.save(self._model.state_dict(), fp)

    def load_checkpoint(self, ckpt_name=None) -> Model:
        if ckpt_name is None:
            ckpt_name = f"{self._model.name}.pt"
        ckpt = torch.load(os.path.join(self._save_path, ckpt_name))
        self._model.load_state_dict(ckpt)
        return self._model


class Trainer:
    def __init__(self, model: Model, tracker: Tracker, train_data_loader, val_data_loader, test_data_loader,
                 device):
        self.model = model
        self.device = device
        self.train_data_loader = train_data_loader
        self.val_data_loader = val_data_loader
        self.test_data_loader = test_data_loader
        self.tracker = tracker
        self.optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    def fit(self):
        """Train the model."""
        start = time.time()
        for epoch in range(1, N_EPOCH + 1):
            print(f"Epoch {epoch}/{N_EPOCH}\n-------------------------------")
            train_loss = self._train(epoch)
            val_loss = self._validate()
            early_stop = self.tracker.early_stop(val_loss, epoch)
            print(f"Train loss: {train_loss:.5f}")
            print(f"Validate loss: {val_loss:.5f}")
            print("===============================\n")
            if early_stop:
                print("Early stopping now...")
                break
        self.tracker.save_checkpoint()
        print(f"Training complete! Elapsed time: {time.time() - start:.1f}s")

    def test(self):
        """Evaluate the model's performance."""
        start = time.time()
        loss = self._validate(test=True)
        print(f"\nTest loss: {loss:.5f} \n")
        print(f"Test complete! Elapsed time: {time.time() - start:.1f}s")

    def _train(self, epoch) -> float:
        """Returns average training loss"""
        dataloader = self.train_data_loader
        size = len(dataloader.dataset)
        progress = tqdm(enumerate(dataloader), total=int(size / dataloader.batch_size))
        loss_sum = 0
        self.model.train()
        for batch, (X, _) in progress:
            X = X.to(self.device)
            self.optimizer.zero_grad()
            pred, loss = self.model(X)
            loss = loss.mean()
            loss.backward()
            self.optimizer.step()
            loss = loss.item()
            loss_sum += loss
            if batch % 100 == 0:
                progress.set_description(f"[Train loss: {loss:.5f}]")
        avg_loss = loss_sum / len(dataloader)
        return avg_loss

    def _validate(self, test=False) -> float:
        """Returns average validate/test loss"""
        dataloader = self.train_data_loader
        size = len(dataloader.dataset)
        progress = tqdm(enumerate(dataloader), total=int(size / dataloader.batch_size))
        self.model.eval()
        loss_sum = 0
        with torch.no_grad():
            for batch, (X, _) in progress:
                X = X.to(self.device)
                pred, loss = self.model(X)
                loss = loss.mean().item()
                loss_sum += loss
                if batch % 100 == 0:
                    if test:
                        progress.set_description(f"[Validate loss: {loss:.5f}]")
                    else:
                        progress.set_description(f"[Test loss: {loss:.5f}]")
        avg_loss = loss_sum / len(dataloader)
        return avg_loss


## Model

TODO

In [None]:
class InverseAutoregressiveFlow(Model):
    def __init__(self):
        super().__init__(name="IAF")
        pass

    def forward(self, x):
        pass

    def loss_fn(self, x):
        pass


model = InverseAutoregressiveFlow().to(device)

print(model)

In [None]:
tracker = Tracker(model=model)
trainer = Trainer(model, tracker, train_data_loader, val_data_loader, test_data_loader, device=device)
trainer.fit()

In [None]:
trainer.test()

## Load model from disk

In [None]:
# Load from disk
model = tracker.load_checkpoint()
model.eval()

# Relocate tensors to cpu
device = 'cpu'
trainer.device = device
model = model.to(device)