In [1]:
import bhnerf
from astropy import units
import jax

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
from tqdm.notebook import tqdm
from flax.training import checkpoints
from pathlib import Path
import ruamel.yaml as yaml

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'

import warnings
warnings.simplefilter("ignore")

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-ud27qdxo because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


Welcome to eht-imaging! v 1.2.2 



2023-05-19 10:27:32.658246: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /.singularity.d/libs


In [25]:
basename = 'inc_{:.1f}.seed_{}'

recovery_path = Path('../data/synthetic_lightcurves/two_gaussians/recovery/sim2_quadrant_1_seed1_Q_0.25_intrinsic/')
with open(recovery_path.joinpath('params.yaml'), 'r') as stream:
    config = yaml.load(stream, Loader=yaml.Loader)

locals().update(config['simulation']['model'])
locals().update(config['recovery']['model'])
locals().update(config['recovery']['optimization']) 

# Preprocess / split data to train/validation
data_path = Path(config['simulation']['lightcurve_path'])
lightcurves_df = pd.read_csv(data_path)
target, t_frames = np.array(lightcurves_df[stokes]), np.array(lightcurves_df['t'])*units.hr
train_idx = t_frames <= t_start_obs*units.hr + train_split*units.min
val_idx = t_frames > t_start_obs*units.hr + train_split*units.min
data_train, data_val  = target[train_idx], target[val_idx] 
t_train, t_val = t_frames[train_idx], t_frames[val_idx]

inc_true = config['simulation']['model']['inclination']
J_inds = [['I', 'Q', 'U'].index(s) for s in stokes]

In [26]:
def sample_3D_recovery(checkpoint_dir, coords, chunk=-1):
    predictor = bhnerf.network.NeRF_Predictor.from_yml(checkpoint_dir)
    state = checkpoints.restore_checkpoint(checkpoint_dir, None)
    emission = bhnerf.network.sample_3d_grid(predictor.apply, state['params'], coords=coords, chunk=chunk)
    return emission

def image_plane_model(inc, spin, randomize_subpixel_rays=False):
    # Compute geodesics paths
    geos = bhnerf.kgeo.image_plane_geos(
        spin, inc, 
        num_alpha=num_alpha, 
        num_beta=num_beta, 
        alpha_range=[-fov_M/2, fov_M/2],
        beta_range=[-fov_M/2, fov_M/2],
        randomize_subpixel_rays=randomize_subpixel_rays
    )
    geos = geos.fillna(0.0)

     # Keplerian velocity and Doppler boosting
    rot_sign = {'cw': -1, 'ccw': 1}
    Omega = rot_sign[Omega_dir] * np.sqrt(geos.M) / (geos.r**(3/2) + geos.spin * np.sqrt(geos.M))
    umu = bhnerf.kgeo.azimuthal_velocity_vector(geos, Omega)
    g = bhnerf.kgeo.doppler_factor(geos, umu)

    # Magnitude normalized magnetic field in fluid-frame
    b = bhnerf.kgeo.magnetic_field_fluid_frame(geos, umu, **b_consts)
    domain = np.bitwise_and(np.bitwise_and(np.abs(geos.z) < z_width, geos.r > rmin), geos.r < rmax)
    b_mean = np.sqrt(np.sum(b[domain]**2, axis=-1)).mean()
    b /= b_mean

    # Polarized emission factors (including parallel transport)
    J = np.nan_to_num(bhnerf.kgeo.parallel_transport(geos, umu, g, b, Q_frac=Q_frac, V_frac=0), 0.0)[J_inds]
    t_injection = -float(geos.r_o + fov_M/4)
    raytracing_args = bhnerf.network.raytracing_args(geos, Omega, t_injection, t_start_obs*units.hr, J)
    return raytracing_args

def image_plane_model_perturb_rays(inc, spin):
    raytracing_args = []
    for i in tqdm(range(num_subrays), leave=False):
        raytracing_args.append(image_plane_model(inc, spin, randomize_subpixel_rays=True))
    return raytracing_args

def image_plane_fit(raytracing_args, checkpoint_dir, t, data, rmin=0.0, rmax=np.inf, batchsize=20):
    predictor = bhnerf.network.NeRF_Predictor.from_yml(checkpoint_dir)
    predictor.rmax = min(rmax, predictor.rmax)
    predictor.rmin = max(rmin, predictor.rmin)
    params = predictor.init_params(raytracing_args)
    state = predictor.init_state(params, checkpoint_dir=checkpoint_dir)
    train_step = bhnerf.optimization.TrainStep.image(t, data, sigma, dtype='lc')
    _, image_plane = bhnerf.optimization.total_movie_loss(batchsize, state, train_step, raytracing_args, return_frames=True)
    datafit = np.sum(((image_plane.sum(axis=(-1,-2)) - data) / sigma)**2) / len(t)
    return datafit, image_plane

# Inclination data-fit
---
Data-fit as a function of inclination angle ("zero-order" marginal likelihood) \
Generalization: how well does the recovery preform on validation dataset of lightcurves at different times: (11.05 (t7) -- 12.71 UTC)

In [16]:
plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica"})

%matplotlib widget
plt.figure(figsize=(5,4))
plt.errorbar(inc_loss.index, np.nanmean(np.log10(inc_loss), axis=1), np.nanstd(np.log10(inc_loss), axis=1), color='tab:orange', marker='^', mfc='r', mec='r', label='data', markersize=5)
plt.title(r'Inclination data-fit: $\log \chi^2(\theta | {\bf w}^\star)$', fontsize=16)
plt.xticks(fontsize='14')
plt.yticks(fontsize='14')
plt.axhline(0, color='black', linestyle='--',linewidth=0.75)
plt.axvline(inc_true, color='black', linestyle='--',linewidth=1.3)
plt.text(inc_true*1.01, 8.5, r'$\theta_{\rm true}$', fontsize=16)
plt.legend()
plt.xticks([20,30,40,50,60,70,80])
plt.savefig(recovery_path.joinpath('inclination_loss_log_subrays_10.pdf'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [8]:
seeds = range(4)
inclinations = np.arange(56, 82, 2, dtype=float)

data_fit = np.full((len(inclinations), len(seeds)), fill_value=np.nan)
inc_prev = spin_pref = np.nan
for i, inc in enumerate(tqdm(inclinations, desc='inc')):
    for j, seed in enumerate(tqdm(seeds, desc='seed', leave=False)):
        checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
        if os.path.exists(checkpoint_dir.joinpath('checkpoint_50000')):
            if (inc_prev != inc) or (spin_prev != spin):
                raytrace_args = image_plane_model_perturb_rays(np.deg2rad(inc), spin)
                inc_prev, spin_prev = inc, spin
            data_fit[i,j],_ = image_plane_fit(raytrace_args, checkpoint_dir, t_train, data_train)

inc:   0%|          | 0/13 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [17]:
seeds = range(4)
inclinations = np.arange(4, 82, 2, dtype=float)

data_fit = np.full((len(inclinations), len(seeds)), fill_value=np.nan)
validation_fit = np.full((len(inclinations), len(seeds)), fill_value=np.nan)
inc_prev = spin_pref = np.nan
for i, inc in enumerate(tqdm(inclinations, desc='inc')):
    for j, seed in enumerate(tqdm(seeds, desc='seed', leave=False)):
        checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
        if os.path.exists(checkpoint_dir.joinpath('checkpoint_50000')):
            if (inc_prev != inc) or (spin_prev != spin):
                raytrace_args = image_plane_model_perturb_rays(np.deg2rad(inc), spin)
                inc_prev, spin_prev = inc, spin
            data_fit[i,j],_ = image_plane_fit(raytrace_args, checkpoint_dir, t_train, data_train)
            validation_fit[i,j],_ = image_plane_fit(raytrace_args, checkpoint_dir, t_val, data_val)
            
data_fit_df = pd.DataFrame(data_fit, index=inclinations, columns=['seed 0', 'seed 1', 'seed 2', 'seed 3'])
validation_fit_df = pd.DataFrame(validation_fit, index=inclinations, columns=['seed 0', 'seed 1', 'seed 2', 'seed 3'])
data_fit_df.index.name = 'inc'
validation_fit_df.index.name = 'inc'
data_fit_df.to_csv(recovery_path.joinpath('inclination_loss_subrays_{}.csv'.format(num_subrays)))
validation_fit_df.to_csv(recovery_path.joinpath('inclination_loss_validation_subrays_{}.csv'.format(num_subrays)))

inc:   0%|          | 0/39 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [19]:
inc_loss = pd.read_csv(recovery_path.joinpath('inclination_loss_subrays_10.csv'), index_col=0)


plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica"})

%matplotlib widget
plt.figure(figsize=(5,4))
plt.errorbar(inc_loss.index, np.nanmean(np.log10(inc_loss), axis=1), np.nanstd(np.log10(inc_loss), axis=1), color='tab:orange', marker='^', mfc='r', mec='r', label='data', markersize=5)
#plt.plot(inclinations, np.log10(data_fit), markersize=5)
# plt.errorbar(inclinations, np.nanmean(np.log(validation_fit), axis=1), np.nanstd(np.log(validation_fit), axis=1), color='tab:blue', marker='o', mfc='blue', mec='blue', label='validation', markersize=4)
plt.title(r'Inclination data-fit: $\log \chi^2(\theta | {\bf w}^\star)$', fontsize=16)
plt.xticks(fontsize='14')
plt.yticks(fontsize='14')
plt.axhline(0, color='black', linestyle='--',linewidth=0.75)
plt.axvline(inc_true, color='black', linestyle='--',linewidth=1.3)
plt.text(inc_true*1.01, 8.5, r'$\theta_{\rm true}$', fontsize=16)
plt.legend()
plt.xticks([20,30,40,50,60,70,80])
plt.savefig(recovery_path.joinpath('inclination_loss_log_subrays_10.pdf'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [16]:
seeds = range(4)
inclinations = np.arange(4, 82, 2, dtype=float)

data_fit = np.full((len(inclinations), len(seeds)), fill_value=np.nan)
inc_prev = spin_pref = np.nan
for i, inc in enumerate(tqdm(inclinations, desc='inc')):
    for j, seed in enumerate(tqdm(seeds, desc='seed', leave=False)):
        checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
        if os.path.exists(checkpoint_dir.joinpath('checkpoint_50000')):
            if (inc_prev != inc) or (spin_prev != spin):
                raytrace_args = image_plane_model(np.deg2rad(inc), spin)
                inc_prev, spin_prev = inc, spin
            data_fit[i,j],_ = image_plane_fit(raytrace_args, checkpoint_dir, t_train, data_train)
            
data_fit_df = pd.DataFrame(data_fit, index=inclinations, columns=['seed 0', 'seed 1', 'seed 2', 'seed 3'])
data_fit_df.index.name = 'inc'
data_fit_df.to_csv(recovery_path.joinpath('inclination_loss.csv'))

inc:   0%|          | 0/39 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

seed:   0%|          | 0/4 [00:00<?, ?it/s]

In [24]:
inc_loss = pd.read_csv(recovery_path.joinpath('inclination_loss.csv'), index_col=0)

plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica"})

%matplotlib widget
plt.figure(figsize=(5,4))
plt.errorbar(inclinations, np.nanmean(np.log10(inc_loss), axis=1), np.nanstd(np.log10(inc_loss), axis=1), color='tab:orange', marker='^', mfc='r', mec='r', label='data', markersize=5)
#plt.plot(inclinations, np.log10(data_fit), markersize=5)
# plt.errorbar(inclinations, np.nanmean(np.log(validation_fit), axis=1), np.nanstd(np.log(validation_fit), axis=1), color='tab:blue', marker='o', mfc='blue', mec='blue', label='validation', markersize=4)
plt.title(r'Inclination data-fit: $\log \chi^2(\theta | {\bf w}^\star)$', fontsize=16)
plt.xticks(fontsize='14')
plt.yticks(fontsize='14')
plt.axhline(0, color='black', linestyle='--',linewidth=0.75)
plt.axvline(inc_true, color='black', linestyle='--',linewidth=1.3)
plt.text(inc_true*1.01, 8.5, r'$\theta_{\rm true}$', fontsize=16)
plt.legend()
plt.savefig(recovery_path.joinpath('inclination_loss_log.pdf'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
recovery_path = Path('../data/synthetic_lightcurves/two_gaussians/recovery/sim1_quadrant_1_seed1_Q_0.25_intrinsic/')
inc_loss_mean = np.log10(pd.read_csv(recovery_path.joinpath('inclination_loss_subrays_10.csv'), index_col=0)).mean(axis=1)

In [None]:
recovery_path = Path('../data/synthetic_lightcurves/flux_tube/recovery/sim1_seed1_Q_0.25_intrinsic/')
inc_loss_mean = np.log10(pd.read_csv(recovery_path.joinpath('inclination_loss_subrays_10.csv'), index_col=0)).mean(axis=1)

In [None]:
recovery_path = Path('../checkpoints/alma/intrinsic_fits/vertical_b_variable_pixels1/')
inc_loss_mean = np.log(pd.read_csv(recovery_path.joinpath('inclination_loss_subrays_10_0.csv'), index_col=0)).loc[:, 'seed 0']

In [209]:
%matplotlib widget
thresh = np.array(inc_loss_mean).min() + (np.array(inc_loss_mean).max() - np.array(inc_loss_mean).min()) / 5
mean_indices = inc_loss_mean < thresh
plt.plot(inc_loss_mean.index, inc_loss_mean, marker='^', zorder=1)
plt.axhline(thresh, color='black', linestyle='--',linewidth=1.3, label='20%')
plt.scatter(inc_loss_mean.index[mean_indices], inc_loss_mean[mean_indices], marker='^', color='r', zorder=2)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.collections.PathCollection at 0x7f582ffb9340>

In [132]:
%matplotlib widget
thresh = (inc_loss_mean.max() - inc_loss_mean.min()) / 10.0
mean_indices = inc_loss_mean - inc_loss_mean.min() < thresh
plt.plot(inc_loss_mean.index, inc_loss_mean, marker='^', zorder=1)
plt.scatter(inc_loss_mean.index[mean_indices], inc_loss_mean[mean_indices], marker='^', color='r', zorder=2)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.collections.PathCollection at 0x7f5832e6e250>

## Inclination Validation analysis

In [15]:
seed = 0
inclinations = [6, 10]
model_data, model_val = [], []
for inc  in inclinations:
    raytrace_args = image_plane_model(np.deg2rad(inc), spin)
    checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
    loss, image_plane = image_plane_fit(raytrace_args, checkpoint_dir, t_train, data_train)
    model_data.append(image_plane.sum(axis=(-1,-2)))
    loss, image_plane = image_plane_fit(raytrace_args, checkpoint_dir, t_val, data_val)
    model_val.append(image_plane.sum(axis=(-1,-2)))

In [22]:
%matplotlib widget
plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica"})

colors = ['tab:red', 'tab:orange']
fmts = ['^', 'x']
fig, axes = plt.subplots(1, 2, figsize=(7,3))
for i in range(len(inclinations)):
    bhnerf.visualization.plot_stokes_lc(model_data[i], stokes, t_train, color=colors[i], fmt=fmts[i],  axes=axes, label=r'$\theta={}^\circ$'.format(inclinations[i]))
    bhnerf.visualization.plot_stokes_lc(model_val[i], stokes, t_val, color=colors[i], fmt=fmts[i], axes=axes)

bhnerf.visualization.plot_stokes_lc(target, stokes, t_frames, axes=axes, label='Data', color='tab:blue')

t0, t7 = t_train[0].value, t_train[-1].value
for s, ax in zip(stokes, axes):
    ax.set_title(r'$I_{}$ datafit'.format(s), fontsize=16)
    s_i = ['Q', 'U'].index(s)
    ymin = np.min([target[:,s_i].min(), model_data[0][:,s_i].min(), model_data[1][:,s_i].min(),
                   model_val[0][:,s_i].min(), model_val[1][:,s_i].min()])
    ymax = np.max([target[:,s_i].max(), model_data[0][:,s_i].max(), model_data[1][:,s_i].max(),
                   model_val[0][:,s_i].max(), model_val[1][:,s_i].max()])
    ymin -= 0.3*np.abs(ymin)
    ymax += 0.3*np.abs(ymax)
    ax.fill_between([t0, t7], [ymax, ymax], ymin, alpha=0.3, color='gray')
    ax.set_ylim([ymin, ymax])
    ax.set_xlim(left=t0)
    ax.text(9.5, ymin+0.05*np.abs(ymin), 'radio loops data', fontsize=12)
    ax.text(11.5, ymin+0.05*np.abs(ymin), 'validation', fontsize=12)
axes[0].legend(loc='best', bbox_to_anchor=(0.4, 0., 0.5, 0.5))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x7f0bdc7c4670>

In [123]:
%matplotlib widget
plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica"})

colors = ['tab:red', 'tab:orange']
fmts = ['^', 'x']
fig, axes = plt.subplots(1, 2, figsize=(7,3))
for i in range(len(inclinations)):
    bhnerf.visualization.plot_stokes_lc(model_data[i], stokes, np.array(lightcurves_df['t']), color=colors[i], fmt=fmts[i],  axes=axes, label=r'$\theta={}^\circ$'.format(inclinations[i]))
    bhnerf.visualization.plot_stokes_lc(model_val[i], stokes, np.array(lightcurves_val_df['t']), color=colors[i], fmt=fmts[i], axes=axes)

bhnerf.visualization.plot_stokes_lc(np.array(lightcurves_df[stokes]), stokes, np.array(lightcurves_df['t']), axes=axes, label='Data', color='tab:blue')
bhnerf.visualization.plot_stokes_lc(np.array(lightcurves_val_df[stokes]), stokes, np.array(lightcurves_val_df['t']), axes=axes,  color='tab:blue')

t0, t7 = lightcurves_df['t'].iloc[0], lightcurves_df['t'].iloc[-1]
for s, ax in zip(stokes, axes):
    ax.set_title(r'$I_{}$ datafit'.format(s), fontsize=16)
    s_i = ['Q', 'U'].index(s)
    ymin = np.min([lightcurves_df[s].min(), model_data[0][:,s_i].min(), model_data[1][:,s_i].min(),
                   lightcurves_val_df[s].min(), model_val[0][:,s_i].min(), model_val[1][:,s_i].min()])
    ymax = np.max([lightcurves_df[s].max(), model_data[0][:,s_i].max(), model_data[1][:,s_i].max(),
                   lightcurves_val_df[s].max(), model_val[0][:,s_i].max(), model_val[1][:,s_i].max()])
    ymin -= 0.3*np.abs(ymin)
    ymax += 0.3*np.abs(ymax)
    ax.fill_between([t0, t7], [ymax, ymax], ymin, alpha=0.3, color='gray')
    ax.set_ylim([ymin, ymax])
    ax.set_xlim(left=t0)
    ax.text(9.5, ymin+0.05*np.abs(ymin), 'radio loops data', fontsize=12)
    ax.text(11.5, ymin+0.05*np.abs(ymin), 'validation', fontsize=12)
axes[0].legend(loc='best', bbox_to_anchor=(0.4, 0., 0.5, 0.5))

inc_str = '_'.join([str(inc) for inc in inclinations])
plt.savefig(recovery_path.joinpath('datafit_vs_validation_incs_{}.pdf'.format(inc_str)))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Visualize a single recovery
---
Visualize a single 3D recovery / datafit for fixed black-hole parameters

In [20]:
flare_path = Path(config['simulation']['flare_path'])
emission_flare = xr.load_dataarray(flare_path)
bhnerf.visualization.ipyvolume_3d(emission_flare, fov=fov_M, level=[0.1, .2, 0.6], elevation=0)

VBox(children=(Figure(camera=PerspectiveCamera(fov=45.0, position=(0.0, 0.0, 2.5), projectionMatrix=(1.0, 0.0,…

In [32]:
seed = 1
inclination = 60
resolution = 64
checkpoint_dir = recovery_path.joinpath(basename.format(inclination, seed))

grid_1d = np.linspace(-fov_M/2, fov_M/2, resolution)
coords = np.array(np.meshgrid(grid_1d, grid_1d, grid_1d, indexing='ij'))
emission = sample_3D_recovery(checkpoint_dir, coords)

bhnerf.visualization.ipyvolume_3d(emission, fov=fov_M, level=[0.1, .2, 0.6], elevation=0)

VBox(children=(Figure(camera=PerspectiveCamera(fov=45.0, position=(0.0, 0.0, 2.5), projectionMatrix=(1.0, 0.0,…

In [108]:
checkpoint_dir = recovery_path.joinpath(basename.format(inclination, seed))
raytrace_args = image_plane_model(np.deg2rad(inclination), spin)
loss, image_plane = image_plane_fit(raytrace_args, checkpoint_dir, t_train, data_train)
model = image_plane.sum(axis=(-1,-2))

%matplotlib widget
plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica",})
axes = bhnerf.visualization.plot_stokes_lc(data_train, stokes, t_train, label='Data')
bhnerf.visualization.plot_stokes_lc(model, stokes, t_train, axes=axes, color='r', fmt='x', label='Model')

titles = [r'$I_U$ datafit', r'$Q-U$ datafit']
for ax, title in zip(axes, titles):
    ax.set_title(title, fontsize=16)
    ax.legend()
    
axes[0].set_xlabel('Time [UT]', fontsize=12)
axes[1].set_xlabel('Time [UT]', fontsize=12)
plt.tight_layout()
plt.savefig(checkpoint_dir.joinpath('QU.datafit.pdf'), bbox_inches='tight')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Compare recovery to ground truth

In [33]:
# recovery_path = Path('../data/synthetic_lightcurves/two_gaussians/recovery/sim2_quadrant_2/')
# with open(recovery_path.joinpath('params.yaml'), 'r') as stream:
#     config = yaml.load(stream, Loader=yaml.Loader)
# locals().update(config['simulation']['model'])

seed = 1
inclination = 60
checkpoint_dir = recovery_path.joinpath(basename.format(inclination, seed))

jit = False
resolution = 256
bh_radius = 1 + np.sqrt(1-spin**2)
cam_r = 55.
linewidth = 0.14
zenith = np.deg2rad(35)

visualizer = bhnerf.visualization.VolumeVisualizer(resolution, resolution, resolution)
visualizer.set_view(cam_r=cam_r, domain_r=rmax, azimuth=0.0, zenith=zenith)

emission_flare = emission_scale * xr.load_dataarray(config['simulation']['flare_path'])
emission_true = emission_flare.interp(x=xr.DataArray(visualizer.x), 
                                      y=xr.DataArray(visualizer.y), 
                                      z=xr.DataArray(visualizer.z)).fillna(0.0).data

norm_const =  emission_true.max()
emission_rec = sample_3D_recovery(checkpoint_dir, visualizer.coords, chunk=32)
image_true = visualizer.render(emission_true / norm_const, facewidth=1.9*rmax, jit=jit, 
                               bh_radius=bh_radius, linewidth=linewidth)
image_rec = visualizer.render(emission_rec / norm_const, facewidth=1.9*rmax, jit=jit, 
                              bh_radius=bh_radius, linewidth=linewidth)

In [34]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import Normalize

%matplotlib widget
images = [image_true, image_rec]
titles = ['Ground truth', 'Recovery']
fig, axes = plt.subplots(1,2, figsize=(9,4))
for ax, img, title in zip(axes, images, titles):
    ax.imshow(img)
    ax.set_title(title, fontsize=18, y=0.78)
    ax.set_axis_off()
    
ax = fig.add_subplot(132)
ax.set_visible(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes('bottom', size='3%', pad=-1)
cmap = plt.cm.ScalarMappable(norm=Normalize(0, norm_const, clip=True), cmap=plt.get_cmap('hot'))
cbar = fig.colorbar(cmap, cax=cax, orientation='horizontal', shrink=.0, extend='max')
cbar.ax.tick_params(labelsize=12) 
plt.tight_layout()
plt.savefig(checkpoint_dir.joinpath('gt_vs_rec.pdf'.format(inclination, seed)), bbox_inches='tight')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



In [129]:
raytrace_args = image_plane_model(np.deg2rad(inclination), spin)
loss, image_plane = image_plane_fit(raytrace_args, checkpoint_dir, lightcurves_df)
model = image_plane.sum(axis=(-1,-2))

%matplotlib widget
plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica",})
axes = bhnerf.visualization.plot_stokes_lc(np.array(lightcurves_df[stokes]), stokes, np.array(lightcurves_df['t']), label='Data')
bhnerf.visualization.plot_stokes_lc(model, stokes, t_frames, axes=axes, color='r', fmt='x', label='Model')

titles = [r'$I_U$ datafit', r'$Q-U$ datafit']
for ax, title in zip(axes, titles):
    ax.set_title(title, fontsize=16)
    ax.legend()
    
axes[0].set_xlabel('Time [UT]', fontsize=12)
axes[1].set_xlabel('Time [UT]', fontsize=12)
plt.tight_layout()
# plt.savefig(checkpoint_dir.joinpath('QU.datafit.pdf'), bbox_inches='tight')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [48]:
%matplotlib widget
movie_list = [xr.DataArray(image_plane[:, i], dims=['t','beta','alpha']) for i in range(image_plane.shape[1])]
fig, axes = plt.subplots(1, 3, figsize=(10, 3))
bhnerf.visualization.animate_movies_synced(movie_list, axes, titles=['I', 'Q', 'U'], vmin=[0, -.02, -.02], vmax=[.03, .02, .01])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.animation.FuncAnimation at 0x7f143bbefd30>

# Visualize average recoveries
---
Visualize an average recovery with fixed black-hole parameters

In [254]:
flare_path = Path(config['simulation']['flare_path'])
emission_flare = emission_scale * xr.load_dataarray(flare_path)
bhnerf.visualization.ipyvolume_3d(emission_flare, fov=fov_M, level=[0.1, .2, 0.6])


VBox(children=(Figure(camera=PerspectiveCamera(fov=45.0, position=(0.0, -2.1650635094610964, 1.250000000000000…

In [87]:
seed = 0
inc_grid = [4, 6, 8, 10, 12, 14, 16]
resolution = 64

grid_1d = np.linspace(-fov_M/2, fov_M/2, resolution)
coords = np.array(np.meshgrid(grid_1d, grid_1d, grid_1d, indexing='ij'))
emission = np.empty((len(inc_grid), resolution, resolution, resolution))
for i, inc in enumerate(tqdm(inc_grid, desc='inc')):
    checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
    emission[i] = sample_3D_recovery(checkpoint_dir, coords)

bhnerf.visualization.ipyvolume_3d(emission.mean(axis=0), fov=fov_M, level=[0.1, .2, 0.6])

inc:   0%|          | 0/7 [00:00<?, ?it/s]

interactive(children=(IntSlider(value=0, description='t', max=6), Output()), _dom_classes=('widget-interact',)…

In [255]:
emission_true = emission_flare.interp(x=xr.DataArray(visualizer.x), 
                                      y=xr.DataArray(visualizer.y), 
                                      z=xr.DataArray(visualizer.z)).fillna(0.0).data
norm_const =  emission_true.max()
image_true = visualizer.render(emission_true / norm_const, facewidth=1.9*rmax, jit=jit, 
                               bh_radius=bh_radius, linewidth=linewidth)

In [299]:
seed = 1
inc_grid = np.arange(34, 82, 2)
zenith = np.deg2rad(35)

jit = False
resolution = 256
bh_radius = 1 + np.sqrt(1-spin**2)
cam_r = 55.
linewidth = 0.14

visualizer = bhnerf.visualization.VolumeVisualizer(resolution, resolution, resolution)
visualizer.set_view(cam_r=cam_r, domain_r=rmax, azimuth=0.0, zenith=zenith)

emission = 0
for inc in tqdm(inc_grid, desc='inc', leave=False):
    checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
    emission += sample_3D_recovery(checkpoint_dir, visualizer.coords, chunk=32) / len(inc_grid)

image_rec = visualizer.render(emission / norm_const, facewidth=1.9*rmax, jit=jit, bh_radius=bh_radius, linewidth=linewidth).clip(a_max=1)

inc:   0%|          | 0/24 [00:00<?, ?it/s]

In [300]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import Normalize

%matplotlib widget
images = [image_true, image_rec]
titles = ['Ground truth', 'Recovery Mean']
fig, axes = plt.subplots(1,2, figsize=(9,4))
for ax, img, title in zip(axes, images, titles):
    ax.imshow(img)
    ax.set_title(title, fontsize=18, y=0.78)
    ax.set_axis_off()
    
ax = fig.add_subplot(132)
ax.set_visible(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes('bottom', size='3%', pad=-1)
cmap = plt.cm.ScalarMappable(norm=Normalize(0, norm_const, clip=True), cmap=plt.get_cmap('hot'))
cbar = fig.colorbar(cmap, cax=cax, orientation='horizontal', shrink=.0, extend='max')
cbar.ax.tick_params(labelsize=12) 
plt.tight_layout()

outname = '3D_Recovery_mean_incs_{}_seed_{}.pdf'.format('-'.join([str(int(inc)) for inc in inc_grid]), seed)
plt.savefig(recovery_path.joinpath(outname), bbox_inches='tight')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

