In [2]:
import anndata
import scipy.sparse as sp
import numpy as np
import scanpy as sc
import torch
from torch.utils.data.dataset import Dataset

In [3]:
class SingleCellData(Dataset):
    def __init__(self, data_path, num_gene, normalized=True):
        self.data_path = data_path
        self.num_gene = num_gene
        self.normalized = normalized
        self.anndata = anndata.read_h5ad(data_path)
        sc.pp.filter_genes(self.anndata, min_counts=1)
        self.gene_mask = self.select_gene(data=self.anndata.X, num_gene=self.num_gene)
        
        if self.normalized:
            anndata_norm = self.anndata.copy()
            sc.pp.normalize_per_cell(anndata_norm, counts_per_cell_after=1_000_000)
            sc.pp.log1p(anndata_norm)
            anndata_norm.X = anndata_norm.X.toarray()
            anndata_norm.X -= anndata_norm.X.mean(axis=0)
            anndata_norm.X /= anndata_norm.X.std(axis=0)
            self.anndata_preprocessed = anndata_norm[:, self.gene_mask].copy()
        else:
            self.anndata_preprocessed = self.anndata.copy()
            
        self.X = torch.tensor(self.anndata_preprocessed.X)
        self.id_to_batch, self.batch_to_id = self.get_batch_map()
        self.id_to_cell, self.cell_to_id = self.get_cell_map()
        
        
        self.cell_label_tensor = torch.tensor([self.cell_to_id[e] for e in self.anndata.obs['labels']])
        self.batch_id_tensor = torch.tensor([self.batch_to_id[e] for e in self.anndata.obs['batch_id']])
        
    def __getitem__(self, index):
        return self.X[index], self.cell_label_tensor[index], self.batch_id_tensor[index]
            
    def __len__(self):
        return len(self.anndata)
    
    def get_batch_map(self):
        num_batch_type = len(self.anndata.obs['batch_id'].unique().tolist())
        id2batch = dict(zip(list(range(num_batch_type)), self.anndata.obs['batch_id'].unique().tolist()))
        batch2id = {v: k for k,v in id2batch.items()}
        return id2batch, batch2id
    
    def get_cell_map(self):
        num_cell_type = len(self.anndata.obs['labels'].unique().tolist())
        id2cell = dict(zip(list(range(num_cell_type)), self.anndata.obs['labels'].unique().tolist()))
        cell2id = {v:k for k,v in id2cell.items()}
        return id2cell, cell2id
    
    def select_gene(self, data,\
                    num_gene,\
                    threshold=0,\
                    atleast=10,\
                    decay=1,
                    xoffset=5,\
                    yoffset=0.02):
        
        if sp.issparse(data):
            zeroRate = 1 - np.squeeze(np.array((data > threshold).mean(axis=0)))
            A = data.multiply(data > threshold)
            A.data = np.log2(A.data)
            meanExpr = np.zeros_like(zeroRate) * np.nan
            detected = zeroRate < 1
            meanExpr[detected] = np.squeeze(np.array(A[:, detected].mean(axis=0))) / (
                1 - zeroRate[detected]
            )
        else:
            zeroRate = 1 - np.mean(data > threshold, axis=0)
            meanExpr = np.zeros_like(zeroRate) * np.nan
            detected = zeroRate < 1
            meanExpr[detected] = np.nanmean(
                np.where(data[:, detected] > threshold, np.log2(data[:, detected]), np.nan),
                axis=0,
            )

        lowDetection = np.array(np.sum(data > threshold, axis=0)).squeeze() < atleast
        # lowDetection = (1 - zeroRate) * data.shape[0] < atleast - .00001
        zeroRate[lowDetection] = np.nan
        meanExpr[lowDetection] = np.nan

        if self.num_gene is not None:
            up = 10
            low = 0
            for t in range(100):
                nonan = ~np.isnan(zeroRate)
                selected = np.zeros_like(zeroRate).astype(bool)
                selected[nonan] = (
                    zeroRate[nonan] > np.exp(-decay * (meanExpr[nonan] - xoffset)) + yoffset
                )
                if np.sum(selected) == num_gene:
                    break
                elif np.sum(selected) < num_gene:
                    up = xoffset
                    xoffset = (xoffset + low) / 2
                else:
                    low = xoffset
                    xoffset = (xoffset + up) / 2
            print("Chosen offset: {:.2f}".format(xoffset))
        else:
            nonan = ~np.isnan(zeroRate)
            selected = np.zeros_like(zeroRate).astype(bool)
            selected[nonan] = (
                zeroRate[nonan] > np.exp(-decay * (meanExpr[nonan] - xoffset)) + yoffset
            )

        return selected
    

In [4]:
sc_dataset = SingleCellData(data_path="./data/baron_2016h.h5ad", num_gene=3000)

Chosen offset: 0.18


In [None]:
sc_dataloader_train = to

In [None]:
x,y,b

In [None]:
x[0:100]

In [None]:
sc_dataset.cell_to_id

In [None]:
sc_dataset.batch_to_id

In [None]:
adata = anndata.read_h5ad("./data/baron_2016h.h5ad")
sc.pp.filter_genes(adata, min_counts=1)
adata

In [None]:
def select_genes(
    data,
    threshold=0,
    atleast=10,
    yoffset=0.02,
    xoffset=5,
    decay=1,
    n=None,
    alpha=1,
):
    if sp.issparse(data):
        zeroRate = 1 - np.squeeze(np.array((data > threshold).mean(axis=0)))
        A = data.multiply(data > threshold)
        A.data = np.log2(A.data)
        meanExpr = np.zeros_like(zeroRate) * np.nan
        detected = zeroRate < 1
        meanExpr[detected] = np.squeeze(np.array(A[:, detected].mean(axis=0))) / (
            1 - zeroRate[detected]
        )
    else:
        zeroRate = 1 - np.mean(data > threshold, axis=0)
        meanExpr = np.zeros_like(zeroRate) * np.nan
        detected = zeroRate < 1
        meanExpr[detected] = np.nanmean(
            np.where(data[:, detected] > threshold, np.log2(data[:, detected]), np.nan),
            axis=0,
        )

    lowDetection = np.array(np.sum(data > threshold, axis=0)).squeeze() < atleast
    # lowDetection = (1 - zeroRate) * data.shape[0] < atleast - .00001
    zeroRate[lowDetection] = np.nan
    meanExpr[lowDetection] = np.nan

    if n is not None:
        up = 10
        low = 0
        for t in range(100):
            nonan = ~np.isnan(zeroRate)
            selected = np.zeros_like(zeroRate).astype(bool)
            selected[nonan] = (
                zeroRate[nonan] > np.exp(-decay * (meanExpr[nonan] - xoffset)) + yoffset
            )
            if np.sum(selected) == n:
                break
            elif np.sum(selected) < n:
                up = xoffset
                xoffset = (xoffset + low) / 2
            else:
                low = xoffset
                xoffset = (xoffset + up) / 2
        print("Chosen offset: {:.2f}".format(xoffset))
    else:
        nonan = ~np.isnan(zeroRate)
        selected = np.zeros_like(zeroRate).astype(bool)
        selected[nonan] = (
            zeroRate[nonan] > np.exp(-decay * (meanExpr[nonan] - xoffset)) + yoffset
        )

    return selected

In [None]:
gene_mask = select_genes(adata.X, n=3000, threshold=0)

adata_norm = adata.copy()
sc.pp.normalize_per_cell(adata_norm, counts_per_cell_after=1_000_000)
sc.pp.log1p(adata_norm)

In [None]:
adata_norm.X = adata_norm.X.toarray()
adata_norm.X -= adata_norm.X.mean(axis=0)
adata_norm.X /= adata_norm.X.std(axis=0)

In [None]:

adata_3000 = adata_norm[:, gene_mask].copy()

In [None]:
adata_3000.obs['batch_id'][4578]

In [None]:
adata_3000.obs['labels'][4578]

In [None]:
x,y,b = sc_dataset.__getitem__(4578)

In [None]:
adata_3000.X[4578]

In [None]:
np.array_equal(x,adata_3000.X[4578])

In [None]:
b

In [None]:
y

In [None]:
sc_dataset.id_to_cell

In [None]:
sc_dataset.id_to_batch