# CHARIS MCMC

To speed things up, we use jax functions.

In [1]:
import multiprocessing as mp
import os
mp.set_start_method("spawn", force=True) # Jax was slowing down from os.fork() and this fixed it
os.environ["JAX_PLATFORM_NAME"] = "cpu" # Jax wasn't working with our GPU for unknown reasons
import sys
import numpy as np
from pathlib import Path
parent_dir = Path.cwd().parent
sys.path.append(str(parent_dir))
import re
from csv_tools import read_csv_physical_model_all_bins,arr_csv_HWP
from instruments_jax import *
from utils import generate_system_mueller_matrix
from pyMuellerMat.physical_models.charis_physical_models import HWP_retardance,IMR_retardance
from scipy import stats as scipy_stats
import h5py
import corner
import shutil
import tqdm

First, I'm going to read the CSVs for all wavelength bins at once so we can do a global fit. 

In [2]:
csv_dir = Path('datacsvs/csvs_pickoff')
interleaved_values_all, interleaved_stds_all, configuration_list_all = read_csv_physical_model_all_bins(csv_dir)


Now, I'm going to set everything up for MCMC. 

In [5]:
# Getting the system dictionary setup and defining starting guesses values
wavelength_bins = np.array([1159.5614, 1199.6971, 1241.2219, 1284.184 , 1328.6331, 1374.6208,
1422.2002, 1471.4264, 1522.3565, 1575.0495, 1629.5663, 1685.9701,
1744.3261, 1804.7021, 1867.1678, 1931.7956, 1998.6603, 2067.8395,
2139.4131, 2213.4641, 2290.0781, 2369.3441])
wavelength_bin = 15 # placeholder
epsilon_cal = 1 # defining as perfect, reasoning in Joost t Hart 2021
offset_imr = 0 
offset_hwp = 0
offset_cal = 0
imr_theta = 0 # placeholder 
hwp_theta = 0 # placeholder
# Past fits from scipy minimize on the naive fits
d = 259.7 
wsio2 = 1.617
wmgf2 = 1.264

# Define instrument configuration as system dictionary
# Wollaston beam, imr theta/phi, and hwp theta/phi will all be updated within functions, so don't worry about their values here

system_dict = {
"components" : {
    "wollaston" : {
        "type" : "wollaston_prism_function",
        "properties" : {"beam": 'o'}, 
        "tag": "internal",
    },
    "image_rotator" : {
        "type" : "naive_IMR_function",
        "properties" : {"wavelength":wavelength_bins[wavelength_bin]},
        "tag": "internal",
    },
    "hwp" : {
        "type" : "two_layer_HWP_function",
        "properties" : {"wavelength": wavelength_bins[wavelength_bin], "w_SiO2": wsio2, "w_MgF2": wmgf2, "theta":hwp_theta, "delta_theta": offset_hwp},
        "tag": "internal",
    },
    "lprot": { # changed from delta_theta to match Joost t Hart
        "type": "rotator_function",
        "properties" : {'pa':offset_cal},
        "tag": "internal",
    },
    "lp" : {  # calibration polarizer for internal calibration source
        "type": "diattenuator_retarder_function",
        "properties": {"epsilon": epsilon_cal},
        "tag": "internal",
    }}
}
    
# Starting guesses

p0_dict = {
    "image_rotator" : 
        {"delta_theta": offset_imr},
    "hwp" :  
        {"w_SiO2": wsio2, "w_MgF2": wmgf2,"delta_theta": offset_hwp},
    "lprot" : 
        {"pa": offset_cal}
}

system_mm = generate_system_mueller_matrix(system_dict) # Generating pyMuellerMat system MM

p0 = [1.623, 1.268, 262.56] # Starting guesses from Joost t Hart 2021 
offset_bounds = (-5.0,5.0) # Wider bounds than first MCMC run, offsets were converging to their bounds at first
d_bounds = (0.8*p0[2], 1.2*p0[2]) # Physical parameters shouldn't have changed much
imr_offset_bounds = offset_bounds
wsio2_bounds = (0.8*p0[0], 1.2*p0[0])
wmgf2_bounds = (0.8*p0[1], 1.2*p0[1])
hwp_offset_bounds = offset_bounds
cal_offset_bounds = offset_bounds

bounds = {
    "image_rotator" : 
        {"delta_theta": offset_bounds},
    "hwp" :  
        {"w_SiO2": wsio2_bounds, "w_MgF2": wmgf2_bounds, "delta_theta": offset_bounds},
    "lprot" : 
        {"pa": offset_bounds}
}

# define priors
prior_dict = {
    "image_rotator": {
        "delta_theta": {"type": "uniform", "kwargs": {"low":-5, "high": 5}},
    },
    "hwp": {
        "w_SiO2": {"type": "uniform", "kwargs": {"low": 0.8*p0[0], "high": 1.2*p0[0]}},
        "w_MgF2":{"type": "uniform", "kwargs": {"low": 0.8*p0[1], "high": 1.2*p0[1]}},
        "delta_theta": {"type": "uniform", "kwargs": {"low":-5, "high": 5}},
    },
    "lprot": {
        "pa": {"type": "uniform", "kwargs": {"low":-5, "high": 5}},
}}

# backend h5
output_h5 = Path('mcmc_output_dds_naive_imr_uniform_priors.h5')

In [6]:
# Minimize everything globally with least squares
    # Counters for iterative fitting
from vampires_calibration.fitting import minimize_system_mueller_matrix

iteration = 1
previous_logl = 1000000
new_logl = 0
boundslist = [d_bounds,offset_bounds,wsio2_bounds,wmgf2_bounds,offset_bounds,offset_bounds]
interleaved_difs = interleaved_values_all
difer = interleaved_stds_all

    # Perform iterative fitting
    # MODIFY THE BOUNDS INPUT HERE IF YOU WANT TO CHANGE PARAMETERS
while abs(previous_logl - new_logl) > 0.01*abs(previous_logl):
    if iteration > 1:
        previous_logl = new_logl
    # Configuring minimization function for CHARIS
    result, new_logl,err = minimize_system_mueller_matrix(p0_dict, system_mm, interleaved_difs, 
         configuration_list_all,s_in=[1,0,0,0], process_dataset=process_dataset,process_model=process_model,process_errors=process_errors,bounds = boundslist,include_sums=False,mode='least_squares')
    print(result)

    # Update p0 with new values

    update_p0(p0_dict, result.x)
    iteration += 1

   Iteration     Total nfev        Cost      Cost reduction    Step norm     Optimality   
       0              1         4.2030e+00                                    8.55e+00    
       1              2         2.9950e+00      1.21e+00       1.07e+00       1.36e+00    
       2              3         2.9168e+00      7.81e-02       5.14e-01       8.41e-01    
       3              4         2.9117e+00      5.18e-03       3.32e-01       2.01e-01    
       4              5         2.9084e+00      3.30e-03       1.63e+00       3.75e-01    
       5              6         2.9069e+00      1.46e-03       1.67e+00       4.12e-01    
       6              7         2.9066e+00      2.91e-04       6.63e-01       1.74e-01    
       7              8         2.9065e+00      6.90e-05       2.72e-01       7.63e-02    
       8              9         2.9065e+00      1.10e-05       1.02e-01       3.01e-02    
       9             10         2.9065e+00      1.54e-06       4.34e-02       1.18e-02    

Now I'm going to run the simulation. 

In [6]:
# Interactive plotting
%matplotlib notebook 
ndim = 6  # Number of parameters to fit
# Detect computing resources
pool_processes = max(1, 12) # Leaving one free
nwalkers = max(2 * ndim, pool_processes * 2)
if nwalkers % pool_processes != 0:
    nwalkers += pool_processes - (nwalkers % pool_processes)

print(f"Auto-detected: {pool_processes} processes, {nwalkers} walkers for {ndim} parameters")
sampler, p_keys = run_mcmc(p0_dict, system_mm, interleaved_values_all,configuration_list_all,prior_dict,bounds,output_h5,nwalkers=nwalkers,log_f=-3.0, pool_processes=pool_processes,process_model=process_model, process_errors=process_errors,process_dataset=process_dataset,nsteps=40000,plot=True, include_sums=False)

Auto-detected: 12 processes, 24 walkers for 6 parameters


<IPython.core.display.Javascript object>

 29%|██▉       | 11656/40000 [2:00:04<4:52:00,  1.62it/s]


emcee: Exception while calling your likelihood function:emcee: Exception while calling your likelihood function:emcee: Exception while calling your likelihood function:
emcee: Exception while calling your likelihood function:  params:

   params:  params:
 emcee: Exception while calling your likelihood function:  params:  
emcee: Exception while calling your likelihood function:  params:emcee: Exception while calling your likelihood function: emcee: Exception while calling your likelihood function:

  params:  params:
    params: emcee: Exception while calling your likelihood function:emcee: Exception while calling your likelihood function:

  params:  params:  emcee: Exception while calling your likelihood function:emcee: Exception while calling your likelihood function:

  params:  params:  [ 1.96358807e+000  1.43763810e+000  1.13368382e+000 -4.74553618e-001
 -3.46841728e+000  7.17592084e+127][ 8.30696462e+000  1.02320549e+000  8.35769570e-001 -1.57458125e+000
  9.79266841e+000 -8.19

KeyboardInterrupt: 