In [1]:
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.distributions import MultivariateNormal
import matplotlib.pyplot as plt
from utility import *

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
class coupling_layer(nn.Module):
    """
    Implements coupling layer:
    Forward transform (x = f(z)):
    x_1:d = z_1:d
    x_d+1:D = (z_d+1:D - \mu_d+1:D)*exp(-\alpha_d+1:D)
    
    Inverse transform (z = f^-1(x)):
    z_1:d = x_1:d
    z_d+1:D = x_d+1_D*exp(\alpha_d+1:D)+\mu_d+1:D
    
    d is a parameter here (although no choice in 2D case).
    """
    def __init__(self, net):
        """
        Init net for parametrizing $\alpha$ and $\mu$.
        """
        super().__init__()
        self.net = net
        
    def forward(self, z):
        with torch.no_grad():
            z_1_d, z_d_D = torch.chunk(z, 2, dim=-1)
            outs = self.net(z_1_d)
            mu, alpha = torch.chunk(outs, 2, dim=-1)
            x_1_d = z_1_d
            x_d_D = (z_d_D-mu)*torch.exp(-alpha)
        x = torch.cat([x_1_d, x_d_D], dim=-1)
        return x    
    
    def inverse(self, x):
        x_1_d, x_d_D = torch.chunk(x, 2, dim=-1)
        outs = self.net(x_1_d)
        mu, alpha = torch.chunk(outs, 2, dim=-1)
        z_1_d = x_1_d
        z_d_D = x_d_D*torch.exp(alpha)+mu
        z = torch.cat([z_1_d, z_d_D], dim=-1)
        return z, alpha       
    
class CouplingFlow(nn.Module):
    """
    Now create the planar flow. Stack layers of transforms and forward.
    *K is a choice in planar flows.*
    """
    def __init__(self, dim, K):
        super().__init__()
        
        self.transform_layers = [PlanarTransform(dim) for i in range(K)]
        self.model = nn.Sequential(*self.transform_layers)
        
    def forward(self, z):
        log_det_J = 0
        
        for layer in self.transform_layers:
            log_det_J += layer.log_determ(z)
            z = layer(z)
            
        return z, log_det_J

In [8]:
cl = coupling_layer(1)
print(cl(torch.Tensor([[1,2]])))

(tensor([[1.]]), tensor([[2.]]))
