In [None]:
import argparse
import os
import sys
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as tforms
from torchvision.utils import save_image


""" PROXIMAL IMPORTS """
sys.path.append('../../')
from torchdiffeq import odeint_adjoint as odeint
from proxNode.lib.prox.adjoint import odeint_adjoint as odeint_prox



""" FFJORD IMPORTS """
from .wrappers.cnf_regularization import RegularizedODEfunc

import proxNode.lib.layers as layers
import proxNode.lib.utils as utils
import proxNode.lib.odenvp as odenvp
import proxNode.lib.multiscale_parallel as multiscale_parallel

from train_misc import standard_normal_logprob
from train_misc import set_cnf_options, count_nfe, count_parameters, count_total_time
from train_misc import add_spectral_norm, spectral_norm_power_iteration
from train_misc import create_regularization_fns, get_regularization, append_regularization_to_log

__all__ = ["CNF"]

**Manual Adjustment of Proximal Solver in the Forward Block Below**

In [None]:
class CNF(nn.Module):
    def __init__(self, odefunc, T=1.0, train_T=False, regularization_fns=None, solver='dopri5', atol=1e-5, rtol=1e-5):
        super(CNF, self).__init__()
        if train_T:
            self.register_parameter("sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T))))
        else:
            self.register_buffer("sqrt_end_time", torch.sqrt(torch.tensor(T)))

        nreg = 0
        if regularization_fns is not None:
            odefunc = RegularizedODEfunc(odefunc, regularization_fns)
            nreg = len(regularization_fns)
        self.odefunc = odefunc
        self.nreg = nreg
        self.regularization_states = None
        self.solver = solver
        self.atol = atol
        self.rtol = rtol
        self.test_solver = solver
        self.test_atol = atol
        self.test_rtol = rtol
        self.solver_options = {}

    def forward(self, z, logpz=None, integration_times=None, reverse=False):

        if logpz is None:
            _logpz = torch.zeros(z.shape[0], 1).to(z)
        else:
            _logpz = logpz

        if integration_times is None:
            integration_times = torch.tensor([0.0, self.sqrt_end_time * self.sqrt_end_time]).to(z)
        if reverse:
            integration_times = _flip(integration_times, 0)

        # Refresh the odefunc statistics.
        self.odefunc.before_odeint()

        # Add regularization states.
        reg_states = tuple(torch.tensor(0).to(z) for _ in range(self.nreg))

        if self.training:
            if self.solver == 'dopri5':
                state_t = odeint(
                   self.odefunc,
                   (z, _logpz) + reg_states,
                   integration_times.to(z),
                   atol=self.atol,
                   rtol=self.rtol,
                   method=self.solver,
                   options=self.solver_options,
                )
            elif self.solver == 'proximal':  #### PROXIMAL ADJUSTMENTS HERE
                state_t = odeint(
                    self.odefunc,
                    (z, _logpz) + reg_states,
                    integration_times.to(z),
                    opt_method='fletch_reeves',
                    prox_method='bdf4',
                    int_step=.2,
                    opt_step=.3,
                    adjoint_options={'device':'cuda:0', 'opt_iters':60, 'tol':.01, 'grad_tol':.01}
                )
            else:
                raise Exception('CNF solver not in [dopri5,proximal]')
        else:
            if self.solver == 'dopri5':
                state_t = odeint(
                    self.odefunc,
                    (z, _logpz),
                    integration_times.to(z),
                    atol=self.test_atol,
                    rtol=self.test_rtol,
                    method=self.test_solver,
                )
            elif self.solver == 'proximal':  #### PROXIMAL ADJUSTMENTS HERE
                state_t = odeint(
                    self.odefunc,
                        (z, _logpz) + reg_states,
                        integration_times.to(z),
                        opt_method='fletch_reeves',
                        prox_method='bdf4',
                        int_step=.2,
                        opt_step=.3,
                        adjoint_options={'device':'cuda:0', 'opt_iters':60, 'tol':.01, 'grad_tol':.01}
                )
            else:
                raise Exception('CNF solver not in [dopri5,proximal]')

        if len(integration_times) == 2:
            state_t = tuple(s[1] for s in state_t)

        z_t, logpz_t = state_t[:2]
        self.regularization_states = state_t[2:]

        if logpz is not None:
            return z_t, logpz_t
        else:
            return z_t

    def get_regularization_states(self):
        reg_states = self.regularization_states
        self.regularization_states = None
        return reg_states

    def num_evals(self):
        return self.odefunc._num_evals.item()


def _flip(x, dim):
    indices = [slice(None)] * x.dim()
    indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
    return x[tuple(indices)]


In [None]:
torch.backends.cudnn.benchmark = True
SOLVERS = ["dopri5", 'proximal']
parser = argparse.ArgumentParser("Continuous Normalizing Flow")
parser.add_argument("--data", choices=["mnist", "svhn", "cifar10", 'lsun_church'], type=str, default="mnist")
parser.add_argument("--dims", type=str, default="64,64,64")
parser.add_argument("--strides", type=str, default="1,1,1,1")
parser.add_argument("--num_blocks", type=int, default=2, help='Number of stacked CNFs.')

parser.add_argument("--conv", type=eval, default=True, choices=[True, False])
parser.add_argument(
    "--layer_type", type=str, default="concat",
    choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"]
)
parser.add_argument("--divergence_fn", type=str, default="approximate", choices=["brute_force", "approximate"])
parser.add_argument(
    "--nonlinearity", type=str, default="softplus", choices=["tanh", "relu", "softplus", "elu", "swish"]
)
parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS)
parser.add_argument('--atol', type=float, default=1e-5)
parser.add_argument('--rtol', type=float, default=1e-7)
parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.")

parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None])
parser.add_argument('--test_atol', type=float, default=None)
parser.add_argument('--test_rtol', type=float, default=None)

parser.add_argument("--imagesize", type=int, default=None)
parser.add_argument("--alpha", type=float, default=1e-6)
parser.add_argument('--time_length', type=float, default=1.0)
parser.add_argument('--train_T', type=eval, default=True)

parser.add_argument("--num_epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=200)
parser.add_argument(
    "--batch_size_schedule", type=str, default="", help="Increases the batchsize at every given epoch, dash separated."
)
parser.add_argument("--test_batch_size", type=int, default=200)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--warmup_iters", type=float, default=1000)
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument("--spectral_norm_niter", type=int, default=10)

parser.add_argument("--add_noise", type=eval, default=True, choices=[True, False])
parser.add_argument("--batch_norm", type=eval, default=False, choices=[True, False])
parser.add_argument('--residual', type=eval, default=False, choices=[True, False])
parser.add_argument('--autoencode', type=eval, default=False, choices=[True, False])
parser.add_argument('--rademacher', type=eval, default=True, choices=[True, False])
parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False])
parser.add_argument('--multiscale', type=eval, default=True, choices=[True, False])
parser.add_argument('--parallel', type=eval, default=False, choices=[True, False])

# Regularizations
parser.add_argument('--l1int', type=float, default=None, help="int_t ||f||_1")
parser.add_argument('--l2int', type=float, default=None, help="int_t ||f||_2")
parser.add_argument('--dl2int', type=float, default=None, help="int_t ||f^T df/dt||_2")
parser.add_argument('--JFrobint', type=float, default=None, help="int_t ||df/dx||_F")
parser.add_argument('--JdiagFrobint', type=float, default=None, help="int_t ||df_i/dx_i||_F")
parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F")

parser.add_argument("--time_penalty", type=float, default=0, help="Regularization on the end_time.")
parser.add_argument(
    "--max_grad_norm", type=float, default=1e10,
    help="Max norm of graidents (default is just stupidly high to avoid any clipping)"
)

parser.add_argument("--begin_epoch", type=int, default=1)
parser.add_argument("--resume", type=str, default=None)
parser.add_argument("--save", type=str, default="experiments/cnf")
parser.add_argument("--val_freq", type=int, default=1)
parser.add_argument("--log_freq", type=int, default=10)

args = parser.parse_args()

# logger
utils.makedirs(args.save)
logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))

if args.layer_type == "blend":
    logger.info("!! Setting time_length from None to 1.0 due to use of Blend layers.")
    args.time_length = 1.0

logger.info(args)


In [None]:

def add_noise(x):
    """
    [0, 1] -> [0, 255] -> add noise -> [0, 1]
    """
    if args.add_noise:
        noise = x.new().resize_as_(x).uniform_()
        x = x * 255 + noise
        x = x / 256
    return x


def update_lr(optimizer, itr):
    iter_frac = min(float(itr + 1) / max(args.warmup_iters, 1), 1.0)
    lr = args.lr * iter_frac
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def get_train_loader(train_set, epoch):
    if args.batch_size_schedule != "":
        epochs = [0] + list(map(int, args.batch_size_schedule.split("-")))
        n_passed = sum(np.array(epochs) <= epoch)
        current_batch_size = int(args.batch_size * n_passed)
    else:
        current_batch_size = args.batch_size
    train_loader = torch.utils.data.DataLoader(
        dataset=train_set, batch_size=current_batch_size, shuffle=True, drop_last=True, pin_memory=True
    )
    logger.info("===> Using batch size {}. Total {} iterations/epoch.".format(current_batch_size, len(train_loader)))
    return train_loader


def get_dataset(args):
    trans = lambda im_size: tforms.Compose([tforms.Resize(im_size), tforms.ToTensor(), add_noise])

    if args.data == "mnist":
        im_dim = 1
        im_size = 28 if args.imagesize is None else args.imagesize
        train_set = dset.MNIST(root="./data", train=True, transform=trans(im_size), download=True)
        test_set = dset.MNIST(root="./data", train=False, transform=trans(im_size), download=True)
    elif args.data == "svhn":
        im_dim = 3
        im_size = 32 if args.imagesize is None else args.imagesize
        train_set = dset.SVHN(root="./data", split="train", transform=trans(im_size), download=True)
        test_set = dset.SVHN(root="./data", split="test", transform=trans(im_size), download=True)
    elif args.data == "cifar10":
        im_dim = 3
        im_size = 32 if args.imagesize is None else args.imagesize
        train_set = dset.CIFAR10(
            root="./data", train=True, transform=tforms.Compose([
                tforms.Resize(im_size),
                tforms.RandomHorizontalFlip(),
                tforms.ToTensor(),
                add_noise,
            ]), download=True
        )
        test_set = dset.CIFAR10(root="./data", train=False, transform=trans(im_size), download=True)
    elif args.data == 'celeba':
        im_dim = 3
        im_size = 64 if args.imagesize is None else args.imagesize
        train_set = dset.CelebA(
            train=True, transform=tforms.Compose([
                tforms.ToPILImage(),
                tforms.Resize(im_size),
                tforms.RandomHorizontalFlip(),
                tforms.ToTensor(),
                add_noise,
            ])
        )
        test_set = dset.CelebA(
            train=False, transform=tforms.Compose([
                tforms.ToPILImage(),
                tforms.Resize(im_size),
                tforms.ToTensor(),
                add_noise,
            ])
        )
    elif args.data == 'lsun_church':
        im_dim = 3
        im_size = 64 if args.imagesize is None else args.imagesize
        train_set = dset.LSUN(
            'data', ['church_outdoor_train'], transform=tforms.Compose([
                tforms.Resize(96),
                tforms.RandomCrop(64),
                tforms.Resize(im_size),
                tforms.ToTensor(),
                add_noise,
            ])
        )
        test_set = dset.LSUN(
            'data', ['church_outdoor_val'], transform=tforms.Compose([
                tforms.Resize(96),
                tforms.RandomCrop(64),
                tforms.Resize(im_size),
                tforms.ToTensor(),
                add_noise,
            ])
        )
    data_shape = (im_dim, im_size, im_size)
    if not args.conv:
        data_shape = (im_dim * im_size * im_size,)

    test_loader = torch.utils.data.DataLoader(
        dataset=test_set, batch_size=args.test_batch_size, shuffle=False, drop_last=True
    )
    return train_set, test_loader, data_shape


def compute_bits_per_dim(x, model):
    zero = torch.zeros(x.shape[0], 1).to(x)

    # Don't use data parallelize if batch size is small.
    # if x.shape[0] < 200:
    #     model = model.module

    z, delta_logp = model(x, zero)  # run model forward

    logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True)  # logp(z)
    logpx = logpz - delta_logp

    logpx_per_dim = torch.sum(logpx) / x.nelement()  # averaged over batches
    bits_per_dim = -(logpx_per_dim - np.log(256)) / np.log(2)

    return bits_per_dim


def create_model(args, data_shape, regularization_fns):
    hidden_dims = tuple(map(int, args.dims.split(",")))
    strides = tuple(map(int, args.strides.split(",")))

    if args.multiscale:
        model = odenvp.ODENVP(
            (args.batch_size, *data_shape),
            n_blocks=args.num_blocks,
            intermediate_dims=hidden_dims,
            nonlinearity=args.nonlinearity,
            alpha=args.alpha,
            cnf_kwargs={"T": args.time_length, "train_T": args.train_T, "regularization_fns": regularization_fns},
        )
    elif args.parallel:
        model = multiscale_parallel.MultiscaleParallelCNF(
            (args.batch_size, *data_shape),
            n_blocks=args.num_blocks,
            intermediate_dims=hidden_dims,
            alpha=args.alpha,
            time_length=args.time_length,
        )
    else:
        if args.autoencode:

            def build_cnf():
                autoencoder_diffeq = layers.AutoencoderDiffEqNet(
                    hidden_dims=hidden_dims,
                    input_shape=data_shape,
                    strides=strides,
                    conv=args.conv,
                    layer_type=args.layer_type,
                    nonlinearity=args.nonlinearity,
                )
                odefunc = layers.AutoencoderODEfunc(
                    autoencoder_diffeq=autoencoder_diffeq,
                    divergence_fn=args.divergence_fn,
                    residual=args.residual,
                    rademacher=args.rademacher,
                )
                cnf = layers.CNF(
                    odefunc=odefunc,
                    T=args.time_length,
                    regularization_fns=regularization_fns,
                    solver=args.solver,
                )
                return cnf
        else:

            def build_cnf():
                diffeq = layers.ODEnet(
                    hidden_dims=hidden_dims,
                    input_shape=data_shape,
                    strides=strides,
                    conv=args.conv,
                    layer_type=args.layer_type,
                    nonlinearity=args.nonlinearity,
                )
                odefunc = layers.ODEfunc(
                    diffeq=diffeq,
                    divergence_fn=args.divergence_fn,
                    residual=args.residual,
                    rademacher=args.rademacher,
                )
                cnf = layers.CNF(
                    odefunc=odefunc,
                    T=args.time_length,
                    train_T=args.train_T,
                    regularization_fns=regularization_fns,
                    solver=args.solver,
                )
                return cnf

        chain = [layers.LogitTransform(alpha=args.alpha)] if args.alpha > 0 else [layers.ZeroMeanTransform()]
        chain = chain + [build_cnf() for _ in range(args.num_blocks)]
        if args.batch_norm:
            chain.append(layers.MovingBatchNorm2d(data_shape[0]))
        model = layers.SequentialFlow(chain)
    return model



In [None]:
# get deivce
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

# load dataset
train_set, test_loader, data_shape = get_dataset(args)

# build model
regularization_fns, regularization_coeffs = create_regularization_fns(args)
model = create_model(args, data_shape, regularization_fns)

if args.spectral_norm: add_spectral_norm(model, logger)
set_cnf_options(args, model)

logger.info(model)
logger.info("Number of trainable parameters: {}".format(count_parameters(model)))

# optimizer
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

# restore parameters
if args.resume is not None:
    checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpt["state_dict"])
    if "optim_state_dict" in checkpt.keys():
        optimizer.load_state_dict(checkpt["optim_state_dict"])
        # Manually move optimizer state to device.
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = cvt(v)

if torch.cuda.is_available():
    model = torch.nn.DataParallel(model).cuda()

# For visualization.
fixed_z = cvt(torch.randn(100, *data_shape))

time_meter = utils.RunningAverageMeter(0.97)
loss_meter = utils.RunningAverageMeter(0.97)
steps_meter = utils.RunningAverageMeter(0.97)
bkwd_steps_meter = utils.RunningAverageMeter(0.97)
grad_meter = utils.RunningAverageMeter(0.97)
tt_meter = utils.RunningAverageMeter(0.97)

if args.spectral_norm and not args.resume: spectral_norm_power_iteration(model, 500)

best_loss = float("inf")
itr = 0
for epoch in range(args.begin_epoch, args.num_epochs + 1):
    model.train()
    train_loader = get_train_loader(train_set, epoch)
    for _, (x, y) in enumerate(train_loader):
        start = time.time()
        update_lr(optimizer, itr)
        optimizer.zero_grad()

        if not args.conv:
            x = x.view(x.shape[0], -1)

        # cast data and move to device
        x = cvt(x)
        # compute loss
        loss = compute_bits_per_dim(x, model)
        if regularization_coeffs:
            reg_states = get_regularization(model, regularization_coeffs)
            reg_loss = sum(
                reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0
            )
            loss = loss + reg_loss
        total_time = count_total_time(model)
        loss = loss + total_time * args.time_penalty
        
        fwd_steps = count_nfe(model)

        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

        optimizer.step()

        if args.spectral_norm: spectral_norm_power_iteration(model, args.spectral_norm_niter)

        total_steps = count_nfe(model)
        time_meter.update(time.time() - start)
        loss_meter.update(loss.item())
        steps_meter.update(fwd_steps)
        bkwd_steps_meter.update(total_steps-fwd_steps)
        grad_meter.update(grad_norm)
        tt_meter.update(total_time)

        if itr % args.log_freq == 0:
            log_message = (
                "Iter {:04d} | Time {:.4f}({:.4f}) | Bit/dim {:.4f}({:.4f}) | "
                "NFE-F {:.0f}({:.2f}) | NFE-B {:.0f}({:.2f}) | Grad Norm {:.4f}({:.4f}) | Total Time {:.2f}({:.2f})".format(
                    itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, steps_meter.val,
                    steps_meter.avg, bkwd_steps_meter.val, bkwd_steps_meter.avg, grad_meter.val, grad_meter.avg, tt_meter.val, tt_meter.avg
                )
            )
            if regularization_coeffs:
                log_message = append_regularization_to_log(log_message, regularization_fns, reg_states)
            logger.info(log_message)

        itr += 1

    # compute test loss
    model.eval()
    if epoch % args.val_freq == 0:
        with torch.no_grad():
            start = time.time()
            logger.info("validating...")
            losses = []
            for (x, y) in test_loader:
                if not args.conv:
                    x = x.view(x.shape[0], -1)
                x = cvt(x)
                loss = compute_bits_per_dim(x, model)
                losses.append(loss)

            loss = torch.mean(torch.Tensor(losses))
            logger.info("Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}".format(epoch, time.time() - start, loss))
            if loss < best_loss:
                best_loss = loss
                utils.makedirs(args.save)
                torch.save({
                    "args": args,
                    "state_dict": model.module.state_dict() if torch.cuda.is_available() else model.state_dict(),
                    "optim_state_dict": optimizer.state_dict(),
                }, os.path.join(args.save, "checkpt.pth"))

    # visualize samples and density
    with torch.no_grad():
        fig_filename = os.path.join(args.save, "figs", "{:04d}.jpg".format(epoch))
        utils.makedirs(os.path.dirname(fig_filename))
        generated_samples = model(fixed_z, reverse=True).view(-1, *data_shape)
        save_image(generated_samples, fig_filename, nrow=10)