# Tutorial 1: Generate hard-links for consecutive ST slices

### loading packages

In [8]:
import scipy
import os
import pickle

from utils import (
    build_args_ST,
    create_optimizer
)

from models import build_model_ST

### HP setup

In [16]:
args = build_args_ST()
args.max_epoch=2000
args.max_epoch_triplet=500
args.dataset_name="DLPFC"
args.section_ids=["151507", "151508"]
args.num_class=7
args.num_hidden="512,32"
args.alpha_l=1
args.lam=1 
args.loss_fn="sce" 
args.mask_rate=0.5 
args.in_drop=0 
args.attn_drop=0 
args.remask_rate=0.1
args.seeds=[2023] 
args.num_remasking=1 
args.hvgs=3000 
args.dataset="DLPFC" 
args.consecutive_prior=1 
args.lr=0.001
args.scheduler = True

#### remember to change this to your data path
args.st_data_dir="/home/yunfei/spatial_benchmarking/benchmarking_data/DLPFC12"


In [3]:
args

Namespace(seeds=[2023], dataset='DLPFC', exp_fig_dir='./', h5ad_save_dir='./', st_data_dir='path/to/stdata', pi_dir='./', consecutive_prior=1, section_ids='151507,151508', num_class=7, hvgs=3000, device=3, max_epoch=2000, max_epoch_triplet=500, warmup_steps=-1, num_heads=1, num_out_heads=1, num_layers=2, num_dec_layers=2, num_remasking=1, num_hidden='512,32', residual=False, in_drop=0, attn_drop=0, norm=None, lr=0.001, weight_decay=0, negative_slope=0.2, activation='elu', mask_rate=0.5, remask_rate=0.1, remask_method='random', mask_type='mask', mask_method='random', drop_edge_rate=0.0, encoder='gat', decoder='gat', loss_fn='sce', alpha_l=1, optimizer='adam', linear_prob=False, no_pretrain=False, checkpoint_path=None, use_cfg=False, logging=False, scheduler=False, batch_size=512, sampling_method='saint', label_rate=1.0, lam=1, delayed_ema_epoch=0, replace_rate=0.0, momentum=0.996, load_model=False)

### data loader

In [4]:
from datasets.st_loading_utils import load_DLPFC, create_dictionary_mnn, load_mHypothalamus
from datasets.data_proc import Cal_Spatial_Net
import scanpy as sc
import anndata
import numpy as np
import dgl
import torch


def localOT_loader(section_ids=["151507", "151508"], dataname="DLPFC", hvgs=5000, st_data_dir="./", hard_links=None):
    # hard links is a mapping matrix (2d numpy array with the size of #slice1 spot by #slice2 spot)
    if dataname == "DLPFC":
        Batch_list = []
        adj_list = []
        for section_id in section_ids:
            ad_ = load_DLPFC(root_dir=st_data_dir, section_id=section_id)
            ad_.var_names_make_unique(join="++")
        
            # make spot name unique
            ad_.obs_names = [x+'_'+section_id for x in ad_.obs_names]
            
            # Constructing the spatial network
            Cal_Spatial_Net(ad_, rad_cutoff=150) # the spatial network are saved in adata.uns[‘adj’]
            
            # Normalization
            sc.pp.highly_variable_genes(ad_, flavor="seurat_v3", n_top_genes=hvgs)
            sc.pp.normalize_total(ad_, target_sum=1e4)
            sc.pp.log1p(ad_)
            ad_ = ad_[:, ad_.var['highly_variable']]

            adj_list.append(ad_.uns['adj'])
            Batch_list.append(ad_)
        adata_concat = anndata.concat(Batch_list, label="slice_name", keys=section_ids, uns_merge="same")
        adata_concat.obs['original_clusters'] = adata_concat.obs['original_clusters'].astype('category')
        adata_concat.obs["batch_name"] = adata_concat.obs["slice_name"].astype('category')

        adj_concat = np.asarray(adj_list[0].todense())
        for batch_id in range(1,len(section_ids)):
            adj_concat = scipy.linalg.block_diag(adj_concat, np.asarray(adj_list[batch_id].todense()))
        
        """if hard links is not empty"""
        if hard_links != None:
            for i in range(hard_links.shape[0]):
                for j in range(hard_links.shape[1]):
                    if hard_links[i][j] > 0:
                        adj_concat[i][j+hard_links.shape[0]] = 1
                        adj_concat[j+hard_links.shape[0]][i] = 1

        edgeList = np.nonzero(adj_concat)
        graph = dgl.graph((edgeList[0], edgeList[1]))
        graph.ndata["feat"] = torch.tensor(adata_concat.X.todense())
        num_features = graph.ndata["feat"].shape[1]
    elif dataname == "mHypothalamus":
        Batch_list = []
        adj_list = []
        for section_id in section_ids:
            ad_ = load_mHypothalamus(root_dir=st_data_dir, section_id=section_id)
            ad_.var_names_make_unique(join="++")
        
            # make spot name unique
            ad_.obs_names = [x+'_'+section_id for x in ad_.obs_names]
            
            # Constructing the spatial network
            Cal_Spatial_Net(ad_, rad_cutoff=35) # the spatial network are saved in adata.uns[‘adj’]
            
            # Normalization
            sc.pp.normalize_total(ad_, target_sum=1e4)
            sc.pp.log1p(ad_)

            adj_list.append(ad_.uns['adj'])
            Batch_list.append(ad_)
        adata_concat = anndata.concat(Batch_list, label="slice_name", keys=section_ids, uns_merge="same")
        adata_concat.obs['original_clusters'] = adata_concat.obs['original_clusters'].astype('category')
        adata_concat.obs["batch_name"] = adata_concat.obs["slice_name"].astype('category')

        adj_concat = np.asarray(adj_list[0].todense())
        for batch_id in range(1,len(section_ids)):
            adj_concat = scipy.linalg.block_diag(adj_concat, np.asarray(adj_list[batch_id].todense()))
        
        """if hard links is not empty"""
        if hard_links != None:
            for i in range(hard_links.shape[0]):
                for j in range(hard_links.shape[1]):
                    if hard_links[i][j] > 0:
                        adj_concat[i][j+hard_links.shape[0]] = 1
                        adj_concat[j+hard_links.shape[0]][i] = 1

        edgeList = np.nonzero(adj_concat)
        graph = dgl.graph((edgeList[0], edgeList[1]))
        graph.ndata["feat"] = torch.tensor(adata_concat.X).float()
        num_features = graph.ndata["feat"].shape[1]
    else:
        raise NotImplementedError
    return graph, num_features, adata_concat

2024-03-06 15:51:30,053 - INFO - Enabling RDKit 2022.09.5 jupyter extensions


### model setup

In [19]:
graph, num_features, ad_concat = localOT_loader(section_ids=args.section_ids, hvgs=args.hvgs, st_data_dir=args.st_data_dir, dataname=args.dataset_name)
args.num_features = num_features
model = build_model_ST(args)
print(model)

device = args.device if args.device >= 0 else "cpu"
model.to(device)
optimizer = create_optimizer(args.optimizer, model, args.lr, args.weight_decay)

if args.scheduler:
    scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / args.max_epoch) ) * 0.5
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
else:
    scheduler = None

  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


------Calculating spatial graph...


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)


The graph contains 24762 edges, 4221 cells.
5.8664 neighbors per cell on average.


  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


------Calculating spatial graph...


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)


The graph contains 25692 edges, 4381 cells.
5.8644 neighbors per cell on average.
=== Use sce_loss and alpha_l=1 ===
num_encoder_params: 601696, num_decoder_params: 605020, num_params_in_total: 1243282
PreModel(
  (encoder): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=1140, out_features=512, bias=False)
        (feat_drop): Dropout(p=0, inplace=False)
        (attn_drop): Dropout(p=0, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (activation): ELU(alpha=1.0)
      )
      (1): GATConv(
        (fc): Linear(in_features=512, out_features=32, bias=False)
        (feat_drop): Dropout(p=0, inplace=False)
        (attn_drop): Dropout(p=0, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
      )
    )
    (head): Identity()
  )
  (decoder): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=32, out_features=512, bias=False)
        (feat_drop): Dropout(p=0, inplace=F

### masked reconstruction loss training

In [22]:
from tqdm import tqdm


x = graph.ndata["feat"]
model.to(device)
graph = graph.to(device)
x = x.to(device)

"""training"""
target_nodes = torch.arange(x.shape[0], device=x.device, dtype=torch.long)
print(args.max_epoch)
epoch_iter = tqdm(range(args.max_epoch))

print("training local clusters ... ")
for epoch in epoch_iter:
    model.train()
    # print(type(x), type(graph))
    loss = model(graph, x, targets=target_nodes)

    loss_dict = {"loss": loss.item()}
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if scheduler is not None:
        scheduler.step()

    epoch_iter.set_description(f"# Epoch {epoch}: train_loss: {loss.item():.4f}")


with torch.no_grad():
    embedding = model.embed(graph, x)
ad_concat.obsm["maskgraphene"] = embedding.cpu().detach().numpy()

2000


  0%|          | 0/2000 [00:00<?, ?it/s]

training local clusters ... 


# Epoch 1999: train_loss: 0.2209: 100%|██████████| 2000/2000 [01:17<00:00, 25.71it/s]


### triplet loss training

In [23]:
from datasets.st_loading_utils import mclust_R
from sklearn.metrics import adjusted_rand_score as ari_score

mnn_dict = create_dictionary_mnn(ad_concat, use_rep="maskgraphene", batch_name='batch_name', k=50, verbose = 1, iter_comb=None)

anchor_ind = []
positive_ind = []
negative_ind = []
for batch_pair in mnn_dict.keys():  # pairwise compare for multiple batches
    batchname_list = ad_concat.obs['batch_name'][mnn_dict[batch_pair].keys()]

    cellname_by_batch_dict = dict()
    for batch_id in range(len(args.section_ids)):
        cellname_by_batch_dict[args.section_ids[batch_id]] = ad_concat.obs_names[
            ad_concat.obs['batch_name'] == args.section_ids[batch_id]].values

    anchor_list = []
    positive_list = []
    negative_list = []
    for anchor in mnn_dict[batch_pair].keys():
        anchor_list.append(anchor)
        ## np.random.choice(mnn_dict[batch_pair][anchor])
        positive_spot = mnn_dict[batch_pair][anchor][0]  # select the first positive spot
        positive_list.append(positive_spot)
        section_size = len(cellname_by_batch_dict[batchname_list[anchor]])
        negative_list.append(
            cellname_by_batch_dict[batchname_list[anchor]][np.random.randint(section_size)])

    batch_as_dict = dict(zip(list(ad_concat.obs_names), range(0, ad_concat.shape[0])))
    anchor_ind = np.append(anchor_ind, list(map(lambda _: batch_as_dict[_], anchor_list)))
    positive_ind = np.append(positive_ind, list(map(lambda _: batch_as_dict[_], positive_list)))
    negative_ind = np.append(negative_ind, list(map(lambda _: batch_as_dict[_], negative_list)))

epoch_iter = tqdm(range(args.max_epoch_triplet))
for epoch in epoch_iter:
    model.train()
    optimizer.zero_grad()

    _loss = model(graph, x, targets=target_nodes)
    if epoch % 100 == 0 or epoch == 500:
        with torch.no_grad():
            z = model.embed(graph, x)
        
        anchor_arr = z[anchor_ind,]
        positive_arr = z[positive_ind,]
        negative_arr = z[negative_ind,]

    triplet_loss = torch.nn.TripletMarginLoss(margin=1, p=2, reduction='mean')
    tri_output = triplet_loss(anchor_arr, positive_arr, negative_arr)

    loss = _loss + tri_output
    loss.backward()
    # torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
    # torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
    optimizer.step()

    if scheduler is not None:
        scheduler.step()
    loss_dict = {"loss": loss.item()}
    epoch_iter.set_description(f"# Epoch {epoch}: train_loss: {loss.item():.4f}")

    # z = model.embed(graph, x)
with torch.no_grad():
    embedding = model.embed(graph, x)
ad_concat.obsm["maskgraphene_mnn"] = embedding.cpu().detach().numpy()


"""calculate ARI & umap & viz"""
mclust_R(ad_concat, modelNames='EEE', num_cluster=args.num_class, used_obsm='maskgraphene_mnn')


ad_temp = ad_concat[ad_concat.obs['original_clusters']!='unknown']


Batch_list = []
for section_id in args.section_ids:
    ad__ = ad_temp[ad_temp.obs['batch_name'] == section_id]
    Batch_list.append(ad__)
    print(section_id)
    print('mclust, ARI = %01.3f' % ari_score(ad__.obs['original_clusters'], ad__.obs['mclust']))

Processing datasets (0, 1)


# Epoch 499: train_loss: 0.2467: 100%|██████████| 500/500 [00:19<00:00, 26.28it/s]
2024-03-06 16:10:32,961 - INFO - cffi mode is CFFI_MODE.ANY
2024-03-06 16:10:33,057 - INFO - R home found: /usr/lib/R
2024-03-06 16:10:33,273 - INFO - R library path: /usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/default-java/lib/server
2024-03-06 16:10:33,276 - INFO - LD_LIBRARY_PATH: /usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/default-java/lib/server
2024-03-06 16:10:33,286 - INFO - Default options to initialize R: rpy2, --quiet, --no-save
2024-03-06 16:10:33,463 - INFO - R is already initialized. No need to initialize.
   ____ ___  _____/ /_  _______/ /_
  / __ `__ \/ ___/ / / / / ___/ __/
 / / / / / / /__/ / /_/ (__  ) /_  
/_/ /_/ /_/\___/_/\__,_/____/\__/   version 5.4.10
Type 'citation("mclust")' for citing this R package in publications.



[[ -2.1615634   -6.23971     -2.3459961  ...  -2.3171792   -8.681177
    0.41282746]
 [ -2.460916   -15.380771     0.09077121 ...  -2.9557252  -20.356768
   -3.8703644 ]
 [ -2.6018689   -3.7312057   -4.6612043  ...  -7.7043953   -6.703478
   -1.9260284 ]
 ...
 [ -2.2426836  -15.747475    -0.50736713 ...  -4.671777   -26.402672
   -3.7827911 ]
 [ -5.4853687   -0.9318517   -6.884806   ...  -9.694893   -11.444196
    1.916019  ]
 [ -1.3684633   -3.4660091   -2.5994942  ...  -7.151997    -6.307693
   -2.191994  ]]
fitting ...
151507
mclust, ARI = 0.444
151508
mclust, ARI = 0.460


### paste alignment to generate hard-links

In [24]:
import paste
import ot

slice1 = Batch_list[0]
slice2 = Batch_list[1]
global_PI = np.zeros((len(slice1.obs.index), len(slice2.obs.index)))
slice1_idx_mapping = {}
slice2_idx_mapping = {}
for i in range(len(slice1.obs.index)):
    slice1_idx_mapping[slice1.obs.index[i]] = i
for i in range(len(slice2.obs.index)):
    slice2_idx_mapping[slice2.obs.index[i]] = i

# temp_local_pi_list = []
for i in range(args.num_class):
    subslice1 = slice1[slice1.obs['mclust']==i+1]
    subslice2 = slice2[slice2.obs['mclust']==i+1]
    if subslice1.shape[0]>0 and subslice2.shape[0]>0:
        pi00 = paste.match_spots_using_spatial_heuristic(subslice1.obsm['spatial'], subslice2.obsm['spatial'], use_ot= True)
        local_PI = paste.pairwise_align(subslice1, subslice2, alpha=0.1, dissimilarity='kl', use_rep=None, G_init=pi00, use_gpu = True, backend = ot.backend.TorchBackend())
        for i in range(local_PI.shape[0]):
            for j in range(local_PI.shape[1]):
                global_PI[slice1_idx_mapping[subslice1.obs.index[i]]][slice2_idx_mapping[subslice2.obs.index[j]]] = local_PI[i][j]

gpu is available, using gpu.
gpu is available, using gpu.
gpu is available, using gpu.
gpu is available, using gpu.
gpu is available, using gpu.
gpu is available, using gpu.
gpu is available, using gpu.


### save/load Hard-links

In [None]:
S = scipy.sparse.csr_matrix(global_PI)
file = open(os.path.join(args.save_pi_dir, "pi_151507_151508.pickle"),'wb')
pickle.dump(S, file)

"""
file = open("pi_151507_151508.pickle",'rb') 
S = pickle.load(file)
S.toarray()

to retrieve from saved pi file
"""