In [None]:
# Allow imports from parent directory 
import os, sys
if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")
    sys.path.append(os.path.abspath(".")) 
    
from astropy.io import fits
import numpy as np
import numpy as np
import matplotlib.pyplot as plt
from astropy.coordinates import SkyCoord
import astropy.units as u
from tqdm import tqdm
import warnings
import pickle
from extreme_deconvolution import extreme_deconvolution
from scipy.linalg import det 
warnings.filterwarnings("ignore", category=FutureWarning)
import json

# File path
filtered_corrected_data = 'data/Allsky_Gaia_42481846_extinction_corrected_filtered.fits'

# Extreme Deconvolution Pipeline
- This pipeline was originally built to apply extreme deconvolution to the whole dataset
- This yielded terrible results as it wasnt able seperate noise from tiny scale structures
- It could now be used to take the results from my custom algorithm and seperate noise from cluster, ie create a inner spherical distiribution from the 4D bin
- I did not have time to run this and process results so the pipline is provided here as an extension


In [None]:
# Load matched clusters from the JSON file
with open("data/matched_clusters_LowerPM.json", "r") as f:
    matched_clusters_dict_lower = json.load(f)

# Convert merged dictionary into a list for plotting
cluster_ranges = [
    (cluster_name, data["l_min"], data["l_max"], data["b_min"], data["b_max"], data["pm_ra_min"], data["pm_ra_max"], data["pm_dec_min"], data["pm_dec_max"])
    for cluster_name, data in matched_clusters_dict_lower.items()
]


# Explaination:
- Per hypercube this tried to fit two gaussians, one for the background and one for the cluster
- It selects the gaussian with the small convariance matrix diagonals as it that this will be the guassian fitting the cluster 
- As the background gaussian will be a broader

In [None]:
# Store results
final_results = []

# Iterate over all identified clusters
for cluster in tqdm(cluster_ranges, desc="Processing Clusters", dynamic_ncols=True):
    cluster_id, l_min, l_max, b_min, b_max, pm_ra_min, pm_ra_max, pm_dec_min, pm_dec_max = cluster

    # Extend region in RA/Dec for better selection
    deg_range = (max(l_max - l_min, b_max - b_min) * 1.5) / 2
    l_med, b_med = (l_min + l_max) / 2, (b_min + b_max) / 2

    # Convert central L, B to RA, DEC
    coord_med = SkyCoord(l=l_med * u.degree, b=b_med * u.degree, frame='galactic')
    ra_med, dec_med = coord_med.icrs.ra.degree, coord_med.icrs.dec.degree

    # Define extended search region
    ra_min_ext, ra_max_ext = ra_med - deg_range, ra_med + deg_range
    dec_min_ext, dec_max_ext = dec_med - deg_range, dec_med + deg_range
    pm_ra_min_ext, pm_ra_max_ext = pm_ra_min - 0.25 * (pm_ra_max - pm_ra_min), pm_ra_max + 0.25 * (pm_ra_max - pm_ra_min)
    pm_dec_min_ext, pm_dec_max_ext = pm_dec_min - 0.25 * (pm_dec_max - pm_dec_min), pm_dec_max + 0.25 * (pm_dec_max - pm_dec_min)

    # Extract the data from the FITS file
    with fits.open(filtered_corrected_data, memmap=True) as hdul:
        data = hdul[1].data 
        ra, dec, pm_ra, pm_dec = data['ra'], data['dec'], data['pmra'], data['pmdec']
        ra_error, dec_error = data['ra_error'], data['dec_error']
        
        pm_ra_error, pm_dec_error = data['pmra_error'], data['pmdec_error']

        # Filter the data within the extended region
        mask = (
            (ra >= ra_min_ext) & (ra <= ra_max_ext) &
            (dec >= dec_min_ext) & (dec <= dec_max_ext) &
            (pm_ra >= pm_ra_min_ext) & (pm_ra <= pm_ra_max_ext) &
            (pm_dec >= pm_dec_min_ext) & (pm_dec <= pm_dec_max_ext)
        )

        # Extract relevant data
        cluster_ra_data, cluster_dec_data = ra[mask], dec[mask]
        cluster_pm_ra_data, cluster_pm_dec_data = pm_ra[mask], pm_dec[mask]
        cluster_ra_error, cluster_dec_error = ra_error[mask], dec_error[mask]
        cluster_pm_ra_error, cluster_pm_dec_error = pm_ra_error[mask], pm_dec_error[mask]

        # If no data, skip
        if len(cluster_ra_data) < 10:
            continue

        # Standardize data
        cluster_data_set = np.stack([cluster_ra_data, cluster_dec_data, cluster_pm_ra_data, cluster_pm_dec_data], axis=1)
        cluster_data_set_error = np.stack([cluster_ra_error, cluster_dec_error, cluster_pm_ra_error, cluster_pm_dec_error], axis=1)

    # **Perform Extreme Deconvolution multiple times & select the best result**
    best_result = None
    best_log_likelihood = -np.inf

    for run in range(5):
        # Initialize two Gaussians (background & signal)
        init_mean = np.array([
            [ra_med, dec_med, (pm_ra_min + pm_ra_max) / 2, (pm_dec_min + pm_dec_max) / 2],  # Signal
            [ra_med, dec_med, 0, 0]  # Background (assuming mean ~0 in PM)
        ])
        init_cov = np.array([np.identity(4) * 0.1, np.identity(4) * 10])  # Signal tight, background broad
        init_weights = np.array([0.5, 0.5])  # Equal initial weighting

        # Perform Extreme Deconvolution
        try:
            log_likelihood = extreme_deconvolution(cluster_data_set, cluster_data_set_error, init_mean, init_cov, init_weights, maxiter=1000)
            post_XD_means, post_XD_cov = init_mean.copy(), init_cov.copy()

            # Choose the best result based on log-likelihood
            if log_likelihood > best_log_likelihood:
                best_log_likelihood = log_likelihood
                best_result = {
                    'cluster_id': cluster_id,
                    'mean': post_XD_means,
                    'covariance': post_XD_cov,
                    'log_likelihood': log_likelihood
                }


        except Exception as e:
            print(f"XD failed for cluster {cluster_id} on run {run}: {e}")
            continue

        # **Select the "signal" Gaussian (smallest covariance determinant)**
        if best_result:
            cov_dets = [det(cov) for cov in best_result['covariance']]
            signal_idx = np.argmin(cov_dets)  # The component with the smallest determinant is the "signal"

            final_results.append({
                'cluster_id': best_result['cluster_id'],
                'mean': best_result['mean'][signal_idx],  # Only store the signal component
                'covariance': best_result['covariance'][signal_idx],
                'log_likelihood': best_result['log_likelihood']
            })

# Save only the final results
save_file_path = 'data/glob_clust_XD.pkl'
with open(save_file_path, 'wb') as f:
    pickle.dump(final_results, f)

print(f"Results saved successfully to {save_file_path}.")