## Run cell-based matching

This takes as set of input metadetect catalogs and runs matching using the cell-based `ShearMatch` matcher.

#### Standard imports

In [None]:
import hpmcm
import glob
import os
import numpy as np
import matplotlib.pyplot as plt

#### Set up the configuration

In [None]:
DATADIR = "test_data"   # Input data directory
shear_st = "0p01"       # Applied shear as a string
shear = 0.01            # Decimal version of applied shear
shear_type = "wmom"     # which object characterization to use 
tract = 10463           # which tract to study

SOURCE_TABLEFILES = sorted(glob.glob(os.path.join(DATADIR, f"shear_{shear_type}_{shear_st}_uncleaned_{tract}_*.pq")))
SOURCE_TABLEFILES.reverse()
VISIT_IDS = np.arange(len(SOURCE_TABLEFILES))

PIXEL_R2CUT = 4.         # Cut at distance**2 = 4 pixels
PIXEL_MATCH_SCALE = 1    # Use pixel scale to do matching

#### Make the matcher, reduce the data

In [None]:
matcher = hpmcm.ShearMatch.createShearMatch(pixelR2Cut=PIXEL_R2CUT, pixelMatchScale=PIXEL_MATCH_SCALE, deshear=-1*shear)

In [None]:
matcher.reduceData(SOURCE_TABLEFILES, VISIT_IDS)

#### This should have made 200 x 200 cells

In [None]:
matcher.nCell

#### Run the data

Note the option to run all the cells.  By default we only run a small subset for testing

In [None]:
do_partial = True
if do_partial:
    xRange = range(50, 70)
    yRange = range(170, 190)
    #xRange = [55]
    #yRange = [170]
    matcher.analysisLoop(xRange, yRange)
else:
    matcher.analysisLoop()

#### Show the source counts map for a single cell

The x and y axes here are the in the cell frame.
The color scale shows the number of sources per/pixel.
The analysis looks for clusters of adjacent pixels with counts.

In [None]:
cell = matcher.cellDict[matcher.getCellIdx(50, 170)]
od = cell.analyze(None, 4)
_ = plt.imshow(od['countsMap'], origin='lower')
_ = plt.colorbar(label="n sources / pixel")
_ = plt.xlabel(r"$x_{\rm cell}$ [pixels]")
_ = plt.ylabel(r"$y_{\rm cell}$ [pixels]")

#### Show a single cluster

The x and y axes here are the in the cluster frame for a single cluster.
The color scale shows the number of sources per/pixel.

The `x` markers are the original source postions.   The `o` makters are the deshear positions.


In [None]:
cluster = list(cell.clusterDict.values())[0]
fig = hpmcm.viz_utils.showCluster(od['image'], cluster, cell)
_ = fig.axes[0].set_xlim(-1, 1)
_ = fig.axes[0].set_ylim(0, 2)
_ = fig.axes[0].set_xlabel(r"$x_{\rm cluster}$ [pixels]")
_ = fig.axes[0].set_ylabel(r"$y_{\rm cluster}$ [pixels]")

#### Extract the output of the matching

There are a few empty cells to play around with the output data.

`stats` and `shear_stats` are both tuples of pandas.DataFrame 

In [None]:
stats = matcher.extractStats()
shear_stats = matcher.extractShearStats()
obj_shear = shear_stats[1]

In [None]:
stats[0]

#### Get the offsets between the cluster centroid and the sources

This is to check that the deshearing is correctly applied

In [None]:
def get_offsets(matcher):
    n = 0
    dd = {
        0:dict(dx=[], dy=[], x=[], y=[]), 
        1:dict(dx=[], dy=[], x=[], y=[]), 
        2:dict(dx=[], dy=[], x=[], y=[]), 
        3:dict(dx=[], dy=[], x=[], y=[]), 
        4:dict(dx=[], dy=[], x=[], y=[]), 
    }
    for cellData in matcher.cellDict.values():
        n += len(cellData.data[0])
        for obj in cellData.objectDict.values():
            if not obj.nUnique == 5 and obj.nSrc == 5:
                continue
            for iCat in range(5):
                mask = obj.catIndices == iCat
                if mask.sum() == 0:
                    continue
                for dx, dy in zip((obj.xPix[mask] - obj.xCent), (obj.yPix[mask] - obj.yCent)):
                    dd[iCat]["dx"].append(dx)
                    dd[iCat]["dy"].append(dy)
                    dd[iCat]["x"].append(float(obj.data[mask].iloc[0].xCell))
                    dd[iCat]["y"].append(float(obj.data[mask].iloc[0].yCell))

    for i in range(5):
        dd[i]['dx'] = np.array(dd[i]['dx'])
        dd[i]['dy'] = np.array(dd[i]['dy'])
        dd[i]['x'] = np.array(dd[i]['x'])
        dd[i]['y'] = np.array(dd[i]['y'])
    print(n)
    return dd                  
                    
    

In [None]:
dd = get_offsets(matcher)

#### Plots the residuals, they should be flat

In [None]:
_ = plt.scatter(dd[4]['x'], dd[4]['dx'])

#### Look at how the sources lie within the cells

In [None]:
_ = plt.hist(matcher.fullData[0].xCellCoadd, bins=np.linspace(-100, 100, 201))
#_ = plt.hist(matcher.fullData[0].loc[stats[0][mask_0].id].xCell_coadd, bins=np.linspace(-100, 100, 201))

In [None]:
_ = plt.hist(matcher.fullData[0].yCellCoadd, bins=np.linspace(-100, 100, 201))
#_ = plt.hist(matcher.fullData[0].loc[stats[0][mask_0].id].yCell_coadd, bins=np.linspace(-100, 100, 201))

#### Classify the objects by match type

This looks at the characteristics of the matched objects and categorizes them.

In [None]:
objLists = hpmcm.classify.classifyObjects(matcher, SNRCut=10.)
hpmcm.classify.printObjectTypes(objLists)

#### Measure the matching efficiency for objects above the SNRCut

In [None]:
n_good = len(objLists['ideal'])
bad_list = ['edge_mixed', 'edge_missing', 'edge_extra', 'orphan', 'missing', 'two_missing', 'many_missing', 'extra', 'caught']
n_bad = np.sum([len(objLists[x]) for x in bad_list])
effic = n_good/(n_good+n_bad)
effic_err = np.sqrt(effic*(1-effic)/(n_good+n_bad))
print(f"Effic: {effic:.5} +- {effic_err:.5f}")

#### Classify the clusters by match type

This looks at the characteristics of the matched cluster and categorizes them.  

In [None]:
clusterLists = hpmcm.classify.classifyClusters(matcher, SNRCut=10.)
hpmcm.classify.printClusterTypes(clusterLists)

#### Display a few objects

The various markers show the sources from different shear catalogs:  `ns=.`, `1m = <`, `1p = >`, `2m = ^`, `2p = v`. 

In [None]:
_ = hpmcm.viz_utils.showShearObjs(matcher, clusterLists['ideal'][5])

In [None]:
_ = hpmcm.viz_utils.showShearObj(matcher, objLists['many_missing'][0])