# Estimating hotspot(s) emission and rotational axis from visibilties. 

Notes:
Inner most stable circular orbit (ISCO), for spin=0 with r_g=2 this is at 3M \
Overleaf notes: https://www.overleaf.com/project/60ff0ece5aa4f90d07f2a417

In [1]:
import sys
sys.path.append('../bhnerf')

import os

import jax
from jax import random
from jax import numpy as jnp
import jax.scipy.ndimage as jnd
import scipy.ndimage as nd

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import xarray as xr
import flax
from flax.training import train_state
from flax.training import checkpoints
import optax
import numpy as np
import matplotlib.pyplot as plt

import utils, emission_utils, visualization, network_utils, observation_utils
from network_utils import shard

import ehtim as eh
import ehtim.const_def as ehc
from tensorboardX import SummaryWriter
from datetime import datetime
from tqdm.notebook import tqdm
import ipyvolume as ipv
from ipywidgets import interact
import ipywidgets as widgets

# from jax.config import config
# config.update("jax_debug_nans", True)
%load_ext autoreload
%autoreload 2

2021-11-01 16:13:36.433113: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /.singularity.d/libs
Matplotlib created a temporary config/cache directory at /tmp/matplotlib-ujmwreb0 because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


Welcome to eht-imaging! v 1.2.2 



In [9]:
tmp = np.array([8.8556e-6, 3.5908e-5, 9.9678e-5, 2.6747e-4, 2.5619e-4, 5.4590e-4])
print(10 * np.log10(1/tmp))

[50.52782008 44.44808783 40.01400685 35.72724922 35.91437826 32.62886906]


In [2]:
# Generate hotspot emission (4 hotspots)
nt, nx, ny, nz = 128, 64, 64, 64
nt_test = 128
nspots = 1
r_isco = 3.0 

phi = 0.0            # azimuth angle (ccw from x-axis)
theta = np.pi/3      # zenith angle (pi/2 = equatorial plane)
orbit_radius = 3.5 
std = .4 * np.ones_like(orbit_radius)
initial_frame = emission_utils.generate_hotspots_3d(nx, ny, nz, theta, phi, orbit_radius, std, r_isco, std_clip=np.inf)
rot_axis = np.array([np.cos(theta)*np.cos(phi), np.cos(theta)*np.sin(phi), -np.sin(theta)])

nspots = 4
frame_subset_inds = [60, 75, 105, 0]  # Generated by np.random.choice(range(0, nt, 5), size=nspots) with np.random.seed(0)
initial_frame = emission_utils.generate_orbit_3d(initial_frame, nt, 1, rot_axis).isel(
    t=frame_subset_inds).sum('t')

orbit_period = 3.5**(-3./2.) 
velocity_field = lambda r: (1.0 / orbit_period) * r**(-3/2)

emission = emission_utils.generate_orbit_3d(initial_frame, nt, velocity_field, rot_axis)
emission_test = emission_utils.generate_orbit_3d(initial_frame, nt_test, velocity_field, rot_axis)

normalization_factor = 1.0
emission *= normalization_factor
emission_test *= normalization_factor

print('rotation axis: {}'.format(rot_axis))

rotation axis: [ 0.5        0.        -0.8660254]


In [3]:
extent = [(float(emission[dim].min()), float(emission[dim].max())) for dim in ('x', 'y', 'z')]
@interact(t=widgets.IntSlider(min=0, max=emission.t.size-1, step=1, value=0))
def plot_vol(t):
    ipv.figure()
    ipv.view(0, -60, distance=2.5)
    ipv.volshow(emission.isel(t=t), extent=extent, memorder='F', level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3], controls=False)
    ipv.show()

interactive(children=(IntSlider(value=0, description='t', max=127), Output()), _dom_classes=('widget-interact'…

In [4]:
# Generate image plane fluxes
sensor = xr.load_dataset('../sensors/a0.00_th1.57_ngeo100_npix4096.nc')
sensor = sensor.where(sensor.r < 5)
r_min = sensor.r.min().data   # Minimum supervision radius
r_max = sensor.r.max().data
sensor = sensor.fillna(0.0)

image_plane = emission_utils.integrate_rays(emission, sensor).data.reshape(nt, sensor.num_alpha, sensor.num_beta)
image_plane_test = emission_utils.integrate_rays(emission_test, sensor).data.reshape(nt_test, sensor.num_alpha, sensor.num_beta)

In [5]:
def intensity_to_nchw(intensity, cmap='viridis', gamma=0.5):
    cm = plt.get_cmap(cmap)
    minval = np.amin(intensity)
    maxval = np.amax(intensity)
    norm_images = ( (intensity - minval) / (maxval - minval) )**gamma
    nchw_images = np.moveaxis(cm(norm_images)[...,:3], (0, 1, 2, 3), (3, 2, 0, 1))
    return nchw_images

def train_grid(sensor, emission, emission_test, velocity_field, hparams, runname, 
                  log_period=100, save_period=10000, x_res_vis=64, y_res_vis=64, z_res_vis=64):
    
    # Training / testing coordinates
    train_coords = network_utils.get_input_coords(sensor, t_array=np.linspace(0., 1., emission.t.size), batch='t')
    t, x, y, z, d = train_coords.values()

    test_coords = network_utils.get_input_coords(sensor, t_array=np.linspace(0., 1., emission.t.size), batch='t')
    t_test, x_test, y_test, z_test, d_test = test_coords.values()

    # Emission visualization inputs
    t_res_vis = 1
    emission_extent = [emission.x[0], emission.x[-1], emission.y[0], emission.y[-1], emission.z[0], emission.z[-1]]
    t_vis, x_vis, y_vis, z_vis  = np.meshgrid(0.0, np.linspace(emission_extent[0], emission_extent[1], x_res_vis),
                                              np.linspace(emission_extent[2], emission_extent[3], y_res_vis),
                                              np.linspace(emission_extent[4], emission_extent[5], x_res_vis),
                                              indexing='ij')

    d_vis = np.ones_like(y_vis)                  # meaningless placeholder for emission visualization
    target_vis = np.ones([y_res_vis, x_res_vis]) # meaningless placeholder for emission visualization
    r_min = hparams['r_min']                 # Zero network output where there is no supervision (within black-hole radius)  
    r_max = hparams['r_max']
        
    # Model setup and initialization
    rng = jax.random.PRNGKey(1)
    predictor = network_utils.PREDICT_EMISSION_3D_GRID()
    params = predictor.init(rng, x[:1, ...], y[:1, ...], z[:1, ...], t[:1, ...], velocity_field, rot_axis)['params']
    tx = optax.adam(learning_rate=optax.polynomial_schedule(hparams['lr_init'], hparams['lr_final'], 1, hparams['num_iters']))
    state = train_state.TrainState.create(apply_fn=predictor.apply, params=params.unfreeze(), tx=tx)  # TODO(pratul): this unfreeze feels sketchy
    
    if np.isscalar(save_period):
        checkpoint_dir = 'checkpoints/{}'.format(runname)
        state = checkpoints.restore_checkpoint(checkpoint_dir, state)
    init_step = 1 + state.step

    train_pstep = jax.pmap(predictor.train_step, axis_name='batch', in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, 0, 0), static_broadcasted_argnums=(0))
    eval_pstep = jax.pmap(predictor.eval_step, axis_name='batch', in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, 0, 0), static_broadcasted_argnums=(0))
    rand_key = jax.random.split(rng, jax.local_device_count())
    state = flax.jax_utils.replicate(state)


    # TensorBoard logging
    time = datetime.now().strftime('%Y-%m-%d.%H:%M:%S')
    logdir = 'runs/{}.{}'.format(runname, time)
    with SummaryWriter(logdir=logdir) as writer:

        # Log ground-truth data   
        %matplotlib inline
        images_true = intensity_to_nchw(emission.isel(t=0))        

        for i in tqdm(range(1, hparams['num_iters']+1), desc='iteration'):

            # Testing and Visualization
            if (i == 1) or (i % log_period) == 0:

                batch_inds = np.random.randint(0, x.shape[0], hparams['batchsize'])
                loss_test, _, _, rendering_test, _ = eval_pstep(
                    velocity_field, rot_axis, i, 
                    shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
                    shard(d[batch_inds, ...]), shard(t[batch_inds, ...]), shard(image_plane_test[batch_inds, ...]), 
                    state, rand_key
                )
                writer.add_scalar('log loss/test', np.log10(np.mean(float(loss_test))), global_step=i)

                # Log prediction and estimate
                _, _, emission_vis, _, _ = eval_pstep(
                    velocity_field, rot_axis, i, shard(x_vis), shard(y_vis), shard(z_vis), 
                    shard(d_vis), shard(t_vis), shard(target_vis), state, rand_key
                )
                emission_vis = np.reshape(emission_vis, [t_res_vis, x_res_vis, y_res_vis, z_res_vis])
                emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 < r_min**2, jnp.zeros_like(emission_vis), emission_vis)
                emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 > r_max**2, jnp.zeros_like(emission_vis), emission_vis)
                
                # Log initialization seperately 
                images = intensity_to_nchw(emission_vis[0])
                writer.add_images('emission/estimate', images, global_step=i)
                writer.add_images('emission/true', images_true, global_step=i)
                writer.add_image('rendering/test', rendering_test[0, 0, None, :, :], global_step=i)
                writer.add_image('rendering/true', image_plane_test[batch_inds[0], None, ...], global_step=i)
                writer.add_image('rendering/diff', np.abs(image_plane_test[batch_inds[0], None, ...] - rendering_test[0, 0, None, :, :]), global_step=i)
                
                writer.add_scalar('emission_mse', float(np.mean((emission.data[0] - emission_vis)**2)), global_step=i)
                
            # Training
            batch_inds = np.random.randint(0, x.shape[0], hparams['batchsize'])
            loss_train, state, _, _, rand_key = train_pstep(
                velocity_field, rot_axis, i, 
                shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
                shard(d[batch_inds, ...]), shard(t[batch_inds, ...]), shard(image_plane[batch_inds, ...]), 
                state, rand_key
            )
            writer.add_scalar('log loss/train', np.log10(float(np.mean(loss_train))), global_step=i)
            
            if np.isscalar(save_period) and ((i % save_period == 0) or (i == hparams['num_iters'])):
                state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
                checkpoints.save_checkpoint(checkpoint_dir, state_to_save, int(i), keep=5)
            
    return state, emission_vis

In [None]:
# Training parameters
hparams = {
    'num_iters': 5000,
    'lr_init': 1e-2,
    'lr_final': 1e-4,
    'batchsize': 8,
    'r_min': r_min,
    'r_max': r_max,
}

runname = 'nerf_vs_grid_11_1_21/grid_4_spots'
state, emission_vis = train_grid(
    sensor, emission, emission_test, velocity_field, hparams, 
    runname=runname, log_period=100, save_period=5000, x_res_vis=64, y_res_vis=64, z_res_vis=64
)
# emission_estimate = emission_utils.generate_orbit_3d(
#     xr.DataArray(emission_vis[0], coords=initial_frame.coords), nt, velocity_field, rot_axis)



iteration:   0%|          | 0/5000 [00:00<?, ?it/s]

In [7]:
def intensity_to_nchw(intensity, cmap='viridis', gamma=0.5):
    cm = plt.get_cmap(cmap)
    norm_images = ( (intensity - np.min(intensity)) / (np.max(intensity) - np.min(intensity)) )**gamma
    nchw_images = np.moveaxis(cm(norm_images)[...,:3], (0, 1, 2, 3), (3, 2, 0, 1))
    return nchw_images

def train_network(sensor, emission, emission_test, velocity_field, hparams, runname, 
                  log_period=100, save_period=10000, x_res_vis=64, y_res_vis=64, z_res_vis=64):
    
    # Training / testing coordinates
    train_coords = network_utils.get_input_coords(sensor, t_array=np.linspace(0., 1., emission.t.size), batch='t')
    t, x, y, z, d = train_coords.values()

    test_coords = network_utils.get_input_coords(sensor, t_array=np.linspace(0., 1., emission.t.size), batch='t')
    t_test, x_test, y_test, z_test, d_test = test_coords.values()

    # Emission visualization inputs
    t_res_vis = 1
    emission_extent = [emission.x[0], emission.x[-1], emission.y[0], emission.y[-1], emission.z[0], emission.z[-1]]
    t_vis, x_vis, y_vis, z_vis  = np.meshgrid(0.0, np.linspace(emission_extent[0], emission_extent[1], x_res_vis),
                                              np.linspace(emission_extent[2], emission_extent[3], y_res_vis),
                                              np.linspace(emission_extent[4], emission_extent[5], x_res_vis),
                                              indexing='ij')

    d_vis = np.ones_like(y_vis)                  # meaningless placeholder for emission visualization
    target_vis = np.ones([y_res_vis, x_res_vis]) # meaningless placeholder for emission visualization
    r_min = hparams['r_min']                 # Zero network output where there is no supervision (within black-hole radius)  
    r_max = hparams['r_max']
    
    # Model setup and initialization
    rng = jax.random.PRNGKey(1)
    predictor = network_utils.PREDICT_EMISSION_3D()
    params = predictor.init(rng, x[:1, ...], y[:1, ...], z[:1, ...], t[:1, ...], velocity_field, rot_axis)['params']
    tx = optax.adam(learning_rate=optax.polynomial_schedule(hparams['lr_init'], hparams['lr_final'], 1, hparams['num_iters']))
    state = train_state.TrainState.create(apply_fn=predictor.apply, params=params.unfreeze(), tx=tx)  # TODO(pratul): this unfreeze feels sketchy
    
    if np.isscalar(save_period):
        checkpoint_dir = 'checkpoints/{}'.format(runname)
        state = checkpoints.restore_checkpoint(checkpoint_dir, state)
    init_step = 1 + state.step

    train_pstep = jax.pmap(predictor.train_step, axis_name='batch', in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, 0, 0), static_broadcasted_argnums=(0))
    eval_pstep = jax.pmap(predictor.eval_step, axis_name='batch', in_axes=(None, None, None, 0, 0, 0, 0, 0, 0, 0, 0), static_broadcasted_argnums=(0))
    rand_key = jax.random.split(rng, jax.local_device_count())
    state = flax.jax_utils.replicate(state)


    # TensorBoard logging
    time = datetime.now().strftime('%Y-%m-%d.%H:%M:%S')
    logdir = 'runs/{}.{}'.format(runname, time)
    with SummaryWriter(logdir=logdir) as writer:

        # Log ground-truth data   
        %matplotlib inline
        images_true = intensity_to_nchw(emission.isel(t=0))        

        for i in tqdm(range(1, hparams['num_iters']+1), desc='iteration'):

            # Testing and Visualization
            if (i == 1) or (i % log_period) == 0:

                batch_inds = np.random.randint(0, x.shape[0], hparams['batchsize'])
                loss_test, _, _, rendering_test, _ = eval_pstep(
                    velocity_field, rot_axis, i, 
                    shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
                    shard(d[batch_inds, ...]), shard(t[batch_inds, ...]), shard(image_plane_test[batch_inds, ...]), 
                    state, rand_key
                )
                writer.add_scalar('log loss/test', np.log10(np.mean(float(loss_test))), global_step=i)

                # Log prediction and estimate
                _, _, emission_vis, _, _ = eval_pstep(
                    velocity_field, rot_axis, i, shard(x_vis), shard(y_vis), shard(z_vis), 
                    shard(d_vis), shard(t_vis), shard(target_vis), state, rand_key
                )
                emission_vis = np.reshape(emission_vis, [t_res_vis, x_res_vis, y_res_vis, z_res_vis])
                emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 < r_min**2, jnp.zeros_like(emission_vis), emission_vis)
                emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 > r_max**2, jnp.zeros_like(emission_vis), emission_vis)
                
                # Log initialization seperately 
                images = intensity_to_nchw(emission_vis[0])
                writer.add_images('emission/estimate', images, global_step=i)
                writer.add_images('emission/true', images_true, global_step=i)
                writer.add_image('rendering/test', rendering_test[0, 0, None, :, :], global_step=i)
                writer.add_image('rendering/true', image_plane_test[batch_inds[0], None, ...], global_step=i)
                writer.add_image('rendering/diff', np.abs(image_plane_test[batch_inds[0], None, ...] - rendering_test[0, 0, None, :, :]), global_step=i)
                
                writer.add_scalar('emission_mse', float(np.mean((emission.data[0] - emission_vis)**2)), global_step=i)
                
            # Training
            batch_inds = np.random.randint(0, x.shape[0], hparams['batchsize'])
            loss_train, state, _, _, rand_key = train_pstep(
                velocity_field, rot_axis, i, 
                shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
                shard(d[batch_inds, ...]), shard(t[batch_inds, ...]), shard(image_plane[batch_inds, ...]), 
                state, rand_key
            )
            writer.add_scalar('log loss/train', np.log10(float(np.mean(loss_train))), global_step=i)
            
            if np.isscalar(save_period) and ((i % save_period == 0) or (i == hparams['num_iters'])):
                state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
                checkpoints.save_checkpoint(checkpoint_dir, state_to_save, int(i), keep=5)
            
    return state, emission_vis

In [None]:
# Training parameters
hparams = {
    'num_iters': 5000,
    'lr_init': 1e-4,
    'lr_final': 1e-6,
    'lr_axis': 1e-2,
    'posenc_deg': 3,
    'batchsize': 8,
    'r_min': r_min,
    'r_max': r_max,
}
# runname = '{}/known_riaf.flux{}.norm{}.rot_axis_MLP.fov{}.batch{}.phi{:2.1f}.theta{:2.1f}.nspots{}.nonoise'.format(
#     array_name, riaf_flux, normalization_factor, fov, hparams['batchsize'], np.rad2deg(phi), np.rad2deg(theta), nspots
# )

runname = 'nerf_vs_grid_11_1_21/network_4_spots'
state, emission_vis = train_network(
    sensor, emission, emission_test, velocity_field, hparams, 
    runname=runname, log_period=100, save_period=5000, x_res_vis=64, y_res_vis=64, z_res_vis=64
)
# emission_estimate = emission_utils.generate_orbit_3d(
#     xr.DataArray(emission_vis[0], coords=initial_frame.coords), nt, velocity_field, rot_axis_est)



iteration:   0%|          | 0/5000 [00:00<?, ?it/s]

In [34]:
%matplotlib widget
ax = plt.figure(figsize=(7,7)).add_subplot(projection='3d')
ax.scatter(0,0,0, color='black', s=50)
ax.quiver(0,0,0,*rot_axis, length=0.05, linewidths=3,  label='true')
ax.quiver(0,0,0,*rot_axis_est, length=0.05, linewidths=3, label='estimated', color='r')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.legend(fontsize=14)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x7fccbc6a53a0>