In [17]:
import bhnerf
from astropy import units
import jax
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import ruamel.yaml as yaml

# Runing on 2 GPUs
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1,3'

import warnings
from bhnerf.optimization import LogFn
warnings.simplefilter("ignore")

# Load synthetic data
---

In [30]:
data_path = Path('../data/synthetic_lightcurves/flux_tube/sim1_lightcurve.csv')
lightcurves_df = pd.read_csv(data_path)
lightcurves_df.head()

Unnamed: 0,t,I,Q,U
0,9.340563,0.260572,0.096583,0.12208
1,9.35067,0.255594,0.102229,0.110132
2,9.360777,0.25052,0.106546,0.097898
3,9.370883,0.245289,0.109462,0.085561
4,9.38099,0.240443,0.110972,0.072903


# Recover 3D emission 
---
Recover the unknown 3D emission directly from the polarized lightcurves using bh-NeRF. \
This recovery is an idealized recovery with no systematic noise modeling.

In [31]:
with open(data_path.parent.joinpath('sim1_params.yml'), 'r') as stream:
    simulation_params = yaml.load(stream, Loader=yaml.Loader)
locals().update(simulation_params['model'])

recovery_params = { 
    # Domain dimensions
    'z_width': 4,                                   # maximum disk width [M]
    'rmax': fov_M / 2,                              # maximum recovery radius
    'rmin': float(bhnerf.constants.isco_pro(spin)), # minimum recovery radius
    'recovery_scale': 1.0,                          # feature scale for recovery [M] 
    
    # Optimization
    'stokes': ['I', 'Q', 'U'],
    'batchsize': 6,
    'sigma': 1.0,
    'hparams': {'num_iters': 50000, 'lr_init': 1e-4, 'lr_final': 1e-6, 'seed': 1}
}

locals().update(recovery_params)

In [32]:
# Load ground truth flare for comparison
sim_name = simulation_params['name']
flare_path = Path(simulation_params['flare_path'])
emission_flare = xr.load_dataarray(flare_path)

# Compute geodesics
geos = bhnerf.kgeo.image_plane_geos(
    spin, np.deg2rad(inclination), 
    num_alpha=num_alpha, num_beta=num_beta, 
    alpha_range=[-fov_M/2, fov_M/2],
    beta_range=[-fov_M/2, fov_M/2])
t_injection = -float(geos.r_o + fov_M/4)

# Keplerian velocity field
rot_sign = {'cw': -1, 'ccw': 1}
Omega = rot_sign[Omega_dir] * np.sqrt(geos.M) / (geos.r**(3/2) + geos.spin * np.sqrt(geos.M))
umu = bhnerf.kgeo.azimuthal_velocity_vector(geos, Omega)
g = bhnerf.kgeo.doppler_factor(geos, umu)

# Magnitude normalized magnetic field in fluid-frame
b = bhnerf.kgeo.magnetic_field_fluid_frame(geos, umu, **b_consts)
domain = np.bitwise_and(np.bitwise_and(np.abs(geos.z) < z_width, geos.r > rmin), geos.r < rmax)
b_mean = np.sqrt(np.sum(b[domain]**2, axis=-1)).mean()
b /= b_mean

# Polarized emission factors (including parallel transport)
J = np.nan_to_num(bhnerf.kgeo.parallel_transport(geos, umu, g, b, Q_frac=Q_frac, V_frac=0), 0.0)

# Network / Optimization parameters
raytracing_args = bhnerf.network.raytracing_args(geos, Omega, t_injection, t_start_obs*units.hr, J)
predictor = bhnerf.network.NeRF_Predictor(rmax, rmin, rmax, z_width, posenc_var=recovery_scale/fov_M)

In [33]:
%matplotlib inline
log_period = 500
recovery_dir = data_path.parent.joinpath('recovery/{}.seed_{}.no_systematics.{}'.format(sim_name, hparams['seed'], ''.join(stokes)))
t_frames, target = np.array(lightcurves_df['t']), np.array(lightcurves_df[stokes])
train_step = bhnerf.optimization.TrainStep.image(t_frames, target, sigma, dtype='lc')

writer = bhnerf.optimization.SummaryWriter(logdir=recovery_dir)
writer.add_images('emission/true', bhnerf.utils.intensity_to_nchw(emission_flare), global_step=0)
log_fns = [
    LogFn(lambda opt: writer.add_scalar('log_loss/train', np.log10(np.mean(opt.loss)), global_step=opt.step)), 
    LogFn(lambda opt: writer.recovery_3d(fov_M, emission_true=emission_flare)(opt), log_period=log_period),
    LogFn(lambda opt: writer.plot_lc_datafit(opt, target, stokes, t_frames, batchsize=20), log_period=log_period)
]

# Optimization
optimizer = bhnerf.optimization.Optimizer(hparams, predictor, raytracing_args, checkpoint_dir=recovery_dir)
optimizer.run(batchsize, train_step, raytracing_args, log_fns=log_fns)
writer.close()

params = {'simulation': simulation_params, 'recovery': recovery_params}
with open('{}/params.yml'.format(recovery_dir), 'w') as file:
    yaml.dump(params, file, default_flow_style=False)

iteration:   0%|          | 0/50000 [00:00<?, ?it/s]