# MIRI Imaging Data Reduction Notebook

By the JWST TEMPLATES ERS Team 

Kedar A. Phadke

Last updated: December 18th, 2023

Tested by Lily Kettler

Mostly compiled from JWebbinar 3: https://github.com/spacetelescope/jwebbinar_prep/tree/webbinar3
and https://github.com/STScI-MIRI/Imaging_ExampleNB/blob/main/helpers/miri_clean.py (for de-striping)

Requirements: Installation of JWST pipeline and data for MIRI Imaging. This notebook will show how to create stage 3 calibrated products from uncalibrated data products. Relevant sections can also be used for only running specific stages of the pipeline. For example, downloading level 2b products and running stage 3 section only. It is always better to ensure the CRDS pmap version is consistent throughout all stages

We recommend to create different folders for different filters. For organizational purpose we assume all 'UNCAL' files using a particular filter are in a folder called 'uncal' in the current working directory for this notebook to work.

In [None]:

# Packages that allow us to get information about objects:
import asdf
import copy
import os
import shutil
import glob

# Numpy library:
import numpy as np
# Astropy tools:
from astropy.io import fits


In [None]:
import jwst

print('Pipeline version:', jwst.__version__)

# List of possible data quality flags
from jwst import datamodels
from jwst.datamodels import dqflags
# To read association file
import json

In [None]:
#Uncomment below two lines if you have not defined CRDS elsewhere
# os.environ["CRDS_PATH"] = "path to CRDS_cache"   #Change this to appropriate path on your disk
# os.environ["CRDS_SERVER_URL"] = "https://jwst-crds.stsci.edu"

## Stage 1

In [None]:
# The entire calwebb_detector1 pipeline
from jwst.pipeline import calwebb_detector1

# Individual steps that make up calwebb_detector1
from jwst.group_scale import GroupScaleStep
from jwst.dq_init import DQInitStep
from jwst.saturation import SaturationStep
from jwst.firstframe import FirstFrameStep
from jwst.lastframe import LastFrameStep
from jwst.reset import ResetStep
from jwst.linearity import LinearityStep
from jwst.rscd import RscdStep
from jwst.dark_current import DarkCurrentStep                                                                                   
from jwst.refpix import RefPixStep
from jwst.jump import JumpStep
from jwst.ramp_fitting import RampFitStep
from jwst.persistence import PersistenceStep
from jwst.gain_scale import GainScaleStep
from jwst.ipc import IPCStep 

In [None]:
output_dir = './stage1/'

#Directory names for individual steps. If you are not interested in saving them please comment each line after this
output_dir_dq = output_dir+'dq/'
output_dir_saturation = output_dir+'saturation/'
output_dir_ipc = output_dir+'ipc/'
output_dir_firstframe = output_dir+'firstframe/'
output_dir_lastframe = output_dir+'lastframe/'
output_dir_reset = output_dir+'reset/'
output_dir_linearity = output_dir+'linearity/'
output_dir_rscd = output_dir+'rscd/'
output_dir_darkcurrent = output_dir+'darkcurrent/'
output_dir_refpix = output_dir+'refpix/'
output_dir_jump = output_dir+'jump/'
output_dir_rampfitting = output_dir+'rampfitting/'
output_dir_gainscale = output_dir+'gainscale/'

In [None]:
os.makedirs(output_dir, exist_ok=True)

#same as previous cell
os.makedirs(output_dir_dq, exist_ok=True)
os.makedirs(output_dir_saturation, exist_ok=True)
os.makedirs(output_dir_ipc, exist_ok=True)
os.makedirs(output_dir_firstframe, exist_ok=True)
os.makedirs(output_dir_lastframe, exist_ok=True)
os.makedirs(output_dir_reset, exist_ok=True)
os.makedirs(output_dir_linearity, exist_ok=True)
os.makedirs(output_dir_rscd, exist_ok=True)
os.makedirs(output_dir_darkcurrent, exist_ok=True)
os.makedirs(output_dir_refpix, exist_ok=True)
os.makedirs(output_dir_jump, exist_ok=True)
os.makedirs(output_dir_rampfitting, exist_ok=True)
os.makedirs(output_dir_gainscale, exist_ok=True)

In [None]:
#If some parameters are known to have better results with certain value use the dictionary to edit those parameters
parameter_dict = {"dq_init": {"output_dir": output_dir_dq,"save_results": True},
                  "saturation": {"output_dir": output_dir_saturation,"save_results": True},
                  "ipc": {"output_dir": output_dir_ipc,"save_results": True},
                  "firstframe": {"output_dir": output_dir_firstframe,"save_results": True},
                  "lastframe": {"output_dir": output_dir_lastframe,"save_results": True},
                  "reset": {"output_dir": output_dir_reset,"save_results": True},
                  "linearity": {"output_dir": output_dir_linearity,"save_results": True},
                  "rscd": {"output_dir": output_dir_rscd,"save_results": True},
                  "dark_current": {"output_dir": output_dir_darkcurrent,"save_results": True},
                  "refpix": {"output_dir": output_dir_refpix,"save_results": True,"use_side_ref_pixels":False},
                  "jump": {"rejection_threshold": 5,"output_dir": output_dir_jump,"save_results": True}, # if one sees CR not being flagged properly, this is the step to modify
                  "ramp_fit": {"output_dir": output_dir_rampfitting,"save_results": True},
                  "gain_scale": {"output_dir": output_dir_gainscale,"save_results": True},
                 }

In [None]:
#Directory where the uncalibrated files are
input_dir='./uncal/'

In [None]:
list_files=glob.glob(input_dir+'*_uncal.fits')
print('No of files to be processed:', len(list_files))

In [None]:
for i in range(len(list_files)):    
    miri_uncal_file = list_files[i]
    print('File currently being processed:',miri_uncal_file)
    # Call the pipeline method using the dictionary
    miri_output = calwebb_detector1.Detector1Pipeline.call(miri_uncal_file, output_dir=output_dir, save_results=True, steps=parameter_dict,logcfg='stage1-log.cfg')
    

## Stage 2

In [None]:
# The entire calwebb_image2 pipeline
from jwst.pipeline import calwebb_image2

# Individual steps that make up calwebb_image2
from jwst.background import BackgroundStep
from jwst.assign_wcs import AssignWcsStep
from jwst.flatfield import FlatFieldStep
from jwst.photom import PhotomStep
from jwst.resample import ResampleStep

In [None]:
output_dir = './stage2/'
output_dir_bkg = output_dir+'bkg/'
output_dir_assign_wcs = output_dir+'assign_wcs/'
output_dir_flatfield = output_dir+'flatfield/'
output_dir_photom = output_dir+'photom/'
output_dir_resample = output_dir+'resample/'

In [None]:
os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_dir_bkg, exist_ok=True)
os.makedirs(output_dir_assign_wcs, exist_ok=True)
os.makedirs(output_dir_flatfield, exist_ok=True)
os.makedirs(output_dir_photom, exist_ok=True)
os.makedirs(output_dir_resample, exist_ok=True)

In [None]:
input_dir='./stage1/'

In [None]:
os.system('cp '+input_dir+'*_rate.fits '+output_dir+'')

In [None]:
# Create and open the association file and load into json object
os.system('asn_from_list -o level2_asn.json -r DMSLevel2bBase '+output_dir+'*_rate.fits')
asn_file ='level2_asn.json'

In [None]:
#If some parameters are known to have better results with certain value use the dictionary
parameter_dict = {"bkg_subtract": {"sigma":4,"output_dir": output_dir_bkg,"save_results": True},
                  "assign_wcs": {"output_dir": output_dir_assign_wcs,"save_results": True},
                  "flat_field": {"output_dir": output_dir_flatfield,"save_results": True},
                  "photom": {"output_dir": output_dir_photom,"save_results": True},
                  "resample": {"pixfrac": 1.0, "output_dir": output_dir_resample,"save_results": True},
                 }

In [None]:
call_output = calwebb_image2.Image2Pipeline.call(asn_file, output_dir=output_dir, save_results=True, steps=parameter_dict,logcfg='stage2-log.cfg')


In [None]:
os.system('rm '+output_dir+'*_rate.fits')

## Stage 3

In [None]:
# The entire calwebb_image3 pipeline
from jwst.pipeline import calwebb_image3

# Individual steps that make up calwebb_image3
from jwst.tweakreg import TweakRegStep
from jwst.skymatch import SkyMatchStep
from jwst.outlier_detection import OutlierDetectionStep
from jwst.resample import ResampleStep
from jwst.source_catalog import SourceCatalogStep
from jwst.associations import asn_from_list
from jwst.associations.lib.rules_level3_base import DMS_Level3_Base


In [None]:
#For de-striping
from astropy.modeling import models, fitting

from astropy.stats import sigma_clipped_stats
from astropy.convolution import Gaussian1DKernel, convolve
from astropy.wcs import WCS

Below two functions are from https://github.com/STScI-MIRI/Imaging_ExampleNB/blob/main/helpers/miri_clean.py
For column and row median removal.

In [None]:
def cal_column_clean(mfile, exclude_above=None):
    """
    Remove the median of each column to suppress residual detector artifacts

    works on cal images

    Parameters
    ----------
    mfile : str
        filename with a MIRI cal image (i.e., xxx_cal.fits)
    exclude_above : float
        value above which to exclude data from calculating the column median
    """
    # Create kernel
    g = Gaussian1DKernel(stddev=65)
    #g = Gaussian1DKernel(stddev=35)

    # read in the final rate image
    rdata = datamodels.open(mfile)
    rimage = copy.deepcopy(rdata.data)

    # use the cal file dq flags as only after flat fielding are the outside the
    # FOV regions flagged
    bdata = rdata.dq & dqflags.pixel["DO_NOT_USE"] > 0

    colimage = np.zeros(rimage.shape)

    # mask all the do_not_use data with NaNs
    rimage[bdata] = np.NaN
    # exclude that bright column near the right edge
    rimage[:,1024:] = 0
    # compute the median of each column
    with warnings.catch_warnings():
        warnings.filterwarnings(action="ignore", message="All-NaN slice encountered")
        colmeds = np.nanmedian(rimage, axis=0)
    # create a smoothed version to avoid removing large scale structure
    colmeds_smooth = convolve(colmeds - np.nanmedian(colmeds), g)
    # remove large scale structure from column medians
    colmeds_sub = colmeds - colmeds_smooth
    # make the 2D image version
    for j in range(rdata.shape[0]):
        colimage[j, :] = colmeds_sub
    # NaN all the no data pixels so they are not included in the median
    colimage[bdata] = np.NaN
    # subtarct the mean as we only want to remove residuals
    colimage -= np.nanmedian(colimage)
    # zero all the no data pixels
    colimage[bdata] = 0.0

    rdata.data -= colimage

    # save the new rateints and rate results
    nfile = mfile.replace("cal.fits", "cccal.fits")
    rdata.save(nfile)


def cal_row_clean(mfile, exclude_above=None):
    """
    Remove the median of each row to suppress residual detector artifacts

    works on cal images

    Parameters
    ----------
    mfile : str
        filename with a MIRI cal image (i.e., xxx_cal.fits)
    exclude_above : float
        value above which to exclude data from calculating the column median
    """
    # Create kernel
    g = Gaussian1DKernel(stddev=330)
    #g = Gaussian1DKernel(stddev=150)

    # read in the final cal image
    rdata = datamodels.open(mfile)

    # use the cal file dq flags as only after flat fielding are the outside the
    # FOV regions flagged
    bdata = rdata.dq & dqflags.pixel["DO_NOT_USE"] > 0

    rimage = copy.deepcopy(rdata.data)

    rowimage = np.zeros(rdata.data.shape)

    # mask all the do_not_use data with NaNs
    rimage[bdata] = np.NaN
    # also remove zeros due to 2nd+ integration bug
    rimage[rimage == 0.0] = np.NaN
    # mask data above a threshold
    if exclude_above is not None:
        rimage[rimage > exclude_above] = np.NaN
    # exclude everything to the left of the imager FOV (basically the Lyot)
    rimage[:, 0:325] = np.NaN
    # exclude that bright column near the right edge
    rimage[:,1024:] = 0
    # compute the median of each column
    with warnings.catch_warnings():
        warnings.filterwarnings(action="ignore", message="All-NaN slice encountered")
        rowmeds = np.nanmedian(rimage, axis=1)
    # create a smoothed version to avoid removing large scale structure
    rowmeds_smooth = convolve(rowmeds - np.nanmedian(rowmeds), g)
    # remove large scale structure from column medians
    rowmeds_sub = rowmeds - rowmeds_smooth
    # make the 2D image version
    for i in range(rimage.shape[1]):
        rowimage[:, i] = rowmeds_sub
    # NaN all the no data pixels so they are not included in the median
    rowimage[bdata] = np.NaN
    # subtarct the mean as we only want to remove residuals
    rowimage -= np.nanmedian(rowimage)
    # zero all the no data pixels
    rowimage[bdata] = 0.0

    rdata.data -= rowimage

    # save the new rateints and rate results
    nfile = mfile.replace("cccal.fits", "cccrcal.fits")
    rdata.save(nfile)


In [None]:
output_dir = './stage3/'
output_dir_tweakreg = output_dir+'tweakreg/'
output_dir_skymatch = output_dir+'skymatch/'
output_dir_outlier_detection = output_dir+'outlier_detection/'
output_dir_resample = output_dir+'resample/'
output_dir_source_catalog = output_dir+'source_catalog/'

os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_dir_tweakreg, exist_ok=True)
os.makedirs(output_dir_skymatch, exist_ok=True)
os.makedirs(output_dir_outlier_detection, exist_ok=True)
os.makedirs(output_dir_resample, exist_ok=True)
os.makedirs(output_dir_source_catalog, exist_ok=True)

In [None]:
input_dir='./stage2/'

In [None]:
os.system('cp '+input_dir+'*_cal.fits '+output_dir+'')

list_files=glob(output_dir+'*_cal.fits')


In [None]:
#Do de-striping
for i in range(len(list_files)):
	cal_column_clean(list_files[i])

list_files2=glob(output_dir+'*cccal.fits')

for i in range(len(list_files2)):
	cal_row_clean(list_files2[i])

In [None]:
# Create and open the association file and load into json object

os.system('asn_from_list -o level3_asn.json --product-name l3_results '+output_dir+'*_cccrcal.fits')
#os.system('asn_from_list -o level3_asn.json --product-name l3_results '+output_dir+'*_cal.fits') # Use this if no de-striping is done

asn_file ='level3_asn.json'

In [None]:
#If some parameters are known to have better results with certain value use the dictionary
parameter_dict = {"tweakreg": {"snr_threshold": 10.0, "brightest": 100,"output_dir": output_dir_tweakreg,"save_results": True,"abs_refcat":'GAIA_DR3'}, #kernel_fwhm can also be added for better centroiding
                  "skymatch": {"skip":False,"skymethod":'local',"output_dir": output_dir_skymatch,"save_results": True,"subtract":True,"match_down":True},
                  "outlier_detection": {"output_dir": output_dir_outlier_detection,"save_results": True},
                  "resample": {"pixfrac": 1.0/4,"output_dir": output_dir_resample,"save_results": True}, #pixfrac can be higher if one wants lesser resolution
                  "source_catalog": {"snr_threshold": 10.0,"output_dir": output_dir_source_catalog,"save_results": True}, 
                 }

In [None]:
call_output = calwebb_image3.Image3Pipeline.call(asn_file, output_dir=output_dir, save_results=True, steps=parameter_dict,logcfg='stage3-log.cfg')

#os.system('rm '+output_dir+'*_cal.fits')

We encountered an oblique CR hit for observations of SGAS 1226 and created a custom function to set the DQ values as 1 for the region from a DS9 region file. Below is a function for that purpose.

In [None]:
#function to remove cosmic ray artifact from filters affected by the CR hit; inspired from the column and row de-trending functions   
def cal_cr_remove(mfile,regfile):
    """
    Remove the cosmic ray detector artifact where a previous cosmic ray makes certain pixels behave non-linearly
    
    Could be used to set any region as 'DO NOT USE' in the DQ array with a DS9 region file. 

    tested to work on cal images for now

    Parameters
    ----------
    mfile : str
        filename with a MIRI cal image (i.e., xxx_cal.fits)
    regfile : str
        region filename from ds9 covering the artifiact
    """
    # read in the final image
    cdata = datamodels.open(mfile)
    # read the cosmic ray region file
    cr_region=Regions.read(regfile,format='ds9')
    #create a mask from the region file with same shape as image
    im_shape=cdata.data.shape
    cr_region1=cr_region[0]
    cr_region_mask=cr_region1.to_mask()
    masked_im=cr_region_mask.to_image(im_shape)
    #change the dq values for cr affected region
    index_cr=np.where(masked_im >= 1)
    cdata.dq[index_cr] = 1
    #cchange the values to 0 or nan for the affected region
    cdata.data[index_cr] = 0
    # save the new rateints and rate results
    nfile = mfile.replace("cal.fits", "crstreak_fixcal.fits")
    cdata.save(nfile)
