In [8]:
import bhnerf
from astropy import units
import jax

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
from tqdm.notebook import tqdm
from flax.training import checkpoints
from pathlib import Path
import ruamel.yaml as yaml

# Runing on 2 GPUs
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

# Import script function
import sys
sys.path.append('../scripts/')
from Fit_ALMA_LP_Apr11_SgrA_Flare import preprocess_data

In [2]:
def sample_3D_recovery(checkpoint_dir, coords, chunk=-1):
    predictor = bhnerf.network.NeRF_Predictor.from_yml(checkpoint_dir)
    state = checkpoints.restore_checkpoint(checkpoint_dir, None)
    emission = bhnerf.network.sample_3d_grid(predictor.apply, state['params'], coords=coords, chunk=chunk)
    return emission

def image_plane_model(inc, spin, randomize_subpixel_rays=False):
    # Compute geodesics paths
    geos = bhnerf.kgeo.image_plane_geos(
        spin, inc, 
        num_alpha=num_alpha, 
        num_beta=num_beta, 
        alpha_range=[-fov_M/2, fov_M/2],
        beta_range=[-fov_M/2, fov_M/2],
        randomize_subpixel_rays=randomize_subpixel_rays
    )
    geos = geos.fillna(0.0)
     # Keplerian velocity and Doppler boosting
    rot_sign = {'cw': -1, 'ccw': 1}
    Omega = rot_sign[Omega_dir] * np.sqrt(geos.M) / (geos.r**(3/2) + geos.spin * np.sqrt(geos.M))
    umu = bhnerf.kgeo.azimuthal_velocity_vector(geos, Omega)
    g = bhnerf.kgeo.doppler_factor(geos, umu)

    # Magnitude normalized magnetic field in fluid-frame
    b = bhnerf.kgeo.magnetic_field_fluid_frame(geos, umu, **b_consts)
    domain = np.bitwise_and(np.bitwise_and(np.abs(geos.z) < z_width, geos.r > rmin), geos.r < rmax)
    b_mean = np.sqrt(np.sum(b[domain]**2, axis=-1)).mean()
    b /= b_mean

    # Polarized emission factors (including parallel transport)
    de_rot_model = np.deg2rad(de_rot_angle + 20.0)
    J = np.nan_to_num(bhnerf.kgeo.parallel_transport(geos, umu, g, b, Q_frac=Q_frac, V_frac=0), 0.0)
    J_rot = bhnerf.emission.rotate_evpa(J, de_rot_model)

    t_injection = -float(geos.r_o + fov_M/4)
    raytracing_args = bhnerf.network.raytracing_args(geos, Omega, t_injection, t_start_obs*units.hr, J_rot)
    return raytracing_args

def image_plane_model_perturb_rays(inc, spin):
    raytracing_args = []
    for i in tqdm(range(num_subrays), leave=False):
        raytracing_args.append(image_plane_model(inc, spin, randomize_subpixel_rays=True))
    return raytracing_args

def image_plane_fit(raytracing_args, checkpoint_dir, t, data, rmin=0.0, rmax=np.inf, batchsize=20):
    predictor = bhnerf.network.NeRF_Predictor.from_yml(checkpoint_dir)
    predictor.rmax = min(rmax, predictor.rmax)
    predictor.rmin = max(rmin, predictor.rmin)
    params = predictor.init_params(raytracing_args)
    state = predictor.init_state(params, checkpoint_dir=checkpoint_dir)
    train_step = bhnerf.optimization.TrainStep.image(t, data, sigma, dtype='lc')
    _, image_plane = bhnerf.optimization.total_movie_loss(batchsize, state, train_step, raytracing_args, return_frames=True)
    datafit = np.sum(((image_plane.sum(axis=(-1,-2)) - data) / sigma)**2) / len(t)
    return datafit, image_plane

In [17]:
recovery_path = Path('../checkpoints/alma/intrinsic_fits/toroidal_b_variable_pixels1_new/')
outpath = Path(os.path.join(*recovery_path.parts[2:]))
outpath.mkdir(parents=True, exist_ok=True)

with open(recovery_path.joinpath('config.yml'), 'r') as stream:
    config = yaml.load(stream, Loader=yaml.Loader)

locals().update(config['preprocess'])
locals().update(config['model'])
locals().update(config['optimization'])

# Preprocess / split data to train/validation
target, t_frames = preprocess_data(**config['preprocess'])
train_idx = t_frames <= config['preprocess']['t_start']*units.hr + train_split*units.min
val_idx = t_frames > config['preprocess']['t_start']*units.hr + train_split*units.min
data_train, data_val  = target[train_idx], target[val_idx] 
t_train, t_val = t_frames[train_idx], t_frames[val_idx]
rmax = fov_M / 2
if rmin == 'ISCO': rmin = float(bhnerf.constants.isco_pro(spin))

# Inclination data-fit
---
Data-fit as a function of inclination angle ("zero-order" marginal likelihood)

In [5]:
basename = 'inc_{:.1f}.seed_{}'
seeds = range(4)
inclinations = np.arange(4, 82, 2, dtype=float)

data_fit = np.full((len(inclinations), len(seeds)), fill_value=np.nan)
inc_prev = spin_pref = np.nan
for i, inc in enumerate(tqdm(inclinations, desc='inc')):
    for j, seed in enumerate(tqdm(seeds, desc='seed', leave=False)):
        checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
        if os.path.exists(checkpoint_dir.joinpath('checkpoint_50000')):
            if (inc_prev != inc) or (spin_prev != spin):
                raytrace_args = image_plane_model(np.deg2rad(inc), spin)
                inc_prev, spin_prev = inc, spin
            data_fit[i,j],_ = image_plane_fit(raytrace_args, checkpoint_dir, t_train, data_train)
            
data_fit_df = pd.DataFrame(data_fit, index=inclinations, columns=['seed 0', 'seed 1', 'seed 2', 'seed 3'])
data_fit_df.index.name = 'inc'
data_fit_df.to_csv(recovery_path.joinpath('inclination_loss.csv'))

inc:   0%|          | 0/39 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
inclinations = np.arange(4, 82, 2, dtype=float)
inc_loss = pd.read_csv(recovery_path.joinpath('inclination_loss.csv'), index_col=0)
plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica"})

%matplotlib widget
plt.figure(figsize=(5,4.8))
plt.errorbar(inclinations, np.nanmean(np.log10(inc_loss), axis=1), np.nanstd(np.log10(inc_loss), axis=1), 
             color='tab:orange', marker='^', mfc='r', mec='r', 
             label=r'$\log \chi^2(\theta)$', markersize=5)
plt.ylim([-1, 3])
plt.legend(loc='best', fontsize=14)
plt.axhline(0, color='gray', linestyle='--',linewidth=0.8)
plt.xlabel(r'Inclination [deg]', fontsize=16)
plt.title(r'Inclination data-fit: $\log \chi^2(\theta | {\bf w}^\star)$', fontsize=16)
plt.tight_layout()
plt.savefig(outpath.joinpath('inclination_loss.pdf'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [16]:
basename = 'inc_{:.1f}.seed_{}'
seeds = range(4)
inclinations = np.arange(4, 82, 2, dtype=float)
num_subrays = 10

data_fit = np.full((len(inclinations), len(seeds)), fill_value=np.nan)
validation_fit = np.full((len(inclinations), len(seeds)), fill_value=np.nan)
inc_prev = spin_pref = np.nan
for i, inc in enumerate(tqdm(inclinations, desc='inc')):
    for j, seed in enumerate(tqdm(seeds, desc='seed', leave=False)):
        checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
        if os.path.exists(checkpoint_dir.joinpath('checkpoint_50000')):
            if (inc_prev != inc) or (spin_prev != spin):
                raytrace_args = image_plane_model_perturb_rays(np.deg2rad(inc), spin)
                inc_prev, spin_prev = inc, spin
            data_fit[i,j],_ = image_plane_fit(raytrace_args, checkpoint_dir, t_train, data_train)
            validation_fit[i,j],_ = image_plane_fit(raytrace_args, checkpoint_dir, t_val, data_val)
            
data_fit_df = pd.DataFrame(data_fit, index=inclinations, columns=['seed 0', 'seed 1', 'seed 2', 'seed 3'])
data_fit_df.index.name = 'inc'
data_fit_df.to_csv(recovery_path.joinpath('inclination_loss_subrays_{}_0.csv'.format(num_subrays)))

inc:   0%|          | 0/39 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

# Spin data-fit
---
Data-fit as a function of spin

In [32]:
basename = 'spin_{:.1f}.seed_{}'
seeds = range(4)
spin_grid = np.linspace(1e-3, 0.99, 10)
num_subrays = 10 

data_fit = np.full((len(spin_grid), len(seeds)), fill_value=np.nan)
inc_prev = spin_pref = np.nan
for i, spin in enumerate(tqdm(spin_grid, desc='spin')):
    if rmin == 'ISCO': rmin = float(bhnerf.constants.isco_pro(spin))
    for j, seed in enumerate(tqdm(seeds, desc='seed', leave=False)):
        checkpoint_dir = recovery_path.joinpath(basename.format(spin, seed))
        if os.path.exists(checkpoint_dir.joinpath('checkpoint_50000')):
            if (inc_prev != inclination) or (spin_prev != spin):
                raytrace_args = image_plane_model_perturb_rays(np.deg2rad(inclination), spin)
                inc_prev, spin_prev = inclination, spin
            data_fit[i,j],_ = image_plane_fit(raytrace_args, checkpoint_dir, t_train, data_train)
         
data_fit_df = pd.DataFrame(data_fit, index=spin_grid, columns=['seed 0', 'seed 1', 'seed 2', 'seed 3'])
data_fit_df.index.name = 'spin'
data_fit_df.to_csv(recovery_path.joinpath('spin_loss_subrays_{}_1.csv'.format(num_subrays)))

spin:   0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [45]:
basename = 'spin_{:.1f}.seed_{}'
seeds = range(4)
spin_grid = np.linspace(1e-3, 0.99, 10)
num_subrays = 10 

data_fit = np.full((len(spin_grid), len(seeds)), fill_value=np.nan)
inc_prev = spin_pref = np.nan
for i, spin in enumerate(tqdm(spin_grid, desc='spin')):
    if rmin == 'ISCO': rmin = float(bhnerf.constants.isco_pro(spin))
    for j, seed in enumerate(tqdm(seeds, desc='seed', leave=False)):
        checkpoint_dir = recovery_path.joinpath(basename.format(spin, seed))
        if os.path.exists(checkpoint_dir.joinpath('checkpoint_50000')):
            if (inc_prev != inclination) or (spin_prev != spin):
                raytrace_args = image_plane_model_perturb_rays(np.deg2rad(inclination), spin)
                inc_prev, spin_prev = inclination, spin
            data_fit[i,j],_ = image_plane_fit(raytrace_args, checkpoint_dir, t_train, data_train)
         
data_fit_df = pd.DataFrame(data_fit, index=spin_grid, columns=['seed 0', 'seed 1', 'seed 2', 'seed 3'])
data_fit_df.index.name = 'spin'
data_fit_df.to_csv(recovery_path.joinpath('spin_loss_subrays_{}_3.csv'.format(num_subrays)))

spin:   0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [47]:
spin_loss = pd.read_csv(recovery_path.joinpath('spin_loss_subrays_10.csv'), index_col=0)

plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica"})

%matplotlib widget
plt.figure(figsize=(5,4))
plt.errorbar(spin_grid, np.nanmean(np.log10(spin_loss), axis=1), np.nanstd(np.log10(spin_loss), axis=1), color='tab:orange', marker='^', mfc='r', mec='r', label='data-fit', markersize=5)
plt.xticks(fontsize='14')
plt.yticks(fontsize='14')
plt.axhline(0, color='black', linestyle='--',linewidth=0.75)
plt.title(r'Spin data-fit: $\log \chi^2(a | {\bf w}^\star)$', fontsize=16)
# plt.legend(loc='best', fontsize=14)
# plt.savefig(outpath.joinpath('spin_loss_subrays_10.pdf'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, 'Spin data-fit: $\\log \\chi^2(a | {\\bf w}^\\star)$')

In [14]:
basename = 'spin_{:.1f}.seed_{}'
seeds = range(4)
spin_grid = np.linspace(1e-3, 0.99, 10)

data_fit = np.full((len(spin_grid), len(seeds)), fill_value=np.nan)
for i, spin in enumerate(tqdm(spin_grid, desc='spin')):
    if rmin == 'ISCO': rmin = float(bhnerf.constants.isco_pro(spin))
    raytrace_args = image_plane_model(np.deg2rad(inclination), spin)
    for j, seed in enumerate(tqdm(seeds, desc='seed', leave=False)):
        checkpoint_dir = recovery_path.joinpath(basename.format(spin, seed))
        if os.path.exists(checkpoint_dir.joinpath('checkpoint_50000')):
            data_fit[i,j],_ = image_plane_fit(raytrace_args, checkpoint_dir, t_train, data_train)
        
data_fit_df = pd.DataFrame(data_fit, index=spin_grid, columns=['seed 0', 'seed 1', 'seed 2', 'seed 3'])
data_fit_df.index.name = 'spin'
data_fit_df.to_csv(recovery_path.joinpath('spin_loss.csv'))

spin:   0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

In [16]:
spin_loss = pd.read_csv(recovery_path.joinpath('spin_loss.csv'), index_col=0)

plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica"})

%matplotlib widget
plt.figure(figsize=(5,4))
plt.errorbar(spin_grid, np.nanmean(np.log10(spin_loss), axis=1), np.nanstd(np.log10(spin_loss), axis=1), color='tab:orange', marker='^', mfc='r', mec='r', label='data-fit', markersize=5)

plt.xticks(fontsize='14')
plt.yticks(fontsize='14')
plt.axhline(0, color='black', linestyle='--',linewidth=0.75)
plt.title(r'Spin data-fit: $\log \chi^2(a | {\bf w}^\star)$', fontsize=16)
# plt.legend(loc='best', fontsize=14)
plt.savefig(outpath.joinpath('spin_loss.pdf'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
spin_loss = pd.read_csv(recovery_path.joinpath('spin_loss.csv'), index_col=0)
spin_loss_subrays = pd.read_csv(recovery_path.joinpath('spin_loss_subrays_10.csv'), index_col=0)

plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica"})

%matplotlib widget
plt.figure(figsize=(5,4))
plt.errorbar(spin_grid, np.nanmean(np.log10(spin_loss), axis=1), np.nanstd(np.log10(spin_loss), axis=1), color='tab:orange', marker='^', mfc='r', mec='r', label='data-fit', markersize=5)
plt.errorbar(spin_grid, np.nanmean(np.log10(spin_loss_subrays), axis=1), np.nanstd(np.log10(spin_loss_subrays), axis=1), color='tab:blue', marker='o', mfc='b', mec='b', label='validation', markersize=5)

plt.xticks(fontsize='14')
plt.yticks(fontsize='14')
plt.axhline(0, color='black', linestyle='--',linewidth=0.75)
plt.title(r'Spin data-fit: $\log \chi^2(a | {\bf w}^\star)$', fontsize=16)
plt.legend(loc='best', fontsize=14)
plt.savefig(outpath.joinpath('spin_loss_w_subray_validation.pdf'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Visualize a single recovery
---
Visualize a single 3D recovery / datafit for fixed black-hole parameters

In [19]:
basename = 'inc_{:.1f}.seed_{}'
recovery_path = Path('../checkpoints/alma/intrinsic_fits/vertical_b_variable_pixels1/')

seed = 0
inclination = 12.0
resolution = 64
checkpoint_dir = recovery_path.joinpath(basename.format(inclination, seed))

grid_1d = np.linspace(-fov_M/2, fov_M/2, resolution)
coords = np.array(np.meshgrid(grid_1d, grid_1d, grid_1d, indexing='ij'))
emission = sample_3D_recovery(checkpoint_dir, coords)
bhnerf.visualization.ipyvolume_3d(emission, fov=fov_M, level=[0.1, .2, 0.6])

# clip_rmin, clip_rmax = 0, 10
# emission_clipped = bhnerf.emission.fill_unsupervised_emission(emission, coords, clip_rmin, clip_rmax, z_width)
# bhnerf.visualization.ipyvolume_3d(emission_clipped, fov=fov_M, level=[0.1, .2, 0.6])

VBox(children=(Figure(camera=PerspectiveCamera(fov=45.0, position=(0.0, -2.1650635094610964, 1.250000000000000…

In [7]:
raytrace_args = image_plane_model(np.deg2rad(inclination), spin)
datafit, image_plane = image_plane_fit(raytrace_args, checkpoint_dir, t_train, data_train)

In [11]:
%matplotlib widget
plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica",})
axes = bhnerf.visualization.plot_stokes_lc(data_train[:,1:], ['Q','U'], t_train, label='ALMA', color='black', plot_qu=True)
bhnerf.visualization.plot_stokes_lc(model[:,1:], ['Q','U'], t_train, axes=axes, color='r', fmt='x', label='Model', plot_qu=True)

titles = [r'$I_Q$ datafit', r'$I_U$ datafit', r'$Q-U$ datafit']
for ax, title in zip(axes, titles):
    ax.set_title(title, fontsize=16)
    ax.legend()
    
axes[0].set_xlabel('Time [UT]', fontsize=12)
axes[1].set_xlabel('Time [UT]', fontsize=12)
plt.tight_layout()

savepath = outpath.joinpath(basename.format(inclination, seed))
savepath.mkdir(parents=True, exist_ok=True)
plt.savefig(savepath.joinpath('QU.datafit.pdf'), bbox_inches='tight')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [266]:
fov_rad = (fov_M * consts.GM_c2(consts.sgra_mass) / consts.sgra_distance.to('m')) * units.rad
psize = fov_rad.value / num_alpha

im = eh.image.Image(np.rot90(image_plane[0,0], 2), psize, 0, 0)
im.display();

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [44]:
%matplotlib widget
movie_list = [xr.DataArray(image_plane[:, i], dims=['t','beta','alpha']) for i in range(image_plane.shape[1])]
fig, axes = plt.subplots(1, 3, figsize=(10, 3))
bhnerf.visualization.animate_movies_synced(movie_list, axes, titles=['I', 'Q', 'U'], vmin=[0, -.02, -.02], vmax=[.03, .02, .01])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.animation.FuncAnimation at 0x7f2a0458b8e0>

# Component decomposition

In [None]:
raytrace_args = image_plane_model(np.deg2rad(inclination), spin)
model_decomposition = []
for clip_rmin, clip_rmax in zip([0, 13, 18], [13, 18, np.inf]):
    datafit, image_plane = image_plane_fit(raytrace_args, checkpoint_dir, t_train, data_train, clip_rmin, clip_rmax)
    model_decomposition.append(image_plane.sum(axis=(-1,-2)))
    
datafit, image_plane = image_plane_fit(raytrace_args, checkpoint_dir, t_train, data_train)
%matplotlib widget
plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica",})
axes = bhnerf.visualization.plot_stokes_lc(data_train[:,1:], ['Q','U'], t_train, label='ALMA', color='black')
colors = ['tab:orange', 'tab:green', 'tab:blue']
for i, model in enumerate(model_decomposition):
    bhnerf.visualization.plot_stokes_lc(model[:,1:], ['Q','U'], t_train, axes=axes, label='region {}'.format(i), color=colors[i])
bhnerf.visualization.plot_stokes_lc(np.sum(model_decomposition, axis=0)[:,1:], ['Q','U'], t_train, axes=axes, color='r', fmt='x', label='Model')

titles = [r'$I_Q$ datafit', r'$I_U$ datafit', r'$Q-U$ datafit']
for ax, title in zip(axes, titles):
    ax.set_title(title, fontsize=16)
    ax.legend()
    
axes[0].set_xlabel('Time [UT]', fontsize=12)
axes[1].set_xlabel('Time [UT]', fontsize=12)
plt.tight_layout()

savepath = outpath.joinpath(basename.format(inclination, seed))
savepath.mkdir(parents=True, exist_ok=True)
plt.savefig(savepath.joinpath('QU.datafit.multiple_components.pdf'), bbox_inches='tight')

In [34]:
view_zeniths = [np.deg2rad(35), np.deg2rad(65)]
jit = False if len(view_zeniths) < 5 else True
resolution = 256
bh_radius = 1 + np.sqrt(1-spin**2)
cam_r = 55.
linewidth = 0.14
norm_const =  0.05
visualizer = bhnerf.visualization.VolumeVisualizer(resolution, resolution, resolution)

images = np.empty((len(view_zeniths), 3, resolution, resolution, 3))
for i, zenith in enumerate(tqdm(view_zeniths, desc='view angle')):
    visualizer.set_view(cam_r=cam_r, domain_r=rmax, azimuth=0.0, zenith=zenith)
    checkpoint_dir = recovery_path.joinpath(basename.format(inclination, seed))
    emission = sample_3D_recovery(checkpoint_dir, visualizer.coords, chunk=32)
    for j, (clip_rmin, clip_rmax) in enumerate(zip([0, 13, 18], [13, 18, np.inf])):
        emission_clipped = bhnerf.emission.fill_unsupervised_emission(emission, visualizer.coords, clip_rmin, clip_rmax, z_width)
        images[i,j] = visualizer.render(emission_clipped / norm_const, facewidth=1.9*rmax, jit=jit, bh_radius=bh_radius, linewidth=linewidth).clip(a_max=1)

view angle:   0%|          | 0/2 [00:00<?, ?it/s]

In [35]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import Normalize

for j in range(3):
    fig, axes = plt.subplots(1, len(images[:,j]), figsize=(9,4))
    for ax, image in zip(axes, images[:,j]):
        ax.imshow(image)
        ax.set_axis_off()

    ax = fig.add_subplot(132)
    ax.set_visible(False)
    plt.tight_layout()

    savepath = outpath.joinpath(basename.format(inclination, seed))
    savepath.mkdir(parents=True, exist_ok=True)
    plt.savefig(savepath.joinpath('3D.component_{}.pdf'.format(j)), bbox_inches='tight')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Visualize average recoveries
---
Visualize an average recovery with fixed black-hole parameters

In [9]:
basename = 'inc_{:.1f}.seed_{}'
recovery_path = Path('../checkpoints/alma/intrinsic_fits/vertical_b_variable_pixels1/')

seeds = range(4)
inc = 12
resolution = 64

grid_1d = np.linspace(-fov_M/2, fov_M/2, resolution)
coords = np.array(np.meshgrid(grid_1d, grid_1d, grid_1d, indexing='ij'))
emission = np.empty((len(seeds), resolution, resolution, resolution))
for i, seed in enumerate(tqdm(seeds, desc='seed')):
    checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
    emission[i] = sample_3D_recovery(checkpoint_dir, coords)

bhnerf.visualization.ipyvolume_3d(emission.mean(axis=0), fov=fov_M, level=[0.1, .2, 0.6])

seed:   0%|          | 0/4 [00:00<?, ?it/s]

VBox(children=(Figure(camera=PerspectiveCamera(fov=45.0, position=(0.0, -2.1650635094610964, 1.250000000000000…

In [76]:
basename = 'inc_{:.1f}.seed_{}'
recovery_path = Path('../checkpoints/alma/intrinsic_fits/vertical_b_variable_pixels1/')

# Visualize image plane frames
seeds = range(4)
inc = 12

raytrace_args = image_plane_model_perturb_rays(np.deg2rad(inc), spin)
image_plane_average = 0
for i, seed in enumerate(tqdm(seeds, desc='seed')):
    checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
    datafit, image_plane = image_plane_fit(raytrace_args, checkpoint_dir, t_train, data_train)
    image_plane_average += image_plane/len(seeds)

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

In [194]:
import math
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import ehtim as eh
import matplotlib.font_manager as fm
fontprops = fm.FontProperties(size=16)

P = image_plane_average[:,0]
psize = 200 * eh.const_def.RADPERUAS / 64
%matplotlib widget
fig, axes = plt.subplots(1, 4, figsize=(9,3))
for i, t_idx in enumerate([0, 19, 40, 70]):
    image = eh.image.Image(np.flipud(P[t_idx]), psize=psize, ra=0, dec=0)
    image.display(axis=axes[i], cbar_orientation='horizontal', cbar_lims=[0, 0.005])
    t0, t1 = math.modf(t_frames[t_idx].value)
    axes[i].set_aspect('equal')
    axes[i].get_xaxis().set_visible(False)
    axes[i].get_yaxis().set_visible(False)
    axes[i].set_title('{}:{:02d} UTC'.format(int(t1), int(t0*60)), fontsize=18)
    
    scalebar = AnchoredSizeBar(axes[i].transData,
                   50 / 200 * 64, 
                   r'$50 \mu as$', 
                   'lower center', 
                   pad=0.0,
                   color='white',
                   frameon=False,
                   size_vertical=0.1,
                   fontproperties=fontprops)
    axes[i].add_artist(scalebar)
plt.tight_layout()
plt.savefig(recovery_path.joinpath('image_plane_I.pdf'))


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [144]:
import math
from mpl_toolkits.axes_grid1 import make_axes_locatable

P = np.sqrt(image_plane_average[:,1]**2 + image_plane_average[:,2]**2)

%matplotlib widget
fig, axes = plt.subplots(1, 5, figsize=(10,3))
for i, t_idx in enumerate([0, 20, 40, 60, 80]):
    im = axes[i].imshow(P[t_idx], cmap='afmhot', origin='lower', vmax=0.006)
    divider = make_axes_locatable(axes[i])
    cax = divider.append_axes('bottom', size='5%', pad=0.05)
    cbar = fig.colorbar(im, cax=cax, orientation='horizontal')
    cbar.set_ticks([0, 0.0025, 0.005])
    t0, t1 = math.modf(t_frames[t_idx].value)
    axes[i].set_aspect('equal')
    axes[i].get_xaxis().set_visible(False)
    axes[i].get_yaxis().set_visible(False)
    axes[i].set_title('{}:{:02d} UTC'.format(int(t1), int(t0*60)), fontsize=18)
plt.tight_layout()
# plt.savefig('hs+bg illustration/image_plane_stokes.pdf')


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [77]:
%matplotlib widget
movie_list = [xr.DataArray(image_plane_average[:, i], dims=['t','beta','alpha']) for i in range(image_plane_average.shape[1])]
fig, axes = plt.subplots(1, 3, figsize=(10, 3))
bhnerf.visualization.animate_movies_synced(movie_list, axes, titles=['I', 'Q', 'U'], vmin=[0, -.02, -.02], vmax=[.03, .02, .01])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.animation.FuncAnimation at 0x7fd0c4519520>

In [64]:
recovery_path = Path('../checkpoints/alma/intrinsic_fits/vertical_b_variable_pixels1/')
inc_loss = 0.5 * (
    np.log(pd.read_csv(recovery_path.joinpath('inclination_loss_subrays_10_0.csv'), index_col=0)) + 
    np.log(pd.read_csv(recovery_path.joinpath('inclination_loss_subrays_10_1.csv'), index_col=0))
)

%matplotlib widget
thresh = np.array(inc_loss_mean).min() + (np.array(inc_loss_mean).max() - np.array(inc_loss_mean).min()) / 10
mean_indices = inc_loss_mean < thresh
plt.plot(inc_loss_mean.index, inc_loss_mean, marker='^', zorder=1)
plt.axhline(thresh, color='black', linestyle='--',linewidth=1.3, label='20%')
plt.legend(inc_loss_mean.columns)
# plt.scatter(inc_loss_mean.index[mean_indices], inc_loss_mean[mean_indices], marker='^', color='r', zorder=2)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x7f8fc81c8ee0>

In [11]:
# #D rotation around the recovered volume
seeds = range(4)
inc = 12

zenith = np.deg2rad(60)
view_azimuths = np.linspace(0, 2*np.pi, 30)

jit = True
resolution = 256
bh_radius = 1 + np.sqrt(1-spin**2)
cam_r = 55.
linewidth = 0.14
norm_const =  0.05

visualizer = bhnerf.visualization.VolumeVisualizer(resolution, resolution, resolution)

images = []
for azimuth in tqdm(view_azimuths):
    emission = 0
    visualizer.set_view(cam_r=cam_r, domain_r=rmax, azimuth=azimuth, zenith=zenith)
    for  seed in seeds:
        checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
        emission += sample_3D_recovery(checkpoint_dir, visualizer.coords, chunk=32) / len(seeds)

    image = visualizer.render(emission / norm_const, facewidth=1.9*rmax, jit=jit, bh_radius=bh_radius, linewidth=linewidth).clip(a_max=1)
    images.append(image)
    
movie = np.array(images)
np.save('alma/intrinsic_fits/radial_b_variable_pixels1/recovery_3D_rotation_zenith_60.npy', movie)

  0%|          | 0/30 [00:00<?, ?it/s]

In [18]:
np.save('alma/intrinsic_fits/vertical_b_variable_pixels1/inc12.seeds0-4.recovery_3D_rotation_zenith_60.npy', movie)

In [22]:
from matplotlib import animation
def animate(movie, t_frames, fps=10, output=None, writer='ffmpeg'):

    # Image animation function (called sequentially)
    def animate_frame(i):
        #axes.set_title('t={:1.2f} UTC'.format(t_frames[i]))
        im.set_array(movie[i])
        return im
    
    num_frames = movie.shape[0]
    
    fig, axes = plt.subplots(1,1, figsize=(4,4))
    im =  axes.imshow(movie[0])
    #axes.set_title('t={:1.2f} UTC'.format(t_frames[0]), fontsize=16)
    axes.set_axis_off()
    
    plt.tight_layout()
    anim = animation.FuncAnimation(fig, animate_frame, frames=num_frames, interval=1e3 / fps)

    if output is not None:
        anim.save(output, writer=writer, fps=fps)
    return anim

%matplotlib widget
animate(movie, t_frames, output='alma/intrinsic_fits/vertical_b_variable_pixels1/inc12.seeds0-4.recovery_3D_rotation_zenith_60.gif')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.animation.FuncAnimation at 0x7f4a08541c70>

In [55]:
# 10 precent configuration
seeds = [0, 1, 2, 3, 0, 1]
# inc_grid = np.arange(4, 24, 2, dtype=float)
inc_grid = [12, 12, 12, 12, 16, 16]

view_zeniths = [np.deg2rad(35), np.deg2rad(65)]

jit = False if len(view_zeniths) < 5 else True
resolution = 256
bh_radius = 1 + np.sqrt(1-spin**2)
cam_r = 55.
linewidth = 0.14
norm_const =  0.05

visualizer = bhnerf.visualization.VolumeVisualizer(resolution, resolution, resolution)

images = []
for zenith in tqdm(view_zeniths, desc='view angle'):
    emission = 0
    visualizer.set_view(cam_r=cam_r, domain_r=rmax, azimuth=0.0, zenith=zenith)
    for inc, seed in tqdm(zip(inc_grid, seeds), desc='inc/seed', leave=False):
        checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
        emission += sample_3D_recovery(checkpoint_dir, visualizer.coords, chunk=32) / len(inc_grid)
        
    image = visualizer.render(emission / norm_const, facewidth=1.9*rmax, jit=jit, bh_radius=bh_radius, linewidth=linewidth).clip(a_max=1)
    images.append(image)

view angle:   0%|          | 0/2 [00:00<?, ?it/s]

inc/seed: 0it [00:00, ?it/s]

inc/seed: 0it [00:00, ?it/s]

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import Normalize

fig, axes = plt.subplots(1, len(images), figsize=(9,4))
for ax, image in zip(axes, images):
    ax.imshow(image)
    ax.set_axis_off()

ax = fig.add_subplot(132)
ax.set_visible(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes('bottom', size='3%', pad=-1)
cmap = plt.cm.ScalarMappable(norm=Normalize(0, norm_const, clip=True), cmap=plt.get_cmap('hot'))
cbar = fig.colorbar(cmap, cax=cax, orientation='horizontal', shrink=.0)
cbar.ax.tick_params(labelsize=12) 
plt.tight_layout()

outname = '3D_Recovery_20precent_seeds_{}_incs_{}.pdf'.format('-'.join([str(seed) for seed in seeds]), '-'.join([str(int(inc)) for inc in inc_grid]))
plt.savefig(outpath.joinpath(outname), bbox_inches='tight')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [69]:
seeds = [0]
# inc_grid = np.arange(4, 24, 2, dtype=float)
inc_grid = [12]

view_zeniths = [np.deg2rad(35), np.deg2rad(65)]

jit = False if len(view_zeniths) < 5 else True
resolution = 256
bh_radius = 1 + np.sqrt(1-spin**2)
cam_r = 55.
linewidth = 0.14
norm_const =  0.05

visualizer = bhnerf.visualization.VolumeVisualizer(resolution, resolution, resolution)

images = []
for zenith in tqdm(view_zeniths, desc='view angle'):
    emission = 0
    visualizer.set_view(cam_r=cam_r, domain_r=rmax, azimuth=0.0, zenith=zenith)
    for inc in tqdm(inc_grid, desc='inc', leave=False):
        for seed in tqdm(seeds, desc='seed', leave=False):
            checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
            emission += sample_3D_recovery(checkpoint_dir, visualizer.coords, chunk=32) / (len(seeds)*len(inc_grid))
        
    image = visualizer.render(emission / norm_const, facewidth=1.9*rmax, jit=jit, bh_radius=bh_radius, linewidth=linewidth).clip(a_max=1)
    images.append(image)

view angle:   0%|          | 0/2 [00:00<?, ?it/s]

inc:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

inc:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

In [70]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import Normalize

fig, axes = plt.subplots(1, len(images), figsize=(9,4))
for ax, image in zip(axes, images):
    ax.imshow(image)
    ax.set_axis_off()

ax = fig.add_subplot(132)
ax.set_visible(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes('bottom', size='3%', pad=-1)
cmap = plt.cm.ScalarMappable(norm=Normalize(0, norm_const, clip=True), cmap=plt.get_cmap('hot'))
cbar = fig.colorbar(cmap, cax=cax, orientation='horizontal', shrink=.0)
cbar.ax.tick_params(labelsize=12) 
plt.tight_layout()

outname = '3D_Recovery_seeds_{}_incs_{}.pdf'.format('-'.join([str(seed) for seed in seeds]), '-'.join([str(int(inc)) for inc in inc_grid]))
plt.savefig(outpath.joinpath(outname), bbox_inches='tight')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Animate recoveries


In [6]:
from matplotlib import animation

def animate_synced(movie1, movie2, loss, inc_grid, axes, cmap='RdBu_r', fps=10, output=None, writer='imagemagick'):

    def animate_both(i):
        return animate_frame(i), animate_plot(i)
        
    # Image animation function (called sequentially)
    def animate_frame(i):
        axes[0].set_title('Emission estimate: inc={:1.1f}'.format(inc_grid[i]))
        im1.set_array(movie1[i])
        
        axes[1].set_title('Emission estimate: inc={:1.1f}'.format(inc_grid[i]))
        im2.set_array(movie2[i])
        return im1, im2
    
    def animate_plot(i):
        line.set_xdata(inc_grid[i])
        return line,
    
    num_frames = len(movie1)
    
    fig = plt.gcf()
    axes[0].set_title('Emission estimate')
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    
    im1 =  axes[0].imshow(np.zeros_like(movie1[0]))
    im2 =  axes[1].imshow(np.zeros_like(movie2[0]))
    axes[2].plot(inc_grid, loss)
    line = axes[2].axvline(0, color='green', linestyle='--', label='inc hypothesis')
    axes[2].set_title('Inclination Log(loss)')
    axes[2].legend(loc='upper left')
    
    plt.tight_layout()
    anim = animation.FuncAnimation(fig, animate_both, frames=num_frames, interval=1e3 / fps)

    if output is not None:
        anim.save(output, writer=writer, fps=fps)
    return anim

In [8]:
checkpoint_dir_fmt = '../checkpoints/alma/intrinsic_fits/const_I/' + \
                      'IQU.sigmas_1e-01_1e-02_1e-02.spin_0.0.initkey{}.rmin6.0.z_width4.scale1.0.Qfrac{:.1f}.inc_{:1.1f}/'
inc_grid = np.deg2rad(np.linspace(2, 80, 40))
seeds = [2, 3, 4]
Q_frac = 0.85

In [8]:
import warnings
warnings.simplefilter("ignore")

batchsize = 20
J_inds = [['I', 'Q', 'U'].index(s) for s in stokes]

loss = np.full((len(seeds), len(inc_grid)), fill_value=np.nan)

for i, seed in enumerate(tqdm(seeds, desc='seed')):
    for j, inclination in enumerate(tqdm(inc_grid, desc='inc', leave=False)):
        checkpoint_dir = checkpoint_dir_fmt.format(seed, Q_frac, np.rad2deg(inclination))
        if os.path.exists(checkpoint_dir):
            predictor = bhnerf.network.NeRF_Predictor.from_yml(checkpoint_dir)

            geos = bhnerf.kgeo.image_plane_geos(
                spin, inclination, 
                num_alpha=64, num_beta=64, 
                alpha_range=[-rmax, rmax],
                beta_range=[-rmax, rmax]
            )
            geos = geos.fillna(0.0)
            t_injection = -float(geos.r_o + fov_M/4)

            # Keplerian prograde velocity field
            Omega = rot_sign[Omega_dir] * np.sqrt(geos.M) / (geos.r**(3/2) + geos.spin * np.sqrt(geos.M))
            # Omega = rot_sign[Omega_dir] * np.sqrt(geos.M) / (11.0**(3/2) + geos.spin * np.sqrt(geos.M))

            umu = bhnerf.kgeo.azimuthal_velocity_vector(geos, Omega)
            g = bhnerf.kgeo.doppler_factor(geos, umu)
            b = bhnerf.kgeo.magnetic_field(geos, *b_consts)
            J = np.nan_to_num(bhnerf.kgeo.parallel_transport(geos, umu, g, b, Q_frac=Q_frac, V_frac=0), 0.0)[J_inds]
            J_rot = bhnerf.emission.rotate_evpa(J, de_rot_model)
            raytracing_args = bhnerf.network.raytracing_args(geos, Omega, t_injection, t_frames[0], J_rot)
            params = predictor.init_params(raytracing_args)
            state = predictor.init_state(params, checkpoint_dir=checkpoint_dir)
            loss[i, j] = bhnerf.optimization.total_movie_loss(batchsize, state, train_step, raytracing_args)

seed:   0%|          | 0/3 [00:00<?, ?it/s]

inc:   0%|          | 0/40 [00:00<?, ?it/s]

inc:   0%|          | 0/40 [00:00<?, ?it/s]

inc:   0%|          | 0/40 [00:00<?, ?it/s]

In [23]:
%matplotlib widget
for i in range(len(seeds)):
    plt.scatter(np.rad2deg(inc_grid), loss[i],  marker='^')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "Helvetica"
})


In [66]:
%matplotlib widget
plt.figure(figsize=(5,4))
plt.errorbar(np.rad2deg(inc_grid)[1:], np.nanmean(np.log(loss/nt), axis=0)[1:],  
             np.nanstd(np.log(loss/nt), axis=0)[1:], marker='^', mfc='r', mec='r')
plt.title(r'Inclination data-fit: $\log \chi^2(\theta | \hat{\bf w})$', fontsize=16)
plt.xticks(fontsize='14')
plt.yticks(fontsize='14')
plt.axhline(0, color='black', linestyle='--',linewidth=1)
plt.savefig('alma/inclination_loss.pdf')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
%matplotlib widget
plt.scatter(np.rad2deg(inc_grid), np.log(loss),  marker='^')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.collections.PathCollection at 0x7f9cd8057f10>

In [18]:
import warnings
warnings.simplefilter("ignore")

batchsize = 20
stokes= ['Q', 'U']
target_qu = target[:,1:]
train_step = bhnerf.optimization.TrainStep.image(t_frames, target_qu, sigma[1:], dtype='lc')
J_inds = [['I', 'Q', 'U'].index(s) for s in stokes]

loss_qu = []
for inclination in tqdm(inc_grid, desc='inc'):
    
    checkpoint_dir = checkpoint_dir_fmt.format(np.rad2deg(inclination))
    if os.path.exists(checkpoint_dir):
        predictor = bhnerf.network.NeRF_Predictor.from_yml(checkpoint_dir)

        geos = bhnerf.kgeo.image_plane_geos(
            spin, inclination, 
            num_alpha=64, num_beta=64, 
            alpha_range=[-rmax, rmax],
            beta_range=[-rmax, rmax]
        )
        geos = geos.fillna(0.0)
        t_injection = -float(geos.r_o + fov_M/4)

        # Keplerian prograde velocity field
        Omega = rot_sign[Omega_dir] * np.sqrt(geos.M) / (geos.r**(3/2) + geos.spin * np.sqrt(geos.M))
        # Omega = rot_sign[Omega_dir] * np.sqrt(geos.M) / (11.0**(3/2) + geos.spin * np.sqrt(geos.M))

        umu = bhnerf.kgeo.azimuthal_velocity_vector(geos, Omega)
        g = bhnerf.kgeo.doppler_factor(geos, umu)
        b = bhnerf.kgeo.magnetic_field(geos, *b_consts)
        J = np.nan_to_num(bhnerf.kgeo.parallel_transport(geos, umu, g, b, Q_frac=Q_frac, V_frac=0), 0.0)[J_inds]
        J_rot = bhnerf.emission.rotate_evpa(J, de_rot_model)
        raytracing_args = bhnerf.network.raytracing_args(geos, Omega, t_injection, t_frames[0], J_rot)
        params = predictor.init_params(raytracing_args)
        state = predictor.init_state(params, checkpoint_dir=checkpoint_dir)

        loss_qu.append(bhnerf.optimization.total_movie_loss(batchsize, state, train_step, raytracing_args))
    else:
        loss_qu.append(np.nan)
    
loss_qu = np.array(loss_qu)

inc:   0%|          | 0/40 [00:00<?, ?it/s]

In [18]:
checkpoint_dir_fmt = '../checkpoints/alma/intrinsic_fits/inc_scan_const_omega_r_11/' + \
                     'QU.spin_0.0.initkey2.rmin6.0.z_width4.inc_{:1.1f}/'

import warnings
warnings.simplefilter("ignore")

batchsize = 20
J_inds = [['I', 'Q', 'U'].index(s) for s in stokes]

loss_noshear = []
for inclination in tqdm(inc_grid, desc='inc'):
    checkpoint_dir = checkpoint_dir_fmt.format(np.rad2deg(inclination))
    predictor = bhnerf.network.NeRF_Predictor.from_yml(checkpoint_dir)
        
    geos = bhnerf.kgeo.image_plane_geos(
        spin, inclination, 
        num_alpha=64, num_beta=64, 
        alpha_range=[-rmax, rmax],
        beta_range=[-rmax, rmax]
    )
    geos = geos.fillna(0.0)
    t_injection = -float(geos.r_o)

    # Keplerian prograde velocity field
    # Omega = rot_sign[Omega_dir] * np.sqrt(geos.M) / (geos.r**(3/2) + geos.spin * np.sqrt(geos.M))
    Omega = rot_sign[Omega_dir] * np.sqrt(geos.M) / (11.0**(3/2) + geos.spin * np.sqrt(geos.M))
    
    umu = bhnerf.kgeo.azimuthal_velocity_vector(geos, Omega)
    g = bhnerf.kgeo.doppler_factor(geos, umu)
    b = bhnerf.kgeo.magnetic_field(geos, *b_consts)
    J = np.nan_to_num(bhnerf.kgeo.parallel_transport(geos, umu, g, b, Q_frac=Q_frac, V_frac=0), 0.0)[J_inds]
    J_rot = bhnerf.emission.rotate_evpa(J, de_rot_model)
    raytracing_args = bhnerf.network.raytracing_args(geos, Omega, t_injection, t_frames[0], J_rot)
    params = predictor.init_params(raytracing_args)
    state = predictor.init_state(params, checkpoint_dir=checkpoint_dir)
    
    loss_noshear.append(bhnerf.optimization.total_movie_loss(batchsize, state, train_step, raytracing_args))
    
loss_noshear = np.array(loss_noshear)

inc:   0%|          | 0/40 [00:00<?, ?it/s]

In [38]:
%matplotlib widget

snr = 1
min_loss = np.rad2deg(inc_grid)[np.argmin(loss)]
plt.title('Inclination data-fit [log-scale]')
plt.scatter(np.rad2deg(inc_grid), np.log(loss * snr), label='With shear', marker='x')
plt.scatter(np.rad2deg(inc_grid), np.log(loss_noshear * snr), label='Without shear', marker='^')
plt.legend()
plt.savefig('alma/intrinsic fits/inc_loss_both.pdf')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Animate recovered volumes

In [16]:
"""
Animate the recovered volumes from two view angles.
Note: this is time consuming to produce high res images.
"""
seed = 2
inc_grid = list(range(4, 22, 2))
from flax.training import checkpoints

jit = True
resolution = 256
bh_radius = 2.0
cam_r = 55.
linewidth = 0.1
bh_radius = 1 + np.sqrt(1-spin**2)
norm_const =  0.1
visualizer = bhnerf.visualization.VolumeVisualizer(resolution, resolution, resolution)
visualizer.set_view(cam_r=cam_r, domain_r=rmax, azimuth=0.0, zenith=np.deg2rad(35))

images1, images2 = [], []
for i, inclination in enumerate(tqdm(inc_grid)):
    checkpoint_dir = checkpoint_dir_fmt.format(seed, Q_frac, inclination)
    predictor = bhnerf.network.NeRF_Predictor.from_yml(checkpoint_dir)
    state = checkpoints.restore_checkpoint(checkpoint_dir, None)
    
    emission = bhnerf.network.sample_3d_grid(predictor.apply, state['params'], 
                                             coords=visualizer.coords, chunk=32)
    image = visualizer.render(emission / norm_const, facewidth=1.9*rmax, jit=jit, 
                              bh_radius=bh_radius, linewidth=linewidth)
    images1.append(image.clip(a_max=1))
    
    visualizer.set_view(cam_r=cam_r, domain_r=rmax, azimuth=0.0, zenith=np.deg2rad(65))
    emission = bhnerf.network.sample_3d_grid(predictor.apply, state['params'], 
                                             coords=visualizer.coords, chunk=32)
    image = visualizer.render(emission / norm_const, facewidth=1.9*rmax, jit=jit, 
                              bh_radius=bh_radius, linewidth=linewidth)
    images2.append(image.clip(a_max=1))

  0%|          | 0/9 [00:00<?, ?it/s]

IndexError: Replacement index 2 out of range for positional args tuple

In [None]:
%matplotlib widget

fig, axes = plt.subplots(1, 3,figsize=(9,3.5))
output = 'alma/intrinsic fits/emission_estimate_1fps.mp4'
animate_synced(images, images_view2, np.log(loss), np.rad2deg(inc_grid), axes, output=output, fps=1, , writer='ffmpeg')

## Animate data fits across inclinations

In [79]:
%matplotlib inline

import glob, warnings, subprocess
from pathlib import Path
import imageio

plt.style.use('default')

warnings.simplefilter("ignore")

directory = 'alma/datafit/'
frame_fmt = directory + 'frame{:03d}.png'
Path(directory).mkdir(parents=True, exist_ok=True)
gif_writer = imageio.get_writer(directory + 'datafit.gif', fps=1)
mp4_writer = imageio.get_writer(directory + 'datafit.mp4', fps=1)

stokes = ['Q', 'U']
J_inds = [['I', 'Q', 'U'].index(s) for s in stokes]
target = np.array(alma_lc_means[stokes])
sigma = 1.0
snr = 1.0
batchsize = 20
train_step = bhnerf.optimization.TrainStep.image(t_frames, target, sigma, dtype='lc')

for i, inclination in enumerate(tqdm(inc_grid, desc='inc')):
    checkpoint_dir = checkpoint_dir_fmt.format(np.rad2deg(inclination))
    
    geos = bhnerf.kgeo.image_plane_geos(
        spin, inclination, 
        num_alpha=64, num_beta=64, 
        alpha_range=[-rmax, rmax],
        beta_range=[-rmax, rmax]
    )
    geos = geos.fillna(0.0)
    t_injection = -float(geos.r_o)

    # Keplerian prograde velocity field
    Omega = rot_sign[Omega_dir] * np.sqrt(geos.M) / (geos.r**(3/2) + geos.spin * np.sqrt(geos.M))
    umu = bhnerf.kgeo.azimuthal_velocity_vector(geos, Omega)
    g = bhnerf.kgeo.doppler_factor(geos, umu)
    b = bhnerf.kgeo.magnetic_field(geos, *b_consts)
    J = np.nan_to_num(bhnerf.kgeo.parallel_transport(geos, umu, g, b, Q_frac=Q_frac, V_frac=0), 0.0)
    
    raytracing_args = bhnerf.network.raytracing_args(geos, Omega, t_injection, t_frames[0], J[J_inds])
    predictor = bhnerf.network.NeRF_Predictor.from_yml(checkpoint_dir)
    params = predictor.init_params(raytracing_args)
    state = predictor.init_state(params, checkpoint_dir=checkpoint_dir)
    
    _, movie = bhnerf.optimization.total_movie_loss(batchsize, state, train_step, raytracing_args, True)
    lc_est = movie.sum(axis=(-1,-2))
    
    fig, axes = plt.subplots(1, 4, figsize=(14, 3.5))
    axes[3].set_title(r'Inclination log(loss): ${:1.1f}^\circ$'.format(np.rad2deg(inclination)))
    axes[3].plot(np.rad2deg(inc_grid), np.log(loss * snr))
    axes[3].axvline(np.rad2deg(inclination), color='green', linestyle='--', label='inc hypothesis')
    bhnerf.visualization.plot_stokes_lc(target, stokes, t_frames, axes=axes[:3], label='True')
    bhnerf.visualization.plot_stokes_lc(lc_est, stokes, t_frames, axes=axes[:3], color='r', fmt='x', label='Estimate')
    for ax in axes:
        ax.legend()
        
    plt.savefig(frame_fmt.format(i), dpi=100)
    plt.close()
    
    im = imageio.imread(frame_fmt.format(i))
    gif_writer.append_data(im)
    mp4_writer.append_data(im)
    
gif_writer.close()
mp4_writer.close()

inc:   0%|          | 0/40 [00:00<?, ?it/s]

