In [1]:
import shap
from biomart import BiomartServer
import copy
import gc
import json
import os
from pathlib import Path
import shutil
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings
import pandas as pd
import pickle
import torch
from anndata import AnnData
import scanpy as sc
import scvi
import seaborn as sns
import numpy as np
import wandb
from scipy.sparse import issparse
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)
from sklearn.metrics import confusion_matrix
sys.path.insert(0, "../")
import scgpt as scg
from scgpt.model import TransformerModel, AdversarialDiscriminator
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.preprocess import Preprocessor
from scgpt import SubsetsBatchSampler
from scgpt.utils import set_seed, category_str2int, eval_scib_metrics

sc.set_figure_params(figsize=(6, 6))
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')

Global seed set to 0
  IPython.display.set_matplotlib_formats(*ipython_format)


In [2]:
print(f"CUDA is available: {torch.cuda.is_available()}")

CUDA is available: True


In [3]:
hyperparameter_defaults = dict(
    seed=0,
    dataset_name="fly",
    do_train=False,
    load_model="./model.pt",
    mask_ratio=0.0,
    epochs=10,
    n_bins=51,
    MVC=False, # Masked value prediction for cell embedding
    ecs_thres=0.0, # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
    dab_weight=0.0,
    lr=1e-4,
    batch_size=1,
    layer_size=512,
    nlayers=12,  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
    nhead=8,  # number of heads in nn.MultiheadAttention
    dropout=0.2,  # dropout probability
    schedule_ratio=0.9,  # ratio of epochs for learning rate schedule
    save_eval_interval=5,
    fast_transformer=True,
    pre_norm=False,
    amp=True,  # Automatic Mixed Precision
    include_zero_gene = False,
    freeze = False, #freeze
    DSBN = False,  # Domain-spec batchnorm
)

In [4]:
# settings for input and preprocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_ratio = hyperparameter_defaults["mask_ratio"]
mask_value = "auto"  # for masked values, now it should always be auto
include_zero_gene = hyperparameter_defaults["include_zero_gene"]  # if True, include zero genes among hvgs in the training
max_seq_len = 4716
n_bins = hyperparameter_defaults["n_bins"]

# input/output representation
input_style = "binned"  # "normed_raw", "log1p", or "binned"
output_style = "binned"  # "normed_raw", "log1p", or "binned"

# settings for training
MLM = False  # whether to use masked language modeling, currently it is always on.
CLS = True  # classification objective
ADV = False  # Adversarial training for batch correction
CCE = False  # Contrastive cell embedding objective
MVC = hyperparameter_defaults["MVC"]  # Masked value prediction for cell embedding
ECS = hyperparameter_defaults["ecs_thres"] > 0  # Elastic cell similarity objective
DAB = False  # Domain adaptation by reverse backpropagation, set to 2 for separate optimizer
INPUT_BATCH_LABELS = False  # TODO: have these help MLM and MVC, while not to classifier
input_emb_style = "continuous"  # "category" or "continuous" or "scaling"
cell_emb_style = "cls"  # "avg-pool" or "w-pool" or "cls"
adv_E_delay_epochs = 0  # delay adversarial training on encoder for a few epochs
adv_D_delay_epochs = 0
mvc_decoder_style = "inner product"
ecs_threshold = hyperparameter_defaults["ecs_thres"]
dab_weight = hyperparameter_defaults["dab_weight"]
explicit_zero_prob = MLM and include_zero_gene  # whether explicit bernoulli for zeros
do_sample_in_train = False and explicit_zero_prob  # sample the bernoulli in training
per_seq_batch_sample = False

# settings for optimizer
lr = hyperparameter_defaults["lr"]  # TODO: test learning rate ratio between two tasks
lr_ADV = 1e-3  # learning rate for discriminator, used when ADV is True
batch_size = hyperparameter_defaults["batch_size"]
eval_batch_size = hyperparameter_defaults["batch_size"]
epochs = hyperparameter_defaults["epochs"]
schedule_interval = 1

# settings for the model
fast_transformer = hyperparameter_defaults["fast_transformer"]
fast_transformer_backend = "flash"  # "linear" or "flash"
embsize = hyperparameter_defaults["layer_size"]  # embedding dimension
d_hid = hyperparameter_defaults["layer_size"]  # dimension of the feedforward network in TransformerEncoder
nlayers = hyperparameter_defaults["nlayers"]  # number of TransformerEncoderLayer in TransformerEncoder
nhead = hyperparameter_defaults["nhead"]  # number of heads in nn.MultiheadAttention
dropout = hyperparameter_defaults["dropout"]  # dropout probability

# logging
log_interval = 100  # iterations
save_eval_interval = hyperparameter_defaults["save_eval_interval"]  # epochs
do_eval_scib_metrics = True

In [5]:
# %% validate settings
assert input_style in ["normed_raw", "log1p", "binned"]
assert output_style in ["normed_raw", "log1p", "binned"]
assert input_emb_style in ["category", "continuous", "scaling"]
if input_style == "binned":
    if input_emb_style == "scaling":
        raise ValueError("input_emb_style `scaling` is not supported for binned input.")
elif input_style == "log1p" or input_style == "normed_raw":
    if input_emb_style == "category":
        raise ValueError(
            "input_emb_style `category` is not supported for log1p or normed_raw input."
        )

if input_emb_style == "category":
    mask_value = n_bins + 1
    pad_value = n_bins  # for padding gene expr values
    n_input_bins = n_bins + 2
else:
    mask_value = -1
    pad_value = -2
    n_input_bins = n_bins

if ADV and DAB:
    raise ValueError("ADV and DAB cannot be both True.")
DAB_separate_optim = True if DAB > 1 else False

In [6]:
# data_dir = Path("/home/phv028r/scGPT-aging/raw-data/fly")
adata = sc.read("./adata_headBody_S_v1.0.h5ad")
adata.obs["age"] = adata.obs["age"].astype("category")
adata.var["gene_name"] = adata.var.index
adata.var.set_index(adata.var["gene_name"], inplace=True)
data_is_raw = False
filter_gene_by_counts = False

ensembl_server = BiomartServer("http://www.ensembl.org/biomart")
ensembl_dataset = ensembl_server.datasets['dmelanogaster_gene_ensembl']
ensembl_query = ensembl_dataset.search({
#         "filters": {
#             "external_gene_name": adata.var["gene_name"].tolist(),
#         },
    'attributes': [
#             'flybase_gene_id',        # Fly gene ID
        'external_gene_name',     # Fly gene name
        'hsapiens_homolog_ensembl_gene', # Human gene ID
        'hsapiens_homolog_associated_gene_name' # Human gene name
    ]
})

ensembl_query_text = ensembl_query.text
ensembl_query_lines = ensembl_query_text.strip().split("\n")
ensembl_query_columns = ["fly_gene", "human_ensembl_id", "human_gene"]
ensembl_query_text = [line.split("\t") for line in ensembl_query_lines[1:]]
ensembl_query_df = pd.DataFrame(ensembl_query_text, columns=ensembl_query_columns)
ensembl_query_df['human_gene'].replace('', pd.NA, inplace=True)
filtered_ensembl_query_df = ensembl_query_df.dropna(subset=["human_gene"])
map_fly_to_human_dict = pd.Series(filtered_ensembl_query_df.human_gene.values, index=filtered_ensembl_query_df.fly_gene).to_dict()
reverse_mapping = dict()
for fly_gene, human_gene in map_fly_to_human_dict.items():
    reverse_mapping[human_gene] = fly_gene
reverse_mapping.pop("AGO2")
reverse_mapping.pop("PCNA")
reverse_mapping.pop("SPR")
reverse_mapping.pop("PGAP2")
reverse_mapping.pop("MED22")
one_to_one_mapping = {v: k for k, v in reverse_mapping.items()}
new_index = []
for gene in adata.var.index:
    if gene in one_to_one_mapping:
        new_index.append(one_to_one_mapping[gene])
    else:
        new_index.append(gene)
adata.var.index = pd.Index(new_index)


# make the batch category column
age_id_labels = adata.obs["age"].astype("category").cat.codes.values
ages = adata.obs["age"].unique()
num_types = len(np.unique(age_id_labels))
id2type = dict(enumerate(adata.obs["age"].astype("category").cat.categories))
adata.obs["age_id"] = age_id_labels
adata.var["gene_name"] = adata.var.index.tolist()

# randomly sample cells
sampled_indices = []
for age_category in ages:
    age_group = adata.obs[adata.obs['age'] == age_category]
    sampled_group = age_group.sample(n=250, random_state=42)
    sampled_indices.extend(sampled_group.index)
adata = adata[sampled_indices, :]

In [37]:
adata.T

AnnData object with n_obs × n_vars = 4715 × 1000
    obs: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'gene_name', 'id_in_vocab'
    var: 'tissue', 'sex', 'age', 'sex_age', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'log1p_n_genes_by_counts', 'log1p_total_counts', 'log1p_total_counts_mt', 'dataset', 'fca_annotation', 'afca_annotation', 'afca_annotation_broad', 'age_id'
    uns: 'age_colors', 'hvg', 'leiden', 'leiden_colors', 'neighbors', 'pca', 'sex_colors', 'tissue_colors', 'tsne', 'umap'
    varm: 'X_pca', 'X_tsne', 'X_umap', 'bin_edges'
    layers: 'X_binned'

In [39]:
adata.write_csvs("./adata.csv", skip_data=False)

In [45]:
dataset_gene = pd.read_csv("/home/qxy699/scGPT/adata/obs.csv")
dataset_gene["Unnamed: 0"]

0              TCATATCTCACCGGGT-1_AFCA_female_body_30_S4
1                TCCGAAAGTAGGCTGA-1_AFCA_male_body_30_S5
2                CCGAACGAGGCCCACT-1_AFCA_male_body_30_S3
3              CGTGATATCCCTAGGG-1_AFCA_female_body_30_S2
4              CAGATTGAGAATTGTG-1_AFCA_female_body_30_S5
                             ...                        
995    GTGATGTTCGAGTGGA-f8548b44__FCA27_Female_body_a...
996    GTCTTTAAGGTTGTTC-f7dc3ba8__FCA13_Female_head_a...
997      ACGTCCTAGTAGATCA-d541ae4e__FCA1_MaleFemale_Head
998    TTTCACATCCGTCCTA-f8548b44__FCA27_Female_body_a...
999    GGCTTTCAGGCATGGT-f8438c5e__FCA25_Female_body_a...
Name: Unnamed: 0, Length: 1000, dtype: object

In [33]:
adata

AnnData object with n_obs × n_vars = 1000 × 4715
    obs: 'tissue', 'sex', 'age', 'sex_age', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'log1p_n_genes_by_counts', 'log1p_total_counts', 'log1p_total_counts_mt', 'dataset', 'fca_annotation', 'afca_annotation', 'afca_annotation_broad', 'age_id'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'gene_name', 'id_in_vocab'
    uns: 'age_colors', 'hvg', 'leiden', 'leiden_colors', 'neighbors', 'pca', 'sex_colors', 'tissue_colors', 'tsne', 'umap'
    obsm: 'X_pca', 'X_tsne', 'X_umap', 'bin_edges'
    layers: 'X_binned'

In [7]:
# model_dir = Path(hyperparameter_defaults["load_model"])
model_file = "./model.pt"
vocab_file = "./vocab.json"
vocab = GeneVocab.from_file(vocab_file)
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s)

adata.var["id_in_vocab"] = [
    1 if gene in vocab else -1 for gene in adata.var["gene_name"]
]
gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
print(
    f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
    f"in vocabulary of size {len(vocab)}."
)

adata = adata[:, adata.var["id_in_vocab"] >= 0]

match 4715/15992 genes in vocabulary of size 60697.


In [8]:
from typing import Dict, Optional, Union

import numpy as np
import torch
from scipy.sparse import issparse
import scanpy as sc
from scanpy.get import _get_obs_rep, _set_obs_rep
from anndata import AnnData

from scgpt import logger


class Preprocessor_Edit:
    """
    Prepare data into training, valid and test split. Normalize raw expression
    values, binning or using other transform into the preset model input format.
    """

    def __init__(
        self,
        use_key: Optional[str] = None,
        filter_gene_by_counts: Union[int, bool] = False,
        filter_cell_by_counts: Union[int, bool] = False,
        normalize_total: Union[float, bool] = 1e4,
        result_normed_key: Optional[str] = "X_normed",
        log1p: bool = False,
        result_log1p_key: str = "X_log1p",
        subset_hvg: Union[int, bool] = False,
        hvg_use_key: Optional[str] = None,
        hvg_flavor: str = "seurat_v3",
        binning: Optional[int] = None,
        result_binned_key: str = "X_binned",
    ):
        r"""
        Set up the preprocessor, use the args to config the workflow steps.

        Args:

        use_key (:class:`str`, optional):
            The key of :class:`~anndata.AnnData` to use for preprocessing.
        filter_gene_by_counts (:class:`int` or :class:`bool`, default: ``False``):
            Whther to filter genes by counts, if :class:`int`, filter genes with counts
        filter_cell_by_counts (:class:`int` or :class:`bool`, default: ``False``):
            Whther to filter cells by counts, if :class:`int`, filter cells with counts
        normalize_total (:class:`float` or :class:`bool`, default: ``1e4``):
            Whether to normalize the total counts of each cell to a specific value.
        result_normed_key (:class:`str`, default: ``"X_normed"``):
            The key of :class:`~anndata.AnnData` to store the normalized data. If
            :class:`None`, will use normed data to replce the :attr:`use_key`.
        log1p (:class:`bool`, default: ``True``):
            Whether to apply log1p transform to the normalized data.
        result_log1p_key (:class:`str`, default: ``"X_log1p"``):
            The key of :class:`~anndata.AnnData` to store the log1p transformed data.
        subset_hvg (:class:`int` or :class:`bool`, default: ``False``):
            Whether to subset highly variable genes.
        hvg_use_key (:class:`str`, optional):
            The key of :class:`~anndata.AnnData` to use for calculating highly variable
            genes. If :class:`None`, will use :attr:`adata.X`.
        hvg_flavor (:class:`str`, default: ``"seurat_v3"``):
            The flavor of highly variable genes selection. See
            :func:`scanpy.pp.highly_variable_genes` for more details.
        binning (:class:`int`, optional):
            Whether to bin the data into discrete values of number of bins provided.
        result_binned_key (:class:`str`, default: ``"X_binned"``):
            The key of :class:`~anndata.AnnData` to store the binned data.
        """
        self.use_key = use_key
        self.filter_gene_by_counts = filter_gene_by_counts
        self.filter_cell_by_counts = filter_cell_by_counts
        self.normalize_total = normalize_total
        self.result_normed_key = result_normed_key
        self.log1p = log1p
        self.result_log1p_key = result_log1p_key
        self.subset_hvg = subset_hvg
        self.hvg_use_key = hvg_use_key
        self.hvg_flavor = hvg_flavor
        self.binning = binning
        self.result_binned_key = result_binned_key

    def __call__(self, adata: AnnData, batch_key: Optional[str] = None) -> Dict:
        """
        format controls the different input value wrapping, including categorical
        binned style, fixed-sum normalized counts, log1p fixed-sum normalized counts, etc.

        Args:

        adata (:class:`AnnData`):
            The :class:`AnnData` object to preprocess.
        batch_key (:class:`str`, optional):
            The key of :class:`AnnData.obs` to use for batch information. This arg
            is used in the highly variable gene selection step.
        """
        key_to_process = self.use_key
        # preliminary checks, will use later
        if key_to_process == "X":
            key_to_process = None  # the following scanpy apis use arg None to use X
        is_logged = self.check_logged(adata, obs_key=key_to_process)

        # step 1: filter genes
        if self.filter_gene_by_counts:
            logger.info("Filtering genes by counts ...")
            sc.pp.filter_genes(
                adata,
                min_counts=self.filter_gene_by_counts
                if isinstance(self.filter_gene_by_counts, int)
                else None,
            )

        # step 2: filter cells
        if (
            isinstance(self.filter_cell_by_counts, int)
            and self.filter_cell_by_counts > 0
        ):
            logger.info("Filtering cells by counts ...")
            sc.pp.filter_cells(
                adata,
                min_counts=self.filter_cell_by_counts
                if isinstance(self.filter_cell_by_counts, int)
                else None,
            )

        # step 3: normalize total
        if self.normalize_total:
            logger.info("Normalizing total counts ...")
            normed_ = sc.pp.normalize_total(
                adata,
                target_sum=self.normalize_total
                if isinstance(self.normalize_total, float)
                else None,
                layer=key_to_process,
                inplace=False,
            )["X"]
            key_to_process = self.result_normed_key or key_to_process
            _set_obs_rep(adata, normed_, layer=key_to_process)

        # step 4: log1p
        if self.log1p:
            logger.info("Log1p transforming ...")
            if is_logged:
                logger.warning(
                    "The input data seems to be already log1p transformed. "
                    "Set `log1p=False` to avoid double log1p transform."
                )
            if self.result_log1p_key:
                _set_obs_rep(
                    adata,
                    _get_obs_rep(adata, layer=key_to_process),
                    layer=self.result_log1p_key,
                )
                key_to_process = self.result_log1p_key
            sc.pp.log1p(adata, layer=key_to_process)

        # step 5: subset hvg
        if self.subset_hvg:
            logger.info("Subsetting highly variable genes ...")
            if batch_key is None:
                logger.warning(
                    "No batch_key is provided, will use all cells for HVG selection."
                )
            sc.pp.highly_variable_genes(
                adata,
                layer=self.hvg_use_key,
                n_top_genes=self.subset_hvg
                if isinstance(self.subset_hvg, int)
                else None,
                batch_key=batch_key,
                flavor=self.hvg_flavor,
                subset=True,
            )

        # step 6: binning
        if self.binning:
            logger.info("Binning data ...")
            if not isinstance(self.binning, int):
                raise ValueError(
                    "Binning arg must be an integer, but got {}.".format(self.binning)
                )
            n_bins = self.binning  # NOTE: the first bin is always a spectial for zero
            binned_rows = []
            bin_edges = []
            layer_data = _get_obs_rep(adata, layer=key_to_process)
            layer_data = layer_data.toarray() if issparse(layer_data) else layer_data
            if layer_data.min() < 0:
                raise ValueError(
                    f"Assuming non-negative data, but got min value {layer_data.min()}."
                )
            for row in layer_data:
                if row.max() == 0:
                    logger.warning(
                        "The input data contains all zero rows. Please make sure "
                        "this is expected. You can use the `filter_cell_by_counts` "
                        "arg to filter out all zero rows."
                    )
                    binned_rows.append(np.zeros_like(row, dtype=np.int64))
                    bin_edges.append(np.array([0] * n_bins))
                    continue
                non_zero_ids = row.nonzero()
                non_zero_row = row[non_zero_ids]
                bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1))
                # bins = np.sort(np.unique(bins))
                # NOTE: comment this line for now, since this will make the each category
                # has different relative meaning across datasets
                non_zero_digits = _digitize(non_zero_row, bins)
                assert non_zero_digits.min() >= 1
                assert non_zero_digits.max() <= n_bins - 1
                binned_row = np.zeros_like(row, dtype=np.int64)
                binned_row[non_zero_ids] = non_zero_digits
                binned_rows.append(binned_row)
                bin_edges.append(np.concatenate([[0], bins]))
            adata.layers[self.result_binned_key] = np.stack(binned_rows)
            adata.obsm["bin_edges"] = np.stack(bin_edges)

    def check_logged(self, adata: AnnData, obs_key: Optional[str] = None) -> bool:
        """
        Check if the data is already log1p transformed.

        Args:

        adata (:class:`AnnData`):
            The :class:`AnnData` object to preprocess.
        obs_key (:class:`str`, optional):
            The key of :class:`AnnData.obs` to use for batch information. This arg
            is used in the highly variable gene selection step.
        """
        data = _get_obs_rep(adata, layer=obs_key)
        max_, min_ = data.max(), data.min()
        if max_ > 30:
            return False
        if min_ < 0:
            return False

        non_zero_min = data[data > 0].min()
        if non_zero_min >= 1:
            return False

        return True


def _digitize(x: np.ndarray, bins: np.ndarray, side="both") -> np.ndarray:
    """
    Digitize the data into bins. This method spreads data uniformly when bins
    have same values.

    Args:

    x (:class:`np.ndarray`):
        The data to digitize.
    bins (:class:`np.ndarray`):
        The bins to use for digitization, in increasing order.
    side (:class:`str`, optional):
        The side to use for digitization. If "one", the left side is used. If
        "both", the left and right side are used. Default to "one".

    Returns:

    :class:`np.ndarray`:
        The digitized data.
    """
    assert x.ndim == 1 and bins.ndim == 1

    left_digits = np.digitize(x, bins)
    if side == "one":
        return left_digits

    right_difits = np.digitize(x, bins, right=True)

    rands = np.random.rand(len(x))  # uniform random numbers

    digits = rands * (right_difits - left_digits) + left_digits
    digits = np.ceil(digits).astype(np.int64)
    return digits


def binning(
    row: Union[np.ndarray, torch.Tensor], n_bins: int
) -> Union[np.ndarray, torch.Tensor]:
    """Binning the row into n_bins."""
    dtype = row.dtype
    return_np = False if isinstance(row, torch.Tensor) else True
    row = row.cpu().numpy() if isinstance(row, torch.Tensor) else row
    # TODO: use torch.quantile and torch.bucketize

    if row.max() == 0:
        logger.warning(
            "The input data contains row of zeros. Please make sure this is expected."
        )
        return (
            np.zeros_like(row, dtype=dtype)
            if return_np
            else torch.zeros_like(row, dtype=dtype)
        )

    if row.min() <= 0:
        non_zero_ids = row.nonzero()
        non_zero_row = row[non_zero_ids]
        bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1))
        non_zero_digits = _digitize(non_zero_row, bins)
        binned_row = np.zeros_like(row, dtype=np.int64)
        binned_row[non_zero_ids] = non_zero_digits
    else:
        bins = np.quantile(row, np.linspace(0, 1, n_bins - 1))
        binned_row = _digitize(row, bins)
    return torch.from_numpy(binned_row) if not return_np else binned_row.astype(dtype)

In [9]:
# set up the preprocessor, use the args to config the workflow
preprocessor = Preprocessor_Edit(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=filter_gene_by_counts,  # step 1
    filter_cell_by_counts=False,  # step 2
    normalize_total=False,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_is_raw,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=False,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=n_bins,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)

# Preprocess training and testing data
preprocessor(adata, batch_key=None)

scGPT - INFO - Binning data ...


In [10]:
input_layer_key = {  # the values of this map coorespond to the keys in preprocessing
    "normed_raw": "X_normed",
    "log1p": "X_normed",
    "binned": "X_binned",
}[input_style]
all_counts = (
    adata.layers[input_layer_key].A
    if issparse(adata.layers[input_layer_key])
    else adata.layers[input_layer_key]
)
genes = adata.var["gene_name"].tolist()

age_labels = adata.obs["age_id"].tolist()  # make sure count from 0
age_labels = np.array(age_labels)

In [11]:
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(vocab(genes), dtype=int)

In [12]:
tokenized_data = tokenize_and_pad_batch(
    all_counts,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,  # append <cls> token at the beginning
    include_zero_gene=include_zero_gene,
)

print(
    f"data set number of samples: {tokenized_data['genes'].shape[0]}, "
    f"\n\t feature length: {tokenized_data['genes'].shape[1]}"
)

data set number of samples: 1000, 
	 feature length: 2837


In [56]:
# dataset
class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}


# data_loader
def prepare_dataloader(
    data_pt: Dict[str, torch.Tensor],
    batch_size: int,
    shuffle: bool = False,
    intra_domain_shuffle: bool = False,
    drop_last: bool = False,
    num_workers: int = 0,
) -> DataLoader:
    if num_workers == 0:
        num_workers = min(len(os.sched_getaffinity(0)), batch_size // 2)

    dataset = SeqDataset(data_pt)
    if per_seq_batch_sample:
        # find the indices of samples in each seq batch
        subsets = []
        batch_labels_array = data_pt["batch_labels"].numpy()
        print(f"{batch_labels_array=}")
        for batch_label in np.unique(batch_labels_array):
            batch_indices = np.where(batch_labels_array == batch_label)[0].tolist()
            subsets.append(batch_indices)
        data_loader = DataLoader(
            dataset=dataset,
            batch_sampler=SubsetsBatchSampler(
                subsets,
                batch_size,
                intra_subset_shuffle=intra_domain_shuffle,
                inter_subset_shuffle=shuffle,
                drop_last=drop_last,
            ),
            num_workers=num_workers,
            pin_memory=True,
        )
        return data_loader
    
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=True,
    )
    return data_loader


def prepare_data(tokenized_data_dxe, sort_seq_batch=False) -> Tuple[Dict[str, torch.Tensor]]:
    masked_values = random_mask_value(
        tokenized_data_dxe["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    
    input_gene_ids = tokenized_data_dxe["genes"]
    input_values = masked_values
    target_values = tokenized_data_dxe["values"]
    tensor_age_labels = torch.from_numpy(age_labels).long()

    data_pt = {
        "gene_ids": input_gene_ids,
        "values": input_values,
        "target_values": target_values,
    }

    return data_pt

In [14]:
print(f"Cuda is available: {torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ntokens = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=3,
    n_cls=num_types if CLS else 1,
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    do_mvc=MVC,
    do_dab=DAB,
    use_batch_labels=INPUT_BATCH_LABELS,
    num_batch_labels=None,
    domain_spec_batchnorm=hyperparameter_defaults["DSBN"],
    input_emb_style=input_emb_style,
    n_input_bins=n_input_bins,
    cell_emb_style=cell_emb_style,
    mvc_decoder_style=mvc_decoder_style,
    ecs_threshold=ecs_threshold,
    explicit_zero_prob=explicit_zero_prob,
    use_fast_transformer=fast_transformer,
    fast_transformer_backend=fast_transformer_backend,
    pre_norm=hyperparameter_defaults["pre_norm"],
)
if hyperparameter_defaults["load_model"] is not None:
    try:
        model.load_state_dict(torch.load(model_file, map_location=device))
        print(f"Loading all model params from {model_file}")
    except:
        # only load params that are in the model and match the size
        model_dict = model.state_dict()
        pretrained_dict = torch.load(model_file, map_location=device)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        for k, v in pretrained_dict.items():
            print(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

pre_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())

# Freeze all pre-decoder weights
for name, para in model.named_parameters():
    print("-"*20)
    print(f"name: {name}")
    if hyperparameter_defaults["freeze"] and "encoder" in name and "transformer_encoder" not in name:
        print(f"freezing weights for: {name}")
        para.requires_grad = False

post_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())

model.to(device)

if ADV:
    discriminator = AdversarialDiscriminator(
        d_model=embsize,
        n_cls=num_batch_types,
    ).to(device)

Cuda is available: True
Loading all model params from ./model.pt
--------------------
name: encoder.embedding.weight
--------------------
name: encoder.enc_norm.weight
--------------------
name: encoder.enc_norm.bias
--------------------
name: value_encoder.linear1.weight
--------------------
name: value_encoder.linear1.bias
--------------------
name: value_encoder.linear2.weight
--------------------
name: value_encoder.linear2.bias
--------------------
name: value_encoder.norm.weight
--------------------
name: value_encoder.norm.bias
--------------------
name: transformer_encoder.layers.0.self_attn.in_proj_weight
--------------------
name: transformer_encoder.layers.0.self_attn.in_proj_bias
--------------------
name: transformer_encoder.layers.0.self_attn.out_proj.weight
--------------------
name: transformer_encoder.layers.0.self_attn.out_proj.bias
--------------------
name: transformer_encoder.layers.0.linear1.weight
--------------------
name: transformer_encoder.layers.0.linear1.bi

In [15]:
def evaluate(model: nn.Module, input_gene_ids, input_values):
    model.eval()
    predictions = []
    with torch.no_grad():
        # Move data to device
        input_gene_ids = torch.tensor(input_gene_ids).to(device)
        input_values = torch.tensor(input_values).to(device)
        src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
        
#         with torch.cuda.amp.autocast(enabled=hyperparameter_defaults["amp"]):
        with torch.cuda.amp.autocast(enabled=False):
            output_dict = model(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                batch_labels=None,
                CLS=CLS,  # evaluation does not need CLS or CCE
                CCE=False,
                MVC=False,
                ECS=False,
                do_sample=do_sample_in_train,
                # generative_training = False,
            )
            output_values = output_dict["cls_output"]
            return output_values

In [16]:
def tokenize(x):
    all_counts = (
        x.layers[input_layer_key].A
        if issparse(x.layers[input_layer_key])
        else x.layers[input_layer_key]
    )

    age_labels = x.obs["age_id"].tolist()  # make sure count from 0
    age_labels = np.array(age_labels)

    # Tokenize and pad a batch of data. Returns a list of tuple (gene_id, count)
    tokenized_test = tokenize_and_pad_batch(
        all_counts,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=include_zero_gene,
    )

    input_values_test = random_mask_value(
        tokenized_test["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )

    test_data_pt = {
        "gene_ids": tokenized_test["genes"],
        "values": input_values_test,
        "age_labels": torch.from_numpy(age_labels).long(),
    }
    
    return test_data_pt

In [17]:
data_tokenized = tokenize(adata)
gene_ids_tokenized = data_tokenized['gene_ids']
values_tokenized = data_tokenized['values']

num_samples = 1000
sample_size = 100
background_size = 100
np.random.seed(42)
sample_indices = np.random.choice(num_samples, sample_size, replace=False)
np.random.seed(24)
background_indices = np.random.choice(num_samples, background_size, replace=False)

sample_data = {
    'gene_ids': gene_ids_tokenized[sample_indices],
    'values': values_tokenized[sample_indices],
}

background_data = {
    'gene_ids': gene_ids_tokenized[background_indices],
    'values': values_tokenized[background_indices],
}


In [18]:
pred  = evaluate(model, background_data['gene_ids'], background_data['values'])

In [51]:
pred

tensor([[-4.9980,  1.2646, -0.9831,  0.9999],
        [-6.0103,  0.6579,  2.2665, -1.5088],
        [ 8.1151, -0.0442, -2.2191, -2.8731],
        [-5.3844, -0.7016,  1.4141,  0.0278],
        [-4.0334, -1.4406,  4.1624, -1.4413],
        [-3.7026, -1.5635,  0.5189,  1.2883],
        [-2.7651, -2.2523, -0.8181,  4.7030],
        [-5.9006, -0.8736,  2.7056, -0.4837],
        [-1.9611, -2.0895, -1.6023,  5.1898],
        [ 0.4272, -1.3041, -1.0851,  2.0024],
        [-5.6489,  0.0294,  0.5312,  0.2782],
        [-3.7749,  2.9418,  0.6464, -2.1891],
        [ 8.0499, -0.0371, -2.2557, -2.7854],
        [ 8.2493, -0.1994, -2.2436, -2.8290],
        [-4.9783, -1.2129,  0.1304,  2.0930],
        [-3.6091, -2.4759,  0.2529,  3.4584],
        [-5.6139, -0.8703,  3.5868, -1.1431],
        [ 8.1481, -0.0901, -2.2289, -2.8522],
        [ 5.8468,  1.2667, -1.2594, -2.4659],
        [ 8.2085, -0.2041, -2.2305, -2.7986],
        [-2.5261, -1.4744, -1.6605,  4.9951],
        [-5.9979, -0.3291,  3.2254

In [50]:
explainer = shap.Explainer(pred)

RuntimeError: Failed to import transformers.modeling_tf_utils because of the following error (look up to see its traceback):
Unable to convert function return value to a Python type! The signature was
	() -> handle

In [21]:
from scipy.sparse import issparse
x = adata.copy()
all_counts = (
        x.layers[input_layer_key].A
        if issparse(x.layers[input_layer_key])
        else x.layers[input_layer_key]
    )

age_labels = x.obs["age_id"].tolist()  # make sure count from 0
age_labels = np.array(age_labels)

tokenized_test = tokenize_and_pad_batch(
    all_counts,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,  # append <cls> token at the beginning
    include_zero_gene=include_zero_gene,
)

tokenized_test["genes"].shape

torch.Size([1000, 2837])

In [19]:
wandb.init(
    config=hyperparameter_defaults,
    project="scGPT",
    reinit=True,
    settings=wandb.Settings(start_method="fork"),
)
config = wandb.config
print(config)

set_seed(config.seed)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mqxy699[0m ([33mqxy699-The University of Tennessee at Chattanooga[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'seed': 0, 'dataset_name': 'fly', 'do_train': False, 'load_model': './model.pt', 'mask_ratio': 0.0, 'epochs': 10, 'n_bins': 51, 'MVC': False, 'ecs_thres': 0.0, 'dab_weight': 0.0, 'lr': 0.0001, 'batch_size': 1, 'layer_size': 512, 'nlayers': 12, 'nhead': 8, 'dropout': 0.2, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'fast_transformer': True, 'pre_norm': False, 'amp': True, 'include_zero_gene': False, 'freeze': False, 'DSBN': False}


In [172]:
import scipy as sp
def f(tokenized_test = tokenized_test):
    torch.cuda.empty_cache()
    with torch.cuda.amp.autocast():
        tv = tokenized_test["genes"][0].cuda()
        values_model = tokenized_test["values"][0].cuda()
        # print(f'{tv=}')
        # attention_mask = (tv != 0).type(torch.int64).cuda()
        src_key_padding_mask = tv.eq(vocab[pad_token]).cuda()
        print(f"{len(src_key_padding_mask)=}")
        outputs = model(tv, values_model, src_key_padding_mask=src_key_padding_mask)
        scores = (np.exp(outputs[:2]).T / np.exp(outputs[:2]).sum(-1)).T
        val = sp.special.logit(scores)
        return val

In [57]:
def evaluate(model: nn.Module, loader: DataLoader, return_raw: bool = False) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    model.eval()
    total_loss = 0.0
    total_error = 0.0
    total_dab = 0.0
    total_num = 0
    predictions = []
    criterion_cls = nn.CrossEntropyLoss()
    criterion_dab = nn.CrossEntropyLoss()
    with torch.no_grad():
        for batch_data in loader:
            print(f"{batch_data=}")
            input_gene_ids = batch_data["gene_ids"].to(device)
            input_values = batch_data["values"].to(device)
            target_values = batch_data["target_values"].to(device)
            # batch_labels = batch_data["batch_labels"].to(device)
            # celltype_labels = batch_data["celltype_labels"].to(device)

            src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
            with torch.cuda.amp.autocast(enabled=config.amp):
                output_dict = model(
                    input_gene_ids,
                    input_values,
                    src_key_padding_mask=src_key_padding_mask,
                    batch_labels=batch_labels if INPUT_BATCH_LABELS or config.DSBN else None,
                    CLS=CLS,  # evaluation does not need CLS or CCE
                    CCE=False,
                    MVC=False,
                    ECS=False,
                    do_sample=do_sample_in_train,
                    #generative_training = False,
                )
                output_values = output_dict["cls_output"]
                print(f'{output_values.shape=}')
                print(f'{target_values.shape=}')
                loss = criterion_cls(output_values, target_values)

                if DAB:
                    print(f"{output_dict['dab_output']=}")
                    loss_dab = criterion_dab(output_dict["dab_output"], target_values)

            total_loss += loss.item() * len(input_gene_ids)
            accuracy = (output_values.argmax(1) == celltype_labels).sum().item()
            total_error += (1 - accuracy / len(input_gene_ids)) * len(input_gene_ids)
            total_dab += loss_dab.item() * len(input_gene_ids) if DAB else 0.0
            total_num += len(input_gene_ids)
            preds = output_values.argmax(1).cpu().numpy()
            predictions.append(preds)

    wandb.log(
        {
            "valid/mse": total_loss / total_num,
            "valid/err": total_error / total_num,
            "valid/dab": total_dab / total_num,
            "valid/sum_mse_dab": (total_loss + dab_weight * total_dab) / total_num,
            "epoch": epoch,
        },
    )

    if return_raw:
        return np.concatenate(predictions, axis=0)

    return total_loss / total_num, total_error / total_num

In [58]:
valid_data_pt = prepare_data(tokenized_test, sort_seq_batch=per_seq_batch_sample)
valid_loader = prepare_dataloader(
        valid_data_pt,
        batch_size=eval_batch_size,
        shuffle=False,
        intra_domain_shuffle=False,
        drop_last=False,
)
# val_loss, val_err = evaluate(
#         model,
#         loader=valid_loader,
#     )

In [61]:
valid_data_pt

{'gene_ids': tensor([[60695, 35664,  1302,  ..., 60694, 60694, 60694],
         [60695, 35664,  1714,  ..., 60694, 60694, 60694],
         [60695,  8034,  2852,  ..., 60694, 60694, 60694],
         ...,
         [60695, 31038, 11348,  ..., 60694, 60694, 60694],
         [60695, 31038,  1302,  ..., 60694, 60694, 60694],
         [60695,  1714,  1966,  ..., 60694, 60694, 60694]]),
 'values': tensor([[ 0., 43., 42.,  ..., -2., -2., -2.],
         [ 0., 24., 40.,  ..., -2., -2., -2.],
         [ 0., 10., 13.,  ..., -2., -2., -2.],
         ...,
         [ 0.,  9., 27.,  ..., -2., -2., -2.],
         [ 0., 25., 41.,  ..., -2., -2., -2.],
         [ 0.,  7., 43.,  ..., -2., -2., -2.]]),
 'target_values': tensor([[ 0., 43., 42.,  ..., -2., -2., -2.],
         [ 0., 24., 40.,  ..., -2., -2., -2.],
         [ 0., 10., 13.,  ..., -2., -2., -2.],
         ...,
         [ 0.,  9., 27.,  ..., -2., -2., -2.],
         [ 0., 25., 41.,  ..., -2., -2., -2.],
         [ 0.,  7., 43.,  ..., -2., -2., -2.

In [60]:

for batch_data in valid_loader:
    input_gene_ids = batch_data["gene_ids"].to(device)
    input_values = batch_data["values"].to(device)
    target_values = batch_data["target_values"].to(device)
    batch_labels = batch_data["batch_labels"].to(device)
    celltype_labels = batch_data["celltype_labels"].to(device)

KeyError: 'batch_labels'

In [52]:
batch_data

{'gene_ids': tensor([[60695, 35664,  1302,  ..., 60694, 60694, 60694]]),
 'values': tensor([[ 0., 43., 42.,  ..., -2., -2., -2.]]),
 'target_values': tensor([[ 0., 43., 42.,  ..., -2., -2., -2.]])}

In [22]:
with torch.cuda.amp.autocast(enabled=config.amp):
    tv = tokenized_test["genes"][0].cuda()
    values_model = tokenized_test["values"][0].cuda()
    src_key_padding_mask = tv.eq(vocab[pad_token]).cuda()
    print(f"{len(src_key_padding_mask)=}")
    outputs = model(tv, values_model, src_key_padding_mask=src_key_padding_mask)
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores)

len(src_key_padding_mask)=2837


IndexError: too many indices for tensor of dimension 2

In [187]:
tv.shape

torch.Size([2837])

In [125]:
method = "custom tokenizer"

# build an explainer by passing a transformers tokenizer
if method == "transformers tokenizer":
    explainer = shap.Explainer(f, tokenizer, output_names=labels)

# build an explainer by explicitly creating a masker
elif method == "default masker":
    masker = shap.maskers.Text(r"\W")  # this will create a basic whitespace tokenizer
    explainer = shap.Explainer(f, masker, output_names=labels)

# build a fully custom tokenizer
elif method == "custom tokenizer":
    import re

    def custom_tokenizer(s, return_offsets_mapping=True):
        """Custom tokenizers conform to a subset of the transformers API."""
        pos = 0
        offset_ranges = []
        input_ids = []
        for m in re.finditer(r"\W", s):
            start, end = m.span(0)
            offset_ranges.append((pos, start))
            input_ids.append(s[pos:start])
            pos = end
        if pos != len(s):
            offset_ranges.append((pos, len(s)))
            input_ids.append(s[pos:])
        out = {}
        out["input_ids"] = input_ids
        if return_offsets_mapping:
            out["offset_mapping"] = offset_ranges
        return out

In [110]:
masker = shap.maskers.Text(custom_tokenizer)
explainer = shap.Explainer(f, masker, output_names=age_labels)

In [111]:
shap_values = explainer(adata.var["gene_name"][:5])

RuntimeError: Could not infer dtype of dict

In [107]:
custom_tokenizer(adata.var["gene_name"][:5][0], return_offsets_mapping=True)

{'input_ids': ['DRG1'], 'offset_mapping': [(0, 4)]}

In [160]:
try:
    result = f(adata.var["gene_name"][0])  # Assume you manage how this is passed to f
    print(result)
except RuntimeError as e:
    print("Failed due to memory issue:", e)

len(src_key_padding_mask)=2837
Failed due to memory issue: CUDA out of memory. Tried to allocate 124.00 MiB. GPU 0 has a total capacty of 23.69 GiB of which 118.00 MiB is free. Process 34076 has 9.57 GiB memory in use. Including non-PyTorch memory, this process has 12.57 GiB memory in use. Process 1156862 has 1.41 GiB memory in use. Of the allocated memory 12.25 GiB is allocated by PyTorch, and 18.45 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF


In [158]:
len(src_key_padding_mask)

NameError: name 'src_key_padding_mask' is not defined

In [119]:
adata.var["gene_name"]

(4715,)

In [93]:
gene_ids_in_vocab 

array([ 1,  1,  1, ...,  1,  1, -1])

In [97]:
adata.layers["X_binned"]

array([[ 0, 43,  0, ..., 36,  0,  0],
       [ 0, 24,  0, ..., 45,  0,  0],
       [10,  0,  0, ..., 32,  0,  0],
       ...,
       [ 0,  0,  9, ...,  0,  0,  0],
       [ 0,  0, 25, ..., 18,  0,  0],
       [ 0,  0,  0, ..., 43,  0,  0]])

In [98]:
gene_ids

array([ 8034, 35664, 31038, ..., 32191, 19734, 31375])