# Background Subtraction and Stage 3 Processing

- Author: J. Aguilar (jaguilar@stsci.edu)
- Date: Aug 9, 2022


The presence of glow sticks in the MIRI coronagraph means that, for the time being, users must acquire background observations and subtract them off their science observations to perform accurate PSF subtraction. This notebook will guide users through the process of:
- Identifying which files are background observations
- Combining and subtracting the backgrounds from the science observations
- Running the new background-subtracted images through Stage 3 of the JWST calibration pipeline by creating a new asn file

## Data
The ERS data used in this tutorial are available at on MAST (mast.stsci.edu). The following link provides a shortcut: https://mast.stsci.edu/portal/Mashup/Clients/Mast/Portal.html?searchQuery=%7B%22service%22%3A%22CAOMBYPROPID%22%2C%22inputText%22%3A%5B%7B%22paramName%22%3A%22proposal_id%22%2C%22niceName%22%3A%22proposal_id%22%2C%22values%22%3A%5B%221386%22%5D%2C%22valString%22%3A%221386%22%2C%22displayString%22%3A%221386%22%2C%22isDate%22%3Afalse%2C%22facetType%22%3A%22discrete%22%7D%5D%2C%22paramsService%22%3A%22Mast.Caom.Filtered%22%2C%22title%22%3A%22Proposal%20ID%20Results%22%2C%22columns%22%3A%22*%22%2C%22caomVersion%22%3Anull%7D

## Non-standard lib requirements
- `jwst` https://jwst-pipeline.readthedocs.io/
- `astropy` https://docs.astropy.org/

In [None]:
import jwst
print("Using pipeline version:", jwst.__version__)
from jwst.pipeline import Coron3Pipeline

In [None]:
from importlib import reload
from pathlib import Path
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from astropy.io import fits

# Load the observations

We're going to assume that there are `calints.fits` (aka Stage 2b) files available from MAST, but the backgrounds didn't get subtracted properly. If you only have `rateints.fits` (Stage 2a) or `uncal.fits` (Stage 1) files, refer to the pipeline documentation for how to process them up to stage 3.

Set the variable `datapath` to wherever you downloaded the MAST data. I assume all the `.fits` files are in a single folder. If you have it set up differently, the important thing is that the variable `data_files` is a list of paths to the `calints.fits` files.

In [None]:
datapath = Path("/Users/jaguilar/Projects/jwst_programs/mast_data/01386-2/01386/")
data_files = sorted(datapath.glob("j*calints.fits"))

In [None]:
# enforce that all the members of data_files are pathlib.Paths
for i, f in enumerate(data_files):
    data_files[i] = Path(f)
data_files

## Organize the science and background observations so you can subtract the appropriate ones from each other.

We're going to use the observation number as the key for associating the background observations with their corresponding science observations.

We'll store everything in dictionaries and use the file names and observation numbers to look things up.

Files that have the same observation number are different dithers from the same observation

In [None]:
# index them by the observation number
obsnum_filenames = {f.name: Path(f).name[7:10] for f in data_files}
obsnum_filenames

# Match the background files

Connect the science observations (target and reference) with their corresponding background observations
Use the observation number from the APT file.
Format is dict with key: val pairs as  {sci: bgnd}

In [None]:
# Get the background exposures - there's one of each. 
# If science exposures have more than one associated background exposure, combine the backgrounds using min or median before subtracting
bgnd_obsnums = ['032', '033', '034', '035', '036','037']
bgnd_files  = {obsnum: [] for obsnum in bgnd_obsnums}
for f, num in obsnum_filenames.items():
    if num in bgnd_obsnums:
        bgnd_files[num].append(f)
    else:
        pass
bgnd_files

In [None]:
# combine the background images
bgnd_imgs = {}
for num, files in bgnd_files.items():
    img = np.stack([fits.getdata(datapath / f, 1) for f in files])
    
    # define the function you will use to combine the images. Here we use a simple one
    func = np.nanmedian if len(files) > 2 else np.nanmin
    img = func(img, axis=0)
    
    # flatten the background image until it's 2-D
    while np.ndim(img) > 2:
        img = np.nanmean(img, axis=0)
    
    bgnd_imgs[num] = img

In [None]:
bgnd_imgs.keys()

In [None]:
bgnd_matched_obsnums = {
    # sci obs num : bgnd obs num
    '019': '032',
    '020': '032',
    '021': '033',
    '022': '034',
    '023': '034',
    '024': '035',
    '025': '035',
    '026': '036', 
    '027': '037',
}

In [None]:
# Match the science images with their corresponding background images.
# This convoluted inline expression just indexes all the dictionaries we've set up so far.
# In the end we're left with a dictionary containing each matched pair. 
# The pair is stored in a dict where the science observation has key 'sci', 
# and the background observation has key 'bgnd'
bgnd_matched_imgs = {sci_file: {'sci': fits.getdata(datapath / sci_file, 1), 'bgnd': bgnd_imgs[bgnd_matched_obsnums[sci_obsnum]]} 
                      for sci_file, sci_obsnum in obsnum_filenames.items() 
                      if sci_obsnum in bgnd_matched_obsnums.keys()
                    }

# Subtract the backgrounds

Now that you have matched each science file with its background (and combined backgrounds if necessary), you can subtract them from each other

In [None]:
bgnd_sub_imgs = {} # this will be indexed by the original filename
for sci_file, pair in bgnd_matched_imgs.items():
    sci_img = pair['sci']
    bgnd_img = pair['bgnd']
    # Subtract the background off. Since the bgnd is 2-D, the array shapes should automatically broadcast
    img = sci_img - bgnd_img
    bgnd_sub_imgs[sci_file] = img

## Preview the images

In [None]:
# Some observations have multiple dithers
ncols = 3
nrows = len(bgnd_matched_imgs)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(6*ncols, 6*nrows))

for i, (sci_file, pair) in enumerate(bgnd_matched_imgs.items()):
    obsnum = sci_file[7:10]

    # pull the images
    sci_img = pair['sci']
    bgnd_img = pair['bgnd']
    bgsub_img = bgnd_sub_imgs[sci_file]
    # reduce dimensions for plotting
    while np.ndim(sci_img) > 2:
        sci_img = np.nanmedian(sci_img, axis=0)
    while np.ndim(bgnd_img) > 2:
        bgnd_img = np.nanmedian(bgnd_img, axis=0)
    while np.ndim(bgsub_img) > 2:
        bgsub_img = np.nanmedian(bgsub_img, axis=0)

    for img, ax in zip((sci_img, bgnd_img, bgsub_img), axes[i]):
        vmin, vmax = np.quantile(img, [0.01, 0.99])
        ax.imshow(img, vmin=vmin, vmax=vmax, origin='lower', cmap=mpl.cm.magma)
        ax.set_aspect('equal')

    axes[i, 0].set_ylabel(f"Obs {obsnum}", size='x-large')
    axes[i, 1].set_ylabel(f"Obs {bgnd_matched_obsnums[obsnum]}", size='x-large')

axes[0][0].set_title("Science", size='x-large')
axes[0][1].set_title("Background", size='x-large');
axes[0][2].set_title("Background Removed", size='x-large');

## Write images to file

You can either replace the SCI HDU data, or write a new fits file. It's important to remember to preserve the ASDF extension because that contains the WCS information used by the pipeline for alignment. In this example, we are going to write new files with `bgsub` appended to the end fo the filename.

The target folder is `'../bgnd_sub_imgs/'` but you can set it to anything you like.

In [None]:
bgsub_path = Path("../bgnd_sub_imgs/")

def write_bgsub(sci_file, img, path):
    """Write the background-subtracted version of a file to disk"""
    sci_file = Path(sci_file)
    bgsub_name = Path(path) / (sci_file.stem + '_bgsub' + sci_file.suffix)
    with fits.open(sci_file) as hdulist:
        hdus = [i.copy() for i in hdulist]
    hdus[1].data = img
    bgsub_hdulist = fits.HDUList(hdus)
    bgsub_hdulist.writeto(bgsub_name, overwrite='True')
    print("Wrote", bgsub_name)
for sci_file, img in bgnd_sub_imgs.items():
    write_bgsub(datapath / sci_file, img, bgsub_path)

# Stage 3

Replace the filenames in the Stage 3 association file with the background-subtracted filenames, or write your own Stage 3 association file -- which is what we will do here -- and run Stage 3 

In [None]:
def write_dummy_asn(filename, name, filedict):
    """
    Write a dummy ASN file for manually processing files through Stage 3

    Parameters
    ----------
    filedict: dict
      dict where the key is the relative path and filename for the association
      file, and the value is "science" or "psf"
    name: prefix for the stage 3 output files
    sci_files: list of str
      list of paths to the science image files
    psf_files: list of str
      list of paths to the psf image files

    Output
    ------
    association file written to given location

    """
    # make the specifications for the image files
    def make_entries(filedict, ntabs=3):
        tab='\t'
        line = lambda key, val: f"{tab*(ntabs+1)}\"expname\": \"{key}\",\n{tab*(ntabs+1)}\"exptype\": \"{val}\""
        lines = f"\n{tab*ntabs}}},{{\n".join(line(key, val) for key, val in filedict.items())
        return lines
    file_str = make_entries(filedict, 5)

    template = f"""{{
    "asn_type": "coron3",
    "asn_rule": "candidate_Asn_Lv3Coron",
    "program": "{name}",
    "asn_id": "c1001",
    "target": "dummy",
    "asn_pool": "{name}-pool",
    "products": [{{
        "name": "{name}",
        "members": [{{
    {file_str}
    }}]
    }}]
    }}"""
    # make sure the file ends in .json
    filename = Path(filename).with_suffix(".json")
    with open(str(filename), 'w') as ff:
        ff.write(template)

In [None]:
files = sorted(bgsub_path.glob("jw*calints_bgsub.fits"))
obsnum_files = {f.name: f.name[7:10] for f in files}
obsnum_files

In [None]:
# output directory
output_folder = "pipeline_output/"

## 1065

use the APT file to figure out which observation numbers correspond to the science target and which to the reference PSF target.
For HIP 65426's F1140C observations, Observations 4 and 5 are the science, and 6 are the references

In [None]:
sci = {f: obsnum for f, obsnum in obsnum_files.items() if obsnum in ['019','020']}
ref = {f: obsnum for f, obsnum in obsnum_files.items() if obsnum == '021'}

In [None]:
sci

In [None]:
ref

Copy the above filenames into a new association file and write it to disk

In [None]:
asn_file = "./stage3_asn_hd141569a_1065.json"

# list the files and whether they are science or psf
files = {}
for f in sci.keys():
    files["../bgnd_sub_imgs/" + f] = 'science'
for f in ref.keys():
    files["../bgnd_sub_imgs/" + f] = 'psf'
write_dummy_asn(asn_file, "01386_hd141569a_1065", files)

Run Stage 3 with a hand-made association file. 

In [None]:
# it's a disk target so we don't want to use many KLIP modes - set it to 1. Default is 50
params = {'output_dir': "../pipeline_output/", # default is '.'
          'steps': { # this is optional
              'klip': {'truncate': 5}
                   }
         }
cor3 = Coron3Pipeline.call(str(asn_file), **params)


In [None]:
img = fits.getdata("../pipeline_output/01386_hd141569a_1065_i2d.fits")

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8,8))
vmin, vmax = np.quantile(img, [0.01, 0.99])
ax.imshow(img, origin='lower', vmin=vmin, vmax=vmax)

## 1140
For F1140C, Observations 22 and 23 are the science, and 24 are the references

In [None]:
sci = {f: obsnum for f, obsnum in obsnum_files.items() if obsnum in ['022','023']}
ref = {f: obsnum for f, obsnum in obsnum_files.items() if obsnum == '024'}

In [None]:
sci

In [None]:
ref

Copy the above filenames into a new association file

In [None]:
# list the files and whether they are science or psf
files = {}
for f in sci.keys():
    files["../bgnd_sub_imgs/" + f] = 'science'
for f in ref.keys():
    files["../bgnd_sub_imgs/" + f] = 'psf'
files

In [None]:
# write the asn file
asn_file = "./stage3_asn_hd141569a_1140.json"

write_dummy_asn(asn_file, "01386_hd141569a_1140", files)

In [None]:
# it's a disk target so we don't want to use many KLIP modes - set it to 1. Default is 50
params = {'output_dir': "../pipeline_output/", # default is '.'
          'steps': { # this is optional
              'klip': {'truncate': 1}
                   }
         }
cor3 = Coron3Pipeline.call(str(asn_file), **params)


In [None]:
# it's a disk target so we don't want to use many KLIP modes - set it to 1. Default is 50
params = {'output_dir': "../pipeline_output/", # default is '.'
          'steps': { # this is optional
              'klip': {'truncate': 1}
                   }
}
cor3 = Coron3Pipeline.call(str(asn_file), **params)


In [None]:
img = fits.getdata("../pipeline_output/01386_hd141569a_1140_i2d.fits")

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8,8))
vmin, vmax = np.quantile(img, [0.01, 0.99])
ax.imshow(img, origin='lower', vmin=vmin, vmax=vmax)