In [96]:
print("Preparing environment...")

import os
import platform
from itertools import chain

import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from astropy import constants as c
from astropy import table as tbl
from astropy import units as u
from astropy.io import fits
from matplotlib.backends.backend_pdf import PdfPages
from scipy import interpolate, odr
from specutils.utils import wcs_utils as suw
from tqdm.auto import tqdm

# %matplotlib widget
plt.rcParams["savefig.facecolor"] = "w"

print("Done.")

Preparing environment...
Done.


In [99]:
### All functions


# Function to parse tables and set any rows with NaNs to zero, to avoid errors
def RemoveNaNs(tableObj):
    for col in tableObj.colnames:
        for x in range(0, len(tableObj[col]), 1):
            if np.isnan(tableObj[col][x]):
                tableObj[col][x] = 0
    return tableObj


# Function to use appropriate OS directory structure
def GetDirStruct():
    match platform.system():
        case "Windows":
            dirPrefix = "D:/"
        case "Linux" | "Ubuntu":
            dirPrefix = "/mnt/d/"
        case "macOS" | "Darwin":
            dirPrefix = "/Volumes/Storage/"
        case _:
            raise Exception(
                f'OS not recognised: "{platform.system()}". Please define a custom switch inside GetDirStruct().'
            )
    return dirPrefix


# Function to grab list of spectra in a directory
def GetSpec():
    dirPrefix = GetDirStruct()
    specDir = dirPrefix + specFolder
    specList = !bash -c "ls {specDir}*1D.fits"
    specList = [file.split("/")[-1] for file in specList]
    specNames = [file.split("_")[0] for file in specList]
    return specDir, specList, specNames


# Function to import files to dictionary
def ImportSpec(specDir, specList, specNames):
    specData = {}
    for file, name in tqdm(list(zip(specList, specNames)), desc="Importing spectra"):
        specRaw = fits.open(specDir + file)
        specFlux = specRaw[1].data * (u.W / u.m**3)
        specErr = specRaw[2].data * (u.W / u.m**3)
        specWave = specRaw[9].data * u.m
        specData[name] = tbl.QTable(
            [specWave, specFlux, specErr], names=("Wavelength", "Flux", "Error")
        )
        specData[name] = RemoveNaNs(specData[name])
        specRaw.close()
    return specData


# Function to plot the spectra
def PlotSpec(specData, specList, specNames):
    !bash -c "mkdir -p ../../Working_Directory/Apo_Phot_Utils/{outputFolder}/plots/spectra/"
    pdfWave = PdfPages(
        f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}plots/spectra/Spectra_Lambda.pdf"
    )
    pdfFreq = PdfPages(
        f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}plots/spectra/Spectra_NuAng.pdf"
    )
    pdfFreqTrue = PdfPages(
        f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}plots/spectra/Spectra_NuHz.pdf"
    )
    for file, name in tqdm(
        zip(specList, specNames), desc="Plotting spectra", total=len(specList)
    ):
        plt.plot(
            specData[name]["Wavelength"].to(u.angstrom),
            specData[name]["Flux"].to(u.erg / u.cm**2 / u.s / u.angstrom),
            lw=0.5,
        )
        plt.fill_between(
            specData[name]["Wavelength"].to(u.angstrom).value,
            (specData[name]["Flux"] - specData[name]["Error"])
            .to(u.erg / u.cm**2 / u.s / u.angstrom)
            .value,
            (specData[name]["Flux"] + specData[name]["Error"])
            .to(u.erg / u.cm**2 / u.s / u.angstrom)
            .value,
            alpha=0.25,
        )
        plt.xscale("log")
        plt.xlabel(f"Wavelength ({u.angstrom})")
        plt.ylabel(f"Flux ({u.erg / u.cm**2 / u.s / u.angstrom})")
        plt.title(f"Galaxy {name}")
        plt.savefig(pdfWave, format="pdf")
        plt.close()

        plt.plot(
            specData[name]["Wavelength"].to(u.angstrom),
            (specData[name]["Flux"] * specData[name]["Wavelength"] ** 2 / c.c).to(
                u.jansky
            ),
            lw=0.5,
        )
        plt.fill_between(
            specData[name]["Wavelength"].to(u.angstrom).value,
            (
                (specData[name]["Flux"] - specData[name]["Error"])
                * specData[name]["Wavelength"] ** 2
                / c.c
            )
            .to(u.jansky)
            .value,
            (
                (specData[name]["Flux"] + specData[name]["Error"])
                * specData[name]["Wavelength"] ** 2
                / c.c
            )
            .to(u.jansky)
            .value,
            alpha=0.25,
        )
        plt.xscale("log")
        plt.xlabel(f"Wavelength ({u.angstrom})")
        plt.ylabel(f"Flux ({u.jansky})")
        plt.title(f"Galaxy {name}")
        plt.savefig(pdfFreq, format="pdf")
        plt.close()

        plt.plot(
            specData[name]["Wavelength"].to(u.Hz, equivalencies=u.spectral()),
            (specData[name]["Flux"] * specData[name]["Wavelength"] ** 2 / c.c).to(
                u.jansky
            ),
            lw=0.5,
        )
        plt.fill_between(
            specData[name]["Wavelength"].to(u.Hz, equivalencies=u.spectral()).value,
            (
                (specData[name]["Flux"] - specData[name]["Error"])
                * specData[name]["Wavelength"] ** 2
                / c.c
            )
            .to(u.jansky)
            .value,
            (
                (specData[name]["Flux"] + specData[name]["Error"])
                * specData[name]["Wavelength"] ** 2
                / c.c
            )
            .to(u.jansky)
            .value,
            alpha=0.25,
        )
        plt.xscale("log")
        plt.xlabel(f"Wavelength ({u.Hz})")
        plt.ylabel(f"Flux ({u.jansky})")
        plt.title(f"Galaxy {name}")
        plt.savefig(pdfFreqTrue, format="pdf")
        plt.close()

    pdfWave.close()
    pdfFreq.close()
    pdfFreqTrue.close()
    return


# Function to grab list of filter curves
def GetFilter():
    dirPrefix = GetDirStruct()
    filterDir = dirPrefix + filterFolder
    filterList = !bash -c "ls {filterDir}"
    filterList = [file.split("/")[-1] for file in filterList]
    filterNames = [file.split("_")[0] for file in filterList]
    return filterDir, filterList, filterNames


# Function to import filters to a dictionary
def ImportFilter(filterDir, filterList, filterNames):
    filterData = {}
    for file, name in tqdm(
        zip(filterList, filterNames), desc="Importing filters", total=len(filterList)
    ):
        filterData[name] = tbl.QTable.read(filterDir + file, format="ascii")
        filterData[name]["Microns"].unit = u.um
        filterData[name].rename_column("Microns", "Wavelength")
        filterData[name]["Throughput"].unit = u.dimensionless_unscaled
        filterData[name] = RemoveNaNs(filterData[name])
    return filterData


# Function to interpolate datapoints
def InterpFunc(funcXs, funcYs):
    return interpolate.interp1d(funcXs, funcYs, kind="cubic")


# Function to find grid overlap for convolutions of target using input
def FindGrid(targetGrid, inputGrid, inputData):
    # Sort all arrays
    idxSorted = np.argsort(inputGrid)
    sortedGrid = inputGrid[idxSorted]
    sortedTarget = np.sort(targetGrid)
    # # Check for case with zero overlap
    # # if sortedGrid[-1] < sortedTarget[0] or sortedGrid[0] > sortedTarget[-1]:
    # #     overlapGrid = []
    # #     overlapData = []
    # # else:
    # # Find first value in input that overlaps with target
    if sortedGrid[0] < sortedTarget[0]:
        # Grid is at least smaller than beginning of target range
        # Therefore, some initial portion must be chopped off
        idxLeft = np.searchsorted(sortedGrid, sortedTarget[0], side="right")
        if idxLeft == len(sortedGrid) and inputGrid[idxLeft - 1] < sortedTarget[0]:
            # Grid end is smaller than or equal to beginning of target range
            # Chop entire grid by leaving idxLeft at maximum value
            pass
        elif inputGrid[idxLeft - 1] == sortedTarget[0]:
            # Grid end is equal to beginning of target range
            # Chop all except final value
            # OR
            # Values overlap exactly, make sure to include outside value
            idxLeft -= 1
    else:
        # Grid is always greater than beginning of target range
        # No need to chop off any initial portion of grid
        idxLeft = 0
    # Find last value in input that overlaps with target
    if sortedGrid[-1] > sortedTarget[-1]:
        # Grid is at least larger than end of target range
        # Therefore, some end portion must be chopped off
        idxRight = np.searchsorted(sortedGrid, sortedTarget[-1], side="left")
        if idxRight == 0 and inputGrid[idxRight] > sortedTarget[-1]:
            # Grid start is larger than the end of the target range
            # Chop entire grid by leaving idxRight at minimum value
            pass
        elif inputGrid[idxRight] < sortedTarget[-1]:
            # Values do not overlap exactly, make sure to exclude outside value
            idxRight -= 1
    else:
        # Grid is always less than end of target range
        # No need to chop off any ending section
        idxRight = -1
    # Return the section of input that overlaps target
    overlapGrid = sortedGrid[idxLeft:idxRight]
    overlapData = inputData[idxSorted][idxLeft:idxRight]
    # In case of zero overlap, redefine arrays to prevent errors
    if len(overlapGrid) == 0:
        overlapGrid = np.array([np.nan]) * overlapGrid.unit
        overlapData = np.array([np.nan]) * overlapData.unit
    
    # idxSorted = np.argsort(inputGrid)
    # sortedGrid = inputGrid[idxSorted]
    # sortedTarget = np.sort(targetGrid)
    # # Check for case with zero overlap
    # # if sortedGrid[-1] < sortedTarget[0] or sortedGrid[0] > sortedTarget[-1]:
    # #     overlapGrid = []
    # #     overlapData = []
    # # else:
    # # Find first value in input that overlaps with target
    # if sortedGrid[0] < sortedTarget[0]:
    #     idxLeft = np.searchsorted(sortedGrid, sortedTarget[0], side="right")
    #     if inputGrid[idxLeft - 1] == sortedTarget[0]:
    #         idxLeft -= 1
    # else:
    #     idxLeft = 0
    # # Find last value in input that overlaps with target
    # if sortedGrid[-1] > sortedTarget[-1]:
    #     idxRight = np.searchsorted(sortedGrid, sortedTarget[-1], side="left")
    #     if inputGrid[idxRight] != sortedTarget[-1]:
    #         idxRight -= 1
    # else:
    #     idxRight = -1
    # # Return the section of input that overlaps target
    # print(idxLeft)
    # print(idxRight)
    # overlapGrid = sortedGrid[idxLeft:idxRight]
    # overlapData = inputData[idxSorted][idxLeft:idxRight]
    return overlapGrid, overlapData


# Function to manage convolution calculation from first array onto second array
def ConvolveFunc(firstXs, firstYs, secondXs, secondYs):
    # Interpolate datapoints of first array
    firstInterp = InterpFunc(firstXs.to(u.m), firstYs)
    # Find portion of relevant convolution grid for first array based on second array
    secondOverlapGrid, secondOverlapData = FindGrid(firstXs, secondXs, secondYs)
    # Convolve first array onto second array's grid
    firstConvolved = firstInterp(secondOverlapGrid.to(u.m)) * firstYs.unit
    return firstConvolved, secondOverlapGrid, secondOverlapData


# Function to convolve spectra and filters onto same grid
def MergeGrids(specFile, filterFile):
    # Find mutual overlaps
    specConvolved, filterOverlapGrid, filterOverlapData = ConvolveFunc(
        specFile["Wavelength"],
        specFile["Flux"],
        filterFile["Wavelength"],
        filterFile["Throughput"],
    )
    filterConvolved, specOverlapGrid, specOverlapData = ConvolveFunc(
        filterFile["Wavelength"],
        filterFile["Throughput"],
        specFile["Wavelength"],
        specFile["Flux"],
    )
    # Save as tables
    specTable = tbl.QTable(
        [
            np.append(specOverlapGrid, filterOverlapGrid).to(u.m),
            np.append(specOverlapData, specConvolved),
        ],
        names=("Wavelength", "Flux"),
    )
    specTable.sort("Wavelength")
    filterTable = tbl.QTable(
        [
            np.append(filterOverlapGrid, specOverlapGrid).to(u.m),
            np.append(filterOverlapData, filterConvolved),
        ],
        names=("Wavelength", "Throughput"),
    )
    filterTable.sort("Wavelength")
    # Join tables
    mergedTable = tbl.join(specTable, filterTable, keys="Wavelength")
    # Include error table only for spectral grid (errors should not be interpolated onto filter grid)
    filterConvolvedErr, errOverlapGrid, errOverlapData = ConvolveFunc(
        filterFile["Wavelength"],
        filterFile["Throughput"],
        specFile["Wavelength"],
        specFile["Error"],
    )
    errorTable = tbl.QTable(
        [errOverlapGrid.to(u.m), errOverlapData, filterConvolvedErr],
        names=("Wavelength", "Error", "Throughput"),
    )
    return mergedTable, errorTable


# Function to shift flux to photon space
def ShiftPhotonSpace(tableFlux, tableWave):
    tableFlux *= tableWave / (c.h * c.c)
    return


# Function to calculate throughput on a merged grid
def CalcThroughput(
    firstYs, secondYs, commonGrid, errYs=None, errY2s=None, errGrid=None
):
    combinedYs = firstYs * secondYs
    # Trapezoidal rule: SUM( 1/2 * (f(a) + f(b)) * (b - a) )
    specThrough = np.sum(
        (combinedYs + np.roll(combinedYs, 1))[1:]
        / 2
        * (commonGrid - np.roll(commonGrid, 1))[1:]
    )
    # Error propagation: SQRT( SUM( ( SQRT( (df(a)^2 + df(b)^2) / 2 ) * (b - a) )^2 ) )
    # Note that the 1/2 is inside the square root. This seems to stem from Pooled Variance
    # estimates in statistics. See [https://en.wikipedia.org/wiki/Pooled_variance].
    # This is the only way to obtain SNRs that match crude estimates.
    # See also [https://en.wikipedia.org/wiki/Propagation_of_uncertainty] and the
    # arithmetic mean.
    if errYs == errY2s == errGrid == None:
        errThrough = None
    else:
        combinedErrs = errYs * errY2s
        errThrough = np.sqrt(
            np.sum(
                (
                    np.sqrt((combinedErrs**2 + np.roll(combinedErrs**2, 1)) / 2)[1:]
                    * (errGrid - np.roll(errGrid, 1))[1:]
                )
                ** 2
            )
        )
    return specThrough, errThrough
    # return np.trapz(firstYs * secondYs, x=commonGrid) # * (firstYs.unit * secondYs.unit * commonGrid.unit)


# Function to find throughput given a spectrum and filter
def FindThroughput(specFile, filterFile):
    # Convolve onto the same grid
    mergedTable, errorTable = MergeGrids(specFile, filterFile)
    # Shift to photon space
    ShiftPhotonSpace(mergedTable["Flux"], mergedTable["Wavelength"])
    ShiftPhotonSpace(errorTable["Error"], errorTable["Wavelength"])
    # Integrate through the filter
    specThrough, errThrough = CalcThroughput(
        mergedTable["Flux"],
        mergedTable["Throughput"],
        mergedTable["Wavelength"],
        errorTable["Error"],
        errorTable["Throughput"],
        errorTable["Wavelength"],
    )
    return (
        specThrough.to(1 / u.cm**2 / u.s),
        mergedTable,
        errThrough.to(1 / u.cm**2 / u.s),
        errorTable,
    )


# Function to grab central wavelength of filter
def FindCentral(filterWave, filterThrough):
    return np.average(
        filterWave, weights=filterThrough
    )  # Weighted by throughput, not all wavelengths contribute equally


# Function to normalise the throughput to the correct zero-point
def NormaliseValue(specTable, specValue, filterFile):
    specThrough = specTable["Throughput"]
    specGrid = specTable["Wavelength"]
    # norm_ref = 10**(48.6/(-2.5)) # reference zero magnitude
    normRef = 1 * u.Jy  # * 10**(-9) # reference flat-value in f_nu
    normYs = (normRef * c.c / specGrid**2).to(
        u.erg / u.cm**2 / u.s / u.angstrom
    )  # reference flat-value in wavelength space
    ShiftPhotonSpace(normYs, specGrid)  # reference flat-value in photon space
    normValue = CalcThroughput(specThrough, normYs, specGrid)[0].to(
        specValue.unit
    )  # magnitude of normalisation factor
    specNormed = ((specValue / normValue) * u.Jy).to(
        u.nJy
    )  # throughput returned to Jy in wavelength space
    return specNormed


def BalmerBreak(specFile, specName, redshiftFile):
    # Define regions
    balmerLeftRange = np.array([3500, 3650]) * u.angstrom
    balmerRightRange = np.array([3800, 3950]) * u.angstrom
    # Grab relevant grids
    for row in range(0, len(redshiftFile), 1):
        if (
            int(specName) == int(redshiftFile[row][redshiftID])
            and redshiftFile[row][redshiftZ] > 0
        ):
            redshiftFactor = 1 + redshiftFile[row][redshiftZ]
            break
        else:
            redshiftFactor = np.nan
    if redshiftFactor != 0:
        balmerLeftWave, balmerLeftSpec = FindGrid(
            balmerLeftRange,
            specFile["Wavelength"] / redshiftFactor,
            specFile["Flux"] * redshiftFactor,
        )
        _, balmerLeftErr = FindGrid(
            balmerLeftRange,
            specFile["Wavelength"] / redshiftFactor,
            specFile["Error"] * redshiftFactor,
        )
        balmerLeftTable = tbl.QTable(
            [balmerLeftWave, balmerLeftSpec, balmerLeftErr],
            names=("Wavelength", "Flux", "Error"),
        )
        balmerRightWave, balmerRightSpec = FindGrid(
            balmerRightRange,
            specFile["Wavelength"] / redshiftFactor,
            specFile["Flux"] * redshiftFactor,
        )
        _, balmerRightErr = FindGrid(
            balmerRightRange,
            specFile["Wavelength"] / redshiftFactor,
            specFile["Error"] * redshiftFactor,
        )
        balmerRightTable = tbl.QTable(
            [balmerRightWave, balmerRightSpec, balmerRightErr],
            names=("Wavelength", "Flux", "Error"),
        )
        # Shift to photon space
        ShiftPhotonSpace(balmerLeftTable["Flux"], balmerLeftTable["Wavelength"])
        ShiftPhotonSpace(balmerLeftTable["Error"], balmerLeftTable["Wavelength"])
        ShiftPhotonSpace(balmerRightTable["Flux"], balmerRightTable["Wavelength"])
        ShiftPhotonSpace(balmerRightTable["Error"], balmerRightTable["Wavelength"])
        # Calculate average flux
        balmerLeftVal = np.sum(balmerLeftTable["Flux"]) / len(balmerLeftTable["Flux"])
        balmerLeftErr = np.sqrt(np.sum(balmerLeftTable["Error"]**2) / len(balmerLeftTable["Error"]))
        balmerRightVal = np.sum(balmerRightTable["Flux"]) / len(
            balmerRightTable["Flux"]
        )
        balmerRightErr = np.sqrt(np.sum(balmerRightTable["Error"]**2) / len(balmerRightTable["Error"]))
        # Calculate ratio
        balmerRatio = balmerRightVal / balmerLeftVal
        balmerRatioErr = np.sqrt((balmerLeftErr / balmerLeftVal)**2 + (balmerRightErr / balmerRightVal)**2) * balmerRatio
        balmerVals = [balmerLeftVal, balmerLeftErr, balmerRightVal, balmerRightErr, balmerRatio, balmerRatioErr]
    return balmerVals


# Function to loop through filters and balmer breaks for one spectrum
def HandleSpectrum(specFile, specName, filterData, redshiftFile):
    specValues = []
    for filterName, filterFile in zip(filterData.keys(), filterData.values()):
        specValue, specTable, errValue, errTable = FindThroughput(specFile, filterFile)
        specValues += [
            NormaliseValue(specTable, specValue, filterFile),
            NormaliseValue(errTable, errValue, filterFile),
        ]
    specValues += BalmerBreak(specFile, specName, redshiftFile)
    return specValues


# Function to loop through each spectrum and save values to a table
def LoopSpectra(specData, filterData, redshiftFile):
    !bash -c "mkdir -p ../../Working_Directory/Apo_Phot_Utils/{outputFolder}"
    rows = []
    for specName, specFull in tqdm(
        zip(specData.keys(), specData.values()),
        desc="Calculating throughputs",
        total=len(specData.keys()),
    ):
        rows += [
            [int(specName)]
            + HandleSpectrum(specFull, specName, filterData, redshiftFile)
        ]
    names = (
        ["ID"]
        + list(
            chain.from_iterable(
                [
                    [key + " Throughput", key + " Throughput Error"]
                    for key in filterData.keys()
                ]
            )
        )
        + ["Balmer_left", "Balmer_left_err", "Balmer_right", "Balmer_right_err", "Balmer_ratio", "Balmer_ratio_err"]
    )
    outTable = tbl.Table(rows=rows, names=names)
    outTable.sort("ID")
    outTable.write(
        f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}/throughputs.fits",
        overwrite=True,
    )
    return outTable


# Function to import photometry files
def ImportPhot():
    photList = !bash -c "ls -1 {photDir}/FilterData/*_summary.fits | xargs -n 1 basename"
    photList = [file.split("/")[-1] for file in photList]
    photTables = {}
    for file in tqdm(photList, desc="Importing apodised fluxes"):
        pointingNum = int(file.split("_")[0])
        photTables[pointingNum] = tbl.QTable.read(f"{photDir}/FilterData/{file}")
    return photTables


# Function to import Astrometry
def ImportAstro():
    astroTables = {}
    dirPrefix = GetDirStruct()
    for slit, point in tqdm(
        list(zip(slitNums, pointNums)), desc="Importing astrometry"
    ):
        sourceList = tbl.QTable.read(
            f"{dirPrefix}/Astrometry/{astroDir}/{pointFolder}/{pointPrefix}{point}{pointSuffix}",
            format="ascii",
        )
        sourceAstro = tbl.QTable.read(
            f"{dirPrefix}/Astrometry/{astroDir}/{slitFolder}/{slitPrefix}{slit}{slitSuffix}",
            format="ascii",
        )
        astroTables[point] = tbl.join(sourceList, sourceAstro, keys="ID")
        astroTables[point]["Source_RA"].unit = u.degree
        astroTables[point]["Source_Dec"].unit = u.degree
    return astroTables


# Function to grab list of slitloss corrections
def GrabSlitloss():
    lossDir = GetDirStruct() + lossFolder
    lossList = {}
    for lossPoint in tqdm(lossPoints, desc="Identifying available slitlosses"):
        lossList[lossPoint] = {}
        for lossExp in tqdm(lossExps, desc="Pointings", leave=False):
            lossNames = !bash -c "ls {lossDir}/{lossPoint}/{lossPrefix01}{lossPoint}_{lossExp}{lossPrefix02}*{lossSuffix}"
            lossNames = [
                f"{file.split('/')[-2]}/{file.split('/')[-1]}" for file in lossNames
            ]
            lossIDs = [
                int(file.split("_")[-3].split("idcat")[-1]) for file in lossNames
            ]
            lossList[lossPoint][lossExp] = tbl.Table(
                [lossIDs, lossNames], names=("ID", "Name")
            )
    return lossDir, lossList


# Function to import slitloss corrections
def ImportSlitloss(lossDir, lossList, photTables):
    lossData = {}
    for pointNum, pointName in tqdm(
        list(zip(pointNums, lossPoints)), desc="Importing slitlosses"
    ):
        IDList = tbl.Table([photTables[pointNum]["ID"]])
        lossData[pointName] = {}
        for lossExp in tqdm(
            lossList[pointName].keys(), desc="Parsing exposures", leave=False
        ):
            lossData[pointName][lossExp] = {}
            lossData[pointName][lossExp]["Index"] = tbl.join(
                lossList[pointName][lossExp], IDList, keys="ID"
            )
            for row in tqdm(
                range(0, len(lossData[pointName][lossExp]["Index"]), 1),
                desc="Matching IDs",
                leave=False,
            ):
                galID = lossData[pointName][lossExp]["Index"][row]["ID"]
                galName = lossData[pointName][lossExp]["Index"][row]["Name"]
                lossData[pointName][lossExp][galID] = tbl.QTable.read(
                    lossDir + galName, format="ascii"
                )
                lossData[pointName][lossExp][galID].rename_column("col1", "Slitloss")
                lossData[pointName][lossExp][galID].rename_column("col2", "Wavelength")
                lossData[pointName][lossExp][galID][
                    "Slitloss"
                ].unit = u.dimensionless_unscaled
                lossData[pointName][lossExp][galID]["Wavelength"].unit = u.um
            del lossData[pointName][lossExp]["Index"]
    return lossData


# Function to count High SNR Bands
def CountSNR(countTable, fluxCols, fluxErrCols, addNewCol=False):
    SNRCounts = []
    for gal in range(0, len(countTable), 1):
        SNRCount = 0
        for flux, fluxErr in list(zip(fluxCols, fluxErrCols)):
            if countTable[flux][gal] / countTable[fluxErr][gal] > 5.0:
                SNRCount += 1
        SNRCounts += [SNRCount]
    if addNewCol == True:
        countTable.add_column(SNRCounts, name="Total High SNR Bands")
    else:
        countTable["Total High SNR Bands"] = SNRCounts
    return


# Function to calculate photometry ratios
def CalcSlitRatio(photTable, specTable, pointNum):
    combTable = tbl.join(photTable, specTable, keys="ID")
    fluxCols = [
        colname
        for colname in combTable.colnames
        if "Actual Flux" in colname and "Density" not in colname
    ]
    fluxErrCols = [colname for colname in combTable.colnames if "Flux Error" in colname]
    # throughCols = combTable.colnames[firstThrough::2]
    throughCols = [
        colname
        for colname in combTable.colnames
        if "Throughput" in colname and "Error" not in colname
    ]
    throughErrCols = [
        colname for colname in combTable.colnames if "Throughput Error" in colname
    ]
    ratioTable = tbl.QTable([combTable["ID"]])
    fluxNames = []
    # Eliminating nonsensical results from calculations
    for col in fluxCols + fluxErrCols + throughCols + throughErrCols:
        combTable[col] = np.where(combTable[col] > 0, combTable[col], np.nan)
    # 5 sigma selection criteria
    CountSNR(combTable, fluxCols, fluxErrCols)
    plt.close("all")
    counts, bins = np.histogram(combTable["Total High SNR Bands"], bins=len(fluxCols))
    plt.hist(bins[:-1], bins, weights=counts, density=True, cumulative=True)
    plt.savefig(
        f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}/plots/SNRcriteria/cumulativeSNR_{pointNum}.png"
    )
    plt.close()
    for fluxCol, fluxErrCol in list(zip(fluxCols, fluxErrCols)):
        fluxName = fluxCol.split()[0]
        for throughCol, throughErrCol in list(zip(throughCols, throughErrCols)):
            throughName = throughCol.split()[0]
            if fluxName == throughName:
                ratioTable.add_columns(
                    [
                        combTable[fluxCol],
                        combTable[throughCol],
                        combTable[fluxCol] / combTable[throughCol],
                        combTable[fluxCol]
                        / combTable[throughCol]
                        * np.sqrt(
                            (combTable[fluxErrCol] / combTable[fluxCol]) ** 2
                            + (combTable[throughErrCol] / combTable[throughCol]) ** 2
                        ),
                    ],
                    names=(
                        fluxCol,
                        throughCol,
                        f"{fluxName} Ratio",
                        f"{fluxName} Ratio Error",
                    ),
                )
                fluxNames += [fluxName]
    ratioTable.add_column(combTable["Total High SNR Bands"])
    ratioTable.sort("ID")
    ratioTable.write(
        f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}/RawOutput/ratios_{pointNum}.fits",
        overwrite=True,
    )
    return ratioTable, fluxNames


# Function to calculate slit-losses at centre of filters
def CalcLossTable(lossData, fluxNames, lossTables, pointName):
    lossTables[pointName] = {}
    for exposure in lossData.keys():
        rows = []
        for galaxy in lossData[exposure].keys():
            row = []
            row += [galaxy]
            lossInterpolated = interpolate.interp1d(
                lossData[exposure][galaxy]["Wavelength"],
                lossData[exposure][galaxy]["Slitloss"],
            )
            for fluxName in fluxNames:
                centralWave = float(fluxName[1] + "." + fluxName[2:4]) * u.um
                centralLoss = lossInterpolated(centralWave)
                row += [1 / centralLoss * u.dimensionless_unscaled]
            rows += [row]
        names = ["ID"] + [f"{fluxName} slitloss" for fluxName in fluxNames]
        lossTables[pointName][exposure] = tbl.QTable(rows=rows, names=names)
    return lossTables


# Function to calculate slit-loss factor between photometry and model
def CalcDiffTable(ratioTable, lossTables, pointName):
    diffTable = tbl.QTable([ratioTable["ID"], ratioTable["Total High SNR Bands"]])
    for exposure in lossTables.keys():
        matchedTable = tbl.join(ratioTable, lossTables[exposure], keys="ID")
        firstRatio = 3
        firstLoss = len(ratioTable.colnames)
        ratioNames = matchedTable.colnames[firstRatio:firstLoss:4]
        lossNames = matchedTable.colnames[firstLoss:]
        tempTable = tbl.QTable(
            [matchedTable["ID"], matchedTable["Total High SNR Bands"]]
        )
        for ratioName in ratioNames:
            for lossName in lossNames:
                if ratioName.split()[0] == lossName.split()[0]:
                    tempTable.add_columns(
                        [
                            matchedTable[ratioName] / matchedTable[lossName],
                            (
                                matchedTable[f"{ratioName} Error"]
                                / matchedTable[ratioName]
                            )
                            * (matchedTable[ratioName] / matchedTable[lossName]),
                        ],
                        names=(
                            f"{ratioName.split()[0]}_{exposure}_Diff",
                            f"{ratioName.split()[0]}_{exposure}_DiffErr",
                        ),
                    )
        diffTable = tbl.join(
            diffTable, tempTable, keys=("ID", "Total High SNR Bands"), join_type="outer"
        )
    diffTable.sort("ID")
    diffTable.write(
        f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}/RawOutput/differences_{pointName}_detailed.fits",
        overwrite=True,
    )
    return diffTable


# Function to append average differences across pointings and exposures
def CalcAverageDiff(diffTable, astroTable, pointName):
    colNames = diffTable.colnames[2::2]
    for i, colName_i in enumerate(colNames):
        stepSize = 0
        for j, colName_j in enumerate(colNames[i + 1 :]):
            if (
                colName_i.split("_")[0] == colName_j.split("_")[0]
                and colName_i.split("_")[-1] == colName_j.split("_")[-1] == "Diff"
            ):
                stepSize = j + 1
                break
        if 0 not in [stepSize]:
            break
    tempTable = tbl.QTable([diffTable["ID"], diffTable["Total High SNR Bands"]])
    for i, colName_i in enumerate(colNames[0:stepSize:2]):
        tempCol = diffTable[colName_i].copy()
        tempErr = diffTable[f"{colName_i}Err"].copy()
        count = np.where(tempCol > 0, 1, 0)
        tempErr = np.where(tempCol > 0, tempErr**2, np.nan)
        for j, colName_j in enumerate(colNames[i + stepSize :: stepSize]):
            tempCol += diffTable[colName_j]
            tempErr += np.where(
                diffTable[colName_j] > 0, diffTable[f"{colName_j}Err"] ** 2, np.nan
            )
            count += np.where(diffTable[colName_j] > 0, 1, 0)
        count = np.where(count > 0, count, np.nan)
        tempTable.add_columns(
            [tempCol / count, np.sqrt(tempErr) / np.sqrt(count)],
            names=(
                f"{colName_i.split('_')[0]}_Diff",
                f"{colName_i.split('_')[0]}_DiffErr",
            ),
        )
    diffTable = tbl.join(tempTable, astroTable, keys="ID", join_type="left")
    diffTable.sort("ID")
    colIndices = np.arange(
        diffTable.colnames.index("Total High SNR Bands") + 1,
        diffTable.colnames.index("Source_RA"),
    )
    colNames = diffTable.colnames[colIndices[0] : colIndices[-1] : 2]
    errNames = diffTable.colnames[colIndices[0] + 1 : colIndices[-1] + 1 : 2]
    CountSNR(diffTable, colNames, errNames)
    diffTable.write(
        f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}differences_{pointName}.fits",
        overwrite=True,
    )
    return diffTable


# Function to handle slit-losses
def HandleSlitloss(lossData, photTables, astroTables, specTable):
    !bash -c "mkdir -p ../../Working_Directory/Apo_Phot_Utils/{outputFolder}/RawOutput/"
    !bash -c "mkdir -p ../../Working_Directory/Apo_Phot_Utils/{outputFolder}/plots/SNRcriteria/"
    lossTables = {}
    diffTables = {}
    ratioTables = {}
    for pointNum, pointName in tqdm(
        list(zip(pointNums, lossPoints)), desc="Calculating difference factors"
    ):
        ratioTable, fluxNames = CalcSlitRatio(photTables[pointNum], specTable, pointNum)
        lossTables = CalcLossTable(
            lossData[pointName], fluxNames, lossTables, pointName
        )
        diffTable = CalcDiffTable(ratioTable, lossTables[pointName], pointName)
        diffTable = CalcAverageDiff(diffTable, astroTables[pointNum], pointName)
        diffTables[pointNum] = diffTable
        ratioTables[pointNum] = ratioTable
    return ratioTables, lossTables, diffTables

# Function to import difference factors
def ImportDiff():
    diffTables = {}
    # diffList = !bash -c "ls ../../Working_Directory/Apo_Phot_Utils/{outputFolder}/differences*.fits | grep -v \"detailed.fits\""
    for pointName, pointNum in list(zip(lossPoints, pointNums)):
        for file in diffList:
            if file.split("/")[-1].split("_")[-1].split(".fits")[0] == pointName:
                diffTables[pointNum] = tbl.QTable.read(file)
    return diffTables


# Function to plot all difference factors by RA/Dec
def PlotDiffCoords(diffTables):
    print("Plotting coordinates...")
    for pointName, pointNum in tqdm(list(zip(lossPoints, pointNums)), desc="Pointings"):
        !bash -c "mkdir -p ../../Working_Directory/Apo_Phot_Utils/{outputFolder}/plots/RA_Dec/"
        colIndices = np.arange(
            diffTables[pointNum].colnames.index("Total High SNR Bands") + 1,
            diffTables[pointNum].colnames.index("Source_RA"),
        )
        colNames = diffTables[pointNum].colnames[colIndices[0] : colIndices[-1] : 2]
        errNames = diffTables[pointNum].colnames[
            colIndices[0] + 1 : colIndices[-1] + 1 : 2
        ]
        for plotCoord in tqdm(["RA", "Dec"], desc="Coordinates", leave=False):
            for plotLog in tqdm(["lin", "log"], desc="Plots", leave=False):
                allCols = colIndices[0::2]
                cmap = mpl.colormaps["jet_r"]
                norm = mpl.colors.Normalize(vmin=min(allCols), vmax=max(allCols))
                for galaxy in tqdm(
                    range(0, len(diffTables[pointNum]["ID"]), 1),
                    desc="Galaxies",
                    leave=False,
                ):
                    if diffTables[pointNum]["Total High SNR Bands"][galaxy] >= 2:
                        allPos = np.array(
                            [diffTables[pointNum][f"Source_{plotCoord}"][galaxy].value]
                            * len(colNames)
                        )
                        allVals = np.array(list(diffTables[pointNum][colNames][galaxy]))
                        allErrs = np.array(list(diffTables[pointNum][errNames][galaxy]))
                        validCols = allCols[~np.isnan(allVals)]
                        validPos = allPos[~np.isnan(allVals)]
                        validVals = allVals[~np.isnan(allVals)]
                        validErrs = allErrs[~np.isnan(allVals)]
                        plt.scatter(
                            validPos,
                            validVals,
                            c=cmap(norm(validCols)),
                            s=2,
                            zorder=100,
                        )
                        plt.errorbar(
                            validPos,
                            validVals,
                            yerr=validErrs,
                            linestyle="",
                            marker=None,
                            ecolor=cmap(norm(validCols)),
                            mew=0,
                            lw=0.5,
                            zorder=0,
                        )
                cbar = plt.colorbar()
                cbar.set_ticklabels([colName.split("_")[0] for colName in colNames])
                cbar.ax.set_title("Filter")
                plt.xlabel(f"{plotCoord} ({u.deg})")
                plt.axhline(1)
                plt.ylabel("Difference Factor")
                if plotLog == "lin":
                    plt.ylim(-1, 10)
                elif plotLog == "log":
                    plt.yscale("log")
                    plt.ylim(10 ** (-2), 10**2)
                if plotCoord == "RA":
                    plt.title("Difference Factors vs RA", wrap=True)
                elif plotCoord == "Dec":
                    plt.title("Difference Factors vs Dec", wrap=True)
                plt.savefig(
                    f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}/plots/RA_Dec/{plotCoord}_{plotLog}_byFilter_{pointNum}-{pointName}.png",
                    dpi=300,
                )
                plt.close()
                allCols = diffTables[pointNum][f"Source_{plotCoord}"]
                cmap = mpl.colormaps["plasma"]
                norm = mpl.colors.Normalize(
                    vmin=min(allCols.value), vmax=max(allCols.value)
                )
                for colName in tqdm(colNames, desc="Filters", leave=False):
                    allPos = np.array(
                        [float(colName[1:4]) * 10] * len(diffTables[pointNum])
                    )
                    allVals = diffTables[pointNum][colName]
                    allErrs = diffTables[pointNum][f"{colName}Err"]
                    validPos = allPos[
                        np.where(diffTables[pointNum]["Total High SNR Bands"] >= 2)
                    ]
                    validVals = allVals[
                        np.where(diffTables[pointNum]["Total High SNR Bands"] >= 2)
                    ]
                    validErrs = allErrs[
                        np.where(diffTables[pointNum]["Total High SNR Bands"] >= 2)
                    ]
                    validCols = allCols[
                        np.where(diffTables[pointNum]["Total High SNR Bands"] >= 2)
                    ]
                    plt.scatter(
                        validPos,
                        validVals,
                        c=cmap(norm(validCols)),
                        s=2,
                    )
                    plt.errorbar(
                        validPos,
                        validVals,
                        yerr=validErrs,
                        linestyle="",
                        marker=None,
                        ecolor=cmap(norm(validCols)),
                        mew=0,
                        lw=0.5,
                        zorder=0,
                    )
                cbar = plt.colorbar()
                plt.clim(min(allCols.value), max(allCols.value))
                cbar.ax.set_title(f"{plotCoord} ({u.degree})")
                plt.xlabel(f"Wavelength (nm)")
                plt.axhline(1)
                plt.ylabel("Difference Factor")
                if plotLog == "lin":
                    plt.ylim(-1, 10)
                elif plotLog == "log":
                    plt.yscale("log")
                    plt.ylim(10 ** (-2), 10**2)
                if plotCoord == "RA":
                    plt.title("Difference Factors vs Wavelength", wrap=True)
                elif plotCoord == "Dec":
                    plt.title("Difference Factors vs Wavelength", wrap=True)
                plt.savefig(
                    f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}/plots/RA_Dec/{plotCoord}_{plotLog}_byCoord_{pointNum}-{pointName}.png",
                    dpi=300,
                )
                plt.close()
        for colName in tqdm(colNames, desc="Filters", leave=False):
            allCols = np.log10(diffTables[pointNum][colName])
            cmap = mpl.colormaps["gist_rainbow"]
            norm = mpl.colors.Normalize(
                vmin=min(allCols[~np.isnan(allCols)].value),
                vmax=max(allCols[~np.isnan(allCols)].value),
            )
            allPos = diffTables[pointNum]["Source_RA"]
            allVals = diffTables[pointNum]["Source_Dec"]
            validPos = allPos[
                np.where(diffTables[pointNum]["Total High SNR Bands"] >= 2)
            ]
            validVals = allVals[
                np.where(diffTables[pointNum]["Total High SNR Bands"] >= 2)
            ]
            validCols = allCols[
                np.where(diffTables[pointNum]["Total High SNR Bands"] >= 2)
            ]
            plt.scatter(
                validPos,
                validVals,
                c=cmap(norm(validCols)),
                s=2,
            )
            cbar = plt.colorbar()
            cbar.ax.set_title("Log10(Difference factor)")
            plt.clim(np.log10(0.2), np.log10(5))
            plt.xlabel(f"RA ({u.deg})")
            plt.ylabel(f"Dec ({u.deg})")
            plt.title(f"{colName.split('_')[0]}")
            plt.title(f"RA vs Dec for {colName.split('_')[0]}", wrap=True)
            plt.savefig(
                f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}/plots/RA_Dec/Both_{colName.split('_')[0]}_{pointNum}-{pointName}.png",
                dpi=300,
            )
            plt.close()
    return


# Function to plot all difference factors by offset
def PlotDiffOffsets(diffTables):
    print("Plotting offsets...")
    for pointName, pointNum in tqdm(list(zip(lossPoints, pointNums)), desc="Pointings"):
        !bash -c "mkdir -p ../../Working_Directory/Apo_Phot_Utils/{outputFolder}/plots/offsets/"
        colIndices = np.arange(
            diffTables[pointNum].colnames.index("Total High SNR Bands") + 1,
            diffTables[pointNum].colnames.index("Source_RA"),
        )
        colNames = diffTables[pointNum].colnames[colIndices[0] : colIndices[-1] : 2]
        errNames = diffTables[pointNum].colnames[
            colIndices[0] + 1 : colIndices[-1] + 1 : 2
        ]
        for plotOffset in tqdm(["Offset_x", "Offset_y"], desc="Offsets", leave=False):
            for plotLog in tqdm(["lin", "log"], desc="Plots", leave=False):
                allCols = colIndices[0::2]
                cmap = mpl.colormaps["jet_r"]
                norm = mpl.colors.Normalize(vmin=min(allCols), vmax=max(allCols))
                for galaxy in tqdm(
                    range(0, len(diffTables[pointNum]["ID"]), 1),
                    desc="Galaxies",
                    leave=False,
                ):
                    if diffTables[pointNum]["Total High SNR Bands"][galaxy] >= 2:
                        allPos = np.array(
                            [diffTables[pointNum][f"{plotOffset}"][galaxy]]
                            * len(allCols)
                        )
                        allVals = np.array(list(diffTables[pointNum][colNames][galaxy]))
                        allErrs = np.array(list(diffTables[pointNum][errNames][galaxy]))
                        plt.scatter(
                            allPos, allVals, c=cmap(norm(allCols)), s=2, zorder=100
                        )
                        plt.errorbar(
                            allPos,
                            allVals,
                            yerr=allErrs,
                            linestyle="",
                            marker=None,
                            ecolor=cmap(norm(allCols)),
                            mew=0,
                            lw=0.5,
                            zorder=0,
                        )
                cbar = plt.colorbar()
                cbar.set_ticklabels(
                    [colName.split("_")[0] for colName in colNames[0:-1]]
                )
                cbar.ax.set_title("Filter")
                if plotOffset == "Offset_x":
                    plt.xlabel("Offset (x)")
                elif plotOffset == "Offset_y":
                    plt.xlabel("Offset (y)")
                plt.axhline(1)
                plt.ylabel("Difference Factor")
                if plotLog == "lin":
                    plt.ylim(-1, 10)
                elif plotLog == "log":
                    plt.yscale("log")
                    plt.ylim(10 ** (-2), 10**2)
                if plotOffset == "Offset_x":
                    plt.title("Difference Factors vs Offset (x)", wrap=True)
                elif plotOffset == "Offset_y":
                    plt.title("Difference Factors vs Offset (y)", wrap=True)
                plt.savefig(
                    f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}/plots/offsets/{plotOffset}_{plotLog}_byFilter_{pointNum}-{pointName}.png",
                    dpi=300,
                )
                plt.close()
                allCols = diffTables[pointNum][f"{plotOffset}"]
                cmap = mpl.colormaps["plasma"]
                norm = mpl.colors.Normalize(vmin=min(allCols), vmax=max(allCols))
                for colName, errName in tqdm(
                    list(zip(colNames, errNames)), desc="Filters", leave=False
                ):
                    allPos = np.array(
                        [float(colName[1:4]) * 10] * len(diffTables[pointNum])
                    )
                    allVals = diffTables[pointNum][colName]
                    allErrs = diffTables[pointNum][errName]
                    validPos = allPos[
                        np.where(diffTables[pointNum]["Total High SNR Bands"] >= 2)
                    ]
                    validVals = allVals[
                        np.where(diffTables[pointNum]["Total High SNR Bands"] >= 2)
                    ]
                    validErrs = allErrs[
                        np.where(diffTables[pointNum]["Total High SNR Bands"] >= 2)
                    ]
                    validCols = allCols[
                        np.where(diffTables[pointNum]["Total High SNR Bands"] >= 2)
                    ]
                    plt.scatter(
                        validPos, validVals, c=cmap(norm(validCols)), s=2, zorder=100
                    )
                    plt.errorbar(
                        validPos,
                        validVals,
                        yerr=validErrs,
                        linestyle="",
                        marker=None,
                        ecolor=cmap(norm(validCols)),
                        mew=0,
                        lw=0.5,
                        zorder=0,
                    )
                cbar = plt.colorbar()
                if plotOffset == "Offset_x":
                    cbar.ax.set_title("Offset (x)")
                elif plotOffset == "Offset_y":
                    cbar.ax.set_title("Offset (y)")
                plt.xlabel(f"Wavelength (nm)")
                plt.axhline(1)
                plt.ylabel("Difference Factor")
                if plotLog == "lin":
                    plt.ylim(-1, 10)
                elif plotLog == "log":
                    plt.yscale("log")
                    plt.ylim(10 ** (-2), 10**2)
                if plotOffset == "Offset_x":
                    plt.title("Difference Factors vs Wavelength", wrap=True)
                elif plotOffset == "Offset_y":
                    plt.title("Difference Factors vs Wavelength", wrap=True)
                plt.savefig(
                    f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}/plots/offsets/{plotOffset}_{plotLog}_byOffset_{pointNum}-{pointName}.png",
                    dpi=300,
                )
                plt.close()
        for colName in tqdm(colNames, desc="Filters", leave=False):
            allCols = np.log10(diffTables[pointNum][colName])
            cmap = mpl.colormaps["gist_rainbow"]
            norm = mpl.colors.Normalize(
                vmin=min(allCols[~np.isnan(allCols)].value),
                vmax=max(allCols[~np.isnan(allCols)].value),
            )
            allPos = diffTables[pointNum]["Offset_x"]
            allVals = diffTables[pointNum]["Offset_y"]
            validPos = allPos[
                np.where(diffTables[pointNum]["Total High SNR Bands"] >= 2)
            ]
            validVals = allVals[
                np.where(diffTables[pointNum]["Total High SNR Bands"] >= 2)
            ]
            validCols = allCols[
                np.where(diffTables[pointNum]["Total High SNR Bands"] >= 2)
            ]
            plt.scatter(
                validPos,
                validVals,
                c=cmap(norm(validCols)),
                s=2,
            )
            cbar = plt.colorbar()
            cbar.ax.set_title("Log10(Difference factor)")
            plt.clim(np.log10(0.2), np.log10(5))
            plt.xlabel("Offset (x)")
            plt.ylabel("Offset (y)")
            plt.title(
                f"Offset (y) vs Offset (x) for {colName.split('_')[0]}", wrap=True
            )
            plt.savefig(
                f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}/plots/offsets/Both_{colName.split('_')[0]}_{pointNum}-{pointName}.png",
                dpi=300,
            )
            plt.close()
    return


# Function for plotting galaxy wavelength dependence
def PlotDiffGals(diffTables):
    print("Plotting galaxies...")
    for pointName, pointNum in tqdm(list(zip(lossPoints, pointNums)), desc="Pointings"):
        !bash -c "mkdir -p ../../Working_Directory/Apo_Phot_Utils/{outputFolder}/plots/galaxies/"
        colIndices = np.arange(
            diffTables[pointNum].colnames.index("Total High SNR Bands") + 1,
            diffTables[pointNum].colnames.index("Source_RA"),
        )
        colNames = diffTables[pointNum].colnames[colIndices[0] : colIndices[-1] : 2]
        errNames = diffTables[pointNum].colnames[
            colIndices[0] + 1 : colIndices[-1] + 1 : 2
        ]
        cmap = mpl.colormaps["jet_r"]
        norm = mpl.colors.Normalize(vmin=0, vmax=len(diffTables[pointNum]["ID"]))
        for plotLog in tqdm(["lin", "log"], desc="Plots", leave=False):
            for condition in tqdm(
                [2, len(colNames)], desc="Selection criteria", leave=False
            ):
                for galaxy in tqdm(
                    range(0, len(diffTables[pointNum]["ID"]), 1),
                    desc="Galaxies",
                    leave=False,
                ):
                    if (
                        diffTables[pointNum]["Total High SNR Bands"][galaxy]
                        >= condition
                    ):
                        allPos = np.array(
                            [float(colName[1:4]) * 10 for colName in colNames]
                        )
                        allVals = np.array(list(diffTables[pointNum][colNames][galaxy]))
                        allErrs = np.array(list(diffTables[pointNum][errNames][galaxy]))
                        plt.scatter(
                            allPos, allVals, c=cmap(norm(galaxy)), s=2, zorder=200
                        )
                        plt.plot(
                            allPos, allVals, c=cmap(norm(galaxy)), lw=0.5, zorder=100
                        )
                        plt.errorbar(
                            allPos,
                            allVals,
                            yerr=allErrs,
                            c=cmap(norm(galaxy)),
                            linestyle="",
                            marker=None,
                            mew=0,
                            lw=0.5,
                            zorder=0,
                        )
                if plotLog == "lin":
                    plt.ylim(0, 5)
                elif plotLog == "log":
                    plt.yscale("log")
                    plt.ylim(10 ** (-2), 10**2)
                plt.xlabel("Wavelength (nm)")
                plt.ylabel("Difference Factor")
                plt.axhline(1)
                if condition == 2:
                    selCrit = "2-sigma"
                else:
                    selCrit = "all-bands-observed"
                plt.title(
                    f"Difference Factor vs Wavelength, with {selCrit} selection criteria",
                    wrap=True,
                )
                plt.savefig(
                    f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}/plots/galaxies/galaxies_{selCrit}_{plotLog}_{pointNum}-{pointName}.png",
                    dpi=300,
                )
                plt.close()
    return


# Function to handle plotting of difference factors
def PlotDiff(diffTables):
    PlotDiffCoords(diffTables)
    PlotDiffOffsets(diffTables)
    PlotDiffGals(diffTables)
    return

In [None]:
# Directories
try:
    filterFolder
    specFolder
    outputFolder
    photDir
    slitNums
    pointNums
    lossFolder
    lossPrefix
    lossSuffix
except NameError:
    print("Defining directories...")
    filterFolder = "Throughputs/nircam_throughputs/mean_throughputs/"
    specFolder = "Spectra/HST_Deep/prism_v1.5/"
    outputFolder = "HST_Deep/prism_v1.5/"
    photDir = "../../Working_Directory/Apo_Phot/HST_Deep/"
    astroDir = "HST_Deep/for_emma_v2a_clean2/prism_trial_01_v2a_clean2/"
    slitFolder = "m_make_output/"
    slitPrefix = "table"
    slitSuffix = ".txt"
    slitNums = [3, 2, 1]
    pointFolder = "m_check_output/triple_1/"
    pointPrefix = "pointing_"
    pointSuffix = "/ds9_targ_regions.txt"
    pointNums = [7, 8, 13]
    lossFolder = "Slit-losses/HST/pathlosses_deep_R100/R100_v0/"
    lossPoints = ["p01", "p02", "p03"]
    lossPrefix01 = "pathlosses_correction_deep_hst_1x1_"
    lossPrefix02 = "_idcat"
    lossExps = ["exp00", "exp01", "exp02"]
    lossSuffix = "_v0_point.txt"

# Filters
try:
    filterData
except NameError:
    filterDir, filterList, filterNames = GetFilter()
    filterData = ImportFilter(filterDir, filterList, filterNames)

# Redshift catalogue
try:
    redshiftFile
except NameError:
    redshiftFolder = GetDirStruct() + "Redshifts/Deep_HST.csv"
    redshiftID = "ID"
    redshiftZ = "Assigned_redshift"
    redshiftFile = tbl.Table.read(redshiftFolder)

# Spectra
try:
    specData
    specTable
except NameError:
    specDir, specList, specNames = GetSpec()
    specData = ImportSpec(specDir, specList, specNames)
    specTable = LoopSpectra(specData, filterData, redshiftFile)

# Plots
# Comment out if you want to disable
if not os.path.isdir(f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}/plots/spectra"):
    PlotSpec(specData, specList, specNames)

# Photometry
try:
    photTables
except NameError:
    photTables = ImportPhot()

# Astrometry
try:
    astroTables
except NameError:
    astroTables = ImportAstro()

# Slitlosses
try:
    lossData
except NameError:
    lossDir, lossList = GrabSlitloss()
    lossData = ImportSlitloss(lossDir, lossList, photTables)

# Slit-losses
try:
    diffTables
except NameError:
    ratioTables, lossTables, diffTables = HandleSlitloss(lossData, photTables, astroTables, specTable)

# Plotting of difference factors
PlotDiff(diffTables)

print("\nDone.")

In [126]:
# Function to redshift input spectra
def RedshiftSpectra(specData, redshiftFile):
    for galaxy in specData.keys():
        redValue = redshiftFileTest[redshiftZ][np.where(redshiftFileTest[redshiftID] == int(galaxy))[0][0]]
        specData[galaxy]["Wavelength"] *= 1 + redValue
        specData[galaxy]["Flux"] /= 1 + redValue
        specData[galaxy]["Error"] /= 1 + redValue
    return

# Function to remove redshift from input spectra
def BlueshiftSpectra(specData, redshiftFile):
    for galaxy in specData.keys():
        redValue = redshiftFileTest[redshiftZ][np.where(redshiftFileTest[redshiftID] == int(galaxy))[0][0]]
        specData[galaxy]["Wavelength"] /= 1 + redValue
        specData[galaxy]["Flux"] *= 1 + redValue
        specData[galaxy]["Error"] *= 1 + redValue
    return

In [146]:
# Grab test spectra
specFolder = "Spectra/SMACS_ERC/"
specDirTest, specListTest, specNamesTest = GetSpec()
specDataTest = ImportSpec(specDirTest, specListTest, specNamesTest)

# Find redshifts
redshiftFolder = GetDirStruct() + "NIRSpec/Glamdring/Data/NIRCamSpec.fits"
redshiftID = "ID"
redshiftZ = "Z_INPUT"
redshiftFileTest = tbl.QTable.read(redshiftFolder)

# Grab filters
filterFolder = "Throughputs/nircam_throughputs/mean_throughputs/"
filterDirTest, filterListTest, filterNamesTest = GetFilter()
filterDataTest = ImportFilter(filterDirTest, filterListTest, filterNamesTest)

# Calculate throughputs
outputFolder = "testData"
# BlueshiftSpectra(specDataTest, redshiftFileTest)
specTableTest = LoopSpectra(specDataTest, filterDataTest, redshiftFileTest)

# Grab test fluxes
testData = tbl.QTable.read(f"{GetDirStruct()}/NIRSpec/Glamdring/Data/NIRCamSpec.fits")



Importing spectra:   0%|          | 0/3 [00:00<?, ?it/s]

Importing filters:   0%|          | 0/30 [00:00<?, ?it/s]

Calculating throughputs:   0%|          | 0/3 [00:00<?, ?it/s]

  result = super().__array_ufunc__(function, method, *arrays, **kwargs)


In [151]:
specTableTest

ID,F070W Throughput,F070W Throughput Error,F090W Throughput,F090W Throughput Error,F115W Throughput,F115W Throughput Error,F140M Throughput,F140M Throughput Error,F150W2 Throughput,F150W2 Throughput Error,F150W Throughput,F150W Throughput Error,F162M Throughput,F162M Throughput Error,F164N Throughput,F164N Throughput Error,F182M Throughput,F182M Throughput Error,F187N Throughput,F187N Throughput Error,F200W Throughput,F200W Throughput Error,F210M Throughput,F210M Throughput Error,F212N Throughput,F212N Throughput Error,F250M Throughput,F250M Throughput Error,F277W Throughput,F277W Throughput Error,F300M Throughput,F300M Throughput Error,F322W2 Throughput,F322W2 Throughput Error,F323N Throughput,F323N Throughput Error,F335M Throughput,F335M Throughput Error,F356W Throughput,F356W Throughput Error,F360M Throughput,F360M Throughput Error,F405N Throughput,F405N Throughput Error,F410M Throughput,F410M Throughput Error,F430M Throughput,F430M Throughput Error,F444W Throughput,F444W Throughput Error,F460M Throughput,F460M Throughput Error,F466N Throughput,F466N Throughput Error,F470N Throughput,F470N Throughput Error,F480M Throughput,F480M Throughput Error,WLP4 Throughput,WLP4 Throughput Error,Balmer_left,Balmer_left_err,Balmer_right,Balmer_right_err,Balmer_ratio,Balmer_ratio_err
Unnamed: 0_level_1,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,nJy,W / (J m3),W / (J m3),W / (J m3),W / (J m3),Unnamed: 65_level_1,Unnamed: 66_level_1
int64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64
4590,,,,,,,,,,,,,,,,,,,,,,,,,,,,,37.71038985116001,3.073690031475175,35.77300854666388,2.8936671761435058,40.21518792883823,1.3845605095866502,56.78585924891328,6.4780355095499775,39.37895422444612,2.213085108295835,41.04774564308435,1.5452902934916064,39.8792042690869,2.185859748531461,33.046876270823105,7.445735474695367,40.92732533318183,2.2489958603085283,1.9657843776865676,1.492221721914653,101.03657154824144,2.2863622385461704,224.80045417975967,6.745124395456287,107.10454545016496,13.794236259591807,259.6308902443005,16.53770970729897,253.0191665788235,7.260753503242993,,,177685.6210450158,148971.31919369058,129321.83906916418,142106.14987790378,0.7278126294552619,1.0059616996059095
6355,,,,,,,,,,,,,,,,,,,,,,,,,,,,,106.0319035635302,3.143860266466537,106.21231413181656,2.9674075289135105,111.45477218773426,1.3931916454448494,286.0968553306949,7.491851985200487,140.5391473177916,2.244326026549833,114.78971066310524,1.5481098068323431,100.84296418957828,2.2085431609050605,49.15967640917487,7.446610211692012,195.28748518837736,2.584298932945875,810.8603297463145,4.836466431790129,236.597840905887,1.914412193193201,68.7210219987567,4.660225070540485,85.74328140726593,10.99135457205773,20.204279626028296,4.287482305955082,17.51351357793429,1.7971895903035129,,,471339.444631515,169520.50719672305,701483.0728196707,150695.85460395733,1.4882757656068395,0.6234840887451969
10612,,,,,,,,,,,,,,,,,,,,,,,,,,,,,70.63568975214622,2.9876179084682013,69.51552944990304,2.8169401939642764,66.20188935180802,1.360228222209329,67.46232906024098,6.400101067957965,71.89227640943264,2.213128477691793,66.37699767501319,1.524815875081814,65.97116192341063,2.1645444649165624,70.25454850520309,7.481483601776743,102.47382382942216,2.5574316597490783,374.89541469958834,4.326861243613829,114.29938192782951,1.884581881237786,24.86051663308393,2.889662120435564,0.1110602764555498,0.0186890320450371,0.0003499319239318,0.0001013507189011,19.680954229569167,3.4828524199541744,,,299966.44593309634,163593.65904500702,378032.8242274929,147252.83990984966,1.260250369175652,0.8446130200273309


In [None]:
# Function to calculate average offsets corrections
def CalcOffCorr(diffTables):
    corrRows = []
    print("Averaging correction factors within pointings...")
    for pointNum in tqdm(pointNums, desc="Pointings"):
        corrRow = [pointNum]
        colIndices = np.arange(
            diffTables[pointNum].colnames.index("Total High SNR Bands") + 1,
            diffTables[pointNum].colnames.index("Source_RA"),
        )
        colNames = diffTables[pointNum].colnames[colIndices[0] : colIndices[-1] : 2]
        errNames = diffTables[pointNum].colnames[
            colIndices[0] + 1 : colIndices[-1] + 1 : 2
        ]
        try:
            corrNames
        except NameError:
            corrNames = ["Pointing"] + [word + "_corr" for word in np.array([colNames, errNames]).transpose().flatten()]
        colFactsAll = []
        colErrsAll = []
        for colName, errName in tqdm(list(zip(colNames, errNames)), desc="Averaging difference factors", leave=False):
            condition = max(diffTables[pointNum]["Total High SNR Bands"])
            # condition = 2
            colFacts = np.where(diffTables[pointNum]["Total High SNR Bands"] >= condition, diffTables[pointNum][colName], np.nan)
            colErrs = np.where(diffTables[pointNum]["Total High SNR Bands"] >= condition, diffTables[pointNum][errName], np.nan)
            colFacts = colFacts[~np.isnan(colErrs)]
            colErrs = colErrs[~np.isnan(colErrs)]
            colFactsAll += list(colFacts)
            colErrsAll += list(colErrs)
            colFact = np.mean(colFacts)
            colErr = np.sqrt(np.mean(colErrs**2) / len(colErrs))
            corrRow += [colFact, colErr]
        corrRows += [corrRow]
    corrTable = tbl.QTable(rows=corrRows, names=corrNames)
    print("Averaging correction factors across pointings...")
    newRow = [np.nan]
    for colName in tqdm(corrTable.colnames[1:], desc="Photometry bands"):
        if colName.split('_')[1] == "Diff":
            newRow += [np.mean(corrTable[colName])]
        elif colName.split('_')[1] == "DiffErr":
            newRow += [np.sqrt(np.mean(corrTable[colName]**2)/len(corrTable))]
    aveTable = tbl.QTable(rows=[newRow], names=corrNames)
    corrTable = tbl.vstack([corrTable, aveTable])
    print("Calculating line of best fit")
    fitWave = [(float(centralWave.split('_')[0][1:-1]) * 10 * u.nm).value for centralWave in corrTable.colnames[1::2]]
    fitCorr = [corrFact.value for corrFact in newRow[1::2]]
    fitErr = [corrErr.value for corrErr in newRow[2::2]]
    fitData = odr.RealData(fitWave, fitCorr, sy=fitErr)
    fitObj = odr.ODR(fitData, odr.unilinear)
    fitResults = fitObj.run()
    print("Done.")
    return corrTable, fitResults

In [None]:
corrTable, corrResults = CalcOffCorr(diffTables)

In [None]:
corrResults.pprint()

In [None]:
np.array([float(item.split('_')[0][1:-1]) * 10 for item in corrTable.colnames[1:]])

In [None]:
print(np.polynomial.Polynomial.fit(np.array([float(item.split('_')[0][1:-1]) * 10 for item in corrTable.colnames[1:]]), np.array([item.value for item in list(corrTable[3])[1:]]), 1))