Skip to content

Commit

Permalink
Parallelised and improved run speed of cell type information permutat…
Browse files Browse the repository at this point in the history
…ion.
  • Loading branch information
BradBalderson committed Nov 21, 2022
1 parent 1c132d7 commit 8b9348c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 23 deletions.
17 changes: 12 additions & 5 deletions stlearn/tools/microenv/cci/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,9 @@ def run_cci(
min_spots: int = 3,
sig_spots: bool = True,
cell_prop_cutoff: float = 0.2,
p_cutoff=0.05,
n_perms=100,
p_cutoff: float=0.05,
n_perms: int=100,
n_cpus: int=1,
verbose: bool = True,
):
"""Calls significant celltype-celltype interactions based on cell-type data randomisation.
Expand Down Expand Up @@ -531,6 +532,8 @@ def run_cci(
raw counting of the cell type interactions with each LR hotspot. This
can still be visualised downstream by setting paramters to plot
significant interactions to false.
n_cpus: int
cpu resources to use.
verbose: bool
True if print dialogue to user during run-time.
Returns
Expand Down Expand Up @@ -565,6 +568,10 @@ def run_cci(
The same as f"per_lr_cci_raw_{use_label}", except
subsetted to significant CCIs.
"""
# Setting threads for paralellisation #
if type(n_cpus) != type(None):
numba.set_num_threads(n_cpus)

ran_lr = "lr_summary" in adata.uns
ran_sig = False if not ran_lr else "n_spots_sig" in adata.uns["lr_summary"].columns
if not ran_lr and not ran_sig:
Expand Down Expand Up @@ -657,7 +664,7 @@ def run_cci(
lr_n_spot_cci_sig = np.zeros((lr_summary.shape[0]))
lr_n_cci_sig = np.zeros((lr_summary.shape[0]))
with tqdm(
total=best_lrs,
total=len(best_lrs),
desc=f"Counting celltype-celltype interactions per LR and permutating {n_perms} times.",
bar_format="{l_bar}{bar} [ time left: {remaining} ]",
disable=verbose == False,
Expand Down Expand Up @@ -733,9 +740,9 @@ def run_cci(
if verbose:
print(
f"Significant counts of cci_rank interactions for all LR pairs in "
f"{f'lr_cci_{use_label}'}"
f"{f'data.uns[lr_cci_{use_label}]'}"
)
print(
f"Significant counts of cci_rank interactions for each LR pair "
f"stored in dictionary {f'per_lr_cci_{use_label}'}"
f"stored in dictionary {f'data.uns[per_lr_cci_{use_label}]'}"
)
25 changes: 16 additions & 9 deletions stlearn/tools/microenv/cci/het.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def count_interactions(

return int_matrix if trans_dir else int_matrix.transpose()


@jit(parallel=True, nopython=False)

This comment has been minimized.

Copy link
@sklam

sklam Oct 17, 2023

Numba is in the middle of removing object-mode fallback and we are looking for code that are relying on it by setting nopython=False instead of forcing objectmode use with forceobj=True. If it's convenient, can you reply to numba/numba#9247 as to why this use of nopython=False is needed?

In addition, parallel=True will have no effect when Numba uses object-mode to compile this function. If by setting nopython=True causes an error, it's likely a bug or a missing feature in Numba.

def get_interaction_pvals(
int_matrix,
n_perms,
Expand All @@ -207,8 +207,14 @@ def get_interaction_pvals(

# Counting how many times permutation of spots cell data creates interaction
# counts greater than that observed, in order to calculate p-values.
greater_counts = np.zeros(int_matrix.shape).astype(int)
indices = np.array([i for i in range(cell_data.shape[0])])
shape_ = (n_perms, int_matrix.shape[0], int_matrix.shape[1])
# Storing the instances where the count is greater randomly for each perm.
# Allows for embarassing parallelisation.
greater_counts = np.zeros(shape_, dtype=np.int64)
indices = np.zeros((cell_data.shape[0]), dtype=np.int64)
for i in range(cell_data.shape[0]):
indices[i] = i

# If dealing with discrete data, no need to randomise columns indendently #
discrete = np.all(np.logical_or(cell_data == 0, cell_data == 1))
for i in range(n_perms):
Expand All @@ -232,15 +238,16 @@ def get_interaction_pvals(
L_bool,
R_bool,
cell_prop_cutoff,
).astype(int)
perm_greater = (perm_matrix >= int_matrix).astype(int)
greater_counts += perm_greater
)
#perm_greater = (perm_matrix >= int_matrix).astype(int)
perm_greater = perm_matrix >= int_matrix
greater_counts[i,:,:] = perm_greater

# Calculating the pvalues #
int_pvals = greater_counts / n_perms
total_greater_counts = greater_counts.sum(axis=0) # cts * ct counts
int_pvals = total_greater_counts / n_perms
return int_pvals


@njit
def get_interaction_matrix(
cell_data,
Expand Down Expand Up @@ -283,7 +290,7 @@ def get_interaction_matrix(

# Counting the number of unique interacting edges
# between different cell type via indicate LR
int_matrix = np.zeros((all_set.shape[0], all_set.shape[0]))
int_matrix = np.zeros((all_set.shape[0], all_set.shape[0]), dtype=np.int64)
edge_i = 0
for i in range(all_set.shape[0]):
for j in range(all_set.shape[0]):
Expand Down
14 changes: 5 additions & 9 deletions stlearn/tools/microenv/cci/het_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ def get_data_for_counting(adata, use_label, mix_mode, all_set):
obs_key, uns_key = use_label, use_label

# Getting the neighbourhoods #
neighbours, neighbourhood_bcs, neighbourhood_indices = get_neighbourhoods(
adata)
#neighbours, neighbourhood_bcs, neighbourhood_indices = get_neighbourhoods(
# adata)

# Getting the cell type information; if not mixtures then populate
# matrix with one's indicating pure spots.
Expand All @@ -221,7 +221,7 @@ def get_data_for_counting(adata, use_label, mix_mode, all_set):
)

spot_bcs = adata.obs_names.values.astype(str)
return spot_bcs, cell_data, neighbourhood_bcs, neighbourhood_indices
return spot_bcs, cell_data, #neighbourhood_bcs, neighbourhood_indices

def get_data_for_counting_OLD(adata, use_label, mix_mode, all_set):
"""Retrieves the minimal information necessary to perform edge counting."""
Expand Down Expand Up @@ -255,7 +255,7 @@ def get_data_for_counting_OLD(adata, use_label, mix_mode, all_set):
spot_bcs = adata.obs_names.values.astype(str)
return spot_bcs, cell_data, neighbourhood_bcs, neighbourhood_indices

@njit
#@njit
def get_neighbourhoods_FAST(spot_bcs: np.array, spot_neigh_bcs: np.ndarray,
n_spots: int, str_dtype: str,
neigh_indices: np.array, neigh_bcs: np.array):
Expand All @@ -280,8 +280,7 @@ def get_neighbourhoods_FAST(spot_bcs: np.array, spot_neigh_bcs: np.ndarray,
neigh_bcs_sub.append( neigh_bc )

#neigh_bcs_array = np.empty((len(neigh_bcs_sub)), str_dtype)
neigh_bcs_array = np.empty(len(neigh_bcs_sub),
dtype=neigh_bcs_sub._dtype)
neigh_bcs_array = np.empty(len(neigh_bcs_sub), dtype=str_dtype)
neigh_indices = np.zeros((len(neigh_bcs_sub)), dtype=np.int64)
for j, neigh_bc in enumerate(neigh_bcs_sub):
neigh_bcs_array[j] = neigh_bc
Expand Down Expand Up @@ -340,9 +339,6 @@ def get_neighbourhoods(adata):
neigh_indices = np.zeros((n_spots), dtype=np.int64)
neigh_bcs = np.empty((n_spots), dtype=str_dtype)

print(type(neigh_indices))
print(type(neigh_bcs))

return get_neighbourhoods_FAST(spot_bcs, spot_neigh_bcs,
n_spots, str_dtype,
neigh_indices, neigh_bcs)
Expand Down

0 comments on commit 8b9348c

Please sign in to comment.