<a href="https://colab.research.google.com/github/LiuLab-Bioelectronics-Harvard/UnitedNet/blob/main/colab_notebooks/DBiTseq_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# install required modules for colab
!pip install anndata scanpy

In [None]:
# mount the google drive
from google.colab import drive
drive.mount('/content/gdrive')
#============================clone the repository to the folder 'UnitedNet_TestCodes' in google drive============ 
# %cd /content/gdrive/My Drive/UnitedNet_TestCodes
# ! git clone https://github.com/LiuLab-Bioelectronics-Harvard/UnitedNet.git
#change the directory
%cd /content/gdrive/My Drive/UnitedNet_TestCodes/UnitedNet

# dbitseq

In [None]:
import anndata as ad
import numpy as np
import scanpy as sc
import pandas as pd


import seaborn as sns
from collections import Counter
from sklearn.neighbors import NearestNeighbors
import pandas as pd
import sys
sys.path.append('..')
import matplotlib.pyplot as plt
from sklearn.metrics import adjusted_rand_score

from src.interface import UnitedNet
from src.configs import *
from scipy.stats import spearmanr, pearsonr

In [None]:
def split_data(test_batch):
    adata_rna_train = adata_rna_all[adata_rna_all.obs['batch'] != test_batch]
    adata_morph_train = adata_morph_all[adata_morph_all.obs['batch'] != test_batch]
    adata_mrna_niche_train = adata_mrna_niche_all[adata_mrna_niche_all.obs['batch'] != test_batch]

    adata_rna_test = adata_rna_all[adata_rna_all.obs['batch'] == test_batch]
    adata_morph_test = adata_morph_all[adata_morph_all.obs['batch'] == test_batch]
    adata_mrna_niche_test = adata_mrna_niche_all[adata_mrna_niche_all.obs['batch'] == test_batch]

    return [adata_rna_train, adata_morph_train, adata_mrna_niche_train], [adata_rna_test, adata_morph_test,
                                                                          adata_mrna_niche_test]


def concat_adatas(adatas_train, adatas_test):
    return [ad.concat([adata_train, adata_test]) for adata_train, adata_test in zip(adatas_train, adatas_test)]


def save_umap(adata_all,label,test_batch,nametype):
  fig,ax = plt.subplots(figsize=(6,4))
  ax=sc.pl.umap(adata_all,color=label, ax=ax,show=False)
  fig.savefig(root_save_path+f'/plot/{test_batch}_{nametype}_{label}.png',dpi=300)
  

def generate_adata(data, nonnan_indices, cell_type_label, cols, rows, batch):
    data = data.loc[data.index[nonnan_indices]]
    adata=ad.AnnData(X=np.array(data),obs=list(data.index))
    adata.obs['label']  = cell_type_label
    adata.obs['imagecol'] = cols
    adata.obs['imagerow'] = rows
    adata.obs['batch']  = batch
    return adata

In [None]:
from sklearn import preprocessing
def change_label(adata,batch):
    adata.obs['batch'] = batch
    adata.obs['imagecol'] = adata.obs['array_col']
    adata.obs['imagerow'] = adata.obs['array_row']
    adata.obs['label'] = adata.obs['cell_type']
    return adata

In [None]:
def pre_ps(adata_list,sc_pre = None):
    adata_list_all = [ad_x.copy() for ad_x in adata_list]
    scalars = []
    assert (adata_list_all[0].X>=0).all(), "poluted input"
    for idx, mod in enumerate(adata_list_all):
        t_x = mod.X
        if sc_pre != None:
            scaler = sc_pre[idx]
        else:
            scaler = preprocessing.StandardScaler().fit(t_x)
        t_x = scaler.transform(t_x)
        mod.X = t_x
        adata_list_all[idx] = mod
        scalars.append(scaler)

    return adata_list_all,scalars

# load dbitseq data

In [None]:
technique = 'dbitseq'
#========================change for colab===========================
#data_path = f"../data/{technique}"
data_path = f"./data/{technique}"
device = "cuda:0"

Add a short cut of dbitseq from https://drive.google.com/drive/folders/1Aj01ufOiDrdCRYe_7wvLAC9tUs1LGzGj?usp=sharing to My Drive/UnitedNet_TestCodes/data

In [None]:
adata_niche_rna_train = sc.read_h5ad(f'{data_path}/adata_niche_rna_train.h5ad')
adata_niche_rna_test = sc.read_h5ad(f'{data_path}/adata_niche_rna_test.h5ad')

adata_rna_train = sc.read_h5ad(f'{data_path}/adata_rna_train.h5ad')
adata_rna_test = sc.read_h5ad(f'{data_path}/adata_rna_test.h5ad')

adata_protein_train = sc.read_h5ad(f'{data_path}/adata_protein_train.h5ad')
adata_protein_test = sc.read_h5ad(f'{data_path}/adata_protein_test.h5ad')


In [None]:
adata_rna_train = change_label(adata_rna_train,'train')
adata_protein_train=change_label(adata_protein_train,'train')
adata_niche_rna_train=change_label(adata_niche_rna_train,'train')

adata_rna_test = change_label(adata_rna_test,'test')
adata_protein_test = change_label(adata_protein_test,'test')
adata_niche_rna_test = change_label(adata_niche_rna_test,'test')

adatas_train = [adata_rna_train, adata_protein_train, adata_niche_rna_train]
adatas_test = [adata_rna_test, adata_protein_test, adata_niche_rna_test]

adatas_all = []
for ad_train, ad_test in zip(adatas_train,adatas_test):
    ad_all = ad_train.concatenate(ad_test,batch_key='sample')
    ad_all = change_label(ad_all,'test')
    adatas_all.append(ad_all)
adatas_all,_ = pre_ps(adatas_all)    
    

adatas_train,_ = pre_ps(adatas_train)   
adatas_test,_ = pre_ps(adatas_test)   


# Train and finetune UnitedNet

In [None]:
train_model = True

In [None]:
if train_model:
    technique = 'dbitseq'
    #========================change for colab===========================
    #data_path = f"../data/{technique}"
    #root_save_path = f"../saved_results/dlpfc"
    data_path = f"./data/{technique}"
    root_save_path = f"./saved_results/dlpfc"
    model = UnitedNet(root_save_path, device=device, technique=dbitseq_config)
    model.train(adatas_train,verbose=True)
    model.finetune(adatas_all,verbose=True)
else:
    technique = 'dbitseq'
    #========================change for colab===========================
    # data_path = f"../data/{technique}"
    # root_save_path = f"../saved_results/dlpfc"
    data_path = f"./data/{technique}"
    root_save_path = f"./saved_results/dlpfc"
    model = UnitedNet(root_save_path, device=device, technique=dbitseq_config)



# Predict on all data

In [None]:

model = UnitedNet(root_save_path, device=device, technique=dbitseq_config)
model.load_model(f"{root_save_path}/train_best.pt")
# model.evaluate(adatas_train,give_losses=False,stage="train")
predict_label = model.predict_label(adatas_all)

print(root_save_path,'ari:',adjusted_rand_score(adatas_all[0].obs['cell_type'],
                    predict_label))

In [None]:
from src.data import create_dataloader
dataloader_test = create_dataloader(
    model.model,
    adatas_train,
    shuffle=False,
    batch_size=model.model.config["train_batch_size"],
)

In [None]:
import torch
model.model.best_head = torch.tensor(7)

In [None]:
from src.scripts import run_evaluate
metrics = run_evaluate(model.model, dataloader_test)

In [None]:
metrics

In [None]:
#smooth and plot results

coord=np.array((list(adatas_all[0].obs['array_row'].astype('int')),
                list(adatas_all[0].obs['array_col'].astype('int')))).T


united_clus=list(predict_label)

coord=np.array((list(adatas_all[0].obs['array_row'].astype('int')),
                list(adatas_all[0].obs['array_col'].astype('int')))).T

nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(coord)
distances,indices = nbrs.kneighbors(coord)

united_clus_new=[]
for indi,i in enumerate(united_clus):
  np.array(united_clus)[(indices[indi])]
  occurence_count=Counter(np.array(united_clus)[(indices[indi])])
  united_clus_new.append(occurence_count.most_common(1)[0][0])
    

cluster_pl = sns.color_palette('tab20',20)
color_list = [cluster_pl[5],
 cluster_pl[1],
 cluster_pl[2],
 cluster_pl[4],
 cluster_pl[11],
 cluster_pl[6],
 cluster_pl[3],
 cluster_pl[7],
 cluster_pl[8],
 cluster_pl[0]]    

In [None]:
#================================================================================================
# for import errors in colab, restart runtime and run the codes again
#================================================================================================
plt.figure(figsize=(6,5))
for idx,clus_id in enumerate(set(united_clus_new)):
    
    plt.scatter(adatas_all[0].obs['array_row'][(united_clus_new==clus_id)],
               adatas_all[0].obs['array_col'][(united_clus_new==clus_id)],
               color=color_list[idx],cmap='tab20')
plt.axis('off')
# plt.savefig('dbitseq.png',dpi=300)