In [1]:
from __future__ import division
from __future__ import print_function

import torch
import torch.optim as optim
import numpy as np
from utils import *
from models import DMGCN
import os
from config import Config
from sklearn import metrics
import pandas as pd
import matplotlib.pyplot as plt
import igraph
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
data = "DLPFC" # optinal: 'HBC', 'MOSTA'
datasets = ['151507', '151508', '151509', '151510', '151669', '151670',
           '151671', '151672', '151673', '151674','151675', '151676']
# dataset optional: 
# ‘HBC’,  ['Human_Breast_Cancer'], 
# 'MOSTA,'['MOSTA']
config_file = f'../ini/{data}.ini'
config = Config(config_file)

for i in range(len(datasets)):
    dataset = datasets[i]
    print(dataset)
    print("load data:")
    
    path = "/gemini/code/generate_data/" + dataset + "/DMGCN.h5ad"
    adata = sc.read_h5ad(path)
    print(adata)

    features = torch.FloatTensor(adata.X)
    labels = adata.obs['ground']
    fadj = sparse_mx_to_torch_sparse_tensor(adata.obsm['fadj'])
    sadj = sparse_mx_to_torch_sparse_tensor(adata.obsm['sadj'])
    graph_nei = torch.LongTensor(adata.obsm['graph_nei'])
    graph_neg = torch.LongTensor(adata.obsm['graph_neg'])        

    plt.rcParams["figure.figsize"] = (3, 3)
    savepath = '../result_test/' + dataset + '/'
    if not os.path.exists(savepath):
        os.mkdir(savepath)
    title = "Manual annotation (slice #" + dataset + ")"
    sc.pl.spatial(adata, img_key="hires", color=['ground_truth'], title=title,
                  show=False)
    plt.savefig(savepath + 'Manual Annotation.jpg', bbox_inches='tight', dpi=600)
    plt.show()

    cuda = not config.no_cuda and torch.cuda.is_available()
    use_seed = not config.no_seed

    _, ground = np.unique(np.array(labels, dtype=str), return_inverse=True)
    ground = torch.LongTensor(ground)
    # config.n = len(ground)
    n_clisters = len(ground.unique())

    if cuda:
        features = features.cuda()
        sadj = sadj.cuda()
        fadj = fadj.cuda()
        graph_nei = graph_nei.cuda()
        graph_neg = graph_neg.cuda()

    import random
    np.random.seed(config.seed)
    torch.cuda.manual_seed(config.seed)
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    os.environ['PYTHONHASHSEED'] = str(config.seed)
    if not config.no_cuda and torch.cuda.is_available():
        torch.cuda.manual_seed(config.seed)
        torch.cuda.manual_seed_all(config.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True
        
    print(dataset, ' ', config.lr, ' ', config.alpha, ' ', config.beta, ' ', config.gamma)
    model = DMGCN(nfeat=config.fdim,
                         nhid1=config.nhid1,
                         nhid2=config.nhid2,
                         dropout=config.dropout)
    
    if cuda:
        model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

    epoch_max,ari_max,nmi_max = 0,0,0
    pred_max,mean_max,emb_max = [],[],[]

    add_contrastive_label(adata)
    features_a = permutation(features)
    pretrain = False
    for epoch in range(config.epochs+1):

        model.train()
        optimizer.zero_grad()
        model.train()
        optimizer.zero_grad()
        com1, com2, emb, pi, disp, mean, ret_1, ret_2 = model(features, features_a, sadj, fadj)
    
        bce_loss = torch.nn.BCEWithLogitsLoss(reduction='sum').to(torch.device('cuda:0'))
        label_CSL = torch.FloatTensor(adata.obsm['label_CSL']).to(torch.device('cuda:0')) 
        
        cl_loss_1 = bce_loss(ret_1.to(torch.device('cuda:0')),label_CSL) 
        cl_loss_2 = bce_loss(ret_2.to(torch.device('cuda:0')),label_CSL)
        co_loss = config.beta * (cl_loss_1 + cl_loss_2)/2 + config.gamma * consistency_loss(com1, com2)
        zinb_loss = config.alpha * ZINB(pi, theta=disp, ridge_lambda=0).loss(features, mean, mean=True)
        total_loss =  zinb_loss + co_loss 
    
        emb = pd.DataFrame(emb.cpu().detach().numpy()).fillna(0).values
        mean = pd.DataFrame(mean.cpu().detach().numpy()).fillna(0).values
        total_loss.backward()
        optimizer.step()                 
        
        kmeans = KMeans(n_clusters=n_clisters, n_init=10).fit(emb)
        
        pred = kmeans.labels_
        ari_res = metrics.adjusted_rand_score(labels, pred)
        nmi_res = metrics.normalized_mutual_info_score(labels, pred) 
        if ari_res > ari_max:
            ari_max = ari_res
            nmi_max = nmi_res
            epoch_max = epoch
            pred_max = pred
            mean_max = mean
            emb_max = emb

        print(dataset, ' epoch: ', epoch, 
                  ' zinb_loss = {:.2f}'.format(zinb_loss),
                  ' co_loss = {:.2f}'.format(co_loss),
                  ' total_loss = {:.2f}'.format(total_loss),
                  ' ARI = {:.2f}'.format(ari_max),
                  ' NMI = {:.2f}'.format(nmi_max),
             )

    print(dataset, 'ARI', ari_max)
    print(dataset, 'NMI', nmi_max)

    title = 'DMGCN(ARI={:.2f})'.format(ari_max)
    adata.uns['ARI'] = ari_max
    adata.uns['NMI'] = nmi_max
    adata.obs['DMGCN'] = pred_max.astype(str)
    adata.obsm['emb'] = emb_max
    adata.obsm['mean'] = mean_max

    sc.pl.spatial(adata, img_key="hires", color=['DMGCN'], title=title, show=False)
    plt.savefig(savepath + 'DMGCN.jpg', bbox_inches='tight', dpi=600)
    plt.show()

    sc.pp.neighbors(adata, use_rep='mean')
    sc.tl.umap(adata)
    plt.rcParams["figure.figsize"] = (3, 3)
    sc.tl.paga(adata, groups='DMGCN')
    sc.pl.paga_compare(adata, legend_fontsize=10, frameon=False, size=20, title=title, legend_fontoutline=2, show=False)
    plt.savefig(savepath + 'DMGCN_umap.jpg', bbox_inches='tight', dpi=600)
    plt.show()

    adata.layers['X'] = adata.X
    adata.layers['mean'] = mean_max
    adata.write(savepath + 'DMGCN.h5ad')
