## Demo for Jax Gigalens env

**Please CHANGE the relative path before using.**

Please run this notebook on an exclusive GPU node and select the JAX GIGALENS kernel.

In [1]:
from os.path import expanduser
home = expanduser("~/")

import sys
# sys.path.insert(0, '/global/u2/x/xshuang/gigalens-xh-dev/src')
sbalta01_dev = False

if sbalta01_dev:
    sys.path.insert(0, home+'/gigalens-sbalta01-dev'+'/src')
    print('DEVELOPER MODE')

else:
    sys.path.insert(0, home+'/gigalens'+'/src')
    print('MASTER BRANCH GIGALENS')


import jax
print(jax.devices())


MASTER BRANCH GIGALENS
[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=3, process_index=0, slice_index=0)]


In [2]:
from gigalens.jax.inference import ModellingSequence
from gigalens.jax.model import ForwardProbModel, BackwardProbModel
from gigalens.model import PhysicalModel
from gigalens.jax.simulator import LensSimulator
from gigalens.simulator import SimulatorConfig
from gigalens.jax.profiles.light import sersic
from gigalens.jax.profiles.mass import epl, shear

import tensorflow_probability.substrates.jax as tfp
import jax
from jax import random
import numpy as np
import optax
from jax import numpy as jnp
from matplotlib import pyplot as plt
import optax
tfd = tfp.distributions

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=3, process_index=0, slice_index=0)]

In [4]:
lens_prior = tfd.JointDistributionSequential(
    [
        tfd.JointDistributionNamed(
            dict(
                theta_E=tfd.LogNormal(jnp.log(1.25), 0.25),
                gamma=tfd.TruncatedNormal(2, 0.25, 1, 3),
                e1=tfd.Normal(0, 0.1),
                e2=tfd.Normal(0, 0.1),
                center_x=tfd.Normal(0, 0.05),
                center_y=tfd.Normal(0, 0.05),
            )
        ),
        tfd.JointDistributionNamed(
            dict(gamma1=tfd.Normal(0, 0.05), gamma2=tfd.Normal(0, 0.05))
        ),
    ]
)
lens_light_prior = tfd.JointDistributionSequential(
    [
        tfd.JointDistributionNamed(
            dict(
                R_sersic=tfd.LogNormal(jnp.log(1.0), 0.15),
                n_sersic=tfd.Uniform(2, 6),
                e1=tfd.TruncatedNormal(0, 0.1, -0.3, 0.3),
                e2=tfd.TruncatedNormal(0, 0.1, -0.3, 0.3),
                center_x=tfd.Normal(0, 0.05),
                center_y=tfd.Normal(0, 0.05),
                Ie=tfd.LogNormal(jnp.log(500.0), 0.3),
            )
        )
    ]
)

source_light_prior = tfd.JointDistributionSequential(
    [
        tfd.JointDistributionNamed(
            dict(
                R_sersic=tfd.LogNormal(jnp.log(0.25), 0.15),
                n_sersic=tfd.Uniform(0.5, 4),
                e1=tfd.TruncatedNormal(0, 0.15, -0.5, 0.5),
                e2=tfd.TruncatedNormal(0, 0.15, -0.5, 0.5),
                center_x=tfd.Normal(0, 0.25),
                center_y=tfd.Normal(0, 0.25),
                Ie=tfd.LogNormal(jnp.log(150.0), 0.5),
            )
        )
    ]
)

prior = tfd.JointDistributionSequential(
    [lens_prior, lens_light_prior, source_light_prior]
)



In [5]:
kernel = np.load('/global/homes/l/linusu/gigalens/src/gigalens/assets/psf.npy').astype(np.float32)
sim_config = SimulatorConfig(delta_pix=0.065, num_pix=60, supersample=2, kernel=kernel)
phys_model = PhysicalModel([epl.EPL(50), shear.Shear()], [sersic.SersicEllipse(use_lstsq=False)], [sersic.SersicEllipse(use_lstsq=False)])
lens_sim = LensSimulator(phys_model, sim_config, bs=1)
observed_img = np.load('/global/homes/l/linusu/gigalens/src/gigalens/assets/demo.npy')
prob_model = ForwardProbModel(prior, observed_img, background_rms=0.2, exp_time=100)
model_seq = ModellingSequence(phys_model, prob_model, sim_config)

In [6]:
schedule_fn = optax.polynomial_schedule(init_value=-1e-2, end_value=-1e-2/3,
                                      power=0.5, transition_steps=500)
opt = optax.chain(
  optax.scale_by_adam(),
  optax.scale_by_schedule(schedule_fn),
)
#* Returns parameters for each test particle
map_estimate = model_seq.MAP(opt, seed=0)

Chi-squared: 0.986: 100%|██████████| 350/350 [00:20<00:00, 17.04it/s] 


In [7]:
#* Finds the log probability of each test particle
lps = prob_model.log_prob(LensSimulator(phys_model, sim_config, bs=500), map_estimate)[0]
#* Get parameters for best test particle
best = map_estimate[jnp.argmax(lps)][jnp.newaxis,:]

[[{'theta_E': Array([1.101779], dtype=float32),
   'gamma': Array([1.8973744], dtype=float32),
   'e2': Array([0.09508309], dtype=float32),
   'e1': Array([0.08808332], dtype=float32),
   'center_y': Array([-0.0022622], dtype=float32),
   'center_x': Array([0.10186656], dtype=float32)},
  {'gamma2': Array([0.0355338], dtype=float32),
   'gamma1': Array([-0.00309265], dtype=float32)}],
 [{'n_sersic': Array([2.6345026], dtype=float32),
   'e2': Array([0.14800864], dtype=float32),
   'e1': Array([0.09624715], dtype=float32),
   'center_y': Array([0.00048006], dtype=float32),
   'center_x': Array([0.0996528], dtype=float32),
   'R_sersic': Array([0.8305198], dtype=float32),
   'Ie': Array([461.96448], dtype=float32)}],
 [{'n_sersic': Array([1.6322263], dtype=float32),
   'e2': Array([0.00659892], dtype=float32),
   'e1': Array([-0.00838381], dtype=float32),
   'center_y': Array([-0.0588885], dtype=float32),
   'center_x': Array([0.09711999], dtype=float32),
   'R_sersic': Array([0.2291743]

In [8]:
print(best)

[[ 1.0186656e-01 -2.2622033e-03  8.8083319e-02  9.5083088e-02
  -2.0597641e-01  9.6926227e-02 -3.0926461e-03  3.5533804e-02
   6.1354880e+00 -1.8570353e-01  9.9652804e-02  4.8005613e-04
   6.6513050e-01  1.0809889e+00 -1.6684896e+00  4.9255633e+00
  -1.4732724e+00  9.7119994e-02 -5.8888499e-02 -3.3538461e-02
   2.6397170e-02 -7.3776418e-01]]


In [9]:
map_estimate

Array([[ 1.00590356e-01, -1.88869960e-03,  8.27936754e-02, ...,
        -5.80707341e-02,  4.30759013e-01, -5.77628791e-01],
       [-2.95598179e-01, -2.29641661e-01,  2.70568859e-02, ...,
         3.01679850e-01, -5.91614366e-01,  6.18013978e-01],
       [ 6.01531826e-02,  7.43113011e-02, -1.11437380e-01, ...,
        -1.13873208e+00,  3.15286458e-01, -4.71661472e+00],
       ...,
       [ 1.67986706e-01, -5.49579784e-02, -5.11551261e-01, ...,
        -1.20699131e+00,  1.03165436e+00,  6.17846727e-01],
       [-4.13139649e-02,  7.59302795e-01, -2.79936194e-01, ...,
         3.08737326e+00,  1.03677392e+00, -2.52393317e+00],
       [ 4.93191183e-01, -4.98633198e-02, -6.32646829e-02, ...,
        -9.57388878e-01, -5.87239623e-01,  6.81307495e-01]],      dtype=float32)

In [None]:
# NCCL_IB_DISABLE=1
# NCCL_SOCKET_IFNAME=enp

In [None]:
schedule_fn = optax.polynomial_schedule(init_value=-1e-6, end_value=-3e-3,
                                      power=2, transition_steps=300)
opt = optax.chain(
  optax.scale_by_adam(),
  optax.scale_by_schedule(schedule_fn),
)
qz, loss_hist = model_seq.SVI(best, opt, n_vi=1000, num_steps=1500)

In [None]:
plt.plot(loss_hist)
plt.savefig("./output.jpg")

In [None]:
samples = model_seq.HMC(qz, num_burnin_steps=250, num_results=750)

In [None]:
rhat= tfp.mcmc.potential_scale_reduction(jnp.transpose(samples.all_states, (1,2,0,3)), independent_chain_ndims=2)

print(rhat)