# Tutorial 2: Integrate DLPFC data with MaskGraphene

In this tutorial, we show how to use MG to integrate DLPFC data. As an example, we analyze the 151507/151508 sample pair of the dorsolateral prefrontal cortex (DLPFC) dataset. 

We acquired the data from the spatialLIBD webpage, including manual annotations. Before running the model, please download the input data via [zenodo link]()

### Loading packages

In [1]:
import logging
import numpy as np
from tqdm import tqdm
import torch
import pickle
import wandb

from utils import (
    build_args_ST,
    create_optimizer,
    set_random_seed,
    get_current_lr,
)
from datasets.data_proc import load_ST_dataset_hard, load_ST_dataset_hard_erase
from datasets.st_loading_utils import create_dictionary_otn, visualization_umap_spatial, create_dictionary_mnn, cal_layer_based_alignment_result, cal_layer_based_alignment_result_mhypo
from models import build_model_ST
import os
import scanpy as sc
import sklearn.metrics.pairwise

# from maskgraphene_main import main


2024-03-06 16:54:33,522 - INFO - Enabling RDKit 2022.09.5 jupyter extensions


### HP setup

In [2]:
args = build_args_ST()
args.max_epoch=2000
args.max_epoch_triplet=500
args.section_ids=["151507","151508"]
args.num_class=7
args.num_hidden="512,32"
args.alpha_l=1
args.lam=1 
args.loss_fn="sce" 
args.mask_rate=0.5 
args.in_drop=0 
args.attn_drop=0 
args.remask_rate=0.1
args.seeds=[2023] 
args.num_remasking=1 
args.hvgs=3000 
args.dataset="DLPFC" 
args.consecutive_prior=1 
args.lr=0.001

#### remember to change these paths to your data path/link path
args.st_data_dir="/home/yunfei/spatial_benchmarking/benchmarking_data/DLPFC12"
args.pi_dir="/home/yunfei/spatial_dl_integration/MaskGraphene/PI"

### MG training

### consecutive slices

In [None]:
import dgl
import scipy
import anndata
from datasets.st_loading_utils import load_DLPFC, Cal_Spatial_Net, simple_impute

def dlpfc_loader(dataset_name, pi, section_ids=["151507", "151508"], hvgs=5000, st_data_dir="./"):
    
    if "DLPFC" in dataset_name:
        # ad_list = []
        Batch_list = []
        adj_list = []
        # delist = []
        
        for section_id in section_ids:
            ad_ = load_DLPFC(root_dir=st_data_dir, section_id=section_id)
            ad_.var_names_make_unique(join="++")

            # """add cached de list"""
            # file_path = "/home/yunfei/spatial_dl_integration/MaskGraphene/de_temp/"+section_id+"_de_list.txt"
            # with open(file_path, 'r') as file:
            #     lines = file.readlines()
            # de_ = [line.strip() for line in lines]
            # delist.append(de_)
            # make spot name unique
            ad_.obs_names = [x+'_'+section_id for x in ad_.obs_names]
            
            # Constructing the spatial network
            Cal_Spatial_Net(ad_, rad_cutoff=150) # the spatial network are saved in adata.uns[‘adj’]
            adj_list.append(ad_.uns['adj'])
            Batch_list.append(ad_)
        
        # print(np.nonzero())
        ad1, ad2 = simple_impute(Batch_list[0], Batch_list[1], pi)
        Batch_list = [ad1, ad2]
        Batch_list_new = []
        for ad_ in Batch_list:
            # Normalization
            sc.pp.highly_variable_genes(ad_, flavor="seurat_v3", n_top_genes=hvgs)
            sc.pp.normalize_total(ad_, target_sum=1e4)
            sc.pp.log1p(ad_)
            ad_ = ad_[:, ad_.var['highly_variable']]

            # union_list = set(item for sublist in delist for item in sublist)
            # ad_ = ad_[:, list(union_list)]
            Batch_list_new.append(ad_)

            
        adata_concat = anndata.concat(Batch_list_new, label="slice_name", keys=section_ids, uns_merge="same")
        adata_concat.obs['original_clusters'] = adata_concat.obs['original_clusters'].astype('category')
        adata_concat.obs["batch_name"] = adata_concat.obs["slice_name"].astype('category')

        adj_concat = np.asarray(adj_list[0].todense())
        for batch_id in range(1,len(section_ids)):
            adj_concat = scipy.linalg.block_diag(adj_concat, np.asarray(adj_list[batch_id].todense()))

        if pi is not None:
            assert adj_concat.shape[0] == pi.shape[0] + pi.shape[1], "adj matrix shape is not consistent with the pi matrix"

            """keep max"""
            # max_values = np.max(pi, axis=1)

            # # Create a new array with zero
            # pi_keep_argmax = np.zeros_like(pi)

            # # Loop through each row and set the maximum value to 1 (or any other desired value)
            # for i in range(pi.shape[0]):
            #     pi_keep_argmax[i, np.argmax(pi[i])] = max_values[i]
            
            # pi = pi_keep_argmax
            """"""

            for i in range(pi.shape[0]):
                for j in range(pi.shape[1]):
                    if pi[i][j] > 0:
                        adj_concat[i][j+pi.shape[0]] = 1
                        adj_concat[j+pi.shape[0]][i] = 1
        
        edgeList = np.nonzero(adj_concat)
        graph = dgl.graph((edgeList[0], edgeList[1]))
        graph.ndata["feat"] = torch.tensor(adata_concat.X)
    num_features = graph.ndata["feat"].shape[1]
    return graph, num_features, adata_concat

In [None]:
"""if hard-links provided"""

"""file save path"""
exp_fig_dir = args.exp_fig_dir
st_data_dir = args.st_data_dir
pi_dir= os.path.join(args.pi_dir, args.dataset+'_'.join(args.section_ids))


file = open(os.path.join(pi_dir, "S.pickle"),'rb') 
global_PI = pickle.load(file)
global_PI = global_PI.toarray()
"""
STAGE 1
"""

"""train with MSSL"""
graph, num_features, ad_concat = dlpfc_loader(dataset_name=args.dataset, pi=global_PI, section_ids=args.section_ids, hvgs=args.hvgs, st_data_dir=st_data_dir)
args.num_features = num_features
x = graph.ndata["feat"]

model = build_model_ST(args)
print(model)

device = args.device if args.device >= 0 else "cpu"
optimizer = create_optimizer(args.optimizer, model, args.lr, args.weight_decay)

if args.scheduler:
    logging.info("Use scheduler")
    scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / args.max_epoch) ) * 0.5
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
else:
    scheduler = None



### training

In [None]:
model.to(device)
graph = graph.to(device)
x = x.to(device)
# print(ad_concat[0])
model, ad_concat_1 = MG(model, graph, x, optimizer, max_epoch, device, ad_concat, scheduler, logger=logger, key_="MG")
# print(ad_concat_1)
# print(ad_concat_1.obsm["MG"])
ari_ = visualization_umap_spatial(ad_temp=ad_concat_1, section_ids=section_ids, exp_fig_dir=exp_fig_dir, dataset_name=dataset_name, num_iter=counter, identifier="stage1", num_class=args.num_class, use_key="MG")
ari_1_pre.append(ari_[0])
ari_2_pre.append(ari_[1])
if logger is not None:
    logger.log({"slice1_ari_pre": ari_[0], "slice2_ari_pre": ari_[1]})
# print(section_id)
print(section_ids[0], ', ARI = %01.3f' % ari_[0])
print(section_ids[1], ', ARI = %01.3f' % ari_[1])
# exit(-1)
"""train with MSSL + triplet loss"""
logging.info("Keep training Model with cse + triplet loss ")

model, ad_concat_2 = MG_triplet(model, graph, x, optimizer, max_epoch_triplet, device, adata_concat_=ad_concat_1, pis_list_=[global_PI], scheduler=scheduler, logger=logger, key_="MG_triplet")
ari_ = visualization_umap_spatial(ad_temp=ad_concat_2, section_ids=section_ids, exp_fig_dir=exp_fig_dir, dataset_name=dataset_name, num_iter=counter, identifier="stage2", num_class=args.num_class, use_key="MG_triplet")
counter += 1
ari_1.append(ari_[0])
ari_2.append(ari_[1])
if logger is not None:
    logger.log({"slice1_ari_after": ari_[0], "slice2_ari_after": ari_[1]})
print(section_ids[0], ', ARI = %01.3f' % ari_[0])
print(section_ids[1], ', ARI = %01.3f' % ari_[1])