In [None]:
import os
import time
import yaml
from pathlib import Path
import _pickle as cpickle

import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
from matplotlib import pyplot as plt
import cv2
from sklearn.metrics import f1_score

from scSLAT.model import load_anndatas, run_SLAT, Cal_Spatial_Net, spatial_match
from scSLAT.model.prematch import icp, alpha_shape
from scSLAT.metrics import global_score, euclidean_dis, rotation_angle
from scSLAT.viz import match_3D_multi, matching_2d, Sankey

In [None]:
sc.set_figure_params(dpi_save=200, dpi=150)

In [None]:
# parameter cells
adata1_file = ''
adata2_file = ''
metric_file = ''
emb0_file = ''
emb1_file = ''
graphs_file = ''
matching_file = ''
ground_truth = 60

In [None]:
adata1 = sc.read_h5ad(adata1_file)
adata2 = sc.read_h5ad(adata2_file)

# Parameter for dataset

In [None]:
if 'visium' and 'DLPFC' in adata1_file:
    biology_meta = 'cell_type'
    topology_meta = 'layer_guess'
    alpha = 10
    spot_size = 5
elif 'merfish' and 'hypothalamic' in adata1_file:
    biology_meta = 'Cell_class'
    topology_meta = 'region'
    alpha = 25
    spot_size = 15
elif 'stereo' and 'embryo' in adata1_file:
    biology_meta = 'annotation'
    topology_meta = 'region'
    alpha = 3
    spot_size = 5
elif 'brain' in adata1_file:
    biology_meta = 'layer_guess'
    topology_meta = 'layer_guess'
    alpha = 10
    spot_size = 5

# Prematch

In [None]:
start = time.time()
if 'rotation' in adata2.uns.keys():
    boundary_1, edges_1, _ = alpha_shape(adata1.obsm['spatial'], alpha=alpha, only_outer=True)
    boundary_2, edges_2, _ = alpha_shape(adata2.obsm['spatial'], alpha=alpha, only_outer=True)
    T, error = icp(adata2.obsm['spatial'][boundary_2,:].T, adata1.obsm['spatial'][boundary_1,:].T)
    rotation = np.arcsin(T[0,1]) * 360 / 2 / np.pi

    print("T",  T)
    print("icp loss", error)
    print("rotation°", rotation)

    print(f"ground truth: {adata2.uns['rotation']}, prematch result:{rotation}, error is {adata2.uns['rotation'] - rotation}")

    trans = np.squeeze(cv2.transform(np.array([adata2.obsm['spatial']], copy=True).astype(np.float32), T))[:,:2]
    adata2.obsm['spatial'] = trans

# run SLAT with DPCA

In [None]:
Cal_Spatial_Net(adata1, k_cutoff=20, model='KNN')
Cal_Spatial_Net(adata2, k_cutoff=20, model='KNN')
edges, features = load_anndatas([adata1, adata2], feature='dpca', singular=True, dim=30)
embd0, embd1, time1 = run_SLAT(features, edges, 6, LGCN_layer=3)
run_time = str(time.time() - start)
print('Runtime: ' + run_time)

In [None]:
adata1.obsm['X_slat'] = embd0.cpu().detach().numpy()
adata2.obsm['X_slat'] = embd1.cpu().detach().numpy()

# Metric

In [None]:
embd0 = adata1.obsm['X_slat']
embd1 = adata2.obsm['X_slat']
best, index, distance = spatial_match([embd0, embd1], adatas=[adata1,adata2])
matching = np.array([range(index.shape[0]), best])

In [None]:
# angle
data = np.ones(matching.shape[1])
matching_sparse = sp.coo_matrix((data, (matching[1], matching[0])), shape=(adata1.n_obs, adata2.n_obs))
angle = rotation_angle(adata1.obsm['spatial'], adata2.obsm['spatial'], matching_sparse.toarray(), ground_truth=ground_truth)

In [None]:
overall_score = global_score([adata1,adata2], matching.T, biology_meta, topology_meta)
celltype_score = global_score([adata1,adata2], matching.T, biology_meta=biology_meta)
region_score = global_score([adata1,adata2], matching.T, topology_meta=topology_meta)

In [None]:
eud = euclidean_dis(adata1, adata2, matching)

## F1 score

In [None]:
if 'visium' and 'DLPFC' in adata1_file:
    adata2.obs[biology_meta] = 'celltype_' + adata2.obs[biology_meta].astype('str')
    adata1.obs[biology_meta] = 'celltype_' + adata1.obs[biology_meta].astype('str')

adata2.obs['target_celltype'] = adata1.obs.iloc[matching[1,:],:][biology_meta].to_list()
adata2.obs['target_region'] = adata1.obs.iloc[matching[1,:],:][topology_meta].to_list()
adata2.obs['target_celltype_region'] = adata2.obs['target_celltype'].astype('str') + '_' + adata2.obs['target_region'].astype('str')
adata2.obs['celltype_region'] = adata2.obs[biology_meta].astype('str') + '_' + adata2.obs[topology_meta].astype('str')

In [None]:
celltype_macro_f1 = f1_score(adata2.obs[biology_meta], adata2.obs['target_celltype'], average='macro')
celltype_micro_f1 = f1_score(adata2.obs[biology_meta], adata2.obs['target_celltype'], average='micro')

region_macro_f1 = f1_score(adata2.obs[topology_meta], adata2.obs['target_region'], average='macro')
region_micro_f1 = f1_score(adata2.obs[topology_meta], adata2.obs['target_region'], average='micro')

total_macro_f1 = f1_score(adata2.obs['celltype_region'], adata2.obs['target_celltype_region'], average='macro')
total_micro_f1 = f1_score(adata2.obs['celltype_region'], adata2.obs['target_celltype_region'], average='micro')

## Ground truth (perturb)

In [None]:
if 'perturb' in matching_file:
    match_ratio =  (matching[0] == matching[1]).sum() / len(matching[0])
else:
    match_ratio = -1

# Save

In [None]:
metric_dic = {}
metric_dic['global_score'] = overall_score
metric_dic['celltype_score'] = celltype_score
metric_dic['region_score'] = region_score
metric_dic['run_time'] = run_time
metric_dic['euclidean_dis'] = eud
metric_dic['angle_delta'] = float(angle)

metric_dic['celltype_macro_f1'] = float(celltype_macro_f1)
metric_dic['celltype_micro_f1'] = float(celltype_micro_f1)
metric_dic['region_macro_f1'] = float(region_macro_f1)
metric_dic['region_micro_f1'] = float(region_micro_f1)
metric_dic['total_macro_f1'] = float(total_macro_f1)
metric_dic['total_micro_f1'] = float(total_micro_f1)

metric_dic['match_ratio'] = float(match_ratio)

with open(metric_file, "w") as f:
    yaml.dump(metric_dic, f)

np.savetxt(emb0_file, adata1.obsm['X_slat'], delimiter=',')
np.savetxt(emb1_file, adata2.obsm['X_slat'], delimiter=',')
np.savetxt(matching_file, matching, fmt='%i')

In [None]:
# save graphs for edge score
edges = [edge.cpu().detach() for edge in edges]
with open(graphs_file, 'wb') as f:
    cpickle.dump(edges, f)

# Plot

In [None]:
out_dir = Path(os.path.dirname(metric_file))
sc.settings.figdir = out_dir
print(f"Saving figures to {out_dir}")

In [None]:
# show prematch results
if 'rotation' in adata2.uns.keys():
    plt.scatter(trans[:,0], trans[:,1], s=1)
    plt.show()

In [None]:
# adata_all = adata1.concatenate(adata2)
# sc.pp.neighbors(adata_all, metric="cosine", use_rep='X_slat')
# sc.tl.umap(adata_all)
# sc.pl.umap(adata_all, color=biology_meta, save=out_dir / 'biology.pdf')
# sc.pl.umap(adata_all, color=topology_meta, save=out_dir / 'topology.pdf')
# sc.pl.umap(adata_all, color="batch", save=out_dir / 'batch.pdf')

In [None]:
adata1_df = pd.DataFrame({'index':range(embd0.shape[0]),
                          'x': adata1.obsm['spatial'][:,0],
                          'y': adata1.obsm['spatial'][:,1],
                          'celltype':adata1.obs[biology_meta],
                          'region':adata1.obs[topology_meta]})
adata2_df = pd.DataFrame({'index':range(embd1.shape[0]),
                          'x': adata2.obsm['spatial'][:,0],
                          'y': adata2.obsm['spatial'][:,1],
                          'celltype':adata2.obs[biology_meta],
                          'region':adata2.obs[topology_meta]})
matching = np.array([range(index.shape[0]), best])

In [None]:
# 3D matching by cell type
multi_align = match_3D_multi(adata1_df, adata2_df, matching, meta='celltype',
                             scale_coordinate=True, subsample_size=300, exchange_xy=False)

multi_align.draw_3D(size=[7, 8], line_width=1, point_size=[0.8, 0.8], hide_axis=True, 
                    show_error=True, save=out_dir / 'match_by_celltype.pdf')

In [None]:
# 3D matching by region
multi_align = match_3D_multi(adata1_df, adata2_df, matching, meta='region',
                             scale_coordinate=True, subsample_size=300, exchange_xy=False)

multi_align.draw_3D(size=[7, 8], line_width=1, point_size=[0.8, 0.8], hide_axis=True, 
                    show_error=True, save=out_dir / 'match_by_region.pdf')

In [None]:
# 2D matching plot
matching_2d(matching, adata1, adata2, biology_meta, topology_meta, spot_size, save='matching_2d.pdf')

In [None]:
# Sankey plot
adata2.obs['target_celltype'] = adata1.obs.iloc[matching[1,:],:][biology_meta].to_list()
adata2.obs['target_region'] = adata1.obs.iloc[matching[1,:],:][topology_meta].to_list()
## by cell type
matching_table = adata2.obs.groupby([biology_meta,'target_celltype']).size().unstack(fill_value=0)
matching_table.index = adata2.obs[biology_meta].unique()
matching_table.columns = adata2.obs['target_celltype'].unique()
print(matching_table)

Sankey(matching_table, prefix=['Slide1', 'Slide2'], save_name=str(out_dir/'celltype_sankey'),
       format='svg', width=1000, height=1000)

## by region
matching_table = adata2.obs.groupby([topology_meta,'target_region']).size().unstack(fill_value=0)
matching_table.index = adata2.obs[topology_meta].unique()
matching_table.columns = adata2.obs['target_region'].unique()
print(matching_table)

Sankey(matching_table, prefix=['Slide1', 'Slide2'], save_name=str(out_dir/'region_sankey'),
       format='svg', width=1000, height=1000)


# Reverse matching

In [None]:
best_rev, index_rev, _ = spatial_match([embd1, embd0], adatas=[adata2, adata1], reorder=False)
matching_rev = np.array([range(index_rev.shape[0]), best_rev])

In [None]:
matching_2d(matching_rev, adata2, adata1, biology_meta, topology_meta, spot_size, save='matching_rev_2d.pdf')