In [1]:
import bhnerf
from astropy import units
import jax

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
from tqdm.notebook import tqdm
from flax.training import checkpoints
from pathlib import Path
import ruamel.yaml as yaml

# Runing on 2 GPUs
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'

import warnings
warnings.simplefilter("ignore")

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-jt344vsk because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


Welcome to eht-imaging! v 1.2.2 



2023-03-28 18:59:17.544817: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /.singularity.d/libs


In [80]:
basename = 'inc_{:.1f}.seed_{}'

recovery_path = Path('../data/synthetic_lightcurves/single_gaussian/recovery/sim3/')
with open(recovery_path.joinpath('params.yaml'), 'r') as stream:
    config = yaml.load(stream, Loader=yaml.Loader)

locals().update(config['simulation']['model'])
locals().update(config['recovery']['model'])
locals().update(config['recovery']['optimization']) 

lightcurves_df = pd.read_csv(config['simulation']['lightcurve_path'])
lightcurves_val_df = pd.read_csv(config['simulation']['validation_path'])
inc_true = config['simulation']['model']['inclination']
J_inds = [['I', 'Q', 'U'].index(s) for s in stokes]

In [3]:
def sample_3D_recovery(checkpoint_dir, coords, chunk=-1):
    predictor = bhnerf.network.NeRF_Predictor.from_yml(checkpoint_dir)
    state = checkpoints.restore_checkpoint(checkpoint_dir, None)
    emission = bhnerf.network.sample_3d_grid(predictor.apply, state['params'], coords=coords, chunk=chunk)
    return emission

def image_plane_model(inc, spin):
    geos = bhnerf.kgeo.image_plane_geos(
        spin, inc, 
        num_alpha=num_alpha, 
        num_beta=num_beta, 
        alpha_range=[-fov_M/2, fov_M/2],
        beta_range=[-fov_M/2, fov_M/2]
    )
    geos = geos.fillna(0.0)

     # Keplerian velocity and Doppler boosting
    rot_sign = {'cw': -1, 'ccw': 1}
    Omega = rot_sign[Omega_dir] * np.sqrt(geos.M) / (geos.r**(3/2) + geos.spin * np.sqrt(geos.M))
    umu = bhnerf.kgeo.azimuthal_velocity_vector(geos, Omega)
    g = bhnerf.kgeo.doppler_factor(geos, umu)

    # Magnitude normalized magnetic field in fluid-frame
    b = bhnerf.kgeo.magnetic_field_fluid_frame(geos, umu, **b_consts)
    domain = np.bitwise_and(np.bitwise_and(np.abs(geos.z) < z_width, geos.r > rmin), geos.r < rmax)
    b_mean = np.sqrt(np.sum(b[domain]**2, axis=-1)).mean()
    b /= b_mean

    # Polarized emission factors (including parallel transport)
    J = np.nan_to_num(bhnerf.kgeo.parallel_transport(geos, umu, g, b, Q_frac=Q_frac, V_frac=0), 0.0)[J_inds]

    t_injection = -float(geos.r_o + fov_M/4)
    raytracing_args = bhnerf.network.raytracing_args(geos, Omega, t_injection, t_start_obs*units.hr, J)
    return raytracing_args

def image_plane_fit(raytracing_args, checkpoint_dir, lightcurves_df, batchsize=20):
    predictor = bhnerf.network.NeRF_Predictor.from_yml(checkpoint_dir)
    params = predictor.init_params(raytracing_args)
    state = predictor.init_state(params, checkpoint_dir=checkpoint_dir)
    train_step = bhnerf.optimization.TrainStep.image(np.array(lightcurves_df['t']), np.array(lightcurves_df[stokes]), sigma, dtype='lc')
    datafit, image_plane = bhnerf.optimization.total_movie_loss(batchsize, state, train_step, raytracing_args, return_frames=True)
    return datafit, image_plane

# Inclination data-fit
---
Data-fit as a function of inclination angle ("zero-order" marginal likelihood) \
Generalization: how well does the recovery preform on validation dataset of lightcurves at different times: (11.05 (t7) -- 12.71 UTC)

In [4]:
seeds = [1]
inclinations = np.arange(4, 82, 2, dtype=float)

data_fit = np.full((len(inclinations), len(seeds)), fill_value=np.nan)
validation_fit = np.full((len(inclinations), len(seeds)), fill_value=np.nan)
for i, inc in enumerate(tqdm(inclinations, desc='inc')):
    raytrace_args = image_plane_model(np.deg2rad(inc), spin)
    for j, seed in enumerate(tqdm(seeds, desc='seed', leave=False)):
        checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
        data_fit[i,j],_ = image_plane_fit(raytrace_args, checkpoint_dir, lightcurves_df)
        validation_fit[i,j],_ = image_plane_fit(raytrace_args, checkpoint_dir, lightcurves_val_df)

inc:   0%|          | 0/39 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

seed:   0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica"})

%matplotlib widget
plt.figure(figsize=(5,4))
plt.errorbar(inclinations, np.nanmean(np.log(data_fit), axis=1), np.nanstd(np.log(data_fit), axis=1), color='tab:orange', marker='^', mfc='r', mec='r', label='data', markersize=5)
plt.errorbar(inclinations, np.nanmean(np.log(validation_fit), axis=1), np.nanstd(np.log(validation_fit), axis=1), color='tab:blue', marker='o', mfc='blue', mec='blue', label='validation', markersize=4)
plt.title(r'Inclination data-fit: $\log \chi^2(\theta | {\bf w}^\star)$', fontsize=16)
plt.xticks(fontsize='14')
plt.yticks(fontsize='14')
plt.axhline(0, color='black', linestyle='--',linewidth=0.75)
plt.axvline(inc_true, color='black', linestyle='--',linewidth=1.3)
plt.text(inc_true*1.01, 8.5, r'$\theta_{\rm true}$', fontsize=16)
plt.legend()
plt.savefig(recovery_path.joinpath('inclination_loss.pdf'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [110]:
seed = 1
inclinations = [6, inc_true]
model_data, model_val = [], []
for inc  in inclinations:
    raytrace_args = image_plane_model(np.deg2rad(inc), spin)
    checkpoint_dir = recovery_path.joinpath(basename.format(inc, seed))
    loss, image_plane = image_plane_fit(raytrace_args, checkpoint_dir, lightcurves_df)
    model_data.append(image_plane.sum(axis=(-1,-2)))
    loss, image_plane = image_plane_fit(raytrace_args, checkpoint_dir, lightcurves_val_df)
    model_val.append(image_plane.sum(axis=(-1,-2)))

In [123]:
%matplotlib widget
plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica"})

colors = ['tab:red', 'tab:orange']
fmts = ['^', 'x']
fig, axes = plt.subplots(1, 2, figsize=(7,3))
for i in range(len(inclinations)):
    bhnerf.visualization.plot_stokes_lc(model_data[i], stokes, np.array(lightcurves_df['t']), color=colors[i], fmt=fmts[i],  axes=axes, label=r'$\theta={}^\circ$'.format(inclinations[i]))
    bhnerf.visualization.plot_stokes_lc(model_val[i], stokes, np.array(lightcurves_val_df['t']), color=colors[i], fmt=fmts[i], axes=axes)

bhnerf.visualization.plot_stokes_lc(np.array(lightcurves_df[stokes]), stokes, np.array(lightcurves_df['t']), axes=axes, label='Data', color='tab:blue')
bhnerf.visualization.plot_stokes_lc(np.array(lightcurves_val_df[stokes]), stokes, np.array(lightcurves_val_df['t']), axes=axes,  color='tab:blue')

t0, t7 = lightcurves_df['t'].iloc[0], lightcurves_df['t'].iloc[-1]
for s, ax in zip(stokes, axes):
    ax.set_title(r'$I_{}$ datafit'.format(s), fontsize=16)
    s_i = ['Q', 'U'].index(s)
    ymin = np.min([lightcurves_df[s].min(), model_data[0][:,s_i].min(), model_data[1][:,s_i].min(),
                   lightcurves_val_df[s].min(), model_val[0][:,s_i].min(), model_val[1][:,s_i].min()])
    ymax = np.max([lightcurves_df[s].max(), model_data[0][:,s_i].max(), model_data[1][:,s_i].max(),
                   lightcurves_val_df[s].max(), model_val[0][:,s_i].max(), model_val[1][:,s_i].max()])
    ymin -= 0.3*np.abs(ymin)
    ymax += 0.3*np.abs(ymax)
    ax.fill_between([t0, t7], [ymax, ymax], ymin, alpha=0.3, color='gray')
    ax.set_ylim([ymin, ymax])
    ax.set_xlim(left=t0)
    ax.text(9.5, ymin+0.05*np.abs(ymin), 'radio loops data', fontsize=12)
    ax.text(11.5, ymin+0.05*np.abs(ymin), 'validation', fontsize=12)
axes[0].legend(loc='best', bbox_to_anchor=(0.4, 0., 0.5, 0.5))

inc_str = '_'.join([str(inc) for inc in inclinations])
plt.savefig(recovery_path.joinpath('datafit_vs_validation_incs_{}.pdf'.format(inc_str)))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [44]:
recovery_path = Path('../data/synthetic_lightcurves/single_gaussian/recovery/sim3_intrinsic/')
with open(recovery_path.joinpath('params.yaml'), 'r') as stream:
    config = yaml.load(stream, Loader=yaml.Loader)

locals().update(config['simulation']['model'])
locals().update(config['recovery']['model'])
locals().update(config['recovery']['optimization']) 
lightcurves_df = pd.read_csv(config['simulation']['lightcurve_path'])
lightcurves_val_df = pd.read_csv(config['simulation']['validation_path'])
inc_true = config['simulation']['model']['inclination']
J_inds = [['I', 'Q', 'U'].index(s) for s in stokes]

In [38]:
inclination = 10.
checkpoint_dir = recovery_path.joinpath(basename.format(inclination, seed))
raytrace_args = image_plane_model(np.deg2rad(inclination), spin)
loss, image_plane = image_plane_fit(raytrace_args, checkpoint_dir, lightcurves_df)
model = image_plane.sum(axis=(-1,-2))

%matplotlib widget
plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica",})
axes = bhnerf.visualization.plot_stokes_lc(np.array(lightcurves_df[stokes]), stokes, np.array(lightcurves_df['t']), label='Data')
bhnerf.visualization.plot_stokes_lc(model, stokes, np.array(lightcurves_df['t']), axes=axes, color='r', fmt='x', label='Model')

titles = [r'$I_U$ datafit', r'$Q-U$ datafit']
for ax, title in zip(axes, titles):
    ax.set_title(title, fontsize=16)
    ax.legend()
    
axes[0].set_xlabel('Time [UT]', fontsize=12)
axes[1].set_xlabel('Time [UT]', fontsize=12)
plt.tight_layout()
# plt.savefig(checkpoint_dir.joinpath('QU.datafit.pdf'), bbox_inches='tight')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

0.3125

In [135]:
image_plane_flare.shape

(98, 3, 128, 128)

In [172]:
flare_path = Path(config['simulation']['flare_path'])
emission_flare = xr.load_dataarray(flare_path)

geos = bhnerf.kgeo.image_plane_geos(
    spin, np.deg2rad(inclination), 
    num_alpha=64, 
    num_beta=64, 
    alpha_range=[-fov_M/2, fov_M/2],
    beta_range=[-fov_M/2, fov_M/2]
)
geos = geos.fillna(0.0)
 # Keplerian velocity and Doppler boosting
rot_sign = {'cw': -1, 'ccw': 1}
Omega = rot_sign[Omega_dir] * np.sqrt(geos.M) / (geos.r**(3/2) + geos.spin * np.sqrt(geos.M))
umu = bhnerf.kgeo.azimuthal_velocity_vector(geos, Omega)
g = bhnerf.kgeo.doppler_factor(geos, umu)

# Magnitude normalized magnetic field in fluid-frame
b = bhnerf.kgeo.magnetic_field_fluid_frame(geos, umu, **b_consts)
domain = np.bitwise_and(np.bitwise_and(np.abs(geos.z) < z_width, geos.r > rmin), geos.r < rmax)
b_mean = np.sqrt(np.sum(b[domain]**2, axis=-1)).mean()
b /= b_mean

# Polarized emission factors (including parallel transport)
J = np.nan_to_num(bhnerf.kgeo.parallel_transport(geos, umu, g, b, Q_frac=Q_frac, V_frac=0), 0.0)
J_disk = np.nan_to_num(bhnerf.kgeo.parallel_transport(geos, umu, g, b, Q_frac=Q_frac_disk, V_frac=0), 0.0)

t_injection = -float(geos.r_o + fov_M/4)

image_plane_flare = bhnerf.emission.image_plane_dynamics(
    emission_flare, geos, Omega, t_frames, t_injection, J, t_start_obs=t_start_obs*units.hr
)
image_plane_flare_val = bhnerf.emission.image_plane_dynamics(
    emission_flare, geos, Omega, t_frames_val, t_injection, J, t_start_obs=t_start_obs*units.hr
)

psize = 2.5*(fov_M/geos.alpha.size)**2

lightcurves_flare = psize*image_plane_flare.sum(axis=(-1,-2))
lightcurves_flare_val = psize*image_plane_flare_val.sum(axis=(-1,-2))

""" Background accretion modeled as Gaussian Random Field (GRF) """
t_frames = lightcurves_df['t']
t_frames_val = lightcurves_val_df['t']

grf = xr.load_dataarray(config['simulation']['grf_path'])
grf_interp = grf.interp(t=float(grf.t[-1]) * (t_frames - t_frames.min()) / (t_frames.max() - t_frames.min()))
image_plane_disk = bhnerf.emission.grf_to_image_plane(grf_interp, geos, Omega, J_disk, diameter_M, alpha, H_r)

lightcurves_disk = psize*image_plane_disk.sum(axis=(-1,-2))
image_plane_disk *=  I_mean_disk/lightcurves_disk.mean(axis=0)[0] 
lightcurves_disk *=  I_mean_disk/lightcurves_disk.mean(axis=0)[0]

grf_interp = grf.interp(t=float(grf.t[-1]) * (t_frames_val - t_frames_val.min()) / (t_frames_val.max() - t_frames_val.min()))
image_plane_disk_val = bhnerf.emission.grf_to_image_plane(grf_interp, geos, Omega, J_disk, diameter_M, alpha, H_r)
lightcurves_disk_val = psize*image_plane_disk_val.sum(axis=(-1,-2))
image_plane_disk_val *=  I_mean_disk/lightcurves_disk_val.mean(axis=0)[0] 
lightcurves_disk_val *=  I_mean_disk/lightcurves_disk_val.mean(axis=0)[0]

#simulation_params['disk_LP'] = np.sqrt(np.sum(lightcurves_disk_val[:,1:]**2, axis=1)).mean()

image_plane = image_plane_flare + image_plane_disk
image_plane_val = image_plane_flare_val + image_plane_disk_val
lightcurves = lightcurves_flare + lightcurves_disk
lightcurves_val = image_plane_val.sum(axis=(-1,-2)) 

In [167]:
np.sqrt(np.sum(lightcurves_disk[:,1:]**2, axis=1)).mean(), np.sqrt(np.sum(lightcurves_flare[:,1:]**2, axis=1)).mean()

(0.11784497391999049, 0.14790203371248212)

In [168]:
np.sqrt(np.sum(lightcurves_disk_val[:,1:]**2, axis=1)).mean(), np.sqrt(np.sum(lightcurves_flare_val[:,1:]**2, axis=1)).mean()

(0.1109917672229807, 0.053739042551203306)

In [175]:
%matplotlib widget
lc_flare = np.vstack((lightcurves_flare, lightcurves_flare_val))
lc_disk = np.vstack((lightcurves_disk, lightcurves_disk_val))
t_all = np.concatenate((t_frames, t_frames_val))
axes = bhnerf.visualization.plot_stokes_lc(lc_flare, ['I','Q','U'], t_all, label='flare')
bhnerf.visualization.plot_stokes_lc(lc_disk, ['I','Q','U'], t_all, axes=axes)
plt.legend()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x7fd1d04e8ca0>

In [187]:
from bhnerf import constants as consts
GM_c3 = consts.GM_c3(consts.sgra_mass).to('min')

In [188]:
(2*np.pi / (np.sqrt(geos.M) / (11**(3/2) + geos.spin * np.sqrt(geos.M)))) * GM_c3

In [166]:
%matplotlib widget
lc_flare = np.vstack((lightcurves_flare, lightcurves_flare_val))
lc_disk = np.vstack((lightcurves_disk, lightcurves_disk_val))
t_all = np.concatenate((t_frames, t_frames_val))
axes = bhnerf.visualization.plot_stokes_lc(lc_flare, ['I','Q','U'], t_all, label='flare')
bhnerf.visualization.plot_stokes_lc(lc_disk, ['I','Q','U'], t_all, axes=axes)
plt.legend()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x7fd1d0711c70>

In [43]:
inclination = 10.
checkpoint_dir = recovery_path.joinpath(basename.format(inclination, seed))
raytrace_args = image_plane_model(np.deg2rad(inclination), spin)
loss, image_plane = image_plane_fit(raytrace_args, checkpoint_dir, lightcurves_val_df)
model = image_plane.sum(axis=(-1,-2))

%matplotlib widget
plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica",})
axes = bhnerf.visualization.plot_stokes_lc(np.array(lightcurves_val_df[stokes]), stokes, np.array(lightcurves_val_df['t']), label='Data')
bhnerf.visualization.plot_stokes_lc(model, stokes, np.array(lightcurves_val_df['t']), axes=axes, color='r', fmt='x', label='Model')

titles = [r'$I_U$ datafit', r'$Q-U$ datafit']
for ax, title in zip(axes, titles):
    ax.set_title(title, fontsize=16)
    ax.legend()
    
axes[0].set_xlabel('Time [UT]', fontsize=12)
axes[1].set_xlabel('Time [UT]', fontsize=12)
plt.tight_layout()
# plt.savefig(checkpoint_dir.joinpath('QU.datafit.pdf'), bbox_inches='tight')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Visualize a single recovery
---
Visualize a single 3D recovery / datafit for fixed black-hole parameters

In [32]:
recovery_path = Path('../data/synthetic_lightcurves/single_gaussian/recovery/sim3/')
with open(recovery_path.joinpath('params.yaml'), 'r') as stream:
    config = yaml.load(stream, Loader=yaml.Loader)

locals().update(config['simulation']['model'])
locals().update(config['recovery']['model'])
locals().update(config['recovery']['optimization']) 

In [31]:
seed = 1
inclination = 10
resolution = 64
checkpoint_dir = recovery_path.joinpath(basename.format(inclination, seed))

grid_1d = np.linspace(-fov_M/2, fov_M/2, resolution)
coords = np.array(np.meshgrid(grid_1d, grid_1d, grid_1d, indexing='ij'))
emission = sample_3D_recovery(checkpoint_dir, coords)

bhnerf.visualization.ipyvolume_3d(emission, fov=fov_M, level=[0.1, .2, 0.6])

VBox(children=(Figure(camera=PerspectiveCamera(fov=45.0, position=(0.0, -2.1650635094610964, 1.250000000000000…

## Compare recovery to ground truth

In [45]:
jit = False
resolution = 256
bh_radius = 1 + np.sqrt(1-spin**2)
cam_r = 55.
linewidth = 0.14
zenith=np.deg2rad(35)

visualizer = bhnerf.visualization.VolumeVisualizer(resolution, resolution, resolution)
visualizer.set_view(cam_r=cam_r, domain_r=rmax, azimuth=0.0, zenith=zenith)

emission_flare = xr.load_dataarray(config['simulation']['flare_path'])
emission_true = emission_flare.interp(x=xr.DataArray(visualizer.x), 
                                      y=xr.DataArray(visualizer.y), 
                                      z=xr.DataArray(visualizer.z)).fillna(0.0).data
norm_const =  emission_true.max()
emission_rec = sample_3D_recovery(checkpoint_dir, visualizer.coords, chunk=32)
image_true = visualizer.render(emission_true / norm_const, facewidth=1.9*rmax, jit=jit, 
                               bh_radius=bh_radius, linewidth=linewidth).clip(a_max=1)
image_rec = visualizer.render(emission_rec / norm_const, facewidth=1.9*rmax, jit=jit, 
                              bh_radius=bh_radius, linewidth=linewidth).clip(a_max=1)

In [46]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import Normalize

%matplotlib widget
images = [image_true, image_rec]
titles = ['Ground truth', 'Recovery']
fig, axes = plt.subplots(1,2, figsize=(9,4))
for ax, img, title in zip(axes, images, titles):
    ax.imshow(img)
    ax.set_title(title, fontsize=18, y=0.78)
    ax.set_axis_off()
    
ax = fig.add_subplot(132)
ax.set_visible(False)
divider = make_axes_locatable(ax)
cax = divider.append_axes('bottom', size='3%', pad=-1)
cmap = plt.cm.ScalarMappable(norm=Normalize(0, norm_const, clip=True), cmap=plt.get_cmap('hot'))
cbar = fig.colorbar(cmap, cax=cax, orientation='horizontal', shrink=.0)
cbar.ax.tick_params(labelsize=12) 
plt.tight_layout()
plt.savefig(checkpoint_dir.joinpath('gt_vs_rec.pdf'), bbox_inches='tight')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [129]:
raytrace_args = image_plane_model(np.deg2rad(inclination), spin)
loss, image_plane = image_plane_fit(raytrace_args, checkpoint_dir, lightcurves_df)
model = image_plane.sum(axis=(-1,-2))

%matplotlib widget
plt.rcParams.update({"text.usetex": True, "font.family": "Helvetica",})
axes = bhnerf.visualization.plot_stokes_lc(np.array(lightcurves_df[stokes]), stokes, np.array(lightcurves_df['t']), label='Data')
bhnerf.visualization.plot_stokes_lc(model, stokes, t_frames, axes=axes, color='r', fmt='x', label='Model')

titles = [r'$I_U$ datafit', r'$Q-U$ datafit']
for ax, title in zip(axes, titles):
    ax.set_title(title, fontsize=16)
    ax.legend()
    
axes[0].set_xlabel('Time [UT]', fontsize=12)
axes[1].set_xlabel('Time [UT]', fontsize=12)
plt.tight_layout()
# plt.savefig(checkpoint_dir.joinpath('QU.datafit.pdf'), bbox_inches='tight')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [48]:
%matplotlib widget
movie_list = [xr.DataArray(image_plane[:, i], dims=['t','beta','alpha']) for i in range(image_plane.shape[1])]
fig, axes = plt.subplots(1, 3, figsize=(10, 3))
bhnerf.visualization.animate_movies_synced(movie_list, axes, titles=['I', 'Q', 'U'], vmin=[0, -.02, -.02], vmax=[.03, .02, .01])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.animation.FuncAnimation at 0x7f143bbefd30>