In [None]:
%matplotlib inline
import pandas as pd
from rich import print
import numpy as np
import matplotlib.pyplot as plt
#from quickbin import bin2d
from scipy.stats import binned_statistic_2d
import pdr
from astropy.visualization import ZScaleInterval
from astropy import wcs as pywcs
import warnings
from pyarrow import ArrowInvalid

In [None]:
nd_catfiles = !ls data/*/*nd*catalog*
fd_catfiles = !ls data/*/*fd*catalog*

In [None]:
tbl_nd = pd.DataFrame()
for f in nd_catfiles:
    try:
        tbl_nd = pd.concat([tbl_nd,pd.read_parquet(f)])
    except ArrowInvalid:
        print(f'Unable to open {f}')
        continue
tbl_fd = pd.DataFrame()
for f in fd_catfiles:
    try:
        tbl_fd = pd.concat([tbl_fd,pd.read_parquet(f)])
    except ArrowInvalid:
        print(f'Unable to open {f}')
        continue

In [None]:
tbl_nd

In [None]:
tbl_fd

In [None]:
def counts2mag(cps, band):
    scale = 18.82 if band == 'FUV' else 20.08
    with np.errstate(invalid='ignore'):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            mag = -2.5 * np.log10(cps) + scale
    return mag

In [None]:
ix = np.where(
    (np.isfinite(tbl_nd['NUV_MDL_SIGMA'])) & 
    (tbl_nd['NUV_MDL_SIGMA']<10) &
    (tbl_nd['NUV_MDL_SIGMA']>0.2) &
    #(tbl_nd['NUV_MDL_CPS']>1e-03) &
    (tbl_nd['NUV_MDL_SIGMA']) & (tbl_nd['NUV_MDL_SIGMA']) &
    (tbl_nd['NUV_SOFTEDGE_FLAG_A6']==0) & (tbl_nd['NUV_HARDEDGE_FLAG_A6']==0) &
    (tbl_nd['NUV_GHOST_FLAG_A6']==0) & (tbl_nd['NUV_HOTSPOT_FLAG_A6']==0) &
    (counts2mag(tbl_nd['NUV_MDL_CPS'].values,'NUV')<25))


plt.figure(figsize=(12,5))
# Use a logarithmic stretch for the color scale to enhance visibility of both dense and sparse regions
h = plt.hist2d(
    tbl_nd['NUV_MAG_A4'].iloc[ix],
    tbl_nd['FUV_MAG_A4'].iloc[ix],
    bins=100,
    range=[[15,22],[15,25]],
    cmap='Greys',
    norm=plt.matplotlib.colors.LogNorm(vmin=0.1)
)
plt.colorbar(h[3], label='Counts (log scale)')
plt.xlabel('NUV MAG A4')
plt.ylabel('FUV MAG A4')
plt.title('NUV Catalog --- NUV vs FUV Magnitude (log-stretched 2D histogram)')

ix = np.where(
    (np.isfinite(tbl_fd['FUV_MDL_SIGMA'])) & 
    (tbl_fd['FUV_MDL_SIGMA']<10) &
    (tbl_fd['FUV_MDL_SIGMA']>0.2) &
    #(tbl_fd['FUV_MDL_CPS']>1e-03) &
    (tbl_fd['FUV_MDL_SIGMA']) & (tbl_fd['FUV_MDL_SIGMA']) &
    (tbl_fd['FUV_SOFTEDGE_FLAG_A6']==0) & (tbl_fd['FUV_HARDEDGE_FLAG_A6']==0) &
    (tbl_fd['FUV_GHOST_FLAG_A6']==0) & (tbl_fd['FUV_HOTSPOT_FLAG_A6']==0) &
    (counts2mag(tbl_fd['FUV_MDL_CPS'].values,'FUV')<25))

plt.figure(figsize=(12,5))
# Use a logarithmic stretch for the color scale to enhance visibility of both dense and sparse regions
h = plt.hist2d(
    tbl_fd['NUV_MAG_A4'].iloc[ix],
    tbl_fd['FUV_MAG_A4'].iloc[ix],
    bins=100,
    range=[[15,22],[15,25]],
    cmap='Greys',
    norm=plt.matplotlib.colors.LogNorm(vmin=0.1)
)
plt.colorbar(h[3], label='Counts (log scale)')
plt.xlabel('NUV MAG A4')
plt.ylabel('FUV MAG A4')
plt.title('FUV Catalog --- NUV vs FUV Magnitude (log-stretched 2D histogram)')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import matplotlib.patches as mpatches

def create_enhanced_uv_plots(tbl_nd, tbl_fd, counts2mag):
    """
    Create enhanced UV magnitude plots with improved visualization
    """
    
    # Set up the figure with better styling
    plt.style.use('default')  # Reset any previous style
    fig, axes = plt.subplots(1, 2, figsize=(16, 7))
    fig.suptitle('GALEX UV Magnitude Correlation Analysis', fontsize=16, fontweight='bold')
    
    # Define consistent filtering for both datasets
    def get_clean_indices(tbl, band_prefix):
        return np.where(
            (np.isfinite(tbl[f'{band_prefix}_MDL_SIGMA'])) & 
            (tbl[f'{band_prefix}_MDL_SIGMA'] < 10) &
            (tbl[f'{band_prefix}_MDL_SIGMA'] > 0.2) &
            (tbl[f'{band_prefix}_SOFTEDGE_FLAG_A6'] == 0) & 
            (tbl[f'{band_prefix}_HARDEDGE_FLAG_A6'] == 0) &
            (tbl[f'{band_prefix}_GHOST_FLAG_A6'] == 0) & 
            (tbl[f'{band_prefix}_HOTSPOT_FLAG_A6'] == 0) &
            (counts2mag(tbl[f'{band_prefix}_MDL_CPS'].values, band_prefix) < 25)
        )
    
    # Get clean data indices
    ix_nuv = get_clean_indices(tbl_nd, 'NUV')
    ix_fuv = get_clean_indices(tbl_fd, 'FUV')
    
    # Enhanced color schemes
    cmaps = ['viridis', 'plasma']  # More perceptually uniform than Greys
    
    # Plot 1: NUV Catalog
    ax1 = axes[0]
    h1 = ax1.hist2d(
        tbl_nd['NUV_MAG_A4'].iloc[ix_nuv],
        tbl_nd['FUV_MAG_A4'].iloc[ix_nuv],
        bins=120,  # Increased resolution
        range=[[15, 22], [15, 25]],
        cmap=cmaps[0],
        norm=LogNorm(vmin=0.5, vmax=None),  # Better vmin
        alpha=0.8
    )
    
    # Add diagonal reference line (1:1 correlation)
    diag_x = np.linspace(15, 22, 100)
    ax1.plot(diag_x, diag_x, 'r--', alpha=0.6, linewidth=2, label='1:1 Line')
    
    # Customize plot 1
    ax1.set_xlabel('NUV Magnitude (AB)', fontsize=12, fontweight='bold')
    ax1.set_ylabel('FUV Magnitude (AB)', fontsize=12, fontweight='bold')
    ax1.set_title('NUV-Selected Sources\n(Quality-Filtered)', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3, linestyle=':')
    ax1.legend(loc='upper left')
    
    # Add colorbar for plot 1
    cbar1 = plt.colorbar(h1[3], ax=ax1, shrink=0.8)
    cbar1.set_label('Source Count (log scale)', fontsize=11, fontweight='bold')
    
    # Plot 2: FUV Catalog
    ax2 = axes[1]
    h2 = ax2.hist2d(
        tbl_fd['NUV_MAG_A4'].iloc[ix_fuv],
        tbl_fd['FUV_MAG_A4'].iloc[ix_fuv],
        bins=120,
        range=[[15, 22], [15, 25]],
        cmap=cmaps[1],
        norm=LogNorm(vmin=0.5, vmax=None),
        alpha=0.8
    )
    
    # Add diagonal reference line
    ax2.plot(diag_x, diag_x, 'r--', alpha=0.6, linewidth=2, label='1:1 Line')
    
    # Customize plot 2
    ax2.set_xlabel('NUV Magnitude (AB)', fontsize=12, fontweight='bold')
    ax2.set_ylabel('FUV Magnitude (AB)', fontsize=12, fontweight='bold')
    ax2.set_title('FUV-Selected Sources\n(Quality-Filtered)', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3, linestyle=':')
    ax2.legend(loc='upper left')
    
    # Add colorbar for plot 2
    cbar2 = plt.colorbar(h2[3], ax=ax2, shrink=0.8)
    cbar2.set_label('Source Count (log scale)', fontsize=11, fontweight='bold')
    
    # Add data statistics as text boxes
    def add_stats_box(ax, x_data, y_data, position='lower right'):
        n_sources = len(x_data)
        corr_coeff = np.corrcoef(x_data, y_data)[0, 1]
        
        stats_text = f'N = {n_sources:,}\nρ = {corr_coeff:.3f}'
        
        # Position mapping
        pos_map = {
            'lower right': (0.95, 0.05),
            'upper right': (0.95, 0.95),
            'lower left': (0.05, 0.05),
            'upper left': (0.05, 0.95)
        }
        
        ax.text(pos_map[position][0], pos_map[position][1], stats_text, 
                transform=ax.transAxes, fontsize=10,
                bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8),
                verticalalignment='bottom' if 'lower' in position else 'top',
                horizontalalignment='right' if 'right' in position else 'left')
    
    # Add statistics
    add_stats_box(ax1, tbl_nd['NUV_MAG_A4'].iloc[ix_nuv], tbl_nd['FUV_MAG_A4'].iloc[ix_nuv])
    add_stats_box(ax2, tbl_fd['NUV_MAG_A4'].iloc[ix_fuv], tbl_fd['FUV_MAG_A4'].iloc[ix_fuv])
    
    # Improve overall layout
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)  # Make room for suptitle
    
    return fig

# Alternative: Single plot with subsampling for better performance
def create_combined_enhanced_plot(tbl_nd, tbl_fd, counts2mag, subsample_factor=1):
    """
    Create a single enhanced plot combining both datasets
    """
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Get clean indices (same as before)
    ix_nuv = np.where(
        (np.isfinite(tbl_nd['NUV_MDL_SIGMA'])) & 
        (tbl_nd['NUV_MDL_SIGMA'] < 10) &
        (tbl_nd['NUV_MDL_SIGMA'] > 0.2) &
        (tbl_nd['NUV_SOFTEDGE_FLAG_A6'] == 0) & 
        (tbl_nd['NUV_HARDEDGE_FLAG_A6'] == 0) &
        (tbl_nd['NUV_GHOST_FLAG_A6'] == 0) & 
        (tbl_nd['NUV_HOTSPOT_FLAG_A6'] == 0) &
        (counts2mag(tbl_nd['NUV_MDL_CPS'].values, 'NUV') < 25)
    )
    
    ix_fuv = np.where(
        (np.isfinite(tbl_fd['FUV_MDL_SIGMA'])) & 
        (tbl_fd['FUV_MDL_SIGMA'] < 10) &
        (tbl_fd['FUV_MDL_SIGMA'] > 0.2) &
        (tbl_fd['FUV_SOFTEDGE_FLAG_A6'] == 0) & 
        (tbl_fd['FUV_HARDEDGE_FLAG_A6'] == 0) &
        (tbl_fd['FUV_GHOST_FLAG_A6'] == 0) & 
        (tbl_fd['FUV_HOTSPOT_FLAG_A6'] == 0) &
        (counts2mag(tbl_fd['FUV_MDL_CPS'].values, 'FUV') < 25)
    )
    
    # Combine datasets
    nuv_x = tbl_nd['NUV_MAG_A4'].iloc[ix_nuv]
    nuv_y = tbl_nd['FUV_MAG_A4'].iloc[ix_nuv]
    fuv_x = tbl_fd['NUV_MAG_A4'].iloc[ix_fuv]
    fuv_y = tbl_fd['FUV_MAG_A4'].iloc[ix_fuv]
    
    combined_x = np.concatenate([nuv_x, fuv_x])
    combined_y = np.concatenate([nuv_y, fuv_y])
    
    # Subsample if needed for performance
    if subsample_factor > 1:
        indices = np.random.choice(len(combined_x), 
                                 size=len(combined_x)//subsample_factor, 
                                 replace=False)
        combined_x = combined_x[indices]
        combined_y = combined_y[indices]
    
    # Create enhanced 2D histogram
    h = ax.hist2d(combined_x, combined_y,
                  bins=150,
                  range=[[15, 22], [15, 25]],
                  cmap='magma',
                  norm=LogNorm(vmin=1),
                  alpha=0.8)
    
    # Add reference lines
    diag_x = np.linspace(15, 22, 100)
    ax.plot(diag_x, diag_x, 'cyan', linestyle='--', alpha=0.8, linewidth=2, label='NUV = FUV')
    ax.plot(diag_x, diag_x + 1, 'lime', linestyle=':', alpha=0.6, linewidth=1.5, label='FUV = NUV + 1')
    ax.plot(diag_x, diag_x - 1, 'yellow', linestyle=':', alpha=0.6, linewidth=1.5, label='FUV = NUV - 1')
    
    # Styling
    ax.set_xlabel('NUV Magnitude (AB)', fontsize=14, fontweight='bold')
    ax.set_ylabel('FUV Magnitude (AB)', fontsize=14, fontweight='bold')
    ax.set_title('GALEX UV Color-Magnitude Diagram\n(Combined NUV & FUV Selected Sources)', 
                fontsize=16, fontweight='bold')
    ax.grid(True, alpha=0.3, linestyle=':')
    ax.legend(loc='upper left', fontsize=11)
    
    # Add colorbar
    cbar = plt.colorbar(h[3], ax=ax, shrink=0.8)
    cbar.set_label('Source Density (log scale)', fontsize=12, fontweight='bold')
    
    # Add statistics
    n_total = len(combined_x)
    corr_coeff = np.corrcoef(combined_x, combined_y)[0, 1]
    stats_text = f'Total Sources: {n_total:,}\nCorrelation: ρ = {corr_coeff:.3f}'
    
    ax.text(0.05, 0.95, stats_text, transform=ax.transAxes, fontsize=12,
            bbox=dict(boxstyle='round,pad=0.5', facecolor='white', alpha=0.9),
            verticalalignment='top')
    
    plt.tight_layout()
    return fig

# Usage examples:
fig1 = create_enhanced_uv_plots(tbl_nd, tbl_fd, counts2mag)
#fig2 = create_combined_enhanced_plot(tbl_nd, tbl_fd, counts2mag)
plt.show()