## Import necessary packages

In [1]:
import warnings
warnings.filterwarnings("ignore")

import hdf5plugin
import numpy as np
import anndata as ad
from scipy.sparse import csr_matrix
from CellPLM.utils import set_seed
from CellPLM.utils.data import stratified_sample_genes_by_sparsity
from CellPLM.pipeline.ssl import SSLPipeline, SSLDefaultPipelineConfig, SSLDefaultModelConfig

## Specify important parameters before getting started

In [2]:
DATASET = 'Liver' # 'Lung'
PRETRAIN_VERSION = '20231027_85M'
DEVICE = 'cuda:0'

## Load Downstream Dataset

The example datasets here are taken from `HumanLungCancerPatient2` from [Lung cancer 2](https://info.vizgen.com/ffpe-showcase?submissionGuid=88ba0a44-26e2-47a2-8ee4-9118b9811fbf), `GSE131907_Lung` from [GSE131907](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE131907), `HumanLiverCancerPatient2` from [Liver cancer 2](https://info.vizgen.com/ffpe-showcase?submissionGuid=88ba0a44-26e2-47a2-8ee4-9118b9811fbf) and `GSE151530_Liver` from [GSE151530](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE151530).

The data we released are already preprocessed, where we subset the SRT dataset by selecting the first 100 FOVs. For scRNA-seq datasets, we only preserve genes that are overlapped with the SRT dataset. This is to ensure that for all the genes involved in this example, we know the ground-truth gene expressions from the SRT dataset. Later, we will hold out part of the genes from the SRT dataset, and leverage information from the scRNA-seq dataset to impute them. Therefore, this gene filtering is only for the convenience of evaluation. In practice, we can leverage the scRNA-seq dataset to impute unmeasured genes in the SRT dataset.

After the preprocessing, the AnnData object must contain following information:

* `.obs['platform']` A string label for identification of SRT data. When platform is set to 'cosmx' or 'merfish', spatial positional information will be loaded.
* `.obs['x_FOV_px']` For SRT data, please store the float/int type X coordinate of each cell here.
* `.obs['y_FOV_px']` For SRT data, please store the float/int type Y coordinate of each cell here.
* `.obs['batch']` For SRT data, batch refers to an FOV. For scRNA-seq data, batch refers to a sample. Please store a string type batch identifier here.

In [3]:
set_seed(11)
if DATASET == 'Lung':
    query_dataset = 'HumanLungCancerPatient2_filtered_ensg.h5ad'
    ref_dataset = 'GSE131907_Lung_ensg.h5ad'
    query_data = ad.read_h5ad(f'../data/{query_dataset}')
    ref_data = ad.read_h5ad(f'../data/{ref_dataset}')

elif DATASET == 'Liver':
    query_dataset = 'HumanLiverCancerPatient2_filtered_ensg.h5ad'
    ref_dataset = 'GSE151530_Liver_ensg.h5ad'
    query_data = ad.read_h5ad(f'../data/{query_dataset}')
    ref_data = ad.read_h5ad(f'../data/{ref_dataset}')

target_genes = stratified_sample_genes_by_sparsity(query_data, seed=11) # This is for reproducing the hold-out gene lists in our paper
query_data.obsm['truth'] = query_data[:, target_genes].X.toarray()
query_data[:, target_genes].X = 0
train_data = query_data.concatenate(ref_data, join='outer', batch_key=None, index_unique=None)

train_data.obs['split'] = 'train'
train_data.obs['split'][train_data.obs['batch']==query_data.obs['batch'][-1]] = 'valid'
train_data.obs['split'][train_data.obs['batch']==ref_data.obs['batch'][-1]] = 'valid'

## Specify gene to impute
In the last step, we merge the query dataset (SRT) and the reference dataset (scRNA-seq). However, the query dataset does not measures all the genes. For fine-tuning the model, we need to specify which genes are measured in each dataset. Therefore, we create a dictionary for the imputation pipeline.

In [4]:
query_genes = [g for g in query_data.var.index if g not in target_genes]
query_batches = list(query_data.obs['batch'].unique())
ref_batches = list(ref_data.obs['batch'].unique())
batch_gene_list = dict(zip(list(query_batches) + list(ref_batches),
    [query_genes]*len(query_batches) + [ref_data.var.index.tolist()]*len(ref_batches)))
print('query_genes:', query_genes)

query_genes: ['ENSG00000041982', 'ENSG00000158296', 'ENSG00000081041', 'ENSG00000113318', 'ENSG00000165699', 'ENSG00000171223', 'ENSG00000110777', 'ENSG00000006210', 'ENSG00000197461', 'ENSG00000121594', 'ENSG00000174807', 'ENSG00000109072', 'ENSG00000174125', 'ENSG00000178999', 'ENSG00000181449', 'ENSG00000150045', 'ENSG00000184451', 'ENSG00000180644', 'ENSG00000105976', 'ENSG00000164867', 'ENSG00000114013', 'ENSG00000078098', 'ENSG00000187498', 'ENSG00000017427', 'ENSG00000103855', 'ENSG00000169248', 'ENSG00000204103', 'ENSG00000163508', 'ENSG00000171552', 'ENSG00000092969', 'ENSG00000163513', 'ENSG00000077150', 'ENSG00000122512', 'ENSG00000125657', 'ENSG00000169896', 'ENSG00000170458', 'ENSG00000121858', 'ENSG00000115414', 'ENSG00000177606', 'ENSG00000101017', 'ENSG00000203747', 'ENSG00000188676', 'ENSG00000138722', 'ENSG00000137673', 'ENSG00000186891', 'ENSG00000138755', 'ENSG00000111206', 'ENSG00000172216', 'ENSG00000111276', 'ENSG00000123384', 'ENSG00000211445', 'ENSG00000160683'

## Overwrite parts of the default config
These hyperparameters are recommended for general purpose. We did not tune it for individual datasets. You may update them if needed.

In [5]:
pipeline_config = SSLDefaultPipelineConfig.copy()
model_config = SSLDefaultModelConfig.copy()

pipeline_config, model_config

({'lr': 0.0002,
  'wd': 1e-08,
  'scheduler': 'plat',
  'epochs': 1,
  'max_eval_batch_size': 100000,
  'patience': 5,
  'workers': 0},
 {'objective': 'SSL',
  'mask_node_rate': 0.75,
  'mask_feature_rate': 0.25,
  'head_type': 'ssl',
  'max_batch_size': 70000})

## Fine-tuning

Efficient data setup and fine-tuning can be seamlessly conducted using the CellPLM built-in `pipeline` module.

First, initialize a `ImputationPipeline`. This pipeline will automatically load a pretrained model.

In [6]:
pipeline = SSLPipeline(pretrain_prefix=PRETRAIN_VERSION, # Specify the pretrain checkpoint to load
                                      overwrite_config=model_config,  # This is for overwriting part of the pretrain config
                                      pretrain_directory='../ckpt')
model = pipeline.model

for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name}:")
        print("Initial weights:", param.data)


in thre vae:::::::::::::::::::::::::::::::::::::::::::::::::::::::::))))))))))))))))))))))))))))))))))))))
self.head_type: ssl
55555555555555555555555555555555555555555555self.pre_latent_norm = PreLatentNorm(ln, enc_hid)
embedder.norm0.weight:
Initial weights: tensor([1., 1., 1.,  ..., 1., 1., 1.])
embedder.norm0.bias:
Initial weights: tensor([0., 0., 0.,  ..., 0., 0., 0.])
embedder.extra_linear.0.weight:
Initial weights: tensor([[-0.0219, -0.0008,  0.0304,  ...,  0.0036,  0.0086, -0.0162],
        [ 0.0261,  0.0049, -0.0098,  ...,  0.0007, -0.0030,  0.0145],
        [ 0.0102, -0.0045, -0.0115,  ...,  0.0306, -0.0108, -0.0021],
        ...,
        [-0.0011, -0.0138,  0.0118,  ...,  0.0008,  0.0182, -0.0075],
        [ 0.0122,  0.0076,  0.0231,  ...,  0.0296, -0.0130, -0.0008],
        [ 0.0014, -0.0158,  0.0034,  ..., -0.0281, -0.0267, -0.0173]])
embedder.extra_linear.0.bias:
Initial weights: tensor([-0.0262,  0.0247, -0.0041,  ...,  0.0117,  0.0010,  0.0041])
embedder.extra_linear.3.

In [7]:
total_params = sum(p.numel() for p in model.parameters())

print(f'Total number of parameters in the model: {total_params}')

Total number of parameters in the model: 86168494


In [8]:
model

# Assuming you have the OmicsFormer model defined as `model`
# Calculate total parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'Total parameters: {total_params}')


Total parameters: 86168494


Next, employ the `fit` function to fine-tune the model on your downstream dataset. This dataset should be in the form of an AnnData object, where `.X` is a csr_matrix. See previous section for more details.

Typically, a dataset containing approximately 20,000 cells can be trained in under 10 minutes using a V100 GPU card, with an expected GPU memory consumption of around 8GB.

In [9]:
print('train_data:', train_data)

train_data: AnnData object with n_obs × n_vars = 77105 × 407
    obs: 'fov', 'volume', 'center_x', 'center_y', 'min_x', 'max_x', 'min_y', 'max_y', 'n_genes', 'batch', 'x_FOV_px', 'y_FOV_px', 'platform', 'Sample', 'Type', 'split'
    var: 'n_cells-0', 'n_cells-1', 'ENSG-1'
    obsm: 'truth'


In [10]:
pipeline.fit(train_data, # An AnnData object
            pipeline_config, # The config dictionary we created previously, optional
            split_field = 'split', #  Specify a column in .obs that contains split information
            train_split = 'train',
            valid_split = 'valid',
            batch_gene_list = batch_gene_list, # Specify genes that are measured in each batch, see previous section for more details
            device = DEVICE,
            ) 

After filtering, 407 genes remain.
!!!!!!len(batch_gene_list[4]!!!!!!!!!!!!!!!!!!!!!!!): 307
!!!!!!len(self.gene_list!!!!!!!!!!!!!!!!!!!!!!!): 407


  0%|                                                                                             | 0/1 [00:01<?, ?it/s]


IndexError: The shape of the mask [407] at index 0 does not match the shape of the indexed tensor [226, 19374] at index 1

## Inference and evaluation
Once the pipeline has been fitted to the downstream datasets, performing inference or evaluation on new datasets can be easily accomplished using the built-in `predict` and `score` functions.

In [None]:
pipeline.predict(
        query_data, # An AnnData object
        pipeline_config, # The config dictionary we created previously, optional
        device = DEVICE,
    )

In [None]:
pipeline.score(
                query_data, # An AnnData object
                evaluation_config = {'target_genes': target_genes}, # The config dictionary we created previously, optional
                label_fields = ['truth'], # A field in .obsm that stores the ground-truth for evaluation
                device = DEVICE,
)  

In [None]:
#untrained
{'mse': 0.1641050273599103,
 'rmse': 0.4326375951244845,
 'mae': 0.2050452692620456,
 'corr': 0.3440739825367928,
 'cos': 0.4543862997740507}

#pretrained

{'mse': 0.14237468457315117,
 'rmse': 0.4044635064227159,
 'mae': 0.19642182688228785,
 'corr': 0.3790783796971664,
 'cos': 0.4802470280230045}

In [None]:
import torch

# Example tensors
out_dict = {'pred': torch.tensor([0.0001, 1.0, 10.0, 100.0])}
truth = torch.tensor([0.0002, 1.1, 9.0, 95.0])

# Applying log1p
log_pred = torch.log1p(out_dict['pred'])
log_truth = torch.log1p(truth)

print("Log-transformed predictions:", log_pred)
print("Log-transformed true values:", log_truth)

