In [None]:
import yaml
import os
import time
from pathlib import Path

import scanpy as sc
import torch
import numpy as np
import pandas as pd
import torch_geometric

from matplotlib import pyplot as plt

from scSLAT.model import load_anndatas, run_SLAT, Cal_Spatial_Net, spatial_match
from scSLAT.utils import global_seed
from scSLAT.model.prematch import icp, alpha_shape
from scSLAT.metrics import global_score, euclidean_dis
from scSLAT.viz import match_3D_multi

In [None]:
sc.set_figure_params(dpi_save=200, dpi=150)

In [None]:
# parameter cells
adata1_file = ''
adata2_file = ''
noise_level = -1
seed = 0
metric_file = ''
emb0_file = ''
emb1_file = ''
matching_file = ''

In [None]:
adata1 = sc.read_h5ad(adata1_file)
adata2 = sc.read_h5ad(adata2_file)

In [None]:
if 'visium' and 'DLPFC' in adata1_file:
    biology_meta = 'cell_type'
    topology_meta = 'layer_guess'
    alpha = 10
    LGCN_layer = 2
elif 'merfish' and 'hypothalamic' in adata1_file:
    biology_meta = 'Cell_class'
    topology_meta = 'region'
    alpha = 25
    LGCN_layer = 2
elif 'stereo' and 'embryo' in adata1_file:
    biology_meta = 'annotation'
    topology_meta = 'region'
    alpha = 3
    LGCN_layer = 1
elif 'brain' in adata1_file:
    biology_meta = 'layer_guess'
    topology_meta = 'layer_guess'
    alpha = 10
    LGCN_layer = 2

# Prematch

In [None]:
start = time.time()
if 'rotation' in adata2.uns.keys():
    boundary_1, edges_1, _ = alpha_shape(adata1.obsm['spatial'], alpha=alpha, only_outer=True)
    boundary_2, edges_2, _ = alpha_shape(adata2.obsm['spatial'], alpha=alpha, only_outer=True)
    T, error = icp(adata2.obsm['spatial'][boundary_2,:].T, adata1.obsm['spatial'][boundary_1,:].T)
    rotation = np.arcsin(T[0,1]) * 360 / 2 / np.pi

    print("T",  T)
    print("icp loss", error)
    print("rotationÂ°", rotation)

    print(f"ground truth:{adata2.uns['rotation']}, prematch result:{rotation}, error is {adata2.uns['rotation'] - rotation}")

    trans = np.squeeze(cv2.transform(np.array([adata2.obsm['spatial']], copy=True).astype(np.float32), T))[:,:2]
    adata2.obsm['spatial'] = trans

# run SLAT

In [None]:
global_seed(int(seed))

In [None]:
Cal_Spatial_Net(adata1, k_cutoff=20, model='KNN')
Cal_Spatial_Net(adata2, k_cutoff=20, model='KNN')
edges, features = load_anndatas([adata1, adata2], feature='harmony')

In [None]:
g1 = torch_geometric.utils.dropout_adj(edges[0],p=noise_level)[0]
g2 = torch_geometric.utils.dropout_adj(edges[1],p=noise_level)[0]
torch_geometric.utils.remove_isolated_nodes(g1)[2].shape
torch_geometric.utils.remove_isolated_nodes(g2)[2].shape
edges = [g1,g2]

In [None]:
embd0, embd1, time1 = run_SLAT(features, edges, 6, LGCN_layer=LGCN_layer)
print('Runtime: ' + str(time.time() - start))
run_time = str(time.time() - start)

In [None]:
adata1.obsm['X_slat'] = embd0.cpu().detach().numpy()
adata2.obsm['X_slat'] = embd1.cpu().detach().numpy()

# Metric

In [None]:
embd0 = adata1.obsm['X_slat']
embd1 = adata2.obsm['X_slat']
best, index, distance = spatial_match([embd0, embd1], adatas=[adata1,adata2])
matching = np.array([range(index.shape[0]), best])

In [None]:
overall_score = global_score([adata1,adata2], matching.T, biology_meta, topology_meta)
celltype_score = global_score([adata1,adata2], matching.T, biology_meta=biology_meta)
region_score = global_score([adata1,adata2], matching.T, topology_meta=topology_meta)

In [None]:
eud = euclidean_dis(adata1, adata2, matching)

# Save

In [None]:
metric_dic = {}
metric_dic['global_score'] = overall_score
metric_dic['celltype_score'] = celltype_score
metric_dic['region_score'] = region_score
metric_dic['noise_level'] = noise_level
metric_dic['euclidean_dis'] = eud
metric_dic['run_time'] = run_time

with open(metric_file, "w") as f:
    yaml.dump(metric_dic, f)

np.savetxt(emb0_file, adata1.obsm['X_slat'], delimiter=',')
np.savetxt(emb1_file, adata2.obsm['X_slat'], delimiter=',')
np.savetxt(matching_file, matching, fmt='%i')

# Plot

In [None]:
out_dir = Path(os.path.dirname(metric_file))

In [None]:
# show prematch results
if 'rotation' in adata2.uns.keys():
    plt.scatter(trans[:,0], trans[:,1],s=1)
    plt.show()

In [None]:
# adata_all = adata1.concatenate(adata2)
# out_dir = Path(os.path.dirname(metric_file))
# sc.pp.neighbors(adata_all, metric="cosine", use_rep='X_slat')
# sc.tl.umap(adata_all)
# sc.pl.umap(adata_all, color=biology_meta, save=out_dir / 'biology.pdf')
# sc.pl.umap(adata_all, color=topology_meta, save=out_dir / 'topology.pdf')
# sc.pl.umap(adata_all, color="batch", save=out_dir / 'batch.pdf')

In [None]:
adata1_df = pd.DataFrame({'index':range(embd0.shape[0]),
                          'x': adata1.obsm['spatial'][:,0],
                          'y': adata1.obsm['spatial'][:,1],
                          'celltype':adata1.obs[biology_meta],
                          'region':adata1.obs[topology_meta]})
adata2_df = pd.DataFrame({'index':range(embd1.shape[0]),
                          'x': adata2.obsm['spatial'][:,0],
                          'y': adata2.obsm['spatial'][:,1],
                          'celltype':adata2.obs[biology_meta],
                          'region':adata2.obs[topology_meta]})
matching = np.array([range(index.shape[0]), best])

In [None]:
multi_align = match_3D_multi(adata1_df, adata2_df, matching, meta='celltype',
                             scale_coordinate=True, subsample_size=300, exchange_xy=False)

multi_align.draw_3D(size=[7, 8], line_width=1, point_size=[0.8,0.8], hide_axis=True, show_error=False, save=out_dir / 'match_by_celltype.pdf')

In [None]:
multi_align = match_3D_multi(adata1_df, adata2_df, matching, meta='region',
                             scale_coordinate=True, subsample_size=300, exchange_xy=False)

multi_align.draw_3D(size=[7, 8], line_width=1, point_size=[0.8,0.8], hide_axis=True, show_error=False, save=out_dir / 'match_by_region.pdf')