In [None]:
## This Document is meant to be repository of Useful Classes and Functions used for a Scattering Model
#By Felix W.
#Jan 20, 2023

In [None]:
#General Suite – Same dependancies as for General Fitting
from astropy.io import fits
import matplotlib.pyplot as plt
import numpy as np
import astropy.units as u
import matplotlib 
from astropy import constants as const
import astropy as astr
import scipy as sci
import scipy.signal as signal
from astropy.time import Time

#PINT Init – Same dependancies as for General Fitting
import pint.fitter
import pint
from pint.models import get_model_and_toas, get_model
from pint.residuals import Residuals
from pint.toa import get_TOAs
import pint.logging
from pint.models.timing_model import (TimingModel,Component,DelayComponent,prefixParameter)
import pint.models.timing_model as timmdl
import pint.models.parameter as p
from pint.toa_select import TOASelect
from collections import OrderedDict

In [None]:
##Multiple Scattering Delay Component | SMX
#Works using a similar structure as the DMX component w/out a seperate SM function 

##Parameters
#Power – This is the power used in the calculation of scattering: default = 4
#SMX_i - Local SM value
#SMXR1_i - Beginning of the SMX epoch
#SMXR2_i - End of the SMX epoch

#Note: Initialization adds a default SMX_0001 range with zero/null values

##UsefulFunctions:
#add_SMX_Range(mjd_start, mjd_end, index, smx, frozen)
#remove_SMX_Range(index)

class ScatteringSMX(DelayComponent):
    
    register = True
    
    def __init__(self):
        super().__init__()
        self.add_param(p.floatParameter(name='power',value=4, units=u.s/u.s,description='Scattering Power', longdouble=False, Frozen=True))
        self.add_param(p.floatParameter(name='SMX', value=0, units=u.s**(1-self.power.value),description='Scattering Measure Ref', longdouble=True,frozen=True))
        self.add_SMX_range(None, None, smx=0, frozen=False, index=1)
        self.delay_funcs_component += [self.smx_delay]
        self.set_special_params(["SMX_0001", "SMXR1_0001", "SMXR2_0001"])
    
    def add_SMX_range(self, mjd_start, mjd_end, index=None, smx=0, frozen=False):
        
        "This will automatically add a sequential SMX range or if specified at a selected index"
        
        if index is None:
            dct = self.get_prefix_mapping_component("SMX_")
            index = np.max(list(dct.keys())) + 1
        i = f"{int(index):04d}"

        if mjd_end is not None and mjd_start is not None:
            if mjd_end < mjd_start:
                raise ValueError("Starting MJD is greater than ending MJD.")
        elif mjd_start != mjd_end:
            raise ValueError("Only one MJD bound is set.")

        if int(index) in self.get_prefix_mapping_component("SMX_"):
            raise ValueError(
                "Index '%s' is already in use in this model. Please choose another."
                % index)

        if isinstance(smx, u.quantity.Quantity):
            smx = smx.to_value(u.s**(1-self.power.value))
        if isinstance(mjd_start, Time):
            mjd_start = mjd_start.mjd
        elif isinstance(mjd_start, u.quantity.Quantity):
            mjd_start = mjd_start.value
        if isinstance(mjd_end, Time):
            mjd_end = mjd_end.mjd
        elif isinstance(mjd_end, u.quantity.Quantity):
            mjd_end = mjd_end.value
        self.add_param(prefixParameter(
                name="SMX_" + i,
                units=u.s**(1-self.power.value),
                value=smx,
                description="Scatter measure variation",
                parameter_type="float",
                frozen=frozen))
        self.add_param(prefixParameter(
                name="SMXR1_" + i,
                units="MJD",
                description="Beginning of SMX interval",
                parameter_type="MJD",
                time_scale="utc",
                value=mjd_start))
        self.add_param(prefixParameter(
                name="SMXR2_" + i,
                units="MJD",
                description="End of SMX interval",
                parameter_type="MJD",
                time_scale="utc",
                value=mjd_end))
        self.setup()
        self.validate()
        return index
    
    def remove_SMX_range(self, index):
        """Removes all SMX parameters associated with a given index/list of indices.
        Parameters
        ----------
        index : float, int, list, np.ndarray
            Number or list/array of numbers corresponding to SMX indices to be removed from model.
        """

        if (isinstance(index, int) or isinstance(index, float)or isinstance(index, np.int64)):
            indices = [index]
        elif isinstance(index, (list, np.ndarray)):
            indices = index
        else:
            raise TypeError(
                f"index must be a float, int, list, or array - not {type(index)}")
        for index in indices:
            index_rf = f"{int(index):04d}"
            for prefix in ["SMX_", "SMXR1_", "SMXR2_"]:
                self.remove_param(prefix + index_rf)
        self.validate()
        
        
    def get_indices(self):
        """Returns an array of integers corresponding to SMX parameters.
        Returns
        -------
        inds : np.ndarray
        Array of SMX indices in model.
        """
        inds = []
        for p in self.params:
            if "SMX_" in p:
                inds.append(int(p.split("_")[-1]))
        return np.array(inds)
    
    
    def setup(self):
        super().setup()
        # Get SMX mapping.
        # Register the SMX derivatives
        self.register_deriv_funcs(self.d_delay_d_smx,'SMX')
        self.register_deriv_funcs(self.d_delay_d_power,'power')
        for prefix_par in self.get_params_of_type("prefixParameter"):
            if prefix_par.startswith("SMX_"):
                self.register_deriv_funcs(self.d_delay_d_smparam, prefix_par)
            
                
                
    def validate(self):
        """Validate the SMX parameters."""
        super().validate()
        SMX_mapping = self.get_prefix_mapping_component("SMX_")
        SMXR1_mapping = self.get_prefix_mapping_component("SMXR1_")
        SMXR2_mapping = self.get_prefix_mapping_component("SMXR2_")
        if SMX_mapping.keys() != SMXR1_mapping.keys():
            # FIXME: report mismatch
            raise ValueError(
                "SMX_ parameters do not "
                "match SMXR1_ parameters. "
                "Please check your prefixed parameters.")
        if SMX_mapping.keys() != SMXR2_mapping.keys():
            raise ValueError(
                "SMX_ parameters do not "
                "match SMXR2_ parameters. "
                "Please check your prefixed parameters.")
            
    def validate_toas(self, toas):
        SMX_mapping = self.get_prefix_mapping_component("SMX_")
        SMXR1_mapping = self.get_prefix_mapping_component("SMXR1_")
        SMXR2_mapping = self.get_prefix_mapping_component("SMXR2_")
        bad_parameters = []
        for k in SMXR1_mapping.keys():
            if self._parent[SMX_mapping[k]].frozen:
                continue
            b = self._parent[SMXR1_mapping[k]].quantity.mjd * u.d
            e = self._parent[SMXR2_mapping[k]].quantity.mjd * u.d
            mjds = toas.get_mjds()
            n = np.sum((b <= mjds) & (mjds < e))
            if n == 0:
                bad_parameters.append(SMX_mapping[k])
        if bad_parameters:
            raise MissingTOAs(bad_parameters)
    
    
    def smx_sm(self, toas):
        
        "This produces an overall array of SMX values for each TOA"
        
        condition = {}
        tbl = toas.table
        self.smx_toas_selector = TOASelect(is_range=True)
        SMX_mapping = self.get_prefix_mapping_component("SMX_")
        SMXR1_mapping = self.get_prefix_mapping_component("SMXR1_")
        SMXR2_mapping = self.get_prefix_mapping_component("SMXR2_")
        for epoch_ind in SMX_mapping.keys():
            r1 = getattr(self, SMXR1_mapping[epoch_ind]).quantity
            r2 = getattr(self, SMXR2_mapping[epoch_ind]).quantity
            condition[SMX_mapping[epoch_ind]] = (r1.mjd, r2.mjd)
        select_idx = self.smx_toas_selector.get_select_index(condition, tbl["mjd_float"])
        # Get SMX delays
        sm = np.zeros(len(tbl)) * self._parent.SMX.units
        for k, v in select_idx.items():
            sm[v] = getattr(self, k).quantity
        return sm
    
    
    def smx_delay(self,toas,delay):
        
        "Main SMX delay Function"
        
        try:
            bfreq = self._parent.barycentric_radio_freq(toas)
        except AttributeError:
            warn("Using topocentric frequency for dedispersion!")
            bfreq = toas.table["freq"]
            
        return (self.smx_sm(toas) /bfreq.to(u.MHz)**self.power.value).to(u.s)
    
    
    def d_delay_d_smparam(self,toas,param_name,delay):
        
        "Delay Derivative Function for the central SM parameter"
        
        try:
            bfreq = self._parent.barycentric_radio_freq(toas)
        except AttributeError:
            warn("Using topocentric frequency for dedispersion!")
            bfreq = toas.table["freq"]
        param_unit = getattr(self, param_name).units
        d_dm_d_smparam = np.zeros(toas.ntoas) * u.s**3 / u.s**3
        d_dm_d_smparam += self.d_sm_d_SMX(toas, param_name)
        return (d_dm_d_smparam / bfreq.to(u.MHz)**self.power.value).to(u.s**self.power.value)
    
    
    def d_sm_d_SMX(self, toas, param_name, acc_delay=None):
        
        "Change in the SMX array for a change in a specific SMX parameter"
        
        condition = {}
        tbl = toas.table
        self.dmx_toas_selector = TOASelect(is_range=True)
        param = getattr(self, param_name)
        smx_index = param.index
        SMXR1_mapping = self.get_prefix_mapping_component("SMXR1_")
        SMXR2_mapping = self.get_prefix_mapping_component("SMXR2_")
        r1 = getattr(self, SMXR1_mapping[smx_index]).quantity
        r2 = getattr(self, SMXR2_mapping[smx_index]).quantity
        condition = {param_name: (r1.mjd, r2.mjd)}
        select_idx = self.dmx_toas_selector.get_select_index(condition, tbl["mjd_float"])

        try:
            bfreq = self._parent.barycentric_radio_freq(toas)
        except AttributeError:
            warn("Using topocentric frequency for dedispersion!")
            bfreq = tbl["freq"]
        smx = np.zeros(len(tbl))
        for k, v in select_idx.items():
            smx[v] = 1.0
        return smx * (u.s**3) / (u.s**3)
    
    
    def d_delay_d_smx(self,toas,param,delay):
        
        "Delay Derivative Function for SMX"
        
        try:
            bfreq = self._parent.barycentric_radio_freq(toas)
        except AttributeError:
            warn("Using topocentric frequency for dedispersion!")
            bfreq = toas.table["freq"]
            
        return (1 /bfreq.to(u.MHz)**self.power.value).to(u.s**self.power.value)
    
    def d_delay_d_power(self,toas,param,delay):
        
        "Delay Derivative Function for Power –– WIP!!!! NOT WORKING ATM"
        
        try:
            bfreq = self._parent.barycentric_radio_freq(toas)
        except AttributeError:
            warn("Using topocentric frequency for dedispersion!")
            bfreq = toas.table["freq"]
        
        return -(self.smx_sm(toas)/bfreq.to(u.MHz)**self.power.value*np.log(bfreq.to(u.MHz).value)).to(u.s)
    
    def get_prefix_mapping_component(self, prefix):
        """Get the index mapping for the prefix parameters.
        Parameters
        ----------
        prefix : str
           Name of prefix.
        Returns
        -------
        dict
           A dictionary with prefix parameter real index as key and parameter
           name as value.
        """
        parnames = [x for x in self.params if x.startswith(prefix)]
        mapping = dict()
        for parname in parnames:
            par = getattr(self, parname)
            if par.is_prefix and par.prefix == prefix:
                mapping[par.index] = parname
        return OrderedDict(sorted(mapping.items()))

In [None]:
##Useful Functions

def FreezeComp(model,comp):
    
    "Used to totally freeze DMX or SMX"
    
    #comp can be 'DMX' or 'SMX'
    
    b = model.free_params
    
    for i in range(262):
            a = comp+'_'+str(i+1).zfill(4)
            
            if (a in b)==True:
                b.remove(a)
                
    model.free_params = b
    return model


def UnFreezeComp(model,comp):
    
    "Used to totally unfreeze DMX or SMX"
    
    #comp can be 'DMX' or 'SMX'
    
    b = model.free_params
    
    for i in range(262):
            a = comp+'_'+str(i+1).zfill(4)
            
            if (a in b)==False:
                b.append(a)
                
    model.free_params = b
    return model


def AddScatter(model,freeze):

    "Adds ScatteringSMX to specific TimingModel using the DMX ranges already provided"
    
    #freeze is used to specify if the added SMX component is initially frozen | Boolean

    all_components = Component.component_types

    scatteringx_class = all_components["ScatteringSMX"]
    scatteringx = scatteringx_class() 

    model.add_component(scatteringx, validate=False)

    #Add in SMX Ranges Identical to DMX Ranges
    for i in range(261):
            b = 'DMXR1_'+str(i+2).zfill(4)
            c = 'DMXR2_'+str(i+2).zfill(4)
            model.add_SMX_range(model.get_params_dict(which="all")[b].value,model.get_params_dict(which="all")[c].value, frozen=freeze)

    model.remove_SMX_range(index=1)
    model.add_SMX_range(model.get_params_dict(which="all")['DMXR1_0001'].value,model.get_params_dict(which="all")['DMXR2_0001'].value, index=1, frozen=freeze)
    
    return model