# Tutorial5: visualize recovery results

---
This tutorial demonstrates visualize the 3D / 4D recovery results. \
By sampling a trained neural network at regular grid intervals we can visualize the recovered 3D emission.

In [3]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from flax.training import checkpoints
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [59]:
"""
bhnerf methods
"""
from flax import linen as nn
from typing import Any, Callable
import jax, flax
from jax import numpy as jnp
from astropy import units
import functools

from astropy.constants import M_sun
sgra_mass = 4.154*10**6 * M_sun

normalize = lambda vector: vector / np.sqrt(np.dot(vector, vector))
safe_sin = lambda x: jnp.sin(x % (100 * jnp.pi))

class NeRF_Predictor(nn.Module):
    """
    Full function to predict emission at a time step.
    
    Parameters
    ----------
    posenc_deg: int, default=3
    net_depth: int, default=4
    net_width: int, default=128
    activation: Callable[..., Any], default=nn.relu
    out_channel: int default=1
    do_skip: bool, default=True
    """
    posenc_deg: int = 3
    net_depth: int = 4
    net_width: int = 128
    activation: Callable[..., Any] = nn.relu
    out_channel: int = 1
    do_skip: bool = True
    
    @nn.compact
    def __call__(self, t_frames, t_units, coords, Omega, t_start_obs, t_geos, t_injection):
        """
        Sample emission on given coordinates at specified times assuming a velocity model (Omega)
        
        Parameters
        ----------
        t_frames: array, 
            Array of time for each image frame
        t_units: astropy.units, 
            Time units of t_frames.
        coords: list of arrays, 
            For 3D emission coords=[x, y, z] with each array shape=(nt, num_alpha, num_beta, ngeo)
            alpha, beta are image coordinates. These arrays contain the ray integration points
        Omega: array, 
            Angular velocity array sampled along the coords points
        t_start_obs: astropy.Quantity, default=None
            Start time for observations, if None t_frames[0] is assumed to be start time.
        t_geos: array, 
            Time along each geodesic (ray). This is used to account for slow light (light travels at finite velocity).
        t_injection: float, 
            Time of hotspot injection in M units.
        
        Returns
        -------
        emission: jnp.array,
            An array with the emission points
        """
        emission_MLP = MLP(self.net_depth, self.net_width, self.activation, self.out_channel, self.do_skip)
        def predict_emission(t_frames, t_units, coords, Omega, t_start_obs, t_geos, t_injection):
            warped_coords = velocity_warp_coords(
                coords, Omega, t_frames, t_start_obs, t_geos, t_injection, t_units=t_units, use_jax=True
            )
            
            # Zero emission prior to injection time
            valid_inputs_mask = jnp.isfinite(warped_coords)
            net_input = jnp.where(valid_inputs_mask, warped_coords, jnp.zeros_like(warped_coords))
            net_output = emission_MLP(posenc(net_input, self.posenc_deg))
            emission = nn.sigmoid(net_output[..., 0] - 10.0)
            emission = jnp.where(valid_inputs_mask[..., 0], emission, jnp.zeros_like(emission))
            
            return emission
        
        t_injection_param = self.param('t_injection', lambda key, values: jnp.array(values, dtype=jnp.float32), t_injection)
        emission = predict_emission(t_frames, t_units, coords, Omega, t_start_obs, t_geos, t_injection_param)
        return emission
    
class MLP(nn.Module):
    net_depth: int = 4
    net_width: int = 128
    activation: Callable[..., Any] = nn.relu
    out_channel: int = 1
    do_skip: bool = True
  
    @nn.compact
    def __call__(self, x):
        """A simple Multi-Layer Preceptron (MLP) network

        Parameters
        ----------
        x: jnp.ndarray(float32), 
            [batch_size * n_samples, feature], points.
        net_depth: int, 
            the depth of the first part of MLP.
        net_width: int, 
            the width of the first part of MLP.
        activation: function, 
            the activation function used in the MLP.
        out_channel: 
            int, the number of alpha_channels.
        do_skip: boolean, 
            whether or not to use a skip connection

        Returns
        -------
        out: jnp.ndarray(float32), 
            [batch_size * n_samples, out_channel].
        """
        dense_layer = functools.partial(
            nn.Dense, kernel_init=jax.nn.initializers.he_uniform())

        if self.do_skip:
            skip_layer = self.net_depth // 2

        inputs = x
        for i in range(self.net_depth):
            x = dense_layer(self.net_width)(x)
            x = self.activation(x)
            if self.do_skip:
                if i % skip_layer == 0 and i > 0:
                    x = jnp.concatenate([x, inputs], axis=-1)
        out = dense_layer(self.out_channel)(x)

        return out
    
def velocity_warp_coords(coords, Omega, t_frames, t_start_obs, t_geos, t_injection, rot_axis=[0,0,1], M=sgra_mass, t_units=None, use_jax=False):
    """
    Generate an coordinate transoform for the velocity warp.
    
    Parameters
    ----------
    coords: list of np arrays
        A list of arrays with grid coordinates
    Omega: array, 
        Angular velocity array sampled along the coords points.
    t_frames: array, 
        Array of time for each image frame with astropy.units
    t_start_obs: astropy.Quantity, default=None
        Start time for observations, if None t_frames[0] is assumed to be start time.
    t_geos: array, 
        Time along each geodesic (ray). This is used to account for slow light (light travels at finite velocity).
    t_injection: float, 
        Time of hotspot injection in M units.
    rot_axis: array, default=[0, 0, 1]
        Currently only equitorial plane rotation is supported
    M: astropy.Quantity, default=constants.sgra_mass,
        Mass of the black hole used to convert frame times to space-time times in units of M
    t_units: astropy.units, default=None,
        Time units. If None units are taken from t_frames.
    use_jax: bool, default=False,
        Using jax enables GPU accelerated computing.
        
    Returns
    -------
    warped_coords: array,
        An array with the new coordinates for the warp transformation.
    """
    _np = jnp if use_jax else np
    coords = _np.array(coords)
    Omega = _np.array(Omega)
    
    if isinstance(t_start_obs, units.Quantity):
        t_units = t_start_obs.unit
        t_start_obs = t_start_obs.value
    
    GM_c3 = 1.0  
    if t_units is not None:
        GM_c3 = consts.GM_c3(M).to(t_units).value

    if isinstance(t_frames, units.Quantity):
        t_frames = t_frames.to(t_units).value
    t_frames = _np.array(t_frames)

    if (_np.isscalar(Omega) or Omega.ndim == 0):
        Omega = expand_dims(Omega, coords.ndim-1, axis=-1, use_jax=use_jax)

    # Extend the dimensions of `t_frames` and `coords' for an array of times 
    if not (t_frames.ndim == 0):
        coords = expand_dims(coords, coords.ndim + t_frames.ndim, 1, use_jax)
        t_frames = expand_dims(t_frames, t_frames.ndim + Omega.ndim, -1, use_jax)

    # Convert time units to grid units
    
    t_geos = (t_frames - t_start_obs)/GM_c3 + _np.array(t_geos)
    t_M = t_geos - t_injection
    
    # Insert nans for angles before the injection time
    theta_rot = _np.array(t_M * Omega)
    theta_rot = _np.where(t_M < 0.0, _np.full_like(theta_rot, fill_value=np.nan), theta_rot)

    inv_rot_matrix = rotation_matrix(rot_axis, -theta_rot, use_jax=use_jax)
        
    warped_coords = _np.sum(inv_rot_matrix * coords, axis=1)
    warped_coords = _np.moveaxis(warped_coords, 0, -1)
    return warped_coords

def sample_3d_grid(apply_fn, params, rmin=0.0, rmax=np.inf, fov=None, coords=None, resolution=64): 
    """
    Parameters
    ----------
    apply_fn: nn.Module
        A coordinate-based neural net for predicting the emission values as a continuous function
    params: dict, 
        A dictionary with network parameters (from state.params)
    rmin: float, default=0
        Zero values at radii < rmin
    rmax: float, default=np.inf
        Zero values at radii > rmax,
    fov: float, default=None
        Field of view. If None then coords need to be provided.
    coords: array(shape=(3,npoints)), optional, 
        Array of grid coordinates (x, y, z). If not specified, fov and resolution are used to grid the domain.
    resolution: int, default=64
        Grid resolution along [x,y,z].
    """     
    try:
        params = jax.device_get(flax.jax_utils.unreplicate(params))
    except IndexError:
        params = jax.device_get(params)
    
    if (coords is None) and (fov is not None):
        grid_1d = np.linspace(-fov/2, fov/2, resolution)
        coords = np.array(np.meshgrid(grid_1d, grid_1d, grid_1d, indexing='ij'))
    elif (coords is None):
        raise AttributeError('Either coords or fov+resolution must be provided')

    # Get the a grid values sampled from the neural network
    emission = apply_fn({'params': params}, 0.0, None, coords, 0.0, 0.0, 0.0, 0.0)
    emission =  fill_unsupervised_emission(emission, coords, rmin, rmax)
    return emission

def fill_unsupervised_emission(emission, coords, rmin=0, rmax=np.Inf, fill_value=0.0, use_jax=False):
    """
    Fill emission that is not within the supervision region
    
    Parameters
    ----------
    emission: np.array
        3D array with emission values
    coords: list of np.arrays
        Spatial coordinate arrays each shaped like emission
    rmin: float, default=0
        Zero values at radii < rmin
    rmax: float, default=np.inf
        Zero values at radii > rmax
    fill_value: float, default=0.0
        Fill value is default to zero 
    use_jax: bool, default=False,
        Using jax enables GPU accelerated computing.
        
    Returns
    -------
    emission: np.array
        3D array with emission values filled in
    """
    _np = jnp if use_jax else np
    r_sq = _np.sum(_np.array([_np.squeeze(x)**2 for x in coords]), axis=0)
    emission = _np.where(r_sq < rmin**2, _np.full_like(emission, fill_value=fill_value), emission)
    emission = _np.where(r_sq > rmax**2, _np.full_like(emission, fill_value=fill_value), emission)
    return emission

def rotation_matrix(axis, angle, use_jax=False):
    """
    Return the rotation matrix associated with counterclockwise rotation about
    the given axis
    
    Parameters
    ----------
    axis: list or np.array, dim=3
        Axis of rotation
    angle: float or numpy array of floats,
        Angle of rotation in radians
    use_jax: bool, default=False
        Compuatations using jax.
        
    Returns
    -------
    rotation_matrix: np.array(shape=(3,3,...)),
        A rotation matrix. If angle is a numpy array additional dimensions are stacked at the end.
        
    References
    ----------
    [1] https://en.wikipedia.org/wiki/Euler%E2%80%93Rodrigues_formula
    [2] https://stackoverflow.com/questions/6802577/rotation-of-3d-vector
    """
    _np = jnp if use_jax else np
    
    axis = _np.array(axis)
    axis = axis / _np.sqrt(_np.dot(axis, axis))
    
    a = _np.cos(angle / 2.0)
    b, c, d = _np.stack([-ax * _np.sin(angle / 2.0) for ax in axis])
    aa, bb, cc, dd = a * a, b * b, c * c, d * d
    bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
    return _np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
                      [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
                      [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])

def expand_dims(x, ndim, axis=0, use_jax=False):
    _np = jnp if use_jax else np
    for i in range(ndim-_np.array(x).ndim):
        x = _np.expand_dims(x, axis=min(axis, _np.array(x).ndim))
    return x

def posenc(x, deg):
    """
    Concatenate `x` with a positional encoding of `x` with degree `deg`.
    Instead of computing [sin(x), cos(x)], we use the trig identity
    cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).

    Parameters
    ----------
    x: jnp.ndarray, 
        variables to be encoded. Note that x should be in [-pi, pi].
    deg: int, 
        the degree of the encoding.

    Returns
    -------
    encoded: jnp.ndarray, 
        encoded variables.
    """
    if deg == 0:
        return x
    scales = jnp.array([2**i for i in range(deg)])
    xb = jnp.reshape((x[..., None, :] * scales[:, None]),
                     list(x.shape[:-1]) + [-1])
    four_feat = safe_sin(jnp.concatenate([xb, xb + 0.5 * jnp.pi], axis=-1))
    return jnp.concatenate([x] + [four_feat], axis=-1)

In [64]:
"""
Visualization methods -- needs debugging
"""
class VolumeVisualizer(object):
    def __init__(self, width, height, samples):
        """
        A Volume visualization class
        
        Parameters
        ----------
        width: int
            camera horizontal resolution.
        height: int
            camera vertical resolution.
        samples: int
            Number of integration points along a ray.
        """
        self.width = width
        self.height = height
        self.samples = samples 
        self.focal = .5 * width / jnp.tan(.5 * 0.7)
        self._pts = None
        
    def set_view(self, radius, azimuth, zenith, up=np.array([0., 0., 1.])):
        """
        Set camera view geometry
        
        Parameters
        ----------
        radius: float,
            Distance from the origin
        azimuth: float, 
            Azimuth angle in radians
        zenith: float, 
            Zenith angle in radians
        up: array, default=[0,0,1]
            The up direction determines roll of the camera
        """
        camorigin = radius * np.array([np.cos(azimuth)*np.sin(zenith), 
                                       np.sin(azimuth)*np.sin(zenith), 
                                       np.cos(zenith)])
        self._viewmatrix = self.viewmatrix(camorigin, up, camorigin)
        rays_o, rays_d = self.generate_rays(
            self._viewmatrix, self.width, self.height, self.focal)
        self._pts = self.sample_along_rays(rays_o, rays_d, 15., 35., self.samples)
        self.x, self.y, self.z = self._pts[...,0], self._pts[...,1], self._pts[...,2]
        self.d = jnp.linalg.norm(jnp.concatenate([jnp.diff(self._pts, axis=2), 
                                                  jnp.zeros_like(self._pts[...,-1:,:])], 
                                                 axis=2), axis=-1)
    
    def render(self, emission, facewidth, jit=False, bh_radius=0.0, linewidth=0.1, bh_albedo=[0,0,0], cmap='hot'):
        """
        Render an image of the 3D emission
        
        Parameters
        ----------
        emission: 3D array 
            3D array with emission values
        jit: bool, default=False,
            Just in time compilation. Set true for rendering multiple frames.
            First rendering will take more time due to compilation.
        bh_radius: float, default=0.0
            Radius at which to draw a black hole (for visualization). 
            If bh_radius=0 then no black hole is drawn.
        facewidth: float, default=10.0 
            width of the enclosing cube face
        linewidth: float, default=0.1
            width of the cube lines
        bh_albedo: list, default=[0,0,0]
            Albedo (rgb) of the black hole. default is completly black.
        cmap: str, default='hot'
            Colormap for visualization
        Returns
        -------
        rendering: array,
            Rendered image
        """
        if self._pts is None: 
            raise AttributeError('must set view before rendering')
    
        
        cm = plt.get_cmap('hot') 
        emission_cm = cm(emission)
        emission_cm = jnp.clip(emission_cm - 0.05, 0.0, 1.0)
        emission_cm = jnp.concatenate([emission_cm[..., :3], emission[..., None] / jnp.amax(emission)], axis=-1)

        if jit:
            emission_cube = draw_cube_jit(emission_cm, self._pts, facewidth, linewidth)
            if bh_radius > 0:
                emission_cube = draw_bh_jit(emission_cube, self._pts, bh_radius, bh_albedo)
        else:
            emission_cube = draw_cube(emission_cm, self._pts, facewidth, linewidth)
            if bh_radius > 0:
                emission_cube = draw_bh(emission_cube, self._pts, bh_radius, bh_albedo)
        rendering = alpha_composite(emission_cube, self.d, self._pts, bh_radius)
        return rendering
    
    def viewmatrix(self, lookdir, up, position):
        """Construct lookat view matrix."""
        vec2 = normalize(lookdir)
        vec0 = normalize(np.cross(up, vec2))
        vec1 = normalize(np.cross(vec2, vec0))
        m = np.stack([vec0, vec1, vec2, position], axis=1)
        return m

    def generate_rays(self, camtoworlds, width, height, focal):
        """Generating rays for all images."""
        x, y = np.meshgrid(  # pylint: disable=unbalanced-tuple-unpacking
            np.arange(width, dtype=np.float32),  # X-Axis (columns)
            np.arange(height, dtype=np.float32),  # Y-Axis (rows)
            indexing='xy')
        camera_dirs = np.stack(
            [(x - width * 0.5 + 0.5) / focal,
             -(y - height * 0.5 + 0.5) / focal, -np.ones_like(x)],
            axis=-1)
        directions = ((camera_dirs[..., None, :] *
                       camtoworlds[None, None, :3, :3]).sum(axis=-1))
        origins = np.broadcast_to(camtoworlds[None, None, :3, -1],
                                  directions.shape)

        return origins, directions

    def sample_along_rays(self, rays_o, rays_d, near, far, num_samples):
        t_vals = jnp.linspace(near, far, num_samples)
        pts = rays_o[..., None, :] + t_vals[None, None, :, None] * rays_d[..., None, :]
        return pts
    
    @property
    def coords(self):
        coords = None if self._pts is None else jnp.moveaxis(self._pts, -1, 0)
        return coords

def alpha_composite(emission, dists, pts, bh_rad, inside_halfwidth=4.5):
    emission = np.clip(emission, 0., 1.)
    color = emission[..., :-1] * dists[0, ..., None]
    alpha = emission[..., -1:] 

    # mask for points inside wireframe
    inside = np.where(np.less(np.amax(np.abs(pts), axis=-1), inside_halfwidth), 
                      np.ones_like(pts[..., 0]),
                      np.zeros_like(pts[..., 0]))

    # masks for points outside black hole
    bh = np.where(np.greater(np.linalg.norm(pts, axis=-1), bh_rad),
                  np.ones_like(pts[..., 0]),
                  np.zeros_like(pts[..., 0]))

    combined_mask = np.logical_and(inside, bh)


    rendering = np.zeros_like(color[:, :, 0, :])
    acc = np.zeros_like(color[:, :, 0, 0])
    outside_acc = np.zeros_like(color[:, :, 0, 0])
    for i in range(alpha.shape[-2]):
        ind = alpha.shape[-2] - i - 1

        # if pixels inside cube and outside black hole, don't alpha composite
        rendering = rendering + combined_mask[..., ind, None] * color[..., ind, :]

        # else, alpha composite      
        outside_alpha = alpha[..., ind, :] * (1. - combined_mask[..., ind, None])
        rendering = rendering * (1. - outside_alpha) + color[..., ind, :] * outside_alpha 

        acc = alpha[..., ind, 0] + (1. - alpha[..., ind, 0]) * acc
        outside_acc = outside_alpha[..., 0] + (1. - outside_alpha[..., 0]) * outside_acc

    rendering += np.array([1., 1., 1.])[None, None, :] * (1. - acc[..., None])
    return rendering

@jax.jit
def draw_cube_jit(emission_cm, pts, facewidth, linewidth):
    
    linecolor = jnp.array([0.0, 0.0, 0.0, 1000.0])
    vertices = jnp.array([[-facewidth/2., -facewidth/2., -facewidth/2.],
                        [facewidth/2., -facewidth/2., -facewidth/2.],
                        [-facewidth/2., facewidth/2., -facewidth/2.],
                        [facewidth/2., facewidth/2., -facewidth/2.],
                        [-facewidth/2., -facewidth/2., facewidth/2.],
                        [facewidth/2., -facewidth/2., facewidth/2.],
                        [-facewidth/2., facewidth/2., facewidth/2.],
                        [facewidth/2., facewidth/2., facewidth/2.]])
    dirs = jnp.array([[-1., 0., 0.],
                      [1., 0., 0.],
                      [0., -1., 0.],
                      [0., 1., 0.],
                      [0., 0., -1.],
                      [0., 0., 1.]])

    for i in range(vertices.shape[0]):

        for j in range(dirs.shape[0]):
            # Draw line segments from each vertex
            line_seg_pts = vertices[i, None, :] + jnp.linspace(0.0, facewidth, 64)[:, None] * dirs[j, None, :]

            for k in range(line_seg_pts.shape[0]):
                dists = jnp.linalg.norm(pts - jnp.broadcast_to(line_seg_pts[k, None, None, None, :], pts.shape), axis=-1)
                emission_cm += linecolor[None, None, None, :] * jnp.exp(-1. * dists / linewidth ** 2)[..., None]

    out = jnp.where(jnp.greater(jnp.broadcast_to(jnp.amax(jnp.abs(pts), axis=-1, keepdims=True), emission_cm.shape), 
                                facewidth/2. + linewidth), jnp.zeros_like(emission_cm), emission_cm)
    return out

def draw_cube(emission_cm, pts, facewidth, linewidth):
    linecolor = jnp.array([0.0, 0.0, 0.0, 1000.0])
    vertices = jnp.array([[-facewidth/2., -facewidth/2., -facewidth/2.],
                        [facewidth/2., -facewidth/2., -facewidth/2.],
                        [-facewidth/2., facewidth/2., -facewidth/2.],
                        [facewidth/2., facewidth/2., -facewidth/2.],
                        [-facewidth/2., -facewidth/2., facewidth/2.],
                        [facewidth/2., -facewidth/2., facewidth/2.],
                        [-facewidth/2., facewidth/2., facewidth/2.],
                        [facewidth/2., facewidth/2., facewidth/2.]])
    dirs = jnp.array([[-1., 0., 0.],
                      [1., 0., 0.],
                      [0., -1., 0.],
                      [0., 1., 0.],
                      [0., 0., -1.],
                      [0., 0., 1.]])

    for i in range(vertices.shape[0]):

        for j in range(dirs.shape[0]):
            # Draw line segments from each vertex
            line_seg_pts = vertices[i, None, :] + jnp.linspace(0.0, facewidth, 64)[:, None] * dirs[j, None, :]

            for k in range(line_seg_pts.shape[0]):
                dists = jnp.linalg.norm(pts - jnp.broadcast_to(line_seg_pts[k, None, None, None, :], pts.shape), axis=-1)
                emission_cm += linecolor[None, None, None, :] * jnp.exp(-1. * dists / linewidth ** 2)[..., None]

    out = jnp.where(jnp.greater(jnp.broadcast_to(jnp.amax(jnp.abs(pts), axis=-1, keepdims=True), emission_cm.shape), 
                                facewidth/2. + linewidth), jnp.zeros_like(emission_cm), emission_cm)
    return out

@jax.jit
def draw_bh_jit(emission, pts, bh_radius, bh_albedo):
    bh_albedo = jnp.array(bh_albedo)[None, None, None, :]
    lightdir = jnp.array([-1., -1., 1.])
    lightdir /= jnp.linalg.norm(lightdir, axis=-1, keepdims=True)
    bh_color = jnp.sum(lightdir * pts, axis=-1)[..., None] * bh_albedo
    emission = jnp.where(jnp.less(jnp.linalg.norm(pts, axis=-1, keepdims=True), bh_radius),
                    jnp.concatenate([bh_color, jnp.ones_like(emission[..., 3:])], axis=-1), emission)
    return emission

def draw_bh(emission, pts, bh_radius, bh_albedo):
    bh_albedo = jnp.array(bh_albedo)[None, None, None, :]
    lightdir = jnp.array([-1., -1., 1.])
    lightdir /= jnp.linalg.norm(lightdir, axis=-1, keepdims=True)
    bh_color = jnp.sum(lightdir * pts, axis=-1)[..., None] * bh_albedo
    emission = jnp.where(jnp.less(jnp.linalg.norm(pts, axis=-1, keepdims=True), bh_radius),
                    jnp.concatenate([bh_color, jnp.ones_like(emission[..., 3:])], axis=-1), emission)
    return emission

In [61]:
checkpoint_dir = '../checkpoints/tutorial3/recovery.2022-11-16.16:19:12/'
predictor = NeRF_Predictor()
state = checkpoints.restore_checkpoint(checkpoint_dir, None)

In [62]:
"""
Define a VolumeVisualizer object
"""
resolution = 128
visualizer = VolumeVisualizer(resolution, resolution, resolution)

In [65]:
"""
Define the position of the camera and sample/integrate emission according to the camera rays.
View the estimated emission at t=0 from multiple angles. 
Note that jit is useful for acceleration however the initial rendering will be slow due to compilation.
"""
images = []
bh_radius = 2.0
norm_const = 0.02
rmax = 8.0

for azimuth in tqdm(np.linspace(0.0, 360, 6)):
    visualizer.set_view(radius=32.5, azimuth=azimuth, zenith=np.pi/3)
    emission = sample_3d_grid(predictor.apply, state['params'], bh_radius, rmax, coords=visualizer.coords)
    emission = emission / emission.max()
    images.append(visualizer.render(emission, facewidth=2*rmax, jit=True, bh_radius=bh_radius))

  0%|          | 0/6 [00:00<?, ?it/s]

In [66]:
%matplotlib widget
fig, axes = plt.subplots(1, 6, figsize=(10,2))
for ax, image in zip(axes, images):
    ax.imshow(image);
    ax.axis('off');
plt.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



In [71]:
import ipyvolume as ipv
emission = sample_3d_grid(predictor.apply, state['params'], rmin=bh_radius, rmax=rmax, fov=2*rmax)
ipv.figure()
ipv.view(0, -60, distance=2.5)
ipv.volshow(emission, extent=[(-rmax, rmax)]*3, memorder='F', level=[0, 0.2, 0.7], opacity=[0, 0.2, 0.3], controls=False)
ipv.show()

VBox(children=(Figure(camera=PerspectiveCamera(fov=45.0, position=(0.0, -2.1650635094610964, 1.250000000000000…