In [1]:
import time

import numpy as np
import pandas as pd
import scanpy as sc
import scipy
import torch
import harmonypy as hm
import math

from scSLAT.utils import global_seed
from scSLAT.model import scanpy_workflow
from scSLAT.model.prematch import rotate_via_numpy, perturb_data

In [None]:
# parameters cell
dataset1_file = ''
dataset2_file = ''
cells = 5000
seed = 0
leiden_repo = 'X_harmony'
adata1_out = ''
adata2_out = ''
rotation = False
perturb = False
inverse_noise = 5

In [None]:
global_seed(seed)

In [None]:
adata1 = sc.read_h5ad(dataset1_file)
adata2 = sc.read_h5ad(dataset2_file)
adata1 = sc.pp.subsample(adata1, n_obs=cells, copy=True, random_state=seed) if cells>0 and cells<=adata1.shape[0] \
    else adata1.copy()
adata2 = sc.pp.subsample(adata2, n_obs=cells, copy=True, random_state=seed) if cells>0 and cells<=adata2.shape[0] \
    else adata2.copy()

In [None]:
if adata1.shape[0] < adata2.shape[0]:
    adata2, adata1 = adata1, adata2

# Random rotation and perturbation

In [None]:
if cells==0 and rotation:
    deg = np.random.randint(0, 360)
    rad = np.deg2rad(deg)
    adata2.obsm['spatial'] = rotate_via_numpy(adata2.obsm['spatial'], rad)
    adata2.uns['rotation'] = deg
    
if perturb:
    adata2 = adata1.copy()
    perturb_data(adata2, inverse_noise=inverse_noise)

In [None]:
adata1.layers['counts'] = adata1.X.copy()
adata2.layers['counts'] = adata2.X.copy()

# PCA and Harmony

In [None]:
adata_all = adata1.concatenate(adata2)
start = time.time()
adata_all = scanpy_workflow(adata_all)
end_pca = time.time()
harm = hm.run_harmony(adata_all.obsm['X_pca'], adata_all.obs, 'batch', max_iter_harmony=20)
Z = harm.Z_corr.T
end_harmony = time.time()
adata_all.obsm['X_harmony'] = Z

In [None]:
sc.pp.neighbors(adata_all, use_rep=leiden_repo)
sc.tl.leiden(adata_all, resolution=0.5)
sc.tl.umap(adata_all)
sc.pl.umap(adata_all, color=["leiden", "batch"])

In [None]:
adata1.obsm['X_harmony'] = Z[:adata1.shape[0],:]
adata2.obsm['X_harmony'] = Z[adata1.shape[0]:,:]

In [None]:
adata1.obsm['X_pca'] = adata_all.obsm['X_pca'][:adata1.shape[0],:] 
adata2.obsm['X_pca'] = adata_all.obsm['X_pca'][adata1.shape[0]:,:]

In [None]:
adata1.obs['leiden'] = adata_all.obs['leiden'][:adata1.shape[0]].values
adata2.obs['leiden'] = adata_all.obs['leiden'][adata1.shape[0]:].values

# Save adatas

In [None]:
adata1.uns['pca_time'] = str(end_pca-start)
adata1.uns['harmony_time'] = str(end_harmony-start)
adata2.uns['pca_time'] = str(end_pca-start)
adata2.uns['harmony_time'] = str(end_harmony-start)

In [None]:
adata1.X = adata1.layers['counts'].copy()
adata2.X = adata2.layers['counts'].copy()
adata1.write_h5ad(adata1_out)
adata2.write_h5ad(adata2_out)