In [1]:
import torch
from torch import nn
import numpy as np
from torch import distributions
from sklearn import datasets
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
# Phi.py
# neural network to model the potential function
import torch
import torch.nn as nn
import copy
import math

def antiderivTanh(x): # activation function aka the antiderivative of tanh
    return torch.abs(x) + torch.log(1+torch.exp(-2.0*torch.abs(x)))

def derivTanh(x): # act'' aka the second derivative of the activation function antiderivTanh
    return 1 - torch.pow( torch.tanh(x) , 2 )

class ResNN(nn.Module):
    def __init__(self, d, m, nTh=2):
        """
            ResNet N portion of Phi

            This implementation was first described in:

            @article{onken2020otflow,
               title={OT-Flow: Fast and Accurate Continuous Normalizing Flows via Optimal Transport},
                author={Derek Onken and Samy Wu Fung and Xingjian Li and Lars Ruthotto},
                year={2020},
                journal = {arXiv preprint arXiv:2006.00104},
            }
        :param d:   int, dimension of space input (expect inputs to be d+1 for space-time)
        :param m:   int, hidden dimension
        :param nTh: int, number of resNet layers , (number of theta layers)
        """
        super().__init__()

        if nTh < 2:
            print("nTh must be an integer >= 2")
            exit(1)

        self.d = d
        self.m = m
        self.nTh = nTh
        self.layers = nn.ModuleList([])
        self.layers.append(nn.Linear(d + 1, m, bias=True)) # opening layer
        self.layers.append(nn.Linear(m,m, bias=True)) # resnet layers
        for i in range(nTh-2):
            self.layers.append(copy.deepcopy(self.layers[1]))
        self.act = antiderivTanh
        self.h = 1.0 / (self.nTh-1) # step size for the ResNet

    def forward(self, x):
        """
            N(s;theta). the forward propogation of the ResNet
        :param x: tensor nex-by-d+1, inputs
        :return:  tensor nex-by-m,   outputs
        """

        x = self.act(self.layers[0].forward(x))

        for i in range(1,self.nTh):
            x = x + self.h * self.act(self.layers[i](x))

        return x



class Phi(nn.Module):
    def __init__(self, nTh, m, d, r=10, alph=[1.0] * 5):
        """
            neural network approximating Phi (see Eq. (9) in our paper)

            Phi( x,t ) = w'*ResNet( [x;t]) + 0.5*[x' t] * A'A * [x;t] + b'*[x;t] + c

        :param nTh:  int, number of resNet layers , (number of theta layers)
        :param m:    int, hidden dimension
        :param d:    int, dimension of space input (expect inputs to be d+1 for space-time)
        :param r:    int, rank r for the A matrix
        :param alph: list, alpha values / weighted multipliers for the optimization problem
        """
        super().__init__()

        self.m    = m
        self.nTh  = nTh
        self.d    = d
        self.alph = alph

        r = min(r,d+1) # if number of dimensions is smaller than default r, use that

        self.A  = nn.Parameter(torch.zeros(r, d+1) , requires_grad=True)
        self.A  = nn.init.xavier_uniform_(self.A)
        self.c  = nn.Linear( d+1  , 1  , bias=True)  # b'*[x;t] + c
        self.w  = nn.Linear( m    , 1  , bias=False)

        self.N = ResNN(d, m, nTh=nTh)

        # set initial values
        self.w.weight.data = torch.ones(self.w.weight.data.shape)
        self.c.weight.data = torch.zeros(self.c.weight.data.shape)
        self.c.bias.data   = torch.zeros(self.c.bias.data.shape)



    def forward(self, x):
        """ calculating Phi(s, theta)...not used in OT-Flow """

        # force A to be symmetric
        symA = torch.matmul(torch.t(self.A), self.A) # A'A

        return self.w( self.N(x)) + 0.5 * torch.sum( torch.matmul(x , symA) * x , dim=1, keepdims=True) + self.c(x)


    def trHess(self,x,d=None, justGrad=False ):
        """
        compute gradient of Phi wrt x and trace(Hessian of Phi); see Eq. (11) and Eq. (13), respectively
        recomputes the forward propogation portions of Phi

        :param x: input data, torch Tensor nex-by-d
        :param justGrad: boolean, if True only return gradient, if False return (grad, trHess)
        :return: gradient , trace(hessian)    OR    just gradient
        """

        # code in E = eye(d+1,d) as index slicing instead of matrix multiplication
        # assumes specific N.act as the antiderivative of tanh

        N    = self.N
        m    = N.layers[0].weight.shape[0]
        nex  = x.shape[0] # number of examples in the batch
        if d is None:
            d    = x.shape[1]-1
        symA = torch.matmul(self.A.t(), self.A)

        u = [] # hold the u_0,u_1,...,u_M for the forward pass
        z = N.nTh*[None] # hold the z_0,z_1,...,z_M for the backward pass
        # preallocate z because we will store in the backward pass and we want the indices to match the paper

        # Forward of ResNet N and fill u
        opening     = N.layers[0].forward(x) # K_0 * S + b_0
        u.append(N.act(opening)) # u0
        feat = u[0]

        for i in range(1,N.nTh):
            feat = feat + N.h * N.act(N.layers[i](feat))
            u.append(feat)

        # going to be used more than once
        tanhopen = torch.tanh(opening) # act'( K_0 * S + b_0 )

        # compute gradient and fill z
        for i in range(N.nTh-1,0,-1): # work backwards, placing z_i in appropriate spot
            if i == N.nTh-1:
                term = self.w.weight.t()
            else:
                term = z[i+1]

            # z_i = z_{i+1} + h K_i' diag(...) z_{i+1}
            z[i] = term + N.h * torch.mm( N.layers[i].weight.t() , torch.tanh( N.layers[i].forward(u[i-1]) ).t() * term)

        # z_0 = K_0' diag(...) z_1
        z[0] = torch.mm( N.layers[0].weight.t() , tanhopen.t() * z[1] )
        grad = z[0] + torch.mm(symA, x.t() ) + self.c.weight.t()

        if justGrad:
            return grad.t()

        # -----------------
        # trace of Hessian
        #-----------------

        # t_0, the trace of the opening layer
        Kopen = N.layers[0].weight[:,0:d]    # indexed version of Kopen = torch.mm( N.layers[0].weight, E  )
        temp  = derivTanh(opening.t()) * z[1]
        trH  = torch.sum(temp.reshape(m, -1, nex) * torch.pow(Kopen.unsqueeze(2), 2), dim=(0, 1)) # trH = t_0

        # grad_s u_0 ^ T
        temp = tanhopen.t()   # act'( K_0 * S + b_0 )
        Jac  = Kopen.unsqueeze(2) * temp.unsqueeze(1) # K_0' * act'( K_0 * S + b_0 )
        # Jac is shape m by d by nex

        # t_i, trace of the resNet layers
        # KJ is the K_i^T * grad_s u_{i-1}^T
        for i in range(1,N.nTh):
            KJ  = torch.mm(N.layers[i].weight , Jac.reshape(m,-1) )
            KJ  = KJ.reshape(m,-1,nex)
            if i == N.nTh-1:
                term = self.w.weight.t()
            else:
                term = z[i+1]

            temp = N.layers[i].forward(u[i-1]).t() # (K_i * u_{i-1} + b_i)
            t_i = torch.sum(  ( derivTanh(temp) * term ).reshape(m,-1,nex)  *  torch.pow(KJ,2) ,  dim=(0, 1) )
            trH  = trH + N.h * t_i  # add t_i to the accumulate trace
            if i < N.nTh:
                Jac = Jac + N.h * torch.tanh(temp).reshape(m, -1, nex) * KJ # update Jacobian

        return grad.t(), trH + torch.trace(symA[0:d,0:d])
        # indexed version of: return grad.t() ,  trH + torch.trace( torch.mm( E.t() , torch.mm(  symA , E) ) )





In [4]:
import torch
from torch import nn
from torch.nn.functional import pad
from torch import distributions


class OTFlow(nn.Module):
    """
    OT-Flow for density estimation and sampling as described in

    @article{onken2020otflow,
        title={OT-Flow: Fast and Accurate Continuous Normalizing Flows via Optimal Transport},
        author={Derek Onken and Samy Wu Fung and Xingjian Li and Lars Ruthotto},
        year={2020},
        journal = {arXiv preprint arXiv:2006.00104},
    }

    """
    def __init__(self, net, nt, alph, prior, T=1.0):
        """
        Initialize OT-Flow

        :param net: network for value function
        :param nt: number of rk4 steps
        :param alph: penalty parameters
        :param prior: latent distribution, e.g., distributions.MultivariateNormal(torch.zeros(d), torch.eye(d))
        """
        super(OTFlow, self).__init__()
        self.prior = prior
        self.nt = nt
        self.T = T
        self.net = net
        self.alph = alph


    def g(self, z, nt = None, storeAll=False):
        """
        :param z: latent variable
        :return: g(z) and hidden states
        """
        return self.integrate(z,[self.T, 0.0], nt,storeAll)

    def ginv(self, x, nt=None, storeAll=False):
        """
        :param x: sample from dataset
        :return: g^(-1)(x), value of log-determinant, and hidden layers
        """

        return self.integrate(x,[0.0, self.T], nt,storeAll)

    def log_prob(self, x, nt=None):
        """
        Compute log-probability of a sample using change of variable formula

        :param x: sample from dataset
        :return: logp_{\theta}(x)
        """
        z, _, log_det_ginv, v, r = self.ginv(x,nt)
        return self.prior.log_prob(z) - log_det_ginv, v, r

    def sample(self, s,nt=None):
        """
        Draw random samples from p_{\theta}

        :param s: number of samples to draw
        :return:
        """
        z = self.prior.sample((s, 1)).squeeze(1)
        x, _, _, _, _ = self.g(z,nt)
        return x

    def f(self,x, t):
        """
        neural ODE combining the characteristics and log-determinant (see Eq. (2)), the transport costs (see Eq. (5)), and
        the HJB regularizer (see Eq. (7)).

        d_t  [x ; l ; v ; r] = odefun( [x ; l ; v ; r] , t )

        x - particle position
        l - log determinant
        v - accumulated transport costs (Lagrangian)
        r - accumulates violation of HJB condition along trajectory
        """
        nex, d = x.shape
        z = pad(x[:, :d], (0, 1, 0, 0), value=t)
        gradPhi, trH = self.net.trHess(z)

        dx = -(1.0 / self.alph[0]) * gradPhi[:, 0:d]
        dl = (1.0 / self.alph[0]) * trH
        # dv = 0.5 * torch.sum(torch.pow(dx, 2), 1)
        # dr = torch.abs(-gradPhi[:, -1] + self.alph[0] * dv)

        return dx, dl, dv, dr

    def integrate(self, y, tspan, nt=None,storeAll=False):
        """
        RK4 time-stepping to integrate the neural ODE

        :param y: initial state
        :param tspan: time interval (can go backward in time)
        :param nt: number of time steps (default is self.nt)
        :return: y (final state), ys (all states), l (log determinant), v (transport costs), r (HJB penalty)
        """
        if nt is None:
            nt = self.nt

        nex, d = y.shape
        h = (tspan[1] - tspan[0])/ nt
        tk = tspan[0]

        l = torch.zeros((nex), device=y.device, dtype=y.dtype)
        v = torch.zeros((nex), device=y.device, dtype=y.dtype)
        r = torch.zeros((nex), device=y.device, dtype=y.dtype)
        if storeAll:
            ys = [torch.clone(y).detach().cpu()]
        else:
            ys = None

        w =  [(h/6.0),2.0*(h/6.0),2.0*(h/6.0),1.0*(h/6.0)]
        for i in range(nt):
            y0 = y

            dy, dl, dv, dr = self.f(y0, tk)
            y = y0 + w[0] * dy
            l += w[0] * dl
            v += w[0] * dv
            r += w[0] * dr

            dy, dl, dv, dr =  self.f(y0 + 0.5 * h * dy, tk + (h / 2))
            y += w[1] * dy
            l += w[1] * dl
            v += w[1] * dv
            r += w[1] * dr

            dy, dl, dv, dr = self.f(y0 + 0.5 * h * dy, tk + (h / 2))
            y += w[2] * dy
            l += w[2] * dl
            v += w[2] * dv
            r += w[2] * dr

            dy, dl, dv, dr = self.f(y0 + h * dy, tk + h)
            y += w[3] * dy
            l += w[3] * dl
            v += w[3] * dv
            r += w[3] * dr

            if storeAll:
                ys.append(torch.clone(y).detach().cpu())
            tk +=h

        return y, ys, l, v, r



In [5]:

x = []
distance = 0.4
sample_multiplier = 10
# 16
# for i in np.arange(0, 8, distance):
#     for j in np.arange(0, 2, distance):
#         x.append([j, i])
for i in range(16 * sample_multiplier):
    x.append([np.random.uniform(2, 6), np.random.uniform(2, 6)])



x = torch.tensor(x)
x = x.to(device)

In [6]:
def MMD(x, y):
        gamma = 2
        xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
        rx = (xx.diag().unsqueeze(0).expand_as(xx))
        ry = (yy.diag().unsqueeze(0).expand_as(yy))

        dxx = rx.t() + rx - 2. * xx # Used for A in (1)
        dyy = ry.t() + ry - 2. * yy # Used for B in (1)
        rxx = rx[0].repeat(y.shape[0], 1)
        ryy = ry[0].repeat(x.shape[0], 1) 
        dxy = rxx.t() + ryy - 2. * zz # Used for C in (1)

        XX, YY, XY = (torch.zeros(xx.shape).to(device),
                      torch.zeros(yy.shape).to(device),
                      torch.zeros(zz.shape).to(device))
        XX += 1/(1 + dxx/gamma**2)
        YY += 1/(1 + dyy/gamma**2)
        XY += 1/(1 + dxy/gamma**2)
        return XX.mean() + YY.mean() - 2*XY.mean()

In [7]:
f = open("results.txt", "w")
for i in [8, 16, 24, 32]:
    for j in [2, 4]:
        f.write("OTFlow layers: " + str(j) + " width: " + str(i) + "\n")
        nTh = j # number of layers
        width = i # width of network
        alph = [1.0,10.0,5.0] # alph[0]-> weight for transport costs, alph[1] and alph[2]-> HJB penalties
        net = Phi(nTh=nTh, m=width, d=2, alph=alph).to(device)
        prior = distributions.MultivariateNormal(torch.zeros(2).to(device), torch.eye(2).to(device))
        nt = 2                 # number of rk4 steps to solve neural ODE
        flow = OTFlow(net, nt, alph, prior, T=1.0)
        loss_arr = []

        batch_size = 4000
        num_steps = 5000
        f.write("MMD before training: " + str(MMD(flow.sample(x.shape[0]), x).item()) + "\n")

        optim = torch.optim.Adam(net.parameters(), lr=0.01) # lr=0.04 good

        dataloader = torch.utils.data.DataLoader(x, batch_size=batch_size, shuffle=True)
        sheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=2500, gamma=0.1)
        for step in range(num_steps):
            for x_batch in dataloader:
                x_batch = x_batch.to(device)
                logp, L, P = flow.log_prob(x_batch)
                loss = (-alph[1]*logp + alph[0] * L + alph[2] * P).mean()
                optim.zero_grad()
                loss.backward()
                optim.step()
            loss_arr.append(loss.item())
            sheduler.step()
            if step % 1000 == 999:
                print("Step: {}, Loss: {}".format(step, loss.item()))

        f.write("MMD after training: " + str(MMD(flow.sample(x.shape[0]), x).item()) + "\n")
        f.write("\n")
        np.savetxt("l"+str(j) + "w" + str(i) + ".txt", loss_arr, fmt='%f')

        col_red = '#c61826'
        col_dark_red = '#590d08'
        col_blue = '#01024d'
        xs = flow.sample(1000).detach().cpu()
        plt.figure()
        plt.scatter(xs[:, 0], xs[:, 1], s=25, alpha=1, label="Generated Data", c=col_red)
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.ylim(0, 8)
        plt.xlim(0, 8)
        plt.savefig("OTFlow_plot_square_gen_l" + str(j) + "w" + str(i) + ".png")
        plt.close()

        plt.figure()
        plt.scatter(x[:, 0], x[:, 1], alpha = 1, s=25, label="Original Data", c = col_blue)
        plt.scatter(xs[:, 0], xs[:, 1], s=25, alpha=1, label="Generated Data", c=col_red)
        import matplotlib.patches as patches
        square = patches.Rectangle((2, 2), 4, 4, linewidth=1, edgecolor=col_dark_red, facecolor='none')

        plt.gca().add_patch(square)
        plt.ylim(0, 8)
        plt.xlim(0, 8)
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.legend()
        plt.savefig("OTFlow_plot_square_gen_smp_l" + str(j) + "w" + str(i) + ".png")
        plt.legend()
        plt.close()
        plt.figure()
        plt.semilogy(loss_arr)
        plt.xlabel("Epochs")
        plt.ylabel("Log Log Likelihood Loss")
        plt.ylim(50, 200)
        plt.savefig("OTFlow_square_semilog_loss_plot_l" + str(j) + "w" + str(i) + ".png")
        plt.close()
        plt.figure()
        plt.plot(loss_arr)
        plt.xlabel("Epochs")
        plt.ylabel("Log Likelihood Loss")
        plt.ylim(25, 100)
        plt.savefig("OTFlow_square_loss_plot_l" + str(j) + "w" + str(i) + ".png")
        plt.close()

f.close()   


Step: 999, Loss: 47.059349060058594
Step: 1999, Loss: 45.915950775146484
Step: 2999, Loss: 44.71723175048828
Step: 3999, Loss: 44.58055877685547
Step: 4999, Loss: 44.39386749267578
Step: 999, Loss: 45.93810272216797


In [15]:
f.close()