In [3]:
import numpy as np

In [33]:
def sort(crd, axis=-1):
    """Sort coordinates by putting NaN at the end for each list of spots.
    
    Args:
        crd (np.array((ncell, ndomain, ncopy_max, nspot_max, 3), dtype=float))
    """
    
    # Mask the NaN values (taking the x coord, doesn't matter which one)
    nan_mask = np.isnan(crd[:, :, :, :, 0])  # np.array((ncell, ndomain, ncopy_max, nspot_max), dtype=bool)
    # Sort the coordinates by the mask (False first, then True),
    # only on the last axis (the spots axis)
    sort_subgrid = np.argsort(nan_mask, axis=axis)
    # Create a grid of indices to sort the coordinates
    grid = np.meshgrid(*[np.arange(i) for i in nan_mask.shape], indexing='ij')
    # Replace the last axis of the grid by the sorted indices
    grid[axis] = sort_subgrid
    # Convert the grid to a tuple of indices
    idx = tuple(grid)
    # Initialize the sorted coordinates
    crd_srt = np.copy(crd)
    # Sort the coordinates for each spatial dimension
    for i in range(3):
        crd_srt[:, :, :, :, i] = crd[:, :, :, :, i][idx]
    return crd_srt

In [34]:
# Create a sample array
ncell, ndomain, ncopy_max, nspot_max = 3, 10, 2, 4
np.random.seed(42)
crd = np.random.rand(ncell, ndomain, ncopy_max, nspot_max, 3)

# introduce some random NaNs
nan_prob = 0.3
for cellnum in range(ncell):
    for domainnum in range(ndomain):
        for copynum in range(ncopy_max):
            for spotnum in range(nspot_max):
                if np.random.rand() < nan_prob:
                    crd[cellnum, domainnum, copynum, spotnum, :] = np.array([np.nan, np.nan, np.nan])

crd_srt = sort(crd)

In [35]:
for cellnum in range(ncell):
    for domainnum in range(ndomain):
        for copy in range(ncopy_max):
            for i in range(3):
                x = crd[cellnum, domainnum, copy, :, i]
                x_srt = crd_srt[cellnum, domainnum, copy, :, i]
                # check that x and x_srt have the same number of NaNs
                assert np.sum(np.isnan(x)) == np.sum(np.isnan(x_srt))
                # check that has NaNs only at the end
                if np.sum(np.isnan(x_srt)) == 0:
                    continue
                first_nan = np.where(np.isnan(x_srt))[0][0]
                assert np.all(np.isnan(x_srt[first_nan:]))

In [36]:
crd[...]  # np.array((ndomain_chr, ncopy_max, nspot_max, 3), dtype=float)
chroms = ['chr1', 'chr2', 'chr3', 'chrX']
chromstr = np.array(['chr1', 'chr1', 'chr1', 'chr2', 'chr2', 'chr3', 'chrX', 'chrX'])

for cellnum in range(ncell):
    for chrom in chroms:
        crd_cellchr = crd[cellnum, chromstr==chrom, :, :, :]  # np.array((ndomain_chr, ncopy_max, nspot_max, 3), dtype=float)
        crd_cellchr_srt = np.copy(crd_cellchr)
        # loop copy in reverse order
        for cp in range(ncopy_max-1, -1, -1):
            crd_cellchr_cp = crd_cellchr[:, cp, :, :]
            if np.all(np.isnan(crd_cellchr_cp)):
                continue
            for cp2 in range(cp):
                crd_cellchr_cp2 = crd_cellchr[:, cp2, :, :]
                if not np.all(np.isnan(crd_cellchr_cp2)):
                    continue
                # swap the two copies
                crd_cellchr_srt[:, cp, :, :] = crd_cellchr[:, cp2, :, :]
                crd_cellchr_srt[:, cp2, :, :] = crd_cellchr[:, cp, :, :]

SyntaxError: invalid syntax (2017268016.py, line 4)

In [None]:
nchrom = len(chroms)
mask_nan = np.isnan(crd[:, :, :, :, 0])  # np.array((ncell, ndomain, ncopy_max, nspot_max), dtype=bool)
# perform all operations on the last axis (the spots axis)
mask_nan2 = np.all(mask_nan, axis=-1)  # np.array((ncell, ndomain, ncopy_max), dtype=bool)
mask_nan3 = np.zeros((ncell, nchrom, ncopy_max), dtype=bool)
for chromnum, chrom in enumerate(chroms):
    mask_nan3[:, chromnum, :] = np.all(mask_nan2[:, chromstr==chrom, :], axis=1)