In [1]:
import bhnerf
from astropy import units
import jax
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import ruamel.yaml as yaml
import warnings
from bhnerf.optimization import LogFn
warnings.simplefilter("ignore")

# Runing on 2 GPUs
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6,7'

Welcome to eht-imaging! v 1.2.5 





Load synthetic data
---
Data was precomputed using: `notebooks/Synthetic lightcurves 0 - Generate data.ipynb`

In [2]:
data_path = Path('../data/synthetic_lightcurves/flux_tube/sim1_lightcurve.csv')
lightcurves_df = pd.read_csv(data_path)
lightcurves_df.head()

Unnamed: 0,t,I,Q,U
0,9.340563,0.265225,0.047161,0.161291
1,9.35067,0.261085,0.058046,0.153017
2,9.360777,0.256678,0.068029,0.143825
3,9.370883,0.252142,0.076992,0.133942
4,9.38099,0.247599,0.084752,0.12358


# Recover 3D emission 
---
Recover the unknown 3D emission directly from the polarized lightcurves using bh-NeRF. \
This recovery is an idealized recovery with no systematic noise modeling.

In [3]:
with open(data_path.parent.joinpath('sim1_params.yaml'), 'r') as stream:
    simulation_params = yaml.load(stream, Loader=yaml.Loader)
locals().update(simulation_params['model'])

recovery_params = { 
    # Domain dimensions
    'z_width': 4,                                   # maximum disk width [M]
    'rmax': fov_M / 2,                              # maximum recovery radius
    'rmin': float(bhnerf.constants.isco_pro(spin)), # minimum recovery radius
    'recovery_scale': 1.0,                          # feature scale for recovery [M] 
    
    # Optimization
    'stokes': ['I', 'Q', 'U'],
    'batchsize': 6,
    'sigma': 1.0,
    'hparams': {'num_iters': 50000, 'lr_init': 1e-4, 'lr_final': 1e-6, 'seed': 1}
}

params = simulation_params['model'] | recovery_params

In [4]:
# Load ground truth flare for comparison
sim_name = simulation_params['name']
flare_path = Path(simulation_params['flare_path'])
emission_flare = emission_scale * xr.load_dataarray(flare_path)

# Ray tracing parameters
geos, Omega, J = bhnerf.alma.image_plane_model(np.deg2rad(inclination), spin, params)
t_injection = -float(geos.r_o + fov_M/4)

# Network / Optimization parameters
raytracing_args = bhnerf.network.raytracing_args(geos, Omega, t_injection, t_start_obs*units.hr, J)
predictor = bhnerf.network.NeRF_Predictor(rmax, rmin, rmax, z_width)

In [5]:
%matplotlib inline
log_period = 500
recovery_dir = data_path.parent.joinpath('recovery/{}_seed{}_stokes_{}_idealized'.format(sim_name, hparams['seed'], ''.join(stokes)))
t_frames, target = np.array(lightcurves_df['t']), np.array(lightcurves_df[stokes])
train_step = bhnerf.optimization.TrainStep.image(t_frames, target, sigma, dtype='lc')

writer = bhnerf.optimization.SummaryWriter(logdir=recovery_dir)
writer.add_images('emission/true', bhnerf.utils.intensity_to_nchw(emission_flare), global_step=0)
log_fns = [
    LogFn(lambda opt: writer.add_scalar('log_loss/train', np.log10(np.mean(opt.loss)), global_step=opt.step)), 
    LogFn(lambda opt: writer.recovery_3d(fov_M, emission_true=emission_flare)(opt), log_period=log_period),
    LogFn(lambda opt: writer.plot_lc_datafit(opt, 'training', train_step, target, stokes, t_frames, batchsize=20), log_period=log_period)
]

# Optimization
optimizer = bhnerf.optimization.Optimizer(hparams, predictor, raytracing_args, checkpoint_dir=recovery_dir)
optimizer.run(batchsize, train_step, raytracing_args, log_fns=log_fns)
writer.close()

out_params = {'simulation': simulation_params, 'recovery': recovery_params}
with open('{}/params.yaml'.format(recovery_dir), 'w') as file:
    yaml.dump(out_params, file, default_flow_style=False)

iteration:   0%|          | 0/50000 [00:00<?, ?it/s]

In [6]:
"""
Visualize the recovered 3D emission
This visualization requires ipyvolume: https://ipyvolume.readthedocs.io/en/latest/
"""
emission_estimate = bhnerf.network.sample_3d_grid(predictor.apply, optimizer.params, fov=fov_M)
bhnerf.visualization.ipyvolume_3d(emission_estimate, fov=fov_M, level=[0, 0.2, 0.7])

Container(figure=Figure(box_center=[0.5, 0.5, 0.5], box_size=[1.0, 1.0, 1.0], camera=PerspectiveCamera(fov=45.…