# DIA Pipeline Analysis 02
#### flag
This notebook is compatible with `desc-stack-weekly` kernel.

In [1]:
import sqlite3
import re
import numpy as np
import pandas as pd
from lsst.afw.table import BaseCatalog

#### Flag analysis

In [2]:
##################################################### flags #########################################################
# 'base_NaiveCentroid_flag' could also be useful
SAT_FLAGS = ['base_PixelFlags_flag_saturated', 'base_PixelFlags_flag_saturatedCenter', 'base_PixelFlags_flag_suspect',
              'base_PixelFlags_flag_suspectCenter', 'base_PixelFlags_flag_offimage', 'base_PixelFlags_flag_edge',
              'base_PixelFlags_flag_bad']
DIPOLE_FLAGS = ['ip_diffim_DipoleFit_flag_classification', 'ip_diffim_DipoleFit_flag_classificationAttempted']

# 'base_GaussianFlux_flag_badShape'->'base_SdssShape_flag', 'slot_Shape'->'base_SdssShape'
SHAPE_FLAGS = ['base_SdssShape_flag', 'base_GaussianFlux_flag_badShape', 'slot_Shape_flag']
SELECTED_FLAGS = SAT_FLAGS + DIPOLE_FLAGS + SHAPE_FLAGS

####################### convert number of artifacts per calexp to number of artifacts per deg^2 #####################
CALEXP_TO_DEG2 = 1 / (4000 * 0.2 / 3600 * 4072 * 0.2 / 3600)

################################################### functions #######################################################
def cal_eff(synthetic_src_df, flag=None):
    # calculate efficiency
    n_synthetic_src = len(synthetic_src_df)
    synthetic_src_det = synthetic_src_df.loc[synthetic_src_df.matched_status.astype(bool)]
    if flag:
        eff = (synthetic_src_det[flag] == 0).sum() / n_synthetic_src
    else:
        eff = len(synthetic_src_det) / n_synthetic_src        
    return eff

def cal_ar(art_df, n_calexp, flag=None, calexp_to_deg2=CALEXP_TO_DEG2):
    # calculate number of artifacts per deg^2
    if flag:
        art_per_calexp = (art_df[flag] == 0).sum() / n_calexp
    else:
        art_per_calexp = len(art_df) / n_calexp        
    art_per_deg2 = art_per_calexp * calexp_to_deg2
    return art_per_deg2
    
def get_eff_ar(synthetic_src_df, art_df, n_calexp, flag_list, calexp_to_deg2=CALEXP_TO_DEG2):
    # get efficiency and artifact rate for each flag
    eff_ar = np.array([0, 0])
    for flag in flag_list:
        eff = cal_eff(synthetic_src_df, flag)
        ar = cal_ar(art_df, n_calexp, flag, calexp_to_deg2=CALEXP_TO_DEG2)
        row = np.array([eff, ar])
        eff_ar = np.vstack((eff_ar, row))
    eff_ar = eff_ar[1:, :]
    return eff_ar

def count_remaining_source(df, flag_list):
    # count number of sources which has no flag set to True
    tof = np.ones(len(df))
    for flag in flag_list:
        # if the flag is set to 0, the state is True
        state = (df[flag] == 0).to_numpy()
        tof = np.logical_and(tof, state)
    return tof.sum()

def eval_flags(synthetic_src_df, art_df, n_calexp, flag_list, calexp_to_deg2=CALEXP_TO_DEG2):
    # evaluate flag selection results
    eff_ar = get_eff_ar(synthetic_src_df, art_df, n_calexp, flag_list, calexp_to_deg2=CALEXP_TO_DEG2)
    # calculate efficiency and artifact rate without applying flags
    default_eff = cal_eff(synthetic_src_df, flag=None)
    default_ar = cal_ar(art_df, n_calexp, flag=None, calexp_to_deg2=CALEXP_TO_DEG2)
    # calculate efficiency by applying all selected flags
    n_synthetic_src = len(synthetic_src_df)
    synthetic_src_det = synthetic_src_df.loc[synthetic_src_df.matched_status.astype(bool)]
    n_remaining_synthetic = count_remaining_source(synthetic_src_det, flag_list)
    union_eff = n_remaining_synthetic / n_synthetic_src
    # calculate artifact rate by applying all selected flags
    n_remaining_art = count_remaining_source(art_df, flag_list)
    union_ar = n_remaining_art / n_calexp * calexp_to_deg2
    # return analysis results
    results =  {'eff_ar': eff_ar, 'default_eff': default_eff, 'default_ar': default_ar,
                'union_eff': union_eff, 'union_ar': union_ar}
    return results

def write_flag_table(m20_results, m23_results, flag_list, schema, caption, label, save_path):
    # write table to disc
    with open(save_path, "w+") as file:
        file.write("\\begin{longrotatetable}\n")
        file.write("\\begin{deluxetable*}{llrrrr}\n")
        file.write("\\tablecaption{"
                   f"{caption} "
                   "\\label{"
                   f"{label}"
                   "}}\n")
        file.write("\\tablewidth{0pt}\n")
        file.write("\\tabletypesize{\scriptsize}\n")
        file.write("\\tablehead{\n")
        file.write("\\colhead{Flag} & \\colhead{Description} & \\colhead{Eff} & \\colhead{AR(deg$^{-2}$)} & "
                   "\\colhead{Eff} & \\colhead{AR(deg$^{-2}$)} \\\\\n")
        file.write("\\colhead{} & \\colhead{} & \\colhead{MAG20} & \\colhead{MAG20} & \\colhead{MAG23} & \\colhead{MAG23}}\n")
        file.write("\\startdata\n")
        file.write(f"no flag applied & \\nodata & {m20_results['default_eff']:.3f} & {int(m20_results['default_ar'])} & "
                   f"{m23_results['default_eff']:.3f} & {int(m23_results['default_ar'])}\\\\\n")
        for i, flag in enumerate(flag_list):
            flag_name = flag.replace('_', '\\_')
            # get the description of each flag
            dscp = schema.extract(f"{flag}")[f'{flag}'].getField().getDoc()
            dscp = dscp.replace('_', '\\_')
            file.write(f"{flag_name} & {dscp} & {m20_results['eff_ar'][i, 0]:.3f} & {int(m20_results['eff_ar'][i, 1])} & "
                       f"{m23_results['eff_ar'][i, 0]:.3f} & {int(m23_results['eff_ar'][i, 1])}\\\\\n")
        file.write(f"Union of Flags & Apply all of above flags for source selection & "
                   f"{m20_results['union_eff']:.3f} & {int(m20_results['union_ar'])} & "
                   f"{m23_results['union_eff']:.3f} & {int(m23_results['union_ar'])}\\\\\n")
        file.write("\\enddata\n")
        file.write("\\tablecomments{Flags can be used for removing artifacts. "
                   "For a specific flag, detected sources are classified as artifacts if the flag is set to True. "
                   "The first row shows the results without applying flags. "
                   "The last row shows the results of applying all of above flags}\n")
        file.write("\\end{deluxetable*}\n")
        file.write("\\end{longrotatetable}\n")

In [4]:
# get table schema and all flags
t = BaseCatalog.readFits('/pscratch/sd/s/shl159/Cori/projects/fake_injection_v23/dia_improvement/devel/data'
                         '/patch_0to6/diff/al_default_v23/00_20_21_1013665_79_i/diff_20/schema/deepDiff_diaSrc.fits')
t_astropy = t.asAstropy()
schema = t.schema
# get all flags
full_flags = []
for i in t_astropy.columns:
    if re.search('flag', i):
        full_flags.append(i)
print('The number of full flags: ', len(full_flags))
# get analysis results
n_calexp = 70
db = f'/pscratch/sd/s/shl159/Cori/projects/fake_injection_v23/dia_improvement/devel/data/patch_0to6/diff/al_default_v23/detection/detection.sqlite'
conn = sqlite3.connect(db)
# MAG20
print('M20')
mag = 20
synthetic_src_m20 = pd.read_sql_query(f"SELECT * FROM fake_src WHERE host_mag = '20_21' and fake_mag = {mag}", conn)
art_m20 = pd.read_sql_query(f"SELECT * FROM artifact WHERE host_mag = '20_21' and fake_mag = {mag}", conn)
print(f'# of detected synthetic sources: {len(synthetic_src_m20.loc[synthetic_src_m20.matched_status.astype(bool)])}, # of artifacts {len(art_m20)}')
m20_selected_results = eval_flags(synthetic_src_m20, art_m20, n_calexp, SELECTED_FLAGS, calexp_to_deg2=CALEXP_TO_DEG2)
m20_full_results = eval_flags(synthetic_src_m20, art_m20, n_calexp, full_flags, calexp_to_deg2=CALEXP_TO_DEG2)
# MAG23
print('M23')
mag = 23
synthetic_src_m23 = pd.read_sql_query(f"SELECT * FROM fake_src WHERE host_mag = '20_21' and fake_mag = {mag}", conn)
art_m23 = pd.read_sql_query(f"SELECT * FROM artifact WHERE host_mag = '20_21' and fake_mag = {mag}", conn)
print(f'# of detected synthetic sources: {len(synthetic_src_m23.loc[synthetic_src_m23.matched_status.astype(bool)])}, # of artifacts {len(art_m23)}')
m23_selected_results = eval_flags(synthetic_src_m23, art_m23, n_calexp, SELECTED_FLAGS, calexp_to_deg2=CALEXP_TO_DEG2)
m23_full_results = eval_flags(synthetic_src_m23, art_m23, n_calexp, full_flags, calexp_to_deg2=CALEXP_TO_DEG2)
# write results
# selected flags
save_path = './plots_and_tables/flags_selected.txt'
caption = 'Efficiency and Artifact Rate of Selected Flags'
label = 'tab:flag_selected'
write_flag_table(m20_selected_results, m23_selected_results, SELECTED_FLAGS, schema, caption, label, save_path)
# all flags
save_path = './plots_and_tables/flags_full.txt'
caption = 'Efficiency and Artifact Rate of All Flags'
label = 'tab:flag_full'
write_flag_table(m20_full_results, m23_full_results, full_flags, schema, caption, label, save_path)

The number of full flags:  109
M20
# of detected synthetic sources: 1184, # of artifacts 2446
M23
# of detected synthetic sources: 1003, # of artifacts 2464


We need to manually adjust the row of the `ip\_diffim\_forced\_PsfFlux\_flag\_edg` because its description is too long.
```
ip\_diffim\_forced\_PsfFlux\_flag\_edge & \makecell[l]{Forced PSF flux object was too close to the edge of the image \\to use the full PSF model.} & 1.000 & 689 & 0.847 & 694\\
```