In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
from tabulate import tabulate
from astropy.time import Time
from tqdm import tqdm
import csv
import os
import glob
import multiprocessing
from collections import defaultdict
from scipy.spatial import cKDTree


In [None]:
# prelim frequency reduction, remove files with frequencies all less than 2 GHz


directory = "/datag/users/ctremblay/"

# Get list of candidate files
files = [f for f in os.listdir(directory) if f.startswith("Summer_Project") and f.endswith(".pkl")]
full_paths = [os.path.join(directory, f) for f in files]

def check_file(file_path):
    try:
        df = pd.read_pickle(file_path)
        if not df.empty and 'signal_frequency' in df.columns:
            min_freq = df['signal_frequency'].min()
            if pd.notna(min_freq) and min_freq >= 2000:
                return file_path
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
    return None

if __name__ == "__main__":
    with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
        results = pool.map(check_file, full_paths)

    high_freq_files = [r for r in results if r is not None]
    print(f"\n {len(high_freq_files)} files with frequency ≥ 2000 MHz")


In [None]:


# --- 1. Load Crickets CSV ---
crickets_path = "Full_Crickets.csv"
intervals = pd.read_csv(crickets_path)
intervals.columns = intervals.columns.str.strip()

# Rename columns for clarity
intervals = intervals.rename(columns={
    "rfi_bin_bots": "start_frequency",
    "rfi_bin_tops": "end_frequency"
})

# Ensure frequencies are floats and drop invalid rows
intervals['start_frequency'] = pd.to_numeric(intervals['start_frequency'], errors='coerce')
intervals['end_frequency'] = pd.to_numeric(intervals['end_frequency'], errors='coerce')
intervals = intervals.dropna(subset=['start_frequency', 'end_frequency'])

# --- 2. Define NRAO manual bands ---
# --- 2. Define NRAO manual bands ---
nrao_ranges = [
    (2178.0, 2195.0), (2106.4, 2106.4), (2204.5, 2204.5), (2180, 2290),
    (2227.5, 2231.5), (2246.5, 2252.5), (2268.5, 2274.5), (2282.5, 2288.5),
    (2314.5, 2320.5), (2320.0, 2332.5), (2324.5, 2330.5), (2332.5, 2345.0),
    (2334.5, 2340.5), (2387.5, 2387.5), (2400.0, 2483.5), (2411.0, 2413.0),
    (2483.5, 2500.0), (2741.0, 2741.0), (2791.0, 2791.0), (3700.0, 4200.0),
    (5648.5, 5663.5), (5659.5, 5670.5), (5695.5, 5704.5), (5742.5, 5757.5),
    (5765.0, 5769.0), (5796.0, 5804.0), (6108.1, 6138.1), (6137.75, 6167.75),
    (6182.0, 6212.0), (6360.14, 6390.14), (6389.79, 6419.79), (6772.0, 6778.0),
    (7250.0, 7850.0), (9300.0, 9900.0), (9300.0, 9500.0), (10740.0, 10770.0),
    (10820.0, 10850.0), (10957.0, 10993.0), (11037.0, 11073.0), (11230.0, 11260.0),
    (11310.0, 11340.0), (11447.0, 11483.0), (11527.0, 11563.0), (11700.0, 12000.0),
    (12000.0, 12700.0), (13400.0, 13750.0), (17800.0, 20200.0), (29500.0, 30000.0),
    (34875.0, 34875.0), (36286.0, 36286.0)
]

# --- 3. Combine Crickets + Manual RFI bands ---
manual_df = pd.DataFrame(nrao_ranges, columns=['start_frequency', 'end_frequency'])
combined_df = pd.concat([intervals[['start_frequency', 'end_frequency']], manual_df], ignore_index=True)

# --- 4. Merge overlapping bands ---
def merge_overlapping_bands(bands, tol=1e-6):
    bands = sorted(bands, key=lambda x: x[0])
    merged = []
    for band in bands:
        if not merged:
            merged.append(band)
        else:
            prev_start, prev_end = merged[-1]
            curr_start, curr_end = band
            if curr_start <= prev_end + tol:
                # Overlapping or adjacent
                merged[-1] = (prev_start, max(prev_end, curr_end))
            else:
                merged.append(band)
    return merged

# Convert to list of tuples, merge
combined_bands = list(zip(combined_df['start_frequency'], combined_df['end_frequency']))
merged_bands = merge_overlapping_bands(combined_bands)

# --- 5. Save merged bands to CSV ---
merged_df = pd.DataFrame(merged_bands, columns=['start_frequency', 'end_frequency'])
save_path = os.path.join(os.getcwd(), "Full_Crickets_merged.csv")
merged_df.to_csv(save_path, index=False)
print(f" Saved merged RFI band list to {save_path}")


In [None]:
#NRAO and CRICKETS RFI frequency elimination

# --- Update these paths ---
filtered_files = high_freq_files  # List of full file paths
output_dir = '/datax/scratch/ellambishop/rfi_removed_hits'
os.makedirs(output_dir, exist_ok=True)

# --- Load and clean RFI bands ---
intervals = pd.read_csv("Full_Crickets_merged.csv")
intervals.columns = intervals.columns.str.strip()
intervals = intervals.rename(columns={
    "rfi_bin_bots": "start_frequency",
    "rfi_bin_tops": "end_frequency"
})
intervals['start_frequency'] = pd.to_numeric(intervals['start_frequency'], errors='coerce')
intervals['end_frequency'] = pd.to_numeric(intervals['end_frequency'], errors='coerce')
rfi_bands = intervals.dropna(subset=['start_frequency', 'end_frequency'])[['start_frequency', 'end_frequency']].values


def filter_hits_by_rfi(df):
    """Vectorized RFI filtering using NumPy broadcasting."""
    freqs = df['signal_frequency'].values
    keep_mask = np.ones(len(freqs), dtype=bool)

    for low, high in rfi_bands:
        keep_mask &= ~((freqs >= low) & (freqs <= high))  # Mask out RFI frequencies

    return df[keep_mask]


def process_file(filepath):
    input_path = filepath
    output_path = os.path.join(output_dir, os.path.basename(filepath))

    try:
        df = pd.read_pickle(input_path)

        # Pre-filter: only keep signal_frequency ≥ 2000 MHz
        df = df[df['signal_frequency'] >= 2000]

        # Apply vectorized RFI filtering
        df_filtered = filter_hits_by_rfi(df)

        if df_filtered.empty:
            print(f"{filepath}: No hits remain after filtering, skipping save.")
            return

        df_filtered.to_pickle(output_path)
        print(f"{filepath}: Saved filtered hits ({len(df_filtered)} rows)")

    except Exception as e:
        print(f"Error processing {filepath}: {e}")


if __name__ == "__main__":
    print(f"Processing {len(filtered_files)} files with multiprocessing...")
    with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
        pool.map(process_file, filtered_files)

    print("Done filtering all files.")


In [6]:
# sorts through each file by time start and frequency to find coherent hits that have incoherent counterparts and using it later for power analysis 

file_list = glob.glob("/datax/scratch/ellambishop/rfi_removed_hits/Summer*.pkl")
freq_tolerance = 3e-3

# Use tstart as keys consistently
incoh_freqs_by_tstart = defaultdict(set)
phase_freqs_by_tstart = defaultdict(set)
incoh_data_by_tstart = defaultdict(list)  # store (freq, snr, power) per tstart

for file in file_list:
    if 'Dec-45' in file:
        continue
    try:
        df = pd.read_pickle(file)
    except Exception as e:
        print(f"Error reading {file}: {e}")
        continue

    required_cols = {"source_name", "signal_frequency", "tstart", "signal_snr", "signal_power", "file_uri"}
    if not required_cols.issubset(df.columns):
        print(f"Skipping {file}: missing required columns")
        continue

    df["source_lower"] = df["source_name"].fillna("").str.lower()

    for (file_uri,tstart), group in df.groupby(["file_uri","tstart"]):
        incoh_df = group[group["source_lower"].str.contains("incoherent") & (group["signal_snr"] > 0)]
        coh_df = group[~group["source_lower"].str.contains("incoherent") & (group["signal_snr"] > 0)]
        phase_df = group[group["source_lower"].str.contains("phase_center")]

        # Add frequencies to sets (extend the existing set, don't overwrite)
        incoh_freqs_by_tstart[tstart].update(incoh_df["signal_frequency"].unique())
        phase_freqs_by_tstart[tstart].update(phase_df["signal_frequency"].unique())

        if incoh_df.empty or coh_df.empty:
            continue

        coh_freqs = coh_df["signal_frequency"].values.reshape(-1, 1)
        tree = cKDTree(coh_freqs)

        incoh_freqs = incoh_df["signal_frequency"].values.reshape(-1, 1)
        dists, idxs = tree.query(incoh_freqs, distance_upper_bound=freq_tolerance)

        matched_mask = idxs != len(coh_freqs)
        matched_incoh = incoh_df.iloc[np.where(matched_mask)[0]]
        matched_idxs = idxs[matched_mask]

        for incoh_idx, coh_idx in zip(matched_incoh.index, matched_idxs):
            incoh_row = df.loc[incoh_idx]
            coh_row = coh_df.iloc[coh_idx]

            freq = incoh_row["signal_frequency"]
            incoh_snr = incoh_row["signal_snr"]
            incoh_power = incoh_row["signal_power"]
            coh_power = coh_row["signal_power"]
            coh_snr = coh_row["signal_snr"]

            # Optional: compute ratio now
            power_ratio = coh_power / incoh_power if incoh_power > 0 else np.nan

            # Save all useful values
            incoh_data_by_tstart[tstart].append((freq, incoh_snr, incoh_power, coh_power, coh_snr, power_ratio))


# After all files processed, build max snr and power dicts keyed by tstart
incoh_freq_to_max_snr_by_tstart = {}
incoh_freq_to_max_power_by_tstart = {}
coh_freq_to_power_by_tstart = {}
coh_freq_to_snr_by_tstart = {}


for tstart, data in incoh_data_by_tstart.items():
    freq_to_incoh_snrs = defaultdict(list)
    freq_to_incoh_powers = defaultdict(list)
    freq_to_coh_snrs = defaultdict(list)
    freq_to_coh_powers = defaultdict(list)
    freq_to_ratios = defaultdict(list)

    for freq, incoh_snr, incoh_power, coh_power, coh_snr, power_ratio in data:
        freq_to_incoh_snrs[freq].append(incoh_snr)
        freq_to_incoh_powers[freq].append(incoh_power)
        freq_to_coh_snrs[freq].append(coh_snr)
        freq_to_coh_powers[freq].append(coh_power)

    incoh_freq_to_max_snr_by_tstart[tstart] = {f: max(v) for f, v in freq_to_incoh_snrs.items()}
    incoh_freq_to_max_power_by_tstart[tstart] = {f: max(v) for f, v in freq_to_incoh_powers.items()}
    coh_freq_to_snr_by_tstart[tstart] = {f: max(v) for f, v in freq_to_coh_snrs.items()}
    coh_freq_to_power_by_tstart[tstart] = {f: max(v) for f, v in freq_to_coh_powers.items()}


KeyboardInterrupt: 

In [18]:
# splits up files based on observation type while appending the incoherent pairs found in the matched code above dictionary 

#clean out initial drift rates=0 and snr > 16
def set_filter(file, drift_max=0):
    global df_new
    df_new = file

    # Apply filtering thresholds
    df_new = df_new[(df_new['signal_drift_rate'] != drift_max)]
    df_new = df_new[(df_new['signal_snr'])<=16]
    df_new['signal_frequency'] = df_new['signal_frequency'].round(4)    
    
    return df_new


file_list = glob.glob("/datax/scratch/ellambishop/rfi_removed_hits/Summer*.pkl")
# Initialize category buckets
vlass_incoherent = []
vlass_phase_center = []
vlass_other = []

non_vlass_incoherent = []
non_vlass_phase_center = []
non_vlass_other = []

output_dir = "/datax/scratch/ellambishop/hit_test_clean"
os.makedirs(output_dir, exist_ok=True)

def categorize_and_append(df_subset, incoh_list, phase_list, other_list):
    incoherent_mask = df_subset["source_lower"].str.contains("incoherent")
    phase_center_mask = df_subset["source_lower"].str.contains("phase_center")
    other_mask = ~(incoherent_mask | phase_center_mask)

    if incoherent_mask.any():
        incoh_list.append(df_subset[incoherent_mask])
    if phase_center_mask.any():
        phase_list.append(df_subset[phase_center_mask])
    if other_mask.any():
        other_list.append(df_subset[other_mask])

def save_if_not_empty(dfs_list, filename):
    if dfs_list:
        out_path = os.path.join(output_dir, filename)
        pd.concat(dfs_list, ignore_index=True).to_pickle(out_path)
        print(f"Saved: {out_path}")

print("Starting processing with tstart grouping...")
print("Starting processing with tstart grouping...")
total_coh_and_incoh = 0
total_coh_only = 0
total_incoh_only = 0

for file in file_list:
    if 'Dec-45' in file:
        continue
    try:
        df = pd.read_pickle(file)
    except Exception as e:
        print(f"Error reading {file}: {e}")
        continue

    print(f"\nProcessing file: {file}")

    required_cols = {"source_name", "signal_frequency", "tstart"}
    if not required_cols.issubset(df.columns):
        print(f"Skipping {file}, missing required columns.")
        continue

    df["source_lower"] = df["source_name"].fillna("").str.lower()
    df["signal_frequency"] = df["signal_frequency"].round(4)  # Ensure exact rounding

    all_subframes = []
    for tstart, group in df.groupby("tstart"):

        snr_dict = incoh_freq_to_max_snr_by_tstart.get(tstart, {})
        power_dict = incoh_freq_to_max_power_by_tstart.get(tstart, {})
        phase_freqs = phase_freqs_by_tstart.get(tstart, set())

        group = group.copy()  # avoid SettingWithCopyWarning
        group["incoh_snr"] = group["signal_frequency"].map(snr_dict)
        group["incoh_power"] = group["signal_frequency"].map(power_dict)
        group["freq_has_incoherent_counterpart"] = group["incoh_snr"].notnull()
        group["freq_has_phase_center_counterpart"] = group["signal_frequency"].isin(phase_freqs)
        
        num_flagged = group["freq_has_incoherent_counterpart"].sum()
        print(f"tstart: {tstart} | Flagged incoherent freqs: {num_flagged}")

        if num_flagged > 0:
            print("Sample flagged rows:")
            print(group[group["freq_has_incoherent_counterpart"]].head(3)[["tstart", "signal_frequency", "incoh_snr", "source_lower"]])

        all_subframes.append(group)

    df = pd.concat(all_subframes, ignore_index=True)

    print(f"Total rows before filtering: {len(df)}")
    df = set_filter(df, drift_max=0)
    print(f"Total rows after filtering: {len(df)}")

    coh_df = df[~df["source_lower"].str.contains("incoherent")].copy()
    incoh_df = df[df["source_lower"].str.contains("incoherent")].copy()

    coh_keys = set(tuple(x) for x in coh_df[["tstart", "signal_frequency"]].drop_duplicates().values)
    incoh_keys = set(tuple(x) for x in incoh_df[["tstart", "signal_frequency"]].drop_duplicates().values)

    both = coh_keys & incoh_keys
    coh_only = coh_keys - incoh_keys
    incoh_only = incoh_keys - coh_keys


    # Add to global totals
    total_coh_and_incoh += len(both)
    total_coh_only += len(coh_only)
    total_incoh_only += len(incoh_only)

    # VLASS / non-VLASS split
    vlass_mask = df["file_uri"].str.contains("vlass", case=False, na=False)
    df_vlass = df[vlass_mask].copy()
    df_non_vlass = df[~vlass_mask].copy()

    if not df_vlass.empty:
        categorize_and_append(df_vlass, vlass_incoherent, vlass_phase_center, vlass_other)
    if not df_non_vlass.empty:
        categorize_and_append(df_non_vlass, non_vlass_incoherent, non_vlass_phase_center, non_vlass_other)

print("\nFinal totals:")
print(f"Total coherent + incoherent matches: {total_coh_and_incoh}")
print(f"Total coherent only: {total_coh_only}")
print(f"Total incoherent only: {total_incoh_only}")
print("Processing complete.")

# Save all categorized dfs
save_if_not_empty(vlass_incoherent, "vlass_incoherent.pkl")
save_if_not_empty(vlass_phase_center, "vlass_phase_center.pkl")
save_if_not_empty(vlass_other, "vlass_other.pkl")

save_if_not_empty(non_vlass_incoherent, "non_vlass_incoherent.pkl")
save_if_not_empty(non_vlass_phase_center, "non_vlass_phase_center.pkl")
save_if_not_empty(non_vlass_other, "non_vlass_other.pkl")


print("Processing complete.")


Starting processing with tstart grouping...
Starting processing with tstart grouping...

Processing file: /datax/scratch/ellambishop/rfi_removed_hits/Summer_Project_RA10_Dec0.26.pkl
tstart: 60757.07977429824 | Flagged incoherent freqs: 128
Sample flagged rows:
            tstart  signal_frequency  incoh_snr         source_lower
1346  60757.079774         2659.9439  15.844229  3996427005473589888
1347  60757.079774         2659.6450  19.924137  3996427005473589888
1348  60757.079774         2659.8862   8.446166  3996427005473589888
tstart: 60757.08016265975 | Flagged incoherent freqs: 118
Sample flagged rows:
            tstart  signal_frequency   incoh_snr         source_lower
1475  60757.080163         2659.6681   82.247833  3996427005473589888
1476  60757.080163         2659.5680  120.397026  3996427005473589888
1477  60757.080163         2659.9191   85.393219  3996427005473589888
tstart: 60757.08025975013 | Flagged incoherent freqs: 5
Sample flagged rows:
           tstart  signal_f

In [None]:
file_list = glob.glob("/datax/scratch/ellambishop/rfi_removed_hits/Summer*.pkl")

global_incoherent_freqs = set()
global_phase_center_freqs = set()

for file in file_list:
    if 'Dec-45' in file:
        continue  # Skip large files
    try:
        df = pd.read_pickle(file)
    except Exception as e:
        print(f"Error reading {file}: {e}")
        continue
    if "source_name" in df.columns and "signal_frequency" in df.columns:
        source_lower = df["source_name"].fillna("").str.lower()
        incoherent_freqs = df.loc[source_lower.str.contains("incoherent"), "signal_frequency"].unique()
        phase_center_freqs = df.loc[source_lower.str.contains("phase_center"), "signal_frequency"].unique()
        global_incoherent_freqs.update(incoherent_freqs)
        global_phase_center_freqs.update(phase_center_freqs)

print(f"Total unique incoherent frequencies: {len(global_incoherent_freqs)}")
print(f"Total unique phase_center frequencies: {len(global_phase_center_freqs)}")

In [14]:
def processing(df):
    """
    Processes filtered DataFrame by grouping on (file_uri, observation_id),
    flags RFI based on conditions, and returns the filtered DataFrame.
    """
    df = df.copy()
    grouped = df.groupby(['file_uri', 'observation_id'])
    filtered_groups = []

    for (file_uri, obs_id), group_df in grouped:
        group_df = group_df.copy()

        # Conditions to drop hits (RFI filtering)
        cond1 = (
            (group_df['signal_num_timesteps'] >= 16) &
            (group_df['signal_num_timesteps'] <= 64) &
            (group_df['signal_snr'] > 10)
        )
        cond2 = (
            (group_df['signal_num_timesteps'] < 16) &
            (group_df['signal_snr'] > 15)
        )

        drop_mask = cond1 | cond2
        filtered_group = group_df.loc[~drop_mask]

        filtered_groups.append(filtered_group)

    full_filtered_df = pd.concat(filtered_groups, ignore_index=True)
    return full_filtered_df


# Your existing pipeline:
file_list = glob.glob("/datax/scratch/ellambishop/rfi_removed_hits/Summer*.pkl")

vlass_incoherent = []
vlass_phase_center = []
vlass_other = []

non_vlass_incoherent = []
non_vlass_phase_center = []
non_vlass_other = []

def set_filter(df, drift_max=0):
    df = df.copy()
    df = df[(df['signal_drift_rate'] != drift_max)]
    return df

for file in file_list:
    if 'Dec-45' in file:
        continue

    try:
        df = pd.read_pickle(file)
    except Exception as e:
        print(f"Error reading {file}: {e}")
        continue

    df = set_filter(df)
    df = processing(df)  # Apply your RFI filtering here

    if "file_uri" in df.columns and "source_name" in df.columns:
        source_lower = df["source_name"].fillna("").str.lower()

        vlass_mask = df["file_uri"].str.contains("vlass", case=False, na=False)
        df_vlass = df[vlass_mask].copy()
        df_non_vlass = df[~vlass_mask].copy()

        def categorize_and_append(df_subset, incoh_list, phase_list, other_list):
            source_lower_sub = df_subset["source_name"].fillna("").str.lower()
            incoherent_mask = source_lower_sub.str.contains("incoherent")
            phase_center_mask = source_lower_sub.str.contains("phase_center")
            other_mask = ~(incoherent_mask | phase_center_mask)

            if incoherent_mask.any():
                incoh_list.append(df_subset[incoherent_mask])
            if phase_center_mask.any():
                phase_list.append(df_subset[phase_center_mask])
            if other_mask.any():
                other_list.append(df_subset[other_mask])

        if not df_vlass.empty:
            categorize_and_append(df_vlass, vlass_incoherent, vlass_phase_center, vlass_other)

        if not df_non_vlass.empty:
            categorize_and_append(df_non_vlass, non_vlass_incoherent, non_vlass_phase_center, non_vlass_other)

output_dir = "/datax/scratch/ellambishop/test_refine"
os.makedirs(output_dir, exist_ok=True)

def save_if_not_empty(dfs_list, filename):
    if dfs_list:
        out_path = os.path.join(output_dir, filename)
        pd.concat(dfs_list, ignore_index=True).to_pickle(out_path)
        print(f"Saved: {out_path}")

save_if_not_empty(vlass_incoherent, "vlass_incoherent.pkl")
save_if_not_empty(vlass_phase_center, "vlass_phase_center.pkl")
save_if_not_empty(vlass_other, "vlass_other.pkl")

save_if_not_empty(non_vlass_incoherent, "non_vlass_incoherent.pkl")
save_if_not_empty(non_vlass_phase_center, "non_vlass_phase_center.pkl")
save_if_not_empty(non_vlass_other, "non_vlass_other.pkl")


Saved: /datax/scratch/ellambishop/test_refine/vlass_incoherent.pkl
Saved: /datax/scratch/ellambishop/test_refine/vlass_phase_center.pkl
Saved: /datax/scratch/ellambishop/test_refine/vlass_other.pkl
Saved: /datax/scratch/ellambishop/test_refine/non_vlass_incoherent.pkl
Saved: /datax/scratch/ellambishop/test_refine/non_vlass_phase_center.pkl
Saved: /datax/scratch/ellambishop/test_refine/non_vlass_other.pkl


In [15]:
def qa_summary(name, dfs):
    if dfs:
        df_all = pd.concat(dfs)
        total = len(df_all)
        print(f"{name:<30} | Total: {total:>6} ")
    else:
        print(f"{name:<30} | Total:      0 | Flagged:      0")

print("\n==== QA Summary ====")
qa_summary("VLASS Incoherent", vlass_incoherent)
qa_summary("VLASS Phase Center", vlass_phase_center)
qa_summary("VLASS Other", vlass_other)
qa_summary("Non-VLASS Incoherent", non_vlass_incoherent)
qa_summary("Non-VLASS Phase Center", non_vlass_phase_center)
qa_summary("Non-VLASS Other", non_vlass_other)

# Final totals (insert your actual calculated values here)
print("\n==== Final Totals ====")

print("Processing complete.")



==== QA Summary ====
VLASS Incoherent               | Total:  21287 
VLASS Phase Center             | Total:    309 
VLASS Other                    | Total: 186145 
Non-VLASS Incoherent           | Total: 232695 
Non-VLASS Phase Center         | Total: 2181407 
Non-VLASS Other                | Total: 804791 

==== Final Totals ====
Processing complete.


In [None]:

def qa_summary(name, dfs):
    if dfs:
        df_all = pd.concat(dfs)
        total = len(df_all)
        flagged = df_all["freq_has_incoherent_counterpart"].sum()
        print(f"{name:<30} | Total: {total:>6} | Flagged: {flagged:>5}")

print("\n==== QA Summary ====")
qa_summary("VLASS Incoherent", vlass_incoherent)
qa_summary("VLASS Phase Center", vlass_phase_center)
qa_summary("VLASS Other", vlass_other)
qa_summary("Non-VLASS Incoherent", non_vlass_incoherent)
qa_summary("Non-VLASS Phase Center", non_vlass_phase_center)
qa_summary("Non-VLASS Other", non_vlass_other)


8875194

==== QA Summary ====
VLASS Incoherent               | Total:  98718 | Flagged: 96063
VLASS Phase Center             | Total:   1735 | Flagged:     0
VLASS Other                    | Total: 638961 | Flagged: 247922
Non-VLASS Incoherent           | Total: 8575202 | Flagged: 3108957
Non-VLASS Phase Center         | Total: 7350331 | Flagged:     0
Non-VLASS Other                | Total: 11482838 | Flagged: 8875194


In [13]:
flagged_coherent_unique = 0
for dfs in [vlass_phase_center, vlass_other, non_vlass_phase_center, non_vlass_other]:
    if dfs:
        df_all = pd.concat(dfs)
        unique_pairs = df_all[df_all["freq_has_incoherent_counterpart"]].drop_duplicates(subset=["tstart", "signal_frequency"])
        flagged_coherent_unique += len(unique_pairs)

print(f"Unique flagged coherent freq/tstart pairs: {flagged_coherent_unique}")
print(f"Matched pairs from final totals: {total_coh_and_incoh}")
flagged_coh_pairs = set()  # e.g., (tstart, freq) pairs flagged as having incoherent counterpart
matched_pairs = set()      # your final matched (tstart, freq) pairs

print(f"Flagged but unmatched coherent pairs: {len(flagged_coh_pairs - matched_pairs)}")
print(f"Matched pairs not flagged: {len(matched_pairs - flagged_coh_pairs)}")


Unique flagged coherent freq/tstart pairs: 750953
Matched pairs from final totals: 500243
Flagged but unmatched coherent pairs: 0
Matched pairs not flagged: 0


In [15]:

def analyze_matched_hits_strict(pickle_path, output_path=None, freq_tolerance=1e-3):
    df = pd.read_pickle(pickle_path)

    matched_pairs = []
    total_incoh = 0
    snr_filtered_incoh = 0

    for fov_uri, fov_df in df.groupby('file_uri'):
        for tstart, time_df in fov_df.groupby('tstart'):
            incoh_hits = time_df[time_df['source_name'].str.contains('Incoherent', case=False)]
            coh_hits = time_df[~time_df['source_name'].str.contains('Incoherent|PHASE_CENTER', case=False)]

            total_incoh += len(incoh_hits)
            incoh_hits = incoh_hits[incoh_hits['signal_snr'] > 0]
            coh_hits = coh_hits[coh_hits['signal_snr'] > 0]
            snr_filtered_incoh += len(incoh_hits)

            for _, incoh_row in incoh_hits.iterrows():
                freq = incoh_row['signal_frequency']
                coh_hits['freq_diff'] = np.abs(coh_hits['signal_frequency'] - freq)
                close_match = coh_hits[coh_hits['freq_diff'] <= freq_tolerance]

                if not close_match.empty:
                    # Pick best match (e.g., highest SNR)
                    best_match = close_match.loc[close_match['signal_snr'].idxmax()]
                    matched_pairs.append({
                        'fov_uri': fov_uri,
                        'tstart': tstart,
                        'freq_mhz': freq,
                        'incoh_noise_est': incoh_row['noise_est'],
                        'incoh_snr': incoh_row['signal_snr'],
                        'coh_noise_est': best_match['noise_est'],
                        'coh_snr': best_match['signal_snr'],
                        'incoh_power': incoh_row['signal_power'],
                        'coh_power': best_match['signal_power']
                    })

    matched_df = pd.DataFrame(matched_pairs)
    if matched_df.empty:
        print("No matched hits found.")
        return None

    # Print diagnostic: are duplicates fixed?
    print("Unique incoherent hits:", matched_df[['fov_uri', 'tstart', 'freq_mhz']].drop_duplicates().shape[0])
    print("Total matched rows:", matched_df.shape[0])

    # Compute ratio-based metrics
    matched_df['power_ratio'] = matched_df['coh_noise_est'] / matched_df['incoh_noise_est']
    matched_df['ratio_within_expected'] = matched_df['power_ratio'].between(15, 25)
    matched_df['coh_power_db'] = 10 * np.log10(matched_df['coh_power'])
    matched_df['incoh_power_db'] = 10 * np.log10(matched_df['incoh_power'])

    # Track filter stages
    counts = [
        total_incoh,
        snr_filtered_incoh,
        len(matched_df),
        matched_df['ratio_within_expected'].sum()
    ]
    stages = ['total', 'snr_filtered', 'matched', 'antenna_ratio_ok']

    # Plot filtering stages
    plt.figure(figsize=(8, 5))
    plt.bar(stages, counts, color='mediumslateblue')
    plt.ylabel('Number of Hits')
    plt.title('Hits Remaining After Each Filter Stage')
    plt.tight_layout()
    plt.show()



In [37]:
df = pd.read_pickle('/datax/scratch/ellambishop/hit_test_clean/vlass_other.pkl')
count = df['freq_has_incoherent_counterpart'].sum()
print(count, len(df))
print(df.columns)

247922 638961
Index(['id', 'beam_id', 'observation_id', 'tuning', 'subband_offset',
       'file_uri', 'file_local_enumeration', 'signal_frequency',
       'signal_index', 'signal_drift_steps', 'signal_drift_rate', 'signal_snr',
       'signal_coarse_channel', 'signal_beam', 'signal_num_timesteps',
       'signal_power', 'signal_incoherent_power', 'source_name', 'fch1_mhz',
       'foff_mhz', 'tstart', 'tsamp', 'ra_hours', 'dec_degrees',
       'telescope_id', 'num_timesteps', 'num_channels', 'coarse_channel',
       'start_channel', 'source_lower', 'incoh_snr', 'incoh_power',
       'freq_has_incoherent_counterpart', 'freq_has_phase_center_counterpart'],
      dtype='object')


In [None]:
#checking to see which params have least number of unique objects 
import os
import pandas as pd
from collections import defaultdict

# Track unique values for each column
global_uniques = defaultdict(set)

directory = '/datax/scratch/ellambishop/rfi_removed_hits/'
total = 0

for file in sorted(os.listdir(directory)):
    if 'Dec-45' in file:
        continue  # Skip large files

    file_path = os.path.join(directory, file)
    df = pd.read_pickle(file_path)

    print(f"Processing {file} ({len(df)} rows)")
    total += len(df)

    for col in df.columns:
        # Limit to only object/int/str columns (skip arrays, large blobs)
        if df[col].dtype.kind in {'O', 'i', 'u', 'S'}:
            unique_vals = df[col].unique()
            global_uniques[col].update(unique_vals)

# Sort and print columns with fewest unique values
print("\nGlobal grouping candidates with fewest unique values:")
summary = {col: len(vals) for col, vals in global_uniques.items()}
for col, count in sorted(summary.items(), key=lambda x: x[1]):
    print(f"{col:<25} → {count} unique values")

print(f"\nTotal rows across all files: {total}")


Processing Summer_Project_RA10_Dec0.1.pkl (712727 rows)
Processing Summer_Project_RA10_Dec0.10.pkl (464136 rows)
Processing Summer_Project_RA10_Dec0.11.pkl (157980 rows)
Processing Summer_Project_RA10_Dec0.12.pkl (179218 rows)
Processing Summer_Project_RA10_Dec0.13.pkl (112923 rows)
Processing Summer_Project_RA10_Dec0.14.pkl (1000000 rows)


KeyboardInterrupt: 

In [None]:
print("\nGlobal grouping candidates with fewest unique values:")
summary = {col: len(vals) for col, vals in global_uniques.items()}
for col, count in sorted(summary.items(), key=lambda x: x[1]):
    print(f"{col:<25} → {count} unique values")

print(f"\nTotal rows across all files: {total}")
