In [1]:
import comet_ml
from comet_ml import API
from comet_ml import Experiment

experiment = Experiment('HU8yNOX96Ang8huavKsvrTbiK', project_name="diff_sim_ffjord", workspace="schattengenie")

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/schattengenie/diff-sim-ffjord/b2852c1881f24e158b30cb86dea8b948



In [2]:
from model import YModel

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch

import numpy as np
import pandas as pd
from tqdm import trange
%pylab inline

import matplotlib.pyplot as plt
import seaborn as sns

Populating the interactive namespace from numpy and matplotlib


## Recovering conditional density with FFJORD

# All needed imports

In [3]:
device = torch.device('cuda:1')

In [4]:
!pip install git+https://github.com/rtqichen/torchdiffeq.git

Collecting git+https://github.com/rtqichen/torchdiffeq.git
  Cloning https://github.com/rtqichen/torchdiffeq.git to /tmp/pip-req-build-5jredbvr
  Running command git clone -q https://github.com/rtqichen/torchdiffeq.git /tmp/pip-req-build-5jredbvr
Building wheels for collected packages: torchdiffeq
  Building wheel for torchdiffeq (setup.py) ... [?25ldone
[?25h  Stored in directory: /tmp/pip-ephem-wheel-cache-fu7w4n5u/wheels/f1/89/ce/78b4c1aabbb8dad56a2dbd776f9ffcbeca103b2ddae40d094b
Successfully built torchdiffeq


In [5]:
import sys
import torch
import torch.optim as optim
from IPython.display import clear_output

sys.path.append('./ffjord/')
import ffjord.lib.utils as utils
from ffjord.lib.visualize_flow import visualize_transform
import ffjord.lib.layers.odefunc as odefunc
from ffjord.train_misc import standard_normal_logprob
from ffjord.train_misc import count_nfe, count_parameters, count_total_time
from ffjord.train_misc import add_spectral_norm, spectral_norm_power_iteration
from ffjord.train_misc import create_regularization_fns, get_regularization, append_regularization_to_log
from ffjord.train_misc import build_model_tabular
import lib.layers as layers

In [6]:
SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams']

print(odefunc.NONLINEARITIES)

{'tanh': Tanh(), 'relu': ReLU(), 'softplus': Softplus(beta=1, threshold=20), 'elu': ELU(alpha=1.0), 'swish': Swish(), 'square': Lambda(), 'identity': Lambda()}


In [7]:
def set_cnf_options(model, solver, rademacher, residual, atol=1e-4, rtol=1e-4):

    def _set(module):
        if isinstance(module, layers.CNF):
            # Set training settings
            module.solver = solver
            module.atol = atol
            module.rtol = rtol

            # If using fixed-grid adams, restrict order to not be too high.
            if solver in ['fixed_adams', 'explicit_adams']:
                module.solver_options['max_order'] = 4

        if isinstance(module, layers.ODEfunc):
            module.rademacher = rademacher
            module.residual = residual

    model.apply(_set)
    
# layer_type - ["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"]
def build_model_tabular(dims=2,
                        layer_type='concatsquash', 
                        nonlinearity='relu', 
                        residual=False, 
                        rademacher=False,
                        train_T=True,
                        solver='dopri5',
                        time_length=0.1,
                        divergence_fn='approximate', # ["brute_force", "approximate"]
                        hidden_dims=(32, 32), 
                        num_blocks=1, batch_norm=False, 
                        bn_lag=0, regularization_fns=None):


    def build_cnf():
        diffeq = layers.ODEnet(
            hidden_dims=hidden_dims,
            input_shape=(dims,),
            strides=None,
            conv=False,
            layer_type=layer_type,
            nonlinearity=nonlinearity,
        )
        odefunc = layers.ODEfunc(
            diffeq=diffeq,
            divergence_fn=divergence_fn,
            residual=residual,
            rademacher=rademacher,
        )
        cnf = layers.CNF(
            odefunc=odefunc,
            T=time_length,
            train_T=train_T,
            regularization_fns=regularization_fns,
            solver=solver,
        )
        return cnf

    chain = [build_cnf() for _ in range(num_blocks)]
    if batch_norm:
        bn_layers = [layers.MovingBatchNorm1d(dims, bn_lag=bn_lag) for _ in range(num_blocks)]
        bn_chain = [layers.MovingBatchNorm1d(dims, bn_lag=bn_lag)]
        for a, b in zip(chain, bn_layers):
            bn_chain.append(a)
            bn_chain.append(b)
        chain = bn_chain
    model = layers.SequentialFlow(chain)

    set_cnf_options(model, solver, rademacher, residual)

    return model


In [8]:
import ffjord.lib.layers.wrappers.cnf_regularization as reg_lib
import six

REGULARIZATION_FNS = {
    "l1int": reg_lib.l1_regularzation_fn,
    "l2int": reg_lib.l2_regularzation_fn,
    "dl2int": reg_lib.directional_l2_regularization_fn,
    "JFrobint": reg_lib.jacobian_frobenius_regularization_fn,
    "JdiagFrobint": reg_lib.jacobian_diag_frobenius_regularization_fn,
    "JoffdiagFrobint": reg_lib.jacobian_offdiag_frobenius_regularization_fn,
}

def create_regularization_fns(regs={'l1int': 1., 'JFrobint': 1.}):
    regularization_fns = []
    regularization_coeffs = []

    for arg_key, reg_fn in six.iteritems(REGULARIZATION_FNS):
        if arg_key in regs:
            regularization_fns.append(reg_fn)
            regularization_coeffs.append(regs[arg_key])

    regularization_fns = tuple(regularization_fns)
    regularization_coeffs = tuple(regularization_coeffs)
    return regularization_fns, regularization_coeffs


def get_regularization(model, regularization_coeffs):
    if len(regularization_coeffs) == 0:
        return None

    acc_reg_states = tuple([0.] * len(regularization_coeffs))
    for module in model.modules():
        if isinstance(module, layers.CNF):
            acc_reg_states = tuple(acc + reg for acc, reg in zip(acc_reg_states, module.get_regularization_states()))
    return acc_reg_states

In [9]:
import warnings
warnings.filterwarnings("ignore")

In [10]:
def get_transforms(model):

    def sample_fn(z, logpz=None):
        if logpz is not None:
            return model(z, logpz, reverse=True)
        else:
            return model(z, reverse=True)

    def density_fn(x, logpx=None):
        if logpx is not None:
            return model(x, logpx, reverse=False)
        else:
            return model(x, reverse=False)

    return sample_fn, density_fn

# Sampling train dataset

In [11]:
df = pd.read_csv('./simple_surr.csv')
df.drop(columns=['Unnamed: 0'], inplace=True)
df = df.loc[df.magn_len == 8]

In [12]:
def cart2sph(x, y, z):
    hxy = np.hypot(x, y)
    r = np.hypot(hxy, z)
    el = np.arctan2(z, hxy)
    az = np.arctan2(y, x)
    return az, np.pi / 2 - el, r
start_theta = df.start_theta.values
start_phi = df.start_phi.values
az, el, r = cart2sph(df.start_px.values, df.start_py.values, df.start_pz.values)

In [13]:
# np.sqrt(df[['start_px', 'start_py', 'start_pz']].pow(2).sum(axis=1).values) 'magn_len'
# init_cond = np.c_[, df[['start_theta', 'start_phi']].values]
result = df[['hit_x', 'hit_y']].values

In [14]:
plt.hist(result[:, 0], bins=100);

In [15]:
init_cond.shape, result.shape

NameError: name 'init_cond' is not defined

In [None]:
data = result / 500 # np.concatenate([init_cond, result], axis=1)
data = torch.tensor(data).to(device).float()
# data = data / torch.tensor([1., 0.03, 2., 500., 100.]).to(device)
# data = data / torch.tensor([100., 100.]).to(device)

In [None]:
data.std(dim=0)

# Defining FFJORD model

In [None]:
SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams']
# layer_type - ["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"]

In [None]:
regularization_fns = None
#regularization_fns, regularization_coeffs = create_regularization_fns()
model = build_model_tabular(dims=data.size(1),       
                            layer_type='concatsquash',
                            num_blocks=2,
                            time_length=.5,
                            rademacher=False, # descrete distr?
                            nonlinearity='tanh',
                            solver='rk4',
                            hidden_dims=(32, 32, 32), 
                            batch_norm=False,
                            regularization_fns=regularization_fns).to(device)

In [None]:
import math

def standard_normal_logprob(z):
    logZ = -0.5 * math.log(2 * math.pi)
    return logZ - z.pow(2) / 2

def compute_loss(model, data, batch_size=None):
    zero = torch.zeros(data.shape[0], 1).to(data.device)
    z, delta_logp = model(data, zero)

    # compute log q(z)
    logpz = standard_normal_logprob(z)

    logpx = logpz.sum(1, keepdim=True) - delta_logp
    loss = -torch.mean(logpx)
    return loss

In [None]:
from torch.nn.utils import clip_grad_norm
from tqdm import tqdm
from torchcontrib.optim import SWA

In [None]:
optimizer_base = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)
optimizer = SWA(optimizer_base, swa_start=200, swa_freq=30, swa_lr=1e-5)

loss_meter = utils.RunningAverageMeter(0.5)
nfef_meter = utils.RunningAverageMeter(0.5)
nfeb_meter = utils.RunningAverageMeter(0.5)

In [None]:
B = 1000

In [None]:
%pylab inline
model.train();
losses = []
for i in tqdm(range(10000)):
    optimizer.zero_grad()
    
    # loss, loss_xy = compute_loss(model, data) # [torch.randperm(len(data))[:B]])
    loss = compute_loss(model, data[torch.randperm(len(data))[:B]])

    loss_meter.update(loss.item())
    # nfe_forward = count_nfe(model)

    loss.backward()
    optimizer.step()
    nfe_total = count_nfe(model)
    # nfe_backward = nfe_total - nfe_forward
    # nfef_meter.update(nfe_forward)
    # nfeb_meter.update(nfe_backward)
    clip_grad_norm(model.parameters(), 5)
    losses.append(loss_meter.avg)
    if i % 50 == 0:
        clear_output()
        plt.figure()
        plt.plot(losses)
        plt.show()

In [None]:
model.eval();
sample_fn, density_fn = get_transforms(model)

In [None]:
sampled = sample_fn(torch.randn(1000, 2).float().to(device))

In [None]:
for i in range(2):
    plt.figure(figsize=(6, 6))
    plt.hist(data[:, i].detach().cpu().numpy(), bins=100, label='true', alpha=0.5)
    plt.hist(sampled[:, i].detach().cpu().numpy(), bins=100, label='sampled', alpha=0.5)
    plt.legend()
    plt.show()

In [None]:
data

In [None]:
%%time

from pyro import distributions as dist
my_cmap = plt.cm.jet
my_cmap.set_under('white')
mu_range = (1, 14)
mu = torch.linspace(*mu_range, 20).view(-1, 1).to(device)
N = 5000

results = []

for i in tqdm(range(len(mu))):
    mu_r = mu[i, :].reshape(1, -1).repeat(N, 1).to(device)
    init_ = dist.Uniform(low=torch.tensor([5., 0., -np.pi]), high=torch.tensor([10., 0.09, np.pi])).sample((N,)).to(device)
    inputs_test = torch.cat([
        torch.randn(len(mu_r), 2).float().to(device)
    ], dim=1)
    sampled_data = sample_fn(inputs_test)
    results.append(sampled_data)
    clear_output()

In [None]:
# energy, theta, phi, mu, x, y

In [None]:
for result in results:
    plt.scatter(*result[:, [4, 5]].t().cpu().detach().numpy())
    plt.show()

In [None]:
plt.scatter(*df.loc[df.magn_len == 8][['hit_y', 'hit_x']].values.T)

In [None]:
data.shape

In [None]:
plt.scatter(*data[:, [4, 5]].detach().cpu().numpy().T)

In [None]:
init_cond.min(axis=0), init_cond.max(axis=0)

In [None]:
data.shape