In [None]:
import numpy as np
import fastcluster as fc
from scipy.cluster.hierarchy import fcluster



def encode(arr):
    encode_arr = (200 * np.arange(len(arr)) + 1) ** 2
    return np.sum(arr * encode_arr)

def encode_intensity(cells):
    out_arr = np.empty(len(cells), dtype=int)
    for i, cell in enumerate(cells):
        intensity = cell.data.data_dict['storm_inner']['intensity']
        out_arr[i] = encode(intensity)

    return out_arr

def match_cells(gt_cells, m_cells, storm_input, filtered_binaries, max_d=5):
    
    
    img_numbers = np.array([int(re.findall(r'(\d+)', cell.name)[0]) for cell in m_cells])
    encoded_gt = encode_intensity(gt_cells[:9000])
    
    
    gt_matched, m_matched = []
    
    for i in np.unique(storm_input('frame')):  # Iteration starts at 1 (ImageJ indexing)
        st_elem = storm_input[storm_input['frame'] == i]
        X = np.array([st_elem['x'], st_elem['y']]).T.copy()
        linkage = fc.linkage(X)
        clusters = fcluster(cluster, max_d, criterion='distance')
        clustered_st = [st_elem[clusters == i] for i in np.unique(clusters)]
        encoded_storm = [encode(elem['intensity']) for elem in clustered_st]

        
        
        s_cells = m_cells[img_numbers == (i - 1)]
        coms_cells = [mh.center_of_mass(binary_img == j) for j in cell_numbers]
        binary_img = filtered_binaries[i - 1]
        cell_numbers = np.array([int(re.findall(r'(\d+)', cell.name)[1]) for cell in s_cells])
        
        for cluster, code in zip(clustered_storm, encoded_storm):
            
            # Find the GT cell
            idx_gt = np.argwhere(code == encoded_storm)
            if not idx_gt:
                print('Cluster not in cells, probably bordering cell')
                continue
            else:
                gt_cell = gt_cells[idx_gt]
            
            
            # Find the M cell
            com_storm = [np.mean(cluster['y']), np.mean(cluster['x'])]
            ds = np.sqrt((com[:, 0] - coms_storm[0])**2 + (com[:, 1] - coms_storm[1])**2)
            
            
            
            idx_m = np.argmin(ds)
            if np.min(ds) > 10:
                print('Too far away')
                continue
            else:
                m_cell = s_cells[idx_m]
                
            gt_matched.append(gt_cell)
            m_matched.append(m_cell)
                
    return gt_matched, m_matched
        
        