In [1]:
import anndata
import scanpy as sc
import pandas as pd
import numpy as np
import torch

from tqdm.auto import tqdm

In [2]:
torch.set_float32_matmul_precision('high')

## Download data

In [2]:
# download from scVelo: https://github.com/theislab/scvelo/blob/main/scvelo/datasets/_datasets.py#L278-L302C17


def pbmc68k(file_path):
    """Peripheral blood mononuclear cells.

    Data from `Zheng et al. (2017) <https://doi.org/10.1038/ncomms14049>`__.

    This experiment contains 68k peripheral blood mononuclear cells (PBMC) measured
    using 10X.

    PBMCs are a diverse mixture of highly specialized immune cells.
    They originate from hematopoietic stem cells (HSCs) that reside in the bone marrow
    and give rise to all blood cells of the immune system (hematopoiesis).
    HSCs give rise to myeloid (monocytes, macrophages, granulocytes, megakaryocytes,
    dendritic cells, erythrocytes) and lymphoid (T cells, B cells, NK cells) lineages.

    .. image:: https://user-images.githubusercontent.com/31883718/118402351-e1243580-b669-11eb-8256-4a49c299da3d.png
       :width: 600px

    Returns
    -------
    Returns `adata` object
    """
    url = "https://ndownloader.figshare.com/files/27686886"
    adata = sc.read(file_path, backup_url=url, sparse=True, cache=True)
    adata.var_names_make_unique()
    return adata


In [8]:
pbmc68k('/mnt/dssfs02/pbmc68k.h5ad')

100%|██████████| 118M/118M [00:26<00:00, 4.60MB/s] 


AnnData object with n_obs × n_vars = 65877 × 33939
    obs: 'celltype'
    var: 'Accession', 'Chromosome', 'End', 'Start', 'Strand'
    obsm: 'X_tsne'
    layers: 'spliced', 'unspliced'

## Load and streamline data

In [3]:
adata = sc.read('/vol/data/h5ads/pbmc68k.h5ad')
sc.pp.filter_cells(adata, min_counts=50)
adata



AnnData object with n_obs × n_vars = 64846 × 33939
    obs: 'celltype', 'n_counts'
    var: 'Accession', 'Chromosome', 'End', 'Start', 'Strand'
    obsm: 'X_tsne'
    layers: 'spliced', 'unspliced'

In [4]:
adata.var = adata.var.set_index('Accession')

In [5]:
adata.var.head()

Unnamed: 0_level_0,Chromosome,End,Start,Strand
Accession,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ENSG00000237613,1,36081,34554,-
ENSG00000238009,1,133723,89295,-
ENSG00000239945,1,91105,89551,-
ENSG00000239906,1,140339,139790,-
ENSG00000284733,1,451697,450703,-


In [6]:
from cellnet.utils.data_loading import streamline_count_matrix
from scipy.sparse import csc_matrix


genes_from_model = pd.read_parquet('/vol/data/merlin_cxg_2023_05_15_sf-log1p/var.parquet')
genes_from_model.head()

gene_names_model = genes_from_model.feature_id.tolist()
gene_names_raw = adata.var.index.tolist()
adata.X = csc_matrix(adata.X)
adata = adata[:, np.isin(gene_names_raw, gene_names_model)].copy()

x_streamlined = streamline_count_matrix(adata.X, adata.var.index.tolist(), gene_names_model)
x_streamlined.shape

(64846, 19331)

In [7]:
adata = anndata.AnnData(X=x_streamlined, obs=adata.obs, var=genes_from_model)

In [8]:
sc.pp.normalize_total(adata, target_sum=10000)
sc.pp.log1p(adata)

In [9]:
adata.X

<64846x19331 sparse matrix of type '<class 'numpy.float32'>'
	with 6941563 stored elements in Compressed Sparse Row format>

In [10]:
adata.obs['cell_type'] = adata.obs.celltype.astype('category')
adata.obs = adata.obs[['cell_type']]

In [11]:
from sklearn.model_selection import train_test_split


idxs_train, idxs_test = train_test_split(
    np.arange(len(adata)), train_size=0.7, test_size=0.3, stratify=adata.obs.cell_type, random_state=1)
idxs_test, idxs_val = train_test_split(
    idxs_test, train_size=0.5, test_size=0.5, stratify=adata.obs.cell_type.iloc[idxs_test], random_state=1)

splits = {'train': idxs_train, 'val': idxs_val, 'test': idxs_test}

In [12]:
adata_train = adata[splits['train']].copy()
adata_val = adata[splits['val']].copy()
adata_test = adata[splits['test']].copy()

In [13]:
adata_train.X = adata_train.X.toarray()
adata_val.X = adata_val.X.toarray()
adata_test.X = adata_test.X.toarray()

## Train models

#### scTab

In [14]:
import torch
import torch.nn as nn
import lightning as L
import torch.nn.functional as F

from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassF1Score


def augment_data(x: torch.Tensor, augmentation_vectors: torch.Tensor):
    augmentations = augmentation_vectors[
        torch.randint(0, augmentation_vectors.shape[0], (x.shape[0], ), device=x.device), :
    ]
    sign = 2. * (torch.bernoulli(.5 * torch.ones(x.shape[0], 1, device=x.device)) - .5)

    return torch.clamp(x + (sign * augmentations), min=0., max=9.)


class LitTabnet(L.LightningModule):
    
    def __init__(
        self, 
        model, 
        type_dim,
        augmentations,
        learning_rate,
        weight_decay,
        lambda_sparse,
        augment
    ):
        super().__init__()
        self.model = model
        self.lr = learning_rate
        self.weight_decay = weight_decay
        self.lambda_sparse = lambda_sparse

        metrics = MetricCollection({
            'f1_micro': MulticlassF1Score(num_classes=type_dim, average='micro'),
            'f1_macro': MulticlassF1Score(num_classes=type_dim, average='macro'),
        })
        self.train_metrics = metrics.clone(prefix='train_')
        self.val_metrics = metrics.clone(prefix='val_')
        self.test_metrics = metrics.clone(prefix='test_')
        self.augment = augment
        self.register_buffer('augmentations', torch.tensor(augmentations.astype('f4')))

        self.optim = torch.optim.AdamW

    def on_after_batch_transfer(self, batch, dataloader_idx):
        with torch.no_grad():
            batch = batch[0]
            batch['cell_type'] = torch.squeeze(batch['cell_type'])

        return batch

    def training_step(self, batch, batch_idx):
        x = augment_data(batch['X'], self.augmentations) if self.augment else batch['X']
        logits, m_loss = self.model(x)
        preds = torch.argmax(logits, dim=1)
        loss = F.cross_entropy(logits, batch['cell_type']) - self.lambda_sparse * m_loss
        metrics = self.train_metrics(preds, batch['cell_type'])
        self.log_dict(metrics)
        self.log('train_loss', loss)

        return loss

    def validation_step(self, batch, batch_idx):
        logits, m_loss = self.model(batch['X'])
        preds = torch.argmax(logits, dim=1)
        loss = F.cross_entropy(logits, batch['cell_type'])
        metrics = self.val_metrics(preds, batch['cell_type'])
        self.log_dict(metrics)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        logits, _ = self.model(batch['X'])
        preds = torch.argmax(logits, dim=1)
        loss = F.cross_entropy(logits, batch['cell_type'])
        metrics = self.test_metrics(preds, batch['cell_type'])
        self.log_dict(metrics)
        self.log('test_loss', loss)

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        logits, _ = self.model(batch['X'])
        preds = torch.argmax(logits, dim=1)
        return preds

    def configure_optimizers(self):
        optimizer_config = {'optimizer': self.optim(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)}

        return optimizer_config


In [None]:
import os
import wandb

from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import seed_everything
from cellnet.tabnet.tab_network import TabNet
from cellnet.utils.data_loading import dataloader_factory

seed_everything(1)


VERSION = 'scTab_big'
os.makedirs(f'/vol/data/tb_logs/zheng68k/{VERSION}')


# init model
tabnet = TabNet(
    input_dim=adata.X.shape[1],
    output_dim=adata.obs.cell_type.nunique(),
    n_d=128,
    n_a=64,
    n_steps=1,
    gamma=1.3,
    n_independent=5,
    n_shared=3,
    epsilon=1e-15,
    virtual_batch_size=128,
    momentum=0.02,
    mask_type='entmax',
)
augmentations = np.load('/vol/data/merlin_cxg_2023_05_15_sf-log1p/augmentations.npy')
lit_tabnet = LitTabnet(
    tabnet, 
    adata.obs.cell_type.nunique(), 
    augmentations,
    learning_rate=5e-3,
    weight_decay=0.05,
    lambda_sparse=1e-5,
    augment=True
)

# train model
logger = WandbLogger(project='zheng68k', version=VERSION)
trainer = L.Trainer(
    max_epochs=15,
    gradient_clip_val=1.,
    gradient_clip_algorithm='norm',
    accelerator='gpu',
    devices=1,
    num_sanity_val_steps=0,
    logger=logger,
    callbacks=[
        ModelCheckpoint(
            dirpath=f'/vol/data/tb_logs/zheng68k/{logger.version}',
            filename='val_f1_micro_{epoch}_{val_f1_micro:.3f}', 
            monitor='val_f1_micro', 
            mode='max', save_top_k=1),
    ]
)
loader_train = dataloader_factory(x=adata_train.X, obs=adata_train.obs, batch_size=256)
loader_val = dataloader_factory(x=adata_val.X, obs=adata_val.obs, batch_size=256)
trainer.fit(model=lit_tabnet, train_dataloaders=loader_train, val_dataloaders=loader_val)

wandb.finish()

#### CellTypist

In [30]:
import celltypist

In [None]:
new_model = celltypist.train(
    adata_train, 
    labels='cell_type', 
    n_jobs=16,
    feature_selection=True,
    use_SGD=True, 
    mini_batch=True,
    with_mean=False,
    random_state=1
)
new_model.write('/vol/data/tb_logs/zheng68k/celltypist.pkl')

## Evalutate performance

In [26]:
from sklearn.metrics import classification_report

#### scTab

In [24]:
loader_test = dataloader_factory(x=adata_test.X, obs=adata_test.obs, batch_size=256)
preds_tabnet = torch.cat(trainer.predict(
    model=lit_tabnet, 
    dataloaders=loader_test,
    ckpt_path='/vol/data/tb_logs/zheng68k/scTab_big/val_f1_micro_epoch=5_val_f1_micro=0.616.ckpt'
)).numpy()

celltype_mapping = dict(enumerate(adata_train.obs.cell_type.cat.categories))
preds_tabnet = np.array([celltype_mapping[elem] for elem in preds_tabnet])

Restoring states from the checkpoint path at /vol/data/tb_logs/zheng68k/scTab_big/val_f1_micro_epoch=5_val_f1_micro=0.616.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /vol/data/tb_logs/zheng68k/scTab_big/val_f1_micro_epoch=5_val_f1_micro=0.616.ckpt
/home/ubuntu/miniconda3/envs/merlin-2312/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

In [28]:
pd.DataFrame(classification_report(
    adata_test.obs.cell_type.to_numpy(),
    preds_tabnet,
    output_dict=True,
    zero_division=0.
)).T

Unnamed: 0,precision,recall,f1-score,support
CD14+ Monocyte,0.751634,0.849754,0.797688,406.0
CD19+ B,0.915468,0.61697,0.737147,825.0
CD34+,0.903226,0.756757,0.823529,37.0
CD4+ T Helper2,0.0,0.0,0.0,12.0
CD4+/CD25 T Reg,0.356272,0.467963,0.40455,874.0
CD4+/CD45RA+/CD25- Naive T,0.0,0.0,0.0,268.0
CD4+/CD45RO+ Memory,0.0,0.0,0.0,435.0
CD56+ NK,0.830769,0.857143,0.84375,1260.0
CD8+ Cytotoxic T,0.659278,0.54147,0.594595,2966.0
CD8+/CD45RA+ Naive Cytotoxic,0.528167,0.804521,0.637691,2389.0


#### CellTypist

In [38]:
import celltypist

In [45]:
preds = celltypist.annotate(adata_test, model='/vol/data/tb_logs/zheng68k/celltypist.pkl')

🔬 Input data has 9727 cells and 19331 genes
🔗 Matching reference genes in the model
🧬 2624 features used for prediction
⚖️ Scaling input data
🖋️ Predicting labels
✅ Prediction done!


In [46]:
pd.DataFrame(classification_report(
    adata_test.obs.cell_type.to_numpy(),
    preds.predicted_labels.predicted_labels.to_numpy().flatten(),
    output_dict=True
)).T

Unnamed: 0,precision,recall,f1-score,support
CD14+ Monocyte,0.743523,0.706897,0.724747,406.0
CD19+ B,0.692635,0.592727,0.638798,825.0
CD34+,0.769231,0.810811,0.789474,37.0
CD4+ T Helper2,0.0,0.0,0.0,12.0
CD4+/CD25 T Reg,0.306452,0.173913,0.221898,874.0
CD4+/CD45RA+/CD25- Naive T,0.140449,0.093284,0.112108,268.0
CD4+/CD45RO+ Memory,0.161943,0.091954,0.117302,435.0
CD56+ NK,0.734848,0.846825,0.786873,1260.0
CD8+ Cytotoxic T,0.494996,0.583614,0.535665,2966.0
CD8+/CD45RA+ Naive Cytotoxic,0.507305,0.523231,0.515145,2389.0
