In [9]:
import os
import scanpy as sc
import pandas as pd
import numpy as np
import squidpy as sq
import scarches as sca
import torch
from scipy.sparse import csr_matrix
import warnings

In [10]:
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings("ignore", category=UserWarning, message="Variable names are not unique.*")
print(device)

cpu


In [2]:
target_adata_raw = sc.read_h5ad('./GBM_data/P137_A1.h5ad')
target_adata_raw.var_names_make_unique()
target_adata_raw.obs['slide'] = 'target'

In [3]:
sq.gr.spatial_neighbors(target_adata_raw, coord_type="grid")
conn = target_adata_raw.obsp['spatial_connectivities']
X = target_adata_raw.X
if not isinstance(X, np.ndarray):
    X = X.toarray()
neighbor_sum = conn @ X
n_neighbors = np.array(conn.sum(axis=1)).flatten()
n_neighbors_safe = np.where(n_neighbors == 0, 1, n_neighbors)
neighbor_mean = neighbor_sum / n_neighbors_safe[:, None]
X_smoothed = 0.6 * X + 0.4 * neighbor_mean
target_adata_raw.X = csr_matrix(X_smoothed)

In [4]:
sp_path_GBM = './GBM_data/'
sp_files_GBM = os.listdir(sp_path_GBM)
sp_files_GBM = [x for x in sp_files_GBM if '.h5ad' in x]
sp_files_GBM.sort()
slide_ids_GBM = [x.split('.h5ad')[0] for x in sp_files_GBM]

In [5]:
tmp_df = pd.read_csv('./GBM_data/spatial_domain_table_ER.csv',index_col=0)
adata_list = []
for i, tmp_slide in enumerate(slide_ids_GBM):
    tmp_ad = sc.read_h5ad(f'{sp_path_GBM}{tmp_slide}.h5ad')
    tmp_ad.var_names_make_unique()
    tmp_ad.obs['spatial_domain'] = [f'SD-{x:01}' for x in tmp_df.loc[tmp_ad.obs_names]['spatial_domain']]
    adata_list.append(tmp_ad)
source_adata_raw = sc.concat(adata_list)
source_adata_raw.var_names_make_unique()

32it [00:12,  2.52it/s]


In [6]:
varlist_intersection = [var_name for var_name in target_adata_raw.var_names if var_name in source_adata_raw.var_names]
print(f"overlap gene num: {len(varlist_intersection)}")
source_adata = source_adata_raw[:,varlist_intersection].copy()
target_adata = target_adata_raw[:,varlist_intersection].copy()

overlap gene num: 16306


In [7]:
sca.models.SCVI.setup_anndata(source_adata, batch_key='slide', labels_key='spatial_domain')
vae = sca.models.SCVI(
    source_adata,
    n_layers=2,
    encode_covariates=True,
    deeply_inject_covariates=False,
    use_layer_norm="both",
    use_batch_norm="none",
)

In [8]:
vae.train()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 1/101:   0%|          | 0/101 [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [None]:
scanvae = sca.models.SCANVI.from_scvi_model(vae, unlabeled_category = "Unknown")
scanvae.train(max_epochs=20)
source_adata.obs['predictions'] = scanvae.predict()
# print("Acc: {}".format(np.mean(source_adata.obs.predictions == source_adata.obs['spatial_domain'])))

In [None]:
ref_path = f'./ref_model/'
scanvae.save(ref_path, overwrite=True)

In [None]:
savekey = 'spatial_domain' in target_adata.obs.columns
if savekey:
    tmp_obs = target_adata.obs['spatial_domain']
target_adata.obs['spatial_domain'] = scanvae.unlabeled_category_
sc.pp.filter_cells(target_adata, min_genes=20)

In [None]:
model = sca.models.SCANVI.load_query_data(
    target_adata,
    ref_path,
    freeze_dropout = True,
)
model._unlabeled_indices = np.arange(target_adata.n_obs)
model._labeled_indices = []

In [None]:
model.train(
    max_epochs=100,
    plan_kwargs=dict(weight_decay=0.0),
    check_val_every_n_epoch=10,
)

In [None]:
# surgery_path = f'./surgery_model'
# model.save(surgery_path, overwrite=True)

In [None]:
if savekey:
    target_adata.obs['spatial_domain'] = tmp_obs
target_adata.obs['predict_SD'] = model.predict()

In [None]:
target_adata.write_h5ad(f'./output/adata_predict_SD.h5ad')