In [1]:
import os
import sys
import time
import glob
import shutil
import argparse as argp
from functools import partial
import yaml
import numpy as np
import minkasi
#from jack_minkasi import minkasi
from astropy.coordinates import Angle
from astropy import units as u
import minkasi_jax.presets_by_source as pbs
from minkasi_jax.utils import *
from minkasi_jax import helper
from minkasi_jax.core import model

%load_ext autoreload
%autoreload 2

importing mpi4py


In [2]:
def print_once(*args):
    """
    Helper function to print only once when running with MPI.
    Only the rank 0 process will print.

    Arguments:

        *args: Arguments to pass to print.
    """
    if minkasi.myrank == 0:
        print(*args)
        sys.stdout.flush()


In [3]:
import time

def timer(f, *args):
    starttime = time.time()
    result = f(*args)
    endtime = time.time()
    return endtime - starttime

In [4]:
with open('/home/r/rbond/jorlo/dev/minkasi_jax/configs/ms0735_noSub.yaml', "r") as file:
    cfg = yaml.safe_load(file)
fit = True

# Setup coordindate stuff
z = eval(str(cfg["coords"]["z"]))
da = get_da(z)
r_map = eval(str(cfg["coords"]["r_map"]))
dr = eval(str(cfg["coords"]["dr"]))
xyz = make_grid(r_map, dr)
coord_conv = eval(str(cfg["coords"]["conv_factor"]))
x0 = eval(str(cfg["coords"]["x0"]))
y0 = eval(str(cfg["coords"]["y0"]))

# Load TODs
tod_names = glob.glob(os.path.join(cfg["paths"]["tods"], cfg["paths"]["glob"]))
bad_tod, addtag = pbs.get_bad_tods(
    cfg["cluster"]["name"], ndo=cfg["paths"]["ndo"], odo=cfg["paths"]["odo"]
)
tod_names = minkasi.cut_blacklist(tod_names, bad_tod)
tod_names.sort()
tod_names = tod_names[minkasi.myrank :: minkasi.nproc]
print('tod #: ', len(tod_names))
minkasi.barrier()  # Is this needed?

#idx = np.random.randint(-307, 1058, (899454,))
#idy = np.random.randint(-307, 1058, (899454,))

tod #:  285


In [5]:
todvec = minkasi.TodVec()
n_tod = 10
for i, fname in enumerate(tod_names):
    if i >= n_tod: break
    dat = minkasi.read_tod_from_fits(fname)
    minkasi.truncate_tod(dat)

    # figure out a guess at common mode and (assumed) linear detector drifts/offset
    # drifts/offsets are removed, which is important for mode finding.  CM is *not* removed.
    dd, pred2, cm = minkasi.fit_cm_plus_poly(dat["dat_calib"], cm_ord=3, full_out=True)
    dat["dat_calib"] = dd
    dat["pred2"] = pred2
    dat["cm"] = cm

    # Make pixelized RA/Dec TODs
    idx, idy = tod_to_index(dat["dx"], dat["dy"], x0, y0, r_map, dr, coord_conv)
    idu, id_inv = np.unique(
        np.vstack((idx.ravel(), idy.ravel())), axis=1, return_inverse=True
    )
    dat["idx"] = idu[0]
    dat["idy"] = idu[1]
    dat["id_inv"] = id_inv

    tod = minkasi.Tod(dat)
    todvec.add_tod(tod)

nsamp and ndet are  182 24707.0 4496674  on  /scratch/r/rbond/jorlo/MS0735/TS_EaCMS0f0_51_5_Oct_2021/Signal_TOD-AGBT19A_092_01-s12.fits with lims  115.13548797090777 115.79331715025005 74.15935947949289 74.334082548651
truncating from  24707  to  24697
nsamp and ndet are  179 24708.0 4422732  on  /scratch/r/rbond/jorlo/MS0735/TS_EaCMS0f0_51_5_Oct_2021/Signal_TOD-AGBT19A_092_01-s13.fits with lims  115.13373944248026 115.79094024441889 74.15967366819471 74.33531881289076
truncating from  24708  to  24697
nsamp and ndet are  184 24708.0 4546272  on  /scratch/r/rbond/jorlo/MS0735/TS_EaCMS0f0_51_5_Oct_2021/Signal_TOD-AGBT19A_092_01-s14.fits with lims  115.1330017820499 115.7887955650195 74.15989906443733 74.33552371856587
truncating from  24708  to  24697
nsamp and ndet are  180 24708.0 4447440  on  /scratch/r/rbond/jorlo/MS0735/TS_EaCMS0f0_51_5_Oct_2021/Signal_TOD-AGBT19A_092_01-s15.fits with lims  115.13408095193876 115.79137737652576 74.15940046062792 74.3356056808359
truncating from  24

In [6]:
lims = todvec.lims()
pixsize = 2.0 / 3600 * np.pi / 180
skymap = minkasi.SkyMap(lims, pixsize)

Te = eval(str(cfg["cluster"]["Te"]))
freq = eval(str(cfg["cluster"]["freq"]))
beam = beam_double_gauss(
    dr,
    eval(str(cfg["beam"]["fwhm1"])),
    eval(str(cfg["beam"]["amp1"])),
    eval(str(cfg["beam"]["fwhm2"])),
    eval(str(cfg["beam"]["amp2"])),
)

funs = []
npars = []
labels = []
params = []
to_fit = []
priors = []
prior_vals = []
re_eval = []
par_idx = {}
for cur_model in cfg["models"].values():
    npars.append(len(cur_model["parameters"]))
    for name, par in cur_model["parameters"].items():
        labels.append(name)
        par_idx[name] = len(params)
        params.append(eval(str(par["value"])))
        to_fit.append(eval(str(par["to_fit"])))
        if "priors" in par:
            priors.append(par["priors"]["type"])
            prior_vals.append(eval(str(par["priors"]["value"])))
        else:
            priors.append(None)
            prior_vals.append(None)
        if "re_eval" in par and par["re_eval"]:
            re_eval.append(str(par["value"]))
        else:
            re_eval.append(False)
        2.627 * da, funs.append(eval(str(cur_model["func"])))

npars = np.array(npars)
labels = np.array(labels)
params = np.array(params)
to_fit = np.array(to_fit, dtype=bool)
priors = np.array(priors)

noise_class = eval(str(cfg["minkasi"]["noise"]["class"]))
noise_args = eval(str(cfg["minkasi"]["noise"]["args"]))
noise_kwargs = eval(str(cfg["minkasi"]["noise"]["kwargs"]))

In [7]:
#TODO: Implement tsBowl here 
if "bowling" in cfg:
    sub_poly = cfg["bowling"]["sub_poly"]

sim = False #This script is for simming, the option to turn off is here only for debugging
from minkasi_jax.core import model
#TODO: Write this to use minkasi_jax.core.model
for i, tod in enumerate(todvec.tods):
    
    temp_tod = tod.copy()
    if sim:
        pred = model(
            xyz,
            1,
            0,
            0,
            0,
            0,
            0,
            0,
            float(y2K_RJ(freq, Te)*dr*XMpc/me),
            beam,
            tod.info["idx"],
            tod.info["idy"],
            params
        )  

    if (sim) and ("id_inv" in tod.info):
        id_inv = tod.info["id_inv"]
        shape = tod.info["dx"].shape
        pred = pred[id_inv].reshape(shape)
        
    ipix = skymap.get_pix(tod)
    tod.info["ipix"] = ipix

    if sim:
        #Flip alternate TODs and add simulated profile on top
        if (i % 2) == 0:
            tod.info['dat_calib'] = -1*tod.info['dat_calib']
        else:
            tod.info['dat_calib'] = tod.info['dat_calib']

        tod.info['dat_calib'] = tod.info['dat_calib'] + np.array(pred)

In [8]:
# Figure out output
outdir = os.path.join(
    cfg["paths"]["outroot"],
    cfg["cluster"]["name"],
    "-".join(mn for mn in cfg["models"].keys()),
)
if "subdir" in cfg["paths"]:
    outdir = os.path.join(outdir, cfg["paths"]["subdir"])
if fit:
    outdir = os.path.join(outdir, "-".join(l for l in labels[to_fit]))
else:
    outdir = os.path.join(outdir, "not_fit")
if sub_poly:
    outdir += "-" + method + "_" + str(degree)
if sim:
    outdir += "-" + "sim"
print_once("Outputs can be found in", outdir)

Outputs can be found in /scratch/r/rbond/jorlo/Reductions/MS0735/double_isobeta_shock_bubbles-gauss/r1=r3/amp_1-amp_2-shock_val-b_ne_sup-b_sw_sup-sigma-amp


In [10]:
pars_fit = params
if fit:
    t1 = time.time()
    print_once("Started actual fitting")
    pars_fit, chisq, curve, errs = minkasi.fit_timestreams_with_derivs_manyfun(
        funs,
        params,
        npars,
        todvec,
        to_fit,
        maxiter=cfg["minkasi"]["maxiter"],
        priors=priors,
        prior_vals=prior_vals,
    )
    minkasi.comm.barrier()
    t2 = time.time()
    print_once("Took", t2 - t1, "seconds to fit")

    for i, re in enumerate(re_eval):
        if not re:
            continue
        pars_fit[i] = eval(re)

    print_once("Fit parameters:")
    for l, pf, err in zip(labels, pars_fit, errs):
        print_once("\t", l, "=", pf, "+/-", err)
    print_once("chisq =", chisq)

    if minkasi.myrank == 0:
        res_path = os.path.join(outdir, "results")
        print_once("Saving results to", res_path + ".npz")
        np.savez_compressed(
            res_path, pars_fit=pars_fit, chisq=chisq, errs=errs, curve=curve
        )

Started actual fitting


ValueError: at least one array or dtype is required

In [None]:
#params_extended = [0.0, 0.0, 0.0
model(
    xyz,
    1,
    0,
    0,
    1,
    0,
    0,
    0,
    float(y2K_RJ(freq, Te)*dr*XMpc/me),
    beam,
    idx,
    idy,
    params
)

In [None]:
runtime1 = timer(model, xyz, 1, 0, 0, 0, 0, 0, 0, float(y2K_RJ(freq, Te)*dr*XMpc/me), beam, idx, idy, params) 
runtime2 = timer(model, xyz, 1, 0, 0, 0, 0, 0, 0, float(y2K_RJ(freq, Te)*dr*XMpc/me), beam, idx, idy, params) 
compiletime = runtime1 - runtime2
print(f"runtime: {runtime2} compiletime: {compiletime}")

In [None]:
start = 0
delta = 9

In [None]:
jnp.array(params)[start : start + delta].shape

In [None]:
params = jnp.array(params)
params[start : start + delta].reshape((1, 9))

In [None]:
partial(helper, xyz=xyz, dx=float(y2K_RJ(freq, Te)*dr*XMpc/me), beam=beam, argnums=np.where(to_fit)[0], re_eval=re_eval, par_idx=par_idx, n_isobeta=2, n_gnfw=0, n_uniform=3, n_exponential=0)

In [None]:
runtime1 = timer(helper, xyz, float(y2K_RJ(freq, Te)*dr*XMpc/me), beam, np.where(to_fit)[0], re_eval, par_idx, 2, 0, 3, 0)
runtime2 = timer(helper, xyz, float(y2K_RJ(freq, Te)*dr*XMpc/me), beam, np.where(to_fit)[0], re_eval, par_idx, 2, 0, 3, 0)
compiletime = runtime1 - runtime2
print(f"runtime: {runtime2} compiletime: {compiletime}")