# Tutorial3:estimate 3D emission from ngEHT observations

---
This tutorial demonstrates the recovery of 3D emission from ngEHT observations. The visibility meaurements are used to fit the parameters of a coordinate-based neural network and the rotation axis (dictated by the inclination angle). This tutorial assumes a sensor is ready to be loaded (see Tutorial1 for more details).

In [1]:
import bhnerf
from bhnerf.network import flattened_traversal, shard
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import ehtim as eh
from ehtim.observing.obs_helpers import ftmatrix

import jax
import functools
from jax import jit
from jax import numpy as jnp
import flax
from flax.training import train_state
from flax.training import checkpoints
import optax

from tensorboardX import SummaryWriter
from datetime import datetime
from tqdm.notebook import tqdm

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-yjzw_9lz because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


Welcome to eht-imaging! v 1.2.2 



2022-02-02 14:58:20.298934: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /.singularity.d/libs


In [2]:
"""
Generate synthetic ngEHT observations of an orbiting Gaussian hotspot. 
This cell is condensed version of Tutorial 2.
This notebook requires eht-imaging: https://github.com/achael/eht-imaging
"""
spatial_res = (64, 64, 64)
nt = 128
r_isco = 3.0
orbit_radius = 3.5 
std = .4
rot_angle = 0
rot_axis = [-0.5, 0.0, 0.8660254]
orbit_period = orbit_radius**(-3./2.) 
velocity_field = lambda r: (1.0 / orbit_period) * r**(-3/2)

initial_frame = bhnerf.emission.generate_hotspot_xr(spatial_res, rot_axis, rot_angle, orbit_radius, std, r_isco)
emission = bhnerf.emission.generate_orbit(initial_frame, nt, velocity_field, rot_axis)

# Normalize emission values 
normalization_factor = 0.02
emissions = emission * normalization_factor

# Load a precomuted sensor Dataset and integrate over the 4D emission field to get 3D image-plane movie.
sensor = xr.load_dataset('../sensors/a0.00_th1.57_ngeo100_npix4096.nc')
sensor = sensor.where(sensor.r < 5)
sensor.attrs.update(r_min=sensor.r.min().data, r_max=sensor.r.max().data)
sensor = sensor.fillna(0.0)
image_pixels = bhnerf.emission.integrate_rays(emission, sensor)
image_plane = image_pixels.data.reshape(nt, sensor.num_alpha, sensor.num_beta)


# Generate synthetic ngEHT observations of the image plane.
fov = 85.0             
obs_params = {
    'mjd': 57851,
    'array': eh.array.load_txt('../eht_arrays/ngEHT.txt'),
    'timetype': 'GMST',
    'nt': 128,           # number of time samples 
    'tstart': 2.0,       # start of observations
    'tstop': 2.0 + 2/3,  # end of observation 
    'tint': 15.0         # integration time
}
obs_empty = bhnerf.observation.empty_eht_obs(**obs_params)
obs_args = {
    'psize': fov / sensor.num_alpha * eh.ehc.RADPERUAS,
    'ra': obs_empty.ra, 
    'dec': obs_empty.dec,
    'rf': obs_empty.rf, 
    'mjd': obs_empty.mjd,
}
times = np.linspace(obs_params['tstart'], obs_params['tstop'], obs_params['nt'])
movie = eh.movie.Movie(image_plane, times=times, **obs_args)
obs = bhnerf.observation.observe_same(movie, obs_empty, ttype='direct', seed=None)

# Stack measurements and associated parameters for down-stream optimization
measurements = bhnerf.observation.padded_obs(obs, 'vis', fill_value=0.0)
sigma = bhnerf.observation.padded_obs(obs, 'sigma', fill_value=np.inf)
uv = np.stack((bhnerf.observation.padded_obs(obs, 'u'), bhnerf.observation.padded_obs(obs, 'v')), axis=2)
ft_mats = np.nan_to_num(np.stack([ftmatrix(movie.psize, movie.xdim, movie.ydim, uv_t, pulse=movie.pulse) for uv_t in uv]), 0.0)
obs_times = np.array([np.mean(obsdata['time'][0]) for obsdata in obs.tlist()])

Producing clean visibilities from movie with direct FT . . . 
Applying Jones Matrices to data . . . 
Applying Jones Matrices to data . . . 
Adding thermal noise to data . . . 
Applying a priori calibration with estimated Jones matrices . . . 


In [3]:
"""
Visualize the underlying image plane
"""
%matplotlib widget
image_plane_xr = xr.DataArray(image_plane, dims=['t', 'alpha', 'beta'])
image_plane_xr.visualization.animate(cmap='afmhot')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.animation.FuncAnimation at 0x7f5f50761670>

In [5]:
"""
Define the forward operator (generate synthetic ngEHT measurements), loss function and training step.
"""
axis_init = jnp.array([-0.4556, 0.5549, 0.6961])  # Randomly sampled (kept fixed for experiments)
orbit_args = {
    'tstart': obs_times[0], 
    'tstop': obs_times[-1],
    'axis_init': axis_init,
    'velocity': velocity_field 
}

# Measurement model and loss/training setup
predictor = bhnerf.network.NeRF_RotationAxis()
emission_op = bhnerf.network.EmissionOperator(predictor, orbit_args)
visibility_op = bhnerf.network.VisibilityOperator(emission_op)

def loss_fn(params, coordinates, d, target, sigma, dtft_matrices):
    visibilities, images = visibility_op(params, coordinates, d, dtft_matrices)
    loss = jnp.mean((jnp.abs(visibilities - target)/sigma)**2)
    return loss, [visibilities, images]

@functools.partial(jit, static_argnums=(0))
def train_step(loss_fn, state, x, y, z, t, d, target, sigma, dtft_matrices):
    (loss, [visibilities, images]), grads = jax.value_and_grad(loss_fn, argnums=(0), has_aux=True)(
        state.params, [t, x, y, z], d,  target, sigma, dtft_matrices)
    grads = jax.lax.pmean(grads, axis_name='batch')
    state = state.apply_gradients(grads=grads)
    return loss, state, visibilities, images

# Parallel mapping across GPUs
train_pstep = jax.pmap(train_step, axis_name='batch', in_axes=(None, 0, 0, 0, 0, 0, 0, 0, 0, 0), static_broadcasted_argnums=(0))

In [6]:
"""
Define training parameters and run the training loop.
"""
log_period = 100       # Logging frequency to TensorBoard
save_period = 1000     # Saving checkpoints 

runname = 'tutorial3'
checkpoint_dir = 'checkpoints/{}'.format(runname)
logdir = 'runs/{}.{}'.format(runname, datetime.now().strftime('%Y-%m-%d.%H:%M:%S'))

# Training params
hparams = {
    'num_iters': 5000,
    'lr_init': 1e-4,
    'lr_final': 1e-6,
    'batchsize': 6,
    'lr_axis': 1e-2,
}

# Define supervision coordinates
t, x, y, z, d = bhnerf.network.get_input_coords(sensor, t_array=obs_times, batch='t').values()

# Grid for visualization (no interpolation is added for easy comparison) 
vis_coords = np.meshgrid(orbit_args['tstart'], 
                         np.linspace(emission.x[0], emission.x[-1], spatial_res[0]),
                         np.linspace(emission.y[0], emission.y[-1], spatial_res[1]),
                         np.linspace(emission.z[0], emission.z[-1], spatial_res[2]),
                         indexing='ij')

params = predictor.init(jax.random.PRNGKey(1), [t[:1, ...], x[:1, ...], y[:1, ...], z[:1, ...]], **orbit_args)['params']

# Split learning rate for axis / network parameters
tx = optax.chain(optax.masked(optax.adam(learning_rate=hparams['lr_axis']), mask=flattened_traversal(lambda path, _: path[-1] == 'axis')),
                 optax.masked(optax.adam(learning_rate=optax.polynomial_schedule(hparams['lr_init'], hparams['lr_final'], 1, hparams['num_iters'])),
                              mask=flattened_traversal(lambda path, _: path[-1] != 'axis')))
state = train_state.TrainState.create(apply_fn=predictor.apply, params=params.unfreeze(), tx=tx)  # TODO(pratul): this unfreeze feels sketchy

# Restore saved checkpoint
if np.isscalar(save_period): state = checkpoints.restore_checkpoint(checkpoint_dir, state)
init_step = 1 + state.step

state = flax.jax_utils.replicate(state) # For parallelization across GPUs

# Training loop with TensorBoard logging 
with SummaryWriter(logdir=logdir) as writer:
    
    writer.add_images('emission/true', bhnerf.utils.intensity_to_nchw(emission.isel(t=0)), global_step=0)
    
    for i in tqdm(range(init_step, init_step + hparams['num_iters']), desc='iteration'):
        batch_inds = np.random.choice(range(x.shape[0]), hparams['batchsize'], replace=False)
        loss, state, visibilities, images = train_pstep(
            loss_fn, state, shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
            shard(t[batch_inds, ...]), shard(d[batch_inds, ...]), shard(measurements[batch_inds, ...]), 
            shard(sigma[batch_inds, ...]), shard(ft_mats[batch_inds, ...])
        )
        
        # Log the current state on TensorBoard
        writer.add_scalar('log_loss/train', np.log10(np.mean(loss)), global_step=i)
        if (i == 1) or (i % log_period) == 0:
            current_state = jax.device_get(flax.jax_utils.unreplicate(state))
            emission_grid = emission_op(current_state.params, vis_coords)
            emission_grid = bhnerf.emission.zero_unsupervised_emission(emission_grid, vis_coords[1:], sensor.r_min, sensor.r_max)
            rot_axis_estimate = bhnerf.utils.normalize(current_state.params['axis'])
            writer.add_image('image_plane/true', image_plane[batch_inds[0], None, ...], global_step=i)
            writer.add_image('image_plane/estimate', images[0, 0, None, :, :], global_step=i)
            writer.add_images('emission/estimate', bhnerf.utils.intensity_to_nchw(emission_grid), global_step=i)
            writer.add_scalar('emission/mse', bhnerf.utils.mse(emission.data[0], emission_grid), global_step=i)
            writer.add_scalar('emission/psnr', bhnerf.utils.psnr(emission.data[0], emission_grid), global_step=i)
            writer.add_scalar('rotation/dot_product', np.dot(rot_axis, rot_axis_estimate), global_step=i)
        
        # Save checkpoints occasionally
        if np.isscalar(save_period) and ((i % save_period == 0) or (i == hparams['num_iters'])):
            if (save_period % log_period): current_state = jax.device_get(flax.jax_utils.unreplicate(state))
            checkpoints.save_checkpoint(checkpoint_dir, current_state, int(i), keep=5)



iteration:   0%|          | 0/5000 [00:00<?, ?it/s]

In [7]:
"""
Visualization the recovered 3D emission estimated from ngEHT measurements.
This visualization requires ipyvolume: https://ipyvolume.readthedocs.io/en/latest/
"""
import ipyvolume as ipv

# Get the convereged solution from the neural network
current_state = jax.device_get(flax.jax_utils.unreplicate(state))
emission_grid = emission_op(current_state.params, vis_coords)
emission_grid = bhnerf.emission.zero_unsupervised_emission(emission_grid, vis_coords[1:], sensor.r_min, sensor.r_max)
        
extent = [(float(emission[dim].min()), float(emission[dim].max())) for dim in ('x', 'y', 'z')]
ipv.figure()
ipv.view(0, -60, distance=2.5)
ipv.volshow(emission_grid, extent=extent, memorder='F', level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3], controls=False)
ipv.show()

VBox(children=(Figure(camera=PerspectiveCamera(fov=45.0, position=(0.0, -2.1650635094610964, 1.250000000000000…