Skip to content

Commit

Permalink
update card cell-type deconvolution example script using dance data o…
Browse files Browse the repository at this point in the history
…bject (#93)

* remove unused imports and lines

* use load_data to expose raw data

* refactor card model: take x and spatial as numpy arrays when calling fit

* implement shortcut function fit_and_predict

* update card to use dance data object

* fix numeric_only agg
  • Loading branch information
RemyLau committed Dec 20, 2022
1 parent d03fb90 commit 0d26d29
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 113 deletions.
29 changes: 11 additions & 18 deletions dance/datasets/spatial.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,15 @@
import csv
import glob
import os
import os.path as osp
import pickle as pkl
import random
import re
import time as tm
import warnings
from collections import defaultdict
from operator import itemgetter

import anndata
import cv2
import networkx as nx
import numpy as np
import pandas as pd
import rdata
import scanpy as sc
import scipy.sparse
from anndata import AnnData
from scipy.stats import uniform

from dance.data import download_file, download_unzip, unzip_file
from dance.transforms import preprocess

IGNORED_FILES = ["readme.txt"]

Expand Down Expand Up @@ -181,7 +168,6 @@ def is_complete(self):
check = [self.data_dir + "/mix_count.*", self.data_dir + "/ref_sc_count.*"]

for i in check:
#if not os.path.exists(i):
if not glob.glob(i):
print("lack {}".format(i))
return False
Expand All @@ -195,8 +181,6 @@ def load_data(self):

self.data = {}
files = os.listdir(self.data_dir + "/")
filenames = [f.split(".")[0] for f in files]
extensions = [f.split(".")[1] for f in files]
for f in files:
DataPath = self.data_dir + "/" + f
filename = f.split(".")[0]
Expand Down Expand Up @@ -225,9 +209,9 @@ def __init__(self, data_id="GSE174746", data_dir="data/spatial", build_graph_fn=
self.data_id = data_id
self.data_dir = osp.join(data_dir, data_id)
self.data_url = cellDeconvo_dataset[data_id]
self.load_data()
self._load_data()

def load_data(self):
def _load_data(self):
if not osp.exists(self.data_dir):
download_unzip(self.data_url, self.data_dir)

Expand All @@ -244,6 +228,15 @@ def load_data(self):
else:
warnings.warn(f"Unsupported file type {ext!r}. Use csv or h5ad file types.")

def load_data(self):
ref_count = self.data["ref_sc_count"]
ref_annot = self.data["ref_sc_annot"]
count_matrix = self.data["mix_count"]
cell_type_portion = self.data["true_p"]
if (spatial := self.data.get("spatial_location")) is None:
spatial = pd.DataFrame(0, index=count_matrix.index, columns=["x", "y"])
return ref_count, ref_annot, count_matrix, cell_type_portion, spatial


class CARDSimulationRDataset:
ref_sc_count_url: str = "https://www.dropbox.com/s/wchoppxcsulk8ev/split2_ref_sc_count.h5ad?dl=1"
Expand Down
114 changes: 42 additions & 72 deletions dance/modules/spatial/cell_type_deconvo/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, squareform

from dance.utils.matrix import pairwise_distance


def obj_func(trac_xxt, UtXV, UtU, VtV, mGene, nSample, b, Lambda, beta, vecOne, V, L, alpha, sigma_e2=None):
Expand Down Expand Up @@ -118,10 +119,6 @@ class Card:
Reference single cell RNA-seq counts data.
sc_meta : pd.DataFrame
Reference cell-type label information.
spatial_count : pd.DataFrame
Target spatial RNA-seq counts data to be deconvoluted.
spatial_location : pd.DataFrame, optional
Optional spatial location for the spatial transcriptomic reads.
ct_varname : str, optional
Name of the cell-types column.
ct_select : str, optional
Expand All @@ -141,12 +138,10 @@ class Card:
"""

def __init__(self, sc_count, sc_meta, spatial_count, spatial_location=None, ct_varname=None, ct_select=None,
cell_varname=None, sample_varname=None, minCountGene=100, minCountSpot=5, basis=None, markers=None):
def __init__(self, sc_count, sc_meta, ct_varname=None, ct_select=None, cell_varname=None, sample_varname=None,
minCountGene=100, minCountSpot=5, basis=None, markers=None):
self.sc_count = sc_count
self.sc_meta = sc_meta
self.spatial_count = spatial_count
self.spatial_location = spatial_location
self.ct_varname = ct_varname
self.ct_select = ct_select
self.cell_varname = cell_varname
Expand All @@ -161,6 +156,15 @@ def __init__(self, sc_count, sc_meta, spatial_count, spatial_location=None, ct_v
self.cell_varname = "auto_cell_index"
self.sc_meta[self.cell_varname] = self.sc_meta.index

self.createscRef() # create basis
all_genes = sc_count.columns.tolist()
gene_to_idx = {j: i for i, j in enumerate(all_genes)}
not_mt_genes = [i for i in all_genes if not i.lower().startswith("mt-")]
selected_genes = self.selectInfo(not_mt_genes)
selected_gene_idx = list(map(gene_to_idx.get, selected_genes))
self.gene_mask = np.zeros(len(all_genes), dtype=np.bool)
self.gene_mask[selected_gene_idx] = True

def createscRef(self):
"""CreatescRef - create reference basis matrix from reference scRNA-seq."""
countMat = self.sc_count.copy() # cell by gene matrix
Expand Down Expand Up @@ -193,20 +197,19 @@ def createscRef(self):
S = S_JK.mean(axis=0).to_frame().unstack().droplevel(0)
S = S[sc_meta[ct_varname].unique()]
countMat["ct_sample_id"] = ct_sample_id
Theta_S_colMean = countMat.groupby(ct_sample_id).agg("mean")
Theta_S_colMean = countMat.groupby(ct_sample_id).mean(numeric_only=True)
tbl_sample = countMat.groupby([ct_sample_id]).size()
tbl_sample = tbl_sample.reindex_like(Theta_S_colMean)
tbl_sample = tbl_sample.reindex(Theta_S_colMean.index)
Theta_S_colSums = countMat.groupby(ct_sample_id).agg("sum")
Theta_S_colSums = countMat.groupby(ct_sample_id).sum(numeric_only=True)
Theta_S = Theta_S_colSums.copy()
Theta_S["sum"] = Theta_S_colSums.sum(axis=1)
Theta_S = Theta_S[list(Theta_S.columns)[:-1]].div(Theta_S["sum"], axis=0)
Theta_S = Theta_S[list(Theta_S.columns)[:-1]]
grp = []
for ind in Theta_S.index:
grp.append(ind.split("$*$")[0])
Theta_S["grp"] = grp
Theta = Theta_S.groupby(grp).agg("mean")
Theta = Theta_S.groupby(grp).mean(numeric_only=True)
Theta = Theta.reindex(sc_meta[ct_varname].unique())
S = S[Theta.index]
Theta["S"] = S.iloc[0]
Expand Down Expand Up @@ -258,90 +261,58 @@ def selectInfo(self, common_gene):
genes = list(sd_within_colMean[genes_to_select].index)
return genes

def fit(self, max_iter=100, epsilon=1e-4):
def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1):
"""Fit function for model training.
Parameters
----------
x : np.ndarray
Target spatial RNA-seq counts data to be deconvoluted.
spatial : np.ndarray
2-D array of the spatial locations of each spot (spots x 2). If all zeros, then do not use spatial info.
max_iter : int
Maximum number of iterations for optimization.
epsilon : float
Optimization threshold.
sigma : float
Spatial gaussian kernel scaling factor.
"""
ct_select = self.ct_select
self.createscRef() # create basis
Basis = self.basis.copy()
Basis = Basis.loc[ct_select]
spatial_count = self.spatial_count.copy()

# select common genes
tmp = list(set(list(spatial_count.columns)).intersection(list(Basis.columns)))
tmp.sort()
common_gene = list()
for gene in tmp:
if "mt-" not in gene:
common_gene.append(gene)

# Select informative genes
common = self.selectInfo(common_gene)
Xinput = spatial_count.copy()
B = Basis.copy()
Xinput = Xinput[common]
B = B[common]
Xinput = Xinput.sort_index(axis="columns")
B = B.sort_index(axis="columns")
filter1 = Xinput.sum(axis="columns") > 0
filter2 = Xinput.sum(axis="index") > 0
Xinput = Xinput.loc[filter1]
Xinput = Xinput.loc[:, filter2]
B = B.loc[:, filter2]
# normalize count data
rowsumvec = Xinput.sum(axis=1)
Xinput_norm = Xinput.copy()
Xinput["rowsumvec"] = rowsumvec
Xinput_norm = Xinput[list(Xinput.columns)[:-1]].div(Xinput["rowsumvec"], axis=0)

gene_mask = self.gene_mask & (x.sum(0) > 0)
B = Basis.values[:, gene_mask].copy() # TODO: make it a numpy array
Xinput = x[:, gene_mask].copy()
Xinput_norm = Xinput / Xinput.sum(1, keepdims=True) # TODO: use the normalize util

# Spatial location
spatial_location = self.spatial_location
if spatial_location is None:
if (spatial == 0).all():
kernel_mat = None
else:
# TODO: refactor this to preprocess?
l1 = list(spatial_location.index)
l2 = list(spatial_count.index)
spatial_location = spatial_location.loc[[spot for spot in l1 if spot in l2]]
norm_cords = spatial_location[["x", "y"]]
norm_cords[["x"]] = norm_cords[["x"]] - norm_cords[["x"]].min()
norm_cords[["y"]] = norm_cords[["y"]] - norm_cords[["y"]].min()
scaleFactor = norm_cords.max().max()
norm_cords[["x"]] = norm_cords[["x"]].div(scaleFactor)
norm_cords[["y"]] = norm_cords[["y"]].div(scaleFactor)
ED = squareform(pdist(norm_cords, "euclidean"))
isigma = 0.1
kernel_mat = np.exp(-ED**2 / (2 * isigma**2))
norm_cords = (spatial - spatial.min(0)) # Q: why not min-max?
norm_cords /= norm_cords.max()
euclidean_distances = pairwise_distance(norm_cords.astype(np.float32), 0)
kernel_mat = np.exp(-euclidean_distances**2 / (2 * sigma**2))
np.fill_diagonal(kernel_mat, 0)

# Initialize the proportion matrix
rng = np.random.default_rng(20200107)
Vint1 = rng.dirichlet(np.repeat(10, B.shape[0], axis=0), Xinput_norm.shape[0])
phi = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]
# scale the Xinput_norm and B to speed up the convergence.
mean_X = Xinput_norm.values.mean()
mean_B = B.values.mean()
Xinput_norm = Xinput_norm * 0.1 / mean_X
B = B * 0.1 / mean_B
# print("B", B.values.shape, mean_B.shape)
Xinput_norm = Xinput_norm * 0.1 / Xinput_norm.sum()

This comment has been minimized.

Copy link
@RemyLau

RemyLau Mar 20, 2023

Author Collaborator

Xinput_norm here is inconsistent with the original implementation, which should be divided by the mean instead of the sum, as pointed out by @jdevenegas. Although there's no difference in the final predictions.

B = B * 0.1 / B.mean()

# Optimization
ResList = {}
Obj = np.array([])
for iphi in range(len(phi)):
res = CARDref(Xinput=Xinput_norm.T.values, U=B.T.values, W=kernel_mat, phi=phi[iphi], max_iter=max_iter,
epsilon=epsilon, V=Vint1, b=np.repeat(0, B.T.shape[1]).reshape(B.T.shape[1], 1), sigma_e2=0.1,
res = CARDref(Xinput=Xinput_norm.T, U=B.T, W=kernel_mat, phi=phi[iphi], max_iter=max_iter, epsilon=epsilon,
V=Vint1, b=np.repeat(0, B.T.shape[1]).reshape(B.T.shape[1], 1), sigma_e2=0.1,
Lambda=np.repeat(10, len(ct_select)))
# rownames(res$V) = colnames(Xinput_norm)
# colnames(res$V) = colnames(B)
ResList[str(iphi)] = res
Obj = np.append(Obj, res["Obj"])
self.Obj_hist = Obj
Expand All @@ -351,13 +322,8 @@ def fit(self, max_iter=100, epsilon=1e-4):
print("## Deconvolution Finish! ...\n")

self.info_parameters["phi"] = OptimalPhi
self.algorithm_matrix = {
"B": (B.values * mean_B) / 1e-01,
"Xinput_norm": (Xinput_norm.values * mean_X) / 1e-01,
"Res": OptimalRes
}
self.algorithm_matrix = {"B": B, "Xinput_norm": Xinput_norm, "Res": OptimalRes}

self.spatial_location = spatial_location
return self

def predict(self):
Expand All @@ -373,6 +339,10 @@ def predict(self):
prop_pred = optim_res["V"] / optim_res["V"].sum(axis=1, keepdims=True)
return prop_pred

def fit_and_predict(self, x, spatial, max_iter=100, epsilon=1e-4):
self.fit(x, spatial, max_iter=max_iter, epsilon=epsilon)
return self.predict()

@staticmethod
def score(x, y):
"""Model performance score measured by mean square error.
Expand Down
47 changes: 24 additions & 23 deletions examples/spatial/cell_type_deconvo/card.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import argparse
from pprint import pprint

import numpy as np
from anndata import AnnData

from dance.data import Data
from dance.datasets.spatial import CellTypeDeconvoDatasetLite
from dance.modules.spatial.cell_type_deconvo.card import Card

Expand All @@ -18,33 +22,30 @@

# Load dataset
dataset = CellTypeDeconvoDatasetLite(data_id=args.dataset, data_dir=args.datadir)
sc_count = dataset.data["ref_sc_count"]
sc_meta = dataset.data["ref_sc_annot"]
spatial_count = dataset.data["mix_count"]
true_p = dataset.data["true_p"]
spatial_location = None if args.location_free else dataset.data["spatial_location"]
ct_select = sorted(set(sc_meta.cellType.unique().tolist()) & set(true_p.columns.tolist()))

ref_count, ref_annot, count_matrix, cell_type_portion, spatial = dataset.load_data()

# TODO: add ref index (or more flexible indexing option at init, e.g., as dict?) and combine with data
ref_adata = AnnData(X=ref_count, obsm={"annot": ref_annot}, dtype=np.float32)
adata = AnnData(X=count_matrix, obsm={"spatial": spatial, "cell_type_portion": cell_type_portion}, dtype=np.float32)

# TODO: deprecate the need for ct_select by doing this in a preprocessing step -> convert ct into one-hot matrix
ct_select = sorted(set(ref_annot.cellType.unique().tolist()) & set(cell_type_portion.columns.tolist()))
print(f"{ct_select=}")

# Initialize and train moel
crd = Card(
sc_count=sc_count,
sc_meta=sc_meta,
spatial_count=spatial_count,
spatial_location=spatial_location,
ct_varname="cellType",
ct_select=ct_select,
cell_varname=None,
sample_varname=None,
)
crd.fit(max_iter=args.max_iter, epsilon=args.epsilon)

# Evaluate
pred = crd.predict()
mse = crd.score(pred, true_p[ct_select].values)
data = Data(adata)
data.set_config(feature_channel=[None, "spatial"], feature_channel_type=[None, "obsm"],
label_channel="cell_type_portion")

# TODO: after removing ct_select, return as numpy
(x_count, x_spatial), y = data.get_data(return_type="default")

model = Card(ref_count, ref_annot, ct_varname="cellType", ct_select=ct_select)
pred = model.fit_and_predict(x_count, x_spatial.values, max_iter=args.max_iter, epsilon=args.epsilon)
mse = model.score(pred, y[ct_select].values)

print(f"Predicted cell-type proportions of sample 1: {pred[0].round(3)}")
print(f"True cell-type proportions of sample 1: {true_p[ct_select].iloc[0].tolist()}")
print(f"True cell-type proportions of sample 1: {y.iloc[0].tolist()}")
print(f"mse = {mse:7.4f}")
"""To reproduce CARD benchmarks, please refer to command lines belows:
Expand Down

0 comments on commit 0d26d29

Please sign in to comment.