# Estimating hotspot(s) emission using NeRF. 

Notes:
Inner most stable circular orbit (ISCO), for spin=0 with r_g=2 this is at 3M \
Overleaf notes: https://www.overleaf.com/project/60ff0ece5aa4f90d07f2a417

In [1]:
import sys
sys.path.append('../bhnerf')

import os

import jax
from jax import random
from jax import numpy as jnp

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import xarray as xr
import flax
from flax.training import train_state
import optax
import numpy as np
import matplotlib.pyplot as plt

import utils, emission_utils, visualization, network_utils
from network_utils import shard
import mediapy

from tensorboardX import SummaryWriter
from datetime import datetime
from tqdm.notebook import tqdm
import ipyvolume as ipv
from ipywidgets import interact
import ipywidgets as widgets

%load_ext autoreload
%autoreload 2

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-444xloxi because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [2]:
def intensity_to_nchw(intensity, cmap='viridis', gamma=0.5):
    cm = plt.get_cmap(cmap)
    norm_images = ( (intensity - np.min(intensity)) / (np.max(intensity) - np.min(intensity)) )**gamma
    nchw_images = np.moveaxis(cm(norm_images)[...,:3], (0, 1, 2, 3), (3, 2, 0, 1))
    return nchw_images

def train_network(sensor, emission, emission_test, velocity_field, hparams, runname, 
                  log_period=100, x_res_vis=64, y_res_vis=64, z_res_vis=64):
    
    # Training / testing coordinates
    train_coords = network_utils.get_input_coords(sensor, nt=emission.t.size)
    t, x, y, z, d = train_coords.values()
    train_radiance = measurements.data.ravel() 

    test_coords = network_utils.get_input_coords(sensor, nt=emission_test.t.size)
    test_radiance = measurements_test.data.ravel() 
    t_test, x_test, y_test, z_test, d_test = test_coords.values()

    # Emission visualization inputs
    t_res_vis = 1
    emission_extent = [emission.x[0], emission.x[-1], emission.y[0], emission.y[-1], emission.z[0], emission.z[-1]]
    t_vis, x_vis, y_vis, z_vis  = np.meshgrid(0.0, np.linspace(emission_extent[0], emission_extent[1], x_res_vis),
                                              np.linspace(emission_extent[2], emission_extent[3], y_res_vis),
                                              np.linspace(emission_extent[4], emission_extent[5], x_res_vis),
                                              indexing='ij')

    d_vis = np.ones_like(y_vis)                  # meaningless placeholder for emission visualization
    target_vis = np.ones([y_res_vis, x_res_vis]) # meaningless placeholder for emission visualization
    r_min = sensor.r.min().data                  # Zero network output where there is no supervision (within black-hole radius)  

    # Model setup and initialization
    rng = jax.random.PRNGKey(1)
    predictor = network_utils.PREDICT_EMISSION_3D(posenc_deg=hparams['posenc_deg'])
    params = predictor.init(rng, x[:1, ...], y[:1, ...], z[:1, ...], t[:1, ...], velocity_field, rot_axis)['params']
    tx = optax.adam(learning_rate=optax.polynomial_schedule(hparams['lr_init'], hparams['lr_final'], 1, hparams['num_iters']))
    state = train_state.TrainState.create(apply_fn=predictor.apply, params=params.unfreeze(), tx=tx)  # TODO(pratul): this unfreeze feels sketchy

    train_pstep = jax.pmap(predictor.train_step, axis_name='batch', in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, 0, 0), static_broadcasted_argnums=(0))
    eval_pstep = jax.pmap(predictor.eval_step, axis_name='batch', in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, 0, 0), static_broadcasted_argnums=(0))
    rand_key = jax.random.split(rng, jax.local_device_count())
    state = flax.jax_utils.replicate(state)


    # TensorBoard logging
    time = datetime.now().strftime('%Y-%m-%d.%H:%M:%S')
    logdir = 'runs/{}.{}'.format(runname, time)
    with SummaryWriter(logdir=logdir) as writer:

        # Log ground-truth data   
        %matplotlib inline
        images = intensity_to_nchw(emission.isel(t=0))
        writer.add_images('emission/true', images, global_step=0)

        for i in tqdm(range(1, hparams['num_iters']+1), desc='iteration'):

            # Testing and Visualization
            if (i == 1) or (i % log_period) == 0:

                batch_inds = np.random.randint(0, x.shape[0], hparams['batchsize'])
                loss_test, _, _, _, _ = eval_pstep(
                    velocity_field, rot_axis, i, 
                    shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
                    shard(d[batch_inds, ...]), shard(t[batch_inds, ...]), shard(train_radiance[batch_inds, ...]), 
                    state, rand_key
                )
                writer.add_scalar('log loss/test', np.log10(np.mean(float(loss_test))), global_step=i)

                # Log prediction and estimate
                _, _, emission_vis, _, _ = eval_pstep(
                    velocity_field, rot_axis, i, shard(x_vis), shard(y_vis), shard(z_vis), 
                    shard(d_vis), shard(t_vis), shard(target_vis), state, rand_key
                )
                emission_vis = np.reshape(emission_vis, [t_res_vis, x_res_vis, y_res_vis, z_res_vis])
                emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 >= r_min**2, emission_vis, jnp.zeros_like(emission_vis))
                
                # Log initialization seperately 
                images = intensity_to_nchw(emission_vis[0])
                if (i == 1): 
                    writer.add_images('emission/initial', images, global_step=i)
                else:
                    writer.add_images('emission/esimate', images, global_step=i)

            # Training
            batch_inds = np.random.randint(0, x.shape[0], hparams['batchsize'])
            loss_train, state, _, _, rand_key = train_pstep(
                velocity_field, rot_axis, i, 
                shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
                shard(d[batch_inds, ...]), shard(t[batch_inds, ...]), shard(train_radiance[batch_inds, ...]), 
                state, rand_key
            )
            writer.add_scalar('log loss/train', np.log10(float(np.mean(loss_train))), global_step=i)
            
    return state, emission_vis

In [3]:
sensor = xr.load_dataset('../sensors/a0.00_th1.57_ngeo200_npix10000.nc')

# Multiple Hotspots

In [14]:
# Generate hotspot emission
nt, nx, ny, nz = 64, 64, 64, 64
nt_test = 64
nspots = 1
r_isco = 3.0 

phi = np.pi            # azimuth angle (ccw from x-axis)
theta = np.pi / 3.0    # zenith angle (pi/2 = equatorial plane)
orbit_radius = 3.5 
std = .4 * np.ones_like(orbit_radius)
initial_frame = emission_utils.generate_hotspots_3d(nx, ny, nz, theta, phi, orbit_radius, std, r_isco)

nspots = 8
rot_axis = np.array([np.cos(theta)*np.cos(phi), np.cos(theta)*np.sin(phi), -np.sin(theta)])
initial_frame = emission_utils.generate_orbit_3d(initial_frame, nt, 1, rot_axis).isel(
    t=np.random.choice(range(0, nt, 5), size=nspots)).sum('t')

orbit_period = 3.5**(-3./2.) 
velocity_field = lambda r: (1.0 / orbit_period) * r**(-3/2)
emission = emission_utils.generate_orbit_3d(initial_frame, nt, velocity_field, rot_axis)
emission_test = emission_utils.generate_orbit_3d(initial_frame, nt_test, velocity_field, rot_axis)

In [16]:
extent = [(float(initial_frame[dim].min()), float(initial_frame[dim].max())) for dim in initial_frame.dims]
@interact(t=widgets.IntSlider(min=0, max=emission.t.size-1, step=1, value=0))
def plot_vol(t):
    ipv.figure()
    ipv.view(0, -60, distance=2.5)
    ipv.volshow(emission.isel(t=t), extent=extent, memorder='F', level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3], controls=False)
    ipv.show()

interactive(children=(IntSlider(value=0, description='t', max=63), Output()), _dom_classes=('widget-interact',…

In [17]:
# Visualize radiance measurements 
measurements_2d = emission_utils.integrate_rays(emission, sensor.where(sensor.r < 5))
measurements_2d = xr.DataArray(measurements_2d.data.reshape(nt, sensor.num_alpha, sensor.num_beta), dims=['t', 'alpha', 'beta'])

%matplotlib widget
out_path = 'gifs/multiple_hs.measurements.theta_{:2.0f}.phi_{:2.0f}.nspots{:d}.gif'.format(np.rad2deg(theta), np.rad2deg(phi), nspots)
measurements_2d.utils_visualization.animate(output=out_path)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

MovieWriter imagemagick unavailable; using Pillow instead.


<matplotlib.animation.FuncAnimation at 0x7f4f255fbbb0>

In [18]:
# Drop rays which do not cross the domain (this gives measurements that cannot be reshaped into an image)
sensor = sensor.where(sensor.r < 5, drop=True)
measurements = emission_utils.integrate_rays(emission, sensor)
measurements_test = emission_utils.integrate_rays(emission_test, sensor)

In [19]:
# Training parameters
hparams = {
    'num_iters': 50000,
    'lr_init': 1e-4,
    'lr_final': 1e-6,
    'posenc_deg': 3,
    'batchsize': 1024,
}
state, emission_vis = train_network(
    sensor, emission, emission_test, velocity_field, hparams, 
    runname='multiple_shearing_hs_3D', log_period=100, x_res_vis=64, y_res_vis=64, z_res_vis=64
)



iteration:   0%|          | 0/50000 [00:00<?, ?it/s]

In [20]:
emission_estimate = emission_utils.generate_orbit_3d(xr.DataArray(emission_vis[0], coords=initial_frame.coords), nt, velocity_field, rot_axis)

In [21]:
extent = [(float(initial_frame[dim].min()), float(initial_frame[dim].max())) for dim in initial_frame.dims]
@interact(t=widgets.IntSlider(min=0, max=emission.t.size-1, step=1, value=0))
def plot_vol(t):
    ipv.figure()
    ipv.view(0, -60, distance=2.5)
    ipv.volshow(emission_estimate.isel(t=t), extent=extent, memorder='F', level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3], controls=False)
    ipv.show()

interactive(children=(IntSlider(value=0, description='t', max=63), Output()), _dom_classes=('widget-interact',…

In [22]:
# Save measurement prediction
sensor = xr.load_dataset('../sensors/a0.00_th1.57_ngeo200_npix10000.nc')
measurements_2d = emission_utils.integrate_rays(emission_estimate, sensor.where(sensor.r < 5))
measurements_2d = xr.DataArray(measurements_2d.data.reshape(nt, sensor.num_alpha, sensor.num_beta), dims=['t', 'alpha', 'beta'])

%matplotlib widget
out_path = 'gifs/multiple_hs.prediction.theta_{:2.0f}.phi_{:2.0f}.nspots{:d}.gif'.format(np.rad2deg(theta), np.rad2deg(phi), nspots)
measurements_2d.utils_visualization.animate(output=out_path)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



<matplotlib.animation.FuncAnimation at 0x7f4f1f6063a0>

# Single Hotspot

In [5]:
# Generate hotspot emission
nt, nx, ny, nz = 64, 64, 64, 64
nt_test = 64
nspots = 1
r_isco = 3.0 

phi = np.pi            # azimuth angle (ccw from x-axis)
theta = np.pi / 3.0    # zenith angle (pi/2 = equatorial plane)

orbit_radius = 3.5 
std = .4 * np.ones_like(orbit_radius)
initial_frame = emission_utils.generate_hotspots_3d(nx, ny, nz, theta, phi, orbit_radius, std, r_isco)

orbit_period = 3.5**(-3./2.) 
velocity_field = lambda r: (1.0 / orbit_period) * r**(-3/2)

rot_axis =  np.array([np.cos(theta)*np.cos(phi), np.cos(theta)*np.sin(phi), -np.sin(theta)])
emission = emission_utils.generate_orbit_3d(initial_frame, nt, velocity_field, rot_axis)
emission_test = emission_utils.generate_orbit_3d(initial_frame, nt_test, velocity_field, rot_axis)

In [6]:
extent = [(float(initial_frame[dim].min()), float(initial_frame[dim].max())) for dim in initial_frame.dims]
@interact(t=widgets.IntSlider(min=0, max=emission.t.size-1, step=1, value=0))
def plot_vol(t):
    ipv.figure()
    ipv.view(0, -60, distance=2.5)
    ipv.volshow(emission.isel(t=t), extent=extent, memorder='F', level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3], controls=False)
    ipv.show()

interactive(children=(IntSlider(value=0, description='t', max=63), Output()), _dom_classes=('widget-interact',…

In [6]:
# Visualize radiance measurements 
measurements_2d = emission_utils.integrate_rays(emission, sensor.where(sensor.r < 5))
measurements_2d = xr.DataArray(measurements_2d.data.reshape(nt, sensor.num_alpha, sensor.num_beta), dims=['t', 'alpha', 'beta'])

%matplotlib widget
out_path = 'gifs/orbiting_hs.theta_{:2.0f}.phi_{:2.0f}.gif'.format(np.rad2deg(theta), np.rad2deg(phi))
measurements_2d.utils_visualization.animate(output=out_path)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.animation.FuncAnimation at 0x7fb8f82ef400>

In [96]:
# Drop rays which do not cross the domain (this gives measurements that cannot be reshaped into an image)
sensor = sensor.where(sensor.r < 5, drop=True)
measurements = emission_utils.integrate_rays(emission, sensor)
measurements_test = emission_utils.integrate_rays(emission_test, sensor)

In [136]:
# Training parameters
hparams = {
    'num_iters': 50000,
    'lr_init': 1e-4,
    'lr_final': 1e-6,
    'posenc_deg': 3,
    'batchsize': 1024,
}
state, emission_vis = train_network(
    sensor, emission, emission_test, velocity_field, hparams, 
    runname='shearing_hs_3D', log_period=100, x_res_vis=64, y_res_vis=64, z_res_vis=64
)



iteration:   0%|          | 0/50000 [00:00<?, ?it/s]

In [None]:
ipv.figure()
ipv.view(0, -60, distance=2.5)
ipv.volshow(emission_vis[0], extent=extent, memorder='F', level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3], controls=False)
ipv.show()