In [1]:
import os
import sys
import time
import glob
import shutil
import argparse as argp
from functools import partial
import yaml
import numpy as np
import minkasi
#from jack_minkasi import minkasi
from astropy.coordinates import Angle
from astropy import units as u
import minkasi_jax.presets_by_source as pbs
from minkasi_jax.utils import *
from minkasi_jax import helper
from minkasi_jax.core import model

%load_ext autoreload
%autoreload 2

importing mpi4py


In [2]:
from functools import partial
import jax
import jax.numpy as jnp
from minkasi_jax.utils import fft_conv
from minkasi_jax.structure import (
    isobeta,
    gnfw,
    gaussian,
    add_uniform,
    add_exponential,
    add_powerlaw,
    add_powerlaw_cos,
)

import jax.lax as lax

N_PAR_ISOBETA = 9
N_PAR_GNFW = 14
N_PAR_GAUSSIAN = 9
N_PAR_UNIFORM = 8
N_PAR_EXPONENTIAL = 14
N_PAR_POWERLAW = 11

ARGNUM_SHIFT = 11

@partial(
    jax.jit,
    static_argnums=(1, 2, 3, 4, 5, 6, 7, 8),
)
def model(
    xyz,
    n_isobeta,
    n_gnfw,
    n_gaussian,
    n_uniform,
    n_exponential,
    n_powerlaw,
    n_powerlaw_cos,
    dx,
    beam,
    idx,
    idy,
    *params
):
    """
    Generically create models with substructure.

    Arguments:

        xyz: Coordinate grid to compute profile on.

        n_isobeta: Number of isobeta profiles to add.

        n_gnfw: Number of gnfw profiles to add.

        n_gaussian: Number of gaussians to add.

        n_uniform: Number of uniform ellipsoids to add.

        n_exponential: Number of exponential ellipsoids to add.

        n_powerlaw: Number of power law ellipsoids to add.

        n_powerlaw_cos: Number of radial power law ellipsoids with angulas cos term to add.

        dx: Factor to scale by while integrating.
            Since it is a global factor it can contain unit conversions.
            Historically equal to y2K_RJ * dr * da * XMpc / me.

        beam: Beam to convolve by, should be a 2d array.

        idx: RA TOD in units of pixels.
             Should have Dec stretch applied.

        idy: Dec TOD in units of pixels.

        params: 1D array of model parameters.

    Returns:

        model: The model with the specified substructure.
    """
    params = jnp.array(params)
    params = jnp.ravel(params)
    isobetas = jnp.zeros((1, 1), dtype=float)
    gnfws = jnp.zeros((1, 1), dtype=float)
    gaussians = jnp.zeros((1, 1), dtype=float)
    uniforms = jnp.zeros((1, 1), dtype=float)
    exponentials = jnp.zeros((1, 1), dtype=float)
    powerlaws = jnp.zeros((1, 1), dtype=float)

    start = 0
    if n_isobeta:
        delta = n_isobeta * N_PAR_ISOBETA
        #isobetas = lax.slice(params, (1, start), (1, start + delta)).reshape((n_isobeta, N_PAR_ISOBETA))
        isobetas = params[start : start + delta].reshape((n_isobeta, N_PAR_ISOBETA))
        start += delta
    if n_gnfw:
        delta = n_gnfw * N_PAR_GNFW
        gnfws = params[start : start + delta].reshape((n_gnfw, N_PAR_GNFW))
        start += delta
    if n_gaussian:
        delta = n_gaussian * N_PAR_GAUSSIAN
        gaussians = params[start : start + delta].reshape((n_gaussian, N_PAR_GAUSSIAN))
        start += delta
    if n_uniform:
        delta = n_uniform * N_PAR_UNIFORM
        uniforms = params[start : start + delta].reshape((n_uniform, N_PAR_UNIFORM))
        start += delta
    if n_exponential:
        delta = n_exponential * N_PAR_EXPONENTIAL
        exponentials = params[start : start + delta].reshape(
            (n_exponential, N_PAR_EXPONENTIAL)
        )
        start += delta
    if n_powerlaw:
        delta = n_powerlaw * N_PAR_POWERLAW
        powerlaws = params[start : start + delta].reshape((n_powerlaw, N_PAR_POWERLAW))
        start += delta
    if n_powerlaw_cos:
        delta = n_powerlaw_cos * N_PAR_POWERLAW
        powerlaw_coses = params[start : start + delta].reshape(
            (n_powerlaw_cos, N_PAR_POWERLAW)
        )
        start += delta

    pressure = jnp.zeros((xyz[0].shape[1], xyz[1].shape[0], xyz[2].shape[2]))
    for i in range(n_isobeta):
        pressure = jnp.add(pressure, isobeta(*isobetas[i], xyz))

    for i in range(n_gnfw):
        pressure = jnp.add(pressure, gnfw(*gnfws[i], xyz))

    for i in range(n_gaussian):
        pressure = jnp.add(pressure, gaussian(*gaussians[i], xyz))

    for i in range(n_uniform):
        pressure = add_uniform(pressure, xyz, *uniforms[i])

    for i in range(n_exponential):
        pressure = add_exponential(pressure, xyz, *exponentials[i])

    for i in range(n_powerlaw):
        pressure = add_powerlaw(pressure, xyz, *powerlaws[i])

    for i in range(n_powerlaw_cos):
        pressure = add_powerlaw_cos(pressure, xyz, *powerlaw_coses[i])

    # Integrate along line of site
    ip = jnp.trapz(pressure, dx=dx, axis=-1)

    bound0, bound1 = int((ip.shape[0] - beam.shape[0]) / 2), int(
        (ip.shape[1] - beam.shape[1]) / 2
    )
    beam = jnp.pad(
        beam,
        (
            (bound0, ip.shape[0] - beam.shape[0] - bound0),
            (bound1, ip.shape[1] - beam.shape[1] - bound1),
        ),
    )

    ip = fft_conv(ip, beam)

    # return jsp.ndimage.map_coordinates(ip, (idy, idx), order=0)
    return ip[idy.ravel(), idx.ravel()].reshape(idx.shape)

In [3]:
def print_once(*args):
    """
    Helper function to print only once when running with MPI.
    Only the rank 0 process will print.

    Arguments:

        *args: Arguments to pass to print.
    """
    if minkasi.myrank == 0:
        print(*args)
        sys.stdout.flush()


In [4]:
import time

def timer(f, *args):
    starttime = time.time()
    result = f(*args)
    endtime = time.time()
    return endtime - starttime

In [5]:
with open('/home/r/rbond/jorlo/dev/minkasi_jax/configs/ms0735_noSub.yaml', "r") as file:
    cfg = yaml.safe_load(file)
#with open('/home/r/rbond/jorlo/dev/minkasi_jax/configs/ms0735/ms0735.yaml', "r") as file:
#    cfg = yaml.safe_load(file)
fit = True

# Setup coordindate stuff
z = eval(str(cfg["coords"]["z"]))
da = get_da(z)
r_map = eval(str(cfg["coords"]["r_map"]))
dr = eval(str(cfg["coords"]["dr"]))
xyz = make_grid(r_map, dr)
coord_conv = eval(str(cfg["coords"]["conv_factor"]))
x0 = eval(str(cfg["coords"]["x0"]))
y0 = eval(str(cfg["coords"]["y0"]))

# Load TODs
tod_names = glob.glob(os.path.join(cfg["paths"]["tods"], cfg["paths"]["glob"]))
bad_tod, addtag = pbs.get_bad_tods(
    cfg["cluster"]["name"], ndo=cfg["paths"]["ndo"], odo=cfg["paths"]["odo"]
)
tod_names = minkasi.cut_blacklist(tod_names, bad_tod)
tod_names.sort()
tod_names = tod_names[minkasi.myrank :: minkasi.nproc]
print('tod #: ', len(tod_names))
minkasi.barrier()  # Is this needed?

#idx = np.random.randint(-307, 1058, (899454,))
#idy = np.random.randint(-307, 1058, (899454,))

tod #:  285


In [6]:
todvec = minkasi.TodVec()
n_tod = 10
for i, fname in enumerate(tod_names):
    if i >= n_tod: break
    dat = minkasi.read_tod_from_fits(fname)
    minkasi.truncate_tod(dat)

    # figure out a guess at common mode and (assumed) linear detector drifts/offset
    # drifts/offsets are removed, which is important for mode finding.  CM is *not* removed.
    dd, pred2, cm = minkasi.fit_cm_plus_poly(dat["dat_calib"], cm_ord=3, full_out=True)
    dat["dat_calib"] = dd
    dat["pred2"] = pred2
    dat["cm"] = cm

    # Make pixelized RA/Dec TODs
    idx, idy = tod_to_index(dat["dx"], dat["dy"], x0, y0, r_map, dr, coord_conv)
    idu, id_inv = np.unique(
        np.vstack((idx.ravel(), idy.ravel())), axis=1, return_inverse=True
    )
    dat["idx"] = idu[0]
    dat["idy"] = idu[1]
    dat["id_inv"] = id_inv
    dat["model_idx"] = idx
    dat["model_idy"] = idy

    tod = minkasi.Tod(dat)
    todvec.add_tod(tod)

nsamp and ndet are  182 24707.0 4496674  on  /scratch/r/rbond/jorlo/MS0735/TS_EaCMS0f0_51_5_Oct_2021/Signal_TOD-AGBT19A_092_01-s12.fits with lims  115.13548797090777 115.79331715025005 74.15935947949289 74.334082548651
truncating from  24707  to  24697
nsamp and ndet are  179 24708.0 4422732  on  /scratch/r/rbond/jorlo/MS0735/TS_EaCMS0f0_51_5_Oct_2021/Signal_TOD-AGBT19A_092_01-s13.fits with lims  115.13373944248026 115.79094024441889 74.15967366819471 74.33531881289076
truncating from  24708  to  24697
nsamp and ndet are  184 24708.0 4546272  on  /scratch/r/rbond/jorlo/MS0735/TS_EaCMS0f0_51_5_Oct_2021/Signal_TOD-AGBT19A_092_01-s14.fits with lims  115.1330017820499 115.7887955650195 74.15989906443733 74.33552371856587
truncating from  24708  to  24697
nsamp and ndet are  180 24708.0 4447440  on  /scratch/r/rbond/jorlo/MS0735/TS_EaCMS0f0_51_5_Oct_2021/Signal_TOD-AGBT19A_092_01-s15.fits with lims  115.13408095193876 115.79137737652576 74.15940046062792 74.3356056808359
truncating from  24

In [7]:
lims = todvec.lims()
pixsize = 2.0 / 3600 * np.pi / 180
skymap = minkasi.SkyMap(lims, pixsize)

Te = eval(str(cfg["cluster"]["Te"]))
freq = eval(str(cfg["cluster"]["freq"]))
beam = beam_double_gauss(
    dr,
    eval(str(cfg["beam"]["fwhm1"])),
    eval(str(cfg["beam"]["amp1"])),
    eval(str(cfg["beam"]["fwhm2"])),
    eval(str(cfg["beam"]["amp2"])),
)

funs = []
npars = []
labels = []
params = []
to_fit = []
priors = []
prior_vals = []
re_eval = []
par_idx = {}
for cur_model in cfg["models"].values():
    npars.append(len(cur_model["parameters"]))
    for name, par in cur_model["parameters"].items():
        labels.append(name)
        par_idx[name] = len(params)
        params.append(eval(str(par["value"])))
        to_fit.append(eval(str(par["to_fit"])))
        if "priors" in par:
            priors.append(par["priors"]["type"])
            prior_vals.append(eval(str(par["priors"]["value"])))
        else:
            priors.append(None)
            prior_vals.append(None)
        if "re_eval" in par and par["re_eval"]:
            re_eval.append(str(par["value"]))
        else:
            re_eval.append(False)
        2.627 * da, funs.append(eval(str(cur_model["func"])))

npars = np.array(npars)
labels = np.array(labels)
params = np.array(params)
to_fit = np.array(to_fit, dtype=bool)
priors = np.array(priors)

noise_class = eval(str(cfg["minkasi"]["noise"]["class"]))
noise_args = eval(str(cfg["minkasi"]["noise"]["args"]))
noise_kwargs = eval(str(cfg["minkasi"]["noise"]["kwargs"]))

In [8]:
#TODO: Implement tsBowl here 
if "bowling" in cfg:
    sub_poly = cfg["bowling"]["sub_poly"]

sim = False #This script is for simming, the option to turn off is here only for debugging
from minkasi_jax.core import model
#TODO: Write this to use minkasi_jax.core.model
for i, tod in enumerate(todvec.tods):
    
    temp_tod = tod.copy()
    if sim:
        pred = model(
            xyz,
            1,
            0,
            0,
            0,
            0,
            0,
            0,
            float(y2K_RJ(freq, Te)*dr*XMpc/me),
            beam,
            tod.info["idx"],
            tod.info["idy"],
            params
        )  

    if (sim) and ("id_inv" in tod.info):
        id_inv = tod.info["id_inv"]
        shape = tod.info["dx"].shape
        pred = pred[id_inv].reshape(shape)
        
    ipix = skymap.get_pix(tod)
    tod.info["ipix"] = ipix

    if sim:
        #Flip alternate TODs and add simulated profile on top
        if (i % 2) == 0:
            tod.info['dat_calib'] = -1*tod.info['dat_calib']
        else:
            tod.info['dat_calib'] = tod.info['dat_calib']

        tod.info['dat_calib'] = tod.info['dat_calib'] + np.array(pred)

In [9]:
# Figure out output
outdir = os.path.join(
    cfg["paths"]["outroot"],
    cfg["cluster"]["name"],
    "-".join(mn for mn in cfg["models"].keys()),
)
if "subdir" in cfg["paths"]:
    outdir = os.path.join(outdir, cfg["paths"]["subdir"])
if fit:
    outdir = os.path.join(outdir, "-".join(l for l in labels[to_fit]))
else:
    outdir = os.path.join(outdir, "not_fit")
if sub_poly:
    outdir += "-" + method + "_" + str(degree)
if sim:
    outdir += "-" + "sim"
print_once("Outputs can be found in", outdir)

Outputs can be found in /scratch/r/rbond/jorlo/Reductions/MS0735/double_isobeta_shock_bubbles-gauss/r1=r3/amp_1-amp_2-shock_val-b_ne_sup-b_sw_sup-sigma-amp


In [16]:
ids_inv = np.array([])
idxs = np.array([])
idys = np.array([])

for tod in todvec.tods:
    ids_inv = np.append(ids_inv, tod.info["id_inv"])
    idxs = np.append(idxs, tod.info["idx"])
    idys = np.append(idys, tod.info["idy"])
    
ids_inv = np.unique(ids_inv)


def sample(ids_inv):
    pred = pred = model(
            xyz,
            1,
            0,
            0,
            0,
            0,
            0,
            0,
            float(y2K_RJ(freq, Te)*dr*XMpc/me),
            beam,
            idxs,
            idys,
            params
        )  



In [None]:
def get_chis(dat, pred):
    resid = dat - pred
    pred[:,0]=0.5*pred[:,0]
    pred[:,-1]=0.5*pred[:,-1]
    pred_rot = np.dot(self.v,pred)
    predft = mkfftw.fft_r2r(pred_rot)
    nn=predft.shape[1]
    chisq = np.sum(self.mywt[:,:nn]*predft**2)

In [18]:
new_pars = jnp.array([0.0,0.0,0.0,.341,.249,.341,np.deg2rad(97),0.98,1e-5,
                      0.0,0.0,0.0,.167,.122,.167,np.deg2rad(97),8.93,1e-5])    
    

In [19]:
temp_tod = np.array(tod.info["dat_calib"])
idx, idy = tod.info["model_idx"], tod.info["model_idy"]
pred = model(xyz, 2, 0, 0, 0, 0, 0, 0, float(y2K_RJ(freq, Te)*dr*XMpc/me), beam, idx, idy, new_pars)
jfft = jax.jit(jnp.fft.rfft)
 
def just_fft(tod, pred):
    resid = tod - pred
    _ = jfft(resid).block_until_ready()
vfft = jax.vmap(just_fft)

2023-08-15 09:29:13.543326: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:461] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.76GiB (rounded to 2965299968)requested by op 
2023-08-15 09:29:13.543478: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:472] ************************************________________________________________________________________
2023-08-15 09:29:13.543508: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2085] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2965299856 bytes.


RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2965299856 bytes.

In [None]:
vfft(temp_tod)

In [18]:

                      #0.0,0.0,0.0,.320,.320,.320,np.deg2rad(97),0.26])

idx, idy = np.array([]), np.array([])

for tod in todvec.tods:
    idx = np.append(idx, tod.info['model_idx'])
    idy = np.append(idy, tod.info['model_idy'])

print(idx.shape)

(43145659,)


In [22]:
id_inv = tod.info["id_inv"]
shape = tod.info["dx"].shape
#pred = pred[id_inv].reshape(shape)


In [20]:
starttime = time.time()
model(xyz, 2, 0, 0, 0, 0, 0, 0, float(y2K_RJ(freq, Te)*dr*XMpc/me), beam, tod.info["model_idx"], tod.info["model_idy"], new_pars).block_until_ready()
endtime = time.time()
    
print(endtime-starttime)

0.04680347442626953


In [52]:
import time
import numpy as np

import jax
from jax import numpy as jnp

from minkasi import mkfftw

residuals = tod.info["dat_calib"] - model(xyz, 2, 0, 0, 0, 0, 0, 0, float(y2K_RJ(freq, Te)*dr*XMpc/me), beam, idx, idy, params)

jfft = jax.jit(jnp.fft.rfft)

R = 100

ts = time.time()
for i in range(R):
    _ = jfft(residuals).block_until_ready()
print('jax fft execution time [ms]:\t', (time.time()-ts)/R * 1000)

jax fft execution time [ms]:	 1.3693642616271973


In [40]:
fftw_res = np.array(residuals)

R = 100

ts = time.time()
for i in range(R):
    _ = mkfftw.fft_r2r(fftw_res)
print('mkfftw fft execution time [ms]:\t', (time.time()-ts)/R * 1000)



mkfftw fft execution time [ms]:	 4.868004322052002


In [None]:
start = 0
delta = 9

In [None]:
jnp.array(params)[start : start + delta].shape

In [None]:
params = jnp.array(params)
params[start : start + delta].reshape((1, 9))

In [None]:
partial(helper, xyz=xyz, dx=float(y2K_RJ(freq, Te)*dr*XMpc/me), beam=beam, argnums=np.where(to_fit)[0], re_eval=re_eval, par_idx=par_idx, n_isobeta=2, n_gnfw=0, n_uniform=3, n_exponential=0)

In [13]:
runtime1 = timer(helper, xyz, float(y2K_RJ(freq, Te)*dr*XMpc/me), beam, np.where(to_fit)[0], re_eval, par_idx, 2, 0, 3, 0)
runtime2 = timer(helper, xyz, float(y2K_RJ(freq, Te)*dr*XMpc/me), beam, np.where(to_fit)[0], re_eval, par_idx, 2, 0, 3, 0)
compiletime = runtime1 - runtime2
print(f"runtime: {runtime2} compiletime: {compiletime}")

AttributeError: 'float' object has no attribute 'info'

In [17]:
starttime = time.time()
model(xyz, 2, 0, 0, 1, 0, 0, 0, float(y2K_RJ(freq, Te)*dr*XMpc/me), beam, idx, idy, new_pars).block_until_ready()
endtime = time.time()

print(endtime-starttime)

2023-08-14 13:25:45.026716: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:461] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.76GiB (rounded to 2961169920)requested by op 
2023-08-14 13:25:45.026929: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:472] *___________________________________________________________________________________________________
2023-08-14 13:25:45.026989: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2085] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2961169856 bytes.


RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2961169856 bytes.

In [None]:
R = 100

ts = time.time()
for i in range(R):
    model(xyz, 2, 0, 0, 1, 0, 0, 0, float(y2K_RJ(freq, Te)*dr*XMpc/me), beam, idx, idy, new_pars).block_until_ready()
print('jax fft execution time [ms]:\t', (time.time()-ts)/R * 1000)

(718, 718, 718)

In [22]:
jnp.zeros((xyz[0].shape[1], xyz[1].shape[0], xyz[2].shape[2]))

2023-08-14 13:27:14.483879: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:461] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.76GiB (rounded to 2961169920)requested by op 
2023-08-14 13:27:14.484083: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:472] **__________________________________________________________________________________________________
2023-08-14 13:27:14.484131: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2085] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2961169856 bytes.


RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2961169856 bytes.

In [13]:
idx.shape

(164, 24697)

In [12]:
tods = []
ids_inv = np.array([])
for tod in todvec.tods:
    tods.append([tod.info["model_idx"], tod.info["model_idy"], tod.info["dat_calib"]])
    

In [21]:
def sample(tods):
    pred = model(xyz, 2, 0, 0, 0, 0, 0, 0, float(y2K_RJ(freq, Te)*dr*XMpc/me), beam, idx, idy, new_pars)
    for tod in tods:
        idx, idy, dat = tod[0], tod[1], tod[2]
        
        resid = dat - pred
        _ = jfft(resid).block_until_ready() 


vsample = jax.vmap(sample)

In [22]:
jfft = jax.jit(jnp.fft.rfft)
new_pars = jnp.array([0.0,0.0,0.0,.341,.249,.341,np.deg2rad(97),0.98,1e-5,
                      0.0,0.0,0.0,.167,.122,.167,np.deg2rad(97),8.93,1e-5])
                      #0.0,0.0,0.0,.320,.320,.320,np.deg2rad(97),0.26])
R = 100

id_inv = tod.info["id_inv"]
shape = tod.info["dx"].shape
pred = pred[id_inv].reshape(shape)

ts = time.time()
for i in range(R):
    vsample(tods)
print('jax fft execution time [ms]:\t', (time.time()-ts)/R * 1000)

ValueError: vmap got inconsistent sizes for array axes to be mapped:
the tree of axis sizes is:
([[182, 182, 182], [179, 179, 179], [184, 184, 184], [180, 180, 180], [177, 177, 177], [180, 180, 180], [161, 161, 161], [164, 164, 164], [176, 176, 176], [164, 164, 164]],)

In [26]:
tods[0]

[array([[ 502,  502,  502, ...,   -2,   -3,   -4],
        [ 529,  528,  528, ...,   21,   21,   20],
        [ 559,  559,  559, ...,   54,   54,   53],
        ...,
        [ 411,  411,  411, ...,  -99,  -99, -100],
        [ 496,  496,  496, ...,  -15,  -15,  -16],
        [ 470,  470,  470, ...,  -39,  -39,  -40]]),
 array([[921, 922, 923, ..., 576, 575, 575],
        [871, 871, 872, ..., 524, 523, 523],
        [919, 919, 920, ..., 570, 570, 569],
        ...,
        [810, 811, 812, ..., 470, 469, 468],
        [790, 790, 791, ..., 445, 444, 443],
        [840, 841, 842, ..., 497, 496, 495]]),
 array([[ 0.72333378,  0.71857547,  0.7030193 , ..., -1.16568549,
         -1.16030384, -1.15937239],
        [ 0.66077794,  0.65884039,  0.65510284, ..., -1.03088957,
         -1.06317214, -1.07669928],
        [ 0.73743407,  0.74065586,  0.73807716, ..., -1.14617984,
         -1.16268876, -1.16181935],
        ...,
        [ 0.72745254,  0.73138371,  0.72921792, ..., -1.17590974,
         

In [28]:
tod.info["idx"].shape

(899454,)

362467097

In [None]:
pars_fit = params
if fit:
    t1 = time.time()
    print_once("Started actual fitting")
    pars_fit, chisq, curve, errs = minkasi.fit_timestreams_with_derivs_manyfun(
        funs,
        params,
        npars,
        todvec,
        to_fit,
        maxiter=cfg["minkasi"]["maxiter"],
        priors=priors,
        prior_vals=prior_vals,
    )
    minkasi.comm.barrier()
    t2 = time.time()
    print_once("Took", t2 - t1, "seconds to fit")

    for i, re in enumerate(re_eval):
        if not re:
            continue
        pars_fit[i] = eval(re)

    print_once("Fit parameters:")
    for l, pf, err in zip(labels, pars_fit, errs):
        print_once("\t", l, "=", pf, "+/-", err)
    print_once("chisq =", chisq)

    if minkasi.myrank == 0:
        res_path = os.path.join(outdir, "results")
        print_once("Saving results to", res_path + ".npz")
        np.savez_compressed(
            res_path, pars_fit=pars_fit, chisq=chisq, errs=errs, curve=curve
        )