In [0]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import distributions
import torchvision
import torchvision.transforms as transforms

import os

import cv2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Utility functions

In [0]:
def plot_imgs(samples, title=None):
    fig, ax = plt.subplots(2, 5, figsize=(10,4))
    fig.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
    plt.rcParams.update({'font.size': 20})
    fig.suptitle(title)

    for i in range(2):
        for j in range(5):
            ax[i, j].imshow(samples[i*5 + j], cmap='gray')
            ax[i, j].axis('off')
    plt.show()

In [0]:
def load_data(data_dir, dataset):
    if dataset == 'MNIST':
        if not os.path.exists(data_dir):
            os.makedirs(data_dir)
        points = torchvision.datasets.MNIST(root=data_dir, train=True,
                                               transform=transforms.ToTensor(),download=True)
        
    elif dataset == 'MOONS':
        moon1 = [(0.5 + r*np.cos(t), r*np.sin(t)) 
                 for t in np.arange(0, np.pi, 0.01) for r in np.arange(0.9, 1.1, 0.01)]
        moon2 = [(-0.5 + r*np.cos(t), r*np.sin(t)) 
                 for t in np.arange(np.pi, 2*np.pi, 0.01) for r in np.arange(0.9, 1.1, 0.01)]
        points = moon1 + moon2
        points = torch.tensor(points)
        
    elif dataset == 'MLINPL':
        img = cv2.imread(data_dir + r'MLinPL.png',0) / 255
        n, m = img.shape
        n, m = int(0.2*n), int(0.2*m)
        img = cv2.resize(img, (m,n))

        points = []
        for i in range(n):
            for j in range(m):
                if img[i,j] == 0:
                    points.append((j,-i))

        points = np.array(points)
        points = (points - np.mean(points, axis=0, keepdims=True)) / np.std(points, axis=0, keepdims=True)
        points = torch.from_numpy(points)
        
    return points

# Model implementation

Implement Coupling layer of the Real NVP model here. You will need to fill the body of 3 functions: 
* `forward` - forward pass of the Real NVP
$$\begin{cases}
y_1 =& x_1\\ 
y_2 =& x_2 \odot \exp (s(x_1)) + t(x_1)
\end{cases}$$
* `inverse` - inversion of the forward pass
$$\begin{cases}
x_1 =& y_1\\ 
x_2 =& (y_2 - t(y_1)) \odot \exp (-s(y_1))
\end{cases}$$
* `get_mask` - swap of the processed part of latent code
$$\begin{cases}
y_{I_1} =& y_{I_2}\\ 
y_{I_2} =& y_{I_1}
\end{cases}$$

In [0]:
class Coupling_layer(nn.Module):        
    def __init__(self, device, input_dim, data_dim, n_layers, mask_type, hidden_dim=1024):
        super(Coupling_layer, self).__init__()
        
        self.device = device
        self.mask = self.get_mask(input_dim, mask_type)

        m = [nn.Linear(input_dim, hidden_dim), nn.LeakyReLU(0.1)]
        for i in range(n_layers-2):
            m.append(nn.Linear(hidden_dim, hidden_dim))
            m.append(nn.LeakyReLU(0.1))
        m.append(nn.Linear(hidden_dim, input_dim))
        m.append(nn.Tanh())
        self.m = nn.Sequential(*m)
        
        a = [nn.Linear(input_dim, hidden_dim), nn.LeakyReLU(0.1)]
        for i in range(n_layers-2):
            a.append(nn.Linear(hidden_dim, hidden_dim))
            a.append(nn.LeakyReLU(0.1))
        a.append(nn.Linear(hidden_dim, input_dim))
        self.a = nn.Sequential(*a)

    
    def forward(self, x):
        z = x.view(x.shape[0], -1)

        # TODO
        
        return z.view(x.shape), logdetJ
    
    
    def inverse(self, z):
        x = z.view(z.shape[0], -1)

        # TODO
        
        return x.view(z.shape)
    
    
    def get_mask(self, input_dim, mask_type):
        self.mask = torch.zeros(input_dim)
        # implement checkerboard type mask
        # start with 0 when mask_type is 0 and with 1 otherwise
        
        # TODO
        
        return self.mask.view(1,-1).to(self.device)

Here, you will need to implement two functions, responsible for calculating flow and inverse flow:
* `flow` - compositionn of coupling layers on an input $x$ - $f_K(f_{K-1}(\ldots f_1(x)))$. Returns $z$ and $\log \det J$
* `inv_flow` - compositionn of inversions of coupling layers on an encoding $z$. Returns $x$

In [0]:
class RealNVP():
    def __init__(self, input_dim, data_dim, n_layers, n_couplings, device):
        self.coupling_layers = []
        for i in range(n_couplings):
            if i%2 == 0:
                self.coupling_layers.append(Coupling_layer(device, input_dim, data_dim, n_layers, 0).to(device))
            else:
                self.coupling_layers.append(Coupling_layer(device, input_dim, data_dim, n_layers, 1).to(device))
    
    
    def flow(self, x):
        logdetJ_ac = 0
        
        # TODO
            
        return x, logdetJ_ac.view(-1,1)
    
    
    def inv_flow(self, z):
        
        # TODO
            
        return z
    
    
    def init_weights(self):
        for layer in self.coupling_layers:
            layer.apply(self.init_weights_helper)
    
    
    def init_weights_helper(self, Layer):
        name = Layer.__class__.__name__
        if name == 'Linear':
            torch.nn.init.normal_(Layer.weight, mean=0, std=0.02)
            if Layer.bias is not None:
                torch.nn.init.constant_(Layer.bias, 0)
    
    
    def get_parameters(self):
        parameters = []
        for layer in self.coupling_layers:
            parameters += list(layer.parameters())
        
        return parameters
    
    
    def train_model(self, if_train=True):
        if if_train:
            for layer in self.coupling_layers:
                layer.train()
        else:
            for layer in self.coupling_layers:
                layer.eval()          

In [0]:
def loss_fun(z, prior_z, logdetJ):
    z = z.view(z.shape[0], -1)
    ll_z = prior_z.log_prob(z.cpu()).to(device).view(-1,1) + logdetJ
    return -torch.mean(ll_z)

# Experiments

## MNIST

In [0]:
mnist = load_data(r'datasets/', 'MNIST')

mnist.data = (mnist.data.float() / 255. - 0.1307) / 0.3081
data = mnist.data
targets = mnist.targets
data = data[[idx for idx in range(len(targets)) if targets[idx] in [0,1,2,3,4]]]
targets = targets[[idx for idx in range(len(targets)) if targets[idx] in [0,1,2,3,4]]]

dataloader = DataLoader(data, batch_size=256, shuffle=True)

In [0]:
n_epochs = 1000
l_rate = 1e-4
n_layers = 6 # in each coupling layer
n_couplings = 18

data_dim = data.shape[1:]
input_dim = torch.prod(torch.tensor(data.shape[1:])).item()

prior_z = distributions.MultivariateNormal(torch.zeros(input_dim), torch.eye(input_dim))

realnvp = RealNVP(input_dim, data_dim, n_layers, n_couplings, device)
realnvp.init_weights()

optimizer = torch.optim.Adam(realnvp.get_parameters(), lr=l_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.9)

In [0]:
realnvp.train_model()

for i in range(n_epochs):
    loss_acc = 0
    for j, x in enumerate(dataloader):
        x = (x.float() + torch.rand(x.shape)).to(device) / 255
        z, logdetJ = realnvp.flow(x)
        loss = loss_fun(z, prior_z, logdetJ)
        loss_acc += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    scheduler.step()
    
    if i%2 == 0:
        print('Epoch: {}/{} Loss: {:.4f}'.format(i+1, n_epochs, loss_acc / (j+1)))
        with torch.no_grad():
            samples = torch.randn(10,28,28).to(device)
            samples = realnvp.inv_flow(samples)
            plot_imgs(samples.cpu().numpy())

In [0]:
realnvp.train_model(False)

with torch.no_grad():
    samples = torch.randn(10,28,28).to(device)
    samples = realnvp.inv_flow(samples)
    plot_imgs(samples.cpu().numpy(), None)

## Moons

In [0]:
moons = load_data(r'datasets/', 'MOONS')

plt.scatter(moons[:,0], moons[:,1], c='black')
plt.axis('off')
plt.show()

In [0]:
n_epochs = 10000
l_rate = 1e-4
n_layers = 6 # in each coupling layer
n_couplings = 6

data_dim = (1,2)
input_dim = 2

prior_z = distributions.MultivariateNormal(torch.zeros(input_dim), torch.eye(input_dim))

realnvp2D = RealNVP(input_dim, data_dim, n_layers, n_couplings, device)
realnvp2D.init_weights()

optimizer = torch.optim.Adam(realnvp2D.get_parameters(), lr=l_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.9)

In [0]:
realnvp2D.train_model()
moons = moons.float().to(device)

for i in range(n_epochs):
    x = moons + 1e-6 * torch.rand(moons.shape).to(device)
    
    z, logdetJ = realnvp2D.flow(x)
    loss = loss_fun(z, prior_z, logdetJ)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    
    if i%50 == 0:
        print('Epoch: {}/{} Loss: {:.4f}'.format(i+1, n_epochs, loss.item()))
        with torch.no_grad():
            samples = torch.randn(1000,2).to(device)
            samples = realnvp2D.inv_flow(samples).view(-1,2)
            plt.scatter(samples[:,0].cpu().numpy(), samples[:,1].cpu().numpy(), c='black')
            plt.axis('off')
            plt.show()

In [0]:
realnvp2D.train_model(False)

with torch.no_grad():
    samples = torch.randn(10000,2).to(device)
    samples = realnvp2D.inv_flow(samples).view(-1,2)
    plt.scatter(samples[:,0].cpu().numpy(), samples[:,1].cpu().numpy(), c='black', alpha=0.1)
    plt.axis('off')
    plt.show()

## Sign

In [0]:
sign = load_data(r'images/', 'MLINPL')

plt.scatter(sign[:,0], sign[:,1], c='black')
plt.axis('off')
plt.show()

In [0]:
n_epochs = 10000
l_rate = 1e-3
n_layers = 6 # in each coupling layer
n_couplings = 6

data_dim = (1,2)
input_dim = 2

prior_z = distributions.MultivariateNormal(torch.zeros(input_dim), torch.eye(input_dim))

realnvp_sign = RealNVP(input_dim, data_dim, n_layers, n_couplings, device)
realnvp_sign.init_weights()

optimizer = torch.optim.Adam(realnvp_sign.get_parameters(), lr=l_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.9)

In [0]:
realnvp_sign.train_model()
sign = sign.float().to(device)

for i in range(n_epochs):
    x = sign + 1e-6 * torch.rand(sign.shape).to(device)
    
    z, logdetJ = realnvp_sign.flow(x)
    loss = loss_fun(z, prior_z, logdetJ)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    
    if i%50 == 0:
        print('Epoch: {}/{} Loss: {:.4f}'.format(i+1, n_epochs, loss.item()))
        with torch.no_grad():
            samples = torch.randn(1000,2).to(device)
            samples = realnvp_sign.inv_flow(samples).view(-1,2)
            plt.scatter(samples[:,0].cpu().numpy(), samples[:,1].cpu().numpy(), c='black')
            plt.axis('off')
            plt.show()

In [0]:
realnvp_sign.train_model(False)

with torch.no_grad():
    samples = torch.randn(10000,2).to(device)
    samples = realnvp_sign.inv_flow(samples).view(-1,2)
    plt.scatter(samples[:,0].cpu().numpy(), samples[:,1].cpu().numpy(), c='black', alpha=0.1)
    plt.axis('off')
    plt.show()