# Estimating hotspot(s) emission and rotational axis from visibilties. 

Notes:
Inner most stable circular orbit (ISCO), for spin=0 with r_g=2 this is at 3M \
Overleaf notes: https://www.overleaf.com/project/60ff0ece5aa4f90d07f2a417

In [20]:
import sys
sys.path.append('../bhnerf')

import os

import jax
from jax import random
from jax import numpy as jnp
import jax.scipy.ndimage as jnd
import scipy.ndimage as nd

os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import xarray as xr
import flax
from flax.training import train_state
from flax.training import checkpoints
import optax
import numpy as np
import matplotlib.pyplot as plt

import utils, emission_utils, visualization, network_utils, observation_utils
from network_utils import shard

import ehtim as eh
import ehtim.const_def as ehc
from tensorboardX import SummaryWriter
from datetime import datetime
from tqdm.notebook import tqdm
import ipyvolume as ipv
from ipywidgets import interact
import ipywidgets as widgets

import mediapy as media
from sklearn.gaussian_process.kernels import RBF
from sklearn.gaussian_process import GaussianProcessRegressor as GPR

# from jax.config import config
# config.update("jax_debug_nans", True)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
# Perspective camera functions

def normalize(x):
    """Normalization helper function."""
    return x / np.linalg.norm(x)

def viewmatrix(lookdir, up, position):
    """Construct lookat view matrix."""
    vec2 = normalize(lookdir)
    vec0 = normalize(np.cross(up, vec2))
    vec1 = normalize(np.cross(vec2, vec0))
    m = np.stack([vec0, vec1, vec2, position], axis=1)
    return m

def generate_hemispherical_orbit(radius, n_frames=120):
    """Calculates a render path which orbits around the z-axis."""
    # Assume that z-axis points up towards approximate camera hemisphere
    render_poses = []

    up = np.array([0., 0., 1.])
    for theta in np.linspace(0., 2. * np.pi, n_frames):
        camorigin = radius * np.array(
            [np.cos(theta), np.sin(theta), 0.])
        render_poses.append(viewmatrix(camorigin, up, camorigin))

    render_poses = np.stack(render_poses, axis=0)
    return render_poses

def generate_elevated_orbit(radius, height, n_frames=120):
    """Calculates a render path which orbits around the z-axis."""
    # Assume that z-axis points up towards approximate camera hemisphere
    render_poses = []
    
    orbit_radius = np.sqrt(radius**2 - height**2)

    up = np.array([0., 0., 1.])
    for theta in np.linspace(0., 2. * np.pi, n_frames):
        camorigin = np.array(
            [orbit_radius * np.cos(theta), orbit_radius * np.sin(theta), height])
        render_poses.append(viewmatrix(camorigin, up, camorigin))

    render_poses = np.stack(render_poses, axis=0)
    return render_poses

def generate_rays(camtoworlds, width, height, focal):
    """Generating rays for all images."""
    x, y = np.meshgrid(  # pylint: disable=unbalanced-tuple-unpacking
        np.arange(width, dtype=np.float32),  # X-Axis (columns)
        np.arange(height, dtype=np.float32),  # Y-Axis (rows)
        indexing='xy')
    camera_dirs = np.stack(
        [(x - width * 0.5 + 0.5) / focal,
         -(y - height * 0.5 + 0.5) / focal, -np.ones_like(x)],
        axis=-1)
    directions = ((camera_dirs[None, ..., None, :] *
                   camtoworlds[:, None, None, :3, :3]).sum(axis=-1))
    origins = np.broadcast_to(camtoworlds[:, None, None, :3, -1],
                              directions.shape)

    return origins, directions

def sample_along_rays(rays_o, rays_d, near, far, num_samples):
    t_vals = jnp.linspace(near, far, num_samples)
    pts = rays_o[:, :, :, None, :] + t_vals[None, None, None, :, None] * rays_d[:, :, :, None, :]
    return pts

def draw_cube(emission, pts):
    facewidth = 10.
    linewidth = 0.15
    linecolor = jnp.array([1000.0, 1000.0, 1000.0])
    vertices = jnp.array([[-facewidth/2., -facewidth/2., -facewidth/2.],
                        [facewidth/2., -facewidth/2., -facewidth/2.],
                        [-facewidth/2., facewidth/2., -facewidth/2.],
                        [facewidth/2., facewidth/2., -facewidth/2.],
                        [-facewidth/2., -facewidth/2., facewidth/2.],
                        [facewidth/2., -facewidth/2., facewidth/2.],
                        [-facewidth/2., facewidth/2., facewidth/2.],
                        [facewidth/2., facewidth/2., facewidth/2.]])
    dirs = jnp.array([[-1., 0., 0.],
                      [1., 0., 0.],
                      [0., -1., 0.],
                      [0., 1., 0.],
                      [0., 0., -1.],
                      [0., 0., 1.]])
    
    cm = plt.get_cmap('hot')
    emission = cm(emission)[..., :3]
    emission = jnp.clip(emission - 0.05, 0.0, 1.0)

    for i in range(vertices.shape[0]):

        for j in range(dirs.shape[0]):
            # Draw line segments from each vertex
            line_seg_pts = vertices[i, None, :] + jnp.linspace(0.0, facewidth, 64)[:, None] * dirs[j, None, :]

            for k in range(line_seg_pts.shape[0]):
                dists = jnp.linalg.norm(pts - jnp.broadcast_to(line_seg_pts[k, None, None, None, :], pts.shape), axis=-1)
                # out = jnp.where(jnp.less_equal(dists, linewidth),
                #                 lineval * jnp.exp(-1. * dists), out)
                emission += linecolor[None, None, None, :] * jnp.exp(-1. * dists / linewidth ** 2)[..., None]
    
    out = jnp.where(jnp.greater(jnp.broadcast_to(jnp.amax(jnp.abs(pts), axis=-1, keepdims=True), 
                                                 emission.shape), 
                                facewidth/2. + linewidth), 
                    jnp.zeros_like(emission), emission)
        
    return out

In [22]:
sensor = xr.load_dataset('../sensors/a0.00_th1.57_ngeo100_npix4096.nc')
sensor = sensor.where(sensor.r < 5)
r_min = sensor.r.min().data   # Minimum supervision radius
sensor = sensor.fillna(0.0)

In [23]:
phi = 0.0                # azimuth angle (ccw from x-axis)
theta = np.pi/3          # zenith angle (pi/2 = equatorial plane)
rot_axis = np.array([np.cos(theta)*np.cos(phi), np.cos(theta)*np.sin(phi), -np.sin(theta)])
orbit_period = 3.5**(-3./2.) 
velocity_field = lambda r: (1.0 / orbit_period) * r**(-3/2)
    
checkpoint_path = 'checkpoints/v_noise/ngEHT_amp0.35_corrlen0.01_random_state0_phi0.00_theta1.05/checkpoint_5000'


In [24]:
velocity_std = 0.35
length_scale = 0.01
random_state = 0
radii = np.linspace(1, 10, 100)

nt, nx, ny, nz = 64, 128, 128, 128
nspots = 1
r_isco = 3.0 
orbit_radius = 3.5
std = .4 * np.ones_like(orbit_radius)
orbit_period = orbit_radius**(-3./2.) 
velocity_field = lambda r: (1.0 / orbit_period) * r**(-3/2)
                        
phi = 0.0         
theta = np.pi / 3 
initial_frame = emission_utils.generate_hotspots_3d(nx, ny, nz, theta, phi, orbit_radius, std, r_isco, std_clip=np.inf)
rot_axis = np.array([np.cos(theta)*np.cos(phi), np.cos(theta)*np.sin(phi), -np.sin(theta)])


gpr = GPR(kernel=RBF(length_scale=length_scale))
gp_noise = np.squeeze(velocity_std * gpr.sample_y(radii.reshape(-1, 1), n_samples=1, random_state=random_state))
noisy_velocity_field = lambda r: velocity_field(r) + np.interp(r, radii, gp_noise)
emission = emission_utils.generate_orbit_3d(initial_frame, nt, noisy_velocity_field, rot_axis)

In [25]:
nt = 64
nx = 64
ny = 64
nz = 64

x_res_vis = 128
y_res_vis = 128
z_res_vis = 128

hparams = {
    'num_iters': 50000,
    'lr_init': 1e-4,
    'lr_final': 1e-6,
    'lr_axis': 1e-1,
    'posenc_deg': 3,
    'batchsize': 1,
    'r_min': r_min
}

# Training / testing coordinates
train_coords = network_utils.get_input_coords(sensor, t_array=np.linspace(0, 1, nt), batch='t')
t, x, y, z, d = train_coords.values()
t = np.linspace(0, 1, nt)

# Emission visualization inputs
t_res_vis = 1
emission_extent = [-5, 5, -5, 5, -5, 5]
t_vis, x_vis, y_vis, z_vis  = np.meshgrid(0, np.linspace(emission_extent[0], emission_extent[1], x_res_vis),
                                          np.linspace(emission_extent[2], emission_extent[3], y_res_vis),
                                          np.linspace(emission_extent[4], emission_extent[5], x_res_vis),
                                          indexing='ij')

d_vis = np.ones_like(y_vis)                  # meaningless placeholder for emission visualization
uv = np.zeros((1,171,2))
test_vis = np.zeros((1,171))

# Model setup and initialization
rng = jax.random.PRNGKey(1)
predictor = network_utils.PREDICT_EMISSION_AND_MLP_ROTAXIS_3D_FROM_VIS_W_BG_PRATUL(posenc_deg=hparams['posenc_deg'])
params = predictor.init(rng, x[:1, ...], y[:1, ...], z[:1, ...], t[:1, ...], velocity_field, 0, 1)['params']

def flattened_traversal(fn):
    def mask(data):
        flat = flax.traverse_util.flatten_dict(data)
        return flax.traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})
    return mask

tx = optax.chain(
    optax.masked(optax.adam(learning_rate=hparams['lr_axis']),
                 mask=flattened_traversal(lambda path, _: path[-1] == 'axis')),
    optax.masked(optax.adam(learning_rate=optax.polynomial_schedule(hparams['lr_init'], hparams['lr_final'], 1, hparams['num_iters'])),
                 mask=flattened_traversal(lambda path, _: path[-1] != 'axis')),
)

state = train_state.TrainState.create(apply_fn=predictor.apply, params=params.unfreeze(), tx=tx)  # TODO(pratul): this unfreeze feels sketchy
state = checkpoints.restore_checkpoint(checkpoint_path, state)
state = flax.jax_utils.replicate(state)

eval_pstep = jax.pmap(predictor.eval_step, axis_name='batch', in_axes=(None, None, 0, 0, 0, 0, 0, 0, None, None, None, 0, 0, 0, None, 0, 0), static_broadcasted_argnums=(0))
rng = jax.random.PRNGKey(1)
rand_key = jax.random.split(rng, jax.local_device_count())



In [42]:
# poses = generate_hemispherical_orbit(25., n_frames=60)
poses = generate_elevated_orbit(25., 20., n_frames=60)
height = y_res_vis
width = x_res_vis
focal = .5 * width / jnp.tan(.5 * 0.7)
rays_o, rays_d = generate_rays(poses, width, height, focal)
pts = sample_along_rays(rays_o, rays_d, 15., 35., z_res_vis)

In [43]:
i = 0
t = 0.0
x_vis = pts[i:i+1, :, :, :, 0]
y_vis = pts[i:i+1, :, :, :, 1]
z_vis = pts[i:i+1, :, :, :, 2]
d_vis = jnp.linalg.norm(jnp.concatenate([jnp.diff(pts[i:i+1, ...], axis=3),
                                         jnp.zeros_like(pts[i:i+1, :, :, -1:])], 
                                         axis=3), axis=-1)

t_vis = t * jnp.ones_like(x_vis)

_, _, emission_vis, rendering_vis, axis_est, _ = eval_pstep(
    velocity_field, 0, 
    shard(x_vis), shard(y_vis), shard(z_vis), shard(d_vis), shard(t_vis), shard(uv), 
    100, 0, 1, 
    shard(test_vis), shard(jnp.zeros((1, y_res_vis, x_res_vis))), shard(jnp.ones_like(test_vis)), 
    jnp.zeros((y_res_vis, x_res_vis)),
    state, rand_key
)
axis_est = axis_est[0] / np.sqrt(np.dot(axis_est[0], axis_est[0]))
emission_vis = np.reshape(emission_vis, [t_res_vis, x_res_vis, y_res_vis, z_res_vis])
emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 >= hparams['r_min']**2, emission_vis, jnp.zeros_like(emission_vis))
emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 <= sensor.r.max().data**2, emission_vis, jnp.zeros_like(emission_vis))
emission_vis = emission_vis / emission_vis.max()
emission_vis_cube = draw_cube(emission_vis[0], jnp.stack([x_vis[0], y_vis[0], z_vis[0]], axis=-1))

rendering = jnp.sum(emission_vis_cube * d_vis[0, ..., None], axis=-2)
rendering = jnp.clip(rendering, 0.0, 1.0)

In [41]:
%matplotlib widget
plt.figure()
plt.imshow(rendering)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x7fae731afc10>

In [33]:
i = 0
t = 0.
x_vis = pts[i:i+1, :, :, :, 0]
y_vis = pts[i:i+1, :, :, :, 1]
z_vis = pts[i:i+1, :, :, :, 2]
d_vis = jnp.linalg.norm(jnp.concatenate([jnp.diff(pts[i:i+1, ...], axis=3),
                                         jnp.zeros_like(pts[i:i+1, :, :, -1:])], 
                                         axis=3), axis=-1)

emission_gt = emission.sel(t=t, method='nearest').interp(x=xr.DataArray(x_vis[0]),
                                                         y=xr.DataArray(y_vis[0]),
                                                         z=xr.DataArray(z_vis[0])).data
emission_vis_cube = draw_cube(emission_gt, jnp.stack([x_vis[0], y_vis[0], z_vis[0]], axis=-1))

rendering = jnp.sum(emission_vis_cube * d_vis[0, ..., None], axis=-2)
rendering = jnp.clip(rendering, 0.0, 1.0)

In [34]:
%matplotlib widget
plt.figure()
plt.imshow(rendering)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x7fae73617220>

In [8]:
renderings = []

t = 0.

for i in range(0, pts.shape[0]):
    
    print('rendering pose:', i, 'at time:', t)
            
    x_vis = pts[i:i+1, :, :, :, 0]
    y_vis = pts[i:i+1, :, :, :, 1]
    z_vis = pts[i:i+1, :, :, :, 2]
    d_vis = jnp.linalg.norm(jnp.concatenate([jnp.diff(pts[i:i+1, ...], axis=3),
                                             jnp.zeros_like(pts[i:i+1, :, :, -1:])], 
                                            axis=3), axis=-1)
    
    t_vis = t * jnp.ones_like(x_vis)
    
    _, _, emission_vis, rendering_vis, axis_est, _ = eval_pstep(
        velocity_field, 0, 
        shard(x_vis), shard(y_vis), shard(z_vis), shard(d_vis), shard(t_vis), shard(uv), 
        100, 0, 1, 
        shard(test_vis), shard(jnp.zeros((1, y_res_vis, x_res_vis))), shard(jnp.ones_like(test_vis)), 
        jnp.zeros((y_res_vis, x_res_vis)),
        state, rand_key
    )
    axis_est = axis_est[0] / np.sqrt(np.dot(axis_est[0], axis_est[0]))
    emission_vis = np.reshape(emission_vis, [t_res_vis, x_res_vis, y_res_vis, z_res_vis])
    emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 >= hparams['r_min']**2, emission_vis, jnp.zeros_like(emission_vis))
    emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 <= sensor.r.max().data**2, emission_vis, jnp.zeros_like(emission_vis))
    
    emission_vis_cube = draw_cube(emission_vis[0], jnp.stack([x_vis[0], y_vis[0], z_vis[0]], axis=-1))
    
    rendering = jnp.sum(emission_vis_cube * d_vis[0, ..., None], axis=-2)
    rendering = jnp.clip(rendering, 0.0, 1.0)
    renderings.append(rendering)

for t in jnp.linspace(0.0, 0.5, pts.shape[0]):
        
    print(i, t)
    
    x_vis = pts[i:i+1, :, :, :, 0]
    y_vis = pts[i:i+1, :, :, :, 1]
    z_vis = pts[i:i+1, :, :, :, 2]
    d_vis = jnp.linalg.norm(jnp.concatenate([jnp.diff(pts[i:i+1, ...], axis=3),
                                             jnp.zeros_like(pts[i:i+1, :, :, -1:])], 
                                            axis=3), axis=-1)
    
    t_vis = t * jnp.ones_like(x_vis)
    
    _, _, emission_vis, rendering_vis, axis_est, _ = eval_pstep(
        velocity_field, 0, 
        shard(x_vis), shard(y_vis), shard(z_vis), shard(d_vis), shard(t_vis), shard(uv), 
        100, 0, 1, 
        shard(test_vis), shard(jnp.zeros((1, y_res_vis, x_res_vis))), shard(jnp.ones_like(test_vis)), 
        jnp.zeros((y_res_vis, x_res_vis)),
        state, rand_key
    )
    axis_est = axis_est[0] / np.sqrt(np.dot(axis_est[0], axis_est[0]))
    emission_vis = np.reshape(emission_vis, [t_res_vis, x_res_vis, y_res_vis, z_res_vis])
    emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 >= hparams['r_min']**2, emission_vis, jnp.zeros_like(emission_vis))
    emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 <= sensor.r.max().data**2, emission_vis, jnp.zeros_like(emission_vis))
    
    emission_vis_cube = draw_cube(emission_vis[0], jnp.stack([x_vis[0], y_vis[0], z_vis[0]], axis=-1))
    
    rendering = jnp.sum(emission_vis_cube * d_vis[0, ..., None], axis=-2)
    rendering = jnp.clip(rendering, 0.0, 1.0)
    renderings.append(rendering)
    
for i, t in enumerate(jnp.linspace(0.5, 1., pts.shape[0])):
    
    i += 30
        
    print(i, t)
    
    x_vis = pts[i:i+1, :, :, :, 0]
    y_vis = pts[i:i+1, :, :, :, 1]
    z_vis = pts[i:i+1, :, :, :, 2]
    d_vis = jnp.linalg.norm(jnp.concatenate([jnp.diff(pts[i:i+1, ...], axis=3),
                                             jnp.zeros_like(pts[i:i+1, :, :, -1:])], 
                                            axis=3), axis=-1)
    
    t_vis = t * jnp.ones_like(x_vis)
    
    _, _, emission_vis, rendering_vis, axis_est, _ = eval_pstep(
        velocity_field, 0, 
        shard(x_vis), shard(y_vis), shard(z_vis), shard(d_vis), shard(t_vis), shard(uv), 
        100, 0, 1, 
        shard(test_vis), shard(jnp.zeros((1, y_res_vis, x_res_vis))), shard(jnp.ones_like(test_vis)), 
        jnp.zeros((y_res_vis, x_res_vis)),
        state, rand_key
    )
    axis_est = axis_est[0] / np.sqrt(np.dot(axis_est[0], axis_est[0]))
    emission_vis = np.reshape(emission_vis, [t_res_vis, x_res_vis, y_res_vis, z_res_vis])
    emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 >= hparams['r_min']**2, emission_vis, jnp.zeros_like(emission_vis))
    emission_vis = jnp.where(x_vis**2 + y_vis**2 + z_vis**2 <= sensor.r.max().data**2, emission_vis, jnp.zeros_like(emission_vis))
    
    emission_vis_cube = draw_cube(emission_vis[0], jnp.stack([x_vis[0], y_vis[0], z_vis[0]], axis=-1))
    
    rendering = jnp.sum(emission_vis_cube * d_vis[0, ..., None], axis=-2)
    rendering = jnp.clip(rendering, 0.0, 1.0)
    renderings.append(rendering)

rendering pose: 0 at time: 0.0
rendering pose: 1 at time: 0.0
rendering pose: 2 at time: 0.0
rendering pose: 3 at time: 0.0
rendering pose: 4 at time: 0.0
rendering pose: 5 at time: 0.0
rendering pose: 6 at time: 0.0
rendering pose: 7 at time: 0.0
rendering pose: 8 at time: 0.0
rendering pose: 9 at time: 0.0
rendering pose: 10 at time: 0.0
rendering pose: 11 at time: 0.0
rendering pose: 12 at time: 0.0
rendering pose: 13 at time: 0.0
rendering pose: 14 at time: 0.0
rendering pose: 15 at time: 0.0
rendering pose: 16 at time: 0.0
rendering pose: 17 at time: 0.0
rendering pose: 18 at time: 0.0
rendering pose: 19 at time: 0.0
rendering pose: 20 at time: 0.0
rendering pose: 21 at time: 0.0
rendering pose: 22 at time: 0.0
rendering pose: 23 at time: 0.0
rendering pose: 24 at time: 0.0
rendering pose: 25 at time: 0.0
rendering pose: 26 at time: 0.0
rendering pose: 27 at time: 0.0
rendering pose: 28 at time: 0.0
rendering pose: 29 at time: 0.0
rendering pose: 30 at time: 0.0
rendering pose: 31

KeyboardInterrupt: 

In [11]:
media.show_video(renderings, fps=30, codec='h264', width=400)

0
This browser does not support the video tag.


In [None]:
%matplotlib widget
ax = plt.figure(figsize=(7,7)).add_subplot(projection='3d')
ax.scatter(0,0,0, color='black', s=50)
ax.quiver(0,0,0,*rot_axis, length=0.05, linewidths=3,  label='true')
ax.quiver(0,0,0,*axis_est, length=0.05, linewidths=3, label='estimated', color='r')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.legend(fontsize=14)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

NameError: name 'axis_est' is not defined