In [1]:
import numpy as np
import matplotlib.pyplot as plt
from astropy import units
from scipy.stats import median_abs_deviation

from lsst.daf.butler import Butler
from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS

from tqdm import tqdm

import pickle
import treegp

from sklearn.neighbors import KNeighborsRegressor
import copy

ARCSEC = (180. / np.pi) * 60 * 60

In [4]:
butler = Butler("/repo/main", instrument="HSC")
camera = butler.get('camera', instrument="HSC", collections="HSC/runs/RC2/w_2024_14/DM-43718")

visits = []

for i, ref in enumerate(butler.registry.queryDatasets("preSourceTable_visit", collections="HSC/runs/RC2/w_2024_14/DM-43718")):
    if i>999:
        break
    else:
        visits.append(ref.dataId.mapping['visit'])


data = {}

for visit in tqdm(visits):

    if len(data) > 999:
        break

    cat_main = butler.get("preSourceTable_visit", visit=visit,
                          collections="HSC/runs/RC2/w_2024_14/DM-43718", storageClass="ArrowAstropy")
    #cat_cg = butler.get("preSourceTable_visit", visit=visit,
    #                    collections="u/erykoff/RC2/DM-38632/run2/step2b", storageClass="ArrowAstropy")


    use, = np.where((cat_main["sizeExtendedness"] < 0.2) & (cat_main["detect_isPrimary"]) & (cat_main["ap12Flux"]/cat_main["ap12FluxErr"] > 15.0))

    cat_main = cat_main[use]
    #cat_cg = cat_cg[use]

    mag_ap12_main = (cat_main["ap12Flux"]*units.nJy).to_value(units.ABmag)
    mag_ap12_main_err = (cat_main["ap12FluxErr"]*units.nJy).to_value(units.ABmag)
    mag_psf_main = (cat_main["psfFlux"]*units.nJy).to_value(units.ABmag)
    mag_psf_main_err = (cat_main["psfFluxErr"]*units.nJy).to_value(units.ABmag)
    
    

    bright, = np.where(mag_ap12_main < 30.)

    mag_ap12_main = mag_ap12_main[bright]
    mag_ap12_main_err =  mag_ap12_main_err[bright]
    mag_psf_main = mag_psf_main[bright]
    mag_psf_main_err = mag_psf_main_err[bright]
    x = cat_main["x"][bright]
    y = cat_main["y"][bright]
    detectors = cat_main["detector"][bright]
    band = cat_main["band"][bright]
    u = np.zeros(len(x))
    v = np.zeros(len(y))
    ccd_ids = set(detectors)

    for ccd in ccd_ids:
        # print(ccd)
        detector = camera[ccd]
        mapping = detector.getTransform(PIXELS, FIELD_ANGLE).getMapping()
        Filter = (detectors == ccd)
        points = np.array([[x[Filter], y[Filter]]]).T
        coord = np.array([mapping.applyForward(p) for p in points])
        u[Filter] = coord[:,0]
        v[Filter] = coord[:,1]


    data.update({visit: {
        'x':x,
        'y':y,
        'ccds': detectors, 
        'band':band,
        'u': u,
        'v': v,
        'mag_ap12_main':mag_ap12_main,
        'mag_psf_main': mag_psf_main
        
    }})

# f = open('fov_photometry_hsc.pkl', 'wb')
# pickle.dump(data, f)
# f.close()

  0%|          | 2/404 [00:00<02:11,  3.06it/s]


In [2]:
data = pickle.load(open('fov_photometry_hsc.pkl', 'rb'))

In [None]:
M = {'all': treegp.meanify(bin_spacing=20., statistics='median')}


# medianify = treegp.meanify(bin_spacing=20., statistics='median')

for vis in tqdm(data):
    coord = (np.array([data[vis]['u'], data[vis]['v']]).T)  * (180. / np.pi) * 60 * 60 
    param = data[vis]["mag_psf_main"] - data[vis]["mag_ap12_main"]
    # medianify.add_field(coord, param)
    M['all'].add_field(coord, param)
    band = data[vis]['band'][0]
    if band not in M:
        M.update({band: treegp.meanify(bin_spacing=20., statistics='median')})
    M[band].add_field(coord, param)
        

for key in tqdm(M): 
    M[key].meanify()    

In [None]:
for key in tqdm(M):
    MAX = 0.01
    COLORMAP = plt.cm.inferno #None# plt.cm.seismic
    plt.figure(figsize=(12,12))
    plt.scatter(M[key].coords0[:,0], M[key].coords0[:,1], c=M[key].params0, s = 1, vmin=-MAX, vmax=MAX, cmap = COLORMAP)
    cb = plt.colorbar()
    plt.axis('equal')
    plt.xlabel('x (arcsec)', fontsize=14)
    plt.ylabel('y (arcsec)', fontsize=14)
    plt.title(f'Median of residuals over 404 visits (Filter: {key} | bin size = 20 arcsec)', fontsize=14)
    cb.set_label("mag_psf_main - mag_ap12_main", fontsize=14)
    plt.savefig(f"stack_plot/median_phot_hsc_{key}.png")
    #plt.close()

In [None]:
KNN = {}

for key in tqdm(M):

    KNN.update({key: KNeighborsRegressor(n_neighbors=3)})
    KNN[key].fit(M[key].coords0, M[key].params0)
# average = neigh.predict(X)

In [None]:
for vis in tqdm(data):
    coord = (np.array([data[vis]['u'], data[vis]['v']]).T)  * (180. / np.pi) * 60 * 60
    band = data[vis]['band'][0]
    correction_all = KNN['all'].predict(coord)
    correction_band = KNN[band].predict(coord)
    data[vis].update({'correction_all': correction_all, 'correction_band':correction_band})

In [None]:
apperture = {}
residuals = {}
corrections_all = {}
corrections_band = {}

for vis in tqdm(data):
    band = data[vis]['band'][0]
    if band not in apperture:
        apperture.update({band:[]})
        residuals.update({band:[]})
        corrections_all.update({band:[]})
        corrections_band.update({band:[]})
        
    apperture[band].append(data[vis]["mag_ap12_main"])
    residuals[band].append(data[vis]["mag_psf_main"] - data[vis]["mag_ap12_main"])
    corrections_all[band].append(data[vis]["correction_all"])
    corrections_band[band].append(data[vis]["correction_band"])

for band in apperture:

    apperture[band] = np.concatenate(apperture[band])
    residuals[band] = np.concatenate(residuals[band])
    corrections_all[band] = np.concatenate(corrections_all[band])
    corrections_band[band] = np.concatenate(corrections_band[band])

In [None]:
for band in apperture:
    bright, = np.where(apperture[band]<19.5)
    mad = median_abs_deviation(residuals[band][bright], scale="normal")*1000.
    median_value = np.median(residuals[band][bright])*1000.
    plt.figure()
    plt.hexbin(apperture[band][bright], residuals[band][bright], bins="log", extent=[15.5, 19.5, -0.1, 0.1], vmin=1, vmax=500)
    plt.plot([15.5,19.5], [0,0], 'k--')
    plt.xlabel("ap12_main")
    plt.ylabel("psf_main - ap12_main")
    plt.title(f"All visits ({band}), \n median={median_value:.3f} mmag, mad={mad:.3f} mmag")
    plt.tight_layout()
    plt.colorbar()
    plt.savefig(f'stack_plot/psf-ap12_{band}_0_before_correction.png')

In [None]:
for band in apperture:
    bright, = np.where(apperture[band]<19.5)
    mad = median_abs_deviation(residuals[band][bright] - corrections_all[band][bright], scale="normal")*1000.
    median_value = np.median(residuals[band][bright] - corrections_all[band][bright])*1000.
    plt.figure()
    plt.hexbin(apperture[band][bright], residuals[band][bright] - corrections_all[band][bright], bins="log", extent=[15.5, 19.5, -0.1, 0.1], vmin=1, vmax=500)
    plt.plot([15.5,19.5], [0,0], 'k--')
    plt.xlabel("ap12_main")
    plt.ylabel("psf_main - ap12_main - median_all")
    plt.title(f"All visits ({band}), \n median={median_value:.3f} mmag, mad={mad:.3f} mmag")
    plt.tight_layout()
    plt.colorbar()
    plt.savefig(f'stack_plot/psf-ap12_{band}_1_after_correction_median_all.png')

In [None]:
for band in apperture:
    bright, = np.where(apperture[band]<19.5)
    mad = median_abs_deviation(residuals[band][bright] - corrections_band[band][bright], scale="normal")*1000.
    median_value = np.median(residuals[band][bright] - corrections_band[band][bright])*1000.
    plt.figure()
    plt.hexbin(apperture[band][bright], residuals[band][bright] - corrections_band[band][bright], bins="log", extent=[15.5, 19.5, -0.1, 0.1], vmin=1, vmax=500)
    plt.plot([15.5,19.5], [0,0], 'k--')
    plt.xlabel("ap12_main")
    plt.ylabel(f"psf_main - ap12_main - median_{band}")
    plt.title(f"All visits ({band}), \n median={median_value:.3f} mmag, mad={mad:.3f} mmag")
    plt.tight_layout()
    plt.colorbar()
    plt.savefig(f'stack_plot/psf-ap12_{band}_2_after_correction_median_per_band.png')

In [None]:
for vis in tqdm(data):
    #if i>0: 
    #    break
    bright, = np.where(data[vis]["mag_ap12_main"]<19.5)
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    residuals = copy.deepcopy(data[vis]["mag_psf_main"][bright] -data[vis]["mag_ap12_main"][bright])
    MAX = np.std(residuals)
    plt.hexbin(data[vis]["u"][bright] * ARCSEC, data[vis]["v"][bright] * ARCSEC, C=residuals, cmap=plt.cm.seismic, vmin=-MAX, vmax=MAX)
    plt.xlabel("u")
    plt.ylabel("v")
    plt.colorbar(label="psf_main - ap12_main")
    plt.title(f"Visit {vis} ({data[vis]['band'][0]})")
    plt.tight_layout()
    plt.axis('equal')
    plt.subplot(1,2,2)
    mad = median_abs_deviation(residuals, scale="normal")*1000.
    median_value = np.median(residuals) * 1000.
    plt.hexbin(data[vis]["mag_ap12_main"][bright], residuals, bins="log", extent=[15.5, 19.5, -0.1, 0.1], vmin=1, vmax=10)
    plt.plot([15.5,19.5], [0,0], 'k--')
    plt.xlabel("ap12_main")
    plt.title(f"Visit {vis}, \n median={median_value:.3f} mmag, mad={mad:.3f} mmag")
    plt.tight_layout()
    plt.colorbar()
    plt.savefig(f'visit_level_plot/{vis}_0_no_correction.png')
    plt.close()
    
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    residuals = copy.deepcopy(data[vis]["mag_psf_main"][bright]-data[vis]["mag_ap12_main"][bright]-data[vis]['correction_all'][bright])
    plt.hexbin(data[vis]["u"][bright] * ARCSEC, data[vis]["v"][bright] * ARCSEC, C=residuals, cmap=plt.cm.seismic, vmin=-MAX, vmax=MAX)
    plt.xlabel("u")
    plt.ylabel("v")
    plt.colorbar(label="psf_main - ap12_main - median_all")
    plt.title(f"Visit {vis} ({data[vis]['band'][0]})")
    plt.tight_layout()
    plt.axis('equal')
    plt.subplot(1,2,2)
    mad = median_abs_deviation(residuals, scale="normal")*1000.
    median_value = np.median(residuals) * 1000.
    plt.hexbin(data[vis]["mag_ap12_main"][bright], residuals, bins="log", extent=[15.5, 19.5, -0.1, 0.1], vmin=1, vmax=10)
    plt.plot([15.5,19.5], [0,0], 'k--')
    plt.xlabel("ap12_main")
    plt.title(f"Visit {vis}, \n median={median_value:.3f} mmag, mad={mad:.3f} mmag")
    plt.tight_layout()
    plt.colorbar()
    plt.savefig(f'visit_level_plot/{vis}_1_correction_median_all.png')
    plt.close()
    
    
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    residuals = copy.deepcopy(data[vis]["mag_psf_main"][bright]-data[vis]["mag_ap12_main"][bright]-data[vis]['correction_band'][bright])
    plt.hexbin(data[vis]["u"][bright] * ARCSEC, data[vis]["v"][bright] * ARCSEC, C=residuals, cmap=plt.cm.seismic, vmin=-MAX, vmax=MAX)
    plt.xlabel("u")
    plt.ylabel("v")
    plt.colorbar(label=f"psf_main - ap12_main - median_{data[vis]['band'][0]}")
    plt.title(f"Visit {vis} ({data[vis]['band'][0]})")
    plt.tight_layout()
    plt.axis('equal')
    plt.subplot(1,2,2)
    mad = median_abs_deviation(residuals, scale="normal")*1000.
    median_value = np.median(residuals) * 1000. 
    plt.hexbin(data[vis]["mag_ap12_main"][bright], residuals, bins="log", extent=[15.5, 19.5, -0.1, 0.1], vmin=1, vmax=10)
    plt.plot([15.5,19.5], [0,0], 'k--')
    plt.xlabel("ap12_main")
    plt.title(f"Visit {vis}, \n median={median_value:.3f} mmag, mad={mad:.3f} mmag")
    plt.tight_layout()
    plt.colorbar()
    plt.savefig(f'visit_level_plot/{vis}_2_correction_median_per_band.png')
    plt.close()

In [3]:
# 326 

data[326]

{'x': <Column name='x' dtype='float64' length=8770>
 1173.1422630926218
  433.7100044447944
 354.27967817275743
 1731.5226318344412
 1118.0266842143646
 1131.5735835433575
 214.75102204277525
 339.44244730706265
  1065.182461911502
  1557.037003852923
 1118.9841270973527
  251.2886695020959
                ...
  249.8223728543757
  419.4962976963897
 1861.1245071579988
 1209.6263694224504
 1391.0982792928671
 1046.0291172419732
 293.31374183192827
  683.1313927589509
   671.995297655686
 307.14019719745284
  984.8910580996709
  595.8143835259342,
 'y': <Column name='y' dtype='float64' length=8770>
  38.13682575149839
    276.19237680051
 352.12018481025746
 469.44447685745774
  566.5781422259448
  608.9231457246201
    738.44575668672
  756.2740198798123
  760.9503764119793
  897.7096518222652
  947.9247893326469
 1056.9825922583284
                ...
 3030.2101056707975
  3166.514055147941
  2112.864259347662
 478.73807781271097
  556.5315330571152
 1029.8880894701538
  1242.72354585

In [8]:
for key in cat_main.keys():
    if "psfFlux" in key:
        print(key)
    if "ap12Flux" in key: 
        print(key)

ap12Flux
ap12FluxErr
ap12Flux_flag
psfFlux
psfFluxErr
psfFlux_apCorr
psfFlux_apCorrErr
psfFlux_area
psfFlux_flag
psfFlux_flag_apCorr
psfFlux_flag_edge
psfFlux_flag_noGoodPixels
