In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
from tabulate import tabulate
import pandas as pd
import os
import glob
import re
import sys
import pickle
import shutil

from photutils.aperture import CircularAperture, aperture_photometry
from spectral_cube import SpectralCube

from astropy.time import Time
from astropy.coordinates import SkyCoord
from astropy.table import Table
import astropy.units as u
from astropy.wcs import WCS
from astropy.constants import c
from astropy.io import fits
from astropy.visualization import simple_norm, imshow_norm
from astropy.visualization import AsinhStretch
from astropy.visualization.mpl_normalize import ImageNormalize
from astropy.visualization import SqrtStretch
from matplotlib.patches import Circle, Rectangle
from matplotlib.gridspec import GridSpec
from astropy.nddata import block_reduce
from astropy.nddata import Cutout2D
from astropy.stats import mad_std



home_directory = "/d/ret1/Taylor/jupyter_notebooks/Research" 
os.chdir(home_directory) #TJ change working directory to be the parent directory

from Py_files.Functions import * #TJ import functions from custom package

with open("Data_files/misc_data/jwst_pivots.pkl", "rb") as file:
    jwst_pivots = pickle.load(file)
with open("Data_files/misc_data/jwst_filter_means.pkl", "rb") as file:
    jwst_means = pickle.load(file)

image_files, filter_files = generate_list_of_files(filter_directory, image_directory)
full_raw_ifu_files_loc0 = ['Data_files/IFU_files/raw_IFUs/location_0/jw03435-o004_t005_nirspec_g140m-f100lp_s3d_trimmed.fits',
              'Data_files/IFU_files/raw_IFUs/location_0/jw03435-o004_t005_nirspec_g235m-f170lp_s3d_trimmed.fits',
              'Data_files/IFU_files/raw_IFUs/location_0/jw03435-o004_t005_nirspec_g395m-f290lp_s3d.fits',
              'Data_files/IFU_files/raw_IFUs/location_0/Arm1_Level3_ch1-shortmediumlong_s3d.fits',
              'Data_files/IFU_files/raw_IFUs/location_0/Arm1_Level3_ch2-shortmediumlong_s3d.fits',
              'Data_files/IFU_files/raw_IFUs/location_0/Arm1_Level3_ch3-shortmediumlong_s3d.fits',
              'Data_files/IFU_files/raw_IFUs/location_0/Arm1_Level3_ch4-shortmediumlong_s3d_trimmed.fits']
full_raw_ifu_files_loc1 = ['Data_files/IFU_files/raw_IFUs/location_1/jw03435-o012_t014_nirspec_g140m-f100lp_s3d_trimmed.fits',
              'Data_files/IFU_files/raw_IFUs/location_1/jw03435-o012_t014_nirspec_g235m-f170lp_s3d_trimmed.fits',
              'Data_files/IFU_files/raw_IFUs/location_1/jw03435-o012_t014_nirspec_g395m-f290lp_s3d.fits',
              'Data_files/IFU_files/raw_IFUs/location_1/Arm2_Level3_ch1-shortmediumlong_s3d.fits',
              'Data_files/IFU_files/raw_IFUs/location_1/Arm2_Level3_ch2-shortmediumlong_s3d.fits',
              'Data_files/IFU_files/raw_IFUs/location_1/Arm2_Level3_ch3-shortmediumlong_s3d.fits',
              'Data_files/IFU_files/raw_IFUs/location_1/Arm2_Level3_ch4-shortmediumlong_s3d_trimmed.fits']
#TJ location 2 also within loc1 files
full_raw_ifu_files_loc3 = ['Data_files/IFU_files/raw_IFUs/location_3/jw03435-o006_t010_nirspec_g140m-f100lp_s3d_trimmed.fits',
              'Data_files/IFU_files/raw_IFUs/location_3/jw03435-o006_t010_nirspec_g235m-f170lp_s3d_trimmed.fits',
              'Data_files/IFU_files/raw_IFUs/location_3/jw03435-o006_t010_nirspec_g395m-f290lp_s3d.fits',
              'Data_files/IFU_files/raw_IFUs/location_3/Arm3_Level3_ch1-shortmediumlong_s3d.fits',
              'Data_files/IFU_files/raw_IFUs/location_3/Arm3_Level3_ch2-shortmediumlong_s3d.fits',
              'Data_files/IFU_files/raw_IFUs/location_3/Arm3_Level3_ch3-shortmediumlong_s3d.fits',
              'Data_files/IFU_files/raw_IFUs/location_3/Arm3_Level3_ch4-shortmediumlong_s3d_trimmed.fits']
#TJ this full_spectrum is built by stitching cube0 to cube1 anchored to cube1
#TJ then stitching cube2 to cube3 anchored to cube3, then stitching all the others unaltered
full_spec = 'Data_files/misc_data/Updated_flux_calibration/full_spectrum_loc0_rad1p25.npy'
locations = [[202.5062429, 47.2143358], [202.4335225, 47.1729608], [202.4340450, 47.1732517], [202.4823742, 47.1958589]]
radius = 1.25*u.arcsec
r0 = 1.25*u.arcsec
r1 = 0.9*u.arcsec
r2 = 1*u.arcsec
r3 = 1.1*u.arcsec
filter_names = [extract_filter_name(x) for x in filter_files]
anchor_filters=[]
for file in full_raw_ifu_files_loc0:
    l = get_largest_filter_within(file)
    if l is not None:
        anchor_filters.append(l)

In [None]:
Ip25

In [None]:
def show_images(image_files, loc, radius, ncols=3, cmap='viridis'):
    """
    Create a collage of cutout images with an aperture overlay.
    
    Parameters
    ----------
    list_of_image_fits_files : list of str
        List of FITS image file paths (must contain SCI extension).
    loc : list, tuple, or SkyCoord
        Location of aperture center, either [RA, Dec] in degrees or a SkyCoord object.
    radius : Quantity
        Aperture radius (must have angular units, e.g. arcsec).
    ncols : int, optional
        Number of columns in the collage (default = 3).
    cmap : str, optional
        Colormap for displaying images (default = 'viridis').
    """
    
    # Make sure loc is SkyCoord
    if not isinstance(loc, SkyCoord):
        loc_sky = SkyCoord(ra=loc[0]*u.deg, dec=loc[1]*u.deg, frame='icrs')
    else:
        loc_sky = loc

    n_images = len(image_files)
    nrows = int(np.ceil(n_images / ncols))
    
    fig, axes = plt.subplots(nrows, ncols, figsize=(5*ncols, 5*nrows), 
                             subplot_kw={'projection': None})
    axes = np.atleast_1d(axes).ravel()  # Flatten in case of 1 row/col
    
    for ax, image_file in zip(axes, image_files):
        # Load FITS
        hdu = fits.open(image_file)['SCI']
        image = hdu.data
        header = hdu.header
        wcs = WCS(header, naxis=2)
        pixel_scale = np.abs(wcs.wcs.cdelt[0]) * 3600  # arcsec/pixel

        # Make cutout
        cutout = Cutout2D(image, position=loc_sky, size=(radius*3, radius*3), wcs=wcs)

        # Convert SkyCoord -> pixel coords
        x_img, y_img = cutout.wcs.world_to_pixel(loc_sky)

        # Plot
        im = ax.imshow(cutout.data, origin='lower', cmap=cmap,
                  norm=ImageNormalize(cutout.data, stretch=AsinhStretch(), 
                                      vmin=0, vmax=np.percentile(cutout.data, 99)))
        ax.add_patch(Circle((x_img, y_img), 
                            (radius.to(u.arcsec).value)/pixel_scale, 
                            ec='red', fc='none', lw=2, alpha=0.7))
        cbar = plt.colorbar(
            im,
            ax=ax,
            fraction=0.046,
            pad=0.04
        )
        cbar.set_label("Flux (native units)", fontsize=10)
        ax.set_title(image_file.split("/")[-1], fontsize=12)
        ax.set_xticks([])
        ax.set_yticks([])

    # Hide empty panels if n_images doesn’t fill full grid
    for ax in axes[n_images:]:
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

def show_image_and_synth(filter, ifu_fileset, loc, radius, image_files=v0p3_images, color_min_max = [1, 99.5]):
    """
    Show real image and synthetic IFU-derived image side by side.
    Works with either one or two IFU cubes needed for the filter.
    """

    # ------------------------------------------------------------
    # Location handling
    # ------------------------------------------------------------
    if not isinstance(loc, SkyCoord):
        loc_sky = SkyCoord(ra=loc[0] * u.deg, dec=loc[1] * u.deg, frame="icrs")
    else:
        loc_sky = loc

    def nearest_spaxel_map(cube_src, cube_target):
        ny, nx = cube_target.shape[1:]
        y_t, x_t = np.mgrid[:ny, :nx]

        world = cube_target.wcs.celestial.pixel_to_world(x_t, y_t)
        x_s, y_s = cube_src.wcs.celestial.world_to_pixel(world)

        x_s = np.clip(np.round(x_s).astype(int), 0, cube_src.shape[2] - 1)
        y_s = np.clip(np.round(y_s).astype(int), 0, cube_src.shape[1] - 1)

        return y_s, x_s

    # ------------------------------------------------------------
    # Locate real image
    # ------------------------------------------------------------
    real_image_file = [x for x in image_files if extract_filter_name(x) == filter][0]
    
    short_wl, long_wl = [x.value for x in get_filter_wl_range(filter)]

    needed_ifus = []
    for file in ifu_fileset:
        wl = SpectralCube.read(file, hdu="SCI").spectral_axis.to(u.m).value
        if (wl[0] < short_wl) and (wl[-1] > long_wl):
            needed_ifus = [file]
            break
        if (long_wl > wl[0]) and (short_wl < wl[-1]):
            needed_ifus.append(file)
        if len(needed_ifus) > 1:
            break

    cube1 = SpectralCube.read(needed_ifus[0], hdu="SCI")
    cube2 = SpectralCube.read(needed_ifus[1], hdu="SCI") if len(needed_ifus) > 1 else None

    # ------------------------------------------------------------
    # Base cube quantities
    # ------------------------------------------------------------
    wl1 = cube1.spectral_axis.to(u.m).value
    d1 = cube1.unmasked_data[:].value

    ny, nx = cube1.shape[1:]
    n_pix = ny * nx
    d1 = d1.reshape(len(wl1), n_pix).T

    # ------------------------------------------------------------
    # Stitch spectra if needed
    # ------------------------------------------------------------
    if cube2 is not None:
        wl2 = cube2.spectral_axis.to(u.m).value
        d2 = cube2.unmasked_data[:].value

        y2, x2 = nearest_spaxel_map(cube2, cube1)
        d2 = d2[:, y2, x2].reshape(len(wl2), n_pix).T

        wl_all = np.concatenate([wl1, wl2])
        sort_idx = np.argsort(wl_all)
        wl_all = wl_all[sort_idx]

        spec_all = np.concatenate([d1, d2], axis=1)[:, sort_idx]

        wl_min = max(wl1.min(), wl2.min())
        wl_max = min(wl1.max(), wl2.max())
        overlap = (wl_all >= wl_min) & (wl_all <= wl_max)
        both = np.isin(wl_all, wl1) & np.isin(wl_all, wl2) & overlap
        spec_all[:, both] *= 0.5
    else:
        wl_all = wl1
        spec_all = d1

    # ------------------------------------------------------------
    # Synthetic photometry
    # ------------------------------------------------------------
    filter_wl, filter_trans = get_filter_data(filter)

    image = np.empty(n_pix)
    for i in range(n_pix):
        image[i] = get_Fnu_transmission(
            spec_all[i], wl_all, filter_trans, filter_wl, warnings=True
        )

    synth_image = image.reshape(ny, nx)

    # ------------------------------------------------------------
    # Attach WCS to synthetic image
    # ------------------------------------------------------------
    synth_hdu = fits.PrimaryHDU(
        synth_image,
        header=cube1.wcs.celestial.to_header()
    )

    # ------------------------------------------------------------
    # Plot real vs synthetic
    # ------------------------------------------------------------
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # -------------------------
    # REAL IMAGE
    # -------------------------
    hdu = fits.open(real_image_file)["SCI"]
    real_pix_size = hdu.header['PIXAR_A2']**0.5
    aperture_radius = radius.to(u.arcsec).value / real_pix_size
    cutout_real = Cutout2D(
        hdu.data,
        position=loc_sky,
        size=(radius * 3, radius * 3),
        wcs=WCS(hdu.header)
    )

    # -------------------------
    # SYNTHETIC IMAGE (WCS CUTOUT)
    # -------------------------
    cutout_synth = Cutout2D(
        synth_hdu.data,
        position=loc_sky,
        size=(radius * 3, radius * 3),
        wcs=WCS(synth_hdu.header)
    )

    # ------------------------------------------------------------
    # Shared normalization (1–99%)
    # ------------------------------------------------------------
    combined = np.concatenate([
        cutout_real.data[np.isfinite(cutout_real.data)],
        cutout_synth.data[np.isfinite(cutout_synth.data)]
    ])

    vmin = np.percentile(combined, color_min_max[0])
    vmax = np.percentile(combined, color_min_max[1])

    cmap = plt.get_cmap("viridis").copy()
    cmap.set_under("black")
    cmap.set_over("white")
    norm = colors.Normalize(vmin=vmin, vmax=vmax, clip=False)

    pix_scale = np.abs(cutout_synth.wcs.wcs.cdelt[0]) * 3600
    r_ap_pix = radius.to(u.arcsec).value / pix_scale

    # -------------------------
    # PLOT REAL
    # -------------------------
    im0 = axes[0].imshow(
        cutout_real.data,
        origin="lower",
        cmap=cmap,
        norm=norm
    )
    axes[0].set_title(f"{filter} – Real")

    cbar0 = plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)
    cbar0.set_label("Flux", fontsize=14)
    cbar0.ax.tick_params(labelsize=12)

    x_r, y_r = cutout_real.wcs.world_to_pixel(loc_sky)
    axes[0].add_patch(
        Circle((x_r, y_r), aperture_radius, edgecolor="red", facecolor="none", linewidth=2)
    )

    # -------------------------
    # PLOT SYNTHETIC
    # -------------------------
    im1 = axes[1].imshow(
        cutout_synth.data,
        origin="lower",
        cmap=cmap,
        norm=norm
    )
    axes[1].set_title(f"{filter} – Synthetic (IFU)")

    cbar1 = plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
    cbar1.set_label("Flux", fontsize=14)
    cbar1.ax.tick_params(labelsize=12)

    x_s, y_s = cutout_synth.wcs.world_to_pixel(loc_sky)
    axes[1].add_patch(
        Circle((x_s, y_s), r_ap_pix, edgecolor="red", facecolor="none", linewidth=2)
    )

    # -------------------------
    # Cleanup
    # -------------------------
    for ax in axes:
        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout()
    plt.show()


def is_filter_relevent(filter, ifu_file):
    '''If a filter's mean wavelength is inside the ifu, returns True
    -------------
    
    Parameters
    -------------
    filter : type = string - name of filter ("F115W")
    ifu_file : type = string - string to location of ifu file
    
    Returns
    -------------
    True if filter's mean wavelength is inside the ifu_file, False if it is not.
    '''
    wls = SpectralCube.read(ifu_file, hdu = 'SCI').spectral_axis.to(u.m)
    short, long = wls[0], wls[-1]
    return (jwst_means[filter] > short) & (jwst_means[filter] < long)
    
def adjust_spectrum(original_ifu, filter_name, image_files, location, radius, adjustment_operation = 'add'):
    '''Takes an ifu file and adjusts the flux through an aperture centered at a location with specified radius.
    -------------
    
    Parameters
    -------------
    original_ifu : type = string (or, see retry=True)- string to location of ifu file
    filter_name : type = string - filter name like "F115W"
    location : type = either SkyCoord or list of [ra, dec] values in degrees - location of center of aperture
    radius : type = angular size - radius of aperture, must have units attached.
    adjustment_operation (optional, defaults to 'add'): type = string - either 'add' or 'multiply' to specify what kind of correction to use
    
    Returns
    -------------
    Structured Numpy array with 'intensity' and 'wavelength' keys
    '''
    if filter_name is None:
        return get_IFU_spectrum(original_ifu, location, radius, replace_negatives = False), 0
    else:
        image_file = [x for x in image_files if extract_filter_name(x)==filter_name][0]
        raw_data = get_IFU_spectrum(original_ifu, location, radius, replace_negatives = False)
        filter_wl, filter_trans = get_filter_data(filter_name) #TJ this is the transmission vs wavelength function for this filter
        image_flux = get_image_flux(image_file, location, radius, replace_negatives = False) #TJ this is the flux we SHOULD get
        initial_synth_flux = get_Fnu_transmission(raw_data['intensity'], raw_data['wavelength'], filter_trans, filter_wl, warnings = True) #TJ this is the current synthetic flux we get
        if adjustment_operation == 'add':
            correction = image_flux - initial_synth_flux
            raw_data['intensity'] = raw_data['intensity'] + correction
            return raw_data, correction #TJ now corrected to match photometry
        elif adjustment_operation == 'multiply':
            correction = image_flux/initial_synth_flux
            raw_data['intensity'] = raw_data['intensity']*correction
            return raw_data, correction #TJ now corrected
        else:
            print('adjustment operation not recognized, only "add" or "multiply" are currently implemented')
            return None
        print('Something went wrong.')
        return raw_data, correction #TJ Now corrected data


def get_largest_filter_within(ifu_file):
    '''Takes an ifu file and selects the filter with the largest bandpass that is entirely within it.
    -------------
    
    Parameters
    -------------
    ifu_file : type = string - string to location of ifu file
    
    Returns
    -------------
    Filter name (ex. "F115W") corresponding to the largest filter entirely contained within the IFU file 
    '''
    filters = [extract_filter_name(x) for x in filter_files if full_coverage(extract_filter_name(x),ifu_file)=="good"]
    if len(filters)<1:
        print(f'No filters entirely within {ifu_file}')
        return None
    else:
        best_filter = filters[np.argmax([(get_filter_wl_range(fil)[1].value - get_filter_wl_range(fil)[0].value) for fil in filters])]
        return best_filter

def needed_datasets(filter_name, datasets):
    '''returns which ifu_files should be considered when calculating the synthetic flux. If an ifu even slightly overlaps into
    the filter's range it is included.
    -------------
    
    Parameters
    -------------
    filter : type = string - name of filter ("F115W")
    datasets : type = structured array - array with keys for 'wavelength' and 'intensity'
    
    Returns
    -------------
    Filter name (ex. "F115W") corresponding to the largest filter entirely contained within the IFU file 
    '''
    needed = []
    filter_wl, _ = get_filter_data(filter_name)
    for data in datasets:
        if (filter_wl[0] < data['wavelength'][-1]) & (filter_wl[-1] > data['wavelength'][0]):
            needed.append(data)
    return needed

def merge_datasets(ds1, ds2):
    """
    Merge two structured arrays with 'wavelength' and 'intensity' keys.
    Handles overlapping regions by averaging intensities, and automatically
    determines which dataset has higher wavelength resolution.
    """
    # Sort by wavelength, just to be safe
    ds1 = np.sort(ds1, order='wavelength')
    ds2 = np.sort(ds2, order='wavelength')

    # Determine wavelength resolutions
    d1_res = np.mean(np.diff(ds1['wavelength']))
    d2_res = np.mean(np.diff(ds2['wavelength']))

    # Assign high- and low-resolution datasets
    if d1_res < d2_res:
        highres, lowres = ds1, ds2
    else:
        highres, lowres = ds2, ds1

    # Determine overlap region
    overlap_start = max(highres['wavelength'][0], lowres['wavelength'][0])
    overlap_end   = min(highres['wavelength'][-1], lowres['wavelength'][-1])

    # Interpolate the lowres data onto highres wavelengths (only inside overlap)
    overlap_mask = (highres['wavelength'] >= overlap_start) & (highres['wavelength'] <= overlap_end)
    interp_flux = np.interp(highres['wavelength'][overlap_mask],
                            lowres['wavelength'], lowres['intensity'])

    # Combine in overlap by averaging
    merged_overlap_wl = highres['wavelength'][overlap_mask]
    merged_overlap_intensity = 0.5 * (highres['intensity'][overlap_mask] + interp_flux)

    # Keep the unique non-overlapping parts from both sides
    full_low_side  = ds1[ds1['wavelength'] < overlap_start]
    full_high_side = ds2[ds2['wavelength'] > overlap_end]

    # Concatenate all pieces and sort
    merged = np.concatenate([
        full_low_side,
        np.rec.fromarrays([merged_overlap_wl, merged_overlap_intensity],
                          names=('wavelength', 'intensity')),
        full_high_side
    ])
    merged = np.sort(merged, order='wavelength')

    return merged


def get_all_fluxes(filter_files, spec_datasets, image_files, location, radius):
    '''Creates synthetic fluxes for all filters in the files that have wavelengths that span the entire filter.
    For filters that straddle multiple wavelengths, any wavelength inside a filter that has intensity values
    from multiple datasets uses the average intensity from each dataset.
    -------------
    
    Parameters
    -------------
    filter_files : type = list of strings - name of filter ["F115W", "F2100W"]
    datasets : type = list of structured arrays - arrays with keys for 'wavelength' and 'intensity'
    image_files : type = list of strings - strings to image files
    location : type = either SkyCoord or list of [ra, dec] values in degrees - location of center of aperture
    radius : type = angular size - radius of aperture, must have units attached.
    
    Returns
    -------------
    A dictionary with keys for 'filter_name', 'mean_wl', 'synth_flux', and 'photo_flux'
    '''
    results = {}

    results['filter_name'] = []
    results['mean_wl'] = []
    results['synth_flux'] = []
    results['photo_flux'] = []
    results['wavelength'] = []
    results['intensity'] = []
    
    for i, data in enumerate(spec_datasets[1:]):
        if i == 0:
            prior_data = spec_datasets[0]
        prior_data = merge_datasets(prior_data, data)
        
    results['wavelength'].append(prior_data['wavelength'])
    results['intensity'].append(prior_data['intensity'])
    
    
    for filter_file in filter_files:
        filter_name = extract_filter_name(filter_file)
        image_file = [x for x in image_files if extract_filter_name(x)==filter_name][0]
        photo_flux = get_image_flux(image_file, location, radius, replace_negatives = False)
        results['photo_flux'].append(photo_flux)
        filter_wl, filter_trans = get_filter_data(filter_name)
        results['filter_name'].append(filter_name)
        results['mean_wl'].append(jwst_means[filter_name].value)
        needed_data = needed_datasets(filter_name, spec_datasets)
        if len(needed_data) == 0:
            print('no spectral data was found for ', filter_name)
        if len(needed_data)<2:
            synth_flux = get_Fnu_transmission(needed_data[0]['intensity'], needed_data[0]['wavelength'], filter_trans, filter_wl, warnings = True)
            results['synth_flux'].append(synth_flux)
            
        else:
            full_data = merge_datasets(needed_data[0], needed_data[1])
            synth_flux = get_Fnu_transmission(full_data['intensity'], full_data['wavelength'], filter_trans, filter_wl, warnings = True)
            results['synth_flux'].append(synth_flux)
    results['wavelength'] = np.array(results['wavelength'][0])
    results['intensity'] = np.array(results['intensity'][0])
    results['filter_name'] = np.array(results['filter_name'])
    results['mean_wl'] = np.array(results['mean_wl'])
    results['synth_flux'] = np.array(results['synth_flux'])
    results['photo_flux'] = np.array(results['photo_flux'])
    return results


def get_overlap_region(ds1, ds2):
    """
    Return only the overlapping wavelength region between two structured arrays
    with 'wavelength' and 'intensity'. The returned region contains:
        - wavelength grid from the higher-resolution dataset (within overlap)
        - intensity = average(intensity_highres, interpolated_intensity_lowres)
    """

    # Sort to ensure order
    ds1 = np.sort(ds1, order='wavelength')
    ds2 = np.sort(ds2, order='wavelength')

    # Compute wavelength resolutions
    d1_res = np.mean(np.diff(ds1['wavelength']))
    d2_res = np.mean(np.diff(ds2['wavelength']))

    # Identify high- and low-resolution datasets
    if d1_res < d2_res:
        highres, lowres = ds1, ds2
    else:
        highres, lowres = ds2, ds1

    # Determine numerical overlap bounds
    overlap_start = max(highres['wavelength'][0], lowres['wavelength'][0])
    overlap_end   = min(highres['wavelength'][-1], lowres['wavelength'][-1])

    # If no overlap, return empty structured array
    if overlap_start >= overlap_end:
        return np.recarray(0, dtype=[('wavelength', float), ('intensity', float)])

    # Mask for high-res wavelengths inside the overlap
    mask = (highres['wavelength'] >= overlap_start) & (highres['wavelength'] <= overlap_end)

    high_wl = highres['wavelength'][mask]
    high_flux = highres['intensity'][mask]

    # Interpolate lowres intensities onto the highres wavelength grid
    interp_flux = np.interp(high_wl,
                            lowres['wavelength'],
                            lowres['intensity'])

    # Average intensities
    avg_flux = 0.5 * (high_flux + interp_flux)

    # Return structured array
    overlap = np.rec.fromarrays(
        [high_wl, avg_flux],
        names=('wavelength', 'intensity')
    )

    return overlap




def create_data(ifu_files, image_files, filter_files, loc, radius, anchor_filters = anchor_filters):
    '''


    '''
    temp_filepath = 'Data_files/misc_data/temp_data'
    if os.path.exists(temp_filepath):
        shutil.rmtree(temp_filepath)
    os.makedirs(temp_filepath)
    results = {}
    if loc == [202.5062429, 47.2143358]:
        loc_index = 0
    elif loc == [202.4335225, 47.1729608]:
        loc_index = 1
    elif loc == [202.4340450, 47.1732517]:
        loc_index = 2
    elif loc == [202.4823742, 47.1958589]:
        loc_index = 3
    else:
        loc_index = "?"
    
    results['add_datasets'] = []
    results['mult_datasets'] = []

    results['ifu_files'] = ifu_files
    results['image_files']
    
    results['location'] = loc
    results['loc_idx'] = loc_index
    results['radius'] = radius
    
    print('adjusting spectra using additive and multiplicative corrections')
    results['add_correction_values'] = []
    results['mult_correction_values'] = []
    
    for i, ifu_file in enumerate(ifu_files):
        mult_data, mult_correction = adjust_spectrum(ifu_file, get_largest_filter_within(ifu_file), image_files, loc, radius, adjustment_operation = 'multiply')
        results['mult_correction_values'].append(mult_correction)
        add_data, add_correction = adjust_spectrum(ifu_file, get_largest_filter_within(ifu_file), image_files, loc, radius, adjustment_operation = 'add')
        results['add_correction_values'].append(add_correction)
        fname = os.path.join(temp_filepath, f"add_grism_{i+1}_of_{len(ifu_files)}.npy")
        np.save(fname, add_data)
        fname = os.path.join(temp_filepath, f"mult_grism_{i+1}_of_{len(ifu_files)}.npy")
        np.save(fname, mult_data)

        print(f'adjusted {i+1} of {len(ifu_files)}')

    add_datasets = []
    mult_datasets = []
    add_files = glob.glob(f'Data_files/misc_data/temp_data/add_grism*')
    mult_files = glob.glob(f'Data_files/misc_data/temp_data/mult_grism*')
    for file in add_files:
        data = np.load(file)
        results['add_datasets'].append(data)
        add_datasets.append(data)
    for file in mult_files:
        data = np.load(file)
        results['mult_datasets'].append(data)
        mult_datasets.append(data)

    
    print('calculating additive corrected synthetic photometry...')
    add_results = get_all_fluxes(filter_files, add_datasets, image_files, loc, radius)
    print('calculating multiplicative corrected synthetic photometry...')
    mult_results = get_all_fluxes(filter_files, mult_datasets, image_files, loc, radius)
    
    
    print('Compiling results and cleaning up...')
    
    results['filter_names'] = add_results['filter_name']
    results['filter_wavelengths'] = add_results['mean_wl']
    results['add_synthetic_fluxes'] = add_results['synth_flux']
    results['mult_synthetic_fluxes'] = mult_results['synth_flux']
    if np.mean(add_results['photo_flux']) != np.mean(mult_results['photo_flux']):
        print('!!!!!!!!Photo fluxes were not the same in the two datasets! Something has gone wrong')
    results['photo_fluxes'] = add_results['photo_flux']
    if np.mean(add_results['wavelength']) != np.mean(mult_results['wavelength']):
        print('!!!!!!!!!Wavelength arrays were not the same in the two datasets! Something has gone wrong')
    results['wavelength'] = add_results['wavelength']
    results['add_intensity'] = add_results['intensity']
    results['mult_intensity'] = mult_results['intensity']
    
    shutil.rmtree(temp_filepath)

    return results


def plot_results(results, correction = 'mult', show_images = []):
    '''


    '''
    if correction == 'mult':
        method = 'multiplicative correction'
    elif correction == 'add':
        method = 'additive correction'
    else:
        method = 'unrecognized method'
    fig = plt.figure(figsize = (45,30))
    ax_spec = fig.add_axes((0.05, 0.4, 1, 0.6))
    ax_scat = fig.add_axes((0.05, 0.05, 1, 0.35))
    fontsize_sm = 35
    fontsize_lg = 45
    marker_size = 250
    cube_colors = ['purple', 'blue', 'cyan', 'green', 'orange', 'red', 'pink']
    spec_y_min = 1 #TJ Flux should always be around 10^-20 so setting limits of 0-1 should never be too strict
    spec_y_max = 0
    short_bounds = []
    long_bounds = []
    for i, dataset in  enumerate(results[correction+'_datasets']):
            
            short_bounds.append(dataset['wavelength'][0])
            long_bounds.append(dataset['wavelength'][-1])
            ax_spec.plot(dataset['wavelength'], dataset['intensity'], alpha = 0.5, color = cube_colors[i], linewidth = 5)
            spec_y_min = min(spec_y_min, np.percentile(dataset['intensity'], 1)*0.5)
            spec_y_max = max(spec_y_max, np.percentile(dataset['intensity'], 98)*1.5)
            if i > 0:
                overlap_data = get_overlap_region(results[correction+'_datasets'][i-1], dataset)
                ax_spec.plot(overlap_data['wavelength'], overlap_data['intensity'], alpha = 0.5, color = 'black')
    
    
    
    #ax_spec.plot(addresults0['wavelength'], addresults0['intensity'], linewidth = 5)
    ax_scat.plot(results['wavelength'], [1]*len(results[correction+'_intensity']), color = 'white', alpha = 0)
    ax_spec.scatter(results['filter_wavelengths'], results[correction+'_synthetic_fluxes'], marker = '*', s=marker_size, color = 'black')
    ax_spec.scatter([], [], marker = '*', s=marker_size, color = 'black', label = 'Synth')
    ax_spec.scatter(results['filter_wavelengths'], results['photo_fluxes'], marker = "o", s=marker_size, color = 'black')
    ax_spec.scatter([], [], marker = "o", s=marker_size, color = 'black', label = 'Photo')
    for i, filter in enumerate(results['filter_names']):
        filter_short_wl, filter_long_wl = [x.value for x in get_filter_wl_range(filter)]
        ax_spec.hlines(y=results['photo_fluxes'][i], xmin=filter_short_wl, xmax=filter_long_wl, color='black', alpha=0.7, linewidth=3)
    
    ax_scat.scatter(results['filter_wavelengths'], results[correction+'_synthetic_fluxes']/results['photo_fluxes'], s=marker_size, color = 'black')
    
    
    ax_scat.tick_params(axis='x', which='minor', width=2, length=10, right=True, top=True, direction='in',
                       labelsize=fontsize_sm)
    ax_scat.tick_params(axis='x', which='major', width=3, length=15, right=True, top=True, direction='in',
                       labelsize=fontsize_sm)
    ax_scat.tick_params(axis='y', which='both', width=3, length=15, right=True, top=True, direction='in',
                       labelsize=fontsize_sm)
    ax_scat.set_xlabel('wavelength (m)', fontsize = 40)
    ax_scat.set_ylabel('synthetic/photometric flux', fontsize = 40)
    ax_spec.tick_params(axis='x', which='minor', width=2, length=10, right=True, top=True, direction='in',
                       labelsize=fontsize_sm)
    ax_spec.tick_params(axis='x', which='major', width=3, length=15, right=True, top=True, direction='in',
                       labelsize=fontsize_sm)
    ax_spec.tick_params(axis='y', which='both', width=3, length=15, right=True, top=True, direction='in',
                       labelsize=fontsize_sm)
    ax_spec.set_ylabel('Intensity (MJy/sr)', fontsize = 40)
    
    ax_spec.set_title(f"{results['radius']}-radius aperture at location {results['loc_idx']}\nUsing {method}", fontsize = 50)
    
    ax_scat.set_xscale('log')
    ax_spec.set_xscale('log')
    ax_spec.set_yscale('log')
    label_positions = []  # to store display-space positions

    for x, y, name in zip(results['filter_wavelengths'], results[correction+'_synthetic_fluxes']/results['photo_fluxes'], results['filter_names']):
        # initial offset (just below the point)
        filter_short_wl, filter_long_wl = [x.value for x in get_filter_wl_range(name)]
        ax_scat.hlines(y=y, xmin=filter_short_wl, xmax=filter_long_wl, color='black', alpha=0.7, linewidth=3)
        y_offset = -0.05 

        if name in anchor_filters:
            color = 'red'
        else:
            color = 'black'
        # convert data point to display coords
        x_disp, y_disp = ax_spec.transData.transform((x, y))
        
        # check overlap in display coordinates
        too_close = False
        for (xx, yy) in label_positions:
            if abs(x_disp - xx) < 20 and abs(y_disp + y_offset - yy) < 5:  
                # 20px horizontal & 12px vertical proximity → overlap
                too_close = True
                break
        
        # if overlapping, nudge upward instead of downward
        if too_close:
            y_offset = +0.2  
        if y < 0.8:
            y_offset = +0.4
        # save adjusted label display position
        label_positions.append((x_disp, y_disp + y_offset))
        if name == "F182M":
            y_offset = +0.25
        if name == 'F212N':
            y_offset = +0.25
        # actually plot text in data coordinates
        
        ax_scat.text(
            x, y + y_offset,
            name,
            ha="center", va="top",
            fontsize=fontsize_sm, rotation = 90, color = color,
            bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=0.5)
        )

    show_image_files = [x for x in image_files if extract_filter_name(x) in show_images]
    image_locations = [(0.05, 0.75, 0.2, 0.2), (0.3, 0.75, 0.2, 0.2), (0.6, 0.41, 0.2, 0.2), (0.8, 0.41, 0.2, 0.2), (0.85, 0.65, 0.18, 0.18)]
    
    for i, (img, title) in enumerate(zip(show_image_files, show_images)):
        ax = fig.add_axes(image_locations[i])
        image_file = img
        hdu = fits.open(image_file)['SCI']
        image = hdu.data
        image_header = hdu.header
        image_wcs = WCS(image_header, naxis=2)
        pixel_scale = np.abs(image_wcs.wcs.cdelt[0]) * 3600  # arcsec/pixel
    
        loc_sky = SkyCoord(ra=results['location'][0]*u.deg, dec=results['location'][1]*u.deg, frame='icrs')
        cutout = Cutout2D(data = image, position = loc_sky, size = (results['radius']*3, results['radius']*3), wcs = image_wcs)
        ax.set_title('F115W image with bump apertures', fontsize = 30)
        x_img, y_img = cutout.wcs.all_world2pix(loc_sky.ra, loc_sky.dec, 0)
        im = ax.imshow(cutout.data, origin='lower', cmap='viridis',
                              norm=ImageNormalize(cutout.data, stretch=AsinhStretch(), vmin=0, vmax=np.percentile(cutout.data, 99)))
        ax.add_patch(Circle((x_img, y_img), (results['radius'].to(u.arcsec).value)/pixel_scale, ec='red', fc='none', lw=3, ls='-', alpha = 0.7))
        cbar = plt.colorbar(
            im,
            ax=ax,
            fraction=0.046,
            pad=0.04
        )
        cbar.ax.tick_params(labelsize = fontsize_lg)
        ax.set_title(title, fontsize = fontsize_lg)
        ax.set_xticks([])
        ax.set_yticks([])

    
    ymin, ymax = ax_scat.get_ylim()
    text_y_pos = ymin * 1.1
    ax_scat.axhline(y = 1, color = 'gray', linestyle = '--', linewidth = 4, alpha = 0.5)
    ax_scat.axvline(x=7.650000025896587e-06, color='gray', linestyle='--', linewidth=4, alpha=0.7)
    ax_scat.text(7.25e-6, text_y_pos, "← NIRCam", 
                 ha='right', va='center', 
                 bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=2),
                 fontsize=fontsize_lg)
    
    # Add MIRI label to the right
    ax_scat.text(8e-6, text_y_pos, "MIRI →", 
             ha='left', va='center', 
             bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=2),
             fontsize=fontsize_lg)
    ax_spec.legend(loc = 'upper left', bbox_to_anchor=(1, 1), fontsize = fontsize_lg)
    print('mean ratio : ', np.mean(results[correction+'_synthetic_fluxes']/results['photo_fluxes']))
    print('ratio std : ', np.std(results[correction+'_synthetic_fluxes']/results['photo_fluxes']))
    correction_factors = np.array(results[correction+'_correction_values'])
    print('corrections per solid angle : ', correction_factors/(np.pi*(results['radius'].value)**2))

    return None




In [None]:
show_image_and_synth('F1500W', full_raw_ifu_files_loc0, locations[0], radius, image_files=v0p3_images, color_min_max = [1, 99.5])

In [None]:

Ip25 = create_data(full_raw_ifu_files_loc0, v0p3_images, filter_files, locations[0], 1.25*u.arcsec)
plot_results(Ip25, correction = 'mult')
plot_results(Ip25, correction = 'add')

Ip0 = create_data(full_raw_ifu_files_loc0, v0p3_images, filter_files, locations[0], 1*u.arcsec)
plot_results(Ip0, correction = 'mult')
plot_results(Ip0, correction = 'add')

Op75 = create_data(full_raw_ifu_files_loc0, v0p3_images, filter_files, locations[0], 0.75*u.arcsec)
plot_results(Op75, correction = 'mult')
plot_results(Op75, correction = 'add')

Op5 = create_data(full_raw_ifu_files_loc0, v0p3_images, filter_files, locations[0], 0.5*u.arcsec)
plot_results(Op5, correction = 'mult')
plot_results(Op5, correction = 'add')

Op25 = create_data(full_raw_ifu_files_loc0, v0p3_images, filter_files, locations[0], 0.25*u.arcsec)
plot_results(Op25, correction = 'mult')
plot_results(Op25, correction = 'add')

Op1 = create_data(full_raw_ifu_files_loc0, v0p3_images, filter_files, locations[0], 0.1*u.arcsec)
plot_results(Op1, correction = 'mult')
plot_results(Op1, correction = 'add')

#################################################################################################


loc1_Ip25 = create_data(full_raw_ifu_files_loc1, v0p3_images, filter_files, locations[1], 1.25*u.arcsec)
plot_results(loc1_Ip25, correction = 'mult')
plot_results(loc1_Ip25, correction = 'add')

loc1_Ip0 = create_data(full_raw_ifu_files_loc1, v0p3_images, filter_files, locations[1], 1*u.arcsec)
plot_results(loc1_Ip0, correction = 'mult')
plot_results(loc1_Ip0, correction = 'add')

loc1_Op75 = create_data(full_raw_ifu_files_loc1, v0p3_images, filter_files, locations[1], 0.75*u.arcsec)
plot_results(loc1_Op75, correction = 'mult')
plot_results(loc1_Op75, correction = 'add')

loc1_Op5 = create_data(full_raw_ifu_files_loc1, v0p3_images, filter_files, locations[1], 0.5*u.arcsec)
plot_results(loc1_Op5, correction = 'mult')
plot_results(loc1_Op5, correction = 'add')

loc1_Op25 = create_data(full_raw_ifu_files_loc1, v0p3_images, filter_files, locations[1], 0.25*u.arcsec)
plot_results(loc1_Op25, correction = 'mult')
plot_results(loc1_Op25, correction = 'add')

loc1_Op1 = create_data(full_raw_ifu_files_loc1, v0p3_images, filter_files, locations[1], 0.1*u.arcsec)
plot_results(loc1_Op1, correction = 'mult')
plot_results(loc1_Op1, correction = 'add')

######################################################################################################


loc2_Ip25 = create_data(full_raw_ifu_files_loc1, v0p3_images, filter_files, locations[2], 1.25*u.arcsec)
plot_results(loc2_Ip25, correction = 'mult')
plot_results(loc2_Ip25, correction = 'add')

loc2_Ip0 = create_data(full_raw_ifu_files_loc1, v0p3_images, filter_files, locations[2], 1*u.arcsec)
plot_results(loc2_Ip0, correction = 'mult')
plot_results(loc2_Ip0, correction = 'add')

loc2_Op75 = create_data(full_raw_ifu_files_loc1, v0p3_images, filter_files, locations[2], 0.75*u.arcsec)
plot_results(loc2_Op75, correction = 'mult')
plot_results(loc2_Op75, correction = 'add')

loc2_Op5 = create_data(full_raw_ifu_files_loc1, v0p3_images, filter_files, locations[2], 0.5*u.arcsec)
plot_results(loc2_Op5, correction = 'mult')
plot_results(loc2_Op5, correction = 'add')

loc2_Op25 = create_data(full_raw_ifu_files_loc1, v0p3_images, filter_files, locations[2], 0.25*u.arcsec)
plot_results(loc2_Op25, correction = 'mult')
plot_results(loc2_Op25, correction = 'add')

loc2_Op1 = create_data(full_raw_ifu_files_loc1, v0p3_images, filter_files, locations[2], 0.1*u.arcsec)
plot_results(loc2_Op1, correction = 'mult')
plot_results(loc2_Op1, correction = 'add')

######################################################################################################


loc3_Ip25 = create_data(full_raw_ifu_files_loc3, v0p3_images, filter_files, locations[3], 1.25*u.arcsec)
plot_results(loc3_Ip25, correction = 'mult')
plot_results(loc3_Ip25, correction = 'add')

loc3_Ip0 = create_data(full_raw_ifu_files_loc3, v0p3_images, filter_files, locations[3], 1*u.arcsec)
plot_results(loc3_Ip0, correction = 'mult')
plot_results(loc3_Ip0, correction = 'add')

loc3_Op75 = create_data(full_raw_ifu_files_loc3, v0p3_images, filter_files, locations[3], 0.75*u.arcsec)
plot_results(loc3_Op75, correction = 'mult')
plot_results(loc3_Op75, correction = 'add')

loc3_Op5 = create_data(full_raw_ifu_files_loc3, v0p3_images, filter_files, locations[3], 0.5*u.arcsec)
plot_results(loc3_Op5, correction = 'mult')
plot_results(loc3_Op5, correction = 'add')

loc3_Op25 = create_data(full_raw_ifu_files_loc3, v0p3_images, filter_files, locations[3], 0.25*u.arcsec)
plot_results(loc3_Op25, correction = 'mult')
plot_results(loc3_Op25, correction = 'add')

loc3_Op1 = create_data(full_raw_ifu_files_loc3, v0p3_images, filter_files, locations[3], 0.1*u.arcsec)
plot_results(loc3_Op1, correction = 'mult')
plot_results(loc3_Op1, correction = 'add')

loc0_sets = [Ip25, Ip0, Op75, Op5, Op25, Op1]
loc1_sets = [loc1_Ip25, loc1_Ip0, loc1_Op75, loc1_Op5, loc1_Op25, loc1_Op1]
loc2_sets = [loc2_Ip25, loc2_Ip0, loc2_Op75, loc2_Op5, loc2_Op25, loc2_Op1]
loc3_sets = [loc3_Ip25, loc3_Ip0, loc3_Op75, loc3_Op5, loc3_Op25, loc3_Op1]


In [None]:
for set in loc0_sets:
    plot_results(set, correction = 'mult', show_images = ['F560W', 'F2100W'], color_min_max = [1, 99.5])
for set in loc1_sets:
    plot_results(set, correction = 'mult', show_images = ['F560W', 'F2100W'], color_min_max = [1, 99.5])
for set in loc2_sets:
    plot_results(set, correction = 'mult', show_images = ['F560W', 'F2100W'], color_min_max = [1, 99.5])
for set in loc3_sets:
    plot_results(set, correction = 'mult', show_images = ['F560W', 'F2100W'], color_min_max = [1, 99.5])


In [None]:
def show_image_and_synth(filter, ifu_fileset, loc, radius, image_files=v0p3_images):
    """
    Show real image and synthetic IFU-derived image side by side.
    Works with either one or two IFU cubes.
    """

    # ------------------------------------------------------------
    # Location handling
    # ------------------------------------------------------------
    if not isinstance(loc, SkyCoord):
        loc_sky = SkyCoord(ra=loc[0] * u.deg, dec=loc[1] * u.deg, frame="icrs")
    else:
        loc_sky = loc

    def nearest_spaxel_map(cube_src, cube_target):
        ny, nx = cube_target.shape[1:]
        y_t, x_t = np.mgrid[:ny, :nx]

        world = cube_target.wcs.celestial.pixel_to_world(x_t, y_t)
        x_s, y_s = cube_src.wcs.celestial.world_to_pixel(world)

        x_s = np.clip(np.round(x_s).astype(int), 0, cube_src.shape[2] - 1)
        y_s = np.clip(np.round(y_s).astype(int), 0, cube_src.shape[1] - 1)

        return y_s, x_s
    
    # ------------------------------------------------------------
    # Locate real image
    # ------------------------------------------------------------
    real_image_file = [x for x in image_files if extract_filter_name(x) == filter][0]
    
    short_wl, long_wl = [x.value for x in get_filter_wl_range(filter)]

    needed_ifus = []
    for file in ifu_fileset:
        wl = SpectralCube.read(file, hdu="SCI").spectral_axis.to(u.m).value
        if (wl[0] < short_wl) and (wl[-1] > long_wl):
            needed_ifus = [file]
            break
        if (long_wl > wl[0]) and (short_wl < wl[-1]):
            needed_ifus.append(file)
        if len(needed_ifus) > 1:
            break
    #
    cube1 = SpectralCube.read(needed_ifus[0], hdu="SCI")
    cube2 = SpectralCube.read(needed_ifus[1], hdu="SCI") if len(needed_ifus) > 1 else None

    # ------------------------------------------------------------
    # Base cube quantities
    # ------------------------------------------------------------
    wl1 = cube1.spectral_axis.to(u.m).value
    d1 = cube1.unmasked_data[:].value

    ny, nx = cube1.shape[1:]
    n_pix = ny * nx
    d1 = d1.reshape(len(wl1), n_pix).T

    # ------------------------------------------------------------
    # Stitch spectra if needed
    # ------------------------------------------------------------
    if cube2 is not None:
        wl2 = cube2.spectral_axis.to(u.m).value
        d2 = cube2.unmasked_data[:].value

        y2, x2 = nearest_spaxel_map(cube2, cube1)
        d2 = d2[:, y2, x2].reshape(len(wl2), n_pix).T

        wl_all = np.concatenate([wl1, wl2])
        sort_idx = np.argsort(wl_all)
        wl_all = wl_all[sort_idx]

        spec_all = np.concatenate([d1, d2], axis=1)[:, sort_idx]

        wl_min = max(wl1.min(), wl2.min())
        wl_max = min(wl1.max(), wl2.max())
        overlap = (wl_all >= wl_min) & (wl_all <= wl_max)
        both = np.isin(wl_all, wl1) & np.isin(wl_all, wl2) & overlap
        spec_all[:, both] *= 0.5
    else:
        wl_all = wl1
        spec_all = d1

    # ------------------------------------------------------------
    # Synthetic photometry
    # ------------------------------------------------------------
    filter_wl, filter_trans = get_filter_data(filter)

    image = np.empty(n_pix)
    for i in range(n_pix):
        image[i] = get_Fnu_transmission(
            spec_all[i], wl_all, filter_trans, filter_wl, warnings=True
        )

    synth_image = image.reshape(ny, nx)

    # ------------------------------------------------------------
    # Attach WCS to synthetic image
    # ------------------------------------------------------------
    synth_hdu = fits.PrimaryHDU(
        synth_image,
        header=cube1.wcs.celestial.to_header()
    )

    # ------------------------------------------------------------
    # Plot real vs synthetic
    # ------------------------------------------------------------
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # -------------------------
    # REAL IMAGE
    # -------------------------
    hdu = fits.open(real_image_file)["SCI"]
    real_pix_size = hdu.header['PIXAR_A2']**0.5
    aperture_radius = radius.to(u.arcsec).value / real_pix_size
    cutout_real = Cutout2D(
        hdu.data,
        position=loc_sky,
        size=(radius * 3, radius * 3),
        wcs=WCS(hdu.header)
    )

    # -------------------------
    # SYNTHETIC IMAGE (WCS CUTOUT)
    # -------------------------
    cutout_synth = Cutout2D(
        synth_hdu.data,
        position=loc_sky,
        size=(radius * 3, radius * 3),
        wcs=WCS(synth_hdu.header)
    )

    # ------------------------------------------------------------
    # Shared normalization (1–99%)
    # ------------------------------------------------------------
    combined = np.concatenate([
        cutout_real.data[np.isfinite(cutout_real.data)],
        cutout_synth.data[np.isfinite(cutout_synth.data)]
    ])

    vmin = np.percentile(combined, 1)
    vmax = np.percentile(combined, 99.5)

    cmap = plt.get_cmap("viridis").copy()
    cmap.set_under("black")
    cmap.set_over("white")
    norm = colors.Normalize(vmin=vmin, vmax=vmax, clip=False)

    pix_scale = np.abs(cutout_synth.wcs.wcs.cdelt[0]) * 3600
    r_ap_pix = radius.to(u.arcsec).value / pix_scale

    # -------------------------
    # PLOT REAL
    # -------------------------
    im0 = axes[0].imshow(
        cutout_real.data,
        origin="lower",
        cmap=cmap,
        norm=norm
    )
    axes[0].set_title(f"{filter} – Real")

    cbar0 = plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)
    cbar0.set_label("Flux", fontsize=14)
    cbar0.ax.tick_params(labelsize=12)

    x_r, y_r = cutout_real.wcs.world_to_pixel(loc_sky)
    axes[0].add_patch(
        Circle((x_r, y_r), aperture_radius, edgecolor="red", facecolor="none", linewidth=2)
    )

    # -------------------------
    # PLOT SYNTHETIC
    # -------------------------
    im1 = axes[1].imshow(
        cutout_synth.data,
        origin="lower",
        cmap=cmap,
        norm=norm
    )
    axes[1].set_title(f"{filter} – Synthetic (IFU)")

    cbar1 = plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
    cbar1.set_label("Flux", fontsize=14)
    cbar1.ax.tick_params(labelsize=12)

    x_s, y_s = cutout_synth.wcs.world_to_pixel(loc_sky)
    axes[1].add_patch(
        Circle((x_s, y_s), r_ap_pix, edgecolor="red", facecolor="none", linewidth=2)
    )

    # -------------------------
    # Cleanup
    # -------------------------
    for ax in axes:
        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout()
    plt.show()






#Ip25 = create_data(full_raw_ifu_files_loc0, v0p3_images, filter_files, locations[0], 1.25*u.arcsec)

plot_results(Ip25, correction = 'mult', show_images = ['F560W', 'F2100W'], color_min_max = [1, 99.5])

In [None]:
for set in loc1_sets:
    plot_results(set, correction = 'mult', show_images = ['F115W', 'F2100W'], color_min_max = [1, 99.5])


In [None]:
v0p3_images