In [None]:
#use env ../speedPy.yml
import numpy as np
import astropy.units as u
from astropy.coordinates import SkyCoord 
from scipy.spatial import KDTree
import time
import timeit
from astroquery.vizier import Vizier
from astropy.table import Table

def summary(name,timeit_results):
    r=np.array(timeit_results)
    print(f'{name} {r.size} runs: \n Mean {r.mean()} \n Standard deviation {r.std()}')

## Cross matching
A common task in astronomy is to match coordinates. Say I have two sets of coordinates A and B. For each point in A I want to find all points in B that are within $\delta\theta$ of the point in set A. In pseudo code you might write this as 

    matches=[]
    for a in A:
        close=[]
            for b in B:
                if distance(a,b)< dt:
                    close.append(b)
        matches.append(close)

There are two complications with this naive algorithm.
1) How to compute distance(a,b)
2) This is one to one |A||B| times. Can this be done like
    1) one to many |A| times  (measue a to B |A| times)
    2) many to many |1| times (compute all distances at once)

Along the way Ill show how you can using pre-existing packages and custom math the do this fast in the general and small angle separation case

### Defining coordinates

The choice of coordinates is important so we are all talking about the same stuff. Typically the polar angle (declination) is measured from the equator of the sphere letting it range from -90 to +90. I dont want to deal with negative numbers so Ill redefine coordinates like this

$$ (\phi,\theta) --> (\phi,|\theta-90|)=(\phi,\theta') $$

where $\theta$ is the declination and $\phi$ is the Right Accension. So $\theta'\in[0,180]$ and $\phi\in [0,360)$

In [2]:
np.random.seed(42)
size=10000
A=np.array([np.random.uniform(low=0,high=360,size=size),np.random.uniform(low=0,high=180,size=size)]).T
B=np.array([np.random.uniform(low=0,high=360,size=size),np.random.uniform(low=0,high=180,size=size)]).T
# get a single coordinate by indexing along the first axis

First lets implement the pseudo-code and see how fast it is. For this Ill use Astropy's skycoord object which has a distance built into it

In [3]:
#https://docs.astropy.org/en/stable/api/astropy.coordinates.SkyCoord.html
#dt= 0.000555556 # 2 arc sec
dt= 5
def simplest_cross_match(A,B,dt=5):
    matches=[]
    for a in A:
        a_coord=SkyCoord(ra=a[0],dec=90-a[1],unit='deg') #convert dec back to astro definition for this example
        close_index=[]
        for i,b in enumerate(B):
            b_coord=SkyCoord(ra=b[0],dec=90-b[1],unit='deg')
            if a_coord.separation(b_coord).value < dt:
                close_index.append(i)
        matches.append(close_index)
    return matches

In [4]:
simplest_cross_match_times=timeit.Timer(f'simplest_cross_match(A[0:100],B[0:100],dt=5)',globals=globals()).repeat(5,number=1)
summary('simplest cross match only 100 coords',simplest_cross_match_times)

simplest cross match only 100 coords 5 runs: 
 Mean 5.2450774168006316 
 Standard deviation 0.03122704825221704


5 seconds for only 100 source catalogs!!! This is quite slow. There are some more methods with SkyCoords that may be faster, lets look at those first before doing anything crazy.

In [5]:
def skycoord_cross_match(A,B,dt=5):
    A_coord=SkyCoord(ra=A.T[0],dec=90-A.T[1],unit='deg')
    B_coord=SkyCoord(ra=B.T[0],dec=90-B.T[1],unit='deg')
    idxsearch,idxself,d2d,d3d=A_coord.search_around_sky(B_coord,seplimit=dt*u.deg)
    match=[idxsearch[idxself==i] for i in range(len(A_coord))]
    return match

In [6]:
skycoord_cross_match_times=timeit.Timer(f'skycoord_cross_match(A,B,dt=5)',globals=globals()).repeat(5,number=1)
summary('skycoord cross_match ALL coords',skycoord_cross_match_times)

skycoord cross_match ALL coords 5 runs: 
 Mean 1.6309715501993196 
 Standard deviation 0.011992009147250276


Can we do better? Each time we want to do this we have the cost of creating a SkyCoord object. What is the cost of this?

In [7]:
times=[] #not using timeit this time
reps=100
size=10000
for i in range(reps):
    c1=np.array([np.random.uniform(low=0,high=360,size=size),np.random.uniform(low=0,high=180,size=size)]).T
    c2=np.array([np.random.uniform(low=0,high=360,size=size),np.random.uniform(low=0,high=180,size=size)]).T
    ts=time.time() # dont count generating the numbers
    A_coord=SkyCoord(ra=c1.T[0],dec=90-c1.T[1],unit='deg')
    B_coord=SkyCoord(ra=c2.T[0],dec=90-c2.T[1],unit='deg')
    times.append(time.time()-ts)
    #make sure its not hiding in memory still
    del c1
    del c2
    del A_coord
    del B_coord

summary('Time to make sky coords',times)

Time to make sky coords 100 runs: 
 Mean 0.0004695606231689453 
 Standard deviation 9.66069992530121e-05


Almost no cost, this is going to be hard to beat. The method I will use is the same as astropy but since I'm building it dedicated to cross matching I may be able to save some time. 

This is how the algorithm work.

1) Transform coordinates to cartesian coordinates
2) Convert the angular separation to a 3D distance to do the comparison
3) use a kdtree to do the matching

In [8]:
#Step 1 convert to cart
def sph_to_cart(catalog):
    catalog=(catalog*np.pi/180).T #convert to radians
    sindec=np.sin(catalog[1])
    sinra=np.sin(catalog[0])
    cosdec=np.cos(catalog[1])
    cosra=np.cos(catalog[0])
    return np.asarray([sindec*cosra,sindec*sinra,cosdec])

cord_transform=timeit.Timer(f'sph_to_cart(A)',globals=globals()).repeat(1000,number=1)
summary('Spherical to cartisan 10,000 points',cord_transform)

Spherical to cartisan 10,000 points 1000 runs: 
 Mean 0.0005650236640176445 
 Standard deviation 5.793141598973002e-05


step 2

The angular separation $\Delta\delta$ can corresponds to a chord on a unit circle so

$$
s^2=1^2+1^2-2\cdot1\cdot1\cos{\Delta\delta}=2\left(1-\cos{\Delta\delta}\right)=4\sin^2(\Delta\delta/2)
$$

So our search criterion is just the distance is less than $2\sin(\Delta\delta)$ !!

Step 3

Before we go to step 3 lets see the timing if we dont use a kdtree

In [9]:
def cart_cross_manual(A,B,sep=5): #sep in units of deg
    C1=sph_to_cart(A)
    C2=sph_to_cart(B).T
    dsep=2*np.sin(np.pi*sep/360)
    matches=[]
    for c1 in C1.T:
        dist=np.sqrt(((c1-C2)**2).sum(axis=1))
        matches.append(np.argwhere(dist<dsep)[:,0])
    return matches
loop_match=timeit.Timer(f'cart_cross_manual(A,B)',globals=globals()).repeat(100,number=1)
summary('Cross match loop 10,000 points',loop_match)

Cross match loop 10,000 points 100 runs: 
 Mean 0.44492804751982473 
 Standard deviation 0.02040022541295747


Haha!! I beat Astropy!!!!!! Now do it with a kdtree

In [10]:
def cross_match_kdtree(A,B,sep=0.000555556,njobs=-1,eps=0): #sep is 2 arc sec
    C1=sph_to_cart(A).T
    C2=sph_to_cart(B).T
    dsep=2*np.sin(np.pi*sep/360)
    tree=KDTree(C2)
    results=tree.query_ball_point(C1,dsep,workers=njobs,eps=eps)
    return results
kdt_match=timeit.Timer(f'cross_match_kdtree(A,B,njobs=1)',globals=globals()).repeat(100,number=1)
summary('KDTree match 10,000 points',kdt_match)

KDTree match 10,000 points 100 runs: 
 Mean 0.014575100420042873 
 Standard deviation 0.00028440704314661224


This is already 10x faster and its only using 1 core. By setting the workers parameter to -1 it maxes out the number of processesrs. Using the kwarg eps we can get an approximate solution which should speed things up.

In [11]:
kdt_match=timeit.Timer(f'cross_match_kdtree(A,B,njobs=1,eps=0.0000555556)',globals=globals()).repeat(100,number=1) #eps of .2 arcsec
summary('KDTree match 10,000 points',kdt_match)

KDTree match 10,000 points 100 runs: 
 Mean 0.014555466629972216 
 Standard deviation 0.0006061893097968629


This catalog is realistic. The points are too sparse so the kdtree is not doing any work. The following is a better way to test these methods.

# REAL EXAMPLE
For this example I will download two catalogs from Vizier through python and perform a crossmatch

## The task

You are a young researcher who is interested varible sources in ZTF (who could that be?). You have a list of potentially interesting sources but you only want to report sources that haven't been classified as variable by GAIA. How can you go about doing this fast? 

In [12]:
# read in the tables from the web
Vizier.ROW_LIMIT = 1e6
ZTF_cat = Vizier.get_catalogs('J/ApJS/249/18/table2')
column_filters={'DE_ICRS':f'>{ZTF_cat[0]["DEJ2000"].min()}',
                'DE_ICRS':f'<{ZTF_cat[0]["DEJ2000"].max()}'} #there are so many gaia sources that we can do a little filtering first
v=Vizier(column_filters=column_filters)
v.ROW_LIMIT=1e7
GAIA_cat = v.get_catalogs('I/358/vclassre')

In [13]:
# Save files so I dont have to ask for the data every time
# use the parquet format to speed things up
ZTF_cat[0].write('Chen2020_table2.pq',format='parquet')
GAIA_cat[0].write('Gaia_subset.pq',format='parquet')

In [14]:
ZTF_cat=Table.read('Chen2020_table2.pq',format='parquet')
GAIA_cat=Table.read('Gaia_subset.pq',format='parquet')

In [16]:
# extract subtables with only the coordinates
GAIA_coords=GAIA_cat[['RA_ICRS','DE_ICRS']]
ZTF_coords=ZTF_cat[['RAJ2000','DEJ2000']]
# we will ignore that the frames of the catalogs are different (j2000 vs icrs)
print(f'Number of sources in GAIA catalog: {len(GAIA_coords)}')
print(f'Number of sources in Chen2020 ZTF varible catalog: {len(ZTF_coords)}')

ts=time.time()
# convert the dec to the convention that I use
GAIA_coords['DE_ICRS']=-1*(GAIA_coords['DE_ICRS']-90)
ZTF_coords['DEJ2000']=-1*(ZTF_coords['DEJ2000']-90)
#cast to numpy
A=ZTF_coords.to_pandas().to_numpy()
B=GAIA_coords.to_pandas().to_numpy()
# run the matching
matches_zero_eps = cross_match_kdtree(A,B,sep=0.00055,njobs=-1) # 2 arc sec
print(f'cross match time KDT {time.time()-ts}')
unique_sources_mask=[len(l)==0 for l in matches_zero_eps]
unique_sources=ZTF_cat[unique_sources_mask]
print(f'There are {sum(unique_sources_mask)} sources in Chen2020 ZTF varible catalog not in this subset of GAIA')
del A
del B

Number of sources in GAIA catalog: 9953751
Number of sources in Chen2020 ZTF varible catalog: 781602
cross match time KDT 4.313793897628784
There are 216699 sources in Chen2020 ZTF varible catalog not in this subset of GAIA


In [17]:
# extract subtables with only the coordinates
from astropy.coordinates import SkyCoord
import astropy.units as u
GAIA_coords=GAIA_cat[['RA_ICRS','DE_ICRS']]
ZTF_coords=ZTF_cat[['RAJ2000','DEJ2000']]
#cast to numpy
ts=time.time()
A=SkyCoord(ra=ZTF_coords['RAJ2000'],dec=ZTF_coords['DEJ2000'],unit='deg')
B=SkyCoord(ra=GAIA_coords['RA_ICRS'],dec=GAIA_coords['DE_ICRS'],unit='deg')

# run the matching

matches = A.search_around_sky(B,seplimit=2*u.arcsec) # 2 arc sec
print(f'cross match time Skycoords {time.time()-ts}')
del A
del B

cross match time Skycoords 6.592687129974365


My implementation is a little faster than Astropy, but this is mostly because in the source of .search_around_sky() the tree query is forced to use only one worker. My method will perform better the more core that are available.