In [None]:
!pip install lightning --quiet
!pip install -U huggingface-hub --quiet
!pip install datasets --quiet

In [None]:
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import torch.utils as utils
from torch.utils.data import DataLoader, random_split

import torchvision
import torchvision.transforms as T

import datasets
from datasets import load_dataset

import lightning as L
from lightning import seed_everything
from lightning.pytorch.callbacks import TQDMProgressBar

import matplotlib as  mpl
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
transform = T.Compose([
    T.Grayscale(num_output_channels=1),
])

In [None]:
dataset_dict = load_dataset('fashion_mnist')
dataset_dict = dataset_dict.filter(lambda x: x['label'] == 7)
dataset_dict = dataset_dict.map(lambda x: {'image': transform(x['image'])})
dataset_dict = dataset_dict.remove_columns(['label'])
dataset_dict.set_format('torch')

In [None]:
dataset = dataset_dict['train']

In [None]:
dataset[0]['image'].shape

In [None]:
class Coupling(nn.Module):
    def __init__(self):
        super().__init__()

        self.s_net = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=(3, 3), padding=1),
            nn.Tanh()
        )
        self.t_net = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=(3, 3), padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        s = self.s_net(x)
        t = self.t_net(x)
        return s, t

In [None]:
class RealNVP(L.LightningModule):
    def __init__(self, num_couplings):
        super().__init__()
        self.num_couplings = num_couplings
        self.coupling = nn.ModuleList([Coupling() for _ in range(num_couplings)])
        self.distribution = torch.distributions.Normal(0, 1)

    def forward(self, x, direction=1):
        log_det = 0
        mask = torch.cat(
            [torch.zeros(14, 28), torch.ones(14, 28)], dim=0
        ).to(self.device)
        for coupling in self.coupling[::direction]:
            x_masked = x * mask
            mask = 1 - mask
            s, t = coupling(x_masked)
            s = s * mask
            t = t * mask
            if direction == 1:
                x = x * torch.exp(s) + t
            else:
                x = (x - t) * torch.exp(-s)
            log_det += torch.sum(s)
        return x, log_det

    def training_step(self, batch, batch_idx):
        x = batch['image']
        batch_size = x.shape[0]
        z, log_det = self(x)
        log_prob = self.distribution.log_prob(z).sum(dim=(1, 2, 3))
        loss = - log_det.mean() - log_prob.mean()
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=2e-4)

In [None]:
dataloader = DataLoader(dataset, batch_size=1024, shuffle=True)

In [None]:
trainer = L.Trainer(
    max_epochs=50,
    callbacks=[TQDMProgressBar(refresh_rate=2)],
)
model = RealNVP(num_couplings=8)

trainer.fit(model, dataloader)

In [None]:
model.eval()
with torch.no_grad():
    sample = torch.randn(1, 1, 28, 28).to(model.device)
    sample, _ = model(sample, direction=-1)
    sample = sample.cpu()

In [None]:
torchvision.transforms.functional.to_pil_image(sample[0, 0])