diff --git a/dance/datasets/spatial.py b/dance/datasets/spatial.py index df369070..19d67b76 100644 --- a/dance/datasets/spatial.py +++ b/dance/datasets/spatial.py @@ -1,28 +1,15 @@ -import csv import glob import os import os.path as osp -import pickle as pkl -import random -import re -import time as tm import warnings -from collections import defaultdict -from operator import itemgetter import anndata import cv2 -import networkx as nx -import numpy as np import pandas as pd import rdata import scanpy as sc -import scipy.sparse -from anndata import AnnData -from scipy.stats import uniform from dance.data import download_file, download_unzip, unzip_file -from dance.transforms import preprocess IGNORED_FILES = ["readme.txt"] @@ -181,7 +168,6 @@ def is_complete(self): check = [self.data_dir + "/mix_count.*", self.data_dir + "/ref_sc_count.*"] for i in check: - #if not os.path.exists(i): if not glob.glob(i): print("lack {}".format(i)) return False @@ -195,8 +181,6 @@ def load_data(self): self.data = {} files = os.listdir(self.data_dir + "/") - filenames = [f.split(".")[0] for f in files] - extensions = [f.split(".")[1] for f in files] for f in files: DataPath = self.data_dir + "/" + f filename = f.split(".")[0] @@ -225,9 +209,9 @@ def __init__(self, data_id="GSE174746", data_dir="data/spatial", build_graph_fn= self.data_id = data_id self.data_dir = osp.join(data_dir, data_id) self.data_url = cellDeconvo_dataset[data_id] - self.load_data() + self._load_data() - def load_data(self): + def _load_data(self): if not osp.exists(self.data_dir): download_unzip(self.data_url, self.data_dir) @@ -244,6 +228,15 @@ def load_data(self): else: warnings.warn(f"Unsupported file type {ext!r}. Use csv or h5ad file types.") + def load_data(self): + ref_count = self.data["ref_sc_count"] + ref_annot = self.data["ref_sc_annot"] + count_matrix = self.data["mix_count"] + cell_type_portion = self.data["true_p"] + if (spatial := self.data.get("spatial_location")) is None: + spatial = pd.DataFrame(0, index=count_matrix.index, columns=["x", "y"]) + return ref_count, ref_annot, count_matrix, cell_type_portion, spatial + class CARDSimulationRDataset: ref_sc_count_url: str = "https://www.dropbox.com/s/wchoppxcsulk8ev/split2_ref_sc_count.h5ad?dl=1" diff --git a/dance/modules/spatial/cell_type_deconvo/card.py b/dance/modules/spatial/cell_type_deconvo/card.py index 0529fba0..11df0e2b 100644 --- a/dance/modules/spatial/cell_type_deconvo/card.py +++ b/dance/modules/spatial/cell_type_deconvo/card.py @@ -12,7 +12,8 @@ import numpy as np import pandas as pd -from scipy.spatial.distance import pdist, squareform + +from dance.utils.matrix import pairwise_distance def obj_func(trac_xxt, UtXV, UtU, VtV, mGene, nSample, b, Lambda, beta, vecOne, V, L, alpha, sigma_e2=None): @@ -118,10 +119,6 @@ class Card: Reference single cell RNA-seq counts data. sc_meta : pd.DataFrame Reference cell-type label information. - spatial_count : pd.DataFrame - Target spatial RNA-seq counts data to be deconvoluted. - spatial_location : pd.DataFrame, optional - Optional spatial location for the spatial transcriptomic reads. ct_varname : str, optional Name of the cell-types column. ct_select : str, optional @@ -141,12 +138,10 @@ class Card: """ - def __init__(self, sc_count, sc_meta, spatial_count, spatial_location=None, ct_varname=None, ct_select=None, - cell_varname=None, sample_varname=None, minCountGene=100, minCountSpot=5, basis=None, markers=None): + def __init__(self, sc_count, sc_meta, ct_varname=None, ct_select=None, cell_varname=None, sample_varname=None, + minCountGene=100, minCountSpot=5, basis=None, markers=None): self.sc_count = sc_count self.sc_meta = sc_meta - self.spatial_count = spatial_count - self.spatial_location = spatial_location self.ct_varname = ct_varname self.ct_select = ct_select self.cell_varname = cell_varname @@ -161,6 +156,15 @@ def __init__(self, sc_count, sc_meta, spatial_count, spatial_location=None, ct_v self.cell_varname = "auto_cell_index" self.sc_meta[self.cell_varname] = self.sc_meta.index + self.createscRef() # create basis + all_genes = sc_count.columns.tolist() + gene_to_idx = {j: i for i, j in enumerate(all_genes)} + not_mt_genes = [i for i in all_genes if not i.lower().startswith("mt-")] + selected_genes = self.selectInfo(not_mt_genes) + selected_gene_idx = list(map(gene_to_idx.get, selected_genes)) + self.gene_mask = np.zeros(len(all_genes), dtype=np.bool) + self.gene_mask[selected_gene_idx] = True + def createscRef(self): """CreatescRef - create reference basis matrix from reference scRNA-seq.""" countMat = self.sc_count.copy() # cell by gene matrix @@ -193,20 +197,19 @@ def createscRef(self): S = S_JK.mean(axis=0).to_frame().unstack().droplevel(0) S = S[sc_meta[ct_varname].unique()] countMat["ct_sample_id"] = ct_sample_id - Theta_S_colMean = countMat.groupby(ct_sample_id).agg("mean") + Theta_S_colMean = countMat.groupby(ct_sample_id).mean(numeric_only=True) tbl_sample = countMat.groupby([ct_sample_id]).size() tbl_sample = tbl_sample.reindex_like(Theta_S_colMean) tbl_sample = tbl_sample.reindex(Theta_S_colMean.index) - Theta_S_colSums = countMat.groupby(ct_sample_id).agg("sum") + Theta_S_colSums = countMat.groupby(ct_sample_id).sum(numeric_only=True) Theta_S = Theta_S_colSums.copy() Theta_S["sum"] = Theta_S_colSums.sum(axis=1) Theta_S = Theta_S[list(Theta_S.columns)[:-1]].div(Theta_S["sum"], axis=0) - Theta_S = Theta_S[list(Theta_S.columns)[:-1]] grp = [] for ind in Theta_S.index: grp.append(ind.split("$*$")[0]) Theta_S["grp"] = grp - Theta = Theta_S.groupby(grp).agg("mean") + Theta = Theta_S.groupby(grp).mean(numeric_only=True) Theta = Theta.reindex(sc_meta[ct_varname].unique()) S = S[Theta.index] Theta["S"] = S.iloc[0] @@ -258,68 +261,41 @@ def selectInfo(self, common_gene): genes = list(sd_within_colMean[genes_to_select].index) return genes - def fit(self, max_iter=100, epsilon=1e-4): + def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1): """Fit function for model training. Parameters ---------- + x : np.ndarray + Target spatial RNA-seq counts data to be deconvoluted. + spatial : np.ndarray + 2-D array of the spatial locations of each spot (spots x 2). If all zeros, then do not use spatial info. max_iter : int Maximum number of iterations for optimization. epsilon : float Optimization threshold. + sigma : float + Spatial gaussian kernel scaling factor. """ ct_select = self.ct_select - self.createscRef() # create basis Basis = self.basis.copy() Basis = Basis.loc[ct_select] - spatial_count = self.spatial_count.copy() - - # select common genes - tmp = list(set(list(spatial_count.columns)).intersection(list(Basis.columns))) - tmp.sort() - common_gene = list() - for gene in tmp: - if "mt-" not in gene: - common_gene.append(gene) - - # Select informative genes - common = self.selectInfo(common_gene) - Xinput = spatial_count.copy() - B = Basis.copy() - Xinput = Xinput[common] - B = B[common] - Xinput = Xinput.sort_index(axis="columns") - B = B.sort_index(axis="columns") - filter1 = Xinput.sum(axis="columns") > 0 - filter2 = Xinput.sum(axis="index") > 0 - Xinput = Xinput.loc[filter1] - Xinput = Xinput.loc[:, filter2] - B = B.loc[:, filter2] - # normalize count data - rowsumvec = Xinput.sum(axis=1) - Xinput_norm = Xinput.copy() - Xinput["rowsumvec"] = rowsumvec - Xinput_norm = Xinput[list(Xinput.columns)[:-1]].div(Xinput["rowsumvec"], axis=0) + + gene_mask = self.gene_mask & (x.sum(0) > 0) + B = Basis.values[:, gene_mask].copy() # TODO: make it a numpy array + Xinput = x[:, gene_mask].copy() + Xinput_norm = Xinput / Xinput.sum(1, keepdims=True) # TODO: use the normalize util # Spatial location - spatial_location = self.spatial_location - if spatial_location is None: + if (spatial == 0).all(): kernel_mat = None else: # TODO: refactor this to preprocess? - l1 = list(spatial_location.index) - l2 = list(spatial_count.index) - spatial_location = spatial_location.loc[[spot for spot in l1 if spot in l2]] - norm_cords = spatial_location[["x", "y"]] - norm_cords[["x"]] = norm_cords[["x"]] - norm_cords[["x"]].min() - norm_cords[["y"]] = norm_cords[["y"]] - norm_cords[["y"]].min() - scaleFactor = norm_cords.max().max() - norm_cords[["x"]] = norm_cords[["x"]].div(scaleFactor) - norm_cords[["y"]] = norm_cords[["y"]].div(scaleFactor) - ED = squareform(pdist(norm_cords, "euclidean")) - isigma = 0.1 - kernel_mat = np.exp(-ED**2 / (2 * isigma**2)) + norm_cords = (spatial - spatial.min(0)) # Q: why not min-max? + norm_cords /= norm_cords.max() + euclidean_distances = pairwise_distance(norm_cords.astype(np.float32), 0) + kernel_mat = np.exp(-euclidean_distances**2 / (2 * sigma**2)) np.fill_diagonal(kernel_mat, 0) # Initialize the proportion matrix @@ -327,21 +303,16 @@ def fit(self, max_iter=100, epsilon=1e-4): Vint1 = rng.dirichlet(np.repeat(10, B.shape[0], axis=0), Xinput_norm.shape[0]) phi = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99] # scale the Xinput_norm and B to speed up the convergence. - mean_X = Xinput_norm.values.mean() - mean_B = B.values.mean() - Xinput_norm = Xinput_norm * 0.1 / mean_X - B = B * 0.1 / mean_B - # print("B", B.values.shape, mean_B.shape) + Xinput_norm = Xinput_norm * 0.1 / Xinput_norm.sum() + B = B * 0.1 / B.mean() # Optimization ResList = {} Obj = np.array([]) for iphi in range(len(phi)): - res = CARDref(Xinput=Xinput_norm.T.values, U=B.T.values, W=kernel_mat, phi=phi[iphi], max_iter=max_iter, - epsilon=epsilon, V=Vint1, b=np.repeat(0, B.T.shape[1]).reshape(B.T.shape[1], 1), sigma_e2=0.1, + res = CARDref(Xinput=Xinput_norm.T, U=B.T, W=kernel_mat, phi=phi[iphi], max_iter=max_iter, epsilon=epsilon, + V=Vint1, b=np.repeat(0, B.T.shape[1]).reshape(B.T.shape[1], 1), sigma_e2=0.1, Lambda=np.repeat(10, len(ct_select))) - # rownames(res$V) = colnames(Xinput_norm) - # colnames(res$V) = colnames(B) ResList[str(iphi)] = res Obj = np.append(Obj, res["Obj"]) self.Obj_hist = Obj @@ -351,13 +322,8 @@ def fit(self, max_iter=100, epsilon=1e-4): print("## Deconvolution Finish! ...\n") self.info_parameters["phi"] = OptimalPhi - self.algorithm_matrix = { - "B": (B.values * mean_B) / 1e-01, - "Xinput_norm": (Xinput_norm.values * mean_X) / 1e-01, - "Res": OptimalRes - } + self.algorithm_matrix = {"B": B, "Xinput_norm": Xinput_norm, "Res": OptimalRes} - self.spatial_location = spatial_location return self def predict(self): @@ -373,6 +339,10 @@ def predict(self): prop_pred = optim_res["V"] / optim_res["V"].sum(axis=1, keepdims=True) return prop_pred + def fit_and_predict(self, x, spatial, max_iter=100, epsilon=1e-4): + self.fit(x, spatial, max_iter=max_iter, epsilon=epsilon) + return self.predict() + @staticmethod def score(x, y): """Model performance score measured by mean square error. diff --git a/examples/spatial/cell_type_deconvo/card.py b/examples/spatial/cell_type_deconvo/card.py index f114a446..304fdee5 100644 --- a/examples/spatial/cell_type_deconvo/card.py +++ b/examples/spatial/cell_type_deconvo/card.py @@ -1,6 +1,10 @@ import argparse from pprint import pprint +import numpy as np +from anndata import AnnData + +from dance.data import Data from dance.datasets.spatial import CellTypeDeconvoDatasetLite from dance.modules.spatial.cell_type_deconvo.card import Card @@ -18,33 +22,30 @@ # Load dataset dataset = CellTypeDeconvoDatasetLite(data_id=args.dataset, data_dir=args.datadir) -sc_count = dataset.data["ref_sc_count"] -sc_meta = dataset.data["ref_sc_annot"] -spatial_count = dataset.data["mix_count"] -true_p = dataset.data["true_p"] -spatial_location = None if args.location_free else dataset.data["spatial_location"] -ct_select = sorted(set(sc_meta.cellType.unique().tolist()) & set(true_p.columns.tolist())) + +ref_count, ref_annot, count_matrix, cell_type_portion, spatial = dataset.load_data() + +# TODO: add ref index (or more flexible indexing option at init, e.g., as dict?) and combine with data +ref_adata = AnnData(X=ref_count, obsm={"annot": ref_annot}, dtype=np.float32) +adata = AnnData(X=count_matrix, obsm={"spatial": spatial, "cell_type_portion": cell_type_portion}, dtype=np.float32) + +# TODO: deprecate the need for ct_select by doing this in a preprocessing step -> convert ct into one-hot matrix +ct_select = sorted(set(ref_annot.cellType.unique().tolist()) & set(cell_type_portion.columns.tolist())) print(f"{ct_select=}") -# Initialize and train moel -crd = Card( - sc_count=sc_count, - sc_meta=sc_meta, - spatial_count=spatial_count, - spatial_location=spatial_location, - ct_varname="cellType", - ct_select=ct_select, - cell_varname=None, - sample_varname=None, -) -crd.fit(max_iter=args.max_iter, epsilon=args.epsilon) - -# Evaluate -pred = crd.predict() -mse = crd.score(pred, true_p[ct_select].values) +data = Data(adata) +data.set_config(feature_channel=[None, "spatial"], feature_channel_type=[None, "obsm"], + label_channel="cell_type_portion") + +# TODO: after removing ct_select, return as numpy +(x_count, x_spatial), y = data.get_data(return_type="default") + +model = Card(ref_count, ref_annot, ct_varname="cellType", ct_select=ct_select) +pred = model.fit_and_predict(x_count, x_spatial.values, max_iter=args.max_iter, epsilon=args.epsilon) +mse = model.score(pred, y[ct_select].values) print(f"Predicted cell-type proportions of sample 1: {pred[0].round(3)}") -print(f"True cell-type proportions of sample 1: {true_p[ct_select].iloc[0].tolist()}") +print(f"True cell-type proportions of sample 1: {y.iloc[0].tolist()}") print(f"mse = {mse:7.4f}") """To reproduce CARD benchmarks, please refer to command lines belows: