I was initially concerned about outlier events (likely satellite trails?) in flat images. This appears in a handful of images with excessively high, saturated banded features. 
However, when using median for image combination, then these features are mostly reduced away. 

However, as shown in the histograms below, removing these frames altogether results in shifts to any pixel less than the global noise level, and is at worst results in a median deviation of order noise / 4. In the worst case, only about 50 pixels deviate by more than 5 sigma. As a result, we chose to simply use the mean-combined images for further analysis.
We note that we chose to use median image combination for the calibration images to avoid explicitly dealing with these issues and reducing possible systematic biases as a result.

In [None]:
foldername = ["20230708", "20230709", "20230723"][0]

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import ccdproc
from ccdproc import ImageFileCollection

from astropy.nddata import CCDData
import arya

In [None]:
from convenience_functions import show_image, combine_images, show_image_residual, mad_std

In [None]:
def get_flats(foldername, filt):
    imgfiles = ImageFileCollection(foldername, glob_include=f"flat_{filt}_*.fits")

    return imgfiles

In [None]:
filts = ["g", "r", "i"]

In [None]:
imgfiles = {filt: get_flats(foldername + "/unbiased/", filt) for filt in filts}

In [None]:
flats_stacked = {filt: CCDData.read(foldername + f"/flat_{filt}_stacked.fits") for filt in filts}

## Alternative flat stacking methods

In [None]:
bad_flats = {
    "20230708": ["flat_i_05.fits"],
    "20230709": ["flat_g_06.fits", "flat_g_07.fits", "flat_i_01.fits",
    "flat_i_03.fits"],
    "20230723": ["flat_r_04.fits", "flat_i_04.fits"],
    "20230816": []
}


In [None]:
def stack_flats(foldername, **kwargs):
    flats = {}
    for filt in ["g", "r", "i"]:
        imgfiles = get_flats(foldername + "/unbiased/", filt)
        print(f"read in image files for band {filt} ", imgfiles.files)

        imgs = [img for img in imgfiles.ccds()]
        for img in imgs:
            img.data /= np.nanmedian(img)

        print(type(imgs[0]))
        flats[filt] = combine_images(imgs, **kwargs)

    return flats

In [None]:
def stack_only_good_flats(foldername):
    flats = {}
    for filt in ["g", "r", "i"]:
        imgfiles = get_flats(foldername + "/unbiased/", filt)
        print(f"read in image files for band {filt} ", imgfiles.files)

        files = imgfiles.files_filtered(include_path=True)

        # removing files method
        files_filtered = [file for file in files if file.split("/")[-1] not in bad_flats[foldername]] 
        
        print("stacking: ", files_filtered)
        flats[filt] = combine_images(files_filtered, scale=lambda x: 1/np.median(x))
        

    return flats

In [None]:
flats_stacked_filtered = stack_only_good_flats(foldername)

In [None]:
flats_stacked_mean = stack_flats(foldername, method="average")

In [None]:
flat_std = {filt: mad_std(np.stack([x / np.median(x) for x in imgfiles[filt].data()]), axis=0)
            for filt in filts}

In [None]:
noise = np.mean([np.median(flat_std[filt]) for filt in filts])
noise

# Plots

In [None]:
Npix = np.prod(flat_std["g"].size)

In [None]:
def plot_noise_hist(Npix=Npix, ymin=1):
    x = 1 + np.linspace(-0.1, 0.1, 1000)
    y = Npix * 1 / np.sqrt(2*np.pi * noise**2) * np.exp(-(x-1)**2 / (2 * noise**2))

    plt.plot(x[y>ymin], y[y>ymin], color="red")

In [None]:
for filt in filts:
    plt.figure()
    plot_noise_hist(1, ymin=1e-4)
        
    x = (flats_stacked_mean[filt].data / flats_stacked[filt].data).flatten()
    plt.hist(x, histtype="step", density=True, color=arya.COLORS[1])
    plt.yscale("log")
    plt.xlim(np.min(x), np.max(x))

In [None]:
for filt in filts:
    plt.figure()
    plot_noise_hist(1, ymin=1e-4)


    x = (flats_stacked[filt].data / flats_stacked_filtered[filt].data).flatten()
    if np.any(x != 1):
        plt.hist(x, histtype="step", density=True)
        print("std:", np.std(x))
        print(np.sum(np.abs(x-1) / noise > 5))
    else:
        print("skipping")
        
    x = (flats_stacked_mean[filt].data / flats_stacked_filtered[filt].data).flatten()
    plt.hist(x, histtype="step", density=True, color=arya.COLORS[1])
    plt.yscale("log")
    plt.xlim(np.min(x), np.max(x))
    


In [None]:
def flat_scale(img):
    return 1/np.median(img.data)

In [None]:
for filt in filts:
    for img, fname in imgfiles[filt].ccds( return_fname=True):
        flat = flats_stacked_filtered[filt]
        img_reduced = img.data * flat_scale(img) / flat
        fig, axs = plt.subplots(1, 2, figsize=(5, 2.5))
        
        show_image(img_reduced, fig=fig, ax=axs[0], clabel = "flat / flat mean")
        
        axs[0].set_title(fname)
        plt.sca(axs[1])
        plt.hist(img_reduced.data.flatten(), density=True)
        plt.yscale("log")
        plt.xlabel("relative value")
        plt.ylabel("pixel count")
        plot_noise_hist(1, ymin=1e-4)

        plt.tight_layout()
        # savefig("flat_residual." + fname)