# Tutorial3: estimate 3D emission from observations

---
This tutorial demonstrates the recovery of 3D emission from synthetic observations. Complex visibilities are used to fit the parameters of a coordinate-based neural network (NeRF).

In [1]:
import bhnerf
import kgeo
import ehtim as eh
from ehtim.observing.obs_helpers import ftmatrix
import ehtim.const_def as ehc
import bhnerf.constants as consts
from astropy import units
import jax

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

import os
from datetime import datetime
from tqdm.notebook import tqdm

# Runing on 2 GPUs
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-wk_axo8u because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


Welcome to eht-imaging! v 1.2.2 



2022-08-10 16:20:20.209619: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /.singularity.d/libs


In [2]:
"""
Generate synthetic observations of a hot-spot to serve as ground-truth comparison 
"""
spin = 0.2
inclination = np.deg2rad(60.0)

array = 'ngEHT'             
nt = 64                             # number of image frames
flux_scaling = 0.02                 # scale image-plane values to `reasonable` fluxes  
fov_M = 16.0                        # field of view (M)
tstart = 2.0 * units.hour           # observation start time
tstop = tstart + 40.0 * units.min   # observation stop time

# Compute geodesics (see Tutorial1)
geos = bhnerf.emission.kerr_geodesics(
    spin, inclination, 
    num_alpha=64, num_beta=64, 
    alpha_range=[-fov_M/2, fov_M/2],
    beta_range=[-fov_M/2, fov_M/2]
)
Omega = np.sqrt(geos.M) / (geos.r**(3/2) + geos.spin * np.sqrt(geos.M))
t_injection = -float(geos.r_o)

# Generate hotspot measurements (see Tutorial2) 
emission_0 = flux_scaling * bhnerf.emission.generate_hotspot_xr(
    resolution=(64, 64, 64), 
    rot_axis=[0.0, 0.0, 1.0], 
    rot_angle=0.0,
    orbit_radius=5.5,
    std=0.7,
    r_isco=bhnerf.constants.isco_pro(spin),
    fov=(fov_M, 'GM/c^2')
)
obs_params = {
    'mjd': 57851,                       # night of april 6-7, 2017
    'timetype': 'GMST',
    'nt': nt,                           # number of time samples 
    'tstart': tstart.to('hr').value,    # start of observations
    'tstop': tstop.to('hr').value,      # end of observation 
    'tint': 30.0,                       # integration time,
    'array': eh.array.load_txt('../eht_arrays/{}.txt'.format(array))
}
obs_empty = bhnerf.observation.empty_eht_obs(**obs_params)
fov_rad = (fov_M * consts.GM_c2(consts.sgra_mass) / consts.sgra_distance.to('m')) * units.rad
psize = fov_rad.value / geos.alpha.size 
obs_args = {'psize': psize, 'ra': obs_empty.ra, 'dec': obs_empty.dec, 'rf': obs_empty.rf, 'mjd': obs_empty.mjd}
t_frames = np.linspace(tstart, tstop, nt)
image_plane = bhnerf.emission.image_plane_dynamics(emission_0, geos, Omega, t_frames, t_injection)
movie = eh.movie.Movie(image_plane, times=t_frames.value, **obs_args)
obs = bhnerf.observation.observe_same(movie, obs_empty, ttype='direct', seed=None)

  result_data = func(*input_data)


Producing clean visibilities from movie with direct FT . . . 
Applying Jones Matrices to data . . . 
Applying Jones Matrices to data . . . 
Adding thermal noise to data . . . 
Applying a priori calibration with estimated Jones matrices . . . 


In [3]:
from jax import numpy as jnp
import optax
def run_optimization(runname, hparams, predictor, train_pstep, target, t_frames, geos, Omega, rmax, t_injection,
                     batched_args=[], emission_true=None, vis_res=64, log_period=100, save_period=1000):
    
    from tensorboardX import SummaryWriter
    from tqdm.notebook import tqdm
    
    # Logging parameters
    checkpoint_dir = '../checkpoints/{}'.format(runname)
    logdir = '../runs/{}'.format(runname)
    
    # Image rendering arguments
    t_units = t_frames.unit
    coords = jnp.array([geos.x, geos.y, geos.z])
    umu = bhnerf.emission.azimuthal_velocity_vector(geos, Omega)
    g = jnp.array(bhnerf.emission.doppler_factor(geos, umu))
    Omega = jnp.array(Omega)
    dtau = jnp.array(geos.dtau)
    Sigma = jnp.array(geos.Sigma)
    t_start_obs = t_frames[0]
    t_geos = jnp.array(geos.t)
    rmin = geos.r.min().data
    rendering_args = [coords, Omega, g, dtau, Sigma, t_start_obs, t_geos, t_injection, rmin, rmax]
    print('asd')
    # Grid for visualization (no interpolation is added for easy comparison)
    if emission_true is not None:
        vis_coords = np.array(np.meshgrid(np.linspace(emission_true.x[0], emission_true.x[-1], emission_true.shape[0]),
                                          np.linspace(emission_true.y[0], emission_true.y[-1], emission_true.shape[1]),
                                          np.linspace(emission_true.z[0], emission_true.z[-1], emission_true.shape[2]),
                                          indexing='ij'))
    else:
        grid_1d = np.linspace(-rmax, rmax, vis_res)
        vis_coords = np.array(np.meshgrid(grid_1d, grid_1d, grid_1d, indexing='ij'))

    params = predictor.init(jax.random.PRNGKey(1), t_frames[0], t_units, coords, Omega, t_start_obs, t_geos, t_injection)['params']
    
    tx = optax.adam(learning_rate=optax.polynomial_schedule(hparams['lr_init'], hparams['lr_final'], 1, hparams['num_iters']))
    if 'lr_inject' in hparams.keys():
        tx = optax.chain(
            optax.masked(optax.adam(learning_rate=hparams['lr_inject']), mask=flattened_traversal(lambda path, _: path[-1] == 't_injection')),
            optax.masked(tx, mask=flattened_traversal(lambda path, _: path[-1] != 't_injection')),
        )
        
    state = train_state.TrainState.create(apply_fn=predictor.apply, params=params.unfreeze(), tx=tx)  # TODO(pratul): this unfreeze feels sketchy

    # Restore saved checkpoint
    # if np.isscalar(save_period): state = checkpoints.restore_checkpoint(checkpoint_dir, state)
    init_step = 1 + state.step
    
    # Replicate state for multiple gpus
    state = flax.jax_utils.replicate(state)
    with SummaryWriter(logdir=logdir) as writer:
        if emission_true is not None: writer.add_images('emission/true', bhnerf.utils.intensity_to_nchw(emission_true), global_step=0)

        for i in tqdm(range(init_step, init_step + hparams['num_iters']), desc='iteration'):
            print(i)
            batch = np.random.choice(range(len(t_frames)), hparams['batchsize'], replace=False)
            bargs = [shard(arg[batch, ...]) for arg in batched_args]
            loss, state, images = train_pstep(state, *bargs, *rendering_args, t_units)

            # Log the current state on TensorBoard
            writer.add_scalar('log_loss/train', np.log10(np.mean(loss)), global_step=i)
            if (i == 1) or ((i % log_period) == 0) or (i ==  init_step + hparams['num_iters']):
                current_state = jax.device_get(flax.jax_utils.unreplicate(state))
                emission_grid = state.apply_fn({'params': current_state.params}, t_frames[0], t_units, vis_coords, 0.0, t_start_obs, 0.0, 0.0)
                emission_grid = bhnerf.emission.fill_unsupervised_emission(emission_grid, vis_coords, rmin, rmax)
                writer.add_images('emission/estimate', bhnerf.utils.intensity_to_nchw(emission_grid), global_step=i)
                if 'lr_inject' in hparams.keys(): writer.add_scalar('t_injection', float(current_state.params['t_injection']), global_step=i)
                if emission_true is not None:
                    writer.add_scalar('emission/mse', bhnerf.utils.mse(emission_true.data, emission_grid), global_step=i)
                    writer.add_scalar('emission/psnr', bhnerf.utils.psnr(emission_true.data, emission_grid), global_step=i)

            # Save checkpoints occasionally
            if np.isscalar(save_period) and ((i % save_period == 0) or (i ==  init_step + hparams['num_iters'])):
                current_state = jax.device_get(flax.jax_utils.unreplicate(state))
                checkpoints.save_checkpoint(checkpoint_dir, current_state, int(i), keep=5)
    return current_state

In [3]:
"""
Optimize network paremters to recover the 3D emission (as a continuous function) from observations 
Note that logging is done using tensorboardX. To view the tensorboard (from the main directory):
    `tensorboard --logdir runs`
"""
hparams = {'num_iters': 5000, 'lr_init': 1e-4, 'lr_final': 1e-6, 'batchsize': 6}
predictor = bhnerf.network.NeRF_Predictor()
train_pstep = jax.pmap(bhnerf.network.train_step_eht, 
                       axis_name='batch', 
                       in_axes=(0, None, 0, 0, 0, 0, None, None, None, None, None, None, None, None, None, None), 
                       static_broadcasted_argnums=(1))

# Observation parameters (split observations into frames)
num_frames = 64
obs_frames = obs.split_obs(t_gather=(tstop-tstart).to('s').value / (num_frames+1))
prior = eh.image.make_square(obs, geos.alpha.size, fov_rad.value)
target, sigma, A = [np.array(out) for out in zip(*[eh.imaging.imager_utils.chisqdata_vis(obs, prior, mask=[]) for obs in obs_frames])]
batched_args = [target, sigma, A, t_frames]

# Run optimization
# Note: rmax constrains the optimization domain to a radius.
rmax = fov_M / 2
current_time = datetime.now().strftime('%Y-%m-%d.%H:%M:%S')
runname = 'tutorial3/recovery.{}'.format(current_time)
state = bhnerf.network.run_optimization(
    runname, hparams, predictor, train_pstep, target, t_frames, geos, Omega, rmax, 
    t_injection, batched_args, emission_true=emission_0, save_period=hparams['num_iters'])

Splitting Observation File into 64 times


  result_data = func(*input_data)


iteration:   0%|          | 0/5000 [00:00<?, ?it/s]

In [4]:
"""
Visualization the recovered 3D emission
This visualization requires ipyvolume: https://ipyvolume.readthedocs.io/en/latest/
"""
import ipyvolume as ipv
vis_coords = np.array(np.meshgrid(np.linspace(emission_0.x[0], emission_0.x[-1], emission_0.shape[0]),
                                  np.linspace(emission_0.y[0], emission_0.y[-1], emission_0.shape[1]),
                                  np.linspace(emission_0.z[0], emission_0.z[-1], emission_0.shape[2]),
                                  indexing='ij'))

# Get the a gridded convereged solution from the neural network
emission_estimate = state.apply_fn({'params': state.params}, tstart, tstart.unit, vis_coords, 0.0, tstart, 0.0, 0.0)
emission_estimate =  bhnerf.emission.fill_unsupervised_emission(emission_estimate, vis_coords, geos.r.min().data, rmax)
   
extent = [(float(emission_0[dim].min()), float(emission_0[dim].max())) for dim in ('x', 'y', 'z')]
ipv.figure()
ipv.view(0, -60, distance=2.5)
ipv.volshow(emission_estimate, extent=extent, memorder='F', level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3], controls=False)
ipv.show()

VBox(children=(Figure(camera=PerspectiveCamera(fov=45.0, position=(0.0, -2.1650635094610964, 1.250000000000000…