In [1]:
import numpy as np
import pandas as pd
import os
from scipy.stats import wasserstein_distance
import pandas as pd
import scanpy as sc
import warnings
from scipy.stats import spearmanr, pearsonr
from scipy.spatial import distance_matrix
from sklearn.metrics import matthews_corrcoef
import torch
from scipy.spatial.distance import cdist
import sys
from os.path import join
from IPython.display import display

from model.stDiff_model import DiT_stDiff
from model.stDiff_scheduler import NoiseScheduler
from model.stDiff_train import normal_train_stDiff
from model.sample import sample_stDiff
from process.result_analysis import clustering_metrics

warnings.filterwarnings('ignore')
torch.set_default_tensor_type('torch.cuda.FloatTensor')

from process.data import *

  from .autonotebook import tqdm as notebook_tqdm
2024-06-03 23:47:57,898	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-06-03 23:47:58,018	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


## preprocess
For scRNA-seq data, data enhancement is first performed, then standard preprocessing of normalize_total and log1p is required, and finally normalization is also required. \
ST data requires standard preprocessing and normalization

In [2]:
# ******** preprocess ********
adata_spatial = sc.read_h5ad('datasets/sp/' + 'dataset2_spatial_33.h5ad')
adata_seq = sc.read_h5ad('datasets/sc/'+ 'dataset2_seq_33.h5ad')

adata_seq2 = data_augment(adata_seq.copy(), True, noise_std=10)
adata_spatial2 = adata_spatial.copy()

sc.pp.normalize_total(adata_seq2, target_sum=1e4)
sc.pp.log1p(adata_seq2)
adata_seq2 = scale(adata_seq2) # stDiff need
data_seq_array = adata_seq2.X

sc.pp.normalize_total(adata_spatial2, target_sum=1e4)
sc.pp.log1p(adata_spatial2)
adata_spatial2 = scale(adata_spatial2)
data_spatial_array = adata_spatial2.X

sp_genes = np.array(adata_spatial.var_names)
sp_data = pd.DataFrame(data=data_spatial_array, columns=sp_genes)
sc_data = pd.DataFrame(data=data_seq_array, columns=sp_genes)

In [3]:
type(data_seq_array)

numpy.ndarray

## method
In this example, we mask out some genes in the ST data and then complete them. \
If you target the actual missing genes, just set the mask according to the position of the missing gene. \
Before ST data complementation for missing genes, the corresponding complete scRNA-seq data need to be trained.

In [4]:

lr = 0.00016046744893538737 
depth = 6 
num_epoch = 900 
diffusion_step = 1500 
batch_size = 2048 
hidden_size = 512 
head = 16

# mask
cell_num = data_spatial_array.shape[0]
gene_num = data_spatial_array.shape[1]
mask = np.ones((gene_num,), dtype='float32')

# gene_id_test
train_size = 0.8
gene_names_rnaseq = sp_genes 
np.random.seed(0)
n_genes = len(gene_names_rnaseq)
gene_ids_train = sorted(
    np.random.choice(range(n_genes), int(n_genes * train_size), False)
)
gene_ids_test = sorted(set(range(n_genes)) - set(gene_ids_train)) # test

mask[gene_ids_test] = 0

seq = data_seq_array
st = data_spatial_array
data_seq_masked = seq * mask
data_spatial_masked = st * mask

seq = seq * 2 - 1
data_seq_masked = data_seq_masked * 2 - 1

st = st * 2 - 1
data_spatial_masked = data_spatial_masked * 2 - 1

dataloader = get_data_loader(
    seq, # all gene
    data_seq_masked, # test gene = 0
    batch_size=batch_size, 
    is_shuffle=True)

seed = 1202
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

model = DiT_stDiff(
    input_size=gene_num,  
    hidden_size=hidden_size, 
    depth=depth,
    num_heads=head,
    classes=6, 
    mlp_ratio=4.0,
    dit_type='dit'
)

device = torch.device('cuda:1')
model.to(device)

diffusion_step = diffusion_step

save_path_prefix = 'ckpt/demo.pt'
# train
model.train()
if not os.path.isfile(save_path_prefix):

    normal_train_stDiff(model,
                            dataloader=dataloader,
                            lr=lr,
                            num_epoch=num_epoch,
                            diffusion_step=diffusion_step,
                            device=device,
                            pred_type='noise',
                            mask=mask)

    torch.save(model.state_dict(), save_path_prefix)
else:
    model.load_state_dict(torch.load(save_path_prefix))



  0%|                                                                       | 0/900 [00:00<?, ?it/s]


RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

In [6]:
# sample
gt = data_spatial_masked
noise_scheduler = NoiseScheduler(
    num_timesteps=diffusion_step,
    beta_schedule='cosine'
)

dataloader = get_data_loader(
    data_spatial_masked, # test gene = 0
    data_spatial_masked, # test gene = 0
    batch_size=batch_size, 
    is_shuffle=False)


model.eval()
imputation = sample_stDiff(model,
                                    device=device,
                                    dataloader=dataloader,
                                    noise_scheduler=noise_scheduler,
                                    mask=mask,
                                    gt=gt,
                                    num_step=diffusion_step,
                                    sample_shape=(cell_num, gene_num),
                                    is_condi=True,
                                    sample_intermediate=diffusion_step,
                                    model_pred_type='noise',
                                    is_classifier_guidance=False,
                                    omega=0.2)

data_spatial_masked[:, gene_ids_test] = imputation[:, gene_ids_test]

impu = (data_spatial_masked  + 1) / 2

time: 0: 100%|██████████| 1500/1500 [01:53<00:00, 13.25it/s]  


In [7]:
# ********** metrics **********
def imputation_metrics(original, imputed):
    absolute_error = np.abs(original - imputed)
    relative_error = absolute_error / np.maximum(
        np.abs(original), np.ones_like(original)
    )
    spearman_gene = []
    for g in range(imputed.shape[1]):
        if np.all(imputed[:, g] == 0):
            correlation = 0
        else:
            correlation = spearmanr(original[:, g], imputed[:, g])[0]
        spearman_gene.append(correlation)

    return {
        "median_absolute_error_per_gene": np.median(absolute_error, axis=0),
        "mean_absolute_error_per_gene": np.mean(absolute_error, axis=0),
        "mean_relative_error": np.mean(relative_error, axis=1),
        "median_relative_error": np.median(relative_error, axis=1),
        "spearman_per_gene": np.array(spearman_gene),

        # Metric we report in the GimVI paper:
        "median_spearman_per_gene": np.median(spearman_gene),
    }

tmp = imputation_metrics(np.array(data_spatial_array[:, gene_ids_test]), impu[:, gene_ids_test])
tmp['median_spearman_per_gene']


0.3104034168196477