In [22]:
import sys
sys.path.append('../bhnerf')

import os

import jax
from jax import random
from jax import numpy as jnp

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import xarray as xr
import flax
from flax.training import train_state
import optax
import numpy as np
import matplotlib.pyplot as plt

import utils, emission_utils, visualization, network_utils
from network_utils import shard
import mediapy

from tensorboardX import SummaryWriter
from datetime import datetime
from tqdm.notebook import tqdm
import ipyvolume as ipv
from ipywidgets import interact
import ipywidgets as widgets

# from jax.config import config
# config.update("jax_debug_nans", True)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
nt, nx, ny, nz = 64, 64, 64, 64
fov = (10.0, 'GM/c^2')
std = (10.0, 10.0, 1.0)
gaussian = emission_utils.gaussian_3d(nx, ny, nz, std, fov)
gaussian = emission_utils.rotate_3d(gaussian, axis

array([0., 1.])

In [None]:
def intensity_to_nchw(intensity, cmap='viridis', gamma=0.5):
    cm = plt.get_cmap(cmap)
    norm_images = ( (intensity - np.min(intensity)) / (np.max(intensity) - np.min(intensity)) )**gamma
    nchw_images = np.moveaxis(cm(norm_images)[...,:3], (0, 1, 2, 3), (3, 2, 0, 1))
    return nchw_images

def train_network(sensor, emission, emission_test, velocity_field, hparams, runname, 
                  log_period=100, x_res_vis=64, y_res_vis=64, z_res_vis=64):
    
    # Training / testing coordinates
    train_coords = network_utils.get_input_coords(sensor, nt=emission.t.size)
    t, x, y, z, d = train_coords.values()
    train_radiance = measurements.data.ravel() 

    test_coords = network_utils.get_input_coords(sensor, nt=emission_test.t.size)
    test_radiance = measurements_test.data.ravel() 
    t_test, x_test, y_test, z_test, d_test = test_coords.values()

    # Emission visualization inputs
    t_res_vis = 1
    emission_extent = [emission.x[0], emission.x[-1], emission.y[0], emission.y[-1], emission.z[0], emission.z[-1]]
    t_vis, x_vis, y_vis, z_vis  = np.meshgrid(0.0, np.linspace(emission_extent[0], emission_extent[1], x_res_vis),
                                              np.linspace(emission_extent[2], emission_extent[3], y_res_vis),
                                              np.linspace(emission_extent[4], emission_extent[5], x_res_vis),
                                              indexing='ij')

    d_vis = np.ones_like(y_vis)                  # meaningless placeholder for emission visualization
    target_vis = np.ones([y_res_vis, x_res_vis]) # meaningless placeholder for emission visualization
    r_min = sensor.r.min().data                  # Zero network output where there is no supervision (within black-hole radius)  

    # Model setup and initialization
    rng = jax.random.PRNGKey(1)
    predictor = network_utils.PREDICT_EMISSION_AND_ROTAXIS_3D(posenc_deg=hparams['posenc_deg'])
    params = predictor.init(rng, x[:1, ...], y[:1, ...], z[:1, ...], t[:1, ...], velocity_field)['params']

    def flattened_traversal(fn):
        def mask(data):
            flat = flax.traverse_util.flatten_dict(data)
            return flax.traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})
        return mask

    tx = optax.chain(
        optax.masked(optax.adam(learning_rate=hparams['lr_axis']),
                     mask=flattened_traversal(lambda path, _: path[-1] == 'axis')),
        optax.masked(optax.adam(learning_rate=optax.polynomial_schedule(hparams['lr_init'], hparams['lr_final'], 1, hparams['num_iters'])),
                     mask=flattened_traversal(lambda path, _: path[-1] != 'axis')),
    )

    state = train_state.TrainState.create(apply_fn=predictor.apply, params=params.unfreeze(), tx=tx)  # TODO(pratul): this unfreeze feels sketchy

    train_pstep = jax.pmap(predictor.train_step, axis_name='batch', in_axes=(None, None, 0, 0, 0, 0, 0, 0, 0, 0), static_broadcasted_argnums=(0))
    eval_pstep = jax.pmap(predictor.eval_step, axis_name='batch', in_axes=(None, None, 0, 0, 0, 0, 0, 0, 0, 0), static_broadcasted_argnums=(0))
    rand_key = jax.random.split(rng, jax.local_device_count())
    state = flax.jax_utils.replicate(state)


    # TensorBoard logging
    time = datetime.now().strftime('%Y-%m-%d.%H:%M:%S')
    logdir = 'runs/{}.{}'.format(runname, time)
    with SummaryWriter(logdir=logdir) as writer:

        # Log ground-truth data   
        %matplotlib inline
        images = intensity_to_nchw(emission.isel(t=0))
        writer.add_images('emission/true', images, global_step=0)

        for i in tqdm(range(1, hparams['num_iters']+1), desc='iteration'):

            # Testing and Visualization
            if (i == 1) or (i % log_period) == 0:
                batch_inds = np.random.randint(0, x.shape[0], hparams['batchsize'])
                loss_test, _, _, _, _, _ = eval_pstep(
                    velocity_field, i, shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
                    shard(d[batch_inds, ...]), shard(t[batch_inds, ...]), shard(train_radiance[batch_inds, ...]), 
                    state, rand_key
                )
                writer.add_scalar('log loss/test', np.log10(np.mean(float(loss_test))), global_step=i)

                # Log prediction and estimate
                _, _, emission_vis, _, axis_estimation, _ = eval_pstep(
                    velocity_field, i, shard(x_vis), shard(y_vis), shard(z_vis), 
                    shard(d_vis), shard(t_vis), shard(target_vis), state, rand_key
                )
                axis_estimation = axis_estimation[0] / np.sqrt(np.dot(axis_estimation[0], axis_estimation[0]))
                emission_vis = np.reshape(emission_vis, [t_res_vis, x_res_vis, y_res_vis, z_res_vis])
                emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 >= r_min**2, emission_vis, jnp.zeros_like(emission_vis))
                
                # Log emission and rotation axis
                images = intensity_to_nchw(emission_vis[0])
                writer.add_images('emission/estimate', images, global_step=i)
                writer.add_scalar('rotation/dot_product', np.dot(rot_axis, axis_estimation), global_step=i)
                writer.add_scalar('rotation/x', axis_estimation[0], global_step=i)
                writer.add_scalar('rotation/y', axis_estimation[1], global_step=i)
                writer.add_scalar('rotation/z', axis_estimation[2], global_step=i)

            # Training
            batch_inds = np.random.randint(0, x.shape[0], hparams['batchsize'])
            loss_train, state, _, _, _, rand_key = train_pstep(
                velocity_field, i, shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
                shard(d[batch_inds, ...]), shard(t[batch_inds, ...]), shard(train_radiance[batch_inds, ...]), 
                state, rand_key
            )
            writer.add_scalar('log loss/train', np.log10(float(np.mean(loss_train))), global_step=i)
            
    return state, emission_vis, axis_estimation

In [None]:
sensor = xr.load_dataset('../sensors/a0.00_th1.57_ngeo200_npix10000.nc')

In [213]:
# Generate hotspot emission
nt, nx, ny, nz = 64, 64, 64, 64
nt_test = 64
nspots = 1
r_isco = 3.0 

phi = np.random.rand()*2*np.pi             # azimuth angle (ccw from x-axis)
theta = np.random.rand()*np.pi    # zenith angle (pi/2 = equatorial plane)

orbit_radius = 3.5 
std = .4 * np.ones_like(orbit_radius)
initial_frame = emission_utils.generate_hotspots_3d(nx, ny, nz, theta, phi, orbit_radius, std, r_isco)

orbit_period = 3.5**(-3./2.) 
velocity_field = lambda r: (1.0 / orbit_period) * r**(-3/2)

rot_axis =  np.array([np.cos(theta)*np.cos(phi), np.cos(theta)*np.sin(phi), -np.sin(theta)])
emission = emission_utils.generate_orbit_3d(initial_frame, nt, velocity_field, rot_axis)
# emission_test = emission_utils.generate_orbit_3d(initial_frame, nt_test, velocity_field, rot_axis)

In [183]:
phi, theta

(5.623960004418948, 2.3204244726992567)

In [215]:
nt, nx, ny, nz = 64, 64, 64, 64
rho = np.sqrt(rot_axis[0]**2 + rot_axis[1]**2)

fov = (10.0, 'GM/c^2')
std = (0.5, 10.0, 10.0)

phi_ = -np.arctan2(rho, rot_axis[2])
theta_ = np.arctan2(rot_axis[1], rot_axis[0])
axis = [-np.cos(theta_), np.sin(theta_), 0]
gaussian = emission_utils.gaussian_3d(nx, ny, nz, std, fov)
gaussian = emission_utils.rotate_3d(gaussian, axis=[0, 0, 1], angle=phi_)
gaussian = emission_utils.rotate_3d(gaussian, axis=axis, angle=-theta_)

ipv.figure()
ipv.view(0, -60, distance=2.5)
a = emission.sum('t')
# a = emission.sel(t=0)
ipv.volshow(gaussian + a, 
            memorder='F', extent=extent, level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3])
ipv.show()

  gradient = gradient / np.sqrt(gradient[0] ** 2 + gradient[1] ** 2 + gradient[2] ** 2)


VBox(children=(VBox(children=(HBox(children=(Label(value='levels:'), FloatSlider(value=0.0, max=1.0, step=0.00…

In [82]:
nt, nx, ny, nz = 64, 64, 64, 64
rho = np.sqrt(rot_axis[0]**2 + rot_axis[1]**2)
axis =  np.array([-rot_axis[1]/rho, rot_axis[0]/rho, 0])

fov = (10.0, 'GM/c^2')
std = (10.0, 10.0, 1.0)
gaussian = emission_utils.gaussian_3d(nx, ny, nz, std, fov)
gaussian = emission_utils.rotate_3d(gaussian, axis, theta)

ipv.figure()
ipv.view(0, -60, distance=2.5)
ipv.volshow(gaussian + emission.sum('t'), memorder='F', extent=extent, level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3])
ipv.show()

VBox(children=(VBox(children=(HBox(children=(Label(value='levels:'), FloatSlider(value=0.0, max=1.0, step=0.00…

In [61]:
extent = [(float(initial_frame[dim].min()), float(initial_frame[dim].max())) for dim in initial_frame.dims]
@interact(t=widgets.IntSlider(min=0, max=emission.t.size-1, step=1, value=0))
def plot_vol(t):
    ipv.figure()
    ipv.view(0, -60, distance=2.5)
    ipv.volshow(emission.sel(t=t), extent=extent, memorder='F', level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3], controls=False)
    ipv.show()

interactive(children=(IntSlider(value=0, description='t', max=63), Output()), _dom_classes=('widget-interact',…

In [97]:
# Visualize radiance measurements 
%matplotlib widget
measurements_2d = emission_utils.integrate_rays(emission, sensor.where(sensor.r < 5))
measurements_2d = xr.DataArray(measurements_2d.data.reshape(nt, sensor.num_alpha, sensor.num_beta), dims=['t', 'alpha', 'beta'])
anim = measurements_2d.utils_visualization.animate()

#outpath = 'gifs/unknown_rotaxis.measurements.theta_{:2.0f}.phi_{:2.0f}.gif'.format(np.rad2deg(theta), np.rad2deg(phi))
#anim.save(outpath, writer='imagemagick', fps=10)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



In [9]:
# Drop rays which do not cross the domain (this gives measurements that cannot be reshaped into an image)
sensor = sensor.where(sensor.r < 5, drop=True).fillna(0.0)
measurements = emission_utils.integrate_rays(emission, sensor)
measurements_test = emission_utils.integrate_rays(emission_test, sensor)

In [24]:
# Training parameters
hparams = {
    'num_iters': 50000,
    'lr_init': 1e-4,
    'lr_final': 1e-6,
    'lr_axis': 1e-1,
    'posenc_deg': 3,
    'batchsize': 1024,
}
state, emission_vis, rot_axis_est = train_network(
    sensor, emission, emission_test, velocity_field, hparams, 
    runname='unknown_rotation/init_axis_0.0.-1', 
    log_period=100, x_res_vis=64, y_res_vis=64, z_res_vis=64
)
emission_estimate = emission_utils.generate_orbit_3d(
    xr.DataArray(emission_vis[0], coords=initial_frame.coords), nt, velocity_field, rot_axis_est)



iteration:   0%|          | 0/50000 [00:00<?, ?it/s]

In [25]:
%matplotlib widget
ax = plt.figure(figsize=(7,7)).add_subplot(projection='3d')
ax.scatter(0,0,0, color='black', s=50)
ax.quiver(0,0,0,*rot_axis, length=0.05, linewidths=3,  label='true')
ax.quiver(0,0,0,*rot_axis_est, length=0.05, linewidths=3, label='estimated', color='r')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.legend(fontsize=14)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x7fca500a4f10>

In [26]:
extent = [(float(initial_frame[dim].min()), float(initial_frame[dim].max())) for dim in initial_frame.dims]
@interact(t=widgets.IntSlider(min=0, max=emission.t.size-1, step=1, value=0))
def plot_vol(t):
    ipv.figure()
    ipv.view(0, -60, distance=2.5)
    ipv.volshow(emission_estimate.isel(t=t), extent=extent, memorder='F', level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3], controls=False)
    ipv.show()

interactive(children=(IntSlider(value=0, description='t', max=63), Output()), _dom_classes=('widget-interact',…

In [12]:
# Save measurement prediction
sensor = xr.load_dataset('../sensors/a0.00_th1.57_ngeo200_npix10000.nc')
measurements_2d = emission_utils.integrate_rays(emission_estimate, sensor.where(sensor.r < 5))
measurements_2d = xr.DataArray(measurements_2d.data.reshape(nt, sensor.num_alpha, sensor.num_beta), dims=['t', 'alpha', 'beta'])

%matplotlib widget
outpath = 'gifs/unknown_rotaxis.prediction.theta_{:2.0f}.phi_{:2.0f}.gif'.format(np.rad2deg(theta), np.rad2deg(phi))
measurements_2d.utils_visualization.animate(output=outpath)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



<matplotlib.animation.FuncAnimation at 0x7fa8d45b4d30>