In [16]:
import torch
import torch.nn as nn
import torch.distributions as D
import math
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os

In [17]:
def Gaussian_integral2(z,mu,sigma,base_sigma,k_dim=3,pi=math.pi):
    c=-torch.matmul(sigma.t()[0],sigma.t()[1:].inverse())
    h=torch.matmul(z-mu,sigma.t()[1:].inverse())
    a=(1+c**2).sum()/(2*base_sigma)
    b=(-c*h).sum(1)/base_sigma
    m=(-h**2).sum(1)/(2*base_sigma)
    return (pi/(a*(2*base_sigma*pi)**k_dim)).sqrt()*(m+b**2/(4*a)).exp()

In [18]:
# --------------------
# Flow
# --------------------
class PlanarTransform(nn.Module):
    def __init__(self, init_sigma=0.01):
        super().__init__()
        self.u = nn.Parameter(torch.randn(1, 2).normal_(0, init_sigma))
        self.w = nn.Parameter(torch.randn(1, 2).normal_(0, init_sigma))
        self.b = nn.Parameter(torch.randn(1).fill_(0))

    def forward(self, x, normalize_u=True):
        # allow for a single forward pass over all the transforms in the flows with a Sequential container
        if isinstance(x, tuple):
            z, sum_log_abs_det_jacobians = x
        else:
            z, sum_log_abs_det_jacobians = x, 0

        # normalize u s.t. w @ u >= -1; sufficient condition for invertibility
        u_hat = self.u
        if normalize_u:
            wtu = (self.w @ self.u.t()).squeeze()
            m_wtu = - 1 + torch.log1p(wtu.exp())
            u_hat = self.u + (m_wtu - wtu) * self.w / (self.w @ self.w.t())

        # compute transform
        f_z = z + u_hat * torch.tanh(z @ self.w.t() + self.b)
        # compute log_abs_det_jacobian
        psi = (1 - torch.tanh(z @ self.w.t() + self.b)**2) @ self.w
        det = 1 + psi @ u_hat.t()
        log_abs_det_jacobian = torch.log(torch.abs(det) + 1e-6).squeeze()
        sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + log_abs_det_jacobian

        return f_z, sum_log_abs_det_jacobians

class AffineTransform(nn.Module):
    def __init__(self, learnable=True):
        super().__init__()
        self.mu = nn.Parameter(torch.zeros(2)).requires_grad_(learnable)
        self.sigma = nn.Parameter(torch.ones(2, 3).normal_(0, 1)).requires_grad_(learnable)
        self.u_hat2=nn.Parameter(torch.ones(2)).requires_grad_(learnable)
    def forward(self, x):
        z = self.u_hat2*torch.tanh(self.mu + torch.matmul(x,self.sigma.t()))
        ab=1-torch.tanh(self.mu+torch.matmul(x,self.sigma.t()))**2
        sum_log_abs_det_jacobians = torch.log((((self.u_hat2[0]*self.u_hat2[1]*ab[:,0]*ab[:,1])**2*(self.sigma@self.sigma.t()).det()).abs())**0.5*Gaussian_integral2(torch.atanh(z/self.u_hat2),self.mu,self.sigma,args.base_sigma))
        return z, sum_log_abs_det_jacobians

In [19]:
# --------------------
# Training
# --------------------

def optimize_flow(base_dist, flow, target_energy_potential, optimizer, args):

    # anneal rate for free energy
    temp = lambda i: min(1, 0.01 + i/10000)

    for i in range(args.start_step, args.n_steps):

        # sample base dist
        z = base_dist.sample((args.batch_size, )).to(args.device)

        # pass through flow:
        # 1. compute expected log_prob of data under base dist -- nothing tied to parameters here so irrelevant to grads
        base_log_prob = base_dist.log_prob(z)
        # 2. compute sum of log_abs_det_jacobian through the flow
        zk, sum_log_abs_det_jacobians = flow(z)
        # 3. compute expected log_prob of z_k the target_energy potential
        p_log_prob = - temp(i) * target_energy_potential(zk)  # p = exp(-potential) ==> p_log_prob = - potential

        loss = base_log_prob - sum_log_abs_det_jacobians - args.beta * p_log_prob
        loss = loss.mean(0)

        # compute loss and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 10000 == 0:
            # display loss
            log_qk = base_dist.log_prob(z) - sum_log_abs_det_jacobians
            print('{}: step {:5d} / {}; loss {:.3f}; base_log_prob {:.3f}, sum log dets {:.3f}, p_log_prob {:.3f}, max base = {:.3f}; max qk = {:.3f} \
                zk_mean {}, zk_sigma {}; base_mu {}, base_log_sigma {}'.format(
                args.target_potential, i, args.n_steps, loss.item(), base_log_prob.mean(0).item(), sum_log_abs_det_jacobians.mean(0).item(),
                p_log_prob.mean(0).item(), base_log_prob.exp().max().item(), log_qk.exp().max().item(),
                zk.mean(0).cpu().data.numpy(), zk.var(0).sqrt().cpu().data.numpy(),
                base_dist.loc.cpu().data.numpy() if not args.learn_base else flow[0].mu.cpu().data.numpy(),
                base_dist.covariance_matrix.cpu().diag().data.numpy() if not args.learn_base else flow[0].sigma.cpu().data.numpy()))

            # save model
            torch.save({'step': i,
                        'flow_state': flow.state_dict(),
                        'optimizer_state': optimizer.state_dict()},
                        os.path.join(args.output_dir, 'model_state_flow_length_{}.pt'.format(args.flow_length)))

            # plot and save results
            with torch.no_grad():
                plot_flow2(base_dist, flow, os.path.join(args.output_dir, 'approximating_flow_step{}.png'.format(i)), args)

In [20]:
# --------------------
# Plotting
# --------------------

def plot_flow2(base_dist, flow, filename, args):
    n = 100
    lim = 4
    limx=4
    fig, axs = plt.subplots(1, 2, subplot_kw={'aspect': 'equal'})

    # plot flow-transformed base dist sample and histogram
    z = base_dist.sample((10000,))
    zk, _ = flow(z)
    zk = zk.cpu().data.numpy()
    axs[0].scatter(zk[:,0], zk[:,1], s=10, alpha=0.4)
    axs[1].hist2d(zk[:,0], zk[:,1], bins=[limx*50,lim*50], range = [[-4,4],[-4,4]], cmap=plt.cm.jet)

    for ax in plt.gcf().axes:
        ax.set_xlim(-4, 4)
        ax.set_ylim(-4, 4)
        ax.get_xaxis().set_visible(True)
        ax.get_yaxis().set_visible(True)
        ax.invert_yaxis()

    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

In [21]:
# --------------------
# Plotting
# --------------------

def plot_flow(base_dist, flow, filename, args):
    n = 200
    lim = 4
    limx=4
    z = base_dist.sample((10000,))
    fig, axs = plt.subplots(1, 2, subplot_kw={'aspect': 'equal'})
    zz = z.cpu().data.numpy()
    axs[0].scatter(zz[:,0], zz[:,1], s=10, alpha=0.4)
    axs[1].hist2d(zz[:,0], zz[:,1], bins=[limx*50,lim*50], range = [[-4,4],[-4,4]], cmap=plt.cm.jet)
    for ax in plt.gcf().axes:
        ax.set_xlim(-4, 4)
        ax.set_ylim(-4, 4)
        ax.get_xaxis().set_visible(True)
        ax.get_yaxis().set_visible(True)
        ax.invert_yaxis()

    plt.tight_layout()
    plt.savefig(filename+'_0.png')
    plt.close()
    # plot target density we're trying to approx
    # plot_target_density(u_z, axs[0,0], lim, n)

    # plot posterior approx density
    # plot_flow_density(base_dist, flow, axs[0,1], lim, n)

    # plot flow-transformed base dist sample and histogram
    for i in range(1,34):
        fig, axs = plt.subplots(1, 2, subplot_kw={'aspect': 'equal'})

        zk, _ = flow[0:i](z)
        zk = zk.cpu().data.numpy()
        axs[0].scatter(zk[:,0], zk[:,1], s=10, alpha=0.4)
        axs[1].hist2d(zk[:,0], zk[:,1], bins=[limx*50,lim*50], range = [[-4,4],[-4,4]], cmap=plt.cm.jet)

        for ax in plt.gcf().axes:
            ax.set_xlim(-4, 4)
            ax.set_ylim(-4, 4)
            ax.get_xaxis().set_visible(True)
            ax.get_yaxis().set_visible(True)
            ax.invert_yaxis()

        plt.tight_layout()
        plt.savefig(filename+'_%d.png'%i)
        plt.close()

In [22]:
def plot_target_density(u_z, ax,  n=200, output_dir=None):
    x1 = torch.linspace(-4, 4, n)
    x2 = torch.linspace(-4, 4, n)
    xx, yy = torch.meshgrid((x1, x2))
    zz = torch.stack((xx.flatten(), yy.flatten()), dim=-1).squeeze().to(args.device)
    xx=xx.cpu()
    yy=yy.cpu()
    zz=zz.cpu()
    ax.pcolormesh(xx, yy, torch.exp(-u_z(zz)).view(n,n).data, cmap=plt.cm.jet)

    for ax in plt.gcf().axes:
        ax.set_xlim(-4, 4)
        ax.set_ylim(-4, 4)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.invert_yaxis()

    if output_dir:
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'target_potential_density.png'))
        plt.close()

In [23]:
def plot_flow_density(base_dist, flow, ax, range_lim=4, n=100, output_dir=None):
    x = torch.linspace(-range_lim, range_lim, n)
    xx, yy,hh= torch.meshgrid((x, x, x))
    zz = torch.stack((xx.flatten(), yy.flatten(),hh.flatten()), dim=-1).squeeze().to(args.device)
    # plot posterior approx density
    zzk, sum_log_abs_det_jacobians = flow(zz)
    log_q0 = base_dist.log_prob(zz)
    log_qk = log_q0 - sum_log_abs_det_jacobians
    qk = log_qk.exp().cpu()
    zzk = zzk.cpu()
    n1=1000
    ax.pcolormesh(zzk[:,0].view(n1,n1).data, zzk[:,1].view(n1,n1).data, qk.view(n1,n1).data, cmap=plt.cm.jet)
    ax.set_facecolor(plt.cm.jet(0.))

    for ax in plt.gcf().axes:
        ax.set_xlim(-range_lim, range_lim)
        ax.set_ylim(-range_lim, range_lim)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.invert_yaxis()

    if output_dir:
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'flow_k{}_density.png'.format(len(flow)-1)))
        plt.close()

In [24]:
w1 = lambda z: torch.sin(2 * math.pi * z[:,0] / 4)
w2 = lambda z: 3 * torch.exp(-0.5 * ((z[:,0] - 1)/0.6)**2)
w3 = lambda z: 3 * torch.sigmoid((z[:,0] - 1) / 0.3)

u_z1 = lambda z: 0.5 * ((torch.norm(z, p=2, dim=1) - 2) / 0.4)**2 - \
                 torch.log(torch.exp(-0.5*((z[:,0] - 2) / 0.6)**2) + torch.exp(-0.5*((z[:,0] + 2) / 0.6)**2) + 1e-10)
u_z2 = lambda z: 0.5 * ((z[:,1] - w1(z)) / 0.4)**2
u_z3 = lambda z: - torch.log(torch.exp(-0.5*((z[:,1] - w1(z))/0.35)**2) + torch.exp(-0.5*((z[:,1] - w1(z) + w2(z))/0.35)**2) + 1e-10)
u_z4 = lambda z: - torch.log(torch.exp(-0.5*((z[:,1] - w1(z))/0.4)**2) + torch.exp(-0.5*((z[:,1] - w1(z) + w3(z))/0.35)**2) + 1e-10)

In [25]:
# --------------------
# Run
# --------------------
class args:
    # actions
    use_cuda=True                                          # whether to use cuda 
    restore_file=0 # Path to model to restore.
    output_dir='./results_test_v2_u_z1_k8' # Path to output folder.
    train=True # Train a flow.
    evaluate=True # Evaluate a flow
    plot=True # Plot a flow and target density.

    # flow params
    learn_base=True # Whether to learn a mu-sigma affine transform of the base distribution.
    flow_length=8 # Length of the flow.
    base_sigma=1.0 # Std of the base isotropic 0-mean Gaussian distribution.

    # target potential
    target_potential='u_z1' # Which potential function to approximate.

    # training params
    seed=2 # Random seed
    init_sigma=1.0 #Initialization std for the trainable flow parameters.
    batch_size=100 # batch size in training
    start_step=0 # Starting step (if resuming training will be overwrite from filename).
    n_steps=1000000 # Optimization steps.
    lr= 1e-5 # Learning rate.
    weight_decay= 1e-3 # Weight decay.
    beta=1.0 # Multiplier for the target potential loss.

In [26]:

# cuda to device
args.device = torch.device('cuda:0' if torch.cuda.is_available() and args.use_cuda else 'cpu')

# set up random seed
torch.manual_seed(args.seed)
if args.device.type == 'cuda': torch.cuda.manual_seed(args.seed)

# setup flow
flow = nn.Sequential(AffineTransform(args.learn_base), *[PlanarTransform(args.init_sigma) for _ in range(args.flow_length)]).to(args.device)

# setup target potential to approx
u_z = vars()[args.target_potential]

# setup base distribution
base_dist = D.MultivariateNormal(torch.zeros(3).to(args.device), args.base_sigma * torch.eye(3).to(args.device))

if args.restore_file:
    # get filename
    filename = os.path.basename(args.restore_file)
    args.flow_length = int(filename.partition('length_')[-1].rpartition('.')[0])
    # reset output dir
    args.output_dir = os.path.dirname(args.restore_file)
    # load state
    state = torch.load(args.restore_file, map_location=args.device)
    # compatibility code;
    # 1/ earlier models did not include step and optimizer checkpoints;
    try:
        flow_state = state['flow_state']
        optimizer_state = state['optimizer_state']
        args.start_step = state['step']
    except KeyError:
        # if state is not a dict, load just the model state
        flow_state = state
        optimizer_state = None
    # 2/ some saved checkpoints may not have a first affine layer
    try:
        flow_state['0.mu']
    except KeyError:
        # if no first affine layer, reload a flow model without one
        flow = nn.Sequential(*[PlanarTransform(args.init_sigma) for _ in range(args.flow_length)])
    flow.load_state_dict(flow_state)  

In [27]:
if not os.path.isdir(args.output_dir):
    os.makedirs(args.output_dir)

In [28]:
if args.train:
    optimizer = torch.optim.RMSprop(flow.parameters(), lr=args.lr, momentum=0.9, alpha=0.90, eps=1e-6, weight_decay=args.weight_decay)
    if args.restore_file and optimizer_state:
        optimizer.load_state_dict(optimizer_state)
    args.n_steps = args.start_step + args.n_steps
    optimize_flow(base_dist, flow, u_z, optimizer, args)

u_z1: step     0 / 1000000; loss 5.305; base_log_prob -4.281, sum log dets -9.488, p_log_prob -0.098, max base = 0.062; max qk = 66246.328                 zk_mean [-0.00331786 -0.01490544], zk_sigma [0.11861719 0.8653099 ]; base_mu [3.1620839e-05 3.1622192e-05], base_log_sigma [[ 0.3922652  -0.22359563 -0.3195319 ]
 [-1.2050055   1.0444951  -0.63319606]]
u_z1: step 10000 / 1000000; loss 0.408; base_log_prob -4.215, sum log dets -3.780, p_log_prob -0.843, max base = 0.059; max qk = 170.470                 zk_mean [0.06626673 0.30173114], zk_sigma [1.7746239 1.2242069]; base_mu [0.04641005 0.00958483], base_log_sigma [[ 0.04715291 -0.2188428  -0.26065427]
 [-0.7391984   1.0915568  -0.56721616]]
u_z1: step 20000 / 1000000; loss 0.183; base_log_prob -4.335, sum log dets -3.685, p_log_prob -0.833, max base = 0.062; max qk = 112.701                 zk_mean [-0.26971352 -0.05520277], zk_sigma [1.8085682 1.2111324]; base_mu [0.01047704 0.00060594], base_log_sigma [[ 0.04624397 -0.21629277 -0.2

u_z1: step 240000 / 1000000; loss -0.061; base_log_prob -4.085, sum log dets -3.301, p_log_prob -0.723, max base = 0.062; max qk = 14.129                 zk_mean [-0.51264274 -0.12776595], zk_sigma [1.7876046 1.1394285]; base_mu [0.07512118 0.31316882], base_log_sigma [[ 0.0012201  -0.04646339 -0.21345651]
 [-0.01778726  0.95293266 -0.33602467]]
u_z1: step 250000 / 1000000; loss -0.360; base_log_prob -4.086, sum log dets -3.195, p_log_prob -0.531, max base = 0.062; max qk = 2.206                 zk_mean [-0.27137855  0.11972418], zk_sigma [1.8328567 1.1076787]; base_mu [0.07590809 0.30227453], base_log_sigma [[ 0.00103848 -0.0413519  -0.21463732]
 [-0.00941067  0.9576634  -0.326898  ]]
u_z1: step 260000 / 1000000; loss -0.134; base_log_prob -4.363, sum log dets -3.497, p_log_prob -0.732, max base = 0.063; max qk = 2.558                 zk_mean [-0.15139993 -0.09724956], zk_sigma [1.8240047 1.2092304]; base_mu [0.07551573 0.31101236], base_log_sigma [[ 8.2767947e-04 -2.7286815e-02 -2.13

u_z1: step 480000 / 1000000; loss -0.155; base_log_prob -4.181, sum log dets -3.342, p_log_prob -0.684, max base = 0.059; max qk = 1.930                 zk_mean [-0.19116299  0.08708123], zk_sigma [1.8255352 1.1891335]; base_mu [0.15654764 0.33950076], base_log_sigma [[ 7.0888072e-04  5.9756685e-02 -1.9000770e-01]
 [ 1.1161401e-04  8.6085081e-01  1.4137396e-01]]
u_z1: step 490000 / 1000000; loss -0.014; base_log_prob -4.170, sum log dets -3.403, p_log_prob -0.754, max base = 0.061; max qk = 8.889                 zk_mean [-0.2095502  -0.09908077], zk_sigma [1.7501527 1.1791031]; base_mu [0.15554747 0.32884005], base_log_sigma [[-0.00090421  0.06383959 -0.18862669]
 [-0.00889247  0.87046504  0.13162218]]
u_z1: step 500000 / 1000000; loss 0.009; base_log_prob -4.290, sum log dets -3.488, p_log_prob -0.811, max base = 0.058; max qk = 2.461                 zk_mean [-0.43407884  0.00327169], zk_sigma [1.7743345 1.1372582]; base_mu [0.15860909 0.31607693], base_log_sigma [[-0.00289655  0.0638

u_z1: step 720000 / 1000000; loss -0.204; base_log_prob -4.780, sum log dets -3.725, p_log_prob -0.851, max base = 0.056; max qk = 2.851                 zk_mean [-0.20553118  0.07021464], zk_sigma [1.8281556 1.2033936]; base_mu [0.19776157 0.32278547], base_log_sigma [[-0.00150504  0.09653237 -0.19075443]
 [ 0.00818001  0.75245553  0.13631377]]
u_z1: step 730000 / 1000000; loss -0.080; base_log_prob -4.353, sum log dets -3.499, p_log_prob -0.774, max base = 0.060; max qk = 4.382                 zk_mean [-0.00930209 -0.07165689], zk_sigma [1.867768  1.1406188]; base_mu [0.19484356 0.33676004], base_log_sigma [[-0.00189991  0.09942336 -0.18807277]
 [ 0.01158103  0.7525999   0.12335551]]
u_z1: step 740000 / 1000000; loss 0.036; base_log_prob -4.383, sum log dets -3.697, p_log_prob -0.722, max base = 0.063; max qk = 6.592                 zk_mean [-0.01512934 -0.10657523], zk_sigma [1.847671  1.2157375]; base_mu [0.19668245 0.32948905], base_log_sigma [[-3.5680342e-04  9.3985736e-02 -1.8913

u_z1: step 960000 / 1000000; loss 0.034; base_log_prob -4.247, sum log dets -3.508, p_log_prob -0.772, max base = 0.055; max qk = 2.067                 zk_mean [ 0.10073623 -0.02578577], zk_sigma [1.7772486 1.2428977]; base_mu [0.21340333 0.32521868], base_log_sigma [[-0.00088182  0.05656326 -0.21544845]
 [ 0.01038726  0.68266004 -0.08968596]]
u_z1: step 970000 / 1000000; loss 0.130; base_log_prob -4.323, sum log dets -3.620, p_log_prob -0.834, max base = 0.056; max qk = 3.421                 zk_mean [-0.14885548  0.04407089], zk_sigma [1.8535575 1.0973   ]; base_mu [0.21585815 0.3244979 ], base_log_sigma [[ 0.00214734  0.04696902 -0.21540356]
 [ 0.006734    0.6782986  -0.09425732]]
u_z1: step 980000 / 1000000; loss -0.172; base_log_prob -4.362, sum log dets -3.364, p_log_prob -0.827, max base = 0.059; max qk = 1.638                 zk_mean [-0.16223717  0.0739743 ], zk_sigma [1.8198909 1.0746362]; base_mu [0.21332502 0.32304028], base_log_sigma [[-0.00274083  0.05034275 -0.21589331]
 

In [29]:
if args.evaluate:
    plot_flow(base_dist, flow, os.path.join(args.output_dir, 'approximating_flow'), args)

In [30]:
if args.plot:
    plot_target_density(u_z, plt.gca(), output_dir=args.output_dir)
    plot_flow_density(base_dist, flow, plt.gca(), output_dir=args.output_dir)