In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

In [None]:
from glob import glob

CELLS_DTYPES = dict([
    ('hit_id', 'i4'),
    ('ch0', 'i4'),
    ('ch1', 'i4'),
    ('value', 'f4'),
])
HITS_DTYPES = dict([
    ('hit_id', 'i4'),
    ('x', 'f4'),
    ('y', 'f4'),
    ('z','f4'),
    ('volume_id', 'i4'),
    ('layer_id', 'i4'),
    ('module_id', 'i4'),
])
PARTICLES_DTYPES = dict([
    ('particle_id', 'i8'),
    ('vx', 'f4'),
    ('vy', 'f4'),
    ('vz', 'f4'),
    ('px', 'f4'),
    ('py', 'f4'),
    ('pz', 'f4'),
    ('q', 'i4'),
    ('nhits', 'i4'),
])
TRUTH_DTYPES = dict([
    ('hit_id', 'i4'),
    ('particle_id', 'i8'),
    ('tx', 'f4'),
    ('ty', 'f4'),
    ('tz', 'f4'),
    ('tpx', 'f4'),
    ('tpy', 'f4'),
    ('tpz', 'f4'),
    ('weight', 'f4'),
])
DTYPES = {
    'cells': CELLS_DTYPES,
    'hits': HITS_DTYPES,
    'particles': PARTICLES_DTYPES,
    'truth': TRUTH_DTYPES,
}
def _load_event_data(prefix, name):
    """Load per-event data for one single type, e.g. hits, or particles.
    """
    expr = '{!s}-{}.csv*'.format(prefix, name)
    files = glob(expr)
    dtype = DTYPES[name]
    if len(files) == 1:
        return pd.read_csv(files[0], header=0, index_col=False, dtype=dtype)
    elif len(files) == 0:
        raise Exception('No file matches \'{}\''.format(expr))
    else:
        raise Exception('More than one file matches \'{}\''.format(expr))
def parse_event(event, sample_reduction=0.001, load_truth=True):
    hits = _load_event_data(event,'hits')
    if not load_truth:
        return hits
    truth = _load_event_data(event,'truth')
    particles = _load_event_data(event,'particles')
    hits['particle_id']=truth['particle_id']
    hits['weight']=truth['weight']
    
    # filter to >2 hits and sample
    hit_particles=particles[particles['nhits']>2].sample(frac=sample_reduction)
    
    # merge and reshuffle
    merged_hits=hits.merge(hit_particles,how='inner',on='particle_id').sample(frac=1)
    new_hits = merged_hits[['hit_id','x','y','z','particle_id','weight','volume_id','layer_id','module_id']]
    #new_truth = merged_hits[['hit_id','particle_id','weight']]
    return new_hits

def full_event(event, sample_reduction=0.001, load_truth=True):
    hits = _load_event_data(event,'hits')
    if not load_truth:
        return hits
    truth = _load_event_data(event,'truth')
    particles = _load_event_data(event,'particles')
    hits['particle_id']=truth['particle_id']
    hits['weight']=truth['weight']
    return hits

In [None]:
from scipy.spatial import KDTree
def kd_neighbors(X,region_size=0.5):
    kd=KDTree(X)
    all_pairs=kd.query_pairs(region_size)
    return np.array(list(all_pairs))
def angular_distance(theta1, theta2):
    return np.abs(((theta1-theta2+np.pi) % (2*np.pi)) - np.pi)
def eta(th):
    return -np.log(np.tan(th/2.0))
def pairwise_angular(x,xp,y,yp,rref,eps=1e-9):
    a=yp-y
    b=xp-x
    c=xp*y-yp*x
    inv=np.sign(c)/np.sqrt(a*a+b*b+eps)
    fd=c*inv
    d=np.sqrt(x*x+y*y)
    theta0=np.arctan2(y,x)
    thetaa=np.arctan2(a,b)
    add_pi = (angular_distance(theta0,thetaa+np.pi) < angular_distance(theta0,thetaa)).astype(np.float32)
    thetaa=thetaa + add_pi * np.pi
    thetai = np.arcsin((d/rref)*np.sin(theta0-thetaa)) + thetaa
    return fd.flatten(),thetaa.flatten(),thetai.flatten()
def create_fields(df):
    x=df['x'].values
    y=df['y'].values
    z=df['z'].values
    R2=x*x+y*y
    r=np.sqrt(R2)
    df['r']=r
    df['u']=x/R2
    df['v']=y/R2
    df['rho']=np.sqrt(R2+z*z)
    df['phi']=np.arctan2(y,x)
    df['phi_x']=x/r
    df['phi_y']=y/r
    df['theta']=np.arctan2(r,z)
    df['eta']=eta(df['theta'].values)
def create_pairwise(df,truth=None,region_size=0.5,reference_sphere=200.0):
    create_fields(df)
    kd_bias=[2,1,1]
    kd_pairs=kd_neighbors(df[['eta','phi_x','phi_y']].values*kd_bias,region_size)
    df1 = df.take(kd_pairs[:,0])
    df2 = df.take(kd_pairs[:,1])
    pairs = pd.DataFrame()
    pairs['idx1']=df1['hit_id'].values
    pairs['idx2']=df2['hit_id'].values
    pairs['rz_d'],pairs['rz_tha'],pairs['rz_thi'] = pairwise_angular(df1['z'].values, df2['z'].values, df1['r'].values, df2['r'].values, reference_sphere)
    rref = 1.0 / reference_sphere / np.sin(pairs['rz_thi'].values)
    pairs['uv_d'],pairs['uv_tha'],pairs['uv_thi'] = pairwise_angular(df1['u'].values, df2['u'].values, df1['v'].values, df2['v'].values, rref)
    return pairs.dropna()
def get_pairwise(hits,truth=None,region_size=0.5,reference_sphere=200.0,z_max=10):
    pairs=create_pairwise(hits,truth,region_size,reference_sphere)
    filtered_pairs = pairs.drop(pairs[pairs['rz_d'] > z_max].index)
    filtered_pairs = filtered_pairs.drop(filtered_pairs[filtered_pairs['rz_tha'] < 0].index)
    filtered_pairs = filtered_pairs.drop(filtered_pairs[filtered_pairs['rz_tha'] > np.pi].index)
    filtered_pairs = filtered_pairs.drop(filtered_pairs[filtered_pairs['rz_thi'] < 0].index)
    filtered_pairs = filtered_pairs.drop(filtered_pairs[filtered_pairs['rz_thi'] > np.pi].index)
    filtered_pairs['rz_cosa']=np.cos(filtered_pairs['rz_tha'].values)
    filtered_pairs['rz_sina']=np.sin(filtered_pairs['rz_tha'].values)
    filtered_pairs['rz_cosi']=np.cos(filtered_pairs['rz_thi'].values)
    filtered_pairs['rz_sini']=np.sin(filtered_pairs['rz_thi'].values)
    filtered_pairs['uv_cosa']=np.cos(filtered_pairs['uv_tha'].values)
    filtered_pairs['uv_sina']=np.sin(filtered_pairs['uv_tha'].values)
    filtered_pairs['uv_cosi']=np.cos(filtered_pairs['uv_thi'].values)
    filtered_pairs['uv_sini']=np.sin(filtered_pairs['uv_thi'].values)
    return filtered_pairs

In [None]:
th_range,bin_dth=np.pi,0.1
phi_range,bin_dphi=2*np.pi,0.1
th_bins,phi_bins=int(th_range/bin_dth)+2,int(phi_range/bin_dphi)+2
def add_to_boxes(boxes,th_list,phi_list,sign=1,add_neighbors=False):
    for i in range(len(th_list)):
        th,phi=th_list[i],phi_list[i]
        i_th=int((th % th_range) / bin_dth)
        i_phi=int((phi % phi_range) / bin_dphi)
        boxes[i_th+i_phi*th_bins].append(sign*(i+1))
    if add_neighbors:
        add_to_boxes(boxes,th_list+bin_dth,phi_list+bin_dphi,sign=sign)
        add_to_boxes(boxes,th_list+bin_dth,phi_list,sign=sign)
        add_to_boxes(boxes,th_list+bin_dth,phi_list-bin_dphi,sign=sign)
        add_to_boxes(boxes,th_list,phi_list+bin_dphi,sign=sign)
        add_to_boxes(boxes,th_list,phi_list-bin_dphi,sign=sign)
        add_to_boxes(boxes,th_list-bin_dth,phi_list+bin_dphi,sign=sign)
        add_to_boxes(boxes,th_list-bin_dth,phi_list,sign=sign)
        add_to_boxes(boxes,th_list-bin_dth,phi_list-bin_dphi,sign=sign)
def find_collisions_bins(inner_th,inner_phi,outer_th,outer_phi):
    boxes=[[] for _ in range(th_bins*phi_bins+1)]
    inner_idx=[]
    outer_idx=[]
    scores=[]
    add_to_boxes(boxes,inner_th,inner_phi,sign=-1,add_neighbors=True)
    add_to_boxes(boxes,outer_th,outer_phi)
    collision_boxes=[box for box in boxes if len(box)>0 and min(box)<0 and max(box)>0]
    for box in collision_boxes:
        box_inner_idx=[(-x-1) for x in box if x<0]
        box_outer_idx=[(x-1) for x in box if x>0]
        dphi=angular_distance(inner_phi[box_inner_idx,None],outer_phi[box_outer_idx])*np.sin(.5*inner_th[box_inner_idx,None]+.5*outer_th[box_outer_idx])
        dth=angular_distance(inner_th[box_inner_idx,None],outer_th[box_outer_idx])
        pair_score=np.sqrt(dphi*dphi+dth*dth)
        good_inner_idx,good_outer_idx=np.where(pair_score<0.1)
        inner_idx.extend(np.array(box_inner_idx)[good_inner_idx])
        outer_idx.extend(np.array(box_outer_idx)[good_outer_idx])
        scores.extend([pair_score[i,j] for i,j in zip(good_inner_idx,good_outer_idx)])
    return inner_idx,outer_idx,scores
def find_collisions(inner_th,inner_phi,outer_th,outer_phi,correction_factor=1.0):
    dphi=angular_distance(inner_phi[:,None],outer_phi) #*np.sin(.5*inner_th[:,None]+.5*outer_th)
    dth=angular_distance(inner_th[:,None],outer_th)/(.05+np.abs(np.sin(.5*inner_th[:,None]+.5*outer_th)))
    pair_score=np.sqrt(dphi*dphi+dth*dth)*correction_factor
    inner_idx,outer_idx=np.where(pair_score<0.1) #*np.sqrt(hits.loc[i,'r']/200))
    scores=[pair_score[i,j] for i,j in zip(inner_idx,outer_idx)]
    return inner_idx,outer_idx,scores

In [None]:
def label_hits(filtered_pairs, labels,n_hits=200000):
    sorted_pairs = filtered_pairs.sort_values('label_id')[['idx1','idx2','label_id']].values
    #n_hits = np.max(sorted_pairs[:,:2])
    hit_labels = np.zeros(n_hits+1,dtype=np.int32)
    hit_label_count = np.zeros(n_hits+1,dtype=np.int32)
    n_labels = np.max(labels)+1
    label_idx = np.searchsorted(sorted_pairs[:,2], range(n_labels+1), side='right')
    for label in labels:
        hit_index_all=sorted_pairs[label_idx[label-1]:label_idx[label],:2].flatten()
        hit_indices,hit_index_count = np.unique(hit_index_all,return_counts=True)
        for i in range(len(hit_indices)):
            idx = hit_indices[i]
            count = hit_index_count[i]
            if count > hit_label_count[idx]:
                hit_labels[idx] = label
                hit_label_count[idx] = count
    return hit_labels
def get_track_labels(labels,hits,pairs,min_count=2,max_iter=10):
    n_labels = np.max(labels)+1
    hit_labels = label_hits(pairs, range(1,n_labels+1))
    for i in range(max_iter):
        glab,gcount = np.unique(hit_labels,return_counts=True)
        good_labels = [lab for lab,count in zip(glab,gcount) if (lab > 0 and count >= min_count)]
        if len(good_labels)==n_labels:
            break
        n_labels=len(good_labels)
        hit_labels = label_hits(pairs, good_labels)
    return hit_labels[hits['hit_id'].values]
def full_score(hits, verbose=True):
    total_weight=0.0
    track_bins=[50,80,95,200]
    good_weights=np.zeros(len(track_bins))
    lost_weights=np.zeros(len(track_bins))
    remain_weights=np.zeros(len(track_bins))
    overflow_weight=0.0
    unmatched_weight=0.0
    total_weight=np.sum(hits['weight'].values)
    tdf=hits.groupby('track_id')
    for group,gdf in hits.groupby('particle_id'):
        track_ids,track_counts=np.unique(gdf['track_id'].values,return_counts=True)
        good_pct=max(track_counts)/len(gdf)*100
        good_id=track_ids[np.argmax(track_counts)]
        good_weight=np.sum(gdf[gdf['track_id']==good_id]['weight'].values)
        particle_weight=np.sum(gdf['weight'].values)
        if good_pct>50.0:
            tbin=[i for i,b in enumerate(track_bins) if b<good_pct][-1]
            if good_id > 0:
                group_frac=max(track_counts)/len(tdf.get_group(good_id))
                if group_frac>0.5:
                    good_weights[tbin]+=good_weight
                    lost_weights[tbin]+=particle_weight-good_weight
                else:
                    overflow_weight+=particle_weight
            else:
                remain_weights[tbin]+=good_weight
                lost_weights[tbin]+=particle_weight-good_weight
        else:
            unmatched_weight+=particle_weight
    if verbose:
        print('overflow / unmatched: {:.4f} / {:.4f}'.format(overflow_weight/total_weight, unmatched_weight/total_weight))
        print('bin: good / lost / remaining')
        for i in range(len(good_weights)-1):
            print('{}%: {:.4f} / {:.4f} / {:.4f}'.format(track_bins[i],good_weights[i]/total_weight,lost_weights[i]/total_weight,remain_weights[i]/total_weight))
        print('total: {:.4f} / {:.4f} / {:.4f}'.format(np.sum(good_weights)/total_weight,np.sum(lost_weights)/total_weight,np.sum(remain_weights)/total_weight))
    return np.sum(good_weights)/total_weight
def run_cluster_round(hits,region_size=100,reference_sphere=500.0,z_max=10,eps=0.01,min_samples=3,min_count=2,parts=None):
    pairs=get_pairwise(hits,region_size=region_size,reference_sphere=reference_sphere,z_max=z_max)
    if parts is not None:
        labeled_pairs=fit_bypart(pairs,eps=eps,min_samples=min_samples,iparts=parts[0],jparts=parts[1])
    else:
        labels=fit_pairs(pairs,eps=eps,min_samples=min_samples)
        pairs['label_id']=labels+1
        labeled_pairs=pairs
    hits['track_id'] = get_track_labels(labeled_pairs['label_id'].values,hits,labeled_pairs,min_count=min_count)
def apply_cluster_labels(hits,round_hits):
    new_labels_index=(round_hits['track_id']>0)
    hits.loc[round_hits[new_labels_index].index,('track_id')]=round_hits['track_id'].values[new_labels_index]+np.max(hits['track_id'])
    drop_idx=hits[hits['track_id']>0].index
    round_hits=hits.drop(drop_idx)
    return round_hits

In [None]:
hits=parse_event('../input/train_1/event000001045',sample_reduction=0.1)

In [None]:
create_fields(hits)
kd_bias=[2,1,1]
hits['rp']=np.sqrt(hits['rho'].values/300)
X=hits[['eta','phi_x','phi_y','rp']].values
kd=KDTree(X)
points=hits[['u','v','r','z','rho']].values

In [None]:
%%time
import networkx as nx
def get_angles(ref_pt,points,sign=1):
    phi = np.arctan2(sign*(points[:,1]-ref_pt[1]),sign*(points[:,0]-ref_pt[0]))
    theta = np.arctan2(sign*(points[:,2]-ref_pt[2]),sign*(points[:,3]-ref_pt[3]))
    return theta,phi
def encode_pair(a,b):
    return int(a*1e6+b)
def decode_pair(p):
    return int(p/1e6),int(p % 1e6)
# G = nx.Graph()
graph = {}
all_pairs=set()
hit_pairs=[set() for _ in range(len(hits))]
radius=0.7
min_drho=1.0
min_dr=2.0
min_dz=2.0
from tqdm import tqdm_notebook as tqdm
for i,point in enumerate(tqdm(hits[['eta','phi_x','phi_y','rp']].values)):
    all_neighbors=np.array(kd.query_ball_point(point,radius))
    rz_cond=np.logical_and(np.abs(points[all_neighbors,2]-points[i,2])>min_dr,np.abs(points[all_neighbors,3]-points[i,3])>min_dz)
    inner_cond=points[all_neighbors,-1]<points[i,-1]-min_drho
    outer_cond=points[all_neighbors,-1]>points[i,-1]+min_drho
    inner_neighbors=all_neighbors[np.logical_and(inner_cond,rz_cond)]
    outer_neighbors=all_neighbors[np.logical_and(outer_cond,rz_cond)]
    inner_points=points[inner_neighbors]
    outer_points=points[outer_neighbors]
    inner_th,inner_phi=get_angles(points[i],inner_points,sign=-1)
    outer_th,outer_phi=get_angles(points[i],outer_points)
    inner_idx,outer_idx,scores=find_collisions(inner_th,inner_phi,outer_th,outer_phi,correction_factor=1.0/point[-1])
    for j,k,score in zip(inner_idx,outer_idx,scores):
        ij=inner_neighbors[j]
        ik=outer_neighbors[k]
        in_pair = encode_pair(ij,i)
        out_pair = encode_pair(i,ik)
        all_pairs.update([in_pair,out_pair])
        hit_pairs[i].update([in_pair,out_pair])
        hit_pairs[ij].add(in_pair)
        hit_pairs[ik].add(out_pair)
        if in_pair in graph:
            graph[in_pair].append((out_pair,score))
        else:
            graph[in_pair]=[(out_pair,score)]
        if out_pair in graph:
            graph[out_pair].append((in_pair,score))
        else:
            graph[out_pair]=[(in_pair,score)]

In [None]:
#G=nx.Graph(G)

In [None]:
cc=list(nx.connected_components(G))
# small_graphs=[c for c in cc if len(c)<5]
# for graph in small_graphs:
#     for node in graph:
#         G.remove_node(node)
sg=G.subgraph(cc[np.argmax([len(c) for c in cc])])
d=nx.convert_matrix.to_scipy_sparse_matrix(sg)
d

In [None]:
from hdbscan._hdbscan_reachability import sparse_mutual_reachability
from scipy.sparse import csgraph
dd=sparse_mutual_reachability(d.tolil(),min_points=2)
_,components=csgraph.connected_components(dd,directed=False)
cnum,counts=np.unique(components,return_counts=True)
largest_component=cnum[np.argmax(counts)]
nodes_to_drop=[n for c,n in zip(components,list(sg.nodes())) if c != largest_component]
G.remove_nodes_from(nodes_to_drop)
d=nx.convert_matrix.to_scipy_sparse_matrix(sg)

In [None]:
d

In [None]:
import hdbscan
clusterer = hdbscan.HDBSCAN(metric='precomputed',min_samples=2,cluster_selection_method='leaf')
clusterer.fit(d)
clusterer.labels_

In [None]:
hits.loc[hits.index[np.unique([[a,b] for (a,b),l in zip(sg.nodes(),clusterer.labels_) if l==0])]]

In [None]:
hits.query('particle_id==225182592708640768')

In [None]:
all_pairs=list(G.nodes)

In [None]:
F=nx.Graph(G)

In [None]:
for n in F[all_pairs[0]]:
    print(F[all_pairs[0]][n]['weight'])

In [None]:
[decode_pair(p) for p,_ in graph[list(all_pairs)[3]]]

In [None]:
pair_labels[4000033]

In [None]:
F[all_pairs[0]]

In [None]:
hits['particle_id'].values[[8826,1,6700]]

In [None]:
pair_labels={}
C = 0
eps=0.01
min_count=1
noise = -1
def flat_neighbors(pair, eps):
    # return [nbr for nbr in G[pair] if G[pair][nbr]['weight'] < eps]
    return [nbr for nbr,score in graph[pair] if score < eps]
# DBSCAN Algo
for pair in tqdm(all_pairs):
    if pair in pair_labels:
        continue
    neighbors=flat_neighbors(pair,eps)
    neighbor_set=set(neighbors)
    if len(neighbors) < min_count:
        pair_labels[pair]=noise
        continue
    C = C + 1
    pair_labels[pair] = C
    for npair in neighbors:
        if npair in pair_labels:
            if pair_labels[npair] == noise:
                pair_labels[npair] = C
            continue
        pair_labels[npair] = C
        nneighbors = flat_neighbors(npair,eps)
        if len(nneighbors) > min_count:
            new_neighbors=set(nneighbors).difference(neighbor_set)
            neighbors.extend(list(new_neighbors))
            neighbor_set.update(new_neighbors)
    

In [None]:
def get_hit_labels(restrict_to_labels=None):
    hit_labels=np.zeros(len(hits),dtype=np.int64)
    for i in range(len(hits)):
        this_hit_labels=[pair_labels[p] for p in hit_pairs[i] if pair_labels[p] > 0]
        if restrict_to_labels is not None:
            this_hit_labels=[x for x in this_hit_labels if x in restrict_to_labels]
        if len(this_hit_labels) > 0:
            labels,counts=np.unique(this_hit_labels,return_counts=True)
            hit_labels[i]=labels[np.argmax(counts)]
    return hit_labels
hit_labels=get_hit_labels()
labels,counts=np.unique(hit_labels,return_counts=True)
restricted_labels=labels[np.where(counts>8)]
hit_labels=get_hit_labels(restricted_labels)
hits['track_id']=hit_labels

In [None]:
full_score(hits)

In [None]:
hits.sample(n=10)

In [None]:
hits.query('particle_id==558447315866615808')

In [None]:
sigma_map={7: 0.03, 8: 0.03, 9: 0.03, 12: 0.3, 13: 0.3, 14: 0.3, 16: 3, 17: 3, 18: 3}
hits['sigma']=[sigma_map[x] for x in hits['volume_id'].values]

In [None]:
all_pairs_list=np.array(list(all_pairs))
decoded_pairs_list=np.array([list(decode_pair(p)) for p in all_pairs_list])
pair_label_list=np.array([pair_labels[p] for p in all_pairs_list])
max_label=max(pair_label_list)+1

In [None]:
[i for i in range(len(hits)) if label_index[i]>0]

In [None]:
hits.sample(n=10)

In [None]:
hits.query('track_id==672').sort_values('rho')

In [None]:
hits.query('particle_id==567455133596647424').sort_values('rho')

In [None]:
part_index=[]
for hit_id in hits.query('particle_id==567455133596647424')['hit_id'].values:
    part_index.extend(np.where(hits['hit_id'].values==hit_id)[0].tolist())
part_index

In [None]:
np.where(hits['hit_id'].values==95336)

In [None]:
for test_pair in hit_pairs[696]:
    if test_pair[0] in part_index and test_pair[1] in part_index:
        print('test pair: {}'.format(test_pair))
        for nbr in F[test_pair]:
            if nbr[0] in part_index and nbr[1] in part_index:
                print('{}: {} [{}]'.format(nbr,F[test_pair][nbr]['weight'],pair_labels[nbr] if nbr in pair_labels else 0))
                for nnbr in F[nbr]:
                     if nnbr[0] in part_index and nnbr[1] in part_index:
                        print('-{}: {} [{}]'.format(nnbr,F[nbr][nnbr]['weight'],pair_labels[nnbr] if nnbr in pair_labels else 0))

In [None]:
this_hit_labels=[pair_labels[p] for p in hit_pairs[3385]if pair_labels[p] > 0]
if len(this_hit_labels) > 0:
    labels,counts=np.unique(this_hit_labels,return_counts=True)
    hit_labels[i]=labels[np.argmax(counts)]

In [None]:
labels

In [None]:
labels[np.argmax(counts)]

In [None]:
hit_labels[3385]

In [None]:
nhood=[]
for pair in hit_pairs[5]:
    if pair in in_graph:
        nhood.extend(in_graph[pair])
    if pair in out_graph:
        nhood.extend(out_graph[pair])
[(pair_labels[p],s) for p,s in nhood]

In [None]:
[(pair_labels[p],s) for p,s in nhood]

In [None]:
np.where(hits['hit_id'].values==17943)

In [None]:
npedge=np.array(all_edges)
npedge=npedge[npedge[:,2].argsort()]
all_pairs=list(inner_pairs|outer_pairs)
pair_labels={}
merged_labels={}
C=0
for edge in npedge:
    p1,p2,_=edge
    if p1 in pair_labels and p2 in pair_labels:
        l1=pair_labels[p1]
        l2=pair_labels[p2]
        if l1 != l2:
            merged_labels[max(l1,l2)]=min(l1,l2)
        continue
    if p1 in pair_labels:
        pair_labels[p2]=pair_labels[p1]
        continue
    if p2 in pair_labels:
        pair_labels[p1]=pair_labels[p2]
        continue
    C=C+1
    pair_labels[p1]=C
    pair_labels[p2]=C

In [None]:
npedge=np.array(list(G.edges.data('weight')))

In [None]:
npedge[:10]

In [None]:
# contracting edges in order of weight
npedge=np.array(list(G.edges.data('weight')))
npedge=npedge[npedge[:,2].argsort()]
all_pairs=list(G.nodes)
pair_labels={}
pair_state={}
merged_labels={}
C=0
for edge in G.edges.data('weight'):
    p1,p2,_=edge
    if p1 in pair_state and pair_state[p1] < 1:
        continue
    if p2 in pair_state and pair_state[p2] > -1:
        continue
    if p1 in pair_labels and p2 in pair_labels:
        pair_state[p1]=0
        pair_state[p2]=0
        l1=pair_labels[p1]
        l2=pair_labels[p2]
        if l1 != l2:
            merged_labels[max(l1,l2)]=min(l1,l2)
        continue
    if p1 in pair_labels:
        pair_state[p1]=0
        pair_state[p2]=1
        pair_labels[p2]=pair_labels[p1]
        continue
    if p2 in pair_labels:
        pair_state[p1]=-1
        pair_state[p2]=0
        pair_labels[p1]=pair_labels[p2]
        continue
    C=C+1
    pair_labels[p1]=C
    pair_labels[p2]=C
    pair_state[p1]=-1
    pair_state[p2]=1

In [None]:
labels,counts=np.unique([pair_labels[p] for p in all_pairs if p in pair_labels],return_counts=True)
labels[counts>3],counts[counts>3]

In [None]:
[p for p in all_pairs if p in pair_labels and pair_labels[p]==7]

In [None]:
test_pts=np.unique([[p[0],p[1]] for p in all_pairs if p in pair_labels and pair_labels[p]==1]).flatten()
hits.loc[hits.index[test_pts]]

In [None]:
test_pts=np.array([[p[0],p[1]] for p in all_pairs if p in pair_labels and pair_labels[p] in labels[counts>2]]).flatten()
len(np.unique(hits.loc[hits.index[test_pts],'particle_id']))/len(np.unique(hits['particle_id']))

In [None]:
all_same_track=[]
for l in labels[counts>2]:
    test_pts=np.array([[p[0],p[1]] for p in all_pairs if p in pair_labels and pair_labels[p]==l]).flatten()
    all_same_track.append(np.all(hits.loc[hits.index[test_pts],'particle_id']==hits.loc[hits.index[test_pts[0]],'particle_id']))
np.sum(all_same_track)/len(all_same_track)

In [None]:
for p in all_pairs:
    if p in pair_labels:
        l=pair_labels[p]
        if l in merged_labels:
            pair_labels[p]=merged_labels[l]

In [None]:
def get_hit_labels(restrict_to_labels=None):
    hit_labels=np.zeros(len(hits),dtype=np.int64)
    for i in range(len(hits)):
        this_hit_labels=[pair_labels[p] for p in hit_pairs[i] if p in pair_labels]
        if restrict_to_labels is not None:
            this_hit_labels=[x for x in this_hit_labels if x in restrict_to_labels]
        if len(this_hit_labels) > 0:
            labels,counts=np.unique(this_hit_labels,return_counts=True)
            hit_labels[i]=labels[np.argmax(counts)]
    return hit_labels
hit_labels=get_hit_labels()
labels,counts=np.unique(hit_labels,return_counts=True)
restricted_labels=labels[np.where(counts>2)]
hit_labels=get_hit_labels(restricted_labels)
hits['track_id']=hit_labels

In [None]:
np.unique(hit_labels,return_counts=True)

In [None]:
np.unique(hit_labels,return_counts=True)

In [None]:
full_score(hits)