# Tutorial 3: Integrate MHypo data with MaskGraphene

In this tutorial, we show how to use MG to integrate MHypo data. As an example, we analyze the -0.04 / -0.09 sample pair of the dorsolateral prefrontal cortex (DLPFC) dataset. 

We acquired the data from the spatialLIBD webpage, including manual annotations. Before running the model, please download the input data via [zenodo link](https://zenodo.org/records/10698909)

### Loading packages

In [2]:
import logging
import numpy as np
from tqdm import tqdm
import torch
import pickle
import wandb

from utils import (
    build_args_ST,
    create_optimizer
)
from datasets.st_loading_utils import visualization_umap_spatial, create_dictionary_mnn
from models import build_model_ST
import os
import scanpy as sc
import sklearn.metrics.pairwise

### HP setup

In [3]:
args = build_args_ST()
args.max_epoch=2000
args.max_epoch_triplet=500
args.section_ids=["-0.04","-0.09"]
args.num_class=8
args.num_hidden="512,32"
args.alpha_l=1
args.lam=1 
args.loss_fn="sce" 
args.mask_rate=0.5 
args.in_drop=0 
args.attn_drop=0 
args.remask_rate=0.1
args.seeds=[2023] 
args.num_remasking=1 
args.hvgs=3000 
args.dataset="MHypo" 
args.consecutive_prior=1 
args.lr=0.001

#### remember to change these paths to your data path/link path
args.st_data_dir="/home/yunfei/spatial_benchmarking/benchmarking_data/mHypothalamus"
args.pi_dir="/home/yunfei/spatial_dl_integration/MaskGraphene/PI"

### MG training

### consecutive DLPFC slices

In [5]:
import dgl
import scipy
import anndata
from datasets.st_loading_utils import load_DLPFC
from datasets.data_proc import Cal_Spatial_Net, simple_impute

def dlpfc_loader(dataset_name, pi, section_ids=["151507", "151508"], hvgs=5000, st_data_dir="./"):
    
    if "DLPFC" in dataset_name:
        # ad_list = []
        Batch_list = []
        adj_list = []
        # delist = []
        
        for section_id in section_ids:
            ad_ = load_DLPFC(root_dir=st_data_dir, section_id=section_id)
            ad_.var_names_make_unique(join="++")

            # """add cached de list"""
            # file_path = "/home/yunfei/spatial_dl_integration/MaskGraphene/de_temp/"+section_id+"_de_list.txt"
            # with open(file_path, 'r') as file:
            #     lines = file.readlines()
            # de_ = [line.strip() for line in lines]
            # delist.append(de_)
            # make spot name unique
            ad_.obs_names = [x+'_'+section_id for x in ad_.obs_names]
            
            # Constructing the spatial network
            Cal_Spatial_Net(ad_, rad_cutoff=150) # the spatial network are saved in adata.uns[‘adj’]
            adj_list.append(ad_.uns['adj'])
            Batch_list.append(ad_)
        
        # print(np.nonzero())
        ad1, ad2 = simple_impute(Batch_list[0], Batch_list[1], pi)
        Batch_list = [ad1, ad2]
        Batch_list_new = []
        for ad_ in Batch_list:
            # Normalization
            sc.pp.highly_variable_genes(ad_, flavor="seurat_v3", n_top_genes=hvgs)
            sc.pp.normalize_total(ad_, target_sum=1e4)
            sc.pp.log1p(ad_)
            ad_ = ad_[:, ad_.var['highly_variable']]

            # union_list = set(item for sublist in delist for item in sublist)
            # ad_ = ad_[:, list(union_list)]
            Batch_list_new.append(ad_)

            
        adata_concat = anndata.concat(Batch_list_new, label="slice_name", keys=section_ids, uns_merge="same")
        adata_concat.obs['original_clusters'] = adata_concat.obs['original_clusters'].astype('category')
        adata_concat.obs["batch_name"] = adata_concat.obs["slice_name"].astype('category')

        adj_concat = np.asarray(adj_list[0].todense())
        for batch_id in range(1,len(section_ids)):
            adj_concat = scipy.linalg.block_diag(adj_concat, np.asarray(adj_list[batch_id].todense()))

        if pi is not None:
            assert adj_concat.shape[0] == pi.shape[0] + pi.shape[1], "adj matrix shape is not consistent with the pi matrix"

            """keep max"""
            # max_values = np.max(pi, axis=1)

            # # Create a new array with zero
            # pi_keep_argmax = np.zeros_like(pi)

            # # Loop through each row and set the maximum value to 1 (or any other desired value)
            # for i in range(pi.shape[0]):
            #     pi_keep_argmax[i, np.argmax(pi[i])] = max_values[i]
            
            # pi = pi_keep_argmax
            """"""

            for i in range(pi.shape[0]):
                for j in range(pi.shape[1]):
                    if pi[i][j] > 0:
                        adj_concat[i][j+pi.shape[0]] = 1
                        adj_concat[j+pi.shape[0]][i] = 1
        
        edgeList = np.nonzero(adj_concat)
        graph = dgl.graph((edgeList[0], edgeList[1]))
        graph.ndata["feat"] = torch.tensor(adata_concat.X)
    num_features = graph.ndata["feat"].shape[1]
    return graph, num_features, adata_concat

In [6]:
"""file save path"""
exp_fig_dir = args.exp_fig_dir
st_data_dir = args.st_data_dir
pi_dir= os.path.join(args.pi_dir, args.dataset+'_'.join(args.section_ids))


file = open(os.path.join(pi_dir, "S.pickle"),'rb') 
global_PI = pickle.load(file)
global_PI = global_PI.toarray()
"""
STAGE 1
"""

"""train with MSSL"""
graph, num_features, ad_concat = dlpfc_loader(dataset_name=args.dataset, pi=global_PI, section_ids=args.section_ids, hvgs=args.hvgs, st_data_dir=st_data_dir)
args.num_features = num_features
x = graph.ndata["feat"]

model = build_model_ST(args)
print(model)

device = args.device if args.device >= 0 else "cpu"
optimizer = create_optimizer(args.optimizer, model, args.lr, args.weight_decay)

if args.scheduler:
    logging.info("Use scheduler")
    scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / args.max_epoch) ) * 0.5
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
else:
    scheduler = None

  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


------Calculating spatial graph...


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)


The graph contains 24762 edges, 4221 cells.
5.8664 neighbors per cell on average.


  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


------Calculating spatial graph...


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)


The graph contains 25692 edges, 4381 cells.
5.8644 neighbors per cell on average.
=== Use sce_loss and alpha_l=1 ===
num_encoder_params: 1414752, num_decoder_params: 1422840, num_params_in_total: 2875746
PreModel(
  (encoder): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=2728, out_features=512, bias=False)
        (feat_drop): Dropout(p=0, inplace=False)
        (attn_drop): Dropout(p=0, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (activation): ELU(alpha=1.0)
      )
      (1): GATConv(
        (fc): Linear(in_features=512, out_features=32, bias=False)
        (feat_drop): Dropout(p=0, inplace=False)
        (attn_drop): Dropout(p=0, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
      )
    )
    (head): Identity()
  )
  (decoder): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=32, out_features=512, bias=False)
        (feat_drop): Dropout(p=0, inplace

### training with masked reconstruction loss

In [9]:
model.to(device)
graph = graph.to(device)
x = x.to(device)

target_nodes = torch.arange(x.shape[0], device=x.device, dtype=torch.long)
epoch_iter = tqdm(range(args.max_epoch))

for epoch in epoch_iter:
    model.train()
    loss = model(graph, x, targets=target_nodes)

    loss_dict = {"loss": loss.item()}
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if scheduler is not None:
        scheduler.step()

with torch.no_grad():
    z = model.embed(graph, x)
    # print(z)

# model.eval()
ad_concat.obsm["MG"] = z.cpu().detach().numpy()

# ari_1_pre = []
# ari_2_pre = []
# ari_ = visualization_umap_spatial(ad_temp=ad_concat, section_ids=args.section_ids, exp_fig_dir=exp_fig_dir, dataset_name=args.dataset_name, num_iter="1", identifier="stage1", num_class=args.num_class, use_key="MG")
# ari_1_pre.append(ari_[0])
# ari_2_pre.append(ari_[1])

# print(args.section_ids[0], ', ARI = %01.3f' % ari_[0])
# print(args.section_ids[1], ', ARI = %01.3f' % ari_[1])

100%|██████████| 2000/2000 [01:50<00:00, 18.18it/s]


### training with triplet loss

In [10]:
"""train with MSSL + triplet loss"""
# logging.info("Keep training Model with cse + triplet loss ")

graph = graph.to(device)
x = x.to(device)

target_nodes = torch.arange(x.shape[0], device=x.device, dtype=torch.long)
epoch_iter = tqdm(range(args.max_epoch))
section_ids = np.array(ad_concat.obs['batch_name'].unique())

"""a list of precomputed pi, it has the same length as the iter_comb"""
# mnn_dict = create_dictionary_otn(adata_concat_, pis_list_, section_ids, batch_name='batch_name', conf_thres = 0.0, mode="normal", verbose = 1, iter_comb=iter_comb)
mnn_dict = create_dictionary_mnn(ad_concat, use_rep="MG", batch_name='batch_name', k=50, iter_comb=None)
anchor_ind = []
positive_ind = []
negative_ind = []
for batch_pair in mnn_dict.keys():  # pairwise compare for multiple batches
    batchname_list = ad_concat.obs['batch_name'][mnn_dict[batch_pair].keys()]
    #             print("before add KNN pairs, len(mnn_dict[batch_pair]):",
    #                   sum(adata_new.obs['batch_name'].isin(batchname_list.unique())), len(mnn_dict[batch_pair]))

    cellname_by_batch_dict = dict()
    for batch_id in range(len(section_ids)):
        cellname_by_batch_dict[section_ids[batch_id]] = ad_concat.obs_names[
            ad_concat.obs['batch_name'] == section_ids[batch_id]].values

    anchor_list = []
    positive_list = []
    negative_list = []
    for anchor in mnn_dict[batch_pair].keys():
        anchor_list.append(anchor)
        ## np.random.choice(mnn_dict[batch_pair][anchor])
        positive_spot = mnn_dict[batch_pair][anchor][0]  # select the first positive spot
        positive_list.append(positive_spot)
        section_size = len(cellname_by_batch_dict[batchname_list[anchor]])
        negative_list.append(
            cellname_by_batch_dict[batchname_list[anchor]][np.random.randint(section_size)])

    batch_as_dict = dict(zip(list(ad_concat.obs_names), range(0, ad_concat.shape[0])))
    anchor_ind = np.append(anchor_ind, list(map(lambda _: batch_as_dict[_], anchor_list)))
    positive_ind = np.append(positive_ind, list(map(lambda _: batch_as_dict[_], positive_list)))
    negative_ind = np.append(negative_ind, list(map(lambda _: batch_as_dict[_], negative_list)))

for epoch in epoch_iter:


    model.train()
    optimizer.zero_grad()

    _loss = model(graph, x, targets=target_nodes)
    with torch.no_grad():
        z = model.embed(graph, x)

    anchor_arr = z[anchor_ind,]
    positive_arr = z[positive_ind,]
    negative_arr = z[negative_ind,]

    triplet_loss = torch.nn.TripletMarginLoss(margin=1, p=2, reduction='mean')
    tri_output = triplet_loss(anchor_arr, positive_arr, negative_arr)

    loss = _loss + tri_output
    loss.backward()
    optimizer.step()

    if scheduler is not None:
        scheduler.step()
    # loss_dict = {"loss": loss.item()}
    epoch_iter.set_description(f"# Epoch {epoch}: train_loss: {loss.item():.4f}")

with torch.no_grad():
    z = model.embed(graph, x)

# model.eval()
ad_concat.obsm["MG_triplet"] = z.cpu().detach().numpy()



  0%|          | 0/2000 [00:00<?, ?it/s]

Processing datasets (0, 1)


# Epoch 1643: train_loss: 0.2435:  82%|████████▏ | 1643/2000 [01:40<00:21, 16.54it/s]

### evaluate

In [None]:
ari_ = visualization_umap_spatial(ad_temp=ad_concat, section_ids=section_ids, exp_fig_dir=exp_fig_dir, dataset_name=args.dataset_name, num_iter="1", identifier="stage2", num_class=args.num_class, use_key="MG_triplet")
# ari_1.append(ari_[0])
# ari_2.append(ari_[1])
print(section_ids[0], ', ARI = %01.3f' % ari_[0])
print(section_ids[1], ', ARI = %01.3f' % ari_[1])