In [None]:
print("Preparing environment...")

import platform

import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from astropy import table as tbl
from astropy import units as u
from astropy.io import fits
from scipy import constants as c
from scipy import interpolate
from tqdm.auto import tqdm
from itertools import chain

# %matplotlib widget

print("Done.")

In [None]:
### All functions


# Function to parse tables and set any rows with NaNs to zero, to avoid errors
def RemoveNaNs(tableObj):
    for col in tableObj.colnames:
        for x in range(0, len(tableObj[col]), 1):
            if np.isnan(tableObj[col][x]):
                tableObj[col][x] = 0
    return tableObj


# Function to use appropriate OS directory structure
def GetDirStruct():
    match platform.system():
        case "Windows":
            dirPrefix = "D:/"
        case "Linux" | "Ubuntu":
            dirPrefix = "/mnt/d/"
        case "macOS" | "Darwin":
            dirPrefix = "/Volumes/Storage/"
        case _:
            raise Exception(f"OS not recognised: \"{platform.system()}\". Please define a custom switch inside GetDirStruct().")
    return dirPrefix


# Function to grab list of spectra in a directory
def GetSpec():
    dirPrefix = GetDirStruct()
    specDir = dirPrefix + specFolder
    specList = !bash -c "ls {specDir}*1D.fits"
    specList = [file.split('/')[-1] for file in specList]
    specNames = [file.split('_')[0] for file in specList]
    return specDir, specList, specNames


# Function to import files to dictionary
def ImportSpec(specDir, specList, specNames):
    specData = {}
    for file, name in tqdm(list(zip(specList, specNames)), desc="Importing spectra"):
        specRaw = fits.open(specDir + file)
        specFlux = specRaw[1].data * u.Jy
        specWave = specRaw[9].data * u.m
        specData[name] = tbl.QTable([specWave, specFlux], names=("Wavelength", "Flux"))
        specData[name] = RemoveNaNs(specData[name])
        specRaw.close()
    return specData


# Function to plot the spectra
def PlotSpec(specData, specList, specNames):
    !bash -c "mkdir -p ../../Working_Directory/Apo_Phot_Utils/{outputFolder}plots"
    for file, name in tqdm(zip(specList, specNames), desc="Plotting galaxies", total=len(specList)):
        plt.plot(specData[name]["Wavelength"], specData[name]["Flux"])
        plt.xscale('log')
        plt.savefig(f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}plots/{name}.png")
        plt.close()
    return

# Function to grab list of filter curves
def GetFilter():
    dirPrefix = GetDirStruct()
    filterDir = dirPrefix + filterFolder
    filterList = !bash -c "ls {filterDir}"
    filterList = [file.split('/')[-1] for file in filterList]
    filterNames = [file.split('_')[0] for file in filterList]
    return filterDir, filterList, filterNames

# Function to import filters to a dictionary
def ImportFilter(filterDir, filterList, filterNames):
    filterData = {}
    for file, name in tqdm(zip(filterList, filterNames), desc="Importing filters", total=len(filterList)):
        filterData[name] = tbl.QTable.read(filterDir + file, format="ascii")
        filterData[name]["Microns"].unit = u.um
        filterData[name].rename_column("Microns", "Wavelength")
        filterData[name]["Throughput"].unit = u.dimensionless_unscaled
        filterData[name] = RemoveNaNs(filterData[name])
    return filterData

# Function to interpolate datapoints
def InterpFunc(funcXs, funcYs):
    funcCubic = interpolate.interp1d(funcXs, funcYs, kind='cubic')
    return funcCubic

# Function to find grid overlap for convolutions of target using input
def FindGrid(targetGrid, inputGrid, inputData):
    # Sort all arrays
    idxSorted = np.argsort(inputGrid)
    sortedGrid = inputGrid[idxSorted]
    sortedTarget = np.sort(targetGrid)
    # Find first value in input that overlaps with target
    if sortedGrid[0] < sortedTarget[0]:
        idxLeft = np.searchsorted(sortedGrid, sortedTarget[0], side="right")
        if inputGrid[idxLeft-1] == sortedTarget[0]:
            idxLeft -= 1
    else:
        idxLeft = 0
    # Find last value in input that overlaps with target
    if sortedGrid[-1] > sortedTarget[-1]:   
        idxRight = np.searchsorted(sortedGrid, sortedTarget[-1], side="left")
        if inputGrid[idxRight] != sortedTarget[-1]:
            idxRight -= 1
    else:
        idxRight = -1
    # Return the section of input that overlaps target
    overlapGrid = sortedGrid[idxLeft:idxRight]
    overlapData = inputData[idxSorted][idxLeft:idxRight]
    return overlapGrid, overlapData

# Function to manage convolution calculation from first array onto second array
def ConvolveFunc(firstXs, firstYs, secondXs, secondYs):
    # Interpolate datapoints of first array
    firstInterp = InterpFunc(firstXs.to(u.m), firstYs)
    # Find relevant convolution grid for first array based on second array
    secondOverlapGrid, secondOverlapData = FindGrid(firstXs, secondXs, secondYs)
    # Convolve first array onto second array's grid
    firstConvolved = firstInterp(secondOverlapGrid.to(u.m))
    return firstConvolved, secondOverlapGrid, secondOverlapData

# Function to convolve spectra and filters onto same grid
def MergeGrids(specFile, filterFile):
    specWave = specFile["Wavelength"]
    specFlux = specFile["Flux"]
    filterWave = filterFile["Wavelength"]
    filterThrough = filterFile["Throughput"]
    # Find mutual overlaps
    specConvolved, filterOverlapGrid, filterOverlapData = ConvolveFunc(specWave, specFlux, filterWave, filterThrough)
    filterConvolved, specOverlapGrid, specOverlapData = ConvolveFunc(filterWave, filterThrough, specWave, specFlux)
    # Save as tables
    specTable = tbl.Table([np.append(specOverlapGrid, filterOverlapGrid).to(u.m), np.append(specOverlapData, specConvolved*u.Jy)], names=("Wavelength", "Flux"))
    specTable.sort("Wavelength")
    filterTable = tbl.Table([np.append(filterOverlapGrid, specOverlapGrid).to(u.m), np.append(filterOverlapData, filterConvolved)], names=("Wavelength", "Throughput"))
    filterTable.sort("Wavelength")
    # Join tables
    mergedTable = tbl.join(specTable, filterTable, keys="Wavelength")
    return mergedTable

# Function to shift flux to photon space
def ShiftPhotonSpace(mergedTable):
    mergedTable["Flux"] *= mergedTable["Wavelength"]
    mergedTable["Flux"].unit *= mergedTable["Wavelength"].unit
    return mergedTable

# Function to integrate two functions on the same grid
def IntegFunc(firstYs, secondYs, commonGrid):
    funcIntegrated = np.trapz(firstYs * secondYs, x=commonGrid) * (firstYs.unit * secondYs.unit * commonGrid.unit)
    return funcIntegrated

# Function to calculate throughput on a merged grid
def CalcThroughput(mergedTable):
    mergedWave = mergedTable["Wavelength"]
    mergedFlux = mergedTable["Flux"]
    mergedThrough = mergedTable["Throughput"]
    mergedOut = IntegFunc(mergedFlux, mergedThrough, mergedWave)
    return mergedOut

# Function to find throughput given a spectrum and filter
def FindThroughput(specFile, filterFile):
    # Convolve onto the same grid
    mergedTable = MergeGrids(specFile, filterFile)
    # Shift to photon space
    mergedTable = ShiftPhotonSpace(mergedTable)
    # Integrate through the filter
    return CalcThroughput(mergedTable), mergedTable

# Function to normalise the throughput to the correct zero-point
def NormaliseValue(specTable, specValue):
    specThrough = specTable["Throughput"]
    specGrid = specTable["Wavelength"]
    # norm_ref = 10**(48.6/(-2.5)) # reference zero magnitude
    normRef = 1 * u.Jy.to(u.W / ((u.m)**2 * u.Hz)) #* 10**(-9) # reference flat-value in f_nu
    normYs = normRef * c.c / specGrid**2 * specGrid # not squared, f_lambda in photon_space
    normValue = IntegFunc(specThrough, normYs, specGrid)
    specNormed = (specValue / normValue).to(u.nJy)
    return specNormed

def BalmerBreak(specFile, specName, redshiftFile):
    # Define regions
    balmerLeftRange = np.array([3500, 3650]) * u.angstrom
    balmerRightRange = np.array([3800, 3950]) * u.angstrom
    # Grab relevant grids
    for row in range(0, len(redshiftFile), 1):
        if int(specName) == int(redshiftFile[row][redshiftID]) and redshiftFile[row][redshiftZ] > 0:
            redshiftFactor = 1 + redshiftFile[row][redshiftZ]
            break
        else:
            redshiftFactor = np.nan
    if redshiftFactor != 0:
        balmerLeftWave, balmerLeftSpec = FindGrid(balmerLeftRange, specFile["Wavelength"]/redshiftFactor, specFile["Flux"]*redshiftFactor)
        balmerLeftTable = tbl.Table([balmerLeftWave, balmerLeftSpec], names=("Wavelength", "Flux"))
        balmerRightWave, balmerRightSpec = FindGrid(balmerRightRange, specFile["Wavelength"]/redshiftFactor, specFile["Flux"]*redshiftFactor)
        balmerRightTable = tbl.Table([balmerRightWave, balmerRightSpec], names=("Wavelength", "Flux"))
        # Shift to photon space
        balmerLeftTable = ShiftPhotonSpace(balmerLeftTable)
        balmerRightTable = ShiftPhotonSpace(balmerRightTable)
        # Calculate average flux
        balmerLeftVal = np.sum(balmerLeftTable["Flux"]) / len(balmerLeftTable["Flux"]) * balmerLeftTable["Flux"].unit
        balmerRightVal = np.sum(balmerRightTable["Flux"]) / len(balmerRightTable["Flux"]) * balmerRightTable["Flux"].unit
        # Calculate ratio
        balmerRatio = balmerRightVal/balmerLeftVal
        balmerVals = [balmerLeftVal, balmerRightVal, balmerRatio]
    return balmerVals

# Function to loop through filters and balmer breaks for one spectrum
def HandleSpectrum(specFile, specName, filterData, redshiftFile):
    specValues = []
    for filterFile in filterData.values():
        specValue, specTable = FindThroughput(specFile, filterFile)
        specValues += [NormaliseValue(specTable, specValue), np.nan]
    specValues += BalmerBreak(specFile, specName, redshiftFile)
    return specValues

# Function to loop through each spectrum and save values to a table
def LoopSpectra(specData, filterData, redshiftFile):
    !bash -c "mkdir -p ../../Working_Directory/Apo_Phot_Utils/{outputFolder}"
    rows = []
    for specName, specFull in tqdm(zip(specData.keys(), specData.values()), desc="Galaxies", total=len(specData.keys())):
        rows += [[int(specName)] + HandleSpectrum(specFull, specName, filterData, redshiftFile)]
    names = ["ID"] + list(chain.from_iterable([[key + " Through", key + " through Error"] for key in filterData.keys()])) + ["Balmer_left", "Balmer_right", "Balmer_ratio"]
    outTable = tbl.Table(rows=rows, names=names)
    outTable.write(f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}throughputs.fits", overwrite=True)
    return outTable

# Function to grab photometry files
def ImportPhot(photDir):
    photTable = tbl.QTable.read(f"{photDir}/summary.fits")
    return photTable

# Function to grab list of slitloss corrections
def GrabSlitloss():
    lossDir = GetDirStruct() + lossFolder
    lossList = {}
    for lossPoint in tqdm(lossPoints, desc="Grabbing pointings"):
        lossList[lossPoint] = {}
        for lossExp in tqdm(lossExps, desc="Grabbing exposures", leave=False):
            lossNames = !bash -c "ls {lossDir}/{lossPoint}/{lossPrefix01}{lossPoint}_{lossExp}{lossPrefix02}*{lossSuffix}"
            # lossList = !bash -c "ls {lossDir}/{lossPrefix}*{lossSuffix}"
            lossNames = [f"{file.split('/')[-2]}/{file.split('/')[-1]}" for file in lossNames]
            lossIDs = [file.split('_')[-3] for file in lossNames]
            lossIDs = [file[5:11] for file in lossIDs]
            lossIDs = [int(file) for file in lossIDs]
            lossList[lossPoint][lossExp] = tbl.Table([lossIDs, lossNames], names=("ID", "Name"))
    return lossDir, lossList

# Function to import slitloss corrections
def ImportSlitloss(lossDir, lossList, photTable):
    lossData = {}
    IDList = tbl.Table([photTable["ID"]])
    for lossPoint in tqdm(lossList.keys(), desc="Matching pointings"):
        lossData[lossPoint] = {}
        for lossExp in tqdm(lossList[lossPoint].keys(), desc="Matching exposures", leave=False):
            lossData[lossPoint][lossExp] = {}
            lossData[lossPoint][lossExp]["Index"] = tbl.join(lossList[lossPoint][lossExp], IDList, keys="ID")
            for row in tqdm(range(0, len(lossData[lossPoint][lossExp]["Index"]), 1), desc="Galaxies", leave=False):
                galID = lossData[lossPoint][lossExp]["Index"][row]["ID"]
                galName = lossData[lossPoint][lossExp]["Index"][row]["Name"]
                lossData[lossPoint][lossExp][galID] = tbl.QTable.read(lossDir + galName, format="ascii")
                lossData[lossPoint][lossExp][galID].rename_column("col1", "Slitloss")
                lossData[lossPoint][lossExp][galID].rename_column("col2", "Wavelength")
                lossData[lossPoint][lossExp][galID]["Slitloss"].unit = u.dimensionless_unscaled
                lossData[lossPoint][lossExp][galID]["Wavelength"].unit = u.um
            del lossData[lossPoint][lossExp]["Index"]
    return lossData

# Function to calculate photometry ratios
def CalcSlitRatio(photTable, specTable):
    combTable = tbl.join(photTable, specTable, keys="ID")
    firstFlux = 1
    firstThrough = len(photTable.colnames)
    fluxCols = combTable.colnames[firstFlux:firstThrough:2]
    throughCols = combTable.colnames[firstThrough::2]
    ratioTable = tbl.QTable([combTable["ID"]])
    fluxNames = []
    for col in combTable.colnames[1:]:
        combTable[col] = np.where(combTable[col] > 0, combTable[col], np.nan)
    for fluxCol in fluxCols:
        fluxName = fluxCol.split()[0]
        for throughCol in throughCols:
            throughName = throughCol.split()[0]
            if fluxName == throughName:
                ratioTable.add_columns([combTable[fluxCol], combTable[throughCol], combTable[fluxCol]/combTable[throughCol], np.nan], names=(fluxCol, throughCol, f"{fluxName} Ratio", f"{fluxName} Ratio Error"))
                fluxNames += [fluxName]
    ratioTable.write(f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}ratios.fits", overwrite=True)
    return ratioTable, fluxNames

# Function to calculate slit-losses at centre of filters
def CalcLossTable(lossData, fluxNames):
    lossTables = {}
    for point in lossData.keys():
        lossTables[point] = {}
        for exposure in lossData[point].keys():
            rows = []
            for galaxy in lossData[point][exposure].keys():
                row = []
                row += [galaxy]
                lossInterpolated = interpolate.interp1d(lossData[point][exposure][galaxy]["Wavelength"], lossData[point][exposure][galaxy]["Slitloss"])
                for fluxName in fluxNames:
                    centralWave = float(fluxName[1] + '.' + fluxName[2:4]) * u.um
                    centralLoss = lossInterpolated(centralWave)
                    row += [1/centralLoss * u.dimensionless_unscaled]
                rows += [row]
            names = ["ID"] + [f"{fluxName} slitloss" for fluxName in fluxNames]
            lossTables[point][exposure] = tbl.QTable(rows=rows, names=names)
    return lossTables

# Function to calculate slit-loss factor between photometry and model
def CalcDiffTable(ratioTable, lossTables):
    diffTable = tbl.QTable([ratioTable["ID"]])
    # for col in ratioTable.colnames[3::4]:
    #     ratioTable[col] = np.where(ratioTable[col] >= 0.02, ratioTable[col], np.nan)
    for point in lossTables.keys():
        for exposure in lossTables[point].keys():
            matchedTable = tbl.join(ratioTable, lossTables[point][exposure], keys="ID")
            firstRatio = 3
            firstLoss = len(ratioTable.colnames)
            ratioNames = matchedTable.colnames[firstRatio:firstLoss:4]
            lossNames = matchedTable.colnames[firstLoss:]
            tempTable = tbl.QTable([matchedTable["ID"]])
            for ratioName in ratioNames:
                for lossName in lossNames:
                    if ratioName.split()[0] == lossName.split()[0]:
                        tempTable.add_column(matchedTable[ratioName]/matchedTable[lossName], name=f"{ratioName.split()[0]}_{point}_{exposure}")
            diffTable = tbl.join(diffTable, tempTable, keys="ID", join_type='outer')
    return diffTable

# Function to append average differences across pointings and exposures
def CalcAverageDiff(diffTable):
    colNames = diffTable.colnames[1:]
    k = len(colNames)
    for i, colName_i in enumerate(colNames):
        stepSize = 0
        jumpSize = 0
        for j, colName_j in enumerate(colNames[i+1:]):
            if colName_i.split('_')[0:2] == colName_j.split('_')[0:2]:
                stepSize = j + 1
                break
        for j, colName_j in enumerate(colNames[i+stepSize::stepSize]):
            if colName_i.split('_')[0] == colName_j.split('_')[0] and colName_i.split('_')[1] != colName_j.split('_')[1]:
                jumpSize = (j + 1) * stepSize
                break
        if 0 not in [stepSize, jumpSize]:
            k = 0
            break
    while k < len(colNames):
        for i, colName_i in enumerate(colNames[k:k+stepSize:1]):
            tempCol = diffTable[colName_i].copy().filled(0)
            count = np.where(tempCol > 0, 1, 0)
            for j, colName_j in enumerate(colNames[k+i+stepSize:k+i+jumpSize:stepSize]):
                tempCol += diffTable[colName_j].filled(0)
                count += np.where(diffTable[colName_j].filled(0) > 0, 1, 0)
            count = np.where(count > 0, count, np.nan)
            diffTable.add_column(tempCol/count, name=f"{colName_i.split('_')[0]}_{colName_i.split('_')[1]}")
        k += jumpSize
    for i, colName_i in enumerate(colNames[0:stepSize]):
        tempCol = diffTable[colName_i].copy().filled(0)
        count = np.where(tempCol > 0, 1, 0)
        for j, colName_j in enumerate(colNames[i+stepSize::stepSize]):
            tempCol += diffTable[colName_j].filled(0)
            count += np.where(diffTable[colName_j].filled(0) > 0, 1, 0)
        count = np.where(count > 0, count, np.nan)
        diffTable.add_column(tempCol/count, name=f"{colName_i.split('_')[0]}")
    return diffTable

# Function to handle slit-losses
def HandleSlitloss(lossData, photTable, specTable):
    print("Calculating flux ratios...")
    ratioTable, fluxNames = CalcSlitRatio(photTable, specTable)
    print("Interpolating slit-losses...")
    lossTables = CalcLossTable(lossData, fluxNames)
    print("Calculating slit-loss factors...")
    diffTable = CalcDiffTable(ratioTable, lossTables)
    diffTable = CalcAverageDiff(diffTable)
    diffTable.write(f"../../Working_Directory/Apo_Phot_Utils/{outputFolder}differences.fits", overwrite=True)
    return diffTable

In [None]:
print("Importing data...")

# Directories
try:
    filterFolder
    specFolder
    outputFolder
except NameError:
    filterFolder = "Throughputs/nircam_throughputs/mean_throughputs/"
    specFolder = "Spectra/HST_Deep/prism_v1.5/"
    outputFolder = "HST_Deep/prism_v1.5/"

# Spectra
try:
    specData
except NameError:
    specDir, specList, specNames = GetSpec()
    specData = ImportSpec(specDir, specList, specNames)

# Filters
try:
    filterData
except NameError:
    filterDir, filterList, filterNames = GetFilter()
    filterData = ImportFilter(filterDir, filterList, filterNames)

# Redshift catalogue
try:
    redshiftFile
except NameError:
    redshiftFolder = GetDirStruct() + "Redshifts/Deep_HST.csv"
    redshiftID = "ID"
    redshiftZ = "Assigned_redshift"
    redshiftFile = tbl.Table.read(redshiftFolder)

print("Plotting images...")

# PlotSpec(specData, specList, specNames)

print("Calculating throughputs...")

specTable = LoopSpectra(specData, filterData, redshiftFile)

print("Done.")

In [None]:
print("Importing data...")

# Directories
try:
    filterFolder
    specFolder
    outputFolder
    photDir
    lossFolder
    lossPrefix
    lossSuffix
except NameError:
    print("Defining directories...")
    filterFolder = "Throughputs/nircam_throughputs/mean_throughputs/"
    specFolder = "Spectra/HST_Deep/prism_v1.5/"
    outputFolder = "HST_Deep/prism_v1.5/"
    photDir = "../../Working_Directory/Apo_Phot/HST_Deep/"
    lossFolder = "Slit-losses/HST/pathlosses_deep_R100/R100_v0/"
    lossPoints = ["p01", "p02", "p03"]
    lossPrefix01 = "pathlosses_correction_deep_hst_1x1_"
    lossPrefix02 = "_idcat"
    lossExps = ["exp00", "exp01", "exp02"]
    lossSuffix = "_v0_point.txt"

# Spectra
try:
    specData
except NameError:
    print("Importing spectra...")
    specDir, specList, specNames = GetSpec()
    specData = ImportSpec(specDir, specList, specNames)

# Filters
try:
    filterData
except NameError:
    print("Importing filters...")
    filterDir, filterList, filterNames = GetFilter()
    filterData = ImportFilter(filterDir, filterList, filterNames)
    
# Photometry
try:
    photTable
except NameError:
    print("Importing apodised fluxes...")
    photTable = ImportPhot(photDir)

# Slitlosses
try:
    lossData
except NameError:
    print("Importing slitlosses...")
    lossDir, lossList = GrabSlitloss()
    lossData = ImportSlitloss(lossDir, lossList, photTable)

print("\nFinding slit-losses ratios:")

diffTable = HandleSlitloss(lossData, photTable, specTable)

print("\nDone.")

In [None]:
float(filterNames[0][1] + '.' + filterNames[0][2:4])

In [None]:
lossData

In [None]:
testPrefix = "/Volumes/Storage/Slit-losses/HST/pathlosses_deep_R100/R100_v0/p01/pathlosses_correction_deep_hst_1x1_p01_exp00_idcat"
testSuffix = "_v0_point.txt"

In [None]:
lossesList = !bash -c "ls {testPrefix}*{testSuffix}"
lossesList = [file.split('/')[-1] for file in lossesList]
lossesList = [file.split('_')[-3] for file in lossesList]
lossesNames = [file[5:11] for file in lossesList]
lossesIDs = [int(file) for file in lossesNames]
lossesTable = tbl.Table([lossesIDs, lossesNames], names=("ID", "Name"))
lossesTable.add_column(("pathlosses_correction_deep_hst_1x1_p01_exp00_idcat_" + lossesTable["Name"] + testSuffix), name=("fileName"))

In [None]:
lossesTable.write("lossesTable.csv")

In [None]:
photNames

In [None]:
plt.close("all")

# plt.plot(Spectrum["Wavelength"], Spectrum["Flux"])
# plt.plot(Filter["Wavelength"], Filter["Throughput"] * 10 **(-13))
plt.plot(Filter_grid, Filter_conv * 10 ** (-13), color="green")
plt.plot(Filter_grid, Spectrum_flux, color="blue")
plt.plot(Filter_grid, Filter_conv * Spectrum_flux, color="red")