In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

aspcap_spectra = np.load('../aspcap_spectra_test.npy')
synspec_spectra = np.load('../synspec_test.npy')
wavelengths = np.loadtxt('../apogee_wavelength_sol.csv', delimiter=',')

plt.style.use('../../utils/mystyle.mplstyle')
#make font style sans serif
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Helvetica']
# Choose random spectrum
# random_index = np.random.randint(0, len(aspcap_spectra))
random_index = 23161
aspcap_spectrum = aspcap_spectra[random_index]
synspec_spectrum = synspec_spectra[random_index]
color1 = '#FF8C38'  # Warm coral orange
color2 = '#4B7BE5'  # Rich royal blue
# Set up the figure with two subplots
fig = plt.figure(figsize=(8, 5))
gs = fig.add_gridspec(2, 1, height_ratios=[1, 1], hspace=0.3)

# Top subplot - full range
ax1 = fig.add_subplot(gs[0])
# Split data at detector gaps
gap_indices = [3027, 5522]  # Indices before gaps
split_points = [0] + [i+1 for i in gap_indices] + [len(wavelengths)]

# Plot each section separately
for i in range(len(split_points)-1):
    start = split_points[i]
    end = split_points[i+1]
    # Highlight individual NaN pixels with faint red background
    nan_mask = np.isnan(aspcap_spectrum[start:end])
    for j in range(len(nan_mask)):
        if nan_mask[j]:
            # Use small width around each NaN point
            idx = start + j
            if idx > 0:
                width = (wavelengths[idx] - wavelengths[idx-1])/2.5
                ax1.axvspan(wavelengths[idx]-width, wavelengths[idx]+width, 
                          color='gray', alpha=0.05)
    ax1.plot(wavelengths[start:end], aspcap_spectrum[start:end], color=color2, 
             label='Real' if i==0 else "", lw=1)
    ax1.plot(wavelengths[start:end], synspec_spectrum[start:end], color=color1, 
             label='Synthetic' if i==0 else "", lw=1, alpha=0.9,linestyle='-')
    
    # Add 'Detector Gap' text in the gaps
    if i < len(split_points)-2:  # Don't add text after last section
        # Calculate center of gap
        if i == 0:
            gap_wavelength = (15799.31 + 15867.55) / 2 +1  # Center of first gap
        else:
            gap_wavelength = (16423.81 + 16484.05) / 2 +1  # Center of second gap
        # Get y-axis limits to center text vertically
        y_min, y_max = ax1.get_ylim()
        # y_center = (y_min + y_max) / 2
        y_center = y_max - (y_max - y_min) / 3 - 0.1
        ax1.text(gap_wavelength, y_center, 'Detector Gap', rotation=90,
                verticalalignment='center', horizontalalignment='center', fontsize=12, alpha=0.5)
# Set x-limits to 1% beyond min and max wavelengths
x_min = wavelengths.min()
x_max = wavelengths.max()
x_padding = (x_max - x_min) * 0.01
ax1.set_xlim(x_min - x_padding, x_max + x_padding)

# Choose zoom region (centered)
zoom_center = 16300
zoom_width = 50  # Angstroms
zoom_mask = (wavelengths > zoom_center - zoom_width/2) & (wavelengths < zoom_center + zoom_width/2)

# Bottom subplot - zoomed
ax2 = fig.add_subplot(gs[1])
# Highlight individual NaN pixels in zoom window
zoom_indices = np.where(zoom_mask)[0]
nan_mask_zoom = np.isnan(aspcap_spectrum[zoom_mask])
nan_count = 0
for j in range(len(nan_mask_zoom)):
    if nan_mask_zoom[j]:
        idx = zoom_indices[j]
        if idx > 0:
            width = (wavelengths[idx] - wavelengths[idx-1])/2.5
            ax2.axvspan(wavelengths[idx]-width, wavelengths[idx]+width,
                      color='gray', alpha=0.05,label='NaN' if nan_count==0 else "")
            nan_count += 1
ax2.plot(wavelengths[zoom_mask], aspcap_spectrum[zoom_mask], color=color2, lw=1.7, label='Real')
ax2.plot(wavelengths[zoom_mask], synspec_spectrum[zoom_mask], color=color1, lw=2.2, label='Synthetic', linestyle='--')
ax2.set_xlabel('Wavelength (Å)')
ax2.legend(frameon=False,loc='lower center')
# Set x-limits for zoom plot
ax2.set_xlim(zoom_center - zoom_width/2, zoom_center + zoom_width/2)

# Add super ylabel
fig.supylabel('Normalized Flux', x=0.02)

# Add connecting lines
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
mark_inset(ax1, ax2, loc1=1, loc2=2, fc="none", ec="0.5")

#super title
fig.suptitle('Real and Synthetic APOGEE Spectra', y=0.94)
#add 'Total Pixels' information on bottom left of top plot
ax1.text(0.005, 0.02, f'Total Pixels: {len(wavelengths)}', transform=ax1.transAxes,
         verticalalignment='bottom', horizontalalignment='left', fontsize=12, alpha=0.5)
plt.show()
#save
# fig.savefig('syn_vs_real.png', dpi=300,bbox_inches='tight')

