In [22]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms.functional import center_crop, resize, pad


In [23]:
class reslayer(nn.Module):
    def __init__(self, n_channels):
        super(reslayer, self).__init__()
        self.conv1 = nn.Conv2d(n_channels, n_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(n_channels, n_channels, 3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        return x

In [24]:
class Snet(nn.Module):
    def __init__(self, in_channels, mid_channels, n_layers):
        super(Snet, self).__init__()
        in_layer = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, 5, padding=2),
            nn.ReLU()
        )
        out_layer = nn.Sequential(
            nn.Conv2d(mid_channels, in_channels, 5, padding=2),
            nn.ReLU()
        )
        self.net = nn.Sequential(in_layer, *[reslayer(mid_channels) for _ in range(n_layers)], out_layer)

    def forward(self, x):
        return self.net(x)
    
class Tnet(nn.Module):
    def __init__(self, in_channels, mid_channels, n_layers):
        super(Tnet, self).__init__()
        in_layer = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, 5, padding=2),
            nn.ReLU()
        )
        out_layer = nn.Sequential(
            nn.Conv2d(mid_channels, in_channels, 5, padding=2),
            nn.Tanh()
        )
        self.net = nn.Sequential(in_layer, *[reslayer(mid_channels) for _ in range(n_layers)], out_layer)

    def forward(self, x):
        return self.net(x)

In [34]:
class CouplingLayer(nn.Module):
    def __init__(self, mask_type='checkerboard', pattern=0, input_shape=(1, 3, 256, 256), mid_channels=64, n_layers=8):
        super(CouplingLayer, self).__init__()
        self.mask_type = mask_type
        self.input_shape = input_shape
        self.pattern = pattern
        self.create_mask()
        self.snet = Snet(in_channels=input_shape[1], mid_channels=mid_channels, n_layers=n_layers)
        self.tnet = Tnet(in_channels=input_shape[1], mid_channels=mid_channels, n_layers=n_layers)

    def create_mask(self):
        self.mask = torch.zeros(self.input_shape)
        if self.mask_type == 'checkerboard':
            if self.pattern == 0:
                self.mask[:, :, ::2, ::2] = 1.0
                self.mask[:, :, 1::2, 1::2] = 1.0
            else:
                self.mask[:, :, 1::2, ::2] = 1.0
                self.mask[:, :, ::2, 1::2] = 1.0
        elif self.mask_type == 'channel':
            if self.pattern == 0:
                self.mask[:, :self.input_shape[1]//2, :, :] = 1.0
            else:
                self.mask[:, self.input_shape[1]//2:, :, :] = 1.0
        else:
            raise ValueError('Invalid mask type')

    def forward(self, data):
        _, _, reverse = data
        if reverse == False:
            x, logdet, _ = data
            z = self.mask * x + (1 - self.mask) * (x * torch.exp(self.snet(self.mask * x)) + self.tnet(self.mask * x))
            logdet = torch.exp(self.snet(self.mask * x)).sum(dim=1).sum(dim=1).sum(dim=1)
            return z, logdet
        else:
            z, logdet, _ = data
            x = self.mask * z + (1 - self.mask) * (z - self.tnet(self.mask * z)) * torch.exp(-self.snet(self.mask * z))
            logdet = -torch.exp(self.snet(self.mask * z)).sum(dim=1).sum(dim=1).sum(dim=1)
            return x, logdet

In [26]:
class realNVP(nn.Module):
    def __init__(self, n_coupling_layers=3, input_shape=(1, 3, 256, 256), mid_channels=64, n_res_layers=8):
        super(realNVP, self).__init__()
        self.checker_in = nn.Sequential(*[CouplingLayer(input_shape=input_shape, mid_channels=mid_channels, n_layers=n_res_layers, mask_type='checkerboard', pattern=i%2==0) for i in range(n_coupling_layers)])
        self.channel_in = nn.Sequential(*[CouplingLayer(input_shape=(input_shape[0], input_shape[1]*4, input_shape[2]//2, input_shape[3]//2), mid_channels=mid_channels, n_layers=n_res_layers, mask_type='channel', pattern=i%2==1) for i in range(n_coupling_layers)])
        self.checker_out = nn.Sequential(*[CouplingLayer(input_shape=input_shape, mid_channels=mid_channels, n_layers=n_res_layers, mask_type='checkerboard', pattern=i%2==1) for i in range(n_coupling_layers)])

    def forward(self, x):
        logdet = torch.zeros(x.shape[0])
        x, logdet = self.checker_in((x, logdet, False))
        x, logdet = self.channel_in((x.reshape(x.shape[0], x.shape[1]*4, x.shape[2]//2, x.shape[3]//2), logdet, False))
        x, logdet = self.checker_out((x.reshape(x.shape[0], x.shape[1]//4, x.shape[2]*2, x.shape[3]*2), logdet, False))
        return x, logdet
    
    def reverse(self, z):
        logdet = torch.zeros(z.shape[0])
        z, logdet = self.checker_in((z, logdet, True))
        z, logdet = self.channel_in

In [36]:
# m = CouplingLayer(mask_type='checkerboard', pattern=0, input_shape=(1, 3, 256, 256))
m = CouplingLayer()
with torch.no_grad():
    x = torch.randn((3, 3, 256, 256))
    z, logdet = m((x, torch.zeros_like(x), False))

    x_hat = m((z, torch.zeros_like(x), True))
    print(z.shape, logdet)

torch.Size([3, 3, 256, 256]) tensor([197907.9375, 197907.9375, 197907.9375])
