In [2]:
import os
import torch
import io
import gzip
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import torch.nn as nn
import seaborn as sns
from tqdm import tqdm
from astropy.io import fits
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import random
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from astropy import units as u
from astropy.coordinates import SkyCoord
from astroquery.gaia import Gaia
from scipy.interpolate import interp1d
from sklearn.preprocessing import MinMaxScaler, PowerTransformer
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score
from requests.exceptions import RequestException

In [22]:
def predict_star_labels(gaia_ids, model_path, lamost_catalogue, gaia_transformer_path):
    """
    Given a list of Gaia DR3 IDs, this function:
    1) Queries Gaia for star parameters.
    2) Cross-matches with LAMOST spectra.
    3) Downloads and processes LAMOST spectra.
    4) Normalizes both Gaia and LAMOST data.
    5) Applies a trained StarClassifierFusion model to predict labels.
    
    Returns:
        DataFrame with Gaia IDs and predicted multi-label classifications.
    """

    print("\n🚀 Step 1: Querying Gaia data...")
    print("🔗 Gaia IDs:", len(gaia_ids))
    df_gaia = query_gaia_data(gaia_ids)
    if df_gaia.empty:
        print("⚠️ No Gaia data found. Exiting.")
        return None
    print("🔗 Gaia data:", df_gaia.shape)

    print("\n🔄 Step 2: Cross-matching with LAMOST catalog...")
    
    df_matched = crossmatch_lamost(df_gaia, lamost_catalogue)
    if df_matched.empty:
        print("⚠️ No LAMOST matches found. Exiting.")
        return None

    print("\n📥 Step 3: Downloading LAMOST spectra (if needed)...")
    obsids = df_matched["obsid"].unique()
    spectra_folder = "lamost_spectra_uniques"
    download_lamost_spectra(obsids, save_folder=spectra_folder, num_workers=500)

    print("\n🔧 Step 4: Converting from FITS LAMOST spectra...")
    #process_lamost_fits_files(folder_path=spectra_folder, output_file="Pickles/lamost_data.csv")
    process_lamost_fits_files(folder_path="lamost_spectra_uniques", 
                          output_file="Pickles/lamost_data.csv", 
                          matched_obsids=obsids)

    print("\n📊 Step 5: Extracting and saving flux & frequency values...")
    extract_flux_frequency_from_csv(csv_path="Pickles/lamost_data.csv")

    print("\n📊 Step 6: Interpolating and normalizing LAMOST spectra...")
    nan_files = interpolate_spectrum("Pickles/flux_values.pkl", "Pickles/freq_values.pkl", "Pickles/lamost_data_interpolated.pkl")
    spectrum_interpolated = pd.read_pickle("Pickles/lamost_data_interpolated.pkl")
    spectrum_normalized = normalize_lamost_spectra(spectrum_interpolated)

    if spectrum_normalized.empty:
        print("⚠️ No processed LAMOST spectra found. Exiting.")
        return None

    print("\n📊 Step 7: Normalizing Gaia data...")
    with open(gaia_transformer_path, "rb") as f:
        gaia_transformers = pickle.load(f)   # Dict of {col_name: fitted PowerTransformer}
    gaia_normalized = apply_gaia_transforms(df_gaia, gaia_transformers)

    print("\n🔗 Step 8: Merging Gaia and LAMOST data...")
    gaia_lamost_match = df_matched[["source_id", "obsid"]]
    spectrum_normalized["obsid"] = spectrum_normalized["obsid"].astype(int)
    gaia_lamost_match["obsid"] = gaia_lamost_match["obsid"].astype(int)

    # Identify and remove all obsid values that appear more than once
    obsid_counts = gaia_lamost_match["obsid"].value_counts()
    unique_obsids = obsid_counts[obsid_counts == 1].index  # Keep only obsid values that appear once

    # Filter dataset to keep only unique obsid values
    gaia_lamost_match = gaia_lamost_match[gaia_lamost_match["obsid"].isin(unique_obsids)]

    # Now, map the cleaned obsid-to-source_id mapping
    spectrum_normalized["source_id"] = spectrum_normalized["obsid"].astype(int).map(
        gaia_lamost_match.set_index("obsid")["source_id"]
    )

    # Merge Gaia and LAMOST data
    gaia_lamost_merged = pd.merge(gaia_normalized, spectrum_normalized, on="source_id", how="inner")

    if gaia_lamost_merged.empty:
        print("⚠️ No valid data after merging. Exiting.")
        return None
    
    print("\n🤖 Step 9: Predicting labels using the trained model...")
    predictions = process_star_data_fusion(model_path, gaia_lamost_merged, "Pickles/Updated_List_of_Classes_ubuntu.pkl", sigmoid_constant=0.5)

    print("\n💾 Step 10: Saving predictions...")
    df_predictions = pd.DataFrame(predictions, columns=pd.read_pickle("Pickles/Updated_List_of_Classes_ubuntu.pkl"))
    df_predictions["source_id"] = gaia_lamost_merged["source_id"].values

    return df_predictions, gaia_lamost_merged

def process_lamost_fits_files(folder_path="lamost_spectra_uniques", 
                             output_file="Pickles/lamost_data.csv", 
                             batch_size=10000, 
                             matched_obsids=None):
    """
    Processes LAMOST FITS spectra by extracting both flux and wavelength data.
    Handles both regular and gzipped FITS files with efficient error handling.
    """
    print("\n📂 Processing LAMOST FITS files...")
    
    # Ensure output directory exists
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    # Define column headers
    columns = [f'col_{i}' for i in range(3748)] + ['file_name', 'row']
    
    # Initialize the CSV file with headers
    with open(output_file, 'w') as f:
        pd.DataFrame(columns=columns).to_csv(f, index=False)
    
    # Get list of FITS files in the folder
    if not os.path.exists(folder_path):
        print(f"⚠️ Folder {folder_path} does not exist. Creating it.")
        os.makedirs(folder_path)
        return 0
    
    all_files = os.listdir(folder_path)
    
    # Filter files based on matched_obsids
    if matched_obsids is not None:
        matched_obsids = set(str(obsid) for obsid in matched_obsids)
        all_files = [f for f in all_files if str(f).split('.')[0] in matched_obsids]
    
    total_files = len(all_files)
    if total_files == 0:
        print("⚠️ No matching FITS files found for processing.")
        return 0
    
    batch_list = []
    processed_count = 0
    error_count = 0
    
    # Sample a few files first to understand their structure
    sample_files = all_files[:min(3, len(all_files))]
    print("Examining sample files to determine FITS structure...")
    field_names = set()
    
    for filename in sample_files:
        file_path = os.path.join(folder_path, filename)
        try:
            hdul = open_fits_file(file_path)
            if hdul:
                with hdul:
                    for hdu in hdul:
                        if hasattr(hdu, 'data') and hdu.data is not None:
                            if isinstance(hdu.data, fits.fitsrec.FITS_rec) and len(hdu.data) > 0:
                                print(f"Found FITS_rec in {filename}, names: {hdu.data.dtype.names}")
                                field_names.update(hdu.data.dtype.names)
        except Exception as e:
            print(f"Error examining {filename}: {str(e)}")
    
    if field_names:
        print(f"Found field names across sample files: {field_names}")
    
    # Process all files
    with tqdm(total=total_files, desc='Processing FITS files') as pbar:
        for filename in all_files:
            file_path = os.path.join(folder_path, filename)
            
            try:
                # Skip very small files
                if os.path.getsize(file_path) < 100:
                    print(f"⚠️ Skipping {filename}: File too small")
                    error_count += 1
                    pbar.update(1)
                    continue
                
                hdul = open_fits_file(file_path)
                if not hdul:
                    error_count += 1
                    pbar.update(1)
                    continue
                
                with hdul:
                    # Extract both flux and wavelength data
                    flux_data, wave_data = extract_flux_and_wavelength(hdul, filename)
                    
                    # Process flux data if available
                    if flux_data is not None and len(flux_data) > 0:
                        # Ensure correct length
                        if len(flux_data) < 3748:
                            flux_data = np.pad(flux_data, (0, 3748-len(flux_data)), 'constant')
                        elif len(flux_data) > 3748:
                            flux_data = flux_data[:3748]
                        
                        # Create flux dictionary
                        flux_dict = {
                            f'col_{j}': float(value) if not np.isnan(value) else 0.0 
                            for j, value in enumerate(flux_data)
                        }
                        flux_dict['file_name'] = filename
                        flux_dict['row'] = 0  # Row 0 for flux data
                        batch_list.append(flux_dict)
                    
                    # Process wavelength data if available
                    if wave_data is not None and len(wave_data) > 0:
                        # Ensure correct length
                        if len(wave_data) < 3748:
                            wave_data = np.pad(wave_data, (0, 3748-len(wave_data)), 'constant')
                        elif len(wave_data) > 3748:
                            wave_data = wave_data[:3748]
                        
                        # Create wavelength dictionary
                        wave_dict = {
                            f'col_{j}': float(value) if not np.isnan(value) else 0.0 
                            for j, value in enumerate(wave_data)
                        }
                        wave_dict['file_name'] = filename
                        wave_dict['row'] = 2  # Row 2 for wavelength data
                        batch_list.append(wave_dict)
                    
                    # If we got both flux and wavelength, count as processed
                    if (flux_data is not None and len(flux_data) > 0) or (wave_data is not None and len(wave_data) > 0):
                        processed_count += 1
                    else:
                        error_count += 1
                        print(f"⚠️ Could not extract data from {filename}")
                
                # Write batch to CSV when it gets large enough
                if len(batch_list) >= batch_size:
                    pd.DataFrame(batch_list).to_csv(output_file, mode='a', header=False, index=False)
                    batch_list.clear()
            
            except Exception as e:
                error_count += 1
                print(f"⚠️ Error processing {filename}: {str(e)}")
            
            pbar.update(1)
        
        # Write any remaining data
        if batch_list:
            pd.DataFrame(batch_list).to_csv(output_file, mode='a', header=False, index=False)
    
    print(f"✅ Successfully processed {processed_count} files")
    print(f"⚠️ Encountered errors in {error_count} files")
    
    return processed_count


def open_fits_file(file_path):
    """
    Opens a FITS file, handling both regular and gzipped formats.
    """
    try:
        # Check if the file is gzipped
        with open(file_path, 'rb') as f:
            file_start = f.read(2)
            f.seek(0)  # Reset file pointer
            
            if file_start == b'\x1f\x8b':  # gzip magic number
                # Handle gzipped file
                with gzip.GzipFile(fileobj=f) as gz_f:
                    file_content = gz_f.read()
                return fits.open(io.BytesIO(file_content), ignore_missing_simple=True)
            else:
                # Handle regular file
                return fits.open(file_path, ignore_missing_simple=True)
    except Exception as e:
        print(f"Error opening file {os.path.basename(file_path)}: {str(e)}")
        return None


def extract_flux_and_wavelength(hdul, filename):
    """
    Extracts both flux and wavelength data from a FITS file.
    
    Args:
        hdul (HDUList): The FITS file HDUList
        filename (str): Filename for error reporting
        
    Returns:
        tuple: (flux_data, wavelength_data) - Both are arrays or None if extraction fails
    """
    flux_data = None
    wave_data = None
    
    try:
        # Try each HDU until we find data
        for hdu in hdul:
            if not hasattr(hdu, 'data') or hdu.data is None:
                continue
                
            # Handle FITS_rec (record array) format
            if isinstance(hdu.data, fits.fitsrec.FITS_rec):
                if len(hdu.data) == 0:
                    continue
                    
                # Get field names
                field_names = hdu.data.dtype.names
                
                # Look for flux and wavelength fields
                flux_field = None
                wave_field = None
                
                # Search for field names (case-insensitive)
                for field in field_names:
                    field_lower = field.lower()
                    if 'flux' in field_lower or 'spectrum' in field_lower or 'spec' in field_lower:
                        flux_field = field
                    elif 'wave' in field_lower or 'lambda' in field_lower or 'wavelength' in field_lower:
                        wave_field = field
                
                # Extract flux data
                if flux_field and flux_data is None:
                    try:
                        flux_array = hdu.data[0][flux_field]
                        if isinstance(flux_array, np.ndarray):
                            flux_data = flux_array
                        else:
                            flux_data = np.array([flux_array])
                        #print(f"Found flux data in {filename}, field: {flux_field}")
                    except Exception as e:
                        print(f"Error extracting flux from field {flux_field}: {str(e)}")
                
                # Extract wavelength data
                if wave_field and wave_data is None:
                    try:
                        wave_array = hdu.data[0][wave_field]
                        if isinstance(wave_array, np.ndarray):
                            wave_data = wave_array
                        else:
                            wave_data = np.array([wave_array])
                        #print(f"Found wavelength data in {filename}, field: {wave_field}")
                    except Exception as e:
                        print(f"Error extracting wavelength from field {wave_field}: {str(e)}")
            
            # Handle standard array data
            elif isinstance(hdu.data, np.ndarray):
                # In standard LAMOST FITS, HDU 0 often contains the flux, HDU 2 contains wavelength
                hdu_index = hdul.index(hdu)
                
                if flux_data is None and hdu.data.size > 0:
                    # Take the first row if there are multiple
                    if hdu.data.ndim > 1:
                        flux_data = hdu.data[0]
                    else:
                        flux_data = hdu.data
                    print(f"Found flux data in {filename}, HDU: {hdu_index}")
                
                # Try to find wavelength data in a different HDU if not already found
                if wave_data is None and hdu_index > 0:
                    # Check if this might be wavelength data
                    if hdu.data.size > 0:
                        # Take the first row if there are multiple
                        if hdu.data.ndim > 1:
                            wave_data = hdu.data[0]
                        else:
                            wave_data = hdu.data
                        print(f"Found potential wavelength data in {filename}, HDU: {hdu_index}")
            
            # If we have both flux and wavelength, we can stop
            if flux_data is not None and wave_data is not None:
                break
        
        # If we found data, convert to the right format
        if flux_data is not None and not isinstance(flux_data, np.ndarray):
            flux_data = np.array(flux_data)
        if wave_data is not None and not isinstance(wave_data, np.ndarray):
            wave_data = np.array(wave_data)
        
        return flux_data, wave_data
        
    except Exception as e:
        print(f"Error extracting data from {filename}: {str(e)}")
        return None, None


def extract_flux_frequency_from_csv(csv_path="Pickles/lamost_data.csv", 
                                   flux_pickle="Pickles/flux_values.pkl", 
                                   freq_pickle="Pickles/freq_values.pkl", 
                                   chunk_size=10000):
    """
    Extracts flux and frequency (wavelength) data from a CSV file and saves them as separate pickle files.
    """
    print("\n📊 Extracting flux and frequency values...")

    # Initialize empty dataframes
    flux_values = pd.DataFrame()
    freq_values = pd.DataFrame()

    try:
        # Count total rows for progress tracking
        total_rows = sum(1 for _ in open(csv_path)) - 1  # Subtract header row
        
        # Process the CSV in chunks
        for chunk in tqdm(pd.read_csv(csv_path, chunksize=chunk_size), 
                         total=total_rows // chunk_size + 1):
            # If 'row' column is present, use it to filter rows
            if 'row' in chunk.columns:
                flux_chunk = chunk[chunk['row'] == 0].drop(columns=['row'])
                wave_chunk = chunk[chunk['row'] == 2].drop(columns=['row'])
            else:
                # For old format data
                flux_mask = chunk.index % 3 == 0  # First row for each file
                flux_chunk = chunk[flux_mask]
                
                wave_mask = chunk.index % 3 == 2  # Third row for each file
                wave_chunk = chunk[wave_mask]
            
            # Concatenate with existing data
            flux_values = pd.concat([flux_values, flux_chunk])
            freq_values = pd.concat([freq_values, wave_chunk])
        
        # Ensure filenames match between flux and frequency
        if not freq_values.empty and not flux_values.empty:
            common_files = set(flux_values['file_name']).intersection(set(freq_values['file_name']))
            flux_values = flux_values[flux_values['file_name'].isin(common_files)]
            freq_values = freq_values[freq_values['file_name'].isin(common_files)]
        
        print(f"✅ Flux values shape: {flux_values.shape}, Frequency values shape: {freq_values.shape}")
        
        # If no wavelength data was found, generate synthetic wavelength data based on LAMOST standard wavelength range
        if freq_values.empty and not flux_values.empty:
            print("⚠️ No wavelength data found, generating synthetic wavelength range")
            # Create synthetic wavelength values (3800-9000 Å is typical LAMOST range)
            wavelength_range = np.linspace(3800, 9000, 3748)
            
            # Create wavelength entries for each file in flux_values
            wavelength_rows = []
            for filename in flux_values['file_name'].unique():
                wave_dict = {f'col_{j}': float(value) for j, value in enumerate(wavelength_range)}
                wave_dict['file_name'] = filename
                wavelength_rows.append(wave_dict)
            
            freq_values = pd.DataFrame(wavelength_rows)
            print(f"✅ Created synthetic wavelength data with shape: {freq_values.shape}")
    
    except Exception as e:
        print(f"⚠️ Error processing CSV: {str(e)}")
        return 0, 0
    
    # Save extracted values
    os.makedirs(os.path.dirname(flux_pickle), exist_ok=True)
    os.makedirs(os.path.dirname(freq_pickle), exist_ok=True)
    
    flux_values.to_pickle(flux_pickle)
    freq_values.to_pickle(freq_pickle)
    
    return flux_values.shape[0], freq_values.shape[0]

def query_gaia_data(gaia_id_list):
    """
    Given a list of Gaia DR3 source IDs, queries the Gaia archive
    for the relevant columns used during training.
    Returns a concatenated DataFrame of results.
    """
    # Columns you actually needed
    desired_cols = [
        "source_id", "ra", "ra_error", "dec", "dec_error",
        "pmra", "pmra_error", "pmdec", "pmdec_error",
        "parallax", "parallax_error",
        "phot_g_mean_flux", "phot_g_mean_flux_error",
        "phot_bp_mean_flux", "phot_bp_mean_flux_error",
        "phot_rp_mean_flux", "phot_rp_mean_flux_error"
    ]

    all_dfs = []
    chunks = split_ids_into_chunks(gaia_id_list, chunk_size=30000)
    for chunk in chunks:
        query = f"""
        SELECT {', '.join(desired_cols)}
        FROM gaiadr3.gaia_source
        WHERE source_id IN ({chunk})
        """
        job = Gaia.launch_job_async(query)
        tbl = job.get_results()
        df_tmp = tbl.to_pandas()
        all_dfs.append(df_tmp)

    # Convert string IDs to integers
    if isinstance(gaia_id_list[0], str):
        gaia_id_list = [int(x) for x in gaia_id_list]
    
    # Check for missing IDs
    all_ids = pd.concat(all_dfs)["source_id"].values
    missing_ids = set(gaia_id_list) - set(all_ids)
    if missing_ids:
        print(f"Warning: {len(missing_ids)} IDs not found in Gaia DR3.")
        print(f"Missing IDs: {missing_ids}")

    if not all_dfs:
        return pd.DataFrame(columns=desired_cols)
    else:
        return pd.concat(all_dfs, ignore_index=True)
    
def query_gaia_data(gaia_id_list, desired_cols=None, chunk_size=30000, MAX_RETRIES=5, RETRY_DELAY_SECONDS=5):
    """
    Given a list of Gaia DR3 source IDs, queries the Gaia archive
    for the relevant columns. Includes a retry mechanism for transient errors.
    Returns a concatenated DataFrame of results.
    """
    if desired_cols is None:
        desired_cols = [
            "source_id", "ra", "ra_error", "dec", "dec_error",
            "pmra", "pmra_error", "pmdec", "pmdec_error",
            "parallax", "parallax_error",
            "phot_g_mean_flux", "phot_g_mean_flux_error",
            "phot_bp_mean_flux", "phot_bp_mean_flux_error",
            "phot_rp_mean_flux", "phot_rp_mean_flux_error"
        ]

    # Convert input IDs to strings for consistency
    gaia_id_list_str = [str(x) for x in gaia_id_list]
    num_total_ids = len(gaia_id_list_str)

    chunks_sql = split_ids_into_chunks(gaia_id_list_str, chunk_size=chunk_size)
    num_chunks = len(chunks_sql)
    print(f"Gaia query split into {num_chunks} chunks of up to {chunk_size} IDs.")

    all_dfs = []
    failed_chunks = 0

    for i, chunk_sql in enumerate(chunks_sql):
        print(f"\nProcessing Chunk {i+1}/{num_chunks}...")
        df_tmp = None # Initialize chunk result to None

        for attempt in range(MAX_RETRIES):
            try:
                print(f"  Attempt {attempt + 1}/{MAX_RETRIES}: Launching Gaia job...")
                query = f"""
                SELECT {', '.join(desired_cols)}
                FROM gaiadr3.gaia_source
                WHERE source_id IN ({chunk_sql})
                """
                # Set a timeout for the job launch/get_results (optional, tune as needed)
                # Gaia.TIMEOUT = 60 # Example: 60 seconds timeout for HTTP requests
                # Gaia.POLL_INTERVAL = 10 # Example: Check job status every 10s

                job = Gaia.launch_job_async(query)
                print(f"  Job launched (ID: {job.jobid}). Waiting for results...")
                tbl = job.get_results() # This can also time out or fail
                print(f"  Results received for chunk {i+1}.")
                df_tmp = tbl.to_pandas()
                print(f"  Retrieved {len(df_tmp)} records for this chunk.")
                # Success! Break the retry loop for this chunk
                break

            # --- Catch Specific Exceptions ---
            # Catch the error you observed
            except ConnectionResetError as e:
                print(f"  Attempt {attempt + 1} failed: ConnectionResetError - {e}")
            # Catch common network/request errors from astroquery/requests
            except RequestException as e:
                print(f"  Attempt {attempt + 1} failed: RequestException - {e}")
            # Catch potential socket timeouts
            #except SocketTimeoutError as e:
            #     print(f"  Attempt {attempt + 1} failed: SocketTimeoutError - {e}")
            # Catch potential errors during job execution reported by Gaia TAP service
            except Exception as e:
                 # Check if it's a TAP service error indicating job failure
                 if "job execution failed" in str(e).lower() or "error executing query" in str(e).lower():
                     print(f"  Attempt {attempt + 1} failed: Gaia TAP execution error - {e}")
                     # Often retrying these won't help if the query is bad, but sometimes it's transient
                 else:
                     print(f"  Attempt {attempt + 1} failed with unexpected error: {type(e).__name__} - {e}")
                 # Decide if you want to retry unexpected errors or just fail the chunk
                 # For now, we'll retry them too.

            # --- Retry Logic ---
            if attempt < MAX_RETRIES - 1:
                print(f"  Retrying in {RETRY_DELAY_SECONDS} seconds...")
                time.sleep(RETRY_DELAY_SECONDS)
            else:
                print(f"  Max retries reached for chunk {i+1}. Skipping this chunk.")
                failed_chunks += 1

        # End of retry loop for one chunk

        # Append the result if successful
        if df_tmp is not None and not df_tmp.empty:
            all_dfs.append(df_tmp)

        # Optional short sleep between chunks to be nice to the server
        # Even if a chunk failed, wait before starting the next one
        time.sleep(2)

    # --- Consolidate Results ---
    print("\nConsolidating results...")
    if failed_chunks > 0:
        print(f"Warning: Failed to retrieve data for {failed_chunks} out of {num_chunks} chunks after retries.")

    if not all_dfs:
         print("No data was successfully retrieved from Gaia.")
         return pd.DataFrame(columns=desired_cols) # Return empty DataFrame

    final_df = pd.concat(all_dfs, ignore_index=True)
    print(f"Consolidated DataFrame shape: {final_df.shape}")

    # --- Final Checks (same as before) ---
    if gaia_id_list and isinstance(gaia_id_list[0], int):
         try:
             final_df['source_id'] = final_df['source_id'].astype(int)
             input_ids = set(gaia_id_list)
         except ValueError:
             print("Warning: Could not convert all retrieved source_ids to int.")
             input_ids = set(gaia_id_list_str) # Compare as strings
    else:
         input_ids = set(gaia_id_list_str) # Compare as strings

    retrieved_ids_final = set(final_df['source_id'].astype(str).tolist())
    missing_ids = input_ids - retrieved_ids_final
    if missing_ids:
        # This count includes IDs that might have been in failed chunks
        print(f"Warning: {len(missing_ids)} out of {num_total_ids} requested IDs were not found in the final results (includes skipped chunks).")
    else:
        print("All requested IDs that exist in Gaia DR3 and were in successful chunks seem to be retrieved.")

    return final_df
    
     
def split_ids_into_chunks(gaia_id_list, chunk_size=50000):
    """
    Takes a Python list of Gaia IDs (strings or ints),
    returns a list of comma-joined strings, each containing up to `chunk_size` IDs.
    """
    # Convert everything to string for the SQL query
    gaia_id_list = [str(x) for x in gaia_id_list]
    chunks = []
    for i in range(0, len(gaia_id_list), chunk_size):
        chunk = ", ".join(gaia_id_list[i:i+chunk_size])
        chunks.append(chunk)
    return chunks


def crossmatch_lamost(gaia_df, lamost_df, match_radius=3*u.arcsec):
    """
    Cross-matches Gaia sources with a local LAMOST catalogue.
    Returns a merged DataFrame of matched objects.
    """

    # Ensure RA/Dec are numeric
    gaia_df['ra'] = pd.to_numeric(gaia_df['ra'], errors='coerce')
    gaia_df['dec'] = pd.to_numeric(gaia_df['dec'], errors='coerce')
    lamost_df['ra'] = pd.to_numeric(lamost_df['ra'], errors='coerce')
    lamost_df['dec'] = pd.to_numeric(lamost_df['dec'], errors='coerce')

    # Drop NaN values
    gaia_df = gaia_df.dropna(subset=['ra', 'dec'])
    lamost_df = lamost_df.dropna(subset=['ra', 'dec'])

    print(f"After NaN removal: Gaia={gaia_df.shape}, LAMOST={lamost_df.shape}")

    # Check if LAMOST coordinates are in arcseconds (convert if necessary)
    if lamost_df['ra'].max() > 360:  # RA should not exceed 360 degrees
        print("⚠️ LAMOST RA/Dec seem to be in arcseconds. Converting to degrees.")
        lamost_df['ra'] /= 3600
        lamost_df['dec'] /= 3600

    # Convert to SkyCoord objects (ensuring same frame)
    gaia_coords = SkyCoord(ra=gaia_df['ra'].values*u.deg,
                           dec=gaia_df['dec'].values*u.deg,
                           frame='icrs')

    lamost_coords = SkyCoord(ra=lamost_df['ra'].values*u.deg,
                             dec=lamost_df['dec'].values*u.deg,
                             frame='icrs')

    # Perform crossmatch
    idx, d2d, _ = gaia_coords.match_to_catalog_sky(lamost_coords)

    # Apply matching radius filter
    matches = d2d < match_radius
    #print(f"Match distances (arcsec): {d2d.to(u.arcsec).value[matches]}")

    if matches.sum() == 0:
        print("⚠️ No matches found! Try increasing `match_radius`.")
        return pd.DataFrame()

    # Extract matched rows correctly
    gaia_matched = gaia_df.iloc[matches].copy().reset_index(drop=True)
    lamost_matched = lamost_df.iloc[idx[matches]].copy().reset_index(drop=True)

    print(f"Matched Gaia Objects: {gaia_matched.shape}")
    print(f"Matched LAMOST Objects: {lamost_matched.shape}")

    # Merge matches into final DataFrame
    final = pd.concat([gaia_matched, lamost_matched], axis=1)

    return final



def download_lamost_spectra(obsid_list, save_folder="lamost_spectra_uniques", num_workers=10):
    """
    Downloads LAMOST spectra by obsid in parallel with robust error handling
    and rate limiting to avoid overwhelming the server.
    
    Args:
        obsid_list (list): List of obsids to download
        save_folder (str): Folder where spectra will be saved
        num_workers (int): Number of parallel download threads
        
    Returns:
        list: List of successfully downloaded obsids
    """
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    
    # Check which files are already downloaded
    existing_files = set(os.listdir(save_folder))
    obsid_list = [obsid for obsid in obsid_list if str(obsid) not in existing_files]
    
    if not obsid_list:
        print("✅ All spectra are already downloaded.")
        return []
    
    print(f"📂 {len(obsid_list)} new spectra will be downloaded.")
    
    # Create a requests Session with conservative retry settings
    retries = Retry(
        total=5, 
        backoff_factor=2,  # Exponential backoff
        status_forcelist=[429, 500, 502, 503, 504],
        respect_retry_after_header=True
    )
    session = requests.Session()
    session.mount("https://", HTTPAdapter(max_retries=retries))
    
    # Split the obsid_list into smaller chunks to download in batches
    chunk_size = 100  # Download in smaller batches
    obsid_chunks = [obsid_list[i:i + chunk_size] for i in range(0, len(obsid_list), chunk_size)]
    
    downloaded_obsids = []
    
    for chunk_idx, chunk in enumerate(obsid_chunks):
        print(f"Processing batch {chunk_idx+1}/{len(obsid_chunks)} ({len(chunk)} files)")
        
        # Use ThreadPoolExecutor to download in parallel
        results = []
        with ThreadPoolExecutor(max_workers=num_workers) as executor:
            future_to_obsid = {
                executor.submit(download_one_spectrum, obsid, session, save_folder): obsid 
                for obsid in chunk
            }

            # Wrap with tqdm for progress bar
            for future in tqdm(as_completed(future_to_obsid), total=len(future_to_obsid), desc="Downloading Spectra"):
                obsid = future_to_obsid[future]
                try:
                    obsid, success, error_msg = future.result()
                    results.append((obsid, success, error_msg))
                except Exception as e:
                    results.append((obsid, False, str(e)))

        # Print any failures
        failures = [r for r in results if not r[1]]
        if failures:
            print(f"❌ Failed to download {len(failures)} spectra in this batch.")
        
        # Add successful downloads to the list
        batch_downloaded = [r[0] for r in results if r[1]]
        downloaded_obsids.extend(batch_downloaded)
        
        # Rest between batches to avoid overwhelming the server
        if chunk_idx < len(obsid_chunks) - 1:
            print(f"Sleeping between batches to avoid server overload...")
            time.sleep(2)  # Sleep 2 seconds between batches
    
    print(f"✅ Successfully downloaded {len(downloaded_obsids)} out of {len(obsid_list)} spectra.")
    return downloaded_obsids


def interpolate_spectrum(fluxes_loc, frequencies_loc, output_dir, limit=10, edge_limit=20):
    """Interpolates the flux values to fill in missing data points."""
    # Load the data from the pickle file    
    df_freq = pd.read_pickle(frequencies_loc).reset_index(drop=True)      
    df_flux = pd.read_pickle(fluxes_loc).reset_index(drop=True)  # Reset index for zero-based iteration

    # Initialize an empty list to store the results before concatenating into a DataFrame
    results_list = []

    # Initialize lists to store problematic file_names
    nan_files = []  

    # Count the number of successful interpolations
    cnt_success = 0

    # Debugging counters
    cnt_total_skipped = 0
    cnt_nan_skipped = 0
    cnt_zero_skipped = 0

    # Overwrite the output file at the beginning
    if os.path.exists(output_dir):
        os.remove(output_dir)

    # Loop through each row in the DataFrame (each row is a spectrum) with tqdm for progress bar
    for index, row in tqdm(df_flux.iterrows(), total=len(df_flux), desc='Interpolating spectra'):

        # Extract the fluxes (assuming they start at column 0 and continue to the last column)
        fluxes = row[:-2].values  # Exclude the last columns (file_name, label)

        # Extract the frequencies
        frequencies = df_freq.iloc[int(index), :-2].values  # Exclude the last columns (file_name, label)

        # Count the number of NaN and 0 values in the fluxes and frequencies
        fluxes = pd.to_numeric(row[:-2], errors='coerce').values  # Exclude and convert to numeric
        frequencies = pd.to_numeric(df_freq.iloc[index, :-2], errors='coerce').values  # Same for frequencies
        num_nan = np.isnan(fluxes).sum() + np.isnan(frequencies).sum()  # Count NaN values
        num_zero = (fluxes == 0).sum() + (frequencies == 0).sum()  # Count zero values
        num_freq_nan = np.isnan(frequencies).sum() + (frequencies == 0).sum()
        if num_freq_nan > 0:
            print(f"Number of NaN or zero frequency values: {num_freq_nan}")
        # Special handling for NaN values, counting nans in sequence, except for the first and last 10
        if num_nan > limit and index > edge_limit and index < len(fluxes)-edge_limit:
            cnt_nan_skipped += 1  # Debug: count NaN-skipped rows
            nan_files.append(row['file_name'])
            continue
        
        if num_zero > limit and index > edge_limit and index < len(fluxes)-edge_limit:
            cnt_zero_skipped += 1  # Debug: count zero-skipped rows
            nan_files.append(row['file_name'])
            continue

        # Deal with NaN values
        fluxes = fluxes[~np.isnan(fluxes)]
        frequencies = frequencies[~np.isnan(fluxes)]

        # Interpolate to fill in missing values
        f = interp1d(frequencies, fluxes, kind='linear', fill_value="extrapolate")
        new_frequencies = np.linspace(frequencies.min(), frequencies.max(), len(row[:-2].values))

        # Interpolated flux values
        interpolated_fluxes = f(new_frequencies)

        # Store the interpolated data along with labels and other metadata
        # Create a dictionary for the interpolated spectrum
        interpolated_data = {f'flux_{i}': value for i, value in enumerate(interpolated_fluxes)}

        # Add the original metadata back (e.g., file_name, label, row)
        interpolated_data['file_name'] = row['file_name']
                
        # Append the interpolated data to the results list
        results_list.append(interpolated_data)

        if index % 2000 == 0:  # Save every 5000 rows
            if os.path.exists(output_dir):
                existing_df = pd.read_pickle(output_dir)  # Load existing data
                new_df = pd.DataFrame(results_list)
                # Concatenate existing and new data
                combined_df = pd.concat([existing_df, new_df], ignore_index=True)
                combined_df.to_pickle(output_dir)  # Save combined DataFrame
            else:
                # If the file doesn't exist, create a new DataFrame and save
                pd.DataFrame(results_list).to_pickle(output_dir)
            cnt_success += len(results_list)  # Increment the count of successful interpolations
            results_list = []  # Clear list to free memory

    print(f"Initial number of rows: {len(df_flux)}")

    # After the loop, save any remaining results
    if results_list:
        if os.path.exists(output_dir):
            existing_df = pd.read_pickle(output_dir)
            new_df = pd.DataFrame(results_list)
            combined_df = pd.concat([existing_df, new_df], ignore_index=True)
            combined_df.to_pickle(output_dir)
        else:
            pd.DataFrame(results_list).to_pickle(output_dir)
        cnt_success += len(results_list)

    # Debugging information
    cnt_total_skipped = len(nan_files)
    print(f"Total successful interpolations: {cnt_success}")
    #print(f"Total skipped due to NaNs: {cnt_nan_skipped}")
    #print(f"Total skipped due to zeros: {cnt_zero_skipped}")
    print(f"Total skipped rows (NaNs + zeros): {cnt_total_skipped}")
    print(f"Final check: len(df_flux) == cnt_success + len(nan_files)? {len(df_flux) == cnt_success + cnt_total_skipped}")

    return nan_files


def normalize_lamost_spectra(spectra_df):
    """
    Reads LAMOST FITS spectra, applies interpolation, normalization, and transformation.
    Returns a DataFrame of final spectral features (one row per spectrum).
    """

    
    spectra = spectra_df.iloc[:, 100:-1].values  # Exclude the last column (file_name)

    #print(f"Shape of the spectra array: {spectra.shape}")

    # Normalize the spectra between 0 and 1
    min_max_scaler = MinMaxScaler()
    spectra_normalized = min_max_scaler.fit_transform(spectra.T).T

    #print(f"Shape of the normalized spectra array: {spectra_normalized.shape}")

    # Apply the Yeo-Johnson transformation to the spectra
    pt = PowerTransformer(method='yeo-johnson', standardize=True)
    spectra_transformed = pt.fit_transform(spectra_normalized.T).T

    # Create a new DataFrame with the transformed spectra
    df_transformed = pd.DataFrame(spectra_transformed, columns=spectra_df.columns[100:-1]) # Exclude the first 100+3 columns and the last column

    #print(f"Shape of the transformed spectra array: {spectra_transformed.shape}")

    # Add the file_name column back to the DataFrame
    #print(f"Available columns in spectra_df: {spectra_df.columns}")
    df_transformed['obsid'] = spectra_df['file_name']

    return df_transformed


def apply_gaia_transforms(gaia_df, transformers_dict):
    """
    Applies the same Yeo-Johnson (or other) transformations used in training
    to the relevant Gaia columns. 
    """
    # Fill the same NaN values or set the same flags as in training
    # e.g. if you flagged parallax=NaN => set parallax=0, error=10
    # do that here too, to keep consistent with your training pipeline
    #
    # Example based on your code:
    gaia_df['flagnopllx'] = np.where(gaia_df['parallax'].isna(), 1, 0)
    gaia_df['parallax']       = gaia_df['parallax'].fillna(0)
    gaia_df['parallax_error'] = gaia_df['parallax_error'].fillna(10)
    gaia_df['pmra']           = gaia_df['pmra'].fillna(0)
    gaia_df['pmra_error']     = gaia_df['pmra_error'].fillna(10)
    gaia_df['pmdec']          = gaia_df['pmdec'].fillna(0)
    gaia_df['pmdec_error']    = gaia_df['pmdec_error'].fillna(10)

    gaia_df['flagnoflux'] = 0
    # If G or BP or RP is missing
    missing_flux = gaia_df['phot_g_mean_flux'].isna() | gaia_df['phot_bp_mean_flux'].isna() 
    gaia_df.loc[missing_flux, 'flagnoflux'] = 1

    # fill flux with 0 and error with large number
    gaia_df['phot_g_mean_flux']       = gaia_df['phot_g_mean_flux'].fillna(0)
    gaia_df['phot_g_mean_flux_error'] = gaia_df['phot_g_mean_flux_error'].fillna(50000)
    gaia_df['phot_bp_mean_flux']      = gaia_df['phot_bp_mean_flux'].fillna(0)
    gaia_df['phot_bp_mean_flux_error']= gaia_df['phot_bp_mean_flux_error'].fillna(50000)
    gaia_df['phot_rp_mean_flux']      = gaia_df['phot_rp_mean_flux'].fillna(0)
    gaia_df['phot_rp_mean_flux_error']= gaia_df['phot_rp_mean_flux_error'].fillna(50000)

    # Drop any rows that are incomplete, if that was your final approach:
    gaia_df.dropna(axis=0, inplace=True)
    print(f"Dropped {len(gaia_df) - len(gaia_df.dropna())} rows with NaN values.")

    # Remove source_id and other columns not to be transformed to be added back later
    source_id = gaia_df['source_id']
    gaia_df = gaia_df.drop(columns=["source_id"])

    # Now apply the stored transformations:
    for col, transformer in transformers_dict.items():
        if col in gaia_df.columns:
            #print(f"Transforming column: {col}")
            gaia_df[col] = transformer.transform(gaia_df[[col]])
            #print(f"Transformed column: {col}")
        else:
            # If the column didn't exist, maybe set to 0 or skip?
            print(f"Warning: column {col} not found in new data, skipping transform.")

    # Add back the source_id column
    gaia_df['source_id'] = source_id
    return gaia_df


def process_star_data_fusion(
    model_path, 
    X, 
    classes_path, 
    d_model_spectra=2048, 
    d_model_gaia=2048,
    num_classes=55, 
    input_dim_spectra=3647, 
    input_dim_gaia=18, 
    depth=20, 
    sigmoid_constant=0.5,
    class_to_plot="AllStars***lamost"
):
    """Processes star data using the fused StarClassifierFusion model."""

    # Load the data
    classes = pd.read_pickle(classes_path)

    # Load the trained fusion model
    model = StarClassifierFusion(
        d_model_spectra=d_model_spectra,
        d_model_gaia=d_model_gaia,
        num_classes=num_classes,
        input_dim_spectra=input_dim_spectra,
        input_dim_gaia=input_dim_gaia,
        n_layers=depth,
        use_cross_attention=True,  # Change to False for late fusion
        n_cross_attn_heads=8
    )

    # Load the state dictionary
    state_dict = torch.load(model_path, weights_only=False)
    model.load_state_dict(state_dict)

    # Get multi-hot encoded labels
    #y = X[classes]

    # Define Gaia columns
    gaia_columns = [
        "parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec",
        "pmra_error", "pmdec_error", "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error",
        "phot_bp_mean_flux", "phot_rp_mean_flux", "phot_bp_mean_flux_error", "phot_rp_mean_flux_error",
        "flagnoflux"
    ]

    # Separate Gaia and Spectra features
    X_spectra = X.drop(columns={"obsid","source_id", *gaia_columns})
    X_gaia = X[gaia_columns]

    print(f"X_spectra shape: {X_spectra.shape}")
    print(f"X_gaia shape: {X_gaia.shape}")
    #print(f"y shape: {y.shape}")

    if class_to_plot != "AllStars***lamost":
        # Filter for a specific class
        X_spectra = X_spectra[y[class_to_plot] == 1]
        X_gaia = X_gaia[y[class_to_plot] == 1]
        #y = y[y[class_to_plot] == 1]

        print(f"X_spectra shape after filtering for {class_to_plot}: {X_spectra.shape}")
        print(f"X_gaia shape after filtering for {class_to_plot}: {X_gaia.shape}")
       # print(f"y shape after filtering for {class_to_plot}: {y.shape}")

    # Drop label columns from spectra
    #X_spectra.drop(classes, axis=1, inplace=True)

    # Convert to tensors
    X_spectra = torch.tensor(X_spectra.values, dtype=torch.float32)
    X_gaia = torch.tensor(X_gaia.values, dtype=torch.float32)
    #y = torch.tensor(y.values, dtype=torch.float32)

    # Create DataLoader
    class MultiModalDataset(Dataset):
        def __init__(self, X_spectra, X_gaia):
            self.X_spectra = X_spectra
            self.X_gaia = X_gaia
            #self.y = y

        def __len__(self):
            return len(self.X_spectra)

        def __getitem__(self, idx):
            return self.X_spectra[idx], self.X_gaia[idx]

    dataset = MultiModalDataset(X_spectra, X_gaia)
    loader = DataLoader(dataset, batch_size=128, shuffle=False)

    # Move model to device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    model.eval()

    all_predicted = []
    all_y = []

    with torch.no_grad():
        for X_spc, X_ga in loader:
            # Move data to device
            X_spc, X_ga = X_spc.to(device), X_ga.to(device)

            # Forward pass
            outputs = model(X_spc, X_ga)
            predicted = (torch.sigmoid(outputs) > sigmoid_constant).float()

            # Store predictions and labels
            all_predicted.append(predicted.cpu().numpy())
            #all_y.append(y_batch.cpu().numpy())

            # Free GPU memory
            torch.cuda.empty_cache()

    # Concatenate all predictions and labels
    #y_cpu = np.concatenate(all_y, axis=0)
    predicted_cpu = np.concatenate(all_predicted, axis=0)

    return predicted_cpu


def custom_precision_score(y_true, y_pred, zero_division=0):
    """
    Calculates a custom precision-like metric: TP / (TP + FN')
    where FN' are False Negatives for which the model made at least one
    (wrong) prediction for that sample.

    Args:
        y_true (np.ndarray): Ground truth labels (binary, samples x classes).
        y_pred (np.ndarray): Predicted labels (binary, samples x classes).
        zero_division (int or float): Value to return if the denominator (TP + FN') is 0.

    Returns:
        np.ndarray: Array of custom precision scores for each class.
    """
    if y_true.shape != y_pred.shape:
        raise ValueError("y_true and y_pred must have the same shape.")
    if y_true.ndim != 2:
         raise ValueError("y_true and y_pred must be 2D arrays (samples x classes).")

    num_samples, num_classes = y_true.shape
    custom_precisions = np.zeros(num_classes, dtype=float)

    # Identify samples where the model made *any* prediction
    # Shape: (num_samples,) - boolean mask
    model_made_a_prediction_mask = np.sum(y_pred, axis=1) > 0

    for c in range(num_classes):
        # Identify True Positives for class c
        # Shape: (num_samples,) - boolean mask
        tp_mask = (y_true[:, c] == 1) & (y_pred[:, c] == 1)
        tp = np.sum(tp_mask)

        # Identify standard False Negatives for class c
        # Shape: (num_samples,) - boolean mask
        fn_mask = (y_true[:, c] == 1) & (y_pred[:, c] == 0)

        # Identify the specific FN' subset:
        # These are samples that are FN for class c AND for which the model
        # made *some* prediction (could be for any class).
        # Shape: (num_samples,) - boolean mask
        fn_prime_mask = fn_mask & model_made_a_prediction_mask
        fn_prime = np.sum(fn_prime_mask)

        # Calculate the custom metric for class c
        denominator = tp + fn_prime
        if denominator > 0:
            custom_precisions[c] = tp / denominator
        else:
            # Assign the zero_division value if denominator is 0
            custom_precisions[c] = float(zero_division)

    return custom_precisions

def custom_f1_score(y_true, y_pred, zero_division=0):
    """
    Calculates a custom F1 score based on the harmonic mean of
    custom_precision_score and standard recall_score.

    Args:
        y_true (np.ndarray): Ground truth labels (binary, samples x classes).
        y_pred (np.ndarray): Predicted labels (binary, samples x classes).
        zero_division (int or float): Value to return for F1 when both
                                     custom precision and recall are 0.

    Returns:
        np.ndarray: Array of custom F1 scores for each class.
    """
    if y_true.shape != y_pred.shape:
        raise ValueError("y_true and y_pred must have the same shape.")
    if y_true.ndim != 2:
         raise ValueError("y_true and y_pred must be 2D arrays (samples x classes).")

    # Calculate custom precision per class
    p_custom = custom_precision_score(y_true, y_pred, zero_division=zero_division)

    # Calculate standard recall per class
    # Note: Using the same zero_division behavior for consistency here
    r_standard = recall_score(y_true, y_pred, average=None, zero_division=zero_division)

    # Calculate custom F1 using vectorization, handling division by zero
    denominator = p_custom + r_standard
    f1_custom = np.zeros_like(denominator, dtype=float) # Initialize with zeros

    # Create a mask for where the denominator is non-zero
    valid_mask = denominator > 0

    # Calculate F1 only where the denominator is valid
    f1_custom[valid_mask] = 2 * (p_custom[valid_mask] * r_standard[valid_mask]) / denominator[valid_mask]

    # For cases where denominator is zero, F1 should be zero_division
    # (If zero_division is 0, the initial np.zeros_like already handles this)
    if zero_division != 0:
        zero_mask = ~valid_mask # Where denominator is zero
        f1_custom[zero_mask] = float(zero_division) # Assign zero_division value

    return f1_custom

def exact_match_ratio(y_true, y_pred):
    """
    Calculates the Exact Match Ratio (Subset Accuracy).

    This metric requires the set of predicted labels for a sample to
    exactly match the set of true labels for that sample.

    Args:
        y_true (np.ndarray): Ground truth labels (binary, samples x classes).
        y_pred (np.ndarray): Predicted labels (binary, samples x classes).

    Returns:
        float: The Exact Match Ratio (a single value between 0 and 1).
    """
    if y_true.shape != y_pred.shape:
        raise ValueError("y_true and y_pred must have the same shape.")
    if y_true.ndim != 2:
        raise ValueError("y_true and y_pred must be 2D arrays (samples x classes).")

    # Check equality for each sample across all classes
    # .all(axis=1) returns True only if all columns match for a given row (sample)
    exact_matches_mask = (y_true == y_pred).all(axis=1)

    # Calculate the ratio (mean of a boolean array gives the proportion of True values)
    ratio = np.mean(exact_matches_mask)

    return ratio


import torch
import torch.nn as nn
from functools import partial

# Import the needed components from your MambaOut implementation
from timm.models.layers import DropPath

class GatedCNNBlock(nn.Module):
    """Adaptation of GatedCNNBlock for sequence data with dynamic kernel size adaptation"""
    def __init__(self, dim, d_state=256, d_conv=4, expand=2, drop_path=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        hidden = int(expand * dim)
        self.fc1 = nn.Linear(dim, hidden * 2)
        self.act = nn.GELU()
        
        # Store these for dynamic convolution sizing
        self.d_conv = d_conv
        self.hidden = hidden
        
        self.fc2 = nn.Linear(hidden, dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
        # Use simpler approach for sequence length 1 (common case)
        # This avoids dynamic convolution creation
        self.use_identity_for_length_1 = True
        
        # Cache for static convolution with kernel size 1 (for length 1 sequences)
        self.conv1 = nn.Conv1d(
            in_channels=hidden,
            out_channels=hidden, 
            kernel_size=1,
            padding=0,
            groups=hidden
        )

    def forward(self, x):
        # Input shape: [B, seq_len, dim]
        shortcut = x
        x = self.norm(x)
        
        # Split the channels for gating mechanism
        x = self.fc1(x)  # [B, seq_len, hidden*2]
        g, c = torch.chunk(x, 2, dim=-1)  # Each: [B, seq_len, hidden]
        
        # Get sequence length
        batch_size, seq_len, channels = c.shape
        
        # Apply gating mechanism
        c_permuted = c.permute(0, 2, 1)  # [B, hidden, seq_len]
        
        # Special case for sequence length 1 (most common)
        if seq_len == 1 and self.use_identity_for_length_1:
            # Use the pre-created kernel size 1 conv, which is like identity but keeps channels
            c_conv = self.conv1(c_permuted)
        else:
            # For other sequence lengths, fallback to kernel size 1 to avoid issues
            # The conv1 layer is already initialized and on the correct device
            c_conv = self.conv1(c_permuted)
        
        c_final = c_conv.permute(0, 2, 1)  # [B, seq_len, hidden]
        
        # Gating mechanism
        x = self.fc2(self.act(g) * c_final)  # [B, seq_len, dim]
        
        x = self.drop_path(x)
        return x + shortcut

class SequenceMambaOut(nn.Module):
    """Adaptation of MambaOut for sequence data with a single stage"""
    def __init__(self, d_model, d_state=256, d_conv=4, expand=2, depth=1, drop_path=0.):
        super().__init__()
        
        # Create a sequence of GatedCNNBlocks
        self.blocks = nn.Sequential(
            *[GatedCNNBlock(
                dim=d_model,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
                drop_path=drop_path
            ) for _ in range(depth)]
        )
    
    def forward(self, x):
        return self.blocks(x)

class CrossAttentionBlock(nn.Module):
    def __init__(self, dim, n_heads=8):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.attention = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=n_heads,
            batch_first=True
        )
    
    def forward(self, x, context):
        """
        x: (B, seq_len_x, dim)
        context: (B, seq_len_context, dim)
        """
        x_norm = self.norm(x)
        attn_output, _ = self.attention(
            query=x_norm,
            key=context,
            value=context
        )
        return x + attn_output
    
def process_star_data_fusion(
    model_path, 
    X, 
    classes_path, 
    d_model_spectra=2048, 
    d_model_gaia=2048,
    num_classes=55, 
    input_dim_spectra=3647, 
    input_dim_gaia=18, 
    n_layers=20, 
    sigmoid_constant=0.5,
    classifier_hidden_dim_multiplier=5,
    class_to_plot="AllStars***lamost"
):
    """Processes star data using the fused StarClassifierFusion model."""

    # Load the data
    classes = pd.read_pickle(classes_path)

    # Load the trained fusion model
    model = StarClassifierFusion(
        d_model_spectra=d_model_spectra,
        d_model_gaia=d_model_gaia,
        num_classes=num_classes,
        input_dim_spectra=input_dim_spectra,
        input_dim_gaia=input_dim_gaia,
        n_layers=n_layers,
        use_cross_attention=True,  # set to False to compare with late fusion
        n_cross_attn_heads=8,
        classifier_hidden_dim_multiplier=classifier_hidden_dim_multiplier,
        classifier_dropout=0.2,
    )

    # Load the state dictionary
    state_dict = torch.load(model_path, weights_only=False)
    model.load_state_dict(state_dict)

    # Get multi-hot encoded labels
    #y = X[classes]

    # Define Gaia columns
    gaia_columns = [
        "parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec",
        "pmra_error", "pmdec_error", "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error",
        "phot_bp_mean_flux", "phot_rp_mean_flux", "phot_bp_mean_flux_error", "phot_rp_mean_flux_error",
        "flagnoflux"
    ]

    # Separate Gaia and Spectra features
    X_spectra = X.drop(columns={"obsid","source_id", *gaia_columns})
    X_gaia = X[gaia_columns]

    print(f"X_spectra shape: {X_spectra.shape}")
    print(f"X_gaia shape: {X_gaia.shape}")
    #print(f"y shape: {y.shape}")

    if class_to_plot != "AllStars***lamost":
        # Filter for a specific class
        X_spectra = X_spectra[y[class_to_plot] == 1]
        X_gaia = X_gaia[y[class_to_plot] == 1]
        #y = y[y[class_to_plot] == 1]

        print(f"X_spectra shape after filtering for {class_to_plot}: {X_spectra.shape}")
        print(f"X_gaia shape after filtering for {class_to_plot}: {X_gaia.shape}")
       # print(f"y shape after filtering for {class_to_plot}: {y.shape}")

    # Drop label columns from spectra
    #X_spectra.drop(classes, axis=1, inplace=True)

    # Convert to tensors
    X_spectra = torch.tensor(X_spectra.values, dtype=torch.float32)
    X_gaia = torch.tensor(X_gaia.values, dtype=torch.float32)
    #y = torch.tensor(y.values, dtype=torch.float32)

    # Create DataLoader
    class MultiModalDataset(Dataset):
        def __init__(self, X_spectra, X_gaia):
            self.X_spectra = X_spectra
            self.X_gaia = X_gaia
            #self.y = y

        def __len__(self):
            return len(self.X_spectra)

        def __getitem__(self, idx):
            return self.X_spectra[idx], self.X_gaia[idx]

    dataset = MultiModalDataset(X_spectra, X_gaia)
    loader = DataLoader(dataset, batch_size=128, shuffle=False)

    # Move model to device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    model.eval()

    all_predicted = []
    all_y = []

    with torch.no_grad():
        for X_spc, X_ga in loader:
            # Move data to device
            X_spc, X_ga = X_spc.to(device), X_ga.to(device)

            # Forward pass
            outputs = model(X_spc, X_ga)
            predicted = (torch.sigmoid(outputs) > sigmoid_constant).float()

            # Store predictions and labels
            all_predicted.append(predicted.cpu().numpy())
            #all_y.append(y_batch.cpu().numpy())

            # Free GPU memory
            torch.cuda.empty_cache()

    # Concatenate all predictions and labels
    #y_cpu = np.concatenate(all_y, axis=0)
    predicted_cpu = np.concatenate(all_predicted, axis=0)

    return predicted_cpu

class StarClassifierFusion(nn.Module):
    def __init__(
        self,
        d_model_spectra = 2048,
        d_model_gaia = 2048,
        num_classes = 55,
        input_dim_spectra = 3647,
        input_dim_gaia = 18,
        n_layers=20,
        use_cross_attention=True,
        n_cross_attn_heads=8,
        d_state=256,
        d_conv=4,
        expand=2,
    ):
        """
        Args:
            d_model_spectra (int): embedding dimension for the spectra MAMBA
            d_model_gaia (int): embedding dimension for the gaia MAMBA
            num_classes (int): multi-label classification
            input_dim_spectra (int): # of features for spectra
            input_dim_gaia (int): # of features for gaia
            n_layers (int): depth for each MAMBA
            use_cross_attention (bool): whether to use cross-attention
            n_cross_attn_heads (int): number of heads for cross-attention
        """
        super().__init__()

        # --- MambaOut for spectra ---
        self.mamba_spectra = nn.Sequential(
            *[SequenceMambaOut(
                d_model=d_model_spectra,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
                depth=1,  # Each SequenceMambaOut has depth 1
                drop_path=0.1 if i > 0 else 0.0,  # Optional: add some dropout for regularization
            ) for i in range(n_layers)]
        )
        self.input_proj_spectra = nn.Linear(input_dim_spectra, d_model_spectra)

        # --- MambaOut for gaia ---
        self.mamba_gaia = nn.Sequential(
            *[SequenceMambaOut(
                d_model=d_model_gaia,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
                depth=1,  # Each SequenceMambaOut has depth 1
                drop_path=0.1 if i > 0 else 0.0,  # Optional: add some dropout for regularization
            ) for i in range(n_layers)]
        )
        self.input_proj_gaia = nn.Linear(input_dim_gaia, d_model_gaia)

        # --- Cross Attention (Optional) ---
        self.use_cross_attention = use_cross_attention
        if use_cross_attention:
            self.cross_attn_block_spectra = CrossAttentionBlock(d_model_spectra, n_heads=n_cross_attn_heads)
            self.cross_attn_block_gaia = CrossAttentionBlock(d_model_gaia, n_heads=n_cross_attn_heads)

        # --- Final Classifier ---
        fusion_dim = d_model_spectra + d_model_gaia
        self.classifier = nn.Sequential(
            nn.LayerNorm(fusion_dim),
            nn.Linear(fusion_dim, num_classes)
        )
    
    def forward(self, x_spectra, x_gaia):
        """
        x_spectra : (batch_size, input_dim_spectra) or (batch_size, seq_len_spectra, input_dim_spectra)
        x_gaia    : (batch_size, input_dim_gaia) or (batch_size, seq_len_gaia, input_dim_gaia)
        """
        # For MambaOut, we expect shape: (B, seq_len, d_model). 
        # If input is just (B, d_in), we turn it into (B, 1, d_in).
        
        # --- Project to d_model and add sequence dimension if needed ---
        if len(x_spectra.shape) == 2:
            x_spectra = self.input_proj_spectra(x_spectra)  # (B, d_model_spectra)
            x_spectra = x_spectra.unsqueeze(1)              # (B, 1, d_model_spectra)
        else:
            x_spectra = self.input_proj_spectra(x_spectra)  # (B, seq_len, d_model_spectra)
        
        if len(x_gaia.shape) == 2:
            x_gaia = self.input_proj_gaia(x_gaia)           # (B, d_model_gaia)
            x_gaia = x_gaia.unsqueeze(1)                    # (B, 1, d_model_gaia)
        else:
            x_gaia = self.input_proj_gaia(x_gaia)           # (B, seq_len, d_model_gaia)

        # --- MambaOut encoding (each modality separately) ---
        x_spectra = self.mamba_spectra(x_spectra)  # (B, seq_len, d_model_spectra)
        x_gaia = self.mamba_gaia(x_gaia)           # (B, seq_len, d_model_gaia)

        # Optionally, use cross-attention to fuse the representations
        if self.use_cross_attention:
            # Cross-attention from spectra -> gaia
            x_spectra_fused = self.cross_attn_block_spectra(x_spectra, x_gaia)
            # Cross-attention from gaia -> spectra
            x_gaia_fused = self.cross_attn_block_gaia(x_gaia, x_spectra)
            
            # Update x_spectra and x_gaia
            x_spectra = x_spectra_fused
            x_gaia = x_gaia_fused
        
        # --- Pool across sequence dimension ---
        x_spectra = x_spectra.mean(dim=1)  # (B, d_model_spectra)
        x_gaia = x_gaia.mean(dim=1)        # (B, d_model_gaia)

        # --- Late Fusion by Concatenation ---
        x_fused = torch.cat([x_spectra, x_gaia], dim=-1)  # (B, d_model_spectra + d_model_gaia)

        # --- Final classification ---
        logits = self.classifier(x_fused)  # (B, num_classes)
        return logits
    

class StarClassifierFusion(nn.Module):
    def __init__(
        self,
        d_model_spectra,
        d_model_gaia,
        num_classes,
        input_dim_spectra,
        input_dim_gaia,
        n_layers=20,
        use_cross_attention=True,
        n_cross_attn_heads=8,
        d_state=256,
        d_conv=4,
        expand=2,
        classifier_hidden_dim_multiplier=5, # Multiplier for hidden dim in classifier MLP
        classifier_dropout=0.2,
    ):
        """
        Args:
            d_model_spectra (int): embedding dimension for the spectra MAMBA
            d_model_gaia (int): embedding dimension for the gaia MAMBA
            num_classes (int): multi-label classification
            input_dim_spectra (int): # of features for spectra
            input_dim_gaia (int): # of features for gaia
            n_layers (int): depth for each MAMBA
            use_cross_attention (bool): whether to use cross-attention
            n_cross_attn_heads (int): number of heads for cross-attention
        """
        super().__init__()

        # --- MambaOut for spectra ---
        self.mamba_spectra = nn.Sequential(
            *[SequenceMambaOut(
                d_model=d_model_spectra,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
                depth=1,  # Each SequenceMambaOut has depth 1
                drop_path=0.1 if i > 0 else 0.0,  # Optional: add some dropout for regularization
            ) for i in range(n_layers)]
        )
        self.input_proj_spectra = nn.Linear(input_dim_spectra, d_model_spectra)

        # --- MambaOut for gaia ---
        self.mamba_gaia = nn.Sequential(
            *[SequenceMambaOut(
                d_model=d_model_gaia,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
                depth=1,  # Each SequenceMambaOut has depth 1
                drop_path=0.1 if i > 0 else 0.0,  # Optional: add some dropout for regularization
            ) for i in range(n_layers)]
        )
        self.input_proj_gaia = nn.Linear(input_dim_gaia, d_model_gaia)

        # --- Cross Attention (Optional) ---
        self.use_cross_attention = use_cross_attention
        if use_cross_attention:
            self.cross_attn_block_spectra = CrossAttentionBlock(d_model_spectra, n_heads=n_cross_attn_heads)
            self.cross_attn_block_gaia = CrossAttentionBlock(d_model_gaia, n_heads=n_cross_attn_heads)

        # --- Final Classifier (Improved MLP) ---
        fusion_dim = d_model_spectra + d_model_gaia
        classifier_hidden_dim = int(fusion_dim * classifier_hidden_dim_multiplier)

        self.classifier = nn.Sequential(
            nn.LayerNorm(fusion_dim),
            nn.Linear(fusion_dim, classifier_hidden_dim),
            nn.GELU(), # Or ReLU, SiLU, etc.
            nn.Dropout(classifier_dropout),
            nn.Linear(classifier_hidden_dim, num_classes)
        )
    
    def forward(self, x_spectra, x_gaia):
        """
        x_spectra : (batch_size, input_dim_spectra) or (batch_size, seq_len_spectra, input_dim_spectra)
        x_gaia    : (batch_size, input_dim_gaia) or (batch_size, seq_len_gaia, input_dim_gaia)
        """
        # For MambaOut, we expect shape: (B, seq_len, d_model). 
        # If input is just (B, d_in), we turn it into (B, 1, d_in).
        
        # --- Project to d_model and add sequence dimension if needed ---
        if len(x_spectra.shape) == 2:
            x_spectra = self.input_proj_spectra(x_spectra)  # (B, d_model_spectra)
            x_spectra = x_spectra.unsqueeze(1)              # (B, 1, d_model_spectra)
        else:
            x_spectra = self.input_proj_spectra(x_spectra)  # (B, seq_len, d_model_spectra)
        
        if len(x_gaia.shape) == 2:
            x_gaia = self.input_proj_gaia(x_gaia)           # (B, d_model_gaia)
            x_gaia = x_gaia.unsqueeze(1)                    # (B, 1, d_model_gaia)
        else:
            x_gaia = self.input_proj_gaia(x_gaia)           # (B, seq_len, d_model_gaia)

        # --- MambaOut encoding (each modality separately) ---
        x_spectra = self.mamba_spectra(x_spectra)  # (B, seq_len, d_model_spectra)
        x_gaia = self.mamba_gaia(x_gaia)           # (B, seq_len, d_model_gaia)

        # Optionally, use cross-attention to fuse the representations
        if self.use_cross_attention:
            # Cross-attention from spectra -> gaia
            x_spectra_fused = self.cross_attn_block_spectra(x_spectra, x_gaia)
            # Cross-attention from gaia -> spectra
            x_gaia_fused = self.cross_attn_block_gaia(x_gaia, x_spectra)
            
            # Update x_spectra and x_gaia
            x_spectra = x_spectra_fused
            x_gaia = x_gaia_fused
        
        # --- Pool across sequence dimension ---
        x_spectra = x_spectra.mean(dim=1)  # (B, d_model_spectra)
        x_gaia = x_gaia.mean(dim=1)        # (B, d_model_gaia)

        # --- Late Fusion by Concatenation ---
        x_fused = torch.cat([x_spectra, x_gaia], dim=-1)  # (B, d_model_spectra + d_model_gaia)

        # --- Final classification ---
        logits = self.classifier(x_fused)  # (B, num_classes)
        return logits

In [9]:
# Define ADQL query to fetch source IDs of eclipsing binaries
query = """
SELECT source_id
FROM gaiadr3.vari_eclipsing_binary
"""

# Run the query asynchronously
job = Gaia.launch_job_async(query)
results = job.get_results()

# Convert to Pandas DataFrame
gaia_ids = results.to_pandas()

# Convert to list
gaia_ids = gaia_ids['source_id'].values.tolist()

gaia_ids_small = gaia_ids[:1000]

print(f"✅ Retrieved {len(gaia_ids)} eclipsing binary sources from Gaia DR3.")


INFO: Query finished. [astroquery.utils.tap.core]
✅ Retrieved 2184477 eclipsing binary sources from Gaia DR3.


In [10]:
# Load the LAMOST catalog to cross-match with Gaia as csv
lamost_catalogue = pd.read_csv("lamost/minimal.csv")  # Load LAMOST catalog (Just obsid and Ra, Dec)
label_cols = pd.read_pickle("Pickles/Updated_List_of_Classes_ubuntu.pkl")

# Example usage:
model_path = "Models/model_fusion_mambaoutv3.pth"
gaia_transformers = "Pickles/gaia_normalization.pkl"

df_predictions, gaia_lamost_merged = predict_star_labels(gaia_ids, model_path, lamost_catalogue, gaia_transformers)

# Save the predictions to a npy file
np.save("y_predictions_ecl.npy", df_predictions)

# Save the merged DataFrame to a CSV file
gaia_lamost_merged.to_csv("gaia_lamost_merged_ecl.csv", index=False)


🚀 Step 1: Querying Gaia data...
🔗 Gaia IDs: 2184477
Gaia query split into 73 chunks of up to 30000 IDs.

Processing Chunk 1/73...
  Attempt 1/5: Launching Gaia job...
INFO: Query finished. [astroquery.utils.tap.core]
  Job launched (ID: 1746194630281O). Waiting for results...
  Results received for chunk 1.
  Retrieved 30000 records for this chunk.

Processing Chunk 2/73...
  Attempt 1/5: Launching Gaia job...
INFO: Query finished. [astroquery.utils.tap.core]
  Job launched (ID: 1746194642178O). Waiting for results...
  Results received for chunk 2.
  Retrieved 30000 records for this chunk.

Processing Chunk 3/73...
  Attempt 1/5: Launching Gaia job...
INFO: Query finished. [astroquery.utils.tap.core]
  Job launched (ID: 1746194650894O). Waiting for results...
  Results received for chunk 3.
  Retrieved 30000 records for this chunk.

Processing Chunk 4/73...
  Attempt 1/5: Launching Gaia job...
INFO: Query finished. [astroquery.utils.tap.core]
  Job launched (ID: 1746194660981O). Wait

Processing FITS files: 100%|██████████| 34435/34435 [05:35<00:00, 102.61it/s] 


✅ Successfully processed 34435 files
⚠️ Encountered errors in 0 files

📊 Step 5: Extracting and saving flux & frequency values...

📊 Extracting flux and frequency values...


100%|██████████| 7/7 [00:14<00:00,  2.10s/it]


✅ Flux values shape: (34435, 3749), Frequency values shape: (34435, 3749)

📊 Step 6: Interpolating and normalizing LAMOST spectra...


Interpolating spectra: 100%|██████████| 34435/34435 [01:06<00:00, 516.98it/s]


Initial number of rows: 34435
Total successful interpolations: 34257
Total skipped rows (NaNs + zeros): 178
Final check: len(df_flux) == cnt_success + len(nan_files)? True

📊 Step 7: Normalizing Gaia data...
Dropped 0 rows with NaN values.

🔗 Step 8: Merging Gaia and LAMOST data...


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  gaia_lamost_match["obsid"] = gaia_lamost_match["obsid"].astype(int)



🤖 Step 9: Predicting labels using the trained model...
X_spectra shape: (34253, 3647)
X_gaia shape: (34253, 18)

💾 Step 10: Saving predictions...


In [9]:
def custom_precision_score(y_true, y_pred, zero_division=0):
    """
    Calculates a custom precision-like metric: TP / (TP + FN')
    where FN' are False Negatives for which the model made at least one
    (wrong) prediction for that sample.

    Args:
        y_true (np.ndarray): Ground truth labels (binary, samples x classes).
        y_pred (np.ndarray): Predicted labels (binary, samples x classes).
        zero_division (int or float): Value to return if the denominator (TP + FN') is 0.

    Returns:
        np.ndarray: Array of custom precision scores for each class.
    """
    if y_true.shape != y_pred.shape:
        raise ValueError("y_true and y_pred must have the same shape.")
    if y_true.ndim != 2:
         raise ValueError("y_true and y_pred must be 2D arrays (samples x classes).")

    num_samples, num_classes = y_true.shape
    custom_precisions = np.zeros(num_classes, dtype=float)

    # Identify samples where the model made *any* prediction
    # Shape: (num_samples,) - boolean mask
    model_made_a_prediction_mask = np.sum(y_pred, axis=1) > 0

    for c in range(num_classes):
        # Identify True Positives for class c
        # Shape: (num_samples,) - boolean mask
        tp_mask = (y_true[:, c] == 1) & (y_pred[:, c] == 1)
        tp = np.sum(tp_mask)

        # Identify standard False Negatives for class c
        # Shape: (num_samples,) - boolean mask
        fn_mask = (y_true[:, c] == 1) & (y_pred[:, c] == 0)

        # Identify the specific FN' subset:
        # These are samples that are FN for class c AND for which the model
        # made *some* prediction (could be for any class).
        # Shape: (num_samples,) - boolean mask
        fn_prime_mask = fn_mask & model_made_a_prediction_mask
        fn_prime = np.sum(fn_prime_mask)

        # Calculate the custom metric for class c
        denominator = tp + fn_prime
        if denominator > 0:
            custom_precisions[c] = tp / denominator
        else:
            # Assign the zero_division value if denominator is 0
            custom_precisions[c] = float(zero_division)

    return custom_precisions

def custom_f1_score(y_true, y_pred, zero_division=0):
    """
    Calculates a custom F1 score based on the harmonic mean of
    custom_precision_score and standard recall_score.

    Args:
        y_true (np.ndarray): Ground truth labels (binary, samples x classes).
        y_pred (np.ndarray): Predicted labels (binary, samples x classes).
        zero_division (int or float): Value to return for F1 when both
                                     custom precision and recall are 0.

    Returns:
        np.ndarray: Array of custom F1 scores for each class.
    """
    if y_true.shape != y_pred.shape:
        raise ValueError("y_true and y_pred must have the same shape.")
    if y_true.ndim != 2:
         raise ValueError("y_true and y_pred must be 2D arrays (samples x classes).")

    # Calculate custom precision per class
    p_custom = custom_precision_score(y_true, y_pred, zero_division=zero_division)

    # Calculate standard recall per class
    # Note: Using the same zero_division behavior for consistency here
    r_standard = recall_score(y_true, y_pred, average=None, zero_division=zero_division)

    # Calculate custom F1 using vectorization, handling division by zero
    denominator = p_custom + r_standard
    f1_custom = np.zeros_like(denominator, dtype=float) # Initialize with zeros

    # Create a mask for where the denominator is non-zero
    valid_mask = denominator > 0

    # Calculate F1 only where the denominator is valid
    f1_custom[valid_mask] = 2 * (p_custom[valid_mask] * r_standard[valid_mask]) / denominator[valid_mask]

    # For cases where denominator is zero, F1 should be zero_division
    # (If zero_division is 0, the initial np.zeros_like already handles this)
    if zero_division != 0:
        zero_mask = ~valid_mask # Where denominator is zero
        f1_custom[zero_mask] = float(zero_division) # Assign zero_division value

    return f1_custom

def exact_match_ratio(y_true, y_pred):
    """
    Calculates the Exact Match Ratio (Subset Accuracy).

    This metric requires the set of predicted labels for a sample to
    exactly match the set of true labels for that sample.

    Args:
        y_true (np.ndarray): Ground truth labels (binary, samples x classes).
        y_pred (np.ndarray): Predicted labels (binary, samples x classes).

    Returns:
        float: The Exact Match Ratio (a single value between 0 and 1).
    """
    if y_true.shape != y_pred.shape:
        raise ValueError("y_true and y_pred must have the same shape.")
    if y_true.ndim != 2:
        raise ValueError("y_true and y_pred must be 2D arrays (samples x classes).")

    # Check equality for each sample across all classes
    # .all(axis=1) returns True only if all columns match for a given row (sample)
    exact_matches_mask = (y_true == y_pred).all(axis=1)

    # Calculate the ratio (mean of a boolean array gives the proportion of True values)
    ratio = np.mean(exact_matches_mask)

    return ratio

In [34]:
# Load the predictions and class labels
y_pred = np.load("y_predictions_ecl.npy")
y_pred = np.array(df_predictions.iloc[:, :-1], dtype=int)
classes = pd.read_pickle("Pickles/Updated_List_of_Classes_ubuntu.pkl")

# Generate the expected y_true for eclipsing binaries
y_true = np.zeros_like(y_pred)
y_true[:, -1] = 1  # "EB*" column (last column)
y_true[:, 1] = 1   # "**" column (second column)

# Compute precision, recall, and F1-score for each class
precision = precision_score(y_true, y_pred, average=None, zero_division=0)
custom_precision = custom_precision_score(y_true, y_pred, zero_division=0)
custom_f1 = custom_f1_score(y_true, y_pred, zero_division=0)
recall = recall_score(y_true, y_pred, average=None, zero_division=0)
f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
exact_match = exact_match_ratio(y_true, y_pred)

# Create a DataFrame to store metrics per class
metrics_df = pd.DataFrame({
    "Class": classes,
    "Precision": precision,
    "Custom Precision": custom_precision,
    "Recall": recall,
    "F1 Score": f1,
    "Custom F1 Score": custom_f1,
    "Exact Match Ratio": exact_match
})

# Identify Correctly Classified Samples (True Positives)
correct_predictions = (y_pred == y_true).all(axis=1)
correct_gaia_ids = df_predictions.loc[correct_predictions, "source_id"]

# Identify incorrectly classified samples (False Positives and False Negatives)
incorrect_predictions = (y_pred != y_true).any(axis=1)
incorrect_gaia_ids = df_predictions.loc[incorrect_predictions, "source_id"]

# Display incorrectly classified Gaia IDs
print("\n🔍 Incorrectly Classified Gaia IDs:")
print(pd.DataFrame({"source_id": incorrect_gaia_ids}))

# Display the performance metrics for the non-zero classes
print("\n📊 Performance Metrics:")
metrics_df = metrics_df[metrics_df["Precision"] > 0]
print(metrics_df)


🔍 Incorrectly Classified Gaia IDs:
                 source_id
1          544743587681792
4         1179432379804928
10        2367390269205248
12        2590827352388096
14        3715047927383424
...                    ...
34246  6911669219076603136
34247  6911722369297224704
34248  6912724166123684608
34251  6916019432537074048
34252  6916024384634026240

[15871 rows x 1 columns]

📊 Performance Metrics:
   Class  Precision  Custom Precision    Recall  F1 Score  Custom F1 Score  \
1     **        1.0          0.820649  0.677138  0.807492         0.742018   
54   EB*        1.0          0.758978  0.626252  0.770178         0.686256   

    Exact Match Ratio  
1            0.536654  
54           0.536654  


In [12]:
recall_predictions = process_star_data_fusion(model_path, gaia_lamost_merged, "Pickles/Updated_List_of_Classes_ubuntu.pkl", sigmoid_constant=0.1)

X_spectra shape: (34253, 3647)
X_gaia shape: (34253, 18)


In [35]:
# Load the predictions and class labels
y_pred = recall_predictions
print(f"y_pred shape: {y_pred.shape}")
classes = pd.read_pickle("Pickles/Updated_List_of_Classes_ubuntu.pkl")

# Generate the expected y_true for eclipsing binaries
y_true = np.zeros_like(y_pred)
y_true[:, -1] = 1  # "EB*" column (last column)
y_true[:, 1] = 1   # "**" column (second column)

# Compute precision, recall, and F1-score for each class
precision = precision_score(y_true, y_pred, average=None, zero_division=0)
custom_precision = custom_precision_score(y_true, y_pred, zero_division=0)
recall = recall_score(y_true, y_pred, average=None, zero_division=0)
f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
custom_f1 = custom_f1_score(y_true, y_pred, zero_division=0)
exact_match = exact_match_ratio(y_true, y_pred)

# Create a DataFrame to store metrics per class
metrics_df = pd.DataFrame({
    "Class": classes,
    "Precision": precision,
    "Custom Precision": custom_precision,
    "Recall": recall,
    "F1 Score": f1,
    "Custom F1 Score": custom_f1,
    "Exact Match Ratio": exact_match
})

# Identify Correctly Classified Samples (True Positives)
correct_predictions = (y_pred == y_true).all(axis=1)
correct_gaia_ids = df_predictions.loc[correct_predictions, "source_id"]

# Identify incorrectly classified samples (False Positives and False Negatives)
incorrect_predictions = (y_pred != y_true).any(axis=1)
incorrect_gaia_ids = df_predictions.loc[incorrect_predictions, "source_id"]

# Display incorrectly classified Gaia IDs
print("\n🔍 Incorrectly Classified Gaia IDs:")
print(pd.DataFrame({"source_id": incorrect_gaia_ids}))

# Display the performance metrics for the non-zero classes
print("\n📊 Performance Metrics:")
metrics_df = metrics_df[metrics_df["Precision"] > 0]
print(metrics_df)

y_pred shape: (34253, 55)

🔍 Incorrectly Classified Gaia IDs:
                 source_id
0          138577120584320
1          544743587681792
2          608515262200832
4         1179432379804928
10        2367390269205248
...                    ...
34245  6911578200129779200
34248  6912724166123684608
34249  6912964203255804544
34251  6916019432537074048
34252  6916024384634026240

[15552 rows x 1 columns]

📊 Performance Metrics:
   Class  Precision  Custom Precision    Recall  F1 Score  Custom F1 Score  \
1     **        1.0          0.839757  0.826935   0.90527         0.833297   
54   EB*        1.0          0.856567  0.843488   0.91510         0.849977   

    Exact Match Ratio  
1            0.545967  
54           0.545967  


In [37]:
recall_predictions03 = process_star_data_fusion(model_path, gaia_lamost_merged, "Pickles/Updated_List_of_Classes_ubuntu.pkl", sigmoid_constant=0.3)

# Load the predictions and class labels
y_pred = recall_predictions03
print(f"y_pred shape: {y_pred.shape}")
classes = pd.read_pickle("Pickles/Updated_List_of_Classes_ubuntu.pkl")

# Generate the expected y_true for eclipsing binaries
y_true = np.zeros_like(y_pred)
y_true[:, -1] = 1  # "EB*" column (last column)
y_true[:, 1] = 1   # "**" column (second column)

# Compute precision, recall, and F1-score for each class
precision = precision_score(y_true, y_pred, average=None, zero_division=0)
custom_precision = custom_precision_score(y_true, y_pred, zero_division=0)
recall = recall_score(y_true, y_pred, average=None, zero_division=0)
f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
custom_f1 = custom_f1_score(y_true, y_pred, zero_division=0)
exact_match = exact_match_ratio(y_true, y_pred)

# Create a DataFrame to store metrics per class
metrics_df = pd.DataFrame({
    "Class": classes,
    "Precision": precision,
    "Custom Precision": custom_precision,
    "Recall": recall,
    "F1 Score": f1,
    "Custom F1 Score": custom_f1,
    "Exact Match Ratio": exact_match
})

# Identify Correctly Classified Samples (True Positives)
correct_predictions = (y_pred == y_true).all(axis=1)
correct_gaia_ids = df_predictions.loc[correct_predictions, "source_id"]

# Identify incorrectly classified samples (False Positives and False Negatives)
incorrect_predictions = (y_pred != y_true).any(axis=1)
incorrect_gaia_ids = df_predictions.loc[incorrect_predictions, "source_id"]

# Display incorrectly classified Gaia IDs
print("\n🔍 Incorrectly Classified Gaia IDs:")
print(pd.DataFrame({"source_id": incorrect_gaia_ids}))

# Display the performance metrics for the non-zero classes
print("\n📊 Performance Metrics:")
metrics_df = metrics_df[metrics_df["Precision"] > 0]
print(metrics_df)

X_spectra shape: (34253, 3647)
X_gaia shape: (34253, 18)
y_pred shape: (34253, 55)

🔍 Incorrectly Classified Gaia IDs:
                 source_id
0          138577120584320
1          544743587681792
4         1179432379804928
14        3715047927383424
15        4326135874205184
...                    ...
34243  6911428735267723904
34246  6911669219076603136
34248  6912724166123684608
34251  6916019432537074048
34252  6916024384634026240

[13986 rows x 1 columns]

📊 Performance Metrics:
   Class  Precision  Custom Precision    Recall  F1 Score  Custom F1 Score  \
1     **        1.0          0.819891  0.743439  0.852842         0.779795   
54   EB*        1.0          0.793232  0.719265  0.836713         0.754440   

    Exact Match Ratio  
1            0.591685  
54           0.591685  


In [None]:
precision_predictions = process_star_data_fusion(model_path, gaia_lamost_merged, "Pickles/Updated_List_of_Classes_ubuntu.pkl", sigmoid_constant=0.9)

X_spectra shape: (34253, 3647)
X_gaia shape: (34253, 18)


In [36]:
# Load the predictions and class labels
y_pred = precision_predictions
print(f"y_pred shape: {y_pred.shape}")
classes = pd.read_pickle("Pickles/Updated_List_of_Classes_ubuntu.pkl")

# Generate the expected y_true for eclipsing binaries
y_true = np.zeros_like(y_pred)
y_true[:, -1] = 1  # "EB*" column (last column)
y_true[:, 1] = 1   # "**" column (second column)

# Compute precision, recall, and F1-score for each class
precision = precision_score(y_true, y_pred, average=None, zero_division=0)
custom_precision = custom_precision_score(y_true, y_pred, zero_division=0)
recall = recall_score(y_true, y_pred, average=None, zero_division=0)
f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
custom_f1 = custom_f1_score(y_true, y_pred, zero_division=0)
exact_match = exact_match_ratio(y_true, y_pred)

# Create a DataFrame to store metrics per class
metrics_df = pd.DataFrame({
    "Class": classes,
    "Precision": precision,
    "Custom Precision": custom_precision,
    "Recall": recall,
    "F1 Score": f1,
    "Custom F1 Score": custom_f1,
    "Exact Match Ratio": exact_match
})

# Identify Correctly Classified Samples (True Positives)
correct_predictions = (y_pred == y_true).all(axis=1)
correct_gaia_ids = df_predictions.loc[correct_predictions, "source_id"]

# Identify incorrectly classified samples (False Positives and False Negatives)
incorrect_predictions = (y_pred != y_true).any(axis=1)
incorrect_gaia_ids = df_predictions.loc[incorrect_predictions, "source_id"]

# Display incorrectly classified Gaia IDs
print("\n🔍 Incorrectly Classified Gaia IDs:")
print(pd.DataFrame({"source_id": incorrect_gaia_ids}))

# Display the performance metrics for the non-zero classes
print("\n📊 Performance Metrics:")
metrics_df = metrics_df[metrics_df["Precision"] > 0]
print(metrics_df)

y_pred shape: (34253, 55)

🔍 Incorrectly Classified Gaia IDs:
                 source_id
1          544743587681792
2          608515262200832
4         1179432379804928
10        2367390269205248
12        2590827352388096
...                    ...
34246  6911669219076603136
34247  6911722369297224704
34248  6912724166123684608
34251  6916019432537074048
34252  6916024384634026240

[17312 rows x 1 columns]

📊 Performance Metrics:
   Class  Precision  Custom Precision    Recall  F1 Score  Custom F1 Score  \
1     **        1.0          0.821548  0.640032  0.780511         0.719518   
54   EB*        1.0          0.743751  0.579424  0.733715         0.651383   

    Exact Match Ratio  
1            0.494584  
54           0.494584  


In [None]:
# Load the predictions and class labels
y_pred = np.load("y_predictions_ecl.npy")
classes = pd.read_pickle("Pickles/Updated_List_of_Classes_ubuntu.pkl")

# Remove the last column (source_id) from y_pred
y_pred = y_pred[:, :-1]

# Number of times a Class was predicted by the model
predicted_classes = np.sum(y_pred, axis=0)

# Histogram of predicted classes
plt.figure(figsize=(12, 6))
plt.hist(predicted_classes, bins=range(len(classes)), align='left', rwidth=0.8)
plt.xticks(range(len(classes)), classes, rotation=90)
plt.xlabel('Classes')
plt.ylabel('Count')
plt.title('Histogram of Predicted Classes')
plt.tight_layout()
plt.show()

In [38]:
def download_one_spectrum(obsid, session, save_folder):
    """
    Helper function to download one spectrum file given an obsid.
    Uses the same session to get the file and saves it locally.
    
    Args:
        obsid: LAMOST observation ID
        session: Requests session to use
        save_folder: Folder to save the spectrum file
        
    Returns:
        tuple: (obsid, success, error_message)
    """
    url = f"https://www.lamost.org/dr10/v2.0/spectrum/fits/{obsid}"
    local_path = os.path.join(save_folder, str(obsid))
    
    # If already downloaded, skip
    if os.path.exists(local_path):
        return obsid, True, None
    
    # Add a small random delay to prevent hammering the server
    time.sleep(random.uniform(0.1, 0.5))
    
    try:
        resp = session.get(url, timeout=30)
        resp.raise_for_status()
        
        # Save the raw content - we'll handle format detection during processing
        with open(local_path, "wb") as f:
            f.write(resp.content)
            
        return obsid, True, None
    except Exception as e:
        return obsid, False, str(e)

# Doing the triple plot for ECB

In [26]:
y_pred_ = np.load("y_predictions_ecl.npy")
y_pred = np.array(y_pred_[:, :-1], dtype=int) # last column is source_id probably
classes = pd.read_pickle("Pickles/Updated_List_of_Classes_ubuntu.pkl")

print(classes)

EB_index = classes.index("EB*")
bin_index = classes.index("**")

print(f"EB_index: {EB_index}, bin_index: {bin_index}")  

# Generate the expected y_true for eclipsing binaries
y_true = np.zeros_like(y_pred)
y_true[:, EB_index] = 1  # "EB*" column
y_true[:, bin_index] = 1   # "**" column (second column)

# Count the number of misclassified samples
print(np.sum(y_true!=y_pred))

# Compute precision, recall, and F1-score for each class
precision = precision_score(y_true, y_pred, average=None, zero_division=0)
recall = recall_score(y_true, y_pred, average=None, zero_division=0)
f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
custom_f1 = custom_f1_score(y_true, y_pred, zero_division=0)
custom_precision = custom_precision_score(y_true, y_pred, zero_division=0)
exact_match = exact_match_ratio(y_true, y_pred)
from sklearn.metrics import precision_recall_fscore_support as score
_, _, _, support = score(y_true, y_pred)

# Create a DataFrame to store metrics per class
metrics_df = pd.DataFrame({
    "Class": classes,
    "Precision": precision,
    "Recall": recall,
    "F1 Score": f1,
    "Custom F1 Score": custom_f1,
    "Custom Precision": custom_precision,
    "Exact Match Ratio": exact_match,
    "Support": support
})

# Identify Correctly Classified Samples (True Positives)
correct_predictions = (y_pred == y_true).all(axis=1)
correct_gaia_ids = y_pred_[correct_predictions, -1]

# Identify incorrectly classified samples (False Positives and False Negatives)
incorrect_predictions = (y_pred != y_true).any(axis=1)
incorrect_gaia_ids = y_pred_[incorrect_predictions, -1]

# Not precise classification
print("\n--- Identifying Specific Misclassifications (FN') ---")

# 1. Find samples where the model predicted *at least one* label (any label)
model_predicted_something_mask = np.sum(y_pred, axis=1) > 0

# 2. Find samples where the model missed the target classes
missed_cv_mask = (y_true[:, EB_index] == 1) & (y_pred[:, EB_index] == 0)
missed_star_mask = (y_true[:, bin_index] == 1) & (y_pred[:, bin_index] == 0)

# 3. Find FN' samples for each target class
# FN' = Missed the target class AND predicted something else
fn_prime_cv_mask = missed_cv_mask & model_predicted_something_mask
fn_prime_star_mask = missed_star_mask & model_predicted_something_mask

# 4. Combine: Find samples that are FN' for *either* CV* OR **
overall_fn_prime_mask = fn_prime_cv_mask | fn_prime_star_mask

# 5. Get the corresponding Gaia IDs
# Ensure df_predictions index aligns with y_pred rows
# Using .iloc requires the integer indices based on the mask
fn_prime_indices = np.where(overall_fn_prime_mask)[0]
fn_prime_gaia_ids_ecl = y_pred_[fn_prime_indices, -1] # Assumes last col is source_id

print(f"\n🔍 Found {len(fn_prime_gaia_ids_ecl)} samples meeting the FN' criteria for CV* or **:")
print(pd.DataFrame({"source_id": fn_prime_gaia_ids_ecl}))

# State the labels given by the model to the FN' samples
print("\n🔍 Labels given by the model to FN' samples:")
fn_prime_labels = y_pred[overall_fn_prime_mask]
print_fn =pd.DataFrame(fn_prime_labels, columns=classes)
print(pd.DataFrame(fn_prime_labels, columns=classes))

# Display incorrectly classified Gaia IDs
print("\n🔍 Incorrectly Classified Gaia IDs:")
print(pd.DataFrame({"source_id": incorrect_gaia_ids}))

# Display the performance metrics for the non-zero classes
print("\n📊 Performance Metrics:")
metrics_df = metrics_df[metrics_df["Precision"] > 0]
print(metrics_df)
exact_match = np.mean((y_pred == y_true).all(axis=1))
print("Exact match accuracy:", exact_match)

['RS*', '**', 'El*', 'Y*O', 's*b', 'cC*', 'HB*', 'dS*', 'Or*', 'LP*', 'BS*', 'Ae*', 'WV*', 'HS*', 'Ev*', 'AB*', 'sg*', 's*r', 'Ce*', 'gD*', 'OH*', 'HXB', 'Pu*', 'RV*', 'Sy*', 'V*', 'TT*', 'SN*', 'Be*', 'SB*', 'Em*', 'Er*', 'PM*', 'HV*', 'pA*', 'C*', 'BY*', 'Ro*', 'XB*', 'Ma*', 'Pe*', 'CV*', 'bC*', 'RR*', 'Mi*', 'SX*', 'RG*', 'LM*', 'WD*', 'S*', 'MS*', 'Ir*', 'a2*', 'PN', 'EB*']
EB_index: 54, bin_index: 1
29377

--- Identifying Specific Misclassifications (FN') ---

🔍 Found 9482 samples meeting the FN' criteria for CV* or **:
         source_id
0     1.179432e+15
1     2.367390e+15
2     2.590827e+15
3     3.715048e+15
4     4.326136e+15
...            ...
9477  6.911429e+18
9478  6.911578e+18
9479  6.911722e+18
9480  6.916019e+18
9481  6.916024e+18

[9482 rows x 1 columns]

🔍 Labels given by the model to FN' samples:
      RS*  **  El*  Y*O  s*b  cC*  HB*  dS*  Or*  LP*  ...  SX*  RG*  LM*  \
0       0   0    0    0    0    0    0    0    0    0  ...    0    0    0   
1       0   1    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
# Background: Random Stars in Milky Way for CMD

# Define the ADQL query to select stars with parallax > 0.03 mas (roughly within 30kpc, inside milky way)
query = """
SELECT TOP 1000000 source_id, ra, dec, parallax, phot_bp_mean_mag, phot_rp_mean_mag, phot_g_mean_mag, parallax_error
FROM gaiadr3.gaia_source
WHERE parallax > 0.03 
"""

# Launch an asynchronous job (this will return more than 2000 rows if available)
job = Gaia.launch_job_async(query)
gaia_table = job.get_results()

# Convert to a pandas DataFrame
df_all = gaia_table.to_pandas()
print(f"Total stars returned by the query: {len(df_all)}")

# Define the ADQL query to fetch detailed information for the Correctly Classified Gaia IDs
query = """
SELECT source_id, ra, dec, parallax, phot_bp_mean_mag, phot_rp_mean_mag, phot_g_mean_mag, parallax_error
FROM gaiadr3.gaia_source
WHERE source_id IN ({})
"""

# Join the source IDs into a single string
source_ids_str = ",".join([str(id) for id in correct_gaia_ids])
full_query = query.format(source_ids_str)

# Run the query asynchronously
job = Gaia.launch_job_async(full_query)
results = job.get_results()

# Convert to Pandas DataFrame
correct_df = results.to_pandas()

print(f"✅ Retrieved detailed information for {len(correct_df)} correctly classified Gaia IDs.")

# Define the ADQL query to fetch detailed information for the incorrectly Classified Gaia IDs
query = """
SELECT source_id, ra, dec, parallax, phot_bp_mean_mag, phot_rp_mean_mag, phot_g_mean_mag, parallax_error
FROM gaiadr3.gaia_source
WHERE source_id IN ({})
"""

# Join the source IDs into a single string
source_ids_str = ",".join([str(id) for id in incorrect_gaia_ids])
full_query = query.format(source_ids_str)

# Run the query asynchronously
job = Gaia.launch_job_async(full_query)
results = job.get_results()

# Convert to Pandas DataFrame
incorrect_df = results.to_pandas()

print(f"✅ Retrieved detailed information for {len(incorrect_df)} incorrectly classified Gaia IDs.")
# Define the ADQL query to fetch detailed information for the Correctly Classified Gaia IDs
query = """
SELECT source_id, ra, dec, parallax, phot_bp_mean_mag, phot_rp_mean_mag, phot_g_mean_mag, parallax_error
FROM gaiadr3.gaia_source
WHERE source_id IN ({})
"""

# Join the source IDs into a single string
source_ids_str = ",".join([str(id) for id in correct_gaia_ids])
full_query = query.format(source_ids_str)

# Run the query asynchronously
job = Gaia.launch_job_async(full_query)
results = job.get_results()

# Convert to Pandas DataFrame
correct_df = results.to_pandas()

print(f"✅ Retrieved detailed information for {len(correct_df)} correctly classified Gaia IDs.")

# Define the ADQL query to fetch detailed information for the incorrectly Classified Gaia IDs
query = """
SELECT source_id, ra, dec, parallax, phot_bp_mean_mag, phot_rp_mean_mag, phot_g_mean_mag, parallax_error
FROM gaiadr3.gaia_source
WHERE source_id IN ({})
"""

# Join the source IDs into a single string
source_ids_str = ",".join([str(id) for id in incorrect_gaia_ids])
full_query = query.format(source_ids_str)

# Run the query asynchronously
job = Gaia.launch_job_async(full_query)
results = job.get_results()

# Convert to Pandas DataFrame
incorrect_df = results.to_pandas()

print(f"✅ Retrieved detailed information for {len(incorrect_df)} incorrectly classified Gaia IDs.")

# --- Example: Prepare your Gaia DataFrame ---
# Clean out the bad parallax values (negative or zero)
print(f"Total stars before cleaning: {len(df_all)}")
df_all = df_all[df_all['parallax'] > 0].copy()

# Remove stars with large parallax errors (e.g., > 10% of the parallax value)
df_all = df_all[df_all['parallax_error'] < 0.1 * df_all['parallax']].copy()

print(f"Total stars after cleaning: {len(df_all)}")

# Combine the eclisping binary IDs with the Gaia DataFrame concatenated
df_sample = pd.concat([df_all, incorrect_df, correct_df], axis=0)
print(f"Combined DataFrame shape: {df_sample.shape}")

# Save the DataFrame to a CSV file
df_sample.to_csv("gaia_sample_ecl.csv", index=False)


INFO: Query finished. [astroquery.utils.tap.core]
Total stars returned by the query: 1000000


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
import gzip
import io
import os

def add_vertical_line_between_groups(ax, labels):
    """
    Draws a vertical dashed line on the provided axis between the error and flux groups.
    
    :param ax: The matplotlib axes object where the bar chart is plotted.
    :param labels: List of labels for the bars, ordered such that error columns come first and flux columns second.
    """
    # Count the number of error bars (assumes errors come first)
    num_error = sum(1 for label in labels if label.endswith("_error"))
    if num_error and num_error < len(labels):
        # Vertical line placed between the last error and the first flux bar
        separation_index = num_error - 0.5
        ax.axvline(x=separation_index, color='black', linestyle='--', linewidth=5)


def open_fits_file(file_path):
    """
    Opens a FITS file, handling both regular and gzipped formats.
    
    :param file_path: Path to the FITS file
    :return: FITS HDU list or None if there was an error
    """
    try:
        # Check if the file is gzipped
        with open(file_path, 'rb') as f:
            file_start = f.read(2)
            f.seek(0)  # Reset file pointer
            if file_start == b'\x1f\x8b':  # gzip magic number
                # Handle gzipped file
                with gzip.GzipFile(fileobj=f) as gz_f:
                    file_content = gz_f.read()
                print(f"Opening gzipped file: {file_path}")
                return fits.open(io.BytesIO(file_content), ignore_missing_simple=True)
            else:
                # Handle regular file
                print(f"Opening regular file: {file_path}")
                return fits.open(file_path, ignore_missing_simple=True)
    except Exception as e:
        print(f"Error opening file {os.path.basename(file_path)}: {str(e)}")
        return None

def plot_spectrum_with_gaia_and_cmd(source_id, gaia_lamost_merged, df_sample, correct_df, incorrect_df, n,
                                    spectra_folder="lamost_spectra_uniques", save_path=None):
    """
    Plots the LAMOST spectrum, the Gaia parameters with issues, and a Color–Magnitude Diagram (CMD) 
    in a single figure with three subplots.
    
    :param source_id: Gaia Source ID of the incorrectly classified source.
    :param gaia_lamost_merged: DataFrame containing Gaia and LAMOST cross-matched data.
    :param df_sample: DataFrame containing Gaia photometric and parallax data for the CMD.
    :param correct_df: DataFrame containing correctly classified Gaia IDs.
    :param incorrect_df: DataFrame containing incorrectly classified Gaia IDs.
    :param spectra_folder: Path to the folder containing LAMOST FITS spectra.
    :param save_path: If provided, the complete figure is saved to this path.
    """
    try:
        if 'obsid' not in gaia_lamost_merged.columns:
            print("⚠️ 'obsid' column not found in gaia_lamost_merged.")
            return

        match = gaia_lamost_merged.loc[gaia_lamost_merged['source_id'] == source_id]
        if match.empty:
            print(f"⚠️ No LAMOST match found for source_id {source_id}.")
            return

        obsid = int(match.iloc[0]['obsid'])
        print(f"Found match: Source ID {source_id} -> ObsID {obsid}")

        fits_path = f"{spectra_folder}/{int(obsid)}"
        
        # Use the open_fits_file function to handle both regular and gzipped FITS files
        hdul = open_fits_file(fits_path)
        
        if hdul is None:
            print(f"⚠️ Failed to open FITS file for ObsID {obsid}.")
            return
            
        # Process the FITS data
        try:
            # After opening the FITS file, add debugging:
            print(f"FITS file structure for ObsID {obsid}:")
            for i, hdu in enumerate(hdul):
                print(f"  HDU {i}: {hdu.__class__.__name__}, shape={getattr(hdu.data, 'shape', 'No data')}")
            
            # LAMOST DR5 and later uses BinTableHDU in the first extension
            if len(hdul) > 1 and isinstance(hdul[1], fits.BinTableHDU):
                print("Using data from BinTableHDU (extension 1)")
                table_data = hdul[1].data
                
                # Debug table column names
                print(f"  BinTable columns: {table_data.names}")
                
                # For LAMOST spectra, typical column names are 'FLUX', 'WAVELENGTH', 'LOGLAM', etc.
                # Use appropriate column names based on what's available
                if 'FLUX' in table_data.names and 'WAVELENGTH' in table_data.names:
                    flux = table_data['FLUX'][0]  # First row
                    wavelength = table_data['WAVELENGTH'][0]
                    print(f"  Using FLUX and WAVELENGTH columns")
                elif 'FLUX' in table_data.names and 'LOGLAM' in table_data.names:
                    flux = table_data['FLUX'][0]  # First row
                    # Convert log wavelength to linear wavelength
                    log_wavelength = table_data['LOGLAM'][0]
                    wavelength = 10**log_wavelength
                    print(f"  Using FLUX and LOGLAM (converted) columns")
                # Add more conditions for different column naming conventions
                else:
                    # If column names don't match known formats, try first two columns
                    # (often wavelength is first, flux is second)
                    print(f"  Unknown column format, using first two columns")
                    wavelength = table_data[table_data.names[0]][0]
                    flux = table_data[table_data.names[1]][0]
            # Fallback to original method with primary HDU
            elif hdul[0].data is not None and len(hdul[0].data.shape) >= 1:
                print("Using data from PrimaryHDU")
                data = hdul[0].data
                if data.shape[0] < 3:
                    print(f"⚠️ Skipping {obsid}: Primary HDU data has insufficient dimensions: {data.shape}")
                    return
                flux = data[0]
                wavelength = data[2]
            else:
                print(f"⚠️ Skipping {obsid}: No usable data found in FITS file.")
                return
                
            # Check that we have valid data before proceeding
            if flux is None or wavelength is None or len(flux) == 0 or len(wavelength) == 0:
                print(f"⚠️ Skipping {obsid}: Empty flux or wavelength arrays")
                return
                
            print(f"  Data loaded successfully. Wavelength range: {min(wavelength):.2f}-{max(wavelength):.2f} Å")
            print(f"  Flux range: {min(flux):.2e}-{max(flux):.2e}")
            
            # Create a figure with three subplots using GridSpec.
            # Top row: two subplots (spectrum and Gaia issues), Bottom row: CMD spanning full width.
            fig = plt.figure(figsize=(24, 12))
            gs = fig.add_gridspec(1, 3, height_ratios=[1])
            ax1 = fig.add_subplot(gs[0, 1])
            ax2 = fig.add_subplot(gs[0, 2])
            ax3 = fig.add_subplot(gs[0, 0])
            plt.rcParams.update({'font.size': 16})

            # --- Subplot 1: LAMOST Spectrum ---
            ax1.plot(wavelength, flux, color='blue', alpha=0.7, lw=1, zorder=10)
            ax1.set_xlabel("Wavelength (Å)")
            ax1.set_ylabel("Flux")
            ax1.set_title(f"LAMOST Spectrum")
            ax1.grid(zorder=0)

            # --- Subplot 2: Gaia Parameters with Issues ---
            ax2.grid(zorder=0)
            gaia_info = match.iloc[[0]].drop(columns=["source_id", "obsid"], errors='ignore')
            issues_dict = {}
            issues_text_list = []
            for col in gaia_info.columns:
                value = gaia_info[col].values[0]
                if col.endswith("_error") and value > 1:
                    issues_dict[col] = value
                    issues_text_list.append(f"Large error in {col}")
                elif col.endswith("_flux") and value < -1:
                    issues_dict[col] = value
                    issues_text_list.append(f"Dim object in {col}")

            if issues_dict:
                # Order the labels: errors first, then fluxes.
                error_labels = [l for l in issues_dict.keys() if l.endswith("_error")]
                flux_labels = [l for l in issues_dict.keys() if l.endswith("_flux")]
                ordered_labels = error_labels + flux_labels
                ordered_values = [issues_dict[l] for l in ordered_labels]

                ax2.bar(ordered_labels, ordered_values, color='skyblue', zorder=3)
                ax2.tick_params("x", labelrotation=45)
                ax2.set_title("Gaia Parameters with Issues")
                ax2.set_ylabel("Standard Deviations from Mean")

                # Add a vertical dashed line between error and flux groups.
                add_vertical_line_between_groups(ax2, ordered_labels)
            else:
                ax2.text(0.5, 0.5, "No significant data issues", ha='center', va='center', fontsize=12)
                ax2.axis("off")

            # --- Subplot 3: Color–Magnitude Diagram (CMD) ---
            # Compute additional columns for CMD.
            df_sample['color'] = df_sample['phot_bp_mean_mag'] - df_sample['phot_rp_mean_mag']
            df_sample['distance_pc'] = 1000 / df_sample['parallax']
            df_sample['abs_mag'] = df_sample['phot_g_mean_mag'] - 5 * np.log10(df_sample['distance_pc'] / 10)
            df_sample['is_correct'] = df_sample['source_id'].isin(correct_df['source_id'])
            df_sample['is_incorrect'] = df_sample['source_id'].isin(incorrect_df['source_id'])

            # Plot background stars (those not flagged as correct or incorrect)
            mask_background = ~(df_sample['is_correct'] | df_sample['is_incorrect'])
            ax3.scatter(df_sample.loc[mask_background, 'color'], 
                        df_sample.loc[mask_background, 'abs_mag'],
                        s=3, color='gray', alpha=0.6, label='Nearby Stars')

            # Plot the incorrect in red.
            ax3.scatter(df_sample[df_sample['is_incorrect']]['color'],
                        df_sample[df_sample['is_incorrect']]['abs_mag'],
                        s=100, color='red', label='Incorrectly Classified', alpha=1, 
                        edgecolor='black', marker='H')
            
            # Plot the correct in green.
            #ax3.scatter(df_sample[df_sample['is_correct']]['color'],
            #            df_sample[df_sample['is_correct']]['abs_mag'],
            #            s=100, color='green', label='Correctly Classified', alpha=1, 
            #            edgecolor='black', marker='x')
            
            # Plot the target source in blue. FLUX IS NOT THE SAME AS MAGNITUDE, data for both exist in the Gaia table.
            target_color = df_sample.loc[df_sample['source_id'] == source_id, 'color'].values[0]
            target_abs_mag = df_sample.loc[df_sample['source_id'] == source_id, 'abs_mag'].values[0]
            ax3.scatter(target_color, target_abs_mag, s=200, color='blue', label='Target Source', alpha=1, edgecolor='black', marker='o')
            #target_abs_mag = match['phot_g_mean_flux'].values[0] - 5 * np.log10((1/match['parallax'].values[0] )/ 10)
            #ax3.scatter(target_color, target_abs_mag, s=200, color='blue', label='Target Source', alpha=1, edgecolor='black', marker='o')


            # In a CMD, brighter (lower) magnitudes are at the top.
            ax3.invert_yaxis()
            ax3.set_xlim(-0.5, 3.5)
            ax3.set_ylim(14, 0.5)
            ax3.set_xlabel('Colour (BP - RP)')
            ax3.set_ylabel('Absolute G Magnitude')
            ax3.set_title('Colour–Magnitude Diagram (CMD)')
            ax3.legend(loc='lower right')

            plt.tight_layout()
            if save_path:
                save_path= save_path.replace(".png", f"_{n}.png")
                plt.savefig(save_path)
            plt.show()
                
        except Exception as e:
            print(f"Error processing FITS data for source_id {source_id}: {e}")
            import traceback
            traceback.print_exc()
            
        finally:
            # Ensure proper cleanup of FITS file
            if hdul is not None:
                try:
                    hdul.close()
                except Exception as e:
                    print(f"Warning: Could not close FITS file: {e}")
                    
    except Exception as e:
        print(f"Error in overall processing for source_id {source_id}: {e}")
        import traceback
        traceback.print_exc()

gaia_lamost_merged = pd.read_csv("gaia_lamost_merged_ecl.csv")                                                                                                                                                                                                                                           


#gaia_lamost_merged = pd.DataFrame(gaia_lamost_merged, columns=["source_id", "obsid", "ra", "dec", "parallax", "phot_bp_mean_mag", "phot_rp_mean_mag", "phot_g_mean_mag", "parallax_error", "phot_bp_mean_flux", "phot_rp_mean_flux", "phot_g_mean_flux", "parallax_error_flux", "phot_bp_mean_mag_error", "phot_rp_mean_mag_error", "phot_g_mean_mag_error"])


# Example type conversions (ensure these columns are in the correct type)
gaia_lamost_merged['obsid'] = gaia_lamost_merged['obsid'].astype(int)
gaia_lamost_merged['source_id'] = gaia_lamost_merged['source_id'].astype(int)

# Initialize a counter for the save path
n_=1

# Loop through incorrectly classified sources and plot all spectra with labels if Gaia data is problematic.
for source_id in fn_prime_gaia_ids_ecl.astype(int):
    plot_spectrum_with_gaia_and_cmd(source_id, gaia_lamost_merged, save_path=f"Images_and_Plots/CMD_Spectra_Gaia_CV_take2.png", df_sample=df_sample, correct_df=correct_df, incorrect_df=incorrect_df, n=n_)
    n_+=1

NameError: name 'fn_prime_gaia_ids' is not defined

In [None]:
cat

# Cataclysmic Binaries MambaOut WAS HERE

In [26]:
cat_gaia_ids = []
with open('Pickles/Cataclysmic Bin Catalogue Abrahams et al.txt', 'r') as file:
    for line in file:
        # Check if the line starts with a digit (to avoid header lines)
        if line and line[0].isdigit():
            cat_gaia_ids.append(line[:19].strip())
print(cat_gaia_ids)
print(len(cat_gaia_ids))

import pickle

with open('Pickles/gaia_normalization.pkl', 'rb') as f:
    data = pickle.load(f)
print(data)
print(len(data))

# Load the LAMOST catalog to cross-match with Gaia as csv
lamost_catalogue = pd.read_csv("lamost/minimal.csv")  # Load LAMOST catalog (Just obsid and Ra, Dec)
label_cols = pd.read_pickle("Pickles/Updated_List_of_Classes_ubuntu.pkl")

# Example usage:
model_path = "Models/model_fusion_mambaoutv3.pth"
gaia_transformers = "Pickles/gaia_normalization.pkl"

df_predictions, gaia_lamost_merged = predict_star_labels(cat_gaia_ids, model_path, lamost_catalogue, gaia_transformers)

# Save the predictions to a npy file
np.save("y_predictions_bin_out.npy", df_predictions)

['1593140224924964864', '4471866725361723520', '2698490156365025536', '5171137394568701184', '6226943645600487552', '2163612727665972096', '2307289214897332480', '2104562321825510400', '1800384942558699008', '1013298268207936128', '1332378466733219456', '1563999425873420800', '2465053942183130240', '3876618514794039040', '3445477328117272576', '4714563374364671872', '4406459119386466176', '3681313024562519552', '2754909740118313344', '5099482805904892288', '1920126431748251776', '2477023401857408640', '1809844934461976832', '2923643719394227328', '2818311909906928384', '1030279027003254784', '1612331959869359872', '1203263915795342336', '5294908873052262016', '3859020040917830400', '4306244746253355776', '3688359000015020800', '2096274276193099136', '1558322303741820928', '2488974302977323008', '5745881603063095680', '2234727353044624128', '4545086473126911360', '2155490364688727168', '2355217815809560192', '1796893134146598144', '1374430388449392000', '6557154200328277120', '175915658

Processing FITS files:  28%|██▊       | 18/65 [00:00<00:00, 172.13it/s]

Opening gzipped file: lamost_spectra_uniques/446308146
Opening gzipped file: lamost_spectra_uniques/660601090
Opening gzipped file: lamost_spectra_uniques/457204145
Opening gzipped file: lamost_spectra_uniques/797001189
Opening gzipped file: lamost_spectra_uniques/814503145
Opening gzipped file: lamost_spectra_uniques/631913192
Opening gzipped file: lamost_spectra_uniques/3210010
Opening gzipped file: lamost_spectra_uniques/384707148
Opening gzipped file: lamost_spectra_uniques/565605203
Opening gzipped file: lamost_spectra_uniques/377711035
Opening gzipped file: lamost_spectra_uniques/266807242
Opening gzipped file: lamost_spectra_uniques/577507136
Opening gzipped file: lamost_spectra_uniques/250808080
Opening gzipped file: lamost_spectra_uniques/679513244
Opening gzipped file: lamost_spectra_uniques/315401163
Opening gzipped file: lamost_spectra_uniques/866605170
Opening gzipped file: lamost_spectra_uniques/641310250
Opening gzipped file: lamost_spectra_uniques/616216075
Opening gzip

Processing FITS files:  82%|████████▏ | 53/65 [00:00<00:00, 168.79it/s]

Opening gzipped file: lamost_spectra_uniques/247615160
Opening gzipped file: lamost_spectra_uniques/235613172
Opening gzipped file: lamost_spectra_uniques/372410008
Opening gzipped file: lamost_spectra_uniques/195807009
Opening gzipped file: lamost_spectra_uniques/409102003
Opening gzipped file: lamost_spectra_uniques/155803169
Opening gzipped file: lamost_spectra_uniques/453806027
Opening gzipped file: lamost_spectra_uniques/28908076
Opening gzipped file: lamost_spectra_uniques/807204250
Opening gzipped file: lamost_spectra_uniques/601205115
Opening gzipped file: lamost_spectra_uniques/565809243
Opening gzipped file: lamost_spectra_uniques/4315066
Opening gzipped file: lamost_spectra_uniques/241503093
Opening gzipped file: lamost_spectra_uniques/164609001
Opening gzipped file: lamost_spectra_uniques/564905074
Opening gzipped file: lamost_spectra_uniques/471703215
Opening gzipped file: lamost_spectra_uniques/757314103
Opening gzipped file: lamost_spectra_uniques/417206083
Opening gzipp

Processing FITS files: 100%|██████████| 65/65 [00:00<00:00, 98.87it/s] 


✅ Successfully processed 65 files
⚠️ Encountered errors in 0 files

📊 Step 5: Extracting and saving flux & frequency values...

📊 Extracting flux and frequency values...


100%|██████████| 1/1 [00:00<00:00, 23.40it/s]


✅ Flux values shape: (65, 3749), Frequency values shape: (65, 3749)

📊 Step 6: Interpolating and normalizing LAMOST spectra...


Interpolating spectra: 100%|██████████| 65/65 [00:00<00:00, 683.96it/s]


Initial number of rows: 65
Total successful interpolations: 64
Total skipped rows (NaNs + zeros): 1
Final check: len(df_flux) == cnt_success + len(nan_files)? True


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  gaia_lamost_match["obsid"] = gaia_lamost_match["obsid"].astype(int)



📊 Step 7: Normalizing Gaia data...
Dropped 0 rows with NaN values.

🔗 Step 8: Merging Gaia and LAMOST data...

🤖 Step 9: Predicting labels using the trained model...
X_spectra shape: (64, 3647)
X_gaia shape: (64, 18)

💾 Step 10: Saving predictions...


In [21]:
y_pred_ = np.load("y_predictions_bin_out.npy")
y_pred = np.array(y_pred_[:, :-1], dtype=int) # last column is source_id probably
classes = pd.read_pickle("Pickles/Updated_List_of_Classes_ubuntu.pkl")

print(classes)

cv_index = classes.index("CV*")
star_index = classes.index("**")

print(cv_index, star_index)

# Generate the expected y_true for eclipsing binaries
y_true = np.zeros_like(y_pred)
y_true[:, -14] = 1  # "CV*" column
y_true[:, 1] = 1   # "**" column (second column)

# Count the number of misclassified samples
print(np.sum(y_true!=y_pred))

# Compute precision, recall, and F1-score for each class
precision = precision_score(y_true, y_pred, average=None, zero_division=0)
recall = recall_score(y_true, y_pred, average=None, zero_division=0)
f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
custom_f1 = custom_f1_score(y_true, y_pred, zero_division=0)
custom_precision = custom_precision_score(y_true, y_pred, zero_division=0)
exact_match = exact_match_ratio(y_true, y_pred)
from sklearn.metrics import precision_recall_fscore_support as score
_, _, _, support = score(y_true, y_pred)

# Create a DataFrame to store metrics per class
metrics_df = pd.DataFrame({
    "Class": classes,
    "Precision": precision,
    "Recall": recall,
    "F1 Score": f1,
    "Custom F1 Score": custom_f1,
    "Custom Precision": custom_precision,
    "Exact Match Ratio": exact_match,
    "Support": support
})

# Identify Correctly Classified Samples (True Positives)
correct_predictions = (y_pred == y_true).all(axis=1)
correct_gaia_ids = y_pred_[correct_predictions, -1]

# Identify incorrectly classified samples (False Positives and False Negatives)
incorrect_predictions = (y_pred != y_true).any(axis=1)
incorrect_gaia_ids = y_pred_[incorrect_predictions, -1]

# Not precise classification
print("\n--- Identifying Specific Misclassifications (FN') ---")

# 1. Find samples where the model predicted *at least one* label (any label)
model_predicted_something_mask = np.sum(y_pred, axis=1) > 0

# 2. Find samples where the model missed the target classes
missed_cv_mask = (y_true[:, cv_index] == 1) & (y_pred[:, cv_index] == 0)
missed_star_mask = (y_true[:, star_index] == 1) & (y_pred[:, star_index] == 0)

# 3. Find FN' samples for each target class
# FN' = Missed the target class AND predicted something else
fn_prime_cv_mask = missed_cv_mask & model_predicted_something_mask
fn_prime_star_mask = missed_star_mask & model_predicted_something_mask

# 4. Combine: Find samples that are FN' for *either* CV* OR **
overall_fn_prime_mask = fn_prime_cv_mask | fn_prime_star_mask

# 5. Get the corresponding Gaia IDs
# Ensure df_predictions index aligns with y_pred rows
# Using .iloc requires the integer indices based on the mask
fn_prime_indices = np.where(overall_fn_prime_mask)[0]
fn_prime_gaia_ids = y_pred_[fn_prime_indices, -1] # Assumes last col is source_id

print(f"\n🔍 Found {len(fn_prime_gaia_ids)} samples meeting the FN' criteria for CV* or **:")
print(pd.DataFrame({"source_id": fn_prime_gaia_ids}))

# State the labels given by the model to the FN' samples
print("\n🔍 Labels given by the model to FN' samples:")
fn_prime_labels = y_pred[overall_fn_prime_mask]
print_fn =pd.DataFrame(fn_prime_labels, columns=classes)
print(pd.DataFrame(fn_prime_labels, columns=classes))

# Display incorrectly classified Gaia IDs
print("\n🔍 Incorrectly Classified Gaia IDs:")
print(pd.DataFrame({"source_id": incorrect_gaia_ids}))

# Display the performance metrics for the non-zero classes
print("\n📊 Performance Metrics:")
metrics_df = metrics_df[metrics_df["Precision"] > 0]
print(metrics_df)
exact_match = np.mean((y_pred == y_true).all(axis=1))
print("Exact match accuracy:", exact_match)

['RS*', '**', 'El*', 'Y*O', 's*b', 'cC*', 'HB*', 'dS*', 'Or*', 'LP*', 'BS*', 'Ae*', 'WV*', 'HS*', 'Ev*', 'AB*', 'sg*', 's*r', 'Ce*', 'gD*', 'OH*', 'HXB', 'Pu*', 'RV*', 'Sy*', 'V*', 'TT*', 'SN*', 'Be*', 'SB*', 'Em*', 'Er*', 'PM*', 'HV*', 'pA*', 'C*', 'BY*', 'Ro*', 'XB*', 'Ma*', 'Pe*', 'CV*', 'bC*', 'RR*', 'Mi*', 'SX*', 'RG*', 'LM*', 'WD*', 'S*', 'MS*', 'Ir*', 'a2*', 'PN', 'EB*']
41 1
14

--- Identifying Specific Misclassifications (FN') ---

🔍 Found 5 samples meeting the FN' criteria for CV* or **:
      source_id
0  1.558322e+18
1  1.612332e+18
2  2.116888e+18
3  3.408324e+18
4  3.446676e+18

🔍 Labels given by the model to FN' samples:
   RS*  **  El*  Y*O  s*b  cC*  HB*  dS*  Or*  LP*  ...  SX*  RG*  LM*  WD*  \
0    0   0    0    0    0    0    0    0    0    0  ...    0    0    0    0   
1    0   0    0    0    0    0    0    0    0    0  ...    0    0    0    0   
2    0   0    0    0    0    0    0    0    0    0  ...    0    0    0    0   
3    0   1    0    0    0    0    0    0

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
# For plotting
#incorrect_gaia_ids = fn_prime_gaia_ids

In [33]:
def add_vertical_line_between_groups(ax, labels):
    """
    Draws a vertical dashed line on the provided axis between the error and flux groups.
    
    :param ax: The matplotlib axes object where the bar chart is plotted.
    :param labels: List of labels for the bars, ordered such that error columns come first and flux columns second.
    """
    # Count the number of error bars (assumes errors come first)
    num_error = sum(1 for label in labels if label.endswith("_error"))
    if num_error and num_error < len(labels):
        # Vertical line placed between the last error and the first flux bar
        separation_index = num_error - 0.5
        ax.axvline(x=separation_index, color='black', linestyle='--', linewidth=5)

def plot_spectrum_with_gaia(source_id, gaia_lamost_merged, spectra_folder="lamost_spectra_uniques", save_path=None):
    """
    Plots the LAMOST spectrum from FITS files and displays only Gaia parameters that have issues.
    
    :param source_id: Gaia Source ID of the incorrectly classified source
    :param gaia_lamost_merged: DataFrame containing Gaia and LAMOST cross-matched data
    :param spectra_folder: Path to the folder containing LAMOST FITS spectra
    """
    try:
        if 'obsid' not in gaia_lamost_merged.columns:
            print(f"⚠️ 'obsid' column not found in gaia_lamost_merged.")
            return
        
        match = gaia_lamost_merged.loc[gaia_lamost_merged['source_id'] == source_id]
        if match.empty:
            print(f"⚠️ No LAMOST match found for source_id {source_id}.")
            return
        
        obsid = int(match.iloc[0]['obsid'])
        print(f"Found match: Source ID {source_id} -> ObsID {obsid}")
        
        fits_path = f"{spectra_folder}/{int(obsid)}"
        
        # Load FITS data
        with fits.open(fits_path) as hdul:
            data = hdul[0].data
            if data is None or data.shape[0] < 3:
                print(f"⚠️ Skipping {obsid}: Data not found or incorrect format.")
                return
            
            flux = data[0]
            wavelength = data[2]
            
            fig, ax = plt.subplots(2, 1, figsize=(9, 12), gridspec_kw={'height_ratios': [3, 4]})
            plt.rcParams.update({'font.size': 16})
            
            # Plot the LAMOST spectrum
            ax[0].plot(wavelength, flux, color='blue', alpha=0.7, lw=1, zorder=10)
            ax[0].set_xlabel("Wavelength (Å)")
            ax[0].set_ylabel("Flux")
            ax[0].set_title(f"LAMOST Spectrum for Gaia ID: {source_id} (LAMOST ID: {obsid})")
            ax[0].grid(zorder=0)
            ax[1].grid(zorder=0)

            # Get Gaia info (dropping unneeded columns)
            gaia_info = match.iloc[[0]].drop(columns=["source_id", "obsid"], errors='ignore')
            
            # Build a dictionary for only the columns with issues.
            issues_dict = {}
            issues_text_list = []
            for col in gaia_info.columns:
                value = gaia_info[col].values[0]
                if col.endswith("_error") and value > 1:
                    issues_dict[col] = value
                    issues_text_list.append(f"Large error in {col}")
                elif col.endswith("_flux") and value < -1:
                    issues_dict[col] = value
                    issues_text_list.append(f"Dim object in {col}")
            
            issue_text = "; ".join(issues_text_list) if issues_text_list else "No significant data issues"
            # Plot only the problematic Gaia parameters if any exist.
            if issues_dict:
                # Order the labels: errors first, then fluxes.
                error_labels = [l for l in issues_dict.keys() if l.endswith("_error")]
                flux_labels = [l for l in issues_dict.keys() if l.endswith("_flux")]
                ordered_labels = error_labels + flux_labels
                ordered_values = [issues_dict[l] for l in ordered_labels]
                
                ax[1].bar(ordered_labels, ordered_values, color='skyblue', zorder=3)
                ax[1].tick_params("x", labelrotation=90)
                ax[1].set_title("Gaia Parameters with Issues")
                ax[1].set_ylabel("Standard Deviations from Mean")
                
                # Add a vertical dashed line between error and flux groups.
                add_vertical_line_between_groups(ax[1], ordered_labels)
            else:
                ax[1].text(0.5, 0.5, "No significant data issues", ha='center', va='center', fontsize=12)
                ax[1].axis("off")
            
            plt.tight_layout()
            if save_path:
                plt.savefig(save_path)
            plt.show()
    except Exception as e:
        print(f"Error loading {fits_path}: {e}")

# Example type conversions (ensure these columns are in the correct type)
gaia_lamost_merged['obsid'] = gaia_lamost_merged['obsid'].astype(int)
gaia_lamost_merged['source_id'] = gaia_lamost_merged['source_id'].astype(int)

# Loop through incorrectly classified sources and plot all spectra with labels if Gaia data is problematic.
for source_id in incorrect_gaia_ids.astype(int):
    plot_spectrum_with_gaia(source_id, gaia_lamost_merged, save_path=f"Images_and_Plots/{source_id}_spectrum.png")

Found match: Source ID 1558322303741820928 -> ObsID 566413139
⚠️ Skipping 566413139: Data not found or incorrect format.
Found match: Source ID 1612331959869359872 -> ObsID 814503145
⚠️ Skipping 814503145: Data not found or incorrect format.
⚠️ No LAMOST match found for source_id 2116887920191163904.
⚠️ No LAMOST match found for source_id 3408324422192886784.
⚠️ No LAMOST match found for source_id 3446676070669830656.


In [29]:
# Background: Random Stars in Milky Way for CMD

# Define the ADQL query to select stars with parallax > 0.03 mas (roughly within 30kpc, inside milky way)
query = """
SELECT TOP 1000000 source_id, ra, dec, parallax, phot_bp_mean_mag, phot_rp_mean_mag, phot_g_mean_mag, parallax_error
FROM gaiadr3.gaia_source
WHERE parallax > 0.03 
"""

# Launch an asynchronous job (this will return more than 2000 rows if available)
job = Gaia.launch_job_async(query)
gaia_table = job.get_results()

# Convert to a pandas DataFrame
df_all = gaia_table.to_pandas()
print(f"Total stars returned by the query: {len(df_all)}")

# Define the ADQL query to fetch detailed information for the Correctly Classified Gaia IDs
query = """
SELECT source_id, ra, dec, parallax, phot_bp_mean_mag, phot_rp_mean_mag, phot_g_mean_mag, parallax_error
FROM gaiadr3.gaia_source
WHERE source_id IN ({})
"""

# Join the source IDs into a single string
source_ids_str = ",".join([str(id) for id in correct_gaia_ids])
full_query = query.format(source_ids_str)

# Run the query asynchronously
job = Gaia.launch_job_async(full_query)
results = job.get_results()

# Convert to Pandas DataFrame
correct_df = results.to_pandas()

print(f"✅ Retrieved detailed information for {len(correct_df)} correctly classified Gaia IDs.")

# Define the ADQL query to fetch detailed information for the incorrectly Classified Gaia IDs
query = """
SELECT source_id, ra, dec, parallax, phot_bp_mean_mag, phot_rp_mean_mag, phot_g_mean_mag, parallax_error
FROM gaiadr3.gaia_source
WHERE source_id IN ({})
"""

# Join the source IDs into a single string
source_ids_str = ",".join([str(id) for id in incorrect_gaia_ids])
full_query = query.format(source_ids_str)

# Run the query asynchronously
job = Gaia.launch_job_async(full_query)
results = job.get_results()

# Convert to Pandas DataFrame
incorrect_df = results.to_pandas()

print(f"✅ Retrieved detailed information for {len(incorrect_df)} incorrectly classified Gaia IDs.")
# Define the ADQL query to fetch detailed information for the Correctly Classified Gaia IDs
query = """
SELECT source_id, ra, dec, parallax, phot_bp_mean_mag, phot_rp_mean_mag, phot_g_mean_mag, parallax_error
FROM gaiadr3.gaia_source
WHERE source_id IN ({})
"""

# Join the source IDs into a single string
source_ids_str = ",".join([str(id) for id in correct_gaia_ids])
full_query = query.format(source_ids_str)

# Run the query asynchronously
job = Gaia.launch_job_async(full_query)
results = job.get_results()

# Convert to Pandas DataFrame
correct_df = results.to_pandas()

print(f"✅ Retrieved detailed information for {len(correct_df)} correctly classified Gaia IDs.")

# Define the ADQL query to fetch detailed information for the incorrectly Classified Gaia IDs
query = """
SELECT source_id, ra, dec, parallax, phot_bp_mean_mag, phot_rp_mean_mag, phot_g_mean_mag, parallax_error
FROM gaiadr3.gaia_source
WHERE source_id IN ({})
"""

# Join the source IDs into a single string
source_ids_str = ",".join([str(id) for id in incorrect_gaia_ids])
full_query = query.format(source_ids_str)

# Run the query asynchronously
job = Gaia.launch_job_async(full_query)
results = job.get_results()

# Convert to Pandas DataFrame
incorrect_df = results.to_pandas()

print(f"✅ Retrieved detailed information for {len(incorrect_df)} incorrectly classified Gaia IDs.")

# --- Example: Prepare your Gaia DataFrame ---
# Clean out the bad parallax values (negative or zero)
print(f"Total stars before cleaning: {len(df_all)}")
df_all = df_all[df_all['parallax'] > 0].copy()

# Remove stars with large parallax errors (e.g., > 10% of the parallax value)
df_all = df_all[df_all['parallax_error'] < 0.1 * df_all['parallax']].copy()

print(f"Total stars after cleaning: {len(df_all)}")

# Combine the eclisping binary IDs with the Gaia DataFrame concatenated
df_sample = pd.concat([df_all, incorrect_df, correct_df], axis=0)
print(f"Combined DataFrame shape: {df_sample.shape}")


INFO: Query finished. [astroquery.utils.tap.core]
Total stars returned by the query: 1000000
INFO: Query finished. [astroquery.utils.tap.core]
✅ Retrieved detailed information for 1 correctly classified Gaia IDs.
INFO: Query finished. [astroquery.utils.tap.core]
✅ Retrieved detailed information for 0 incorrectly classified Gaia IDs.
INFO: Query finished. [astroquery.utils.tap.core]
✅ Retrieved detailed information for 1 correctly classified Gaia IDs.
INFO: Query finished. [astroquery.utils.tap.core]
✅ Retrieved detailed information for 0 incorrectly classified Gaia IDs.
Total stars before cleaning: 1000000
Total stars after cleaning: 39158
Combined DataFrame shape: (39159, 8)


In [31]:
# Save df_sample to a CSV file
df_sample.to_csv("Pickles/gaia_sample_background.csv", index=False)

In [15]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
import gzip
import io
import os

def add_vertical_line_between_groups(ax, labels):
    """
    Draws a vertical dashed line on the provided axis between the error and flux groups.
    
    :param ax: The matplotlib axes object where the bar chart is plotted.
    :param labels: List of labels for the bars, ordered such that error columns come first and flux columns second.
    """
    # Count the number of error bars (assumes errors come first)
    num_error = sum(1 for label in labels if label.endswith("_error"))
    if num_error and num_error < len(labels):
        # Vertical line placed between the last error and the first flux bar
        separation_index = num_error - 0.5
        ax.axvline(x=separation_index, color='black', linestyle='--', linewidth=5)


def open_fits_file(file_path):
    """
    Opens a FITS file, handling both regular and gzipped formats.
    
    :param file_path: Path to the FITS file
    :return: FITS HDU list or None if there was an error
    """
    try:
        # Check if the file is gzipped
        with open(file_path, 'rb') as f:
            file_start = f.read(2)
            f.seek(0)  # Reset file pointer
            if file_start == b'\x1f\x8b':  # gzip magic number
                # Handle gzipped file
                with gzip.GzipFile(fileobj=f) as gz_f:
                    file_content = gz_f.read()
                print(f"Opening gzipped file: {file_path}")
                return fits.open(io.BytesIO(file_content), ignore_missing_simple=True)
            else:
                # Handle regular file
                print(f"Opening regular file: {file_path}")
                return fits.open(file_path, ignore_missing_simple=True)
    except Exception as e:
        print(f"Error opening file {os.path.basename(file_path)}: {str(e)}")
        return None

def plot_spectrum_with_gaia_and_cmd(source_id, gaia_lamost_merged, df_sample, correct_df, incorrect_df, n,
                                    spectra_folder="lamost_spectra_uniques", save_path=None):
    """
    Plots the LAMOST spectrum, the Gaia parameters with issues, and a Color–Magnitude Diagram (CMD) 
    in a single figure with three subplots.
    
    :param source_id: Gaia Source ID of the incorrectly classified source.
    :param gaia_lamost_merged: DataFrame containing Gaia and LAMOST cross-matched data.
    :param df_sample: DataFrame containing Gaia photometric and parallax data for the CMD.
    :param correct_df: DataFrame containing correctly classified Gaia IDs.
    :param incorrect_df: DataFrame containing incorrectly classified Gaia IDs.
    :param spectra_folder: Path to the folder containing LAMOST FITS spectra.
    :param save_path: If provided, the complete figure is saved to this path.
    """
    try:
        if 'obsid' not in gaia_lamost_merged.columns:
            print("⚠️ 'obsid' column not found in gaia_lamost_merged.")
            return

        match = gaia_lamost_merged.loc[gaia_lamost_merged['source_id'] == source_id]
        if match.empty:
            print(f"⚠️ No LAMOST match found for source_id {source_id}.")
            return

        obsid = int(match.iloc[0]['obsid'])
        print(f"Found match: Source ID {source_id} -> ObsID {obsid}")

        fits_path = f"{spectra_folder}/{int(obsid)}"
        
        # Use the open_fits_file function to handle both regular and gzipped FITS files
        hdul = open_fits_file(fits_path)
        
        if hdul is None:
            print(f"⚠️ Failed to open FITS file for ObsID {obsid}.")
            return
            
        # Process the FITS data
        try:
            # After opening the FITS file, add debugging:
            print(f"FITS file structure for ObsID {obsid}:")
            for i, hdu in enumerate(hdul):
                print(f"  HDU {i}: {hdu.__class__.__name__}, shape={getattr(hdu.data, 'shape', 'No data')}")
            
            # LAMOST DR5 and later uses BinTableHDU in the first extension
            if len(hdul) > 1 and isinstance(hdul[1], fits.BinTableHDU):
                print("Using data from BinTableHDU (extension 1)")
                table_data = hdul[1].data
                
                # Debug table column names
                print(f"  BinTable columns: {table_data.names}")
                
                # For LAMOST spectra, typical column names are 'FLUX', 'WAVELENGTH', 'LOGLAM', etc.
                # Use appropriate column names based on what's available
                if 'FLUX' in table_data.names and 'WAVELENGTH' in table_data.names:
                    flux = table_data['FLUX'][0]  # First row
                    wavelength = table_data['WAVELENGTH'][0]
                    print(f"  Using FLUX and WAVELENGTH columns")
                elif 'FLUX' in table_data.names and 'LOGLAM' in table_data.names:
                    flux = table_data['FLUX'][0]  # First row
                    # Convert log wavelength to linear wavelength
                    log_wavelength = table_data['LOGLAM'][0]
                    wavelength = 10**log_wavelength
                    print(f"  Using FLUX and LOGLAM (converted) columns")
                # Add more conditions for different column naming conventions
                else:
                    # If column names don't match known formats, try first two columns
                    # (often wavelength is first, flux is second)
                    print(f"  Unknown column format, using first two columns")
                    wavelength = table_data[table_data.names[0]][0]
                    flux = table_data[table_data.names[1]][0]
            # Fallback to original method with primary HDU
            elif hdul[0].data is not None and len(hdul[0].data.shape) >= 1:
                print("Using data from PrimaryHDU")
                data = hdul[0].data
                if data.shape[0] < 3:
                    print(f"⚠️ Skipping {obsid}: Primary HDU data has insufficient dimensions: {data.shape}")
                    return
                flux = data[0]
                wavelength = data[2]
            else:
                print(f"⚠️ Skipping {obsid}: No usable data found in FITS file.")
                return
                
            # Check that we have valid data before proceeding
            if flux is None or wavelength is None or len(flux) == 0 or len(wavelength) == 0:
                print(f"⚠️ Skipping {obsid}: Empty flux or wavelength arrays")
                return
                
            print(f"  Data loaded successfully. Wavelength range: {min(wavelength):.2f}-{max(wavelength):.2f} Å")
            print(f"  Flux range: {min(flux):.2e}-{max(flux):.2e}")
            
            # Create a figure with three subplots using GridSpec.
            # Top row: two subplots (spectrum and Gaia issues), Bottom row: CMD spanning full width.
            fig = plt.figure(figsize=(24, 12))
            gs = fig.add_gridspec(1, 3, height_ratios=[1])
            ax1 = fig.add_subplot(gs[0, 1])
            ax2 = fig.add_subplot(gs[0, 2])
            ax3 = fig.add_subplot(gs[0, 0])
            plt.rcParams.update({'font.size': 16})

            # --- Subplot 1: LAMOST Spectrum ---
            ax1.plot(wavelength, flux, color='blue', alpha=0.7, lw=1, zorder=10)
            ax1.set_xlabel("Wavelength (Å)")
            ax1.set_ylabel("Flux")
            ax1.set_title(f"LAMOST Spectrum")
            ax1.grid(zorder=0)

            # --- Subplot 2: Gaia Parameters with Issues ---
            ax2.grid(zorder=0)
            gaia_info = match.iloc[[0]].drop(columns=["source_id", "obsid"], errors='ignore')
            issues_dict = {}
            issues_text_list = []
            for col in gaia_info.columns:
                value = gaia_info[col].values[0]
                if col.endswith("_error") and value > 1:
                    issues_dict[col] = value
                    issues_text_list.append(f"Large error in {col}")
                elif col.endswith("_flux") and value < -1:
                    issues_dict[col] = value
                    issues_text_list.append(f"Dim object in {col}")

            if issues_dict:
                # Order the labels: errors first, then fluxes.
                error_labels = [l for l in issues_dict.keys() if l.endswith("_error")]
                flux_labels = [l for l in issues_dict.keys() if l.endswith("_flux")]
                ordered_labels = error_labels + flux_labels
                ordered_values = [issues_dict[l] for l in ordered_labels]

                ax2.bar(ordered_labels, ordered_values, color='skyblue', zorder=3)
                ax2.tick_params("x", labelrotation=45)
                ax2.set_title("Gaia Parameters with Issues")
                ax2.set_ylabel("Standard Deviations from Mean")

                # Add a vertical dashed line between error and flux groups.
                add_vertical_line_between_groups(ax2, ordered_labels)
            else:
                ax2.text(0.5, 0.5, "No significant data issues", ha='center', va='center', fontsize=12)
                ax2.axis("off")

            # --- Subplot 3: Color–Magnitude Diagram (CMD) ---
            # Compute additional columns for CMD.
            df_sample['color'] = df_sample['phot_bp_mean_mag'] - df_sample['phot_rp_mean_mag']
            df_sample['distance_pc'] = 1000 / df_sample['parallax']
            df_sample['abs_mag'] = df_sample['phot_g_mean_mag'] - 5 * np.log10(df_sample['distance_pc'] / 10)
            df_sample['is_correct'] = df_sample['source_id'].isin(correct_df['source_id'])
            df_sample['is_incorrect'] = df_sample['source_id'].isin(incorrect_df['source_id'])

            # Plot background stars (those not flagged as correct or incorrect)
            mask_background = ~(df_sample['is_correct'] | df_sample['is_incorrect'])
            ax3.scatter(df_sample.loc[mask_background, 'color'], 
                        df_sample.loc[mask_background, 'abs_mag'],
                        s=3, color='gray', alpha=0.6, label='Nearby Stars')

            # Plot the incorrect in red.
            ax3.scatter(df_sample[df_sample['is_incorrect']]['color'],
                        df_sample[df_sample['is_incorrect']]['abs_mag'],
                        s=100, color='red', label='Incorrectly Classified', alpha=1, 
                        edgecolor='black', marker='H')
            
            # Plot the correct in green.
            #ax3.scatter(df_sample[df_sample['is_correct']]['color'],
            #            df_sample[df_sample['is_correct']]['abs_mag'],
            #            s=100, color='green', label='Correctly Classified', alpha=1, 
            #            edgecolor='black', marker='x')
            
            # Plot the target source in blue. FLUX IS NOT THE SAME AS MAGNITUDE, data for both exist in the Gaia table.
            target_color = df_sample.loc[df_sample['source_id'] == source_id, 'color'].values[0]
            target_abs_mag = df_sample.loc[df_sample['source_id'] == source_id, 'abs_mag'].values[0]
            ax3.scatter(target_color, target_abs_mag, s=200, color='blue', label='Target Source', alpha=1, edgecolor='black', marker='o')
            #target_abs_mag = match['phot_g_mean_flux'].values[0] - 5 * np.log10((1/match['parallax'].values[0] )/ 10)
            #ax3.scatter(target_color, target_abs_mag, s=200, color='blue', label='Target Source', alpha=1, edgecolor='black', marker='o')


            # In a CMD, brighter (lower) magnitudes are at the top.
            ax3.invert_yaxis()
            ax3.set_xlim(-0.5, 3.5)
            ax3.set_ylim(14, 0.5)
            ax3.set_xlabel('Colour (BP - RP)')
            ax3.set_ylabel('Absolute G Magnitude')
            ax3.set_title('Colour–Magnitude Diagram (CMD)')
            ax3.legend(loc='lower right')

            plt.tight_layout()
            if save_path:
                save_path= save_path.replace(".png", f"_{n}.png")
                plt.savefig(save_path)
            plt.show()
                
        except Exception as e:
            print(f"Error processing FITS data for source_id {source_id}: {e}")
            import traceback
            traceback.print_exc()
            
        finally:
            # Ensure proper cleanup of FITS file
            if hdul is not None:
                try:
                    hdul.close()
                except Exception as e:
                    print(f"Warning: Could not close FITS file: {e}")
                    
    except Exception as e:
        print(f"Error in overall processing for source_id {source_id}: {e}")
        import traceback
        traceback.print_exc()

gaia_lamost_merged = np.load("gaia_lamost_merged.npy", allow_pickle=True)                                                                                                                                                                                                                                                           


#gaia_lamost_merged = pd.DataFrame(gaia_lamost_merged, columns=["source_id", "obsid", "ra", "dec", "parallax", "phot_bp_mean_mag", "phot_rp_mean_mag", "phot_g_mean_mag", "parallax_error", "phot_bp_mean_flux", "phot_rp_mean_flux", "phot_g_mean_flux", "parallax_error_flux", "phot_bp_mean_mag_error", "phot_rp_mean_mag_error", "phot_g_mean_mag_error"])


# Example type conversions (ensure these columns are in the correct type)
gaia_lamost_merged['obsid'] = gaia_lamost_merged['obsid'].astype(int)
gaia_lamost_merged['source_id'] = gaia_lamost_merged['source_id'].astype(int)

# Initialize a counter for the save path
n_=1

# Loop through incorrectly classified sources and plot all spectra with labels if Gaia data is problematic.
for source_id in fn_prime_gaia_ids.astype(int):
    plot_spectrum_with_gaia_and_cmd(source_id, gaia_lamost_merged, save_path=f"Images_and_Plots/CMD_Spectra_Gaia_CV_take2.png", df_sample=df_sample, correct_df=correct_df, incorrect_df=incorrect_df, n=n_)
    n_+=1

FileNotFoundError: [Errno 2] No such file or directory: 'gaia_lamost_merged.npy'

In [71]:
print("Columns in gaia_lamost_merged:")
col = np.array(gaia_lamost_merged.columns)

Columns in gaia_lamost_merged:
