# Analyze the effects of measurements sparsity on 3D emission recovery

---
This notebook analyses the effects of measurements sparsity on the recovery of 3D emission. The measurements used for recovery increase in sparsity from full image plane to ngEHT and EHT visibilities. Emission is either recovered jointly with an unknown rotation axis as well as (for comparison) with a known rotation axis. Measurements are used to fit the parameters of a coordinate-based neural network.

In [3]:
import bhnerf
from bhnerf.network import flattened_traversal, shard
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import ehtim as eh
from ehtim.observing.obs_helpers import ftmatrix

import jax
import functools
from jax import jit
from jax import numpy as jnp
import jax.scipy as jsp
import flax
from flax.training import train_state
from flax.training import checkpoints
import optax

from tensorboardX import SummaryWriter
from datetime import datetime
from tqdm.notebook import tqdm

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'

In [4]:
"""
Generate emission hotspots and image plane movies
"""
spatial_res = (64, 64, 64)
nt = 128
r_isco = 3.0
orbit_radius = 3.5 
std = .4

orbit_period = orbit_radius**(-3./2.) 
velocity_field = lambda r: (1.0 / orbit_period) * r**(-3/2)

rot_axis = [-0.5, 0.0, 0.8660254]
rot_angles = np.array([0.73631078, 2.9943305 , 3.38702958, 4.17242774, 5.39961237])

initial_frame = 0.0
for rot_angle in rot_angles: 
    initial_frame += bhnerf.emission.generate_hotspot_xr(spatial_res, rot_axis, rot_angle, orbit_radius, std, r_isco)
    
# Shear and normalize emission values
normalization_factor = 0.02
emission = bhnerf.emission.generate_orbit(initial_frame, nt, velocity_field, rot_axis)

# Load a precomuted sensor Dataset and integrate over the 4D emission field to get 3D image-plane movie.
sensor = xr.load_dataset('../sensors/a0.00_th1.57_ngeo100_npix4096.nc')
sensor = sensor.where(sensor.r < 5)
sensor.attrs.update(r_min=sensor.r.min().data, r_max=sensor.r.max().data)
sensor = sensor.fillna(0.0)
image_pixels = bhnerf.emission.integrate_rays(emission, sensor)
image_plane = image_pixels.data.reshape(nt, sensor.num_alpha, sensor.num_beta)

# anti-aliasing for visualization
x = jnp.linspace(-sensor.num_alpha/2.0, sensor.num_alpha/2.0, sensor.num_alpha)
window = jsp.stats.norm.pdf(x) * jsp.stats.norm.pdf(x[:, None])
window = window / window.sum()
image_plane_blurred = bhnerf.utils.anti_aliasing_filter(image_plane, window)

# normalization for `reasonable` EHT fluxes
emission_normalized = normalization_factor * emission
image_pixels_normalized = bhnerf.emission.integrate_rays(emission_normalized, sensor)
image_plane_normalized = image_pixels_normalized.data.reshape(nt, sensor.num_alpha, sensor.num_beta)

In [5]:
"""
Visualize image-plane movie and ngEHT measurements
"""
image_plane_xr = xr.DataArray(image_plane, dims=['t', 'alpha', 'beta'])
image_plane_blurred_xr = xr.DataArray(image_plane_blurred, dims=['t', 'alpha', 'beta'])

%matplotlib widget
fig, ax = plt.subplots(1, 2, figsize=(6,3.5))
ax[0].set_title('Image plane')
ax[1].set_title('Image plane w/ anti-aliasing')
anim1 = image_plane_xr.visualization.animate(ax=ax[0], cmap='afmhot', add_colorbar=False)
anim2 = image_plane_blurred_xr.visualization.animate(ax=ax[1], cmap='afmhot', add_colorbar=False)
plt.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [6]:
"""
Slider visualization to illustrate the volumetric shearing
This visualization requires ipyvolume: https://ipyvolume.readthedocs.io/en/latest/
"""
import ipyvolume as ipv
from ipywidgets import interact
import ipywidgets as widgets

extent = [(float(emission[dim].min()), float(emission[dim].max())) for dim in ('x', 'y', 'z')]
@interact(t=widgets.IntSlider(min=0, max=emission.t.size-1, step=1, value=0))
def plot_vol(t):
    ipv.figure()
    ipv.view(0, -60, distance=2.5)
    ipv.volshow(emission.isel(t=t), extent=extent, memorder='F', level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3], controls=False)
    ipv.show()

interactive(children=(IntSlider(value=0, description='t', max=127), Output()), _dom_classes=('widget-interact'…

In [7]:
"""
Generate synthetic ngEHT/EHT observations of the image plane
"""
fov = 85.0             
obs_params = {
    'mjd': 57851,
    'timetype': 'GMST',
    'nt': 128,           # number of time samples 
    'tstart': 2.0,       # start of observations
    'tstop': 2.0 + 2/3,  # end of observation 
    'tint': 15.0         # integration time
}

visibilities = dict()
for array in ['ngEHT', 'EHT2017']:
    obs_params['array'] = eh.array.load_txt('../eht_arrays/{}.txt'.format(array))
    obs_empty = bhnerf.observation.empty_eht_obs(**obs_params)
    obs_args = {
        'psize': fov / sensor.num_alpha * eh.ehc.RADPERUAS,
        'ra': obs_empty.ra, 
        'dec': obs_empty.dec,
        'rf': obs_empty.rf, 
        'mjd': obs_empty.mjd,
    }
    times = np.linspace(obs_params['tstart'], obs_params['tstop'], obs_params['nt'])
    movie = eh.movie.Movie(image_plane_normalized, times=times, **obs_args)
    obs = bhnerf.observation.observe_same(movie, obs_empty, ttype='direct', seed=None)

    # Stack measurements and associated parameters for down-stream optimization
    uv = np.stack((bhnerf.observation.padded_obs(obs, 'u'), bhnerf.observation.padded_obs(obs, 'v')), axis=2)
    obs_times = np.array([np.mean(obsdata['time'][0]) for obsdata in obs.tlist()])
    visibilities[array] = {
        'obs': obs,
        'measurements': bhnerf.observation.padded_obs(obs, 'vis', fill_value=0.0),
        'sigma': bhnerf.observation.padded_obs(obs, 'sigma', fill_value=np.inf),
        'ft_mats': np.nan_to_num(np.stack([ftmatrix(movie.psize, movie.xdim, movie.ydim, uv_t, pulse=movie.pulse) for uv_t in uv]), 0.0)
    }

Producing clean visibilities from movie with direct FT . . . 
Applying Jones Matrices to data . . . 
Applying Jones Matrices to data . . . 
Adding thermal noise to data . . . 
Applying a priori calibration with estimated Jones matrices . . . 
Producing clean visibilities from movie with direct FT . . . 
Applying Jones Matrices to data . . . 
Applying Jones Matrices to data . . . 
Adding thermal noise to data . . . 
Applying a priori calibration with estimated Jones matrices . . . 


In [8]:
"""
Visualize visibility measurements
"""
%matplotlib widget
fig, ax = plt.subplots(1, 2, figsize=(7,3.5))
for i, (array, data) in enumerate(visibilities.items()):
    ax[i].set_title(array)
    data['obs'].plotall('uvdist', 'amp', axis=ax[i])
plt.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [12]:
"""
Define the loss functions and training steps for image plane and visibility optimization
"""
def loss_fn_image(params, coordinates, d, target):
    images = measurement_op(params, coordinates, d)
    loss = jnp.mean(jnp.abs(images - target)**2)
    return loss, [images]

def loss_fn_eht(params, coordinates, d, target, sigma, dtft_matrices):
    visibilities, images = measurement_op(params, coordinates, d, dtft_matrices)
    loss = jnp.mean((jnp.abs(visibilities - target)/sigma)**2)
    return loss, [visibilities, images]

@functools.partial(jit, static_argnums=(0))
def train_step_image(loss_fn, state, x, y, z, t, d, target):
    (loss, [images]), grads = jax.value_and_grad(loss_fn, argnums=(0), has_aux=True)(
        state.params, [t, x, y, z], d,  target)
    grads = jax.lax.pmean(grads, axis_name='batch')
    state = state.apply_gradients(grads=grads)
    return loss, state, images

@functools.partial(jit, static_argnums=(0))
def train_step_eht(loss_fn, state, x, y, z, t, d, target, sigma, dtft_matrices):
    (loss, [visibilities, images]), grads = jax.value_and_grad(
        loss_fn, argnums=(0), has_aux=True)(state.params, [t, x, y, z], d, target, sigma, dtft_matrices)
    grads = jax.lax.pmean(grads, axis_name='batch')
    state = state.apply_gradients(grads=grads)
    return loss, state, visibilities, images

In [None]:
"""
Define training parameters and run the training loop.
"""
log_period = 100       # Logging frequency to TensorBoard
save_period = 1000     # Saving checkpoints 

runname = 'sparsity_test'
estimate_axis_flag = [False, True]
measurement_types = ['full_image', 'ngEHT', 'EHT2017']

current_time = datetime.now().strftime('%Y-%m-%d.%H:%M:%S')

hparams = {
    'num_iters': 5000,
    'lr_init': 1e-4,
    'lr_final': 1e-6,
    'batchsize': 6,
}
orbit_args = {
    'tstart': obs_times[0], 
    'tstop': obs_times[-1],
    'velocity': velocity_field 
}

# Define supervision coordinates
t, x, y, z, d = bhnerf.network.get_input_coords(sensor, t_array=obs_times, batch='t').values()

# Grid for visualization (no interpolation is added for easy comparison) 
vis_coords = np.meshgrid(orbit_args['tstart'], 
                         np.linspace(emission.x[0], emission.x[-1], spatial_res[0]),
                         np.linspace(emission.y[0], emission.y[-1], spatial_res[1]),
                         np.linspace(emission.z[0], emission.z[-1], spatial_res[2]),
                         indexing='ij')

for estimate_axis in estimate_axis_flag:

    if estimate_axis:
        axis_str = 'estimated_axis'
        hparams['lr_axis'] = 1e-2
        orbit_args['axis_init'] = jnp.array([-0.4556, 0.5549, 0.6961])
    else:
        axis_str = 'true_axis'
        hparams['lr_axis'] = 0.0
        orbit_args['axis_init'] = rot_axis
    
    # Measurement model and loss/training setup
    predictor = bhnerf.network.NeRF_RotationAxis()
    emission_op = bhnerf.network.EmissionOperator(predictor, orbit_args)
    params = predictor.init(jax.random.PRNGKey(1), [t[:1, ...], x[:1, ...], y[:1, ...], z[:1, ...]], **orbit_args)['params']
    
    # Split learning rate for axis / network parameters
    tx = optax.chain(optax.masked(optax.adam(learning_rate=hparams['lr_axis']), mask=flattened_traversal(lambda path, _: path[-1] == 'axis')),
                     optax.masked(optax.adam(learning_rate=optax.polynomial_schedule(hparams['lr_init'], hparams['lr_final'], 1, hparams['num_iters'])),
                                  mask=flattened_traversal(lambda path, _: path[-1] != 'axis')))
    
    for measurement_type in measurement_types:
        if measurement_type == 'full_image':
            emission_true = emission
            measurement_op = bhnerf.network.ImagePlaneOperator(emission_op)
            train_pstep = jax.pmap(train_step_image, axis_name='batch', in_axes=(None, 0, 0, 0, 0, 0, 0, 0), static_broadcasted_argnums=(0))
        elif measurement_type in visibilities.keys():
            obs, measurements, sigma, ft_mats = visibilities[measurement_type].values()
            emission_true = emission_normalized
            measurement_op = bhnerf.network.VisibilityOperator(emission_op)
            train_pstep = jax.pmap(train_step_eht, axis_name='batch', in_axes=(None, 0, 0, 0, 0, 0, 0, 0, 0, 0), static_broadcasted_argnums=(0))
        else:
            raise AttributeError('Undefined measurement type: {}'.format(measurement_type))
    
        state = train_state.TrainState.create(apply_fn=predictor.apply, params=params.unfreeze(), tx=tx)  # TODO(pratul): this unfreeze feels sketchy
        full_runname = runname + '/{}.nspots{}.{}.{}'.format(measurement_type, len(rot_angles), axis_str, current_time)
        checkpoint_dir = 'checkpoints/{}'.format(full_runname)
        logdir = 'runs/{}'.format(full_runname)

        # Restore saved checkpoint
        if np.isscalar(save_period): state = checkpoints.restore_checkpoint(checkpoint_dir, state)
        init_step = 1 + state.step
        state = flax.jax_utils.replicate(state)

        # Training loop with TensorBoard logging 
        with SummaryWriter(logdir=logdir) as writer:
            writer.add_images('emission/true', bhnerf.utils.intensity_to_nchw(emission_true.isel(t=0)), global_step=0)

            for i in tqdm(range(init_step, init_step + hparams['num_iters']), desc='iteration'):
                batch_inds = np.random.choice(range(x.shape[0]), hparams['batchsize'], replace=False)
                
                if measurement_type == 'full_image':
                    loss, state, images = train_pstep(
                        loss_fn_image, state, shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
                    shard(t[batch_inds, ...]), shard(d[batch_inds, ...]), shard(image_plane[batch_inds, ...])
                )
                else:
                    loss, state, vis, images = train_pstep(
                        loss_fn_eht, state, shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
                        shard(t[batch_inds, ...]), shard(d[batch_inds, ...]), shard(measurements[batch_inds, ...]), 
                        shard(sigma[batch_inds, ...]), shard(ft_mats[batch_inds, ...])
                    )


                # Log the current state on TensorBoard
                writer.add_scalar('log_loss/train', np.log10(np.mean(loss)), global_step=i)
                if (i == 1) or (i % log_period) == 0:
                    current_state = jax.device_get(flax.jax_utils.unreplicate(state))
                    emission_grid = emission_op(current_state.params, vis_coords)
                    emission_grid = bhnerf.emission.zero_unsupervised_emission(emission_grid, vis_coords[1:], sensor.r_min, sensor.r_max)
                    rot_axis_estimate = bhnerf.utils.normalize(current_state.params['axis'])
                    writer.add_image('image_plane/true', image_plane[batch_inds[0], None, ...], global_step=i)
                    writer.add_image('image_plane/estimate', images[0, 0, None, :, :], global_step=i)
                    writer.add_images('emission/estimate', bhnerf.utils.intensity_to_nchw(emission_grid), global_step=i)
                    writer.add_scalar('emission/mse', bhnerf.utils.mse(emission_true.data[0], emission_grid), global_step=i)
                    writer.add_scalar('emission/psnr', bhnerf.utils.psnr(emission_true.data[0], emission_grid), global_step=i)
                    writer.add_scalar('rotation/dot_product', np.dot(rot_axis, rot_axis_estimate), global_step=i)

                # Save checkpoints occasionally
                if np.isscalar(save_period) and ((i % save_period == 0) or (i == hparams['num_iters'])):
                    if (save_period % log_period): current_state = jax.device_get(flax.jax_utils.unreplicate(state))
                    checkpoints.save_checkpoint(checkpoint_dir, current_state, int(i), keep=5)