In [1]:
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
from torch.distributions import Normal

import math
import time,os

from Plot_utils import *
from flow_utils import *
#from utils import plot_s
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
class planar_net(nn.Module):
    """
    Network for planar flow with linear transform and tanh activation
    """
    def __init__(self, ):
        super().__init__()
                
        self.w = nn.Parameter(torch.randn(1, 2).normal_(0, 0.1))
        self.b = nn.Parameter(torch.randn(1).normal_(0, 0.1))
        self.u = nn.Parameter(torch.randn(1, 2).normal_(0, 0.1))
        
        if (torch.mm(self.u, self.w.T)< -1).any():   
            self.get_u_hat()


        
    def get_u_hat(self):
        """Enforce w^T u >= -1. When using h(.) = tanh(.), this is a sufficient condition 
        for invertibility of the transformation f(z). See Appendix A.1.
        """
        wtu = torch.mm(self.u, self.w.T)
        m_wtu = -1 + torch.log(1 + torch.exp(wtu))
        self.u.data = (self.u + (m_wtu - wtu) * self.w / torch.norm(self.w, p=2, dim=1) ** 2)
   
    def forward(self, x):
        
        z = torch.matmul(x, self.w.T)      
        z = torch.add(z, self.b)               
        z = nn.tanh(z)                       
        z = torch.matmul(z, self.u)           
        z = z + x                                
        
        return z

In [6]:
class Flow(nn.Module):
    """
    Generic class for flow functions
    """

    def __init__(self):
        super().__init__() 
        
    @property    
    def base_dist(self):
        return Normal(
            loc=torch.zeros(2,device=device),
            scale=torch.ones(2,device=device), validate_args=False
        )
      
        
    def build(self): 
        
        return NotImplemented
        
    def flow_outputs(self, x):
        
        log_det = torch.zeros(x.shape[0], device=self.device)
        z = x
        for bijection in self.flow:
            z, ldj = bijection(z)
            log_det += ldj
            
        return z, log_det
    
    def sample(self, num_samples):
        z = self.base_dist.sample((num_samples,))
        for bijection in reversed(self.flow):
            z = bijection.inverse(z)
        return z

In [13]:
def newton_method(function, initial, iteration=100, convergence=torch.Tensor([0.0001, 0.0001]).to(device)):
            for i in range(iteration): 
                previous_data = initial.clone()
                value = function(initial)
                value.sum().backward()
                # update 
                initial.data -= (value / initial.grad).data
                # zero out current gradient to hold new gradients in next iteration 
                initial.grad.data.zero_() 
#                 print("epoch {}, obtain {}".format(i, initial))
                # Check convergence. 
                # When difference current epoch result and previous one is less than 
                # convergence factor, return result.
                comp = torch.le(torch.abs(initial - previous_data).data, torch.tensor(convergence))
                
                if comp.all() == True:
                    return initial.data
            return initial.data # return our final after iteration

class Planar(nn.Module):
    """
    Planar flow as introduced in arXiv: 1505.05770
        f(z) = z + u * h(w * z + b)
    """
    
    def __init__(self, net):
        super().__init__()
        self.net = net
        
    def forward(self,x):
        
        z = self.net(x)
            
        for name, param in self.net.named_parameters():
            if name == 'u' : 
                self.u = param
            elif name == 'w' : 
                self.w = param
            elif name == 'b' : 
                self.b = param
        
        affine = torch.mm(x, self.w.T) + self.b          
        psi = (1 - nn.Tanh()(affine) ** 2) * self.w      
        abs_det = (1 + torch.mm(self.u, psi.T)).abs()   
        log_det = torch.log(1e-4 + abs_det).squeeze(0)   
        
        return z, log_det
    
    def inverse(self, z):
        
        sol = torch.zeros(torch.Size([450,2])).to(device)
        for idx, sample in enumerate(z):
            print(str(idx), sample, 'sample')
            sample.requires_grad_()
            s = newton_method(self.net, sample)
            print(s, 's')
            sol[idx] = s
            
        sol = sol.reshape([450,2])    
        return sol
    
class PlanarFlow(Flow):   

    def __init__(self, net = planar_net, dim=10):
        Flow.__init__(self) 
        self.net = net
        self.dim = dim
        self.bijections = []
        self.build()
        self.flow = nn.ModuleList(self.bijections)  

    def build(self): 
        for i in range(self.dim):
            self.bijections += [Planar(self.net)]

    def sample(self, num_samples):
        z = self.base_dist.sample((num_samples,))
        
        z = nn.Sequential(reversed(self.flow))
            
        return z

In [14]:
net=planar_net
flow_planar = PlanarFlow(net = net, dim= 5).to(device)
print(flow_planar.flow)

ModuleList(
  (0): Planar()
  (1): Planar()
  (2): Planar()
  (3): Planar()
  (4): Planar()
)


In [15]:
z_samples = flow_planar.sample(450)
z_samples

TypeError: reversed is not a Module subclass