# Collecting Images For Subtraction

This notebooks demonstrates how to collect template images and science images from [Rob's database](https://github.com/Roman-Supernova-PIT/roman-desc-simdex/blob/main/documentation.ipynb) for image subtraction. A detailed approach to searh Roman simulated images can be found [here](https://github.com/Roman-Supernova-PIT/roman-desc-simdex/blob/main/documentation.ipynb).

In [1]:
import requests
import pandas as pd
from astropy.io import fits
from astropy.table import Table
from astropy.wcs import WCS

from astropy.coordinates import SkyCoord
from astropy.wcs.utils import skycoord_to_pixel

def get_sne(z_cmb_min=0.15, z_cmb_max=0.16):
    server_url = 'https://roman-desc-simdex.lbl.gov'
    req = requests.Session()
    result = req.post( f'{server_url}/findtransients/z_cmb_min={z_cmb_min}/z_cmb_max={z_cmb_max}/gentype=10',
                         json={ 'fields': ('id', 'ra', 'dec', 'z_cmb',  ) } ) # 'host_sn_sep' 'peak_mjd', 'peak_mag_g', 'model_name'
    if result.status_code != 200:
        raise RuntimeError( f"Got status code {result.status_code}\n{result.text}" )
    df = pd.DataFrame( result.json() )
    return df

def get_image_info(ra, dec):
    server_url = 'https://roman-desc-simdex.lbl.gov'
    req = requests.Session()
    result = req.post( f'{server_url}/findromanimages', json={ 'containing': [ ra, dec ] } )
    if result.status_code != 200:
        raise RuntimeError( f"Got status code {result.status_code}\n{result.text}" )
    df = pd.DataFrame( result.json() )
    return df

def load_wcs(image_path, hdu_id=0):
    with fits.open(image_path) as hdul:
        header = hdul[hdu_id].header
        wcs = WCS(header)
    return wcs

def load_table(table_path):
    table = Table.read(table_path, format='ascii').to_pandas()
    return table

def radec_to_xy(ra, dec, wcs, origin=0):
    # ra and dec are in degree unit
    sky_coord = SkyCoord(ra, dec, frame='icrs', unit='deg')
    pixel_coords = skycoord_to_pixel(coords=sky_coord, wcs=wcs, origin=origin)
    return pixel_coords[0], pixel_coords[1]

def xy_in_image(x, y, width, height, offset=0):
    return (0 + offset <= x) & (x < width - offset) & (0 + offset <= y) & (y < height - offset)

## Get information of SNe

In [2]:
sne = get_sne(z_cmb_min=0.10, z_cmb_max=0.17)
print(f'Found {len(sne)} SNe.')

Found 277 SNe.


## Pick 1 SN

In [3]:
band = 'R062'

for idx in range(len(sne)):
    sn = sne.iloc[idx]
    image_info = get_image_info(sn.ra, sn.dec)
    image_info = image_info.sort_values(by=['mjd'])
    # We need at least 11 images which cover the same sn.
    # We use the first image as template, and the remaining as science images.
    if len(image_info) < 11:
        continue
    image_info = image_info[image_info['filter']==band].copy().reset_index(drop=True)
    break


We pick the first image as the template. For the science image, we select from the remaining images in reverse chronological order toward the template image. We require at least 50% of the truth in the template should exist in the science. This ensure both images cover enough overlapping region.

In [4]:
INPUT_IMAGE_PATTERN = ("/global/cfs/cdirs/lsst/shared/external/roman-desc-sims/Roman_data"
                                "/RomanTDS/images/simple_model/{band}/{pointing}/Roman_TDS_simple_model_{band}_{pointing}_{sca}.fits.gz")
INPUT_TRUTH_PATTERN = ("/global/cfs/cdirs/lsst/shared/external/roman-desc-sims/Roman_data"
                             "/RomanTDS/truth/{band}/{pointing}/Roman_TDS_index_{band}_{pointing}_{sca}.txt")

MATCH_RADIUS = 0.4
image_width =4088
image_height = 4088

selected_rows = []
template_info = image_info.iloc[0]

for i in range(len(image_info)-1, 0, -1):
    science_info = image_info.iloc[i]
    science_id = {'band': band, 'pointing': int(science_info['pointing']), 'sca': int(science_info['sca'])}
    template_id = {'band': band, 'pointing': int(template_info['pointing']), 'sca': int(template_info['sca'])}

    science_truth_path = INPUT_TRUTH_PATTERN.format(**science_id)                
    science_truth = load_table(science_truth_path)

    template_image_path = INPUT_IMAGE_PATTERN.format(**template_id)
    template_wcs = load_wcs(template_image_path, hdu_id=1)

    # Some sources from the truth table are not in the image. We need to remove them.
    science_in_science = xy_in_image(science_truth.x, science_truth.y, width=image_width, height=image_height, offset=0)
    science_truth = science_truth[science_in_science].copy().reset_index(drop=True)

    x_in_template, y_in_template = radec_to_xy(science_truth.ra, science_truth.dec, template_wcs, origin=1)
    science_in_template = xy_in_image(x_in_template, y_in_template, width=image_width, height=image_height, offset=0)
    
    if science_in_template.sum() / len(science_in_template) > 0.5:
        selected_rows.append(i)
    if len(selected_rows) == 10:
        break



In [5]:
selected_rows.sort()
selected_image_info = image_info.iloc[selected_rows]

data_records = pd.DataFrame()
data_records['template_band'] = [band] * 10
data_records['template_pointing'] = [template_info.pointing] * 10
data_records['template_sca'] = [template_info.sca] * 10
data_records['science_band'] = [band] * 10
data_records['science_pointing'] = selected_image_info.pointing.values
data_records['science_sca'] = selected_image_info.sca.values
data_records.to_csv('../test/test_data_records.csv', index=False)