In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
from pathlib import Path
import anndata as ad
import scanpy as sc
import numpy as np
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import seaborn as sns
from matplotlib import pyplot as plt

sys.path.insert(0, "../../scripts/methods/")
from my_slat import slat_align, slat_align_ref
sys.path.insert(0, "/home/ylu/project")
from utils import *
import time

%load_ext autoreload
%autoreload 2

In [2]:
data_folder = "/home/ylu/project/MOSTA/data/"
results_folder = "./results/SLAT/"
figures_folder = "./results/figures/SLAT"
Path(results_folder).mkdir(parents=True, exist_ok=True)

In [3]:
## load the data
from tqdm import tqdm
slices = []
for i in tqdm(range(1, 14)):
    slices.append(ad.read_h5ad(os.path.join(data_folder, f"E16.5_E2S{i}.MOSTA.h5ad")))

100%|██████████████████████| 13/13 [00:48<00:00,  3.75s/it]


In [4]:
spatial_key = "spatial"

In [5]:
## Rotate the data
rotate_key = "spatial_rot"
rotations = np.load("./results/random_rotations.npy", allow_pickle=True)
for i in range(len(slices)):
    slices[i].obsm[rotate_key] = slices[i].obsm[spatial_key][:,:2].copy()
    mean = np.mean(slices[i].obsm[rotate_key], axis=0)
    slices[i].obsm[rotate_key] = slices[i].obsm[rotate_key] - mean
    slices[i].obsm[rotate_key] = slices[i].obsm[rotate_key] @ rotations[i].T + mean
    

In [6]:
spatial_key = "spatial_rot"
key_added = "aligned_spatial"

In [7]:
## Perform the SLAT alignment
sampling_num = 20000
for i in tqdm(range(len(slices)-1)):
    slice1, slice2 = slices[i].copy(), slices[i+1].copy()
    sampline_idx1 = np.random.choice(slice1.shape[0], sampling_num, replace=False) if slice1.shape[0] > sampling_num else np.arange(slice1.shape[0])
    sampline_idx2 = np.random.choice(slice2.shape[0], sampling_num, replace=False) if slice2.shape[0] > sampling_num else np.arange(slice2.shape[0])
    slice1 = slice1[sampline_idx1,:]
    slice2 = slice2[sampline_idx2,:]
    time_start = time.time()
    align_slices, pis = slat_align(
        models = [slice1, slice2],
        spatial_key=spatial_key,
        key_added=key_added,
    )
    time_end = time.time()
    matches = pis
    R, t=solve_RT_by_correspondence(align_slices[1].obsm[key_added], align_slices[1].obsm[spatial_key])
    alignment_results = {'sampling_idx1': sampline_idx1, 'sampline_idx2': sampline_idx2, 'R': R, 't': t, 'matches': matches, 'time': time_end - time_start}
    np.save(os.path.join(results_folder, f"slice_{i}_{i+1}_sampling_{sampling_num}.npy"), alignment_results, allow_pickle=True)

  0%|                               | 0/12 [00:00<?, ?it/s]

Calculating spatial neighbor graph ...
The graph contains 222556 edges, 20000 cells.
11.1278 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 223188 edges, 20000 cells.
11.1594 neighbors per cell on average.
Use DPCA feature to format graph


  concat_annot[label] = label_col
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 2.64


  8%|█▉                     | 1/12 [00:47<08:40, 47.30s/it]

Calculating spatial neighbor graph ...
The graph contains 223270 edges, 20000 cells.
11.1635 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 224114 edges, 20000 cells.
11.2057 neighbors per cell on average.
Use DPCA feature to format graph


  concat_annot[label] = label_col
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.67


 17%|███▊                   | 2/12 [01:30<07:28, 44.85s/it]

Calculating spatial neighbor graph ...
The graph contains 224692 edges, 20000 cells.
11.2346 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 224478 edges, 20000 cells.
11.2239 neighbors per cell on average.
Use DPCA feature to format graph


  concat_annot[label] = label_col
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.66


 25%|█████▊                 | 3/12 [02:12<06:34, 43.81s/it]

Calculating spatial neighbor graph ...
The graph contains 224224 edges, 20000 cells.
11.2112 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 224724 edges, 20000 cells.
11.2362 neighbors per cell on average.
Use DPCA feature to format graph


  concat_annot[label] = label_col
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.70


 33%|███████▋               | 4/12 [02:58<05:56, 44.59s/it]

Calculating spatial neighbor graph ...
The graph contains 224260 edges, 20000 cells.
11.213 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 225098 edges, 20000 cells.
11.2549 neighbors per cell on average.
Use DPCA feature to format graph


  concat_annot[label] = label_col
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.69


 42%|█████████▌             | 5/12 [03:51<05:33, 47.64s/it]

Calculating spatial neighbor graph ...
The graph contains 225026 edges, 20000 cells.
11.2513 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 225132 edges, 20000 cells.
11.2566 neighbors per cell on average.
Use DPCA feature to format graph


  concat_annot[label] = label_col
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.81


 50%|███████████▌           | 6/12 [04:57<05:21, 53.66s/it]

Calculating spatial neighbor graph ...
The graph contains 225100 edges, 20000 cells.
11.255 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 224932 edges, 20000 cells.
11.2466 neighbors per cell on average.
Use DPCA feature to format graph


  concat_annot[label] = label_col
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.70


 58%|█████████████▍         | 7/12 [06:13<05:05, 61.05s/it]

Calculating spatial neighbor graph ...
The graph contains 224658 edges, 20000 cells.
11.2329 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 224854 edges, 20000 cells.
11.2427 neighbors per cell on average.
Use DPCA feature to format graph


  concat_annot[label] = label_col
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.67


 67%|███████████████▎       | 8/12 [07:24<04:17, 64.29s/it]

Calculating spatial neighbor graph ...
The graph contains 224774 edges, 20000 cells.
11.2387 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 224432 edges, 20000 cells.
11.2216 neighbors per cell on average.
Use DPCA feature to format graph


  concat_annot[label] = label_col
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.67


 75%|█████████████████▎     | 9/12 [08:10<02:55, 58.42s/it]

Calculating spatial neighbor graph ...
The graph contains 224380 edges, 20000 cells.
11.219 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 224314 edges, 20000 cells.
11.2157 neighbors per cell on average.
Use DPCA feature to format graph


  concat_annot[label] = label_col
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.69


 83%|██████████████████▎   | 10/12 [08:50<01:45, 52.98s/it]

Calculating spatial neighbor graph ...
The graph contains 224396 edges, 20000 cells.
11.2198 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 223870 edges, 20000 cells.
11.1935 neighbors per cell on average.
Use DPCA feature to format graph


  concat_annot[label] = label_col
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.71


 92%|████████████████████▏ | 11/12 [09:30<00:48, 48.78s/it]

Calculating spatial neighbor graph ...
The graph contains 223956 edges, 20000 cells.
11.1978 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 223450 edges, 20000 cells.
11.1725 neighbors per cell on average.
Use DPCA feature to format graph


  concat_annot[label] = label_col
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  utils.warn_names_duplicates("obs")
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 1.69


100%|██████████████████████| 12/12 [10:11<00:00, 50.96s/it]
