# Analyze the effects of measurements sparsity on 3D emission recovery

---
This notebook analyses the effects of measurements sparsity on the recovery of 3D emission. The measurements used for recovery increase in sparsity from full image plane to ngEHT and EHT visibilities. Emission is either recovered jointly with an unknown rotation axis as well as (for comparison) with a known rotation axis. Measurements are used to fit the parameters of a coordinate-based neural network.

In [1]:
import bhnerf
from bhnerf.network import flattened_traversal, shard
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import ehtim as eh
from ehtim.observing.obs_helpers import ftmatrix

import jax
import functools
from jax import jit
from jax import numpy as jnp
import jax.scipy as jsp
import flax
from flax.training import train_state
from flax.training import checkpoints
import optax

from tensorboardX import SummaryWriter
from datetime import datetime
from tqdm.notebook import tqdm

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-thif4_nw because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


Welcome to eht-imaging! v 1.2.2 



2022-01-19 17:25:59.387410: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /.singularity.d/libs


In [2]:
"""
Generate emission hotspots and image plane movies
"""
spatial_res = (64, 64, 64)
nt = 128
r_isco = 3.0
orbit_radius = 3.5 
std = .4

orbit_period = orbit_radius**(-3./2.) 
velocity_field = lambda r: (1.0 / orbit_period) * r**(-3/2)

rot_axis = [-0.5, 0.0, 0.8660254]
rot_angles = np.array([0.73631078, 2.9943305 , 3.38702958, 4.17242774, 5.39961237])

initial_frame = 0.0
for rot_angle in rot_angles: 
    initial_frame += bhnerf.emission.generate_hotspot_xr(spatial_res, rot_axis, rot_angle, orbit_radius, std, r_isco)
    
# Shear and normalize emission values
normalization_factor = 0.02
emission = bhnerf.emission.generate_orbit(initial_frame, nt, velocity_field, rot_axis)

# Load a precomuted sensor Dataset and integrate over the 4D emission field to get 3D image-plane movie.
sensor = xr.load_dataset('../sensors/a0.00_th1.57_ngeo100_npix4096.nc')
sensor = sensor.where(sensor.r < 5)
sensor.attrs.update(r_min=sensor.r.min().data, r_max=sensor.r.max().data)
sensor = sensor.fillna(0.0)
image_pixels = bhnerf.emission.integrate_rays(emission, sensor)
image_plane = image_pixels.data.reshape(nt, sensor.num_alpha, sensor.num_beta)

# anti-aliasing for visualization
x = jnp.linspace(-sensor.num_alpha/2.0, sensor.num_alpha/2.0, sensor.num_alpha)
window = jsp.stats.norm.pdf(x) * jsp.stats.norm.pdf(x[:, None])
window = window / window.sum()
image_plane_blurred = bhnerf.utils.anti_aliasing_filter(image_plane, window)

# normalization for `reasonable` EHT fluxes
emission_normalized = normalization_factor * emission
image_pixels_normalized = bhnerf.emission.integrate_rays(emission_normalized, sensor)
image_plane_normalized = image_pixels_normalized.data.reshape(nt, sensor.num_alpha, sensor.num_beta)

In [3]:
"""
Visualize image-plane movie and ngEHT measurements
"""
image_plane_xr = xr.DataArray(image_plane, dims=['t', 'alpha', 'beta'])
image_plane_blurred_xr = xr.DataArray(image_plane_blurred, dims=['t', 'alpha', 'beta'])

%matplotlib widget
fig, ax = plt.subplots(1, 2, figsize=(6,3.5))
ax[0].set_title('Image plane')
ax[1].set_title('Image plane w/ anti-aliasing')
anim1 = image_plane_xr.visualization.animate(ax=ax[0], cmap='afmhot', add_colorbar=False)
anim2 = image_plane_blurred_xr.visualization.animate(ax=ax[1], cmap='afmhot', add_colorbar=False)
plt.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [4]:
"""
Slider visualization to illustrate the volumetric shearing
This visualization requires ipyvolume: https://ipyvolume.readthedocs.io/en/latest/
"""
import ipyvolume as ipv
from ipywidgets import interact
import ipywidgets as widgets

extent = [(float(emission[dim].min()), float(emission[dim].max())) for dim in ('x', 'y', 'z')]
@interact(t=widgets.IntSlider(min=0, max=emission.t.size-1, step=1, value=0))
def plot_vol(t):
    ipv.figure()
    ipv.view(0, -60, distance=2.5)
    ipv.volshow(emission.isel(t=t), extent=extent, memorder='F', level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3], controls=False)
    ipv.show()

interactive(children=(IntSlider(value=0, description='t', max=127), Output()), _dom_classes=('widget-interact'…

In [6]:
"""
Generate synthetic ngEHT/EHT observations of the image plane
"""
fov = 85.0             
obs_params = {
    'mjd': 57851,
    'timetype': 'GMST',
    'nt': 128,           # number of time samples 
    'tstart': 2.0,       # start of observations
    'tstop': 2.0 + 2/3,  # end of observation 
    'tint': 15.0         # integration time
}

visibilities = dict()
for array in ['ngEHT', 'EHT2017']:
    obs_params['array'] = eh.array.load_txt('../eht_arrays/{}.txt'.format(array))
    obs_empty = bhnerf.observation.empty_eht_obs(**obs_params)
    obs_args = {
        'psize': fov / sensor.num_alpha * eh.ehc.RADPERUAS,
        'ra': obs_empty.ra, 
        'dec': obs_empty.dec,
        'rf': obs_empty.rf, 
        'mjd': obs_empty.mjd,
    }
    times = np.linspace(obs_params['tstart'], obs_params['tstop'], obs_params['nt'])
    movie = eh.movie.Movie(image_plane_normalized, times=times, **obs_args)
    obs = bhnerf.observation.observe_same(movie, obs_empty, ttype='direct', seed=None)

    # Stack measurements and associated parameters for down-stream optimization
    uv = np.stack((bhnerf.observation.padded_obs(obs, 'u'), bhnerf.observation.padded_obs(obs, 'v')), axis=2)
    obs_times = np.array([np.mean(obsdata['time'][0]) for obsdata in obs.tlist()])
    visibilities[array] = {
        'obs': obs,
        'measurements': bhnerf.observation.padded_obs(obs, 'vis', fill_value=0.0),
        'sigma': bhnerf.observation.padded_obs(obs, 'sigma', fill_value=np.inf),
        'ft_mats': np.nan_to_num(np.stack([ftmatrix(movie.psize, movie.xdim, movie.ydim, uv_t, pulse=movie.pulse) for uv_t in uv]), 0.0)
    }

Producing clean visibilities from movie with direct FT . . . 
Applying Jones Matrices to data . . . 
Applying Jones Matrices to data . . . 
Adding thermal noise to data . . . 
Applying a priori calibration with estimated Jones matrices . . . 
Producing clean visibilities from movie with direct FT . . . 
Applying Jones Matrices to data . . . 
Applying Jones Matrices to data . . . 
Adding thermal noise to data . . . 
Applying a priori calibration with estimated Jones matrices . . . 


In [7]:
"""
Visualize visibility measurements
"""
%matplotlib widget
fig, ax = plt.subplots(1, 2, figsize=(7,3.5))
for i, (array, data) in enumerate(visibilities.items()):
    ax[i].set_title(array)
    data['obs'].plotall('uvdist', 'amp', axis=ax[i])
plt.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [8]:
"""
Define the loss functions and training steps for image plane and visibility optimization
"""
def loss_fn_image(params, x, y, z, t, d, target):
    images = measurement_op(params, x, y, z, t, d)
    loss = jnp.mean(jnp.abs(images - target)**2)
    return loss, [images]

def loss_fn_eht(params, x, y, z, t, d, target, sigma, dtft_matrices):
    visibilities, images = measurement_op(params, x, y, z, t, d, dtft_matrices)
    loss = jnp.mean((jnp.abs(visibilities - target)/sigma)**2)
    return loss, [visibilities, images]

@functools.partial(jit, static_argnums=(0))
def train_step_image(loss_fn, state, x, y, z, t, d, target):
    (loss, [images]), grads = jax.value_and_grad(loss_fn, argnums=(0), has_aux=True)(state.params, x, y, z, t, d,  target)
    grads = jax.lax.pmean(grads, axis_name='batch')
    state = state.apply_gradients(grads=grads)
    return loss, state, images

@functools.partial(jit, static_argnums=(0))
def train_step_eht(loss_fn, state, x, y, z, t, d, target, sigma, dtft_matrices):
    (loss, [visibilities, images]), grads = jax.value_and_grad(loss_fn, argnums=(0), has_aux=True)(
        state.params, x, y, z, t, d,  target, sigma, dtft_matrices)
    grads = jax.lax.pmean(grads, axis_name='batch')
    state = state.apply_gradients(grads=grads)
    return loss, state, visibilities, images

In [14]:
"""
Define training parameters and run the training loop.
"""
log_period = 100       # Logging frequency to TensorBoard
save_period = 1000     # Saving checkpoints 

runname = 'sparsity'
estimate_axis_flag = [False, True]
measurement_types = ['full_image', 'ngEHT', 'EHT2017']

current_time = datetime.now().strftime('%Y-%m-%d.%H:%M:%S')

hparams = {
    'num_iters': 5000,
    'lr_init': 1e-4,
    'lr_final': 1e-6,
    'batchsize': 6,
}
orbit_args = {
    'tstart': obs_times[0], 
    'tstop': obs_times[-1],
    'velocity': velocity_field 
}

# Define supervision coordinates
t, x, y, z, d = bhnerf.network.get_input_coords(sensor, t_array=obs_times, batch='t').values()

# Grid for visualization (no interpolation factor is added for easy comparison) 
x_res_vis, y_res_vis, z_res_vis = spatial_res
t_vis, x_vis, y_vis, z_vis  = np.meshgrid(orbit_args['tstart'], 
                                          np.linspace(emission.x[0], emission.x[-1], x_res_vis),
                                          np.linspace(emission.y[0], emission.y[-1], y_res_vis),
                                          np.linspace(emission.z[0], emission.z[-1], z_res_vis),
                                          indexing='ij')

for estimate_axis in estimate_axis_flag:

    if estimate_axis:
        axis_str = 'estimated_axis'
        hparams['lr_axis'] = 1e-2
        orbit_args['axis_init'] = jnp.array([-0.4556, 0.5549, 0.6961])
    else:
        axis_str = 'true_axis'
        hparams['lr_axis'] = 0.0
        orbit_args['axis_init'] = rot_axis
    
    # Measurement model and loss/training setup
    predictor = bhnerf.network.NeRF3D_RotationAxis()
    emission_op = bhnerf.network.EmissionOperator(predictor, orbit_args)
    params = predictor.init(jax.random.PRNGKey(1), x[:1, ...], y[:1, ...], z[:1, ...], t[:1, ...], **orbit_args)['params']
    
    # Split learning rate for axis / network parameters
    tx = optax.chain(optax.masked(optax.adam(learning_rate=hparams['lr_axis']), mask=flattened_traversal(lambda path, _: path[-1] == 'axis')),
                     optax.masked(optax.adam(learning_rate=optax.polynomial_schedule(hparams['lr_init'], hparams['lr_final'], 1, hparams['num_iters'])),
                                  mask=flattened_traversal(lambda path, _: path[-1] != 'axis')))
    
    for measurement_type in measurement_types:
        if measurement_type == 'full_image':
            emission_true = emission
            measurement_op = bhnerf.network.ImagePlaneOperator(emission_op)
            train_pstep = jax.pmap(train_step_image, axis_name='batch', in_axes=(None, 0, 0, 0, 0, 0, 0, 0), static_broadcasted_argnums=(0))
        elif measurement_type in visibilities.keys():
            obs, measurements, sigma, ft_mats = visibilities[measurement_type].values()
            emission_true = emission_normalized
            measurement_op = bhnerf.network.VisibilityOperator(emission_op)
            train_pstep = jax.pmap(train_step_eht, axis_name='batch', in_axes=(None, 0, 0, 0, 0, 0, 0, 0, 0, 0), static_broadcasted_argnums=(0))
        else:
            raise AttributeError('Undefined measurement type: {}'.format(measurement_type))
    
        state = train_state.TrainState.create(apply_fn=predictor.apply, params=params.unfreeze(), tx=tx)  # TODO(pratul): this unfreeze feels sketchy
        full_runname = runname + '/{}.nspots{}.{}.{}'.format(measurement_type, len(rot_angles), axis_str, current_time)
        checkpoint_dir = 'checkpoints/{}'.format(full_runname)
        logdir = 'runs/{}'.format(full_runname)

        # Restore saved checkpoint
        if np.isscalar(save_period): state = checkpoints.restore_checkpoint(checkpoint_dir, state)
        init_step = 1 + state.step
        state = flax.jax_utils.replicate(state)

        # Training loop with TensorBoard logging 
        with SummaryWriter(logdir=logdir) as writer:
            writer.add_images('emission/true', bhnerf.utils.intensity_to_nchw(emission_true.isel(t=0)), global_step=0)

            for i in tqdm(range(init_step, init_step + hparams['num_iters']), desc='iteration'):
                batch_inds = np.random.choice(range(x.shape[0]), hparams['batchsize'], replace=False)
                
                if measurement_type == 'full_image':
                    loss, state, images = train_pstep(
                        loss_fn_image, state, shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
                    shard(t[batch_inds, ...]), shard(d[batch_inds, ...]), shard(image_plane[batch_inds, ...])
                )
                else:
                    loss, state, vis, images = train_pstep(
                        loss_fn_eht, state, shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
                        shard(t[batch_inds, ...]), shard(d[batch_inds, ...]), shard(measurements[batch_inds, ...]), 
                        shard(sigma[batch_inds, ...]), shard(ft_mats[batch_inds, ...])
                    )


                # Log the current state on TensorBoard
                writer.add_scalar('log_loss/train', np.log10(np.mean(loss)), global_step=i)
                if (i == 1) or (i % log_period) == 0:
                    current_state = jax.device_get(flax.jax_utils.unreplicate(state))
                    emission_grid = emission_op(current_state.params, x_vis, y_vis, z_vis, t_vis)[0]
                    emission_grid = bhnerf.emission.zero_unsupervised_emission(
                        emission_grid, x_vis[0], y_vis[0], z_vis[0], sensor.r_min, sensor.r_max)
                    rot_axis_estimate = bhnerf.utils.normalize(current_state.params['axis'])
                    writer.add_image('image_plane/true', image_plane[batch_inds[0], None, ...], global_step=i)
                    writer.add_image('image_plane/estimate', images[0, 0, None, :, :], global_step=i)
                    writer.add_images('emission/estimate', bhnerf.utils.intensity_to_nchw(emission_grid), global_step=i)
                    writer.add_scalar('emission/mse', bhnerf.utils.mse(emission_true.data[0], emission_grid), global_step=i)
                    writer.add_scalar('emission/psnr', bhnerf.utils.psnr(emission_true.data[0], emission_grid), global_step=i)
                    writer.add_scalar('rotation/dot_product', np.dot(rot_axis, rot_axis_estimate), global_step=i)

                # Save checkpoints occasionally
                if np.isscalar(save_period) and ((i % save_period == 0) or (i == hparams['num_iters'])):
                    if (save_period % log_period): current_state = jax.device_get(flax.jax_utils.unreplicate(state))
                    checkpoints.save_checkpoint(checkpoint_dir, current_state, int(i), keep=5)



iteration:   0%|          | 0/5000 [00:00<?, ?it/s]



iteration:   0%|          | 0/5000 [00:00<?, ?it/s]



iteration:   0%|          | 0/5000 [00:00<?, ?it/s]



iteration:   0%|          | 0/5000 [00:00<?, ?it/s]



iteration:   0%|          | 0/5000 [00:00<?, ?it/s]



iteration:   0%|          | 0/5000 [00:00<?, ?it/s]

In [63]:
def train_network(sensor, emission_true, emission_test, velocity_field, obs_times, hparams, runname, 
                  log_period=100, save_period=10000, x_res_vis=64, y_res_vis=64, z_res_vis=64):
    
    bg_image = jnp.array(riaf_emission.imarr(), dtype=np.float32)
    tstart = obs_times[0]
    tstop = obs_times[-1]

    # Training / testing coordinates
    train_coords = network_utils.get_input_coords(sensor, t_array=obs_times, batch='t')
    t, x, y, z, d = train_coords.values()
    train_vis = image_plane

    test_coords = network_utils.get_input_coords(sensor, t_array=obs_times, batch='t')
    test_vis = image_plane_test
    t_test, x_test, y_test, z_test, d_test = test_coords.values()

    # Emission visualization inputs
    t_res_vis = hparams['batchsize']
    emission_extent = [emission_true.x[0], emission_true.x[-1], emission_true.y[0], emission_true.y[-1], emission_true.z[0], emission_true.z[-1]]
    t_vis, x_vis, y_vis, z_vis  = np.meshgrid(np.full(t_res_vis, fill_value=tstart), 
                                              np.linspace(emission_extent[0], emission_extent[1], x_res_vis),
                                              np.linspace(emission_extent[2], emission_extent[3], y_res_vis),
                                              np.linspace(emission_extent[4], emission_extent[5], x_res_vis),
                                              indexing='ij')

    d_vis = np.ones_like(y_vis)                  # meaningless placeholder for emission visualization
    ft_mats = np.zeros((nt, 1, 1))               # meaningless placeholder
    sigma = np.zeros((nt, 1))                    # meaningless placeholder
    sigma_test = np.zeros((nt_test, 1))          # meaningless placeholder 
    r_min = hparams['r_min']                     # Zero network output where there is no supervision (within black-hole radius)  
    r_max = hparams['r_max']
    
    # Model setup and initialization
    rng = jax.random.PRNGKey(1)
    predictor = network_utils.PREDICT_EMISSION_AND_MLP_ROTAXIS_3D_FROM_IMAGE_W_BG(
        posenc_deg=hparams['posenc_deg'],
        axis_net_depth=hparams['axis_net_depth']
    )
    params = predictor.init(rng, x[:1, ...], y[:1, ...], z[:1, ...], t[:1, ...], velocity_field, tstart, tstop)['params']
    tx = optax.adam(learning_rate=optax.polynomial_schedule(hparams['lr_init'], hparams['lr_final'], 1, hparams['num_iters']))
    state = train_state.TrainState.create(apply_fn=predictor.apply, params=params.unfreeze(), tx=tx)  # TODO(pratul): this unfreeze feels sketchy
    
    if np.isscalar(save_period):
        checkpoint_dir = 'checkpoints/{}'.format(runname)
        state = checkpoints.restore_checkpoint(checkpoint_dir, state)
    init_step = 1 + state.step
    train_pstep = jax.pmap(predictor.train_step, axis_name='batch', in_axes=(None, None, 0, 0, 0, 0, 0, 0, None, None, 0, 0, None, None, 0, 0), static_broadcasted_argnums=(0))
    eval_pstep = jax.pmap(predictor.eval_step, axis_name='batch', in_axes=(None, None, 0, 0, 0, 0, 0, 0, None, None, 0, 0, None, None, 0, 0), static_broadcasted_argnums=(0))
    rand_key = jax.random.split(rng, jax.local_device_count())
    state = flax.jax_utils.replicate(state)

    # TensorBoard logging
    time = datetime.now().strftime('%Y-%m-%d.%H:%M:%S')
    logdir = 'runs/{}.{}'.format(runname, time)
       
    with SummaryWriter(logdir=logdir) as writer:

        # Log ground-truth data   
        %matplotlib inline
        images_true = intensity_to_nchw(emission_true.isel(t=0))

        for i in tqdm(range(init_step, init_step+hparams['num_iters']), desc='iteration'):

            # Testing and Visualization
            if (i == 1) or (i % log_period) == 0:
                batch_inds = np.random.choice(range(x.shape[0]), hparams['batchsize'], replace=False)

                loss_test, _, _, rendering_test, _, _ = eval_pstep(
                    velocity_field, i, shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
                    shard(d[batch_inds, ...]), shard(t[batch_inds, ...]), shard(ft_mats[batch_inds, ...]),
                    tstart, tstop, shard(test_vis[batch_inds,...]),  shard(sigma_test[batch_inds, ...]), 
                    bg_image, window, state, rand_key
                )
                writer.add_scalar('log loss/test', np.log10(np.mean(loss_test)), global_step=i)
                
                # Log prediction and estimate
                _, _, emission_vis, _, axis_estimation, _ = eval_pstep(
                    velocity_field, i, shard(x_vis), shard(y_vis), shard(z_vis), shard(d_vis), shard(t_vis),
                    shard(ft_mats[batch_inds, ...]), tstart, tstop, shard(test_vis[batch_inds,...]),
                    shard(sigma_test[batch_inds, ...]), bg_image, window, state, rand_key
                )
                axis_estimation = axis_estimation[0] / np.sqrt(np.dot(axis_estimation[0], axis_estimation[0]))
                emission_vis = np.reshape(emission_vis[0,0], [x_res_vis, y_res_vis, z_res_vis])
                emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 < r_min**2, jnp.zeros_like(emission_vis), emission_vis)
                emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 > r_max**2, jnp.zeros_like(emission_vis), emission_vis)
                
                
                # Log emission and rotation axis
                images = intensity_to_nchw(emission_vis[0])
                writer.add_images('emission/estimate', images, global_step=i)
                writer.add_images('emission/true', images_true, global_step=i)
                emission_mse = float(np.mean((emission.data[0] - emission_vis)**2))
                emission_psnr =  float(10.0 * np.log10(np.max(emission.data[0])**2 / emission_mse))
                writer.add_scalar('emission_mse', emission_mse, global_step=i)
                writer.add_scalar('emission_psnr', emission_psnr, global_step=i)
                writer.add_image('input_bg', bg_image[None, ...], global_step=i)
                writer.add_image('rendering/test', rendering_test[0, 0, None, :, :], global_step=i)
                writer.add_image('rendering/true', image_plane_test[batch_inds[0], None, ...], global_step=i)
                writer.add_image('rendering/diff', np.abs(image_plane_test[batch_inds[0], None, ...] - rendering_test[0, 0, None, :, :]), global_step=i)
                writer.add_scalar('rotation/dot_product', np.dot(rot_axis, axis_estimation), global_step=i)
                writer.add_scalar('rotation/x', axis_estimation[0], global_step=i)
                writer.add_scalar('rotation/y', axis_estimation[1], global_step=i)
                writer.add_scalar('rotation/z', axis_estimation[2], global_step=i)
                
            # Training
            batch_inds = np.random.choice(range(x.shape[0]), hparams['batchsize'], replace=False)
            loss_train, state, _, _, _, rand_key = train_pstep(
                velocity_field, i, shard(x[batch_inds, ...]), shard(y[batch_inds, ...]), shard(z[batch_inds, ...]), 
                shard(d[batch_inds, ...]), shard(t[batch_inds, ...]), shard(ft_mats[batch_inds, ...]), tstart, tstop,
                shard(train_vis[batch_inds, ...]), shard(sigma[batch_inds, ...]), bg_image, window, state, rand_key
            )
            writer.add_scalar('log loss/train', np.log10(np.mean(loss_train)), global_step=i)
    
            
            if np.isscalar(save_period) and ((i % save_period == 0) or (i == hparams['num_iters'])):
                state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
                checkpoints.save_checkpoint(checkpoint_dir, state_to_save, int(i), keep=5)
            
    return state, emission_vis, axis_estimation

In [48]:
# Training params
hparams = {
    'num_iters': 5000,
    'lr_init': 1e-4,
    'lr_final': 1e-6,
    'posenc_deg': 3,
    'axis_net_depth': 3,
    'batchsize': 8,
    'r_min': r_min,
    'r_max': r_max,
}

obs_times = np.linspace(0.0, 1.0, nt)
runname = 'eht_arrays_40min/{}_nspots{}_phi{:2.2f}_theta{:2.2f}_riaf_flux{:2.1f}.norm_factor{:2.3f}'.format(
    'full_image', nspots, phi, theta, riaf_flux, normalization_factor)

state, emission_vis, rot_axis_est = train_network(
    sensor, emission, emission_test, velocity_field, obs_times, hparams, 
    runname=runname, log_period=100, save_period=1000, x_res_vis=64, y_res_vis=64, z_res_vis=64
)



iteration:   0%|          | 0/5000 [00:00<?, ?it/s]

In [124]:
rng = jax.random.PRNGKey(1)
predictor = network_utils.PREDICT_EMISSION_AND_MLP_ROTAXIS_3D_FROM_IMAGE_W_BG()
params = predictor.init(rng, x_vis[:1, ...], y_vis[:1, ...], z_vis[:1, ...], jnp.ones_like(x_vis)[:1, ...], velocity_field, 0.0, 1.0)['params']
tx = optax.adam(learning_rate=optax.polynomial_schedule(1, 1, 1, 1))
state = train_state.TrainState.create(apply_fn=predictor.apply, params=params.unfreeze(), tx=tx)  

rng = jax.random.PRNGKey(1)
rand_key = jax.random.split(rng, jax.local_device_count())
eval_pstep = jax.pmap(predictor.eval_step, axis_name='batch', in_axes=(None, None, 0, 0, 0, 0, 0, 0, None, None, 0, 0, None, None, 0, 0), static_broadcasted_argnums=(0))

ft_mats = np.zeros((nt, 1, 1))               # meaningless placeholder
sigma = np.zeros((nt, 1))                    # meaningless placeholder
sigma_test = np.zeros((nt_test, 1))          # meaningless placeholder 
batch_inds = 0
tstart = 0.0
tstop = 1.0


In [128]:

# anti-aliasing blurring
x = jnp.linspace(-sensor.num_alpha/2.0, sensor.num_alpha/2.0, x_res_vis)
window = jsp.stats.norm.pdf(x) * jsp.stats.norm.pdf(x[:, None])
window = window / window.sum()
bg_image = np.ones((x_res_vis, y_res_vis))

poses = generate_elevated_orbit(25., 12.5, n_frames=60)
# poses = generate_elevated_orbit(25., 22.5, n_frames=60)

height = y_res_vis
width = x_res_vis
focal = .5 * width / jnp.tan(.5 * 0.7)
rays_o, rays_d = generate_rays(poses, width, height, focal)
pts = sample_along_rays(rays_o, rays_d, 15., 35., z_res_vis)

i = 5
x_vis = pts[i, :, :, :, 0]
y_vis = pts[i, :, :, :, 1]
z_vis = pts[i, :, :, :, 2]
d_vis = jnp.linalg.norm(jnp.concatenate([jnp.diff(pts[i:i+1, ...], axis=3),
                                         jnp.zeros_like(pts[i:i+1, :, :, -1:])], 
                                         axis=3), axis=-1)

images_gt = []
vmax = 0.01

# unknown rotation axis
checkpoint_paths = [
    'checkpoints/eht_arrays/full_image_nspots4.00_phi0.63_theta2.62_riaf.true_axis/checkpoint_5000',
]

rendering = []
for checkpoint_path in checkpoint_paths:
    print('loading checkpoint...')
    t0 = time.time()
    state = checkpoints.restore_checkpoint(checkpoint_path, state)
    state = flax.jax_utils.replicate(state)
    
    images = []
    for t in [50]:
        t_vis = (t/nt) * jnp.ones_like(x_vis)
        _, _, emission_vis, _, axis_estimation, _ = eval_pstep(
            velocity_field, i, shard(x_vis), shard(y_vis), shard(z_vis), shard(d_vis), shard(t_vis),
            shard(ft_mats[batch_inds, ...]), tstart, tstop, shard(x_vis[batch_inds,...]),
            shard(sigma_test[batch_inds, ...]), bg_image, window, state, rand_key
        )
        axis_estimation = axis_estimation[0] / np.sqrt(np.dot(axis_estimation[0], axis_estimation[0]))
        emission_vis = np.reshape(emission_vis[0], [x_res_vis, y_res_vis, z_res_vis])
        emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 < r_min**2, jnp.zeros_like(emission_vis), emission_vis)
        emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 > r_max**2, jnp.zeros_like(emission_vis), emission_vis)
        emission_vis = emission_vis / vmax
        print('loading time [secs]: {}'.format(time.time() - t0))
        print('rendering cube visualization...')
        t0 = time.time()
        emission_vis_cube = draw_cube(emission_vis, jnp.stack([x_vis, y_vis, z_vis], axis=-1))
        image, _ = alpha_composite(emission_vis_cube, d_vis, pts[i, ...])
        print('rendering time [secs]: {}'.format(time.time() - t0))
        image = jnp.clip(image, 0.0, 1.0)
        images.append(image)
    rendering.append(np.array(images))

loading checkpoint...
loading time [secs]: 27.596198320388794
rendering cube visualization...
rendering time [secs]: 386.38378977775574


In [130]:
%matplotlib widget
plt.figure()
plt.imshow(image)
plt.axis('off')
plt.tight_layout()
plt.savefig('figures/increasing_meas_sparsity1_t50.pdf', transparent=False, bbox_inches=0)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [39]:
# unknown rotation axis
checkpoint_paths = [
    'checkpoints/eht_arrays/full_image_nspots5_phi0.00_theta1.05_riaf.blurred/checkpoint_5000',
    'checkpoints/eht_arrays/ngEHT_nspots5_phi0.00_theta1.05_riaf/checkpoint_5000',
    'checkpoints/eht_arrays/EHT2017_nspots5_phi0.00_theta1.05_riaf/checkpoint_5000',
    'checkpoints/eht_arrays/full_image_nspots5_phi0.00_theta1.05_riaf.true_axis/checkpoint_5000',
    'checkpoints/eht_arrays/ngEHT_nspots5_phi0.00_theta1.05_riaf.true_axis/checkpoint_5000',
    'checkpoints/eht_arrays/EHT2017_nspots5_phi0.00_theta1.05_riaf.true_axis/checkpoint_5000'
]

rendering = []
for checkpoint_path in checkpoint_paths:
    print('loading checkpoint...')
    t0 = time.time()
    state = checkpoints.restore_checkpoint(checkpoint_path, state)
    state = flax.jax_utils.replicate(state)
    
    images = []
    for t in [0, 25]:
        t_vis = (t/nt) * jnp.ones_like(x_vis)
        _, _, emission_vis, _, axis_estimation, _ = eval_pstep(
            velocity_field, i, shard(x_vis), shard(y_vis), shard(z_vis), shard(d_vis), shard(t_vis),
            shard(ft_mats[batch_inds, ...]), tstart, tstop, shard(x_vis[batch_inds,...]),
            shard(sigma_test[batch_inds, ...]), bg_image, window, state, rand_key
        )
        axis_estimation = axis_estimation[0] / np.sqrt(np.dot(axis_estimation[0], axis_estimation[0]))
        emission_vis = np.reshape(emission_vis[0], [x_res_vis, y_res_vis, z_res_vis])
        emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 < r_min**2, jnp.zeros_like(emission_vis), emission_vis)
        emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 > r_max**2, jnp.zeros_like(emission_vis), emission_vis)
        emission_vis = emission_vis / vmax
        print('loading time [secs]: {}'.format(time.time() - t0))
        print('rendering cube visualization...')
        t0 = time.time()
        emission_vis_cube = draw_cube(emission_vis, jnp.stack([x_vis, y_vis, z_vis], axis=-1))
        image, _ = alpha_composite(emission_vis_cube, d_vis, pts[i, ...])
        print('rendering time [secs]: {}'.format(time.time() - t0))
        image = jnp.clip(image, 0.0, 1.0)
        images.append(image)
    rendering.append(np.array(images))

loading checkpoint...




loading time [secs]: 13.843168258666992
rendering cube visualization...
rendering time [secs]: 254.84529948234558
loading time [secs]: 266.7535593509674
rendering cube visualization...
rendering time [secs]: 254.46521711349487
loading checkpoint...
loading time [secs]: 13.161495685577393
rendering cube visualization...
rendering time [secs]: 257.35203075408936
loading time [secs]: 269.76148200035095
rendering cube visualization...
rendering time [secs]: 258.206862449646
loading checkpoint...
loading time [secs]: 13.281271934509277
rendering cube visualization...
rendering time [secs]: 256.8532314300537
loading time [secs]: 269.88253831863403
rendering cube visualization...
rendering time [secs]: 256.3417339324951
loading checkpoint...
loading time [secs]: 13.229191541671753
rendering cube visualization...
rendering time [secs]: 255.95570492744446
loading time [secs]: 267.8052384853363
rendering cube visualization...
rendering time [secs]: 253.07994198799133
loading checkpoint...
loadin

In [40]:
%matplotlib widget
fig, axes = plt.subplots(2,7, figsize=(18,6))
for i in range(2):
    axes[i,0].axis('off')
    axes[i,0].imshow(images_gt[i])
    for j in range(6):
        axes[i,j+1].axis('off')
        axes[i,j+1].imshow(rendering[j][i])
plt.tight_layout()
plt.savefig('figures/increasing_meas_sparsity2.pdf', transparent=False, bbox_inches=0)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [109]:
# unknown rotation axis
checkpoint_paths = [
    'checkpoints/eht_arrays/full_image_nspots4.00_phi0.63_theta2.62_riaf.blurred/checkpoint_5000',
    'checkpoints/eht_arrays/ngEHT_nspots4.00_phi0.63_theta2.62_riaf/checkpoint_5000',
    'checkpoints/eht_arrays/EHT2017_nspots4.00_phi0.63_theta2.62_riaf/checkpoint_5000'
]

rendering = []
for checkpoint_path in checkpoint_paths:
    print('loading checkpoint...')
    t0 = time.time()
    state = checkpoints.restore_checkpoint(checkpoint_path, state)
    state = flax.jax_utils.replicate(state)
    
    images = []
    for t in [0, 25]:
        t_vis = (t/nt) * jnp.ones_like(x_vis)
        _, _, emission_vis, _, axis_estimation, _ = eval_pstep(
            velocity_field, i, shard(x_vis), shard(y_vis), shard(z_vis), shard(d_vis), shard(t_vis),
            shard(ft_mats[batch_inds, ...]), tstart, tstop, shard(x_vis[batch_inds,...]),
            shard(sigma_test[batch_inds, ...]), bg_image, state, rand_key
        )
        axis_estimation = axis_estimation[0] / np.sqrt(np.dot(axis_estimation[0], axis_estimation[0]))
        emission_vis = np.reshape(emission_vis[0], [x_res_vis, y_res_vis, z_res_vis])
        emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 < r_min**2, jnp.zeros_like(emission_vis), emission_vis)
        emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 > r_max**2, jnp.zeros_like(emission_vis), emission_vis)
        emission_vis = emission_vis / vmax
        print('loading time [secs]: {}'.format(time.time() - t0))
        print('rendering cube visualization...')
        t0 = time.time()
        emission_vis_cube = draw_cube(emission_vis, jnp.stack([x_vis, y_vis, z_vis], axis=-1))
        image, _ = alpha_composite(emission_vis_cube, d_vis, pts[i, ...])
        print('rendering time [secs]: {}'.format(time.time() - t0))
        image = jnp.clip(image, 0.0, 1.0)
        images.append(image)
    rendering.append(np.array(images))

loading checkpoint...
loading time [secs]: 13.291164875030518
rendering cube visualization...
rendering time [secs]: 254.44943928718567
loading time [secs]: 266.0489263534546
rendering cube visualization...
rendering time [secs]: 254.58410501480103
loading checkpoint...
loading time [secs]: 13.33150839805603
rendering cube visualization...
rendering time [secs]: 254.0204586982727
loading time [secs]: 265.65969920158386
rendering cube visualization...
rendering time [secs]: 255.10928678512573
loading checkpoint...
loading time [secs]: 12.956752300262451
rendering cube visualization...
rendering time [secs]: 255.31900429725647
loading time [secs]: 267.2135634422302
rendering cube visualization...
rendering time [secs]: 254.40299558639526


In [119]:
rng = jax.random.PRNGKey(1)
predictor = network_utils.PREDICT_EMISSION_AND_MLP_ROTAXIS_3D_FROM_IMAGE_W_BG()
params = predictor.init(rng, x_vis[:1, ...], y_vis[:1, ...], z_vis[:1, ...], jnp.ones_like(x_vis)[:1, ...], velocity_field, 0.0, 1.0)['params']
tx = optax.adam(learning_rate=optax.polynomial_schedule(1, 1, 1, 1))
state = train_state.TrainState.create(apply_fn=predictor.apply, params=params.unfreeze(), tx=tx)  

rng = jax.random.PRNGKey(1)
rand_key = jax.random.split(rng, jax.local_device_count())
eval_pstep = jax.pmap(predictor.eval_step, axis_name='batch', in_axes=(None, None, 0, 0, 0, 0, 0, 0, None, None, 0, 0, None, 0, 0), static_broadcasted_argnums=(0))

ft_mats = np.zeros((nt, 1, 1))               # meaningless placeholder
sigma = np.zeros((nt, 1))                    # meaningless placeholder
sigma_test = np.zeros((nt_test, 1))          # meaningless placeholder 
batch_inds = 0
tstart = 0.0
tstop = 1.0
bg_image = np.ones((x_res_vis, y_res_vis))

In [120]:
# unknown rotation axis
checkpoint_paths = [
    'checkpoints/eht_arrays/full_image_nspots4.00_phi0.63_theta2.62_riaf.true_axis/checkpoint_5000',
    'checkpoints/eht_arrays/ngEHT_nspots4.00_phi0.63_theta2.62_riaf.true_axis/checkpoint_5000',
    'checkpoints/eht_arrays/EHT2017_nspots4.00_phi0.63_theta2.62_riaf.true_axis/checkpoint_5000'
]

for checkpoint_path in checkpoint_paths:
    print('loading checkpoint...')
    t0 = time.time()
    state = checkpoints.restore_checkpoint(checkpoint_path, state)
    state = flax.jax_utils.replicate(state)
    
    images = []
    for t in [0, 25]:
        t_vis = (t/nt) * jnp.ones_like(x_vis)
        _, _, emission_vis, _, axis_estimation, _ = eval_pstep(
            velocity_field, i, shard(x_vis), shard(y_vis), shard(z_vis), shard(d_vis), shard(t_vis),
            shard(ft_mats[batch_inds, ...]), tstart, tstop, shard(x_vis[batch_inds,...]),
            shard(sigma_test[batch_inds, ...]), bg_image, state, rand_key
        )
        axis_estimation = axis_estimation[0] / np.sqrt(np.dot(axis_estimation[0], axis_estimation[0]))
        emission_vis = np.reshape(emission_vis[0], [x_res_vis, y_res_vis, z_res_vis])
        emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 < r_min**2, jnp.zeros_like(emission_vis), emission_vis)
        emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 > r_max**2, jnp.zeros_like(emission_vis), emission_vis)
        emission_vis = emission_vis / vmax
        print('loading time [secs]: {}'.format(time.time() - t0))
        print('rendering cube visualization...')
        t0 = time.time()
        emission_vis_cube = draw_cube(emission_vis, jnp.stack([x_vis, y_vis, z_vis], axis=-1))
        image, _ = alpha_composite(emission_vis_cube, d_vis, pts[i, ...])
        print('rendering time [secs]: {}'.format(time.time() - t0))
        image = jnp.clip(image, 0.0, 1.0)
        images.append(image)
    rendering.append(np.array(images))

loading checkpoint...




loading time [secs]: 13.335416316986084
rendering cube visualization...
rendering time [secs]: 254.50141215324402
loading time [secs]: 266.46852016448975
rendering cube visualization...
rendering time [secs]: 255.80104041099548
loading checkpoint...
loading time [secs]: 12.318540096282959
rendering cube visualization...
rendering time [secs]: 254.04741740226746
loading time [secs]: 265.38665795326233
rendering cube visualization...
rendering time [secs]: 255.44973230361938
loading checkpoint...
loading time [secs]: 12.54409384727478
rendering cube visualization...
rendering time [secs]: 255.49767804145813
loading time [secs]: 267.72049474716187
rendering cube visualization...
rendering time [secs]: 253.87019658088684


In [124]:
%matplotlib widget
fig, axes = plt.subplots(2,7, figsize=(18,6))
for i in range(2):
    axes[i,0].axis('off')
    axes[i,0].imshow(images_gt[i])
    for j in range(6):
        axes[i,j+1].axis('off')
        axes[i,j+1].imshow(rendering[j][i])
plt.tight_layout()
plt.savefig('figures/increasing_meas_sparsity.pdf', transparent=False, bbox_inches=0)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [116]:
%matplotlib widget
fig, axes = plt.subplots(2,4, figsize=(12,6))
for i in range(2):
    axes[i,0].axis('off')
    axes[i,0].imshow(images_gt[i])
    for j in range(3):
        axes[i,j+1].axis('off')
        axes[i,j+1].imshow(rendering[j][i])
plt.tight_layout()
plt.savefig('figures/increasing_meas_sparsity1.pdf', transparent=False, bbox_inches=0)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [88]:
plt.figure()
plt.imshow(rendering[0])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x7fc5dc3a3370>

In [82]:
emission_estimate = emission_utils.generate_orbit_3d(
    xr.DataArray(emission_vis, coords={'x': np.linspace(-5,5.0, x_res_vis), 
                                       'y': np.linspace(-5,5.0, y_res_vis), 
                                       'z': np.linspace(-5,5.0, z_res_vis)}, dims=['x','y','z']), 
    nt, velocity_field, axis_estimation)

In [83]:
extent = [(float(emission_estimate[dim].min()), float(emission_estimate[dim].max())) for dim in ('x', 'y', 'z')]
@interact(t=widgets.IntSlider(min=0, max=emission_estimate.t.size-1, step=1, value=0))
def plot_vol(t):
    ipv.figure()
    ipv.view(0, -60, distance=2.5)
    ipv.volshow(emission_estimate.isel(t=t), extent=extent, memorder='F', level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3], controls=False)
    ipv.show()

interactive(children=(IntSlider(value=0, description='t', max=127), Output()), _dom_classes=('widget-interact'…

In [35]:
checkpoint_paths = [
    'checkpoints/eht_arrays/ngEHT_nspots4.00_phi0.63_theta2.62_riaf/checkpoint_5000',
]

rendering = []
for checkpoint_path in checkpoint_paths:
    print('loading checkpoint...')
    t0 = time.time()
    state = checkpoints.restore_checkpoint(checkpoint_path, state)
    state = flax.jax_utils.replicate(state)
    
    for t in :
        t_vis = (t/nt) * jnp.ones_like(x_vis)
        _, _, emission_vis, _, axis_estimation, _ = eval_pstep(
            velocity_field, i, shard(x_vis), shard(y_vis), shard(z_vis), shard(d_vis), shard(t_vis),
            shard(ft_mats[batch_inds, ...]), tstart, tstop, shard(test_vis[batch_inds,...]),
            shard(sigma_test[batch_inds, ...]), bg_image, state, rand_key
        )
        axis_estimation = axis_estimation[0] / np.sqrt(np.dot(axis_estimation[0], axis_estimation[0]))
        emission_vis = np.reshape(emission_vis[0,0], [x_res_vis, y_res_vis, z_res_vis])
        emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 < r_min**2, jnp.zeros_like(emission_vis), emission_vis)
        emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 > r_max**2, jnp.zeros_like(emission_vis), emission_vis)

                
                
        _, _, emission_vis, rendering_vis, axis_est, _ = eval_pstep(
            velocity_field, 0, 
            shard(x_vis), shard(y_vis), shard(z_vis), shard(d_vis), shard(t_vis), shard(np.ones_like(x_vis)), 
            100, 0, 1, shard(test_vis), shard(jnp.zeros((1, y_res_vis, x_res_vis))), shard(jnp.ones_like(test_vis)), 
            jnp.zeros((y_res_vis, x_res_vis)),
            state, rand_key
        )

        emission_vis = np.reshape(emission_vis, [t_res_vis, x_res_vis, y_res_vis, z_res_vis])
        emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 >= hparams['r_min']**2, emission_vis, jnp.zeros_like(emission_vis))
        emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 <= sensor.r.max().data**2, emission_vis, jnp.zeros_like(emission_vis))
        emission_vis = emission_vis / vmax
        print('loading time [secs]: {}'.format(time.time() - t0))

        print('rendering cube visualization...')
        t0 = time.time()
        emission_vis_cube = draw_cube(emission_vis[0], jnp.stack([x_vis[0], y_vis[0], z_vis[0]], axis=-1))
        image, _ = alpha_composite(emission_vis_cube, d_vis, pts[i, ...])
        print('rendering time [secs]: {}'.format(time.time() - t0))
        image = jnp.clip(image, 0.0, 1.0)
        rendering.append(image)

loading checkpoint...


NameError: name 'batch_inds' is not defined

In [12]:
%matplotlib widget
fig, axes = plt.subplots(1,2, figsize=(8,4))
for i, ax in enumerate(axes):
    ax.axis('off')
    ax.imshow(images_gt[i])
plt.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …