<a href="https://colab.research.google.com/github/Yyzhang2000/mini-MoE/blob/main/ng/01_glow_cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from matplotlib import pyplot as plt
from tqdm import tqdm

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
batch_size = 128
data_dir = "./data"

cpu


In [7]:
class Jitter:
    def __init__(self, scale = 1.0 / 256):
        self.scale = scale

    def __call__(self, x):
        eps = torch.rand_like(x) * self.scale
        x_ = x + eps
        return x_
class Scale:
    def __init__(self, scale = 255.0 / 256.0):
        self.scale = scale

    def __call__(self, x):
        return x * self.scale

In [8]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    Scale(255. / 256.),
    Jitter(1 / 256.)
    ])
train_data = torchvision.datasets.CIFAR10(data_dir, train=True,
                                 download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
                                           drop_last=True)

100%|██████████| 170M/170M [00:03<00:00, 43.4MB/s]


In [9]:
class Flow(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, z):
        raise NotImplementedError("Forward pass has not been implemented.")

    def inverse(self, z):
        raise NotImplementedError("This flow has no algebraic inverse.")

class GlowBlock(nn.Module):
    def __init__(
            self,
            channels,
            hidden_channels,
            scale = True,
            scale_map = 'sigmoid',
            split_mode = 'channel',
            leaky = 0.0,
            init_zeros = True,
            use_lu = True,
            net_actnorm = False
    ):
        super().__init__()
        self.flows = nn.ModuleList([])
        # Coupling layer
        kernel_size = (3, 1, 3)
        num_param = 2 if scale else 1
        if "channel" == split_mode:
            channels_ = ((channels + 1) // 2,) + 2 * (hidden_channels,)
            channels_ += (num_param * (channels // 2),)
        elif "channel_inv" == split_mode:
            channels_ = (channels // 2,) + 2 * (hidden_channels,)
            channels_ += (num_param * ((channels + 1) // 2),)
        elif "checkerboard" in split_mode:
            channels_ = (channels,) + 2 * (hidden_channels,)
            channels_ += (num_param * channels,)
        else:
            raise NotImplementedError("Mode " + split_mode + " is not implemented.")
        param_map = nets.ConvNet2d(
            channels_, kernel_size, leaky, init_zeros, actnorm=net_actnorm
        )
        self.flows += [AffineCouplingBlock(param_map, scale, scale_map, split_mode)]
        # Invertible 1x1 convolution
        if channels > 1:
            self.flows += [Invertible1x1Conv(channels, use_lu)]
        # Activation normalization
        self.flows += [ActNorm((channels,) + (1, 1))]

    def forward(self, z):
        log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device)
        for flow in self.flows:
            z, log_det = flow(z)
            log_det_tot += log_det
        return z, log_det_tot

    def inverse(self, z):
        log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device)
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z)
            log_det_tot += log_det
        return z, log_det_tot