**Author:** Anowar Shajib

*With thanks to David Law and Takahiro Morishita*

In [None]:
import pathlib

import numpy as np
from scipy.ndimage import median_filter

import matplotlib.pyplot as plt

from scipy.ndimage import gaussian_filter
from scipy import ndimage
import lacosmic

from util import *

In [None]:
clean_image = True
use_lacosmic = True

# Sigma clipping and CR cleaning

In [None]:
if clean_image:
    files = np.array(sorted(glob.glob(stage1_directory + "*nrs1_rate.*fits")))

    for file in files:
        with fits.open(file) as hdu:
            sci = hdu["SCI"].data
            sci_original = sci.copy()
            noise = hdu["ERR"].data

            median_clip_mask = np.zeros_like(sci, dtype=bool)

            # Flag positive pixels
            temp = sci.copy()
            temp[np.isfinite(temp) != True] = 0.0
            sci2 = median_filter(temp, size=(1, 11))
            diff = temp - sci2 * 2
            index = np.where(diff > 0.5)
            sci[index] = np.nan
            median_clip_mask[index] = True
            # Grow by 1 pixel in both directions
            y_index = index[1]
            x_index = index[0]
            n_index = len(y_index)

            for i in range(n_index):
                sci[x_index[i] - 1, y_index[i]] = np.nan
                sci[x_index[i] + 1, y_index[i]] = np.nan
                sci[x_index[i], y_index[i] - 1] = np.nan
                sci[x_index[i], y_index[i] + 1] = np.nan
                median_clip_mask[x_index[i] - 1, y_index[i]] = True
                median_clip_mask[x_index[i] + 1, y_index[i]] = True
                median_clip_mask[x_index[i], y_index[i] - 1] = True
                median_clip_mask[x_index[i], y_index[i] + 1] = True

            # Flag negative pixels
            temp = sci.copy()
            temp[np.isfinite(temp) != True] = 0.0
            sci2 = median_filter(temp, size=(1, 11))
            diff = temp - sci2 * 0.5
            index = np.where(diff < -0.5)
            sci[index] = np.nan
            median_clip_mask[index] = True
            # Grow by 1 pixel in both directions
            y_index = index[1]
            x_index = index[0]
            n_index = len(y_index)
            for i in range(n_index):
                sci[x_index[i] - 1, y_index[i]] = np.nan
                sci[x_index[i] + 1, y_index[i]] = np.nan
                sci[x_index[i], y_index[i] - 1] = np.nan
                sci[x_index[i], y_index[i] + 1] = np.nan
                median_clip_mask[x_index[i] - 1, y_index[i]] = True
                median_clip_mask[x_index[i] + 1, y_index[i]] = True
                median_clip_mask[x_index[i], y_index[i] - 1] = True
                median_clip_mask[x_index[i], y_index[i] + 1] = True

            s = 500
            x_start = 150
            y_start = 1150
            vmax = 1.0
            vmin = -1.0

            def clip(arr):
                return arr[y_start : y_start + s, x_start : x_start + s]

            fig, ax = plt.subplots(1, 2, figsize=(20, 10))

            def plot(n, arr):
                temp = arr.copy()
                temp[np.isfinite(temp) != True] = 0.0

                ax[n].matshow(
                    np.log10(gaussian_filter(temp, sigma=2)),
                    vmax=3,
                    vmin=-7,
                    cmap="cubehelix",
                )

            plot(0, clip(sci_original))
            ax[0].set_title("Original")

            plot(1, clip(sci))
            ax[1].set_title("Outlier cleaned")

            plt.show()

            print(
                "original non-NaN pixels:",
                np.sum(~np.isnan(sci_original)),
                "sigma_clipped:",
                np.sum(~np.isnan(sci)),
            )

            if use_lacosmic:
                dq_flag = fits.getdata(file, ext=3)
                jump = (dq_flag & 4) == 4
                do_not_use = (dq_flag & 1) == 1

                fig, ax = plt.subplots(1, 2, figsize=(20, 10))
                ax[0].matshow(clip(sci_original), vmax=vmax, vmin=vmin, cmap="coolwarm")
                ax[0].set_title("Original")
                ax[1].matshow(clip(sci), vmax=vmax, vmin=vmin, cmap="coolwarm")
                ax[1].set_title("Outlier cleaned")
                plt.show()

                fig, ax = plt.subplots(1, 2, figsize=(20, 10))

                sci_masked_do_not_use = sci.copy()
                sci_masked_do_not_use[do_not_use] = np.nan

                crclean, crmask = lacosmic.lacosmic(
                    sci,
                    3,
                    4,
                    3,
                    error=noise,
                    maxiter=4,
                    readnoise=13.0,
                    mask=do_not_use & np.isnan(sci),
                )

                crmask = ndimage.binary_fill_holes(crmask)

                ax[0].matshow(
                    clip(sci_masked_do_not_use),
                    vmax=vmax,
                    vmin=vmin,
                    cmap="coolwarm",
                )
                ax[0].set_title("Original")

                cr_masked = sci_masked_do_not_use.copy()
                cr_masked[crmask] = -100

                ax[1].matshow(clip(cr_masked), vmax=vmax, vmin=vmin, cmap="coolwarm")
                ax[1].set_title("CR cleaned")

                print(crmask.shape, sci_original.shape)
                plt.show()

            sci[crmask] = np.nan

            hdu["SCI"].data = sci
            name = pathlib.Path(file).name
            hdu.writeto(stage1_processed_directory + name, overwrite=True)