# Obtaining DIA cutouts from DP1

Example tutorial on DIA sources: [link](https://dp1.lsst.io/tutorials/notebook/201/notebook-201-5.html)

Example tutorial on ECDFS: [link](https://dp1.lsst.io/tutorials/notebook/301/notebook-301-4.html)

Schema: [link](https://sdm-schemas.lsst.io/)

Example tutorial on cutout triplet: [link](https://dp1.lsst.io/tutorials/notebook/306/notebook-306-2.html)

In [None]:
from lsst.rsp import get_tap_service
from lsst.rsp.service import get_siav2_service
from lsst.rsp.utils import get_pyvo_auth

import lsst.afw.display as afwDisplay
from lsst.afw.image import ExposureF
from lsst.afw.math import Warper, WarperConfig
from lsst.afw.fits import MemFileManager
import lsst.geom as geom

from pyvo.dal.adhoc import DatalinkResults, SodaQuery

import numpy as np
import matplotlib.pyplot as plt
from astropy import units as u

%matplotlib inline

In [None]:
service = get_tap_service("tap")
assert service is not None

sia_service = get_siav2_service("dp1")
assert sia_service is not None

afwDisplay.setDefaultBackend('matplotlib')

## ECDFS

ECDFS center.

In [None]:
ra_cen, dec_cen = 53.16, -28.10

In [None]:
query = """SELECT diaSourceId, band, visit, midpointMjdTai,
        ra, dec, raErr, decErr, ra_dec_Cov,
        x, y, xErr, yErr, centroid_flag,
        psfFlux, psfFluxErr, psfFlux_flag,
        apFlux, apFluxErr, apFlux_flag, snr,
        scienceFlux, scienceFluxErr, forced_PsfFlux_flag,
        isDipole, dipoleChi2, dipoleFluxDiff, dipoleFluxDiffErr,
        dipoleLength, dipoleMeanFlux, dipoleMeanFluxErr, dipoleNdata,
        extendedness,
        ixx, iyy, ixy, shape_flag,
        ixxPSF, iyyPSF, ixyPSF, psfNdata,
        pixelFlags, reliability, trailLength, trailFlux
        FROM dp1.DiaSource
        WHERE CONTAINS(POINT('ICRS', ra, dec),
        CIRCLE('ICRS', {}, {}, 1.0)) = 1
        ORDER BY diaSourceId ASC""".format(ra_cen, dec_cen)
print(query)
job = service.submit_job(query)
job.run()
job.wait(phases=['COMPLETED', 'ERROR'])
print('Job phase is', job.phase)
if job.phase == 'ERROR':
    job.raise_if_error()
assert job.phase == 'COMPLETED'

In [None]:
results = job.fetch_result().to_table()
print(len(results))

Check flux ranges.

In [None]:
print("%e, %e"%(np.min(results['apFlux']), np.max(results['apFlux'])))
print("%e, %e"%(np.min(results['psfFlux']), np.max(results['psfFlux'])))

Select high SNR ones on the difference image.

In [None]:
sel_snr = np.abs(results['apFlux'] / results['apFluxErr']) > 5
sel_snr &= np.abs(results['psfFlux'] / results['psfFluxErr']) > 5

print(np.sum(sel_snr))

Non-dipoles.

In [None]:
sel_not_dipole = ~results['isDipole']
print(np.sum(sel_not_dipole))

Good sources without flags.

In [None]:
sel_no_flag = ~results['apFlux_flag']
sel_no_flag &= ~results['psfFlux_flag']
print(np.sum(sel_no_flag))

### Check flux

In [None]:
sel_pos_ap = results['apFlux']>0
sel_pos_psf = results['psfFlux']>0

sel_neg_ap = ~sel_pos_ap
sel_neg_psf = ~sel_pos_psf

Check if positive sources always have apFlux > psfFlux. 

In [None]:
results_pos_ap = results[sel_pos_ap]
test = results_pos_ap['apFlux'] > results_pos_ap['psfFlux']
print(np.sum(test) / len(test))

results_pos_psf = results[sel_pos_psf]
test = results_pos_psf['apFlux'] > results_pos_psf['psfFlux']
print(np.sum(test) / len(test))

Check if negative sources always have apFlux < psfFlux. 

In [None]:
results_neg_ap = results[sel_neg_ap]
test = results_neg_ap['apFlux'] < results_neg_ap['psfFlux']
print(np.sum(test) / len(test))

results_neg_psf = results[sel_neg_psf]
test = results_neg_psf['apFlux'] > results_neg_psf['psfFlux']
print(np.sum(test) / len(test))

### Check distribution

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(7, 3), layout='constrained')
axs[0].scatter(results['apFlux'], results['apFlux_flag'], s=1, alpha=0.1)
axs[1].scatter(results['psfFlux'], results['psfFlux_flag'], s=1, alpha=0.1)

Plot their distribution.

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(6, 6), layout='constrained')

xlabel = 'psfFlux'
ylabel = 'apFlux'
#scale = 'asinh'
scale = 'symlog'

axs = axs.flatten()

axs[0].scatter(results[xlabel], results[ylabel], 
               s=0.5, alpha=0.1)

axs[1].scatter(results[sel_no_flag][xlabel], results[sel_no_flag][ylabel], 
               s=0.5, alpha=0.1)

axs[2].scatter(results[sel_no_flag&sel_snr][xlabel], 
               results[sel_no_flag&sel_snr][ylabel], 
               s=0.5, alpha=0.1)

axs[3].scatter(results[sel_no_flag&sel_snr&sel_not_dipole][xlabel], 
               results[sel_no_flag&sel_snr&sel_not_dipole][ylabel], 
               s=0.5, alpha=0.1)

for ind, ax in enumerate(axs):
    ax.set_xscale(scale)
    ax.set_yscale(scale)
    ax.set_xlim([-5e7, 5e7])
    ax.set_ylim([-5e7, 5e7])
    ax.set_xticks([-1e8, -1e5, -1e2, 0, 1e2, 1e5, 1e8])
    ax.set_yticks([-1e8, -1e5, -1e2, 0, 1e2, 1e5, 1e8])
    ax.axline((0,0), (1,1), c='r', ls=':')
    ax.set_aspect('equal')
    ax.set_title(ind)
    

fig.supxlabel('psfFlux')
fig.supylabel('apFlux')

## Cutout triplet

In [None]:
def get_cutout(dl_result, spherePoint, session, fov):
    sq = SodaQuery.from_resource(dl_result,
                                 dl_result.get_adhocservice_by_id("cutout-sync-exposure"),
                                 session=session)
    sphereRadius = fov * u.deg
    sq.circle = (spherePoint.getRa().asDegrees() * u.deg,
                 spherePoint.getDec().asDegrees() * u.deg,
                 sphereRadius)
    cutout_bytes = sq.execute_stream().read()
    sq.raise_if_error()
    mem = MemFileManager(len(cutout_bytes))
    mem.setData(cutout_bytes, len(cutout_bytes))
    return ExposureF(mem)

In [None]:
sel_high_snr = results['snr'] > 2000
print(np.sum(sel_high_snr))

results_s = results[sel_no_flag&sel_snr&sel_not_dipole&sel_high_snr]
ra = results_s[0]['ra']
dec = results_s[0]['dec']
visit = results_s[0]['visit']
midpointMjdTai = results_s[0]['midpointMjdTai']
band = results_s[0]['band']

print(ra, dec, visit, midpointMjdTai, band)

spherePoint = geom.SpherePoint(ra*geom.degrees, dec*geom.degrees)

In [None]:
circle = (ra, dec, 0.0001) # in deg

In [None]:
lvl2_table = sia_service.search(pos=circle, calib_level=2).to_table()

sel = lvl2_table['dataproduct_subtype'] == 'lsst.visit_image'
sel &= lvl2_table['lsst_visit'] == visit
sci_table = lvl2_table[sel]
print(len(sci_table))
#sci_table
print(sci_table.colnames)

In [None]:
lvl3_table = sia_service.search(pos=circle, calib_level=3).to_table()

sel = lvl3_table['dataproduct_subtype'] == 'lsst.template_coadd'
sel &= lvl3_table['lsst_band'] == band
ref_table = lvl3_table[sel]
print(len(ref_table))
print(ref_table.colnames)
ref_table
# if multiple, maybe edge/corner of tract/patch?

In [None]:
sel = lvl3_table['dataproduct_subtype'] == 'lsst.difference_image'
sel &= lvl3_table['lsst_visit'] == visit
diff_table = lvl3_table[sel]
print(len(diff_table))
#diff_table
print(diff_table.colnames)

In [None]:
print(set(lvl3_table['dataproduct_subtype']))

In [None]:
fig, ax = plt.subplots()

for j in range(len(ref_table)):

    ra_vert = np.array(ref_table['s_region'][j].split()[2:])[[0,2,4,6]]
    ra_vert = np.array([float(i) for i in ra_vert])

    dec_vert = np.array(ref_table['s_region'][j].split()[2:])[[1,3,5,7]]
    dec_vert = np.array([float(i) for i in dec_vert])

    ax.fill(ra_vert, dec_vert, alpha=0.2)

ax.scatter([ra], [dec], c='k')

In [None]:
dl_result_sci = DatalinkResults.from_result_url(sci_table['access_url'][0], session=get_pyvo_auth())

dl_result_ref = DatalinkResults.from_result_url(ref_table['access_url'][0], session=get_pyvo_auth())

dl_result_diff = DatalinkResults.from_result_url(diff_table['access_url'][0], session=get_pyvo_auth())

In [None]:
fov = 0.003 #deg
sci = get_cutout(dl_result_sci, spherePoint, get_pyvo_auth(), fov)
ref = get_cutout(dl_result_ref, spherePoint, get_pyvo_auth(), fov)
diff = get_cutout(dl_result_diff, spherePoint, get_pyvo_auth(), fov)

In [None]:
warper_config = WarperConfig()
warper = Warper.fromConfig(warper_config)

sci_wcs = sci.getWcs()
sci_bbox = sci.getBBox()

warped_ref = warper.warpExposure(sci_wcs, ref, destBBox=sci_bbox)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(8, 3))

plt.sca(ax[0])
display1 = afwDisplay.Display(frame=fig)
display1.scale('linear', 'zscale')
display1.image(sci.image)
plt.title('science')

plt.sca(ax[1])
display2 = afwDisplay.Display(frame=fig)
display2.scale('linear', 'zscale')
display2.image(warped_ref.image)
plt.title('template')

plt.sca(ax[2])
display3 = afwDisplay.Display(frame=fig)
display3.scale('linear', 'zscale')
display3.mtv(diff.image)

ax[0].set_axis_off()
ax[1].set_axis_off()
ax[2].set_axis_off()

plt.title('difference')
plt.tight_layout()
fig.suptitle(f'visit: {visit}')