In [56]:
from scipy.spatial import distance
import pandas as pd
import numba
from numba import njit, prange
import numpy as np
import dask_distance as dd
import dask.array as da
from random import choices
import time
from multiprocessing import Pool
import sparse

In [57]:
data = pd.read_csv("/media/austin/IPI_8plex_project/8plex/analysis/cell_data_tables/IPICRC058T1_8plex/IPICRC058T1_8plex_single_cell_data_gated_tumorannotated_20220131.csv")
in_tumor = 0 # 0 if analyzing stroma, 1 if analyzing tumor
data = data[data['in_tumor']==in_tumor]
ncells=len(data)

cell_types = list(set(data.cell_type))
try:
    cell_types.remove('other')
except KeyError:
    pass
ncell_types = len(cell_types)

bootstrap_num = 1000
close_num_rand = np.zeros((ncell_types, ncell_types, bootstrap_num))
distance_threshold= 100

In [58]:
centroid_coords = np.stack([data[data['in_tumor']==in_tumor]['centroid-0'],
                            data[data['in_tumor']==in_tumor]['centroid-1']]).astype(np.float32).transpose()

In [59]:
dcentroid_coords = da.from_array(centroid_coords, chunks=(20000,20000))
dist_mat = dd.cdist(dcentroid_coords, dcentroid_coords, metric="euclidean").astype(np.float32)<distance_threshold

  result = blockwise(


In [60]:
# build index dictionary so that truncated distance matrix only needs to be calculated once for all cell-cell pairs

@njit(parallel=True)
def sample_inds(ncells, n_ct1, n_ct2, bootstrap_num):
    ct1_rand_inds = np.zeros((n_ct1,bootstrap_num))
    ct2_rand_inds = np.zeros((n_ct2,bootstrap_num))
        
    for r in prange(bootstrap_num):
        ct1_rand_inds[:,r] = np.sort(np.random.choice(np.arange(ncells), size=n_ct1, replace=True))
        ct2_rand_inds[:,r] = np.sort(np.random.choice(np.arange(ncells), size=n_ct2, replace=True))
    return [ct1_rand_inds, ct2_rand_inds]

index_dictionary = {}
rand_dictionary = {}

for i in np.arange(ncell_types):
    
    ct1_pos_inds = (data['cell_type']==cell_types[i])
    n_ct1 = sum(ct1_pos_inds)

    for j in np.arange(ncell_types):
        
        ct2_pos_inds = (data['cell_type']==cell_types[j])
        n_ct2 = sum(ct2_pos_inds)

        index_dictionary[(i,j)] = [np.array(ct1_pos_inds), np.array(ct2_pos_inds)]
        
        rand_dictionary[i,j] = sample_inds(ncells, n_ct1, n_ct2, bootstrap_num)

In [61]:
def count_close_interactions(dist_mat, index_dictionary, rand_dictionary, ncell_types, bootstrap_num, threshold):
    
    close_num = np.zeros((ncell_types, ncell_types), dtype=np.uint64)
    close_rand = np.zeros((ncell_types, ncell_types, bootstrap_num), dtype=np.uint64)
    
    chunkx, chunky = dist_mat.chunksize
    max_x, max_y = dist_mat.shape
    
    rangex = np.arange(0,max_x-chunkx,chunkx)
    if rangex[-1] != max_x-1:
        np.append(rangex,max_x-1)
        
    rangey = np.arange(0,max_y-chunky,chunky)
    if rangey[-1] != max_y-1:
        np.append(rangey,max_y-1)
        
    
    for i in np.arange(len(rangex)-1):
        for j in np.arange(len(rangey)-1):
            
            start = time.time()
            
            trunc_mat = dist_mat[rangex[i]:rangex[i+1], rangey[i]:rangey[i+1]].compute()
            for ct1_i, ct2_i in index_dictionary.keys():
                
                ct1_idx, ct2_idx = index_dictionary[ct1_i,ct2_i]
                
                ct1=ct1_idx[rangex[i]:rangex[i+1]]
                ct2=ct2_idx[rangey[j]:rangey[j+1]]
                
                grid = np.ix_(ct1, ct2)
                close_num[ct1_i, ct2_i] += np.sum(trunc_mat[grid])
                                
                rand_ct1_idx, rand_ct2_idx = rand_dictionary[ct1_i, ct2_i]
                
                for r in np.arange(bootstrap_num):
                    ct1_idx_r = rand_ct1_idx[:,r]
                    ct2_idx_r = rand_ct2_idx[:,r]
                    rand_ct1 = ct1_idx_r[np.logical_and(ct1_idx_r>rangex[i], ct1_idx_r<rangex[i+1])]-rangex[i]
                    rand_ct2 = ct2_idx_r[np.logical_and(ct2_idx_r>rangey[j], ct2_idx_r<rangey[j+1])]-rangey[j]
    
                    close_rand[ct1_i, ct2_i, r] += np.sum(numba_ix(trunc_mat, rand_ct1, rand_ct2))
            end = time.time()
            
            print(i, j, (end-start)/60)
    return close_num, close_rand

In [62]:
@njit(parallel=True)
def numba_ix(arr, rows, cols):
    """
    Numba compatible implementation of arr[np.ix_(rows, cols)] for 2D arrays.
    :param arr: 2D array to be indexed
    :param rows: Row indices
    :param cols: Column indices
    :return: 2D array with the given rows and columns of the input array
    """
    
    one_d_index = np.zeros(len(rows) * len(cols), dtype=np.int32)
    for i, r in enumerate(rows):
        start = i * len(cols)
        one_d_index[start: start + len(cols)] = cols + arr.shape[1] * r

    arr_1d = arr.reshape((arr.shape[0] * arr.shape[1], 1))
    slice_1d = np.take(arr_1d, one_d_index)
    return slice_1d.reshape((len(rows), len(cols)))


In [63]:
start = time.time()
close_num, close_rand = count_close_interactions(dist_mat, index_dictionary, rand_dictionary, ncell_types, bootstrap_num, threshold=distance_threshold)
end = time.time()
print('Computation took {time_elapse} minutes'.format(time_elapse=(end-start)/60))

0 0 142.08189049164454


IndexError: Index out of bounds

In [64]:
2*(18*18)

648

In [None]:
z = zeros(ncell_types)
muhat = zeros(ncell_types)
sigmahat = zeros(ncell_types)
p = zeros((ncell_types, ncell_types, 2))

In [173]:
for j in np.arange(celltypeNum)
    for k in np.arange(celltypeNum)

        tmp= reshape(closeNumRand(j,k,:),BootstrapNum,1)
        [muhat(j,k),sigmahat(j,k)] = normfit(tmp)
        z(j,k) = (closeNum(j,k)-muhat(j,k))/sigmahat(j,k)
        p(j,k,1) = (1+(sum(tmp>=closeNum(j,k))))/(BootstrapNum+1)
        p(j,k,2) = (1+(sum(tmp<=closeNum(j,k))))/(BootstrapNum+1)


In [41]:
dist_mat

Unnamed: 0,Array,Chunk
Bytes,597.57 GiB,95.37 MiB
Shape,"(400511, 400511)","(5000, 5000)"
Count,46332 Tasks,6561 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 597.57 GiB 95.37 MiB Shape (400511, 400511) (5000, 5000) Count 46332 Tasks 6561 Chunks Type float32 numpy.ndarray",400511  400511,

Unnamed: 0,Array,Chunk
Bytes,597.57 GiB,95.37 MiB
Shape,"(400511, 400511)","(5000, 5000)"
Count,46332 Tasks,6561 Chunks
Type,float32,numpy.ndarray
