In [2]:
import scanpy as sc
from torch.utils.tensorboard import SummaryWriter

import argparse
import sys
import os, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import to_undirected, remove_self_loops, add_self_loops, subgraph, k_hop_subgraph
from torch_scatter import scatter
from sklearn.neighbors import kneighbors_graph

# 因为修改了dataset文件，直接import得到的一直是之前未修改的代码，所以要reload一下，先import，再reload
import dataset
import importlib
importlib.reload(dataset)

import logger
importlib.reload(logger)

from logger import Logger
from dataset import input_dataset
from data_utils import load_fixed_splits, adj_mul, get_gpu_memory_map, to_sparse_tensor

import eval
importlib.reload(eval)

from eval import evaluate_cpu, eval_acc, eval_rocauc, eval_f1
from parse import parse_method, parser_add_main_args
import time

import warnings
warnings.filterwarnings('ignore')

# NOTE: for consistent data splits, see data_utils.rand_train_test_idx
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [4]:
parser = argparse.ArgumentParser()

from nodeformer import *
from data_utils import normalize

default_args = {
    'method': 'nodeformer',
    'dataset': 'data',
    'sub_dataset': '',
    'data_dir': '../data/asaprpca/', # 标明数据的位置
    'device': 0,
    'seed': 42,
    'epochs': 5000,
    'eval_step': 10,
    'cpu': False,
    'runs': 1,
    'train_prop': 1,
    'valid_prop': 0,
    'protocol': 'semi',
    'rand_split': True,
    'rand_split_class': False,
    'label_num_per_class': 20,
    'metric': 'acc',
    'knn_num': 5,
    'save_model': True,
    'model_dir': '../model/',
    'hidden_channels': 128,
    'dropout': 0.0,
    'lr': 1e-4,
    'weight_decay': 1e-3,
    'num_layers': 2,
    'num_heads': 1,
    'M': 35,
    'use_gumbel': True,
    'use_residual': True,
    'use_bn': True,
    'use_act': False,
    'use_jk': False,
    'K': 15,
    'tau': 0.25,
    'lamda': 0.4,
    'rb_order': 2,
    'rb_trans': 'sigmoid',
    'batch_size': 3000,
    'hops': 1,
    'cached': False,
    'gat_heads': 2,
    'out_heads': 1,
    'projection_matrix_type': True,
    'lp_alpha': 0.1,
    'gpr_alpha': 0.1,
    'directed': False,
    'jk_type': 'max',
    'num_mlp_layers': 1,
    'num_batch': 2
}

args = argparse.Namespace(**default_args)
print(args)

##################### 上面都是设置命令行参数 #####################
fix_seed(args.seed)

if args.cpu:
    device = torch.device("cpu")
else:
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

Namespace(K=15, M=35, batch_size=3000, cached=False, cpu=False, data_dir='../data/asaprpca/', dataset='data', device=0, directed=False, dropout=0.0, epochs=5000, eval_step=10, gat_heads=2, gpr_alpha=0.1, hidden_channels=128, hops=1, jk_type='max', knn_num=5, label_num_per_class=20, lamda=0.4, lp_alpha=0.1, lr=0.0001, method='nodeformer', metric='acc', model_dir='../model/', num_batch=2, num_heads=1, num_layers=2, num_mlp_layers=1, out_heads=1, projection_matrix_type=True, protocol='semi', rand_split=True, rand_split_class=False, rb_order=2, rb_trans='sigmoid', runs=1, save_model=True, seed=42, sub_dataset='', tau=0.25, train_prop=1, use_act=False, use_bn=True, use_gumbel=True, use_jk=False, use_residual=True, valid_prop=0, weight_decay=0.001)


In [5]:
%%time
import time

### Load and preprocess data ###
dataset = input_dataset(args.data_dir, args.dataset)

if len(dataset.label.shape) == 1:
    dataset.label = dataset.label.unsqueeze(1)

CPU times: user 7.45 s, sys: 3.73 s, total: 11.2 s
Wall time: 7.94 s


In [6]:
# get the splits for all runs
if args.rand_split:
    split_idx_lst = [dataset.get_idx_split(train_prop=args.train_prop, valid_prop=args.valid_prop)
                     for _ in range(args.runs)]
elif args.rand_split_class:
    split_idx_lst = [dataset.get_idx_split(split_type='class', label_num_per_class=args.label_num_per_class)
                     for _ in range(args.runs)]
elif args.dataset in ['ogbn-proteins', 'ogbn-arxiv', 'ogbn-products', 'amazon2m']:
    split_idx_lst = [dataset.load_fixed_splits()
                     for _ in range(args.runs)]
else:
    split_idx_lst = load_fixed_splits(args.data_dir, dataset, dataset=args.dataset, protocol=args.protocol)

n = dataset.graph['num_nodes']
# infer the number of classes for non one-hot and one-hot labels
c = max(dataset.label.max().item() + 1, dataset.label.shape[1])
d = dataset.graph['node_feat'].shape[1]

# whether or not to symmetrize
if not args.directed and args.dataset != 'ogbn-proteins':
    dataset.graph['edge_index'] = to_undirected(dataset.graph['edge_index'])

edge_index, x = dataset.graph['edge_index'], dataset.graph['node_feat']

print(f"num nodes {n} | num edges {edge_index.size(1)} | num classes {c} | num node feats {d}")

num nodes 8801 | num edges 51905 | num classes 7 | num node feats 14530


In [7]:
### Load method ###
model = parse_method(args, dataset, n, c, d, device)

### Loss function (Single-class, Multi-class) ###
if args.dataset in ('yelp-chi', 'deezer-europe', 'twitch-e', 'fb100', 'ogbn-proteins'):
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.NLLLoss()

### Performance metric (Acc, AUC, F1) ###
if args.metric == 'rocauc':
    eval_func = eval_rocauc
elif args.metric == 'f1':
    eval_func = eval_f1
else:
    eval_func = eval_acc

logger = Logger(args.runs, args)

model.train()
print('MODEL:', model)

adjs = []
adj, _ = remove_self_loops(edge_index)
adj, _ = add_self_loops(adj, num_nodes=n)
adjs.append(adj)
for i in range(args.rb_order - 1): # edge_index of high order adjacency
    adj = adj_mul(adj, adj, n)
    adjs.append(adj)
dataset.graph['adjs'] = [adj.to(torch.int).to(torch.int64) for adj in adjs]

adj_loss_inter, _ = remove_self_loops(dataset.edge[:, 0:dataset.n_infer])
adj_loss_intra2, _ = remove_self_loops(dataset.edge[:, dataset.n_infer:])

MODEL: NodeFormer(
  (convs): ModuleList(
    (0-1): 2 x NodeFormerConv(
      (Wk): Linear(in_features=128, out_features=128, bias=True)
      (Wq): Linear(in_features=128, out_features=128, bias=True)
      (Wv): Linear(in_features=128, out_features=128, bias=True)
      (Wo): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (fcs): ModuleList(
    (0): Linear(in_features=14530, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=7, bias=True)
  )
  (bns): ModuleList(
    (0-2): 3 x LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (chonggou): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=14530, bias=True)
  )
)


In [8]:
import matplotlib.pyplot as plt
def plot_histogram(data, label, name = "", bin_width=0.1):
    # 将CUDA张量转移到CPU
    values = data[0].cpu()
    categories = data[1].cpu()
    label = label.squeeze().cpu()
    
    # early stop
    # pro = sum(data[0] > 0.8)/dataset.n_data2

    # 初始化计数向量
    count_vector_1 = torch.zeros(int(1 / bin_width)) # 计p的数目
    count_vector_2 = torch.zeros(int(1 / bin_width)) # 计正确的数目

    # 判断索引与对应的类别向量中的值是否相等
    for i in range(len(values)):
        index = int(values[i] / bin_width)
        if index < len(count_vector_1):
            count_vector_1[index] += 1
            if categories[i] == label[i]:
                count_vector_2[index] += 1

    # 绘制直方图
    bins = torch.linspace(0, 1, int(1 / bin_width))
    
    # plt.figure(figsize=(3, 2), dpi=200)
    plt.bar(bins, count_vector_1, width=bin_width, align='edge', color='blue', label='False')
    plt.bar(bins, count_vector_2, width=bin_width, align='edge', color='orange', label='True', alpha=0.7)
    plt.xlabel('Probability')
    plt.ylabel('Count')
    plt.title(name)
    plt.legend()
    plt.show()
    # plt.savefig('savefig_example.png')

In [None]:
### Training loop ###
l = 100 # 用于早停
num = 0

for run in range(args.runs):
    split_idx = split_idx_lst[run]
    rna_idx = split_idx['train']
    atac_idx = torch.arange(dataset.n_data1, n)
    train_atac_idx = torch.empty(0)
    label_train = dataset.label.squeeze(1).clone()
    # num_batch1 = rna_idx.size(0) // args.num_batch
    # num_batch2 = atac_idx.size(0) // args.num_batch
    
    model.reset_parameters()
    optimizer = torch.optim.Adam(model.parameters(),weight_decay=args.weight_decay, lr=args.lr)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 200], gamma=0.5)
    best_val = float('-inf')
    
    
    for epoch in range(args.epochs):
        model.to(device)
        model.train()

        num_batch1 = rna_idx.size(0) // args.num_batch
        num_batch2 = atac_idx.size(0) // args.num_batch
        num_batch3 = train_atac_idx.size(0) // args.num_batch
        
        # dataset.label = dataset.label.to(device)
        label_train = label_train.to(device)

        idx1 = torch.randperm(rna_idx.size(0)) # 生成随机序列
        idx2 = torch.randperm(atac_idx.size(0)) # 生成随机序列
        idx3 = torch.randperm(train_atac_idx.size(0)) # 生成随机序列
        
        L = 0
        
        for i in range(args.num_batch):
            idx_i_rna = rna_idx[idx1[i*num_batch1:(i+1)*num_batch1]]
            idx_i_atac = atac_idx[idx2[i*num_batch2:(i+1)*num_batch2]]
            idx_i_train_atac = train_atac_idx[idx3[i*num_batch3:(i+1)*num_batch3]]
            
            idx_i = torch.cat((idx_i_rna, idx_i_train_atac, idx_i_atac), dim=0).long()
            x_i = x[idx_i].to(device)
            adjs_i = []
            sub_edge_inter, _, inter_mask_i = subgraph(idx_i, adj_loss_inter, num_nodes=n, relabel_nodes=True, return_edge_mask=True)
            sub_edge_intra2, _, intra2_mask_i = subgraph(idx_i, adj_loss_intra2, num_nodes=n, relabel_nodes=True, return_edge_mask=True)
            edge_index_i, _, edge_mask_i = subgraph(idx_i, adjs[0], num_nodes=n, relabel_nodes=True, return_edge_mask=True)
            adjs_i.append(edge_index_i.to(device))
            for k in range(args.rb_order - 1):
                edge_index_i, _ = subgraph(idx_i, adjs[k+1], num_nodes=n, relabel_nodes=True)
                adjs_i.append(edge_index_i.to(device))
            optimizer.zero_grad()
            out_i, link_loss_, z1, z2, chonggou = model(x_i, adjs_i, args.tau)
            out_i = F.log_softmax(out_i, dim=1)
            p = F.softmax(out_i, dim=1)
        
            idx_cross_entropy = torch.cat((idx_i_rna, idx_i_train_atac), dim=0).long()
        
            loss1 = criterion(out_i[0:idx_cross_entropy.shape[0]], label_train[idx_cross_entropy])
            # loss -= args.lamda * sum(link_loss_) / len(link_loss_)
            
            # 计算训练后锚点对间的MSE损失。希望锚点对尽量的接近
            node1_inter = sub_edge_inter[0].int()
            node2_inter = sub_edge_inter[1].int()
            feature1_inter = z2[:,node1_inter, :] ###
            feature2_inter = z2[:,node2_inter, :] ###
            mse_loss_inter = F.mse_loss(feature1_inter, feature2_inter)
            mse_loss_inter = torch.clamp(mse_loss_inter - 0.1, min=0)
            loss2 = (mse_loss_inter)# + mse_loss_intra
        
            # data2图内连接损失
            values, indices = torch.max(p, dim=1)
            node1_intra = sub_edge_intra2[0]
            node2_intra = sub_edge_intra2[1]
            node1_values = values[node1_intra.int()]
            node2_values = values[node2_intra.int()]
            node1_indices = indices[node1_intra.int()]
            node2_indices = indices[node2_intra.int()]
            a = node1_values * node2_values * (node1_indices == node2_indices).float()
            a[a > 0.001] = 1
            loss_inter2 = 1 - torch.sum(a) / len(a)
            loss3 = torch.clamp(loss_inter2 - 0.1, min=0)
        
            loss = loss1 + loss2 + loss3
            L += loss
            loss.backward()
            optimizer.step()
            # scheduler.step()

        if epoch % 5 == 0: #args.eval_step
            result = evaluate_cpu(model, dataset, split_idx, eval_func, criterion, args)
            logger.add_result(run, result[:-1])

            print(f'Epoch: {epoch:02d}, '
                  f'Loss: {loss:.4f}, '
                  f'交叉熵: {loss1:.4f}, '
                  f'Hard正则: {loss2:.4f}, '
                  f'graph2正则: {loss3:.4f}, '
                  f'Train: {100 * result[0]:.2f}%, ,'
                  f'Query Criterion Loss: {result[2]:.4f}, '
                  f'Query: {100 * result[1]:.2f}%')
            if epoch % 50 == 0:
                model.eval() # 评估模式
                model.to(torch.device("cpu"))
                out, link_loss_, z1, z2, chonggou = model(x, adjs, args.tau)
                p = F.softmax(out, dim=1)
                
                selected_indices = torch.nonzero(torch.max(p, axis = 1)[0][atac_idx] > 0.95).squeeze()
                if (selected_indices.numel()):
                    if atac_idx[selected_indices].dim() == 0:
                        query_selected = atac_idx[selected_indices].unsqueeze(0)
                    else:
                        query_selected = atac_idx[selected_indices]
                    train_atac_idx = torch.cat((train_atac_idx, query_selected), dim=0)
                
                    label_pre = torch.max(p, axis=1)[1][query_selected].to(device)
                    print("123: ",sum(label_train[query_selected] == label_pre)/(1+int(label_train[query_selected].size(0))))
                    print("数目: ", int(label_train[query_selected].size(0)), ". 正确的数目: ", sum(label_train[query_selected] == label_pre))

                    label_train[query_selected] = label_pre

                    # 从 idx_atac 的索引中去掉已移动的索引
                    mask = torch.ones_like(atac_idx, dtype=torch.bool)
                    mask[selected_indices] = 0
                    # 通过布尔掩码筛选保留的元素
                    atac_idx = atac_idx[mask]
                  
                plt.figure(figsize=(3,2))
                plot_histogram(torch.max(p[dataset.n_data1:,:], axis=1), dataset.label[dataset.n_data1:], name = "query", bin_width=0.1)
                model.to(device)
                
        if l > L/3:
            l = L/3
            num = 0
            ## 保存最优结果
            
            resultname = args.data_dir + 'results/'
            modelname = args.data_dir + 'model/'
            
            if not os.path.exists(resultname):
                os.makedirs(resultname)
            if not os.path.exists(modelname):
                os.makedirs(modelname)
            
            filename = args.data_dir + f'results/{args.dataset}.csv'
            
            model.eval() # 评估模式
            model.to(torch.device("cpu"))
            out, link_loss_, z1, z2, chonggou = model(x, adjs, args.tau)
            p = F.softmax(out, dim=1)
            
            query_acc = sum(dataset.label.squeeze().numpy()[dataset.n_data1:] == p.argmax(dim=-1, keepdim=True).detach().squeeze().numpy()[dataset.n_data1:])/len(dataset.label.squeeze().numpy()[dataset.n_data1:]);
            
            # print(f"Saving results to {filename}")
            torch.save(z2, resultname + 'embedding.pt')
            torch.save(p, resultname + 'out.pt')
            torch.save(dataset.label, resultname + 'label.pt')
            torch.save(dataset.num_celltype, resultname + 'num_celltype.pt')
            torch.save(dataset.metadata, resultname + 'metadata.pt')
            torch.save(default_args, resultname + 'args.pt')
            torch.save(model.state_dict(), modelname + f'{args.dataset}-{args.method}.pkl')
            with open(f"{filename}", 'w') as write_obj:
                write_obj.write(f"{args.method}," +
                                f"epoch: {epoch:02d}," +
                                f"Query: {100 * result[1]:.2f}%")
            model.to(device)
        else:
            num = num + 1
        
        # pro = sum(torch.max(p[dataset.n_data1:,:], axis=1)[0] > 0.8)/dataset.n_data2
        # print(pro)
        if num >= 30:# or pro > 0.8:
            print (f'Answer Query: {100 * query_acc:.2f}%')
            break;
        
        
#     logger.print_statistics(run)

# results = logger.print_statistics()

In [None]:
# 画图
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import torch

# adjs = [adj.to(device) for adj in adjs]
# x = x.to(device)

model.load_state_dict(torch.load(modelname + f'{args.dataset}-{args.method}.pkl'))
model.to(torch.device("cpu"))
model.eval()

out_i, link_loss_, z1, z2, chonggou = model(x, adjs, args.tau)
out_i = F.log_softmax(out_i, dim=1)
p = F.softmax(out_i, dim=1)

embedding = z2.cpu()
embedding = embedding[0].detach().numpy()#

label = dataset.label.cpu() # 原始的类型
label = label.squeeze().numpy() 

metadata = dataset.metadata.cpu()
metadata = metadata.squeeze().numpy() 

out = p.cpu() # 预测的类型
pre = out.argmax(dim=-1, keepdim=True).detach().squeeze().numpy()

num_celltype = dataset.num_celltype # 数字与类型的对应关系
print(num_celltype)

import umap
umap_model = umap.UMAP(n_neighbors=5, n_components=2, metric='euclidean')
embedding_tsne = umap_model.fit_transform(embedding)

In [None]:
# celltype着色
%matplotlib inline
labels = np.unique(label)  # 获取唯一的标签值
for i in labels:
    plt.scatter(embedding_tsne[label == i, 0], embedding_tsne[label == i, 1], s=0.5, label=str(i))
plt.legend()
plt.show()

In [None]:
# pre着色
%matplotlib inline
pres = np.unique(pre)  # 获取唯一的标签值
for i in pres:
    plt.scatter(embedding_tsne[pre == i, 0], embedding_tsne[pre == i, 1], s=0.5, label=str(i))
plt.legend()
plt.show()

In [None]:
# tech着色
%matplotlib inline
metadatas = np.unique(metadata)  # 获取唯一的标签值
for i in metadatas:
    plt.scatter(embedding_tsne[metadata == i, 0], embedding_tsne[metadata == i, 1], s=0.01, label=str(i))
plt.legend()
plt.show()

In [None]:
from sklearn.metrics import silhouette_score,silhouette_samples, accuracy_score, f1_score
import pandas as pd
import numpy as np
## 轮廓系数
sil_type = silhouette_samples(np.array(embedding_tsne), label)
sil_omic = silhouette_samples(np.array(embedding_tsne), metadata)
sil_f1 = (
            2
            * (1 - (sil_omic + 1) / 2)
            * (sil_type + 1)
            / 2
            / (1 - (sil_omic + 1) / 2 + (sil_type + 1) / 2)
        )
sil_type.mean(), sil_omic.mean(), sil_f1.mean()

In [14]:
## 保存
embedding = torch.load(resultname + 'embedding.pt', map_location=torch.device('cpu'))
embedding = embedding[0].detach().numpy()

label = torch.load(resultname + 'label.pt', map_location=torch.device('cpu')) # 原始的类型
label = label.squeeze().numpy() 

out = torch.load(resultname + 'out.pt', map_location=torch.device('cpu')) # 预测的类型
pre = out.argmax(dim=-1, keepdim=True).detach().squeeze().numpy()
pro = out.max(dim=-1, keepdim=True)[0].detach().squeeze().numpy()

metadata = torch.load(resultname + 'metadata.pt', map_location=torch.device('cpu')) # tech
metadata = metadata.squeeze().numpy() 

num_celltype = torch.load(resultname + 'num_celltype.pt', map_location=torch.device('cpu')) # 数字与类型的对应关系
print(num_celltype)

pro = out.max(dim=-1, keepdim=True)[0].detach().squeeze().numpy()
celltype = [num_celltype[i] for i in label]
pre = [num_celltype[i] for i in pre]
umap1 = embedding_tsne[:,0]
umap2 = embedding_tsne[:,1]
metadata = ["scATAC-seq" if i == 2 else "scRNA-seq" for i in metadata]
import pandas as pd
p = {   'UMAP1' : umap1,
        'UMAP2' : umap2,
        'celltype' : celltype,
        'pre' : pre,
        'tech' : metadata,
        'pro' : pro}
p = pd.DataFrame(p)
p.to_csv(resultname + 'data_list.csv', index=False)
p2 = p[p['tech'] == "scATAC-seq"]
p2.to_csv(resultname + 'data2_list.csv', index=False)

{0: 'Baso', 1: 'Bcell', 2: 'CD4', 3: 'CD8', 4: 'CLP', 5: 'CMP', 6: 'DC', 7: 'GMP', 8: 'HSC/MPP', 9: 'LMPP', 10: 'MEP', 11: 'Mono', 12: 'NK', 13: 'NaiveT', 14: 'early-Ery', 15: 'late-Ery', 16: 'pDC', 17: 'pro/pre-B'}


In [None]:
sum(p2['pre'] == p2['celltype'])/len(p2['pre'])

In [None]:
sum(dataset.label.squeeze().numpy()[dataset.n_data1:] == out.argmax(dim=-1, keepdim=True).detach().squeeze().numpy()[dataset.n_data1:])/len(dataset.label.squeeze().numpy()[dataset.n_data1:])