In [None]:
import numpy as np
from scipy.signal import find_peaks
import pandas as pd

In [None]:
def crosscheck(results):
    all_lines = []

    for dataset, lines in results.items():
        for snr, wave in lines:
            all_lines.append({
                "dataset": dataset,
                "wave": wave,
                "snr": snr
            })
    # track lines already used
    used = np.zeros(len(all_lines), dtype=bool)
    final_results = []

    for i, line in enumerate(all_lines):
        if used[i]:
            continue

        cluster = [line]
        used[i] = True

        for j in range(i + 1, len(all_lines)):
            if used[j]:
                continue
            if abs(all_lines[j]["wave"] - line["wave"]) <= 0.1:
                cluster.append(all_lines[j])
                used[j] = True

        datasets = {c["dataset"] for c in cluster}

        # "recovered in ≥ 2"
        if len(datasets) >= 2:
            final_results.append({
                "mean_wave": np.mean([c["wave"] for c in cluster]),
                "sigma_wave": np.std([c["wave"] for c in cluster]),
                "datasets": sorted(datasets),
                "snrs": [c["snr"] for c in cluster],
                "n_recovered": len(datasets)
            })

    return final_results

In [None]:
def load_results(gal_type,dataset,runtype):
    # load results
    path = f"{gal_type}_{dataset}_{runtype}_coadd_results.npz"
    data = np.load(path)

    wave = data['wavelength']    
    valid_bins = data['valid_bins']
    # smoothed residuals (gaussian)
    snr_narrow = data['snr_narrow']  #3 A
    # snr_broad = data['snr_broad']  #15 A
    noise = data['noise_narrow']

    noise = noise[valid_bins]
    valid_wave = wave[valid_bins]
    valid_snr_narrow = snr_narrow[valid_bins]
    result = set()
    peaks, _ = find_peaks(valid_snr_narrow)
    valleys, _ = find_peaks(-valid_snr_narrow)
    extrema = np.sort(np.concatenate((peaks, valleys)))
    for i in extrema:
        current_snr = valid_snr_narrow[i]
        current_wave = valid_wave[i] 
        if abs(current_snr) >= 2.5:
            to_add = (current_snr,current_wave)
            result.add(to_add)

    print('done loading results')
    return result

In [None]:
def reject_known_lines(matches):
    clean = []
    # http://astronomy.nmsu.edu/drewski/tableofemissionlines.html
    known_lines = [770.409, 780.324, 937.814, 949.742, 977.030, 989.790, 991.514,
                        991.579, 1025.722, 1031.912, 1037.613, 1066.660, 1215.670,
                        1238.821, 1242.804, 1260.422, 1264.730, 1302.168, 1334.532, 
                        1335.708, 1393.755, 1397.232, 1399.780, 1402.770, 1486.496,
                        1548.187, 1550.772, 1640.420, 1660.809, 1666.150, 1746.823,
                        1748.656, 1854.716, 1862.790, 1892.030, 1908.734, 2142.780,
                        2320.951, 2323.500, 2324.690, 2648.710, 2733.289, 2782.700,
                        2795.528, 2802.705, 2829.360, 2835.740, 2853.670, 2868.210,
                        2928.000, 2945.106, 3132.794, 3187.745, 3203.100, 3312.329,
                        3345.821, 3425.881, 3444.052, 3466.497, 3466.543, 3487.727,
                        3586.320, 3662.500, 3686.831, 3691.551, 3697.157, 3703.859, 
                        3711.977, 3721.945, 3726.032, 3728.815, 3734.369, 3750.158,
                        3758.920, 3770.637, 3797.904, 3835.391, 3839.270, 3868.760, 
                        3888.647, 3889.064, 3891.280, 3911.330, 3967.470, 3970.079,
                        4026.190, 4068.600, 4071.240, 4076.349, 4101.742, 4143.761,
                        4178.862, 4180.600, 4233.172, 4227.190, 4287.394, 4303.176,
                        4317.139, 4340.471, 4363.210, 4412.300, 4414.899, 4416.830, 
                        4452.098, 4471.479, 4489.183, 4491.405, 4510.910, 4522.634,
                        4555.893, 4582.835, 4583.837, 4629.339, 4634.140, 4640.640,
                        4641.850, 4647.420, 4650.250, 4651.470, 4658.050, 4685.710,
                        4711.260, 4740.120, 4861.333, 4893.370, 4903.070, 4923.927,
                        4958.911, 5006.843, 5018.440, 5084.770, 5145.750, 5158.890,
                        5169.033, 5176.040, 5197.577, 5200.257, 5234.625, 5236.060,
                        5270.400, 5276.002, 5276.380, 5302.860, 5309.110, 5316.615,
                        5316.784, 5335.180, 5424.220, 5517.709, 5537.873, 5637.600,
                        5677.000, 5695.920, 5720.700, 5754.590, 5801.330, 5811.980,
                        5875.624, 6046.440, 6087.000, 6300.304, 6312.060, 6347.100,
                        6363.776, 6369.462, 6374.510, 6516.081, 6548.050, 6562.819,
                        6583.460, 6716.440, 6730.810, 7002.230, 7005.870, 7065.196, 
                        7135.790, 7155.157, 7170.620, 7172.000, 7236.420, 7237.260,
                        7254.448, 7262.760, 7281.349, 7319.990, 7330.730, 7377.830,
                        7411.160, 7452.538, 7468.310, 7611.000, 7751.060, 7816.136,
                        7868.194, 7889.900, 7891.800, 8236.790, 8392.397, 8413.318,
                        8437.956, 8446.359, 8467.254, 8498.020, 8502.483, 8542.090,
                        8545.383, 8578.700, 8598.392, 8616.950, 8662.140, 8665.019,
                        8680.282, 8703.247, 8711.703, 8750.472, 8862.782, 8891.910, 
                        9014.909, 9068.600, 9229.014, 9531.100, 9545.969, 9824.130,
                        9850.260, 9913.000, 10027.730, 10031.160, 10049.368, 10286.730,
                        10320.490, 10336.410, 10746.800, 10830.340, 10938.086]
    for m in matches:
        if not any(abs(m["mean_wave"] - k) <= 0.1 for k in known_lines):
            clean.append(m)
    return clean

In [None]:
def main():
    all_types = ['BGS_loa_galaxy', 'LRG_loa_galaxy', 'BGS_loa_skyfiber', 'LRG_loa_skyfiber',
                'BGS_iron_galaxy', 'LRG_iron_galaxy', 'BGS_iron_skyfiber', 'LRG_loa_skyfiber']
    # initialize results dictionary
    results = {}
    for type in all_types:
        input_list = type.split('_')
        gal_type = input_list[0]
        dataset = input_list[1]
        runtype = input_list[2]
        key = f"{gal_type}_{dataset}_{runtype}_result"
        results[key] = load_results(input_list[0],input_list[1],input_list[2])
        print(f"{gal_type}_{dataset}_{runtype} has {len(results[key])} SNR points with value >= 2.5")
    
    print("crossmatching....")
    matches = crosscheck(results)

    print(f"Found {len(matches)} lines recovered in ≥ 2 datasets")
    print("Ignoring known lines...")
    final_result = reject_known_lines(matches)

    df_filtered = pd.DataFrame(final_result)

    df_filtered.to_csv('LOA_GAL_crossmatch_results.csv', index=False)
    print("Results saved to 'LOA_GAL_crossmatch_results.csv'")

    for m in final_result:
        print(
            f"λ = {m['mean_wave']:.2f} Å | "
            f"N = {m['n_recovered']} | "
            f"datasets = {m['datasets']}"
        )