In [None]:
import os
import yaml
from pathlib import Path

import scanpy as sc
import torch
import numpy as np
import pandas as pd

from scSLAT.model import spatial_match
from scSLAT.metrics import global_score, euclidean_dis
from scSLAT.viz import match_3D_multi

In [None]:
# parameter cells
adata1_file = ''
adata2_file = ''
emb0_file = ''
emb1_file = ''
metrics_file = ''
matching_file = ''

In [None]:
adata1 = sc.read_h5ad(adata1_file)
adata2 = sc.read_h5ad(adata2_file)

In [None]:
embd0 = np.loadtxt(emb0_file, delimiter=',')
embd1 = np.loadtxt(emb1_file, delimiter=',')
embd0 = torch.from_numpy(embd0)
embd1 = torch.from_numpy(embd1)

In [None]:
best, index, distance = spatial_match([embd0, embd1], smooth=False)
matching = np.array([range(index.shape[0]), best])

In [None]:
if 'visium' and 'DLPFC' in adata1_file:
    biology_meta = 'cell_type'
    topology_meta = 'layer_guess'
elif 'merfish' and 'hypothalamic' in adata1_file:
    biology_meta = 'Cell_class'
    topology_meta = 'region'
elif 'stereo' and 'embryo' in adata1_file:
    biology_meta = 'annotation'
    topology_meta = 'region'
elif 'brain' in adata1_file:
    biology_meta = 'layer_guess'
    topology_meta = 'layer_guess'

In [None]:
overall_score = global_score([adata1,adata2], matching.T, biology_meta, topology_meta)
celltype_score = global_score([adata1,adata2], matching.T, biology_meta=biology_meta)
region_score = global_score([adata1,adata2], matching.T, topology_meta=topology_meta)

In [None]:
eud = euclidean_dis(adata1, adata2, matching)

# Save

In [None]:
out_dir = Path(os.path.dirname(metrics_file))
with open(out_dir / 'run_time.yaml', 'r') as stream:
    run_time_dic = yaml.safe_load(stream)

In [None]:
metric_dic = {}
metric_dic['global_score'] = overall_score
metric_dic['celltype_score'] = celltype_score
metric_dic['region_score'] = region_score
metric_dic['euclidean_dis'] = eud
metric_dic['run_time'] = run_time_dic['run_time']

with open(metrics_file, "w") as f:
    yaml.dump(metric_dic, f)

np.savetxt(matching_file, matching, fmt='%i')

# Plot

In [None]:
# adata1_df = pd.DataFrame({'index':range(embd0.shape[0]),
#                           'x': adata1.obsm['spatial'][:,0],
#                           'y': adata1.obsm['spatial'][:,1],
#                           'celltype':adata1.obs[biology_meta],
#                           'region':adata1.obs[topology_meta]})
# adata2_df = pd.DataFrame({'index':range(embd1.shape[0]),
#                           'x': adata2.obsm['spatial'][:,0],
#                           'y': adata2.obsm['spatial'][:,1],
#                           'celltype':adata2.obs[biology_meta],
#                           'region':adata2.obs[topology_meta]})
# matching = np.array([range(index.shape[0]), best])

In [None]:
# multi_align = match_3D_multi(adata1_df, adata2_df,matching, meta='celltype',
#                              scale_coordinate=True, subsample_size=300, exchange_xy=False)

# multi_align.draw_3D(size=[7, 8], line_width=1, point_size=[0.8,0.8], 
#                     hide_axis=True, show_error=True, save=out_dir / 'match_by_celltype.pdf')

In [None]:
# multi_align = match_3D_multi(adata1_df, adata2_df,matching, meta='region',
#                              scale_coordinate=True, subsample_size=300, exchange_xy=False)

# multi_align.draw_3D(size=[7, 8], line_width=1, point_size=[0.8,0.8], 
#                     hide_axis=True, show_error=True, save=out_dir / 'match_by_region.pdf')