In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import sys
from pathlib import Path
import anndata as ad
import scanpy as sc
import numpy as np
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import seaborn as sns
from matplotlib import pyplot as plt

sys.path.insert(0, "../../scripts/methods/")
from my_slat import slat_align, slat_align_ref
sys.path.insert(0, "/home/ylu/project")
from utils import *
import time
from tqdm import tqdm

%load_ext autoreload
%autoreload 2

In [2]:
# Load data
data_folder = "../../data/BARseq/BARseq_Perturbed/"
results_folder = "./results/SLAT/"
Path(results_folder).mkdir(parents=True, exist_ok=True)

In [None]:
subsample_num = 20000
  
for i in tqdm(range(1, 40)):
    slice1 = ad.read_h5ad(filename=os.path.join(data_folder,"slice_{:0>2d}.h5ad".format(i)))
    slice2 = ad.read_h5ad(filename=os.path.join(data_folder,"slice_{:0>2d}.h5ad".format(i+1)))
    
    # subsample
    subsample1 = np.random.choice(slice1.shape[0], subsample_num, replace=False) if slice1.shape[0] > subsample_num else np.arange(slice1.shape[0])
    subsample2 = np.random.choice(slice2.shape[0], subsample_num, replace=False) if slice2.shape[0] > subsample_num else np.arange(slice2.shape[0])
    slice1 = slice1[subsample1,:]
    slice2 = slice2[subsample2,:]

    spatial_key = 'perturbed_spatial'
    key_added = 'align_spatial'
    align_slices, pis = slat_align(
        models = [slice1, slice2],
        spatial_key=spatial_key,
        key_added=key_added,
    )  
    # recover true R and t
    R1, t1=solve_RT_by_correspondence(align_slices[1].obsm[key_added], align_slices[1].obsm['perturbed_spatial'])
    R2, t2=solve_RT_by_correspondence(align_slices[0].obsm['perturbed_spatial'], align_slices[0].obsm[key_added])
    t = t1 @ R2.T + t2
    R = R2 @ R1
    results = {"R":R, "t":t, "subsample": [subsample1, subsample2]}
    np.save(os.path.join(results_folder, "slice_{}_{}.npy".format(i,i+1)), results, allow_pickle=True)

  0%|                                                           | 0/39 [00:00<?, ?it/s]

Calculating spatial neighbor graph ...
The graph contains 231210 edges, 20000 cells.
11.5605 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 230742 edges, 20000 cells.
11.5371 neighbors per cell on average.
Use DPCA feature to format graph


  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.97


  3%|█▎                                                 | 1/39 [00:22<14:33, 22.98s/it]

Calculating spatial neighbor graph ...
The graph contains 230736 edges, 20000 cells.
11.5368 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 231460 edges, 20000 cells.
11.573 neighbors per cell on average.
Use DPCA feature to format graph


  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.59


  5%|██▌                                                | 2/39 [00:42<13:05, 21.23s/it]

Calculating spatial neighbor graph ...
The graph contains 231118 edges, 20000 cells.
11.5559 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 230672 edges, 20000 cells.
11.5336 neighbors per cell on average.
Use DPCA feature to format graph


  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.76


  8%|███▉                                               | 3/39 [01:04<12:46, 21.29s/it]

Calculating spatial neighbor graph ...
The graph contains 230572 edges, 20000 cells.
11.5286 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 232118 edges, 20000 cells.
11.6059 neighbors per cell on average.
Use DPCA feature to format graph


  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 2.24


 10%|█████▏                                             | 4/39 [01:29<13:13, 22.68s/it]

Calculating spatial neighbor graph ...
The graph contains 232428 edges, 20000 cells.
11.6214 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 232808 edges, 20000 cells.
11.6404 neighbors per cell on average.
Use DPCA feature to format graph


  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.66


 13%|██████▌                                            | 5/39 [01:54<13:28, 23.78s/it]

Calculating spatial neighbor graph ...
The graph contains 232622 edges, 20000 cells.
11.6311 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 231778 edges, 20000 cells.
11.5889 neighbors per cell on average.
Use DPCA feature to format graph


  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.70
