In [1]:
# Search all sources in catalog II for those within 
# a threshold distance of each source in catalog I 

In [7]:
import numpy as np
from collections import defaultdict
import astropy.units as u
#import astropy.coordinates as cor
from astropy.coordinates import SkyCoord
import math
import sys

In [82]:
def search_cat2(ras1, des1, ras2, des2, rad):
    '''
    Inputs: 
        ras1, des1, ras2, des2: coordiantes in two catalogs, should be array-like unit: deg
        rad: matching radius, unit: arcsec

    Outputs:
        idx1_s, unique indexes of matched sourcese in catalog I
        idx2_ss, indexes of matched sources in catalog II
             as [[id2_1, id2_2, ..], [id2_1, id2_5, ...], ...]
        dss, distances of the pairs, unit: arcsec 
    '''
    # Format unification
    ras1 = np.array(ras1)
    des1 = np.array(des1)
    ras2 = np.array(ras2)
    des2 = np.array(des2)

    # Check if all RA, DEC are float
    if ( np.where(np.isfinite(ras1))[0].size != ras1.size ) | \
       ( np.where(np.isfinite(des1))[0].size != des1.size ) | \
       ( np.where(np.isfinite(ras2))[0].size != ras2.size ) | \
       ( np.where(np.isfinite(des2))[0].size != des2.size ) : \
        sys.exit('Error: non-float in ra,dec')

    if ras1.shape != des1.shape:
        raise ValueError('ra1 and dec1 do not match!')
    if ras2.shape != des2.shape:
        raise ValueError('ra2 and dec2 do not match!')
        
    # Set the astropy coordinate system
    cor1 = SkyCoord(ra=ras1*u.deg, dec=des1*u.deg, frame='icrs')
    cor2 = SkyCoord(ra=ras2*u.deg, dec=des2*u.deg, frame='icrs')
    # Search cor2
    idx1_s, idx2_s, ds, buf = cor2.search_around_sky(cor1, rad*u.arcsec)
    # Convert to units of arcsec
    ds = ds.value*3600
    # Split the arrys into subarrays
    # So that each subarry corresponds to a source in cat1
    split_idxs = np.where(idx1_s[1:]>idx1_s[:-1])[0]+1
    idx1_ss = np.split(idx1_s, split_idxs)
    idx2_ss = np.split(idx2_s, split_idxs)
    dss = np.split(ds, split_idxs)
    # Extract unique idx for idx1
    idx1_s = [buf[0] for buf in idx1_ss]
    
    return idx1_s, idx2_ss, dss