# Notebook to mask persistence from saturation in MIRI imaging

Authors: S. Alberts<br>
Last Updated: December 5, 2025<br>
Tested using JWST Pipeline v1.19

---

## Purpose 
This notebook identifies saturated pixels in dithered MIRI imaging uncal files and sets the affected pixels' data quality (DQ) flag to DO_NOT_USE in subsequent exposures to mask persistence.  <span style="color:orange;">This alogorithm only addresses persistence caused by saturating sources observed during the input MIRI exposures.  Persistence artifacts can also be caused by saturation from observations and/or slewing prior to the input MIRI observations.  If these pixels are identified by eye, a user-supplied pixel mask or list of effected pixels can be input.</span>

## Persistence artifacts following saturation
Persistence artifacts can occur when bright (saturating or non-saturating) sources fall on the MIRI detectors, leaving a residual charge in affected pixels.  This residual charge, called persistence, decays away on a timescale dependent on the brightness of the source, how long a pixel was exposed to that source, and other factors.  Because of this complexity, there is currently no correction for persistence; mitigation techniques include dithering during the observations and masking the affected pixels in post-processing.  In the case of persistence from saturation occuring during an observation, we can use the saturation flagging in the pipeline to identify and mask pixels that are likely effected by persistence, as shown in this notebook. 

Persistence can appear both as the affected pixels being over-luminous (positive persistence) or under-luminous (negative persistence).  Positive persistence decays more quickly, on the order of minutes.  The decay timescale for negative persistence is not well known, but it has been observed to last for tens of hours or more and persist across integrations, dithers and filter changes. Because of this, a given affected pixel may need to be masked across subsequent integrations and/or exposures over a long timescale. For more information on persistance, including different sources of persistance, visit the [JDocs MIRI Imaging Known Issues page](https://jwst-docs.stsci.edu/known-issues-with-jwst-data/miri-known-issues/miri-imaging-known-issues#MIRIImagingKnownIssues-Persistence) and see [Dicken et al. 2024](https://ui.adsabs.harvard.edu/abs/2024A%26A...689A...5D/abstract).

## Method
The input is a set of MIRI imaging uncal files, which can include multiple imaging filters. The uncal files are ordered by time at mid-exposure. The [DQInitStep](https://jwst-pipeline.readthedocs.io/en/latest/jwst/dq_init/index.html) and [SaturationStep](https://jwst-pipeline.readthedocs.io/en/stable/jwst/saturation/index.html) are then run on each uncal file to set up the [data quality (DQ) array](https://jwst-docs.stsci.edu/accessing-jwst-data/jwst-science-data-overview#JWSTScienceDataOverview-Dataqualityarrays(DQ)) and identify saturated or partially saturated pixels. Saturation flags are recorded group-by-group in the GROUPDQ array. This code uses the GROUPDQ array to then identify pixels to mask based on a required minimum number of unsaturated groups, set by the user.  For example, the user may specify that any pixel that has less than or equal to 3 unsaturated groups will be masked. Pixels that meet this threshold are recorded and masked in  ALL SUBSEQUENT exposures in the input uncal set by setting the PIXELDQ flag to DO_NOT_USE.  All input uncal exposures are checked for saturation in order by time at mid-exposure and the masking is cumulative.  The masked uncal files are written out and can then be processed through [calwebb detector1](https://jwst-pipeline.readthedocs.io/en/latest/jwst/pipeline/calwebb_detector1.html) as per normal pipeline procedure.

---

In [None]:
import os

# Check whether the local CRDS cache directory has been set.
# If not, set it to the user home directory
if (os.getenv('CRDS_PATH') is None):
    os.environ['CRDS_PATH'] = os.path.join(os.path.expanduser('~'), 'crds')
# Check whether the CRDS server URL has been set.  If not, set it.
if (os.getenv('CRDS_SERVER_URL') is None):
    os.environ['CRDS_SERVER_URL'] = 'https://jwst-crds.stsci.edu'

# Echo CRDS path and context in use
print('CRDS local filepath:', os.environ['CRDS_PATH'])
print('CRDS file server:', os.environ['CRDS_SERVER_URL'])


import glob
import numpy as np

from astropy.table import Table

from jwst import datamodels
from jwst.dq_init import DQInitStep
from jwst.saturation import SaturationStep

print("JWST Calibration Pipeline Version = {}".format(jwst.__version__))
print("Using CRDS Context = {}".format(crds.get_context_name('jwst')))


## <u>User-inputs</u>

### ---Set directories with uncal files to flag---

All uncal files in a given directory will be grabbed.  Multiple directories can be specified.

### ---Set threshold for saturation---

Pixels with less than or equal to ```n_unsat_threshold``` **unsaturated** groups will be flagged for masking (default ```n_unsat_threshold=3```).

### ---Set mask for pixel identified by user---

A custom mask can be used to mask user-identified pixels in all exposures.  This mask can either be supplied via 1) an input list of pixel coordinates (1-indexed) or 2) a boolean array supplied as a fits file.  

To use a custom mask, set ```use_custom_mask = True``` and supply a ```user_input``` as a <span style="color:blue;">**.txt**</span> or <span style="color:blue;">**.fits**</span> file.  A pixel coordinate text file needs to have integer value columns <span style="color:blue;">x</span> and <span style="color:blue;">y</span>.

In [None]:
# directories containing the uncal files to mask
# uncal_dirs = ['/path/to/uncal_files1/', '/path/to/uncal_files2/']
uncal_dirs = ['']

# set the threshold for hard saturation
n_unsat_threshold = 3

# Define a custom mask, if desired
# use_custom_mask = False and/or user_input = '' will skip using a custom mask
use_custom_mask = False
user_input = '' 

## <u>Define Useful Functions</u>

- **get_obs_time(file, verbose=False):**  
    Returns the mid-exposure time (MJD) from a given uncal file.

- **order_by_time(files):**  
    Sorts a list of files by their mid-exposure time.

- **find_saturated_pixels(dq, n_unsat_threshold=3, dq_value=2):**  
    Identifies pixels with fewer than a threshold number of unsaturated groups, indicating saturation.

- **check_user_input_type(user_input, uncal_file=None):**  
    Determines if the user-supplied mask is a FITS file or a pixel list, validates its shape, and returns a boolean mask.

- **create_user_mask_from_pix_list(x, y, uncal_file):**  
    Creates a boolean mask from lists of x and y pixel coordinates (1-indexed).

In [None]:
def get_obs_time(file, verbose=False):
    dm = datamodels.RampModel(file)
    obs_time = dm.meta.exposure.mid_time_mjd
    dm.close()
    if verbose:
        print(obs_time)
    return obs_time

def order_by_time(files):
    return sorted(files, key=get_obs_time)

def find_saturated_pixels(dq, n_unsat_threshold=3, dq_value=2):
    # dq flag = 2 means saturation
    # Ensure the array is 4D
    assert dq.ndim == 4, "Input array must be 4-dimensional"

    # If multi-int, just make a mask of the first integration
    dq = dq[0,:,:,:]

    num_unsat = (dq != dq_value).sum(axis=0)
    mask = num_unsat <= n_unsat_threshold

    return mask

def check_user_input_type(user_input, uncal_file=None):
    import re

    # check if user input is a fits file
    if user_input.lower().endswith('.fits'):
        user_mask_dm = datamodels.open(user_input)

        if uncal_file is None:
            raise ValueError(f'Uncal file must be provided to check user mask dimensions.')
            return None
        
        else:
            dm = datamodels.RampModel(uncal_file)
            nx, ny = dm.data.shape[3], dm.data.shape[2]
            dm.close()
        
            if user_mask_dm.data.shape[0] != nx or user_mask_dm.data.shape[1] != ny:
                raise ValueError(f'User mask dimensions {user_mask_dm.data.shape} do not match data dimensions {ny, nx}.')
            else:
                return user_mask_dm.data

    # check if user input is a list of pixel coordinates
    elif user_input.lower().endswith('.txt'):
        # Check if file exists and has two columns
        tab = Table.read(user_input, format='ascii')
        if 'x' in tab.colnames and 'y' in tab.colnames:
            x = tab['x'].data
            y = tab['y'].data

            user_mask = create_user_mask_from_pix_list(x, y, uncal_file)
            return user_mask

        else:
            raise ValueError('Input text file must contain "x" and "y" columns.')
            return None
                
    else:
        print('User input type not recognized. Please provide a fits file or a list of pixel coordinates.')
        return None

def create_user_mask_from_pix_list(x, y, uncal_file):
    dm = datamodels.RampModel(uncal_file)
    nx, ny = dm.data.shape[3], dm.data.shape[2]
    dm.close()

    mask = np.zeros((ny, nx), dtype=bool)
    
    # Ensure x and y are numpy arrays of integers
    # shift to zero indexed
    x = np.asarray(x, dtype=int) - 1
    y = np.asarray(y, dtype=int) - 1

    valid = (x >= 0) & (x < nx) & (y >= 0) & (y < ny)
    mask[y[valid], x[valid]] = True

    return mask



## Set up masking

In [None]:
# gather all uncal files
uncal_files = []
for dirs in uncal_dirs: uncal_files += glob.glob(f'{dirs}/*mirimage_uncal.fits')

# sort by time
uncal_files = order_by_time(uncal_files)
print(f'Found and sorted {len(uncal_files)} uncal files.')


# check if user input is provided for custom mask
if not use_custom_mask or user_input == '' or user_input.strip() == '':
        user_mask = None
else:
    user_mask = check_user_input_type(user_input, uncal_file=uncal_files[0])

# set up logging to dump pipeline output
logcfg = 'stpipe-log.cfg'
if not os.path.exists(logcfg): 
        with open(logcfg, 'w') as f:
            f.write('[*]\nlevel = INFO\nhandler = append:stpipe.log')
            f.close()
if not os.path.exists('stpipe.log'): 
        with open('stpipe.log', 'w') as f:
            f.close()

## Perform masking

In [None]:
# go through uncal files in order

masks = {}
num_masked = []

for i, file in enumerate(uncal_files):
    # loop through uncal files in order of time of mid-exposure
    print(f'Processing file {i+1}/{len(uncal_files)}: {file}')
    
    dm = datamodels.RampModel(file)
    ngroups = dm.meta.exposure.ngroups

    if i==0: 
        # should be 1024, 1032
        ny, nx = dm.data.shape[2], dm.data.shape[3]

        # check if user mask is provided 
        if user_mask is not None:
            print(f'Using user-supplied custom mask.')

            masks['user_mask'] = user_mask
            num_masked.append(np.count_nonzero(masks['user_mask']))
            print(f'User mask has {num_masked[-1]} pixels masked.')
        else:
            masks['user_mask'] = None
    
    else:
        # check if the dimensions of the current file match the first file
        assert ny == dm.data.shape[2] and nx == dm.data.shape[3], \
            f'Warning: {file} has different dimensions than previous files. Skipping.'
    
    # Initialize the data quality array
    dm = DQInitStep.call(dm, logcfg=logcfg)

    # run the saturation step
    dm = SaturationStep.call(dm, logcfg=logcfg)

    # find saturated pixels
    mask = find_saturated_pixels(dm.groupdq, n_unsat_threshold=n_unsat_threshold)
    
    # check if any saturated pixels were found, add mask to masks dict
    masks[file] = {}
    if np.count_nonzero(mask) == 0:
        print('No saturated pixels found.')
        masks[file]['mask'] = []
    else:
        masks[file]['mask'] = mask
        num_masked.append(np.count_nonzero(mask))
        print(f'Found {np.count_nonzero(mask)} saturated pixels ({(100 * num_masked[-1] / (nx * ny)):.2f}% of total pixels)\n\n.')

    # apply user mask if provided
    if masks['user_mask'] is not None:
        dm.pixeldq[masks['user_mask']] = 1
    
    # for all exposure after the first, 
    # mask pixels that were saturated in previous exposures
    if i!=0:
        combined_mask = np.logical_or.reduce([masks[f]['mask'] for f in uncal_files[0:i] if isinstance(masks[f]['mask'], np.ndarray)])
        dm.pixeldq[combined_mask] = 1
  

    dm.to_fits(file.replace('_uncal.fits', '_masked_persistence_uncal.fits'), overwrite=True)
    dm.close()

print(f'Cumulative percentage of pixels masked: {(100 * np.count_nonzero(combined_mask) / (nx * ny)):.2f}%')
