# Estimating hotspot(s) emission and rotational axis from visibilties. 

Notes:
Inner most stable circular orbit (ISCO), for spin=0 with r_g=2 this is at 3M \
Overleaf notes: https://www.overleaf.com/project/60ff0ece5aa4f90d07f2a417

In [1]:
import sys
sys.path.append('../bhnerf')

import os

import jax
from jax import random
from jax import numpy as jnp
import jax.scipy.ndimage as jnd
import scipy.ndimage as nd

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import xarray as xr
import flax
from flax.training import train_state
from flax.training import checkpoints
import optax
import numpy as np
import matplotlib.pyplot as plt

import utils, emission_utils, visualization, network_utils, observation_utils
from network_utils import shard

import ehtim as eh
import ehtim.const_def as ehc
from tensorboardX import SummaryWriter
from datetime import datetime
from tqdm.notebook import tqdm
import ipyvolume as ipv
from ipywidgets import interact
import ipywidgets as widgets

# from jax.config import config
# config.update("jax_debug_nans", True)
%load_ext autoreload
%autoreload 2

2021-09-24 15:43:01.499923: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /.singularity.d/libs
Matplotlib created a temporary config/cache directory at /tmp/matplotlib-g3il66x0 because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


Welcome to eht-imaging! v 1.2.2 



In [7]:
def load_state(checkpoint_path, velocity_field, nt=64, nx=64, ny=64, nz=64, x_res_vis=64, y_res_vis=64, z_res_vis=64):

    hparams = {
        'num_iters': 50000,
        'lr_init': 1e-4,
        'lr_final': 1e-6,
        'lr_axis': 1e-1,
        'posenc_deg': 3,
        'batchsize': 1,
        'r_min': r_min
    }

    # Training / testing coordinates
    train_coords = network_utils.get_input_coords(sensor, t_array=np.linspace(0, 1, nt), batch='t')
    t, x, y, z, d = train_coords.values()
    t = np.linspace(0, 1, nt)

    # Emission visualization inputs
    t_res_vis = 1
    emission_extent = [-5, 5, -5, 5, -5, 5]
    t_vis, x_vis, y_vis, z_vis  = np.meshgrid(0, np.linspace(emission_extent[0], emission_extent[1], x_res_vis),
                                              np.linspace(emission_extent[2], emission_extent[3], y_res_vis),
                                              np.linspace(emission_extent[4], emission_extent[5], x_res_vis),
                                              indexing='ij')

    d_vis = np.ones_like(y_vis)                  # meaningless placeholder for emission visualization
    uv = np.zeros((1,171,2))
    test_vis = np.zeros((1,171))

    # Model setup and initialization
    rng = jax.random.PRNGKey(1)
    predictor = network_utils.PREDICT_EMISSION_AND_ROTAXIS_3D_FROM_VIS(posenc_deg=hparams['posenc_deg'])
    params = predictor.init(rng, x[:1, ...], y[:1, ...], z[:1, ...], t[:1, ...], velocity_field, 0, 1)['params']

    def flattened_traversal(fn):
        def mask(data):
            flat = flax.traverse_util.flatten_dict(data)
            return flax.traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})
        return mask

    tx = optax.chain(
        optax.masked(optax.adam(learning_rate=hparams['lr_axis']),
                     mask=flattened_traversal(lambda path, _: path[-1] == 'axis')),
        optax.masked(optax.adam(learning_rate=optax.polynomial_schedule(hparams['lr_init'], hparams['lr_final'], 1, hparams['num_iters'])),
                     mask=flattened_traversal(lambda path, _: path[-1] != 'axis')),
    )

    state = train_state.TrainState.create(apply_fn=predictor.apply, params=params.unfreeze(), tx=tx)  # TODO(pratul): this unfreeze feels sketchy
    state = checkpoints.restore_checkpoint(checkpoint_path, state)
    state = flax.jax_utils.replicate(state)

    eval_pstep = jax.pmap(predictor.eval_step, axis_name='batch', in_axes=(None, None, 0, 0, 0, 0, 0, 0, None, None, None, 0, 0, 0), static_broadcasted_argnums=(0))
    rng = jax.random.PRNGKey(1)
    rand_key = jax.random.split(rng, jax.local_device_count())
    _, _, emission_vis, _, axis_estimation, _ = eval_pstep(
        velocity_field, 1, shard(x_vis), shard(y_vis), shard(z_vis), shard(d_vis), shard(t_vis),
        shard(uv[0, ...]), 100, 0, 1, shard(test_vis[0,...]), state, rand_key
    )
    axis_estimation = axis_estimation[0] / np.sqrt(np.dot(axis_estimation[0], axis_estimation[0]))
    emission_vis = np.reshape(emission_vis, [t_res_vis, x_res_vis, y_res_vis, z_res_vis])
    emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 >= hparams['r_min']**2, emission_vis, jnp.zeros_like(emission_vis))
    emission_estimation = emission_utils.generate_orbit_3d(
        xr.DataArray(emission_vis[0], dims=['x', 'y', 'z'], 
                     coords={'x': np.linspace(-5, 5, nx), 
                             'y': np.linspace(-5, 5, ny),
                             'z': np.linspace(-5, 5, nz)}), 
        nt, velocity_field, axis_estimation)

    return emission_estimation, axis_estimation

In [3]:
sensor = xr.load_dataset('../sensors/a0.00_th1.57_ngeo100_npix4096.nc')
sensor = sensor.where(sensor.r < 5)
r_min = sensor.r.min().data   # Minimum supervision radius
sensor = sensor.fillna(0.0)

In [16]:
phi = 0.0                # azimuth angle (ccw from x-axis)
theta = np.pi/3          # zenith angle (pi/2 = equatorial plane)
rot_axis = np.array([np.cos(theta)*np.cos(phi), np.cos(theta)*np.sin(phi), -np.sin(theta)])
orbit_period = 3.5**(-3./2.) 
velocity_field = lambda r: (1.0 / orbit_period) * r**(-3/2)
    
checkpoint_path = 'checkpoints/ngEHT/image_res64x64.fov100.0.ngeo100.nspots10.axis_init_(0,0,-1).nobias/checkpoint_50000'
emission_est, axis_est = load_state(checkpoint_path, velocity_field)



In [14]:
extent = [(float(emission_est[dim].min()), float(emission_est[dim].max())) for dim in ('x', 'y', 'z')]
@interact(t=widgets.IntSlider(min=0, max=emission_est.t.size-1, step=1, value=0))
def plot_vol(t):
    ipv.figure()
    ipv.view(0, -60, distance=2.5)
    ipv.volshow(emission_est.isel(t=t), extent=extent, memorder='F', level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3], controls=False)
    ipv.show()

interactive(children=(IntSlider(value=0, description='t', max=63), Output()), _dom_classes=('widget-interact',…

In [17]:
%matplotlib widget
ax = plt.figure(figsize=(7,7)).add_subplot(projection='3d')
ax.scatter(0,0,0, color='black', s=50)
ax.quiver(0,0,0,*rot_axis, length=0.05, linewidths=3,  label='true')
ax.quiver(0,0,0,*axis_est, length=0.05, linewidths=3, label='estimated', color='r')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.legend(fontsize=14)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x7f6de1f9d370>