In [1]:
import numpy as np
from scipy import fftpack
import os
import subprocess
from astropy.io import fits
from astropy.table import Table, Column

In [None]:
################################################################################
def fits_in(in_file, meta_dict, test=False):
    """
    Read in an eventlist in .fits format to make the cross spectrum. Read
    in a clock-corrected GTI'd event list, populate the light curves, compute
    cross spectrum per energy channel and keep running sum (later to be an
    average) of the cross spectra.

    Assumes that the reference band times are the center of the bins, and the
    CI times are at the front of the bins.

    I take the approach: start time <= segment < end_time, to avoid double-
    counting and/or skipping events.

    Parameters
    ----------
    in_file : str
        The full path of the FITS data file being analyzed.

    ref_band_file : str
        Name of FITS optical or IR data file for reference band. This one file
        has the reference band for the whole data set. Gaps are ok.

    meta_dict : dict
        Dictionary of necessary meta-parameters for data analysis.

    test : boolean
        True if only running one segment of data for testing, False if analyzing
        the whole data file. Default=False

    Returns
    -------
    cross_spec:  np.array of floats
        3-D array of the raw cross spectrum, per segment.
        Dimensions: [n_bins, detchans, n_seg]

    ci_whole : ccf_lc.Lightcurve object
        Channel of interest for this data file.

    ref_whole : ccf_lc.Lightcurve object
        Reference band for this data file.

    n_seg : int
        Number of segments in this data file.

    dt_whole : np.array of floats
        1-D array of timestep between light curve bins for each segment. These
        will be different if artificially adjusting the QPO frequency in Fourier
        space (currently only doing that per data file, not per segment).

    df_whole : np.array of floats
        1-D array of frequency step between Fourier bins for each segment. These
        will be different if artificially adjusting the QPO frequency in Fourier
        space (currently only doing that per data file, not per segment).

    exposure : float
        The total (used) exposure of the data file.

    """

    assert tools.power_of_two(meta_dict['n_bins']), "ERROR: n_bins must be a "\
            "power of 2."
    meta_dict['obs_epoch'] = tools.obs_epoch_rxte(in_file)

    print("Input file: %s" % in_file)

    ## Determining print iterator for segments
    if meta_dict['n_bins'] == 32768:
        print_iterator = int(10)
    elif meta_dict['n_bins'] < 32768:
        print_iterator = int(10)   #TODO: change back to 20
    else:
        print_iterator = int(1)

    #######################################################
    ## Check if the FITS file exists; if so, load the data
    #######################################################

    time = np.asarray([])
    channel = np.asarray([])
    pcuid = np.asarray([])

    ## Reading an event list from an astropy table FITS file
    # try:
    #     data_table = Table.read(in_file)
    #     time = data_table['TIME']
    #     channel = data_table['CHANNEL']
    #     pcuid = data_table['PCUID']
    # except IOError:
    #     print("\tERROR: File does not exist: %s" % in_file)
    #     exit()

    ## Reading an event list from a normal FITS table
    try:
        fits_hdu = fits.open(in_file)
        time = fits_hdu[1].data.field('TIME')  ## Data is in ext 1
        channel = fits_hdu[1].data.field('CHANNEL')
        pcuid = fits_hdu[1].data.field('PCUID')
        fits_hdu.close()
    except IOError:
        print("\tERROR: File does not exist: %s" % in_file)
        exit()

    # try:
    #     fits_hdu = fits.open(in_file)
    # except IOError:
    #     print("\tERROR: File does not exist: %s" % in_file)
    #     sys.exit()
    #
    # data = fits_hdu[1].data
    # fits_hdu.close()

    ###################
    ## Initializations
    ###################

    n_seg = 0
    ci_whole = ccf_lc.Lightcurve(n_bins=meta_dict['n_bins'],
            detchans=meta_dict['detchans'], type='ci')
    ref_whole = ccf_lc.Lightcurve(n_bins=meta_dict['n_bins'],
            detchans=meta_dict['detchans'], type='ref')
    cs_whole = np.zeros((meta_dict['n_bins'], meta_dict['detchans'], 1),
            dtype=np.complex128)
    dt_whole = np.array([])
    df_whole = np.array([])
    exposure = 0
    # print(set(pcuid))
    # exit()
    start_time = time[0]
    final_time = time[-1]

    ###################################
    ## Selecting PCU for interest band
    ###################################

    PCU2_mask = pcuid == 2
    time_pcu2 = time[PCU2_mask]
    chan_pcu2 = channel[PCU2_mask]

    all_time_ci = np.asarray(time_pcu2, dtype=np.float64)
    all_energy_ci = np.asarray(chan_pcu2, dtype=np.float64)

    ######################################
    ## Getting reference band light curve
    ######################################

    refpcu_mask = pcuid != 2
    all_time_ref = np.asarray(time[refpcu_mask], dtype=np.float64)
    all_energy_ref = np.asarray(channel[refpcu_mask], dtype=np.float64)
    all_rate_ref = None
    all_err_ref = None

    seg_end_time = start_time + meta_dict['n_seconds']

    ############################
    ## Looping through segments
    ############################

    print("Segments computed:")

    while (seg_end_time + (meta_dict['adjust_seg'] * meta_dict['dt'])) <= \
            final_time:

        ## Adjusting segment length to artificially line up the QPOs
        seg_end_time += (meta_dict['adjust_seg'] * meta_dict['dt'])

        ## Get events for channels of interest
        time_ci = all_time_ci[np.where(all_time_ci < seg_end_time)]
        energy_ci = all_energy_ci[np.where(all_time_ci < seg_end_time)]

        ## Chop current segment off the rest of the list
        for_next_iteration_ci = np.where(all_time_ci >= seg_end_time)
        all_time_ci = all_time_ci[for_next_iteration_ci]
        all_energy_ci = all_energy_ci[for_next_iteration_ci]

        ## Get events for reference band
        time_ref = all_time_ref[np.where(all_time_ref < seg_end_time)]
        if not meta_dict['ref_file']:
            energy_ref = all_energy_ref[np.where(all_time_ref < seg_end_time)]
            rate_ref = [0]
        else:
            rate_ref = all_rate_ref[np.where(all_time_ref < seg_end_time)]
            err_ref = all_err_ref[np.where(all_time_ref < seg_end_time)]

        ## Chop current segment off the rest of the list
        for_next_iteration_ref = np.where(all_time_ref >= seg_end_time)
        all_time_ref = all_time_ref[for_next_iteration_ref]
        if not meta_dict['ref_file']:
            all_energy_ref = all_energy_ref[for_next_iteration_ref]
        else:
            all_rate_ref = all_rate_ref[for_next_iteration_ref]
            all_err_ref = all_err_ref[for_next_iteration_ref]

        ########################################################################
        ## At the end of a segment, populate light curve and make cross spectrum
        ########################################################################

        if len(time_ci) > 0 and \
                (len(time_ref) > 0 or
                (meta_dict['ref_file'] and \
                len(rate_ref) == meta_dict['n_bins'])):

            ##############################################################
            ## Populate the light curves for interest and reference bands
            ##############################################################

            rate_ci_2d = tools.make_2Dlightcurve(np.asarray(time_ci),
                    np.asarray(energy_ci), meta_dict['n_bins'],
                    meta_dict['detchans'], start_time, seg_end_time)

            if not meta_dict['ref_file']:
                rate_ref_2d = tools.make_2Dlightcurve( np.asarray(time_ref),
                        np.asarray(energy_ref), meta_dict['n_bins'],
                        meta_dict['detchans'], start_time, seg_end_time)
                ## Stack the reference band
                rate_ref = stack_reference_band(rate_ref_2d, instrument="PCA",
                        obs_epoch=meta_dict['obs_epoch'])

            ###########################
            ## Make the cross spectrum
            ###########################

            cs_seg, ci_seg, ref_seg = make_cs(rate_ci_2d, rate_ref, meta_dict)

            ######################################################
            ## Only keep and use segments where the variance > 0.
            ######################################################

            ## Append segment to arrays
            fft_ci_whole = np.dstack((fft_ci_whole, fft_ci_seg))
            fft_ref_whole = np.dstack((fft_ref_whole, fft_ref_seg))

            ## Sum across segments -- arrays, so it adds by index
            exposure += (seg_end_time - start_time)
            n_seg += 1
            ci_whole.mean_rate += ci_seg.mean_rate
            ref_whole.mean_rate += ref_seg.mean_rate

            if n_seg % print_iterator == 0:
                print("\t", n_seg)

            if test is True and n_seg == 1:  # For testing
                break

            start_time = seg_end_time
            seg_end_time += meta_dict['n_seconds']

        ## This next bit deals with gappy data
        else:
            start_time = max(all_time_ci[0], all_time_ref[0])
            seg_end_time = start_time + meta_dict['n_seconds']

        ## End of 'if there are counts in this segment'

    ## End of while-loop

    fft_ci_whole = fft_ci_whole[:,:,1:]
    fft_ref_whole = fft_ref_whole[:,:,1:]


    return fft_ci_whole, fft_ref_whole, n_seg