In [33]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import f1_score
import datetime
import argparse,copy
from sklearn.model_selection import train_test_split
from torch.utils.data import  DataLoader
from sklearn.metrics import roc_curve, roc_auc_score
import matplotlib.pyplot as plt 
from sklearn.metrics import accuracy_score, f1_score
from torch.optim.lr_scheduler import StepLR
from sklearn.manifold import TSNE
import torch.utils.data as Data

import warnings
import sklearn.exceptions
warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning)

import time

import sys
sys.path.append('./ScanNet')
from scannet import *
from datasets import GCNDataset

In [34]:
parser = argparse.ArgumentParser(description='description of')
parser.add_argument('--dataset', default='Muraro', type=str)
parser.add_argument('--cross_protocol', default=False, type=bool)
parser.add_argument('--lr', default=0.01, type=float, help='Initial learning rate')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay (L2 loss on parameters)')
parser.add_argument('--type_fusion', default='att', type=str, help='fusion method')
parser.add_argument('--type_att_size', default=32, type=int, help='attention parameter dimension')
parser.add_argument('--cuda', default=True, type=bool, help='cpu or gpu') 
parser.add_argument('--epochs', default=60, type=int, help='Number of epoch')
parser.add_argument('--batch_size', default=64, type=int, help='Number of batch size')
parser.add_argument('--in_dim', default=1, type=int, help='dim of input')
args = parser.parse_args(args=[])

In [35]:
if args.cuda and torch.cuda.is_available():
        device=torch.device('cuda')
else:
        device=torch.device('cpu')

In [36]:
def loadbench(dataset):
    data_dir='./data/pbmcbench/data_pbmcbench.csv'
    data=pd.read_csv(data_dir,index_col=0,header=0)
    train_df=data[data['protocol'].isin(dataset.split('+'))]
    test_df=data[~data['protocol'].isin(dataset.split('+'))]
    adj_tf_gene=pd.read_csv('./data/pbmcbench/cpmadj_tf_gene.csv',index_col=0,header=0)
    tf_num=adj_tf_gene.shape[0]
    gene_num=adj_tf_gene.shape[1]
    adj_tf_gene=torch.tensor(adj_tf_gene.values,dtype=torch.float32)
    test_tf=test_df.iloc[:,:tf_num].values
    test_gene=test_df.iloc[:,tf_num:tf_num+gene_num].values
    label_test=torch.tensor(test_df['cell_type_label'].values,dtype=torch.int64)
    ft_dict_test=[]
    for i in range(test_tf.shape[0]):
        ft_dict={'tf':torch.tensor(test_tf[i].reshape(-1,1),dtype=torch.float32),'gene':torch.tensor(test_gene[i].reshape(-1,1),dtype=torch.float32)}
        ft_dict_test.append(ft_dict)
    train_tf=train_df.iloc[:,:tf_num].values
    train_gene=train_df.iloc[:,tf_num:tf_num+gene_num].values
    label_sup=torch.tensor(train_df['cell_type_label'].values,dtype=torch.int64)
    ft_dict_sup=[]
    for i in range(train_tf.shape[0]):
        ft_dict={'tf':torch.tensor(train_tf[i].reshape(-1,1),dtype=torch.float32),'gene':torch.tensor(train_gene[i].reshape(-1,1),dtype=torch.float32)}
        ft_dict_sup.append(ft_dict)
    ft_dict_train,ft_dict_valid,label_train,label_valid = train_test_split(ft_dict_sup,label_sup,test_size=0.2,random_state=42,stratify=label_sup) 
    return ft_dict_train,ft_dict_valid,ft_dict_test,label_train,label_valid,label_test,adj_tf_gene


In [37]:
if not args.cross_protocol:
    load_path='./data/{}/'.format(args.dataset)
    ft_dict_list=torch.load(load_path+'logft_dict_list.pt')
    label=torch.load(load_path+'label.pt')
    adj_tf_gene=torch.load(load_path+'cpmadj_tf_gene.pt')
    ft_dict_sup,ft_dict_test,label_sup,label_test = train_test_split(ft_dict_list,label,test_size=0.3,random_state=42,stratify=label)
    ft_dict_train,ft_dict_valid,label_train,label_valid = train_test_split(ft_dict_sup,label_sup,test_size=0.2,random_state=42,stratify=label_sup) 
else:
    ft_dict_train,ft_dict_valid,ft_dict_test,label_train,label_valid,label_test,adj_tf_gene=loadbench(args.dataset)


  ft_dict_list=torch.load(load_path+'logft_dict_list.pt')
  label=torch.load(load_path+'label.pt')
  adj_tf_gene=torch.load(load_path+'cpmadj_tf_gene.pt')


In [38]:
# 获得初始化的adj_dict
adj_dict={'tf':None,'gene':None}
adj_gene_tf=adj_tf_gene.T

degree_tf=torch.abs(adj_tf_gene).sum(dim=1)
degree_tf_inv=torch.pow(degree_tf, -0.5)
degree_gene=torch.abs(adj_tf_gene).sum(dim=0)
degree_gene_inv=torch.pow(degree_gene, -0.5)
D_tf_inv=torch.diag_embed(degree_tf_inv)
D_gene_inv=torch.diag_embed(degree_gene_inv)
adj_dict['tf']=torch.matmul(torch.matmul(D_tf_inv,adj_tf_gene),D_gene_inv)

adj_dict['gene']=adj_dict['tf'].t()

In [None]:
unique_elements = torch.unique(label_test,return_inverse=False)
args.class_num=len(unique_elements)
# dataloader
dataset={'train':None,'valid':None,'test':None}
dataset['train']=GCNDataset(args,ft_dict_train,label_train)
dataset['valid']=GCNDataset(args,ft_dict_valid,label_valid)

dataloader={'train':None,'valid':None,'test':None}
dataloader['train']=DataLoader(dataset['train'],batch_size=args.batch_size,shuffle=True,pin_memory=True)
dataloader['valid']=DataLoader(dataset['valid'],batch_size=args.batch_size,shuffle=True,pin_memory=True)
args.train_size=label_train.shape[0]

In [40]:
# Model and optimizer
net_schema = {'tf':['gene'],'gene':['tf']}
type_nodes={'tf':len(ft_dict_test[0]['tf']),'gene':len(ft_dict_test[0]['gene'])}
all_nodes=len(ft_dict_test[0]['tf'])+len(ft_dict_test[0]['gene'])
tf_nodes=len(ft_dict_test[0]['tf'])
layer_shape=[args.in_dim,8,16,32,args.class_num]
model = ScanNet(
            net_schema=net_schema,
            layer_shape=layer_shape,
            all_nodes=all_nodes,
            tf_nodes=tf_nodes,
            type_fusion=args.type_fusion,
            type_att_size=args.type_att_size,
            )
model=model.to(device)
optimizer = optim.SGD(model.parameters(), momentum=0.9, lr= args.lr)

In [41]:
global_step = 0
decay = 0.95
decay_steps = args.train_size
def adjust_learning_rate(optimizer, lr):
    lr = lr * pow( decay , float(global_step// decay_steps) ) # decay by one epoch
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

best_val_f1 = float(0) 
best_model_weights = None

In [42]:
def train():
    model.train() 
    total_loss = 0    
    all_labels=[]
    all_predictions=[]
    for k in adj_dict.keys():
        adj_dict[k]=adj_dict[k].to(device)
    global global_step,best_val_f1,best_model_weights
    cur_lr = adjust_learning_rate(optimizer, args.lr)
    
    for i,imbalanced_batch in enumerate(dataloader['train']):
        optimizer.zero_grad()

        im_ft_dict,im_label=imbalanced_batch
        for k in im_ft_dict.keys():
            im_ft_dict[k] = im_ft_dict[k].to(device)

        im_label=im_label.to(device)
        ft=torch.cat((im_ft_dict['tf'],im_ft_dict['gene']),dim=1)
        ft=ft.squeeze(2)

        logits,gnn_re,cell_embd=model(im_ft_dict,adj_dict)

        loss=model.loss(gnn_re,ft,logits,im_label,args,cell_embd)

        loss.backward()
        optimizer.step()
        total_loss = total_loss + loss.item()
        global_step += args.batch_size 


        all_labels.extend(im_label.detach().cpu().numpy())
        all_predictions.extend(F.softmax(logits,-1).detach().cpu().numpy())

    
    epoch_loss_train=total_loss/len(dataloader['train'])
    train_labels = np.array(all_labels)    
    train_predictions = np.argmax(np.array(all_predictions),axis=1)

    model.eval()  
    with torch.no_grad():    
        total_loss_val = 0
        val_labels = []
        val_predictions = []    
        for i, batch in enumerate(dataloader['valid']): 
            ft_dict,label=batch
            for k in ft_dict.keys():
                ft_dict[k] = ft_dict[k].to(device)
            label=label.to(device)
            ft=torch.cat((ft_dict['tf'],ft_dict['gene']),dim=1)
            ft=ft.squeeze(2)

            logits,gnn_re,cell_embd=model(ft_dict,adj_dict)
            
            loss=model.loss(gnn_re,ft,logits,label,args,cell_embd)      

            total_loss_val = total_loss_val + loss.item()

            val_labels.extend(label.cpu().numpy())
            val_predictions.extend(F.softmax(logits,-1).detach().cpu().numpy())
    
    epoch_loss_val=total_loss_val/len(dataloader['valid'])
    val_labels = np.array(val_labels)    
    val_predictions = np.argmax(np.array(val_predictions),axis=1)

    accuracy_val = accuracy_score(val_labels, val_predictions)
    f1_val = f1_score(val_labels, val_predictions,average='macro')

    if f1_val>best_val_f1:
        best_val_f1=accuracy_val
        best_model_weights=model.state_dict()  # save model weight

    return epoch_loss_train,epoch_loss_val,accuracy_val,f1_val,cur_lr


In [43]:
## train
train_loss_list=[]
val_loss_list=[]
val_acc_list=[]
val_f1_list=[]
for epoch in range(args.epochs):
    train_loss,val_loss,val_acc,val_f1,cur_lr=train()
    train_loss_list.append(train_loss)
    val_loss_list.append(val_loss)
    val_acc_list.append(val_acc)
    val_f1_list.append(val_f1)
    if epoch%5 == 0:
        print('train loss:',train_loss,'valid loss:',val_loss,'acc:',val_acc,'f1_macro:',val_f1,'learning rate:',cur_lr)

train loss: 18.275768079255755 valid loss: 15.551719284057617 acc: 0.8754208754208754 f1_macro: 0.6393104968723973 learning rate: 0.01
train loss: 4.064474582672119 valid loss: 4.580415058135986 acc: 0.9730639730639731 f1_macro: 0.9687376584270809 learning rate: 0.007737809374999998
train loss: 3.9953305972249886 valid loss: 4.6282837867736815 acc: 0.9764309764309764 f1_macro: 0.9751958739958154 learning rate: 0.005987369392383787
train loss: 3.970906295274433 valid loss: 4.402698040008545 acc: 0.9764309764309764 f1_macro: 0.9751958739958154 learning rate: 0.00463291230159753
train loss: 3.948802057065462 valid loss: 4.3958250999450685 acc: 0.9764309764309764 f1_macro: 0.9751958739958154 learning rate: 0.0035848592240854188
train loss: 3.9343292838648747 valid loss: 4.405804443359375 acc: 0.9764309764309764 f1_macro: 0.9751958739958154 learning rate: 0.002773895731218338
train loss: 3.9181151892009534 valid loss: 4.351572799682617 acc: 0.9764309764309764 f1_macro: 0.9751958739958154 le

In [None]:
# test
dataset['test']=GCNDataset(args,ft_dict_test,label_test)
dataloader['test']=DataLoader(dataset['test'],batch_size=args.batch_size,shuffle=False,pin_memory=True)

model.load_state_dict(best_model_weights)
model.eval() 
all_cell_embd=[] 
all_tf_embd=[]
all_tg_embd=[]
all_labels=[]
all_predictions=[]
for k in adj_dict.keys():
        adj_dict[k]=adj_dict[k].to(device)
with torch.no_grad():        
    for i, batch in enumerate(dataloader['test']):
        ft_dict,label=batch  
        for k in ft_dict.keys():
            ft_dict[k] = ft_dict[k].to(device)
        for k in adj_dict.keys():
            adj_dict[k] = adj_dict[k].to(device)
        label=label.to(device)

        logits,output_re,cell_embd=model(ft_dict,adj_dict)

        all_cell_embd.extend(cell_embd.cpu().numpy())   
        all_labels.extend(label.squeeze().detach().cpu().numpy())
        all_predictions.extend(F.softmax(logits,-1).detach().cpu().numpy())

test_labels=np.array(all_labels)
test_pred = np.argmax(np.array(all_predictions),axis=1)
test_predictions=np.array(all_predictions)
test_embd=np.array(all_cell_embd)


  best_model_weights=torch.load('/home/lyy/oie-HGCN/results/Muraro/HGCN_pure_20251204_1917.pth')


In [48]:
from sklearn.metrics import average_precision_score, f1_score, precision_score, recall_score

def evaluate(y_true,y_pred,y_score):
    f1 = f1_score(y_true, y_pred, average='weighted')
    precision = precision_score(y_true, y_pred,average='weighted')
    recall = recall_score(y_true, y_pred,average='weighted')
    # 计算AUPRC
    auprc = average_precision_score(y_true, y_score)
    print("f1:{:.4f}, precision:{:.4f},recall:{:.4f}, AUPRC:{:.4f}".format(f1,precision,recall,auprc))

In [49]:
evaluate(test_labels,test_pred,test_predictions)

f1:0.9718, precision:0.9726,recall:0.9717, AUPRC:0.9869


In [None]:
# save model
import os
path='./results/{}/'.format(args.dataset)
os.makedirs(path,exist_ok=True)
string='ScanNet_{}'.format(datetime.datetime.now().strftime('%Y%m%d_%H%M'))
save_dir=path+'{}'.format(string)
print(save_dir)
os.makedirs(save_dir,exist_ok=True)
torch.save(model.state_dict(), os.path.join(save_dir,'{}.pth'.format(string)))