# FISH decoding

Aim:  
Identify RNA species from spots coordinates `[round, z, y, x]` and a codebook `{roud | barcode: species}` for non barcoded and barcoded FSIH experiments.

In [43]:
import numpy as np
import pandas as pd
from localize_psf import fish
from tysserand import tysserand as ty
from scipy.spatial.distance import cdist
from itertools import permutations
from copy import copy

In [2]:
def generate_nonbarcoded_spots(species, n_spots, tile_shape):
    """
    Parameters
    ----------
    species: list
        Name of species whose FISH spots are simulated
    n_spots : int | array | list
        Number of spots per species, identical for all species if it's an int,
        otherwise specific to each species if list of array.
    tile_shape : array | list
        Image dimensions, can be any length but 2 or 3 make sense.
    
    Returns
    -------
    coords : DataFrame
        Coordinates of all spots with their identity.
    codebook : dict
        Association from species to rounds.

    Example
    -------
    >>> species = ['a', 'b', 'c']
    >>> n_spots = 10
    >>> tile_shape = [50, 250, 500]
    >>> generate_nonbarcoded_spots(species, n_spots, tile_shape)
    """

    nb_species = len(species)
    nb_dims = len(tile_shape)
    if isinstance(n_spots, int):
        n_spots = [n_spots] * nb_species
    coords = []
    round_ids = []
    specs = []

    for round_id, (spec, n_spec) in enumerate(zip(species, n_spots)):
        spec_coords = [np.random.random(n_spec) * dim for dim in tile_shape]
        spec_coords = np.vstack(spec_coords).T
        coords.append(spec_coords)
        round_ids.extend([round_id] * n_spec)
        specs.extend([spec] * n_spec)
    coords = np.vstack(coords)

    if nb_dims == 2:
        coords_names = ['y', 'x']
    elif nb_dims == 3:
        coords_names = ['z', 'y', 'x']
    else:
        coords_names = [f'dim-{x}' for x in range(nb_dims)]
    
    coords = pd.DataFrame(data=coords, columns=coords_names)
    coords['rounds'] = round_ids
    coords['species'] = specs

    codebook = {key: var for key, var in zip(species, range(nb_species))}

    return coords, codebook
    

In [3]:
def identify_nonbarcoded_spots(coords, codebook):
    """
    Parameters
    ----------
    coords : DataFrame
        Coordinates of all spots.
    codebook : dict
        Association from species to rounds.
    
    Returns
    -------
    species : DataFrame
        Decoded FISH spots identities with their coordinates.

    Example
    -------
    >>> species_labels = ['a', 'b', 'c']
    >>> n_spots = 10
    >>> tile_shape = [50, 250, 500]
    >>> coords, codebook = generate_nonbarcoded_spots(species_labels, n_spots, tile_shape)
    >>> measured_coords = coords.drop(columns=['species'])
    >>> identify_nonbarcoded_spots(measured_coords, codebook)
    """

    # make dictionnary round --> species from dictionnary species --> round
    inv_codebook = {val: key for key, val in codebook.items()}
    species = coords.copy()
    species['species'] = species['rounds'].map(inv_codebook)
    return species

In [4]:
def generate_barcoded_spots(species, n_spots, tile_shape, codebook):
    """
    Parameters
    ----------
    species: list
        Name of species whose FISH spots are simulated
    n_spots : int | array | list
        Number of spots per species, identical for all species if it's an int,
        otherwise specific to each species if list of array.
    tile_shape : array | list
        Image dimensions, can be any length but 2 or 3 make sense.
    codebook : dict
        Association from species to rounds barcodes, given by
        a string of zeros and ones like '01001101'.
    
    Returns
    -------
    coords : DataFrame
        Coordinates of all spots with their identity.

    Example
    -------
    >>> species = ['a', 'b', 'c']
    >>> n_spots = 4
    >>> tile_shape = [25, 250, 500]
    >>> codebook = {'a': '01', 'b': '10', 'c': '11'}
    >>> generate_barcoded_spots(species, n_spots, tile_shape, codebook)
    """

    # TODO: implement noise in successive coordinates, use uniform distribution
    nb_species = len(species)
    nb_dims = len(tile_shape)
    nb_rounds = len(next(iter(codebook.values())))
    if isinstance(n_spots, int):
        n_spots = [n_spots] * nb_species
    coords = []
    round_ids = []
    specs = []

    for (spec, n_spec) in zip(species, n_spots):
        barcode = codebook[spec]
        # go from '01001101' to [1, 4, 5, 7]
        spec_rounds = [i for i, x in enumerate(barcode) if x == '1']
        # number of positive rounds
        n_pos_round = sum([int(x) for x in barcode])
        # total number of spots across images
        n_spec_tot  = n_spec * n_pos_round

        # generate coordinates, identical across rounds
        spec_coords = [np.random.random(n_spec) * dim for dim in tile_shape]
        spec_coords = np.vstack(spec_coords).T
        # repeat and stack coordinates to match number of rounds
        for _ in range(n_pos_round - 1):
            spec_coords = np.vstack([spec_coords, spec_coords])
        coords.append(spec_coords)
        # indicate at what round coordinates are observed
        for round_id in spec_rounds:
            round_ids.extend([round_id] * n_spec)
        # indicate the ground truth species for all spots
        specs.extend([spec] * n_spec_tot)
    coords = np.vstack(coords)

    if nb_dims == 2:
        coords_names = ['y', 'x']
    elif nb_dims == 3:
        coords_names = ['z', 'y', 'x']
    else:
        coords_names = [f'dim-{x}' for x in range(nb_dims)]
    
    coords = pd.DataFrame(data=coords, columns=coords_names)
    coords['rounds'] = round_ids
    coords['species'] = specs

    return coords

In [5]:
species = ['a', 'b', 'c']
n_spots = 3
tile_shape = [25, 250, 500]
codebook = {'a': '01', 'b': '10', 'c': '11'}
coords = generate_barcoded_spots(species, n_spots, tile_shape, codebook)

In [6]:
coords

Unnamed: 0,z,y,x,rounds,species
0,6.52619,66.26363,37.694211,1,a
1,21.896308,70.17479,383.429816,1,a
2,6.455076,159.293727,387.835716,1,a
3,4.398017,189.69977,42.396012,0,b
4,5.090927,161.164849,211.913255,0,b
5,17.802741,174.475235,88.545973,0,b
6,20.079602,14.120343,270.644448,0,c
7,18.047167,33.550688,33.45474,0,c
8,11.23932,114.389863,212.461001,0,c
9,20.079602,14.120343,270.644448,1,c


In [7]:
def identify_barcoded_spots(coords, codebook, distances):
    """
    Parameters
    ----------
    coords : DataFrame
        Coordinates of all spots.
    codebook : dict
        Association from species to rounds.
    
    Returns
    -------
    species : DataFrame
        Decoded FISH spots identities with their coordinates.

    Example
    -------
    >>> species_labels = ['a', 'b', 'c']
    >>> n_spots = 10
    >>> tile_shape = [50, 250, 500]
    >>> coords, codebook = generate_nonbarcoded_spots(species_labels, n_spots, tile_shape)
    >>> measured_coords = coords.drop(columns=['species'])
    >>> identify_nonbarcoded_spots(measured_coords, codebook)
    """

    nb_rounds = len(next(iter(codebook.values())))
    # make dictionnary barcode --> species from dictionnary species --> barcode
    inv_codebook = {val: key for key, val in codebook.items()}
    max_z = distances[0]
    max_xy = distances[1]
    
    # max projection of coordinates over rounds:
    # ignore the round information and fuse neighbooring spots
    proj_coords = coords.iloc[:, 0:3]
    # proj_coords = fish.merge_peaks(proj_coords, max_z=distances[0], max_xy=distances[1], weights=None, method='network', verbose=True)
    nb_nodes = len(proj_coords)
    for node_id in range(nb_nodes):
        for round_id in range(nb_rounds):
            # build the radial distance network using the bigest radius: max distance along z axis
            pairs = ty.build_rdn(coords=proj_coords, r=max_z)
            if len(pairs) == 0:
                # all nodes are well separated from each other, do nothing
                merged_coords = proj_coords
            else:
                # fuse nodes that have too close neighbors
                source = proj_coords[pairs[:, 0]]
                target = proj_coords[pairs[:, 1]]
                # compute the 2 distances arrays
                dist_z, dist_xy = fish.compute_distances(source, target)
                # perform grph cut from the 2 distance thresholds
                _, pairs = fish.cut_graph_bidistance(dist_z, dist_xy, max_z, max_xy, pairs=pairs)




Barcode reconstruction algorithm:
  - for each spot in each round, we reconstruct all potential barcodes:
    - we look for its neighbors in a given radius across rounds, there can be multiple ones per round
    - if there is at least 1 neighbors in a round, the corresponding barcode bit is 1, else 0
  - clean redundant barcodes related to same locations:
    - merge identical barcodes within the same area
    - keep differing barcodes within the same area

If barcode reconstruction per round is a single function, it's easy to parallelize per round. If we have many more cores, we can split spots coordinates per round across several cores.

In [114]:
a = np.arange(20) 
# for i in np.nditer(a, op_flags=['readwrite']):
#     print(i)
#     print(a)
#     a = np.delete(a, i+1)
#     print(a)

# i = 0
# while i < len(a):
#     print(i, a[i])
#     print(a)
#     a = np.delete(a, i+1)
#     print(a)
#     i += 1

a = np.array([[1, 1], [0, 1], [2, 3], [3, 1], [4, 4], [5, 4]])



k = 0
while k < len(a):
    i, j = a[k]
    print(k, a[k])
    print(a)
    if i == j:
        a = a[a[:, 1] != j, :]
    else:
        k += 1
    print(a)
    


0 [1 1]
[[1 1]
 [0 1]
 [2 3]
 [3 1]
 [4 4]
 [5 4]]
[[2 3]
 [4 4]
 [5 4]]
1 [4 4]
[[2 3]
 [4 4]
 [5 4]]
[[2 3]]


In [163]:
def array_to_dict(arr):
    return dict(enumerate(arr))

def dict_to_array(dico):
    return np.array(list(dico.values()))
    
def compute_distances(source, target, dist_method='xy_z_orthog', metric='euclidean', tilt_vector=None):
    """
    Parameters
    ----------
    source : ndarray
        Coordinates of the first set of points.
    target : ndarray
        Coordinates of the second set of points.
    dist_method : str
        Method used to compute distances. If 'isotropic', standard distances are computed considering all axes
        simultaneously. If 'xy_z_orthog' 2 distances are computed, for the xy plane and along the z axis 
        respectively. If 'xy_z_tilted' 2 distances are computed for the tilted plane and its normal axis.
    
    Example
    -------
    >>> source = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]])
    >>> target = np.array([[0, 0, 0], [-3, 0, 2], [0, 0, 10]])
    >>> compute_distances(source, target)
        (array([0, 4, 0]), array([0., 2., 5.]))
    >>> compute_distances(source, target, metric='L1')
        (array([0, 4, 0]), array([0, 2, 7]))
    
    """
    if dist_method == 'isotropic':
        dist = cdist(source, target, metric=metric)
        return dist
        
    elif dist_method == 'xy_z_orthog':
        dist_xy = cdist(source[:, 1:], target[:, 1:], metric=metric)
        dist_z = cdist(source[:, 0].reshape(-1, 1), target[:, 0].reshape(-1, 1), metric=metric)
        return dist_z, dist_xy

    elif dist_method == 'xy_z_tilted':
        raise NotImplementedError("Method 'xy_z_tilted' will be implemented soon")
        
        
def find_neighbor_spots_in_round(source, target, dist_method='xy_z_orthog', 
                                 metric='euclidean', dist_params=None,
                                 return_bool=False):
    """
    For each spot in a given round ("source"), find if there are neighbors 
    in another round ("target") within a given distance.

    Parameters
    ----------
    source : ndarray
        Coordinates of spots in the source round.
    target : ndarray
        Coordinates of spots in the target round.
    dist_method : str
        Method used to compute distance between spots.
        Can be isotropic, or xy_z_orthog
    dist_params : float or array
        Threshold distance to classify spots as neighbors. 
        Multiple threshold can be used depending on the method, typically 2 
        to have a threshold for the xy plane and one for the z axis.
    return_bool : bool
        If True, return a vector indicating the presence of neighbors
        for spots in the source set.

    Returns
    -------
    pairs : ndarray
        Pairs of neighbors.
    has_neighb : array
        Array indicating the presence of neighbors for each spot in their source round.
    
    Example
    -------
    >>> source = np.array([[0, 0, 0],
                           [0, 2, 0]])
    >>> target = np.array([[0, 0, 0],
                           [1, 0, 0],
                           [0, 2, 0],
                           [0, 0, 3]])
    """

    # Compute all distances between spots of given round and all other spots of other round
    dist = compute_distances(source, target, dist_method=dist_method, metric=metric)
    # check if distances below threshold for all dimensions
    if dist_method == 'xy_z_orthog':
        is_neighb = np.logical_and(dist[0] < dist_params[0], dist[1] < dist_params[1])
    elif dist_method == 'isotropic':
        is_neighb = np.logical_and(dist < dist_params)
    
    if return_bool:
        # detect if there is any neighbor for each spot
        has_neighb = np.any(is_neighb, axis=1)
        return has_neighb
    else:
        # extract pairs of neighboring spots
        y, x = np.where(is_neighb)
        pairs = np.vstack([y, x]).T
        return pairs


def make_all_rounds_pairs(start=0, end=16):
    pairs_rounds = list(permutations(range(0, 16), 2))
    return pairs_rounds


def assemble_barcodes(neighbors):
    """
    Parameters
    ----------
    neighbors : dict[dict[array]]
        Dictionnary of dictionnaries, where the fist level of keys is the
        set of source rounds, and the second level of key is the set of
        target round. Each second level value is an array indicating the 
        presence of neighbors from spots in the source round to spots in
        the target round.
    
    Returns
    -------
    barcodes : dict[array]
        Dictionnary of barcodes found for each spot in a source round,
        keys indicate the id of the source round.
    
    Example
    -------
    >>> neighbors = {2: {1: [0, 1, 2, 3],
                         0: [4, 5, 6, 7]},
                     0: {1: [8, 9, 10, 11],
                         2: [12, 13, 14, 15]},
                     1: {2: [16, 17, 18, 19],
                         0: [20, 21, 22, 23]}}
    >>> assemble_barcodes(neighbors)
    {2: array([[4, 0, 1],
               [5, 1, 1],
               [6, 2, 1],
               [7, 3, 1]]),
     0: array([[ 1,  8, 12],
               [ 1,  9, 13],
               [ 1, 10, 14],
               [ 1, 11, 15]]),
     1: array([[20,  1, 16],
               [21,  1, 17],
               [22,  1, 18],
               [23,  1, 19]])}
    """

    # dictionary storing all barcodes matrices for each round
    barcodes = {}
    # for each round, stack vectors of neighbors into arrays across target rounds
    for round_source, round_targets in neighbors.items():
        # get sorted list of target round IDs
        round_ids = np.unique([i for i in round_targets.keys()])
        # initialize empty array
        nb_neigh = len(round_targets[round_ids[0]])
        nb_rounds = round_ids.size + 1 # because we consider current source round
        # get the type of data and choose between1, 1.0 and True
        fill_value = round_targets[round_ids[0]][0]
        if isinstance(fill_value, bool):
            fill_value = True
        elif isinstance(fill_value, int):
            fill_value = 1
        else:
            fill_value = 1.0
        # initilize array, which sets bits of the current source round to 1 or True
        round_barcode = np.full(shape=(nb_neigh, nb_rounds), fill_value=fill_value)
        # stack each vector in the array
        for round_id in round_ids:
            round_barcode[:, round_id] = round_targets[round_id]
        # save the array in the barcode dictionary
        barcodes[round_source] = round_barcode
    return barcodes


def clean_barcodes(barcodes, coords, min=3, max=5):
    """
    Remove barcodes and their corresponding coordinates if they
    have too few or too many positive bits.

    Parameters
    ----------
    barcodes : dict[array]
        Dictionnary of barcodes found for each spot in a source round,
        keys indicate the id of the source round.
    coords : list(arrays)
        Coordinates of all spots in rounds.
    min : int
        Minimum number of positive bits each barcode should have.
    max : int
        Maximum number of positive bits each barcode should have.
    
    Returns
    -------
    barcodes : dict[array]
        Dictionnary of barcodes found for each spot in a source round,
        keys indicate the id of the source round.
    coords : list(arrays)
        Coordinates of all spots in rounds.
    
    Example
    -------
    >>> barcodes = {0: np.array([[1, 1, 1, 1, 1, 1],
                                 [0, 0, 1, 1, 1, 0]]),
                    1: np.array([[1, 0, 0, 1, 0, 0],
                                 [1, 0, 1, 0, 1, 1]])}
    >>> coords = [np.array([[1, 2, 3],
                            [4, 5, 6]]),
                  np.array([[7, 8, 9],
                            [10, 11, 12]])]
    >>> clean_barcodes(barcodes, coords, min=3, max=5)
    ({0: array([[0, 0, 1, 1, 1, 0]]), 1: array([[1, 0, 1, 0, 1, 1]])},
    [array([[4, 5, 6]]), array([[10, 11, 12]])])
    """

    for rd_id, rd_barcode in barcodes.items():
        select = np.logical_and(rd_barcode.sum(axis=1) >= min,
                                rd_barcode.sum(axis=1) <= max)
        # selection by key to be sure assignment is effective, necessary?
        barcodes[rd_id] = rd_barcode[select, :] 
        coords[rd_id] = coords[rd_id][select, :]
    return barcodes, coords

def merge_barcodes_pairs_rounds(barcodes_1, barcodes_2, coords_1, coords_2, 
                                dist_method='xy_z_orthog', metric='euclidean', 
                                dist_params=None):
    """
    Merge barcodes and their corresponding coordinates in a pair of rounds
    when they are identical and close enough to each other.

    Parameters
    ----------
    barcodes_1 : ndarray
        Barcodes of first round, shape (n_barcodes, n_rounds).
    barcodes_2 : ndarray
        Barcodes of second round, shape (n_barcodes, n_rounds).
    coords_1 : ndarray
        Coordinates of barcodes of first round, shape (n_barcodes, dim_image).
    coords_2 : ndarray
        Coordinates of barcodes of second round, shape (n_barcodes, dim_image).
    dist_method : str
        Method used to compute distance between spots.
        Can be isotropic, or xy_z_orthog
    metric : str
        Metric used to compute distance between 2 points.
    dist_params : float or array
        Threshold distance to classify spots as neighbors. 
        Multiple threshold can be used depending on the method, typically 2 
        to have a threshold for the xy plane and one for the z axis.
    
    Returns
    -------
    barcodes_out : ndarray
        Merged barcodes.
    coords_out : ndarray
        Merged coordinates of barcodes.

    Example
    -------
    >>> barcodes_1 = np.array([[1, 1, 1, 1],
                               [0, 0, 0, 0]])
    >>> barcodes_2 = np.array([[1, 1, 1, 1],
                               [0, 0, 0, 0],
                               [0, 0, 0, 0]])
    >>> coords_1 = np.array([[0, 0, 0],
                             [1, 2, 2]])
    >>> coords_2 = np.array([[0, 0, 0],
                             [0, 0, 0],
                             [2, 2, 2]])
    >>> merge_barcodes_pairs_rounds(barcodes_1, barcodes_2, coords_1, coords_2, 
                                    dist_params=[0.6, 0.2])
        (array([[1, 1, 1, 1],
                [0, 0, 0, 0],
                [0, 0, 0, 0],
                [0, 0, 0, 0]]),
        array([[0, 0, 0],
                [1, 2, 2],
                [0, 0, 0],
                [2, 2, 2]]))
    """

    # find all pairs of neighbors between the 2 rounds
    pairs = find_neighbor_spots_in_round(coords_1, coords_2, dist_method, 
                                            metric, dist_params)
    # change to dictionary to delete entries / indices without shifting
    # element by more than one index if previous elements need to be discarded
    barcodes_2 = array_to_dict(barcodes_2)
    coords_2 = array_to_dict(coords_2)

    # very manual iteration to allow modification of the `pairs` array 
    # while iterating over it
    k = 0
    while k < len(pairs):
        i, j = pairs[k]
        if np.all(barcodes_1[i] == barcodes_2[j]):
            # delete barcode and coordinates in the second set
            del barcodes_2[j]
            del coords_2[j]
            select = np.logical_or(k < np.arange(len(pairs)), pairs[:, 1] != j)
            pairs = pairs[select, :]
        else:
            # need to increment only if we don't delete current element in pairs
            k += 1
    # convert back to array for stacking and future distance computation
    barcodes_2 = dict_to_array(barcodes_2)
    coords_2 = dict_to_array(coords_2)

    # stack all remaining barcodes and coordinates
    barcodes_out = np.vstack([barcodes_1, barcodes_2])
    coords_out = np.vstack([coords_1, coords_2])
    
    return barcodes_out, coords_out


def make_pyramidal_pairs(base):
    """
    Make successive lists resulting from merging pairs in previous list,
    until a list of a unique pair is reached.

    Parameters
    ----------
    base : list
        A list of elements that will be successively merged.
    
    Returns
    -------
    pyramid : list
        A list of lists, each of them containing pairs of merged
        items from the previous list.
    
    Example
    -------
    >>> base = list(range(5))
    >>> make_pyramidal_pairs(base)
    [[0, 1, 2, 3, 4], [[0, 1], [2, 3], [4]], [[0, 2], [4]], [[0, 4]]]
    """

    # Make first level of the pyramid with base
    pyramid = [base]
    # First iteration in numbers, not pairs of number
    level = [[pyramid[-1][2*i], pyramid[-1][2*i + 1]] for i in range(len(pyramid[-1]) // 2)]
    if len(pyramid[-1]) % 2 == 1:
        level.append([pyramid[-1][-1]])
    pyramid.append(level)
    # Next iterations on pairs of numbers
    while len(pyramid[-1]) > 1:
        level = [[pyramid[-1][2*i][0], pyramid[-1][2*i + 1][0]] for i in range(len(pyramid[-1]) // 2)]
        if len(pyramid[-1]) % 2 == 1:
            level.append([pyramid[-1][-1][0]])
        pyramid.append(level)
    return pyramid
    

def merge_barcodes(barcodes, coords, dist_method='xy_z_orthog', 
                   metric='euclidean', dist_params=None):
    """
    Merge all barcodes and their corresponding coordinates in all rounds
    when they are identical and close enough to each other.

    Parameters
    ----------
    barcodes : ndarray
        Barcodes detected starting from spots in all rounds, shape (n_barcodes, n_rounds).
    coords : ndarray
        Coordinates of barcodes, shape (n_barcodes, dim_image).
    dist_method : str
        Method used to compute distance between spots.
        Can be isotropic, or xy_z_orthog
    metric : str
        Metric used to compute distance between 2 points.
    dist_params : float or array
        Threshold distance to classify spots as neighbors. 
        Multiple threshold can be used depending on the method, typically 2 
        to have a threshold for the xy plane and one for the z axis.
    
    Returns
    -------
    barcodes_out : ndarray
        Merged barcodes.
    coords_out : ndarray
        Merged coordinates of barcodes.

    Example
    -------
    >>> barcodes = {0: np.array([[1, 1, 1, 1],
                                 [1, 1, 0, 0],
                                 [1, 0, 1, 0],
                                 [1, 0, 0, 1]),
                    1: np.array([[1, 1, 1, 1],
                                 [0, 1, 1, 0]]),
                    2: np.array([[1, 1, 1, 1],
                                 [1, 0, 1, 0]]),
                    3: np.array([[1, 1, 1, 1],
                                 [1, 1, 1, 1],
                                 [1, 0, 0, 1]]),}
    >>> coords = [np.zeros_like(barcodes[0],
                  np.zeros_like(barcodes[1]),
                  np.zeros_like(barcodes[2]),
                  np.array([[0, 0, 0],
                            [5, 5, 5],
                            [0, 0, 0]])]
    >>> merge_barcodes(barcodes, coords, dist_params=[0.6, 0.2])
    """

    # get sorted list of round IDs
    round_ids = np.unique([i for i in barcodes.keys()])

    # Pyramidal merge of pairs of rounds
    pyram_levels = make_pyramidal_pairs(round_ids)
    # something like [[0, 1, 2, 3], [[0, 1], [2, 3]], [[0, 2]]]
    for level_pairs in pyram_levels[1:]:
        # for ex: [[0, 1], [2, 3]]
        print(level_pairs)
        for pair in level_pairs:
            print(pair)
            # for ex: [0, 1]
            if len(pair) == 2:
                # avoid runing on a singlet
                barcodes_1 = barcodes[pair[0]]
                barcodes_2 = barcodes[pair[1]]
                coords_1 = coords[pair[0]]
                coords_2 = coords[pair[1]]

                barcodes[pair[0]], coords[pair[0]] = merge_barcodes_pairs_rounds(
                    barcodes_1, barcodes_2, coords_1, coords_2, 
                    dist_method, metric, dist_params,
                    )
                # # clean-up space
                # barcodes[pair[1]] = None
                # coords[pair[1]] = None
    return barcodes, coords
            
    

def find_neighbor_spots_across_rounds(coords, round_id, 
                                      dist_method='xy_z_orthog', metric='euclidean', 
                                      dist_params=None):
    """
    For each spot in a given round, find if there are neighbors 
    in each other rounds within a given distance, and reconstruct barcodes from that.

    Parameters
    ----------
    spots_coords : ndarray
        Coordinates of spots in a given round.
    coords : list(arrays)
        Coordinates of all spots in rounds.
    round_id : int
    dist_method : str
        Method used to compute distance between spots.
        Can be isotropic, or xy_z_orthog
    metric : str
        Metric used to compute distance between 2 points.
    dist_params : float or array
        Threshold distance to classify spots as neighbors. 
        Multiple threshold can be used depending on the method, typically 2 
        to have a threshold for the xy plane and one for the z axis.

    Returns
    -------
    barcode : array
        Reconstructed barcode from data around each spot location.
    
    Example
    -------
    >>> spots_coords = np.array([[0, 0, 0],
                                [0, 2, 0]])
    >>> round_coords = np.array([[0, 0, 0],
                                 [1, 0, 0],
                                 [0, 2, 0],
                                 [0, 0, 3]])
    """
    
    nb_rounds = len(coords)
    round_pairs = make_all_rounds_pairs(start=0, end=nb_rounds)

    # store all potential neighbors decected from each round to the other
    all_neighbors= {}
    for pair in round_pairs:
        neighbors = find_neighbor_spots_in_round(coords[pair[0]], coords[pair[1]], return_bool=True)

    barcodes = assemble_barcodes(neighbors)

    # remove barcodes that have too few or too many positive bits
    barcodes = clean_barcodes(barcodes, coords)

    barcodes = merge_barcodes(barcodes, coords)

    return barcodes


In [164]:
barcodes = {0: np.array([[1, 1, 1, 1],
                         [1, 1, 0, 0],
                         [1, 0, 1, 0],
                         [1, 0, 0, 1]]),
            1: np.array([[1, 1, 1, 1],
                         [0, 1, 1, 0]]),
            2: np.array([[1, 1, 1, 1],
                         [1, 0, 1, 0]]),
            3: np.array([[1, 1, 1, 1],
                         [1, 1, 1, 1],
                         [1, 0, 0, 1]]),}
coords = [np.zeros_like(barcodes[0]),
          np.zeros_like(barcodes[1]),
          np.zeros_like(barcodes[2]),
          np.array([[0, 0, 0, 0],
                    [5, 5, 5, 5],
                    [0, 0, 0, 0]])]

In [165]:
barcodes

{0: array([[1, 1, 1, 1],
        [1, 1, 0, 0],
        [1, 0, 1, 0],
        [1, 0, 0, 1]]),
 1: array([[1, 1, 1, 1],
        [0, 1, 1, 0]]),
 2: array([[1, 1, 1, 1],
        [1, 0, 1, 0]]),
 3: array([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 0, 0, 1]])}

In [166]:
coords

[array([[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]]),
 array([[0, 0, 0, 0],
        [0, 0, 0, 0]]),
 array([[0, 0, 0, 0],
        [0, 0, 0, 0]]),
 array([[0, 0, 0, 0],
        [5, 5, 5, 5],
        [0, 0, 0, 0]])]

In [167]:
dist_method='xy_z_orthog'
metric='euclidean'
dist_params=[0.6, 0.2]

# get sorted list of round IDs
round_ids = np.unique([i for i in barcodes.keys()])

# Pyramidal merge of pairs of rounds
pyram_levels = make_pyramidal_pairs(round_ids)
# something like [[0, 1, 2, 3], [[0, 1], [2, 3]], [[0, 2]]]
for level_pairs in pyram_levels[1:]:
    # for ex: [[0, 1], [2, 3]]
    print(level_pairs)
    for pair in level_pairs:
        print(pair)
        # for ex: [0, 1]
        barcodes_1 = barcodes[pair[0]]
        barcodes_2 = barcodes[pair[1]]
        coords_1 = coords[pair[0]]
        coords_2 = coords[pair[1]]

        barcodes[pair[0]], coords[pair[0]] = merge_barcodes_pairs_rounds(
            barcodes_1, barcodes_2, coords_1, coords_2, 
            dist_method, metric, dist_params,
            )
        barcodes[pair[1]] = None
        coords[pair[1]] = None 

[[0, 1], [2, 3]]
[0, 1]


KeyError: 0

In [145]:
barcodes

{0: array([[1, 1, 1, 1],
        [1, 1, 0, 0],
        [1, 0, 1, 0],
        [1, 0, 0, 1],
        [1, 1, 1, 1],
        [0, 1, 1, 0],
        [1, 1, 1, 1],
        [1, 0, 1, 0],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 0, 0, 1]]),
 1: None,
 2: None,
 3: None}

In [146]:
coords

[array([[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [5, 5, 5, 5],
        [0, 0, 0, 0]]),
 None,
 None,
 None]