In [None]:
import numpy as np
import argparse
import json
import logging
from time import time
import os
import torch_geometric.transforms as T
from MyLoader import HeteroDataset
from torch_geometric.loader import HGTLoader, NeighborLoader
# from dataloader import DataLoaderMasking 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from model import HGT
import pandas as pd
import pickle
import math
from torch_geometric.datasets import OGB_MAG
import torch.nn.init as init
from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score, roc_auc_score,auc,balanced_accuracy_score,cohen_kappa_score,precision_recall_curve, average_precision_score
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, EsmModel
import joblib
import torch_sparse
from itertools import chain

print(torch.cuda.is_available())

In [None]:
class args :
    def __init__(self):
        self.Full_data_path=r'../data/download_data/kgdata.pkl'
        self.node_type='gene/protein'
        self.Task_data_path = '../data/train_data/Cell_line_specific'
        self.Save_model_path = '../logs_models/train_logs_models/'
        self.processed_data_path = '../data/processed_data/'
        self.init_checkpoint = '../logs_models/pretrained_models/Primekg_HGT_0.2_0.001'
        self.cv = 'CV3'
        self.n_fold = 5
        self.do_low_data = False
        self.sample_nodes = 1024
        self.sample_layers = 4
        self.num_workers = 8
        self.specific = True # Cell line specific
        self.adapted = True # Cell line adapted
        self.cell_line_list = ['A549']
        self.test_cell_line = 'A549'
        self.freeze_graph_encoder = True
        self.freeze_esm_encoder = True
        self.folds = 5
        self.do_train = True
        self.train_batch_size = 256
        self.test_batch_size = 256
        self.hgt_emb_dim = 128
        self.hgt_num_heads = 4
        self.hgt_dropout_ratio = 0.2
        self.hgt_num_layer = 4
        self.mlp_hidden_dim = 128
        self.lr = 1e-5
        self.device = 'cuda:2'
        self.device_0 = 'cuda:0'
        self.device_1 = 'cuda:1'
        self.device_2 = 'cuda:2'
        self.device_3 = 'cuda:3'
        self.esm_sequence_max_length = 256
        self.epoch = 30
        self.use_esm_embedding = True
        self.esm_embedding_file = '../data/download_data/gene_esm2emb.pkl'
        self.decay = 0
        
        
args=args()

In [None]:
def set_logger(args):
    '''
    Write logs to checkpoint and console 
    '''

    if args.do_train:
        # train_log=str(linear_layer_count)+'_'+args.lr+'_'+'train.log'
        log_file = os.path.join(args.Save_model_path or args.init_checkpoint, 'train.log') 
    else:
        log_file = os.path.join(args.Save_model_path or args.init_checkpoint, 'test.log') 
    
    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s', 
        level=logging.INFO,  # 
        datefmt='%Y-%m-%d %H:%M:%S', 
        filename=log_file, 
        filemode='w'  
    )
    console = logging.StreamHandler() # 
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 
    console.setFormatter(formatter) 
    logging.getLogger('').addHandler(console) 
    
def compute_accuracy(target, pred, pred_edge):
    target = np.array(target)
    pred = np.array(pred)
    pred_edge = np.array(pred_edge)
    
    # 转换为 PyTorch 张量
    pred_edge_tensor = torch.tensor(pred_edge, dtype=torch.float32)
    scores = torch.softmax(pred_edge_tensor, dim=1).numpy()

    
    target = target.astype(int)
    

    
    # 计算各项指标
    aucu = roc_auc_score(target, scores[:, 1])
    precision_tmp, recall_tmp, _thresholds = precision_recall_curve(target, scores[:, 1])
    aupr = auc(recall_tmp, precision_tmp)
    aupr= average_precision_score(target,scores[:,1])
    f1 = f1_score(target, pred)
    kappa = cohen_kappa_score(target, pred)
    bacc = balanced_accuracy_score(target, pred)
    
    return aucu, aupr, f1, kappa, bacc


def compute_accuracy_2(target,pred, pred_edge):
    
    target=target.clone().detach().cpu().numpy()
    pred=pred.clone().detach().cpu().numpy()
    pred_edge=pred_edge.clone().detach().cpu()
    scores = torch.softmax(pred_edge, 1).numpy()
    target=target.astype(int)
    
    print(target)
    print(scores[:,1])
   
    aucu=roc_auc_score(target,scores[:,1])
    precision_tmp, recall_tmp, _thresholds = precision_recall_curve(target, pred)
    aupr = auc(recall_tmp, precision_tmp)
    aupr= average_precision_score(target,scores[:,1])
    f1 = f1_score(target,pred)
    kappa=cohen_kappa_score(target,pred)
    bacc=balanced_accuracy_score(target,pred)
    
    return aucu,aupr,f1,kappa,bacc


def load_cell_line_gene_data(args, cell_line):
    """
    load cell line specific gene data
    """
    cell_line_gene_data = pd.read_csv(f"{args.processed_data_path}/{cell_line}_all_data_gene.csv")
    return cell_line_gene_data



def load_esm_embedding_data(args, node_index_data):
    esm_embedding = joblib.load(args.esm_embedding_file )
    esm_embedding_geneid = {}
    for key, value in esm_embedding.items():
        if key not in node_index_data['gene/protein']:
            mapped_key = key  # Use original key or a placeholder if needed
            esm_embedding_geneid[mapped_key] = torch.zeros(1280)
        else:
            mapped_key = node_index_data['gene/protein'][key]
            esm_embedding_geneid[mapped_key] = value
    return esm_embedding_geneid


def Downstream_data_preprocess(args,n_fold,node_type_dict,cell_line): #FIXME
    """
    load SL data and preprocess before training 
    """
    task_data_path=args.Task_data_path
    train_data=pd.read_csv(f"{task_data_path}/{cell_line}/train_{n_fold}.csv")
    test_data=pd.read_csv(f"{task_data_path}/{cell_line}/valid_{n_fold}.csv",)
    train_data.columns=[0,1,2,3]
    test_data.columns=[0,1,2,3]
    train_data[0]=train_data[0].astype(str).map(node_type_dict)
    train_data[1]=train_data[1].astype(str).map(node_type_dict)
    test_data[0]=test_data[0].astype(str).map(node_type_dict)
    test_data[1]=test_data[1].astype(str).map(node_type_dict)
    train_data=train_data.dropna()
    test_data=test_data.dropna()
    train_data[0]=train_data[0].astype(int)
    train_data[1]=train_data[1].astype(int)
    test_data[0]=test_data[0].astype(int)
    test_data[1]=test_data[1].astype(int)
    # low data scenario settings
    if args.do_low_data:
        num_sample=int(train_data.shape[0]*args.train_data_ratio)
        print(num_sample)
        train_data=train_data.sample(num_sample,replace=False,random_state=0)
        train_data.reset_index(inplace=True)
        print(f'train_data.size:{train_data.shape[0]}')

    train_node=list(set(train_data[0])|set(train_data[1]))
    print(f'train_node.size:{len(train_node)}')
    train_mask=torch.zeros((27671))
    test_mask=torch.zeros((27671))
    test_node=list(set(test_data[0])|set(test_data[1]))
    train_mask[train_node]=1
    test_mask[test_node]=1
    train_mask=train_mask.bool()
    test_mask=test_mask.bool()
    num_train_node=len(train_node)
    num_test_node=len(test_node)
    return train_data,test_data,train_mask,test_mask,num_train_node,num_test_node

def override_config(args):
    '''
    Override model and data configuration 
    '''
    with open(os.path.join(args.init_checkpoint, 'config.json'), 'r') as fjson:
        argparse_dict = json.load(fjson)
    
    args.method=argparse_dict['method']
    # args.epochs = argparse_dict['epochs']
    args.lr = argparse_dict['lr']
    args.num_layer = argparse_dict['num_layer']
    args.emb_dim = argparse_dict['emb_dim']
    args.mask_rate = argparse_dict['mask_rate']
    args.gnn_type=argparse_dict['gnn_type']

    if args.Save_model_path is None:
        args.Save_model_path = argparse_dict['Save_model_path']

class GenePairDataset(Dataset):
    def __init__(self, gene_pairs: pd.DataFrame):
        # drop column 2
        self.gene_pairs = gene_pairs.drop(columns=2).values
    
    def __len__(self):
        return len(self.gene_pairs)
    
    def __getitem__(self, idx):
        return self.gene_pairs[idx]


class sequence_dataset(Dataset):
    def __init__(self,sequence_data):
        self.sequence_data=sequence_data
    def __len__(self):
        return len(self.sequence_data)
    def __getitem__(self,idx):
        return self.sequence_data[idx]

def Construct_loader(args,kgdata,cell_line_gene_data,sequence_data,train_mask,test_mask,node_type,train_batch_size,test_batch_size):
    """
    construct loader for train/test data
    """
    
    train_loader = HGTLoader(kgdata,
    num_samples={key: [args.sample_nodes] * args.sample_layers for key in kgdata.node_types},shuffle=False,
    batch_size=train_batch_size,
    input_nodes=(node_type,train_mask),num_workers=args.num_workers)

    return train_loader



# HGT Model with a classification head
class HGT4Classification(nn.Module):
    def __init__(self,args, hgt,emb_dim,hidden_dim, num_classes, len_unique_node):
        super(HGT4Classification, self).__init__()
        self.hgt = hgt
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )
        self.args = args
    
    def forward(self, kg_batch,batch):
        node_rep = self.hgt(kg_batch.x_dict, kg_batch.edge_index_dict)
        node_rep=   node_rep[args.node_type]
        node_set=pd.DataFrame(list(kg_batch[node_type].n_id[:len_unique_node].squeeze().detach().cpu().numpy()))
        node_set.drop_duplicates(inplace=True,keep='first')
        node_set[1]=range(node_set.shape[0])
        node_map=dict(zip(node_set[0],node_set[1]))
        batch=pd.DataFrame(batch.numpy())
        prediction_edge=batch[[0,1]]
        prediction_label=batch[2]
        edge_a,edge_b=prediction_edge[0],prediction_edge[1]
        edge_a=edge_a.map(node_map)
        edge_b=edge_b.map(node_map)
        HGT_nodea_emb=node_rep[edge_a.values]
        HGT_nodeb_emb=node_rep[edge_b.values]
        edge_embedding = torch.cat([HGT_nodea_emb, HGT_nodeb_emb], dim=1)
        emb_dim = edge_embedding.size(1)
        pred = self.mlp(edge_embedding)
        return pred
        
        
        
        

In [None]:
# main
set_logger(args)
with open (args.Full_data_path,'rb') as f:
    kgdata=pickle.load(f)
with open("../data/processed_data/gene_protein_2_id.json",'rb') as f:
    node_index=json.load(f)
sequence_data = pd.read_csv('../data/train_data/uniprot_results_filtered.csv')
sequence_data['Gene_id'] = sequence_data['Gene Name'].map(node_index['gene/protein'])
if args.init_checkpoint:  
    override_config(args)



    
torch.manual_seed(0)
np.random.seed(0)
device = torch.device(args.device_1 )
device_0 = torch.device(args.device_0)
device_1 = torch.device(args.device_1)
device_2 = torch.device(args.device_2)
device_3 = torch.device(args.device_3)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)
    
gene_protein=node_index[args.node_type] 
eval_metric_folds={'fold':[],'auc':[],'aupr':[],'f1':[],'bacc':[],'kappa':[]}
node_type=args.node_type
num_nodes_type=len(kgdata.node_types)
num_edge_type=len(kgdata.edge_types)
num_nodes=kgdata.num_nodes
input_node_embeddings = torch.nn.Embedding(num_nodes_type, 16)
torch.nn.init.xavier_uniform_(input_node_embeddings.weight.data)
for i in range(len(kgdata.node_types)):
    num_repeat=kgdata[kgdata.node_types[i]].x.shape[0]
    kgdata[kgdata.node_types[i]].x =input_node_embeddings(torch.tensor(i)).repeat([num_repeat,1]).detach()

HGT_model = HGT(kgdata,2*args.hgt_emb_dim,args.hgt_emb_dim,args.hgt_num_heads,args.hgt_num_layer).to(args.device)

if args.init_checkpoint: 
    # Restore model from checkpoint directory  
    logging.info('Loading checkpoint %s...' % args.init_checkpoint)
    checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
    init_step = checkpoint['step']
    HGT_model.load_state_dict(checkpoint['model_state_dict'])
        
HGT4Classification_model = HGT4Classification(args, HGT_model, args.hgt_emb_dim, 32, 2, args.train_batch_size).to(args.device)
# HGT4Classification_model = nn.DataParallel(HGT4Classification_model, device_ids=[0, 1, 2, 3])
optimizer_model = optim.Adam(HGT4Classification_model.parameters(), lr=args.lr, weight_decay=args.decay)

logging.info(f"Cell line specific")
for cell_line in args.cell_line_list:
    cell_line_gene_data = load_cell_line_gene_data(args, cell_line) 
    cell_line_gene_data['Gene_id'] = cell_line_gene_data['Gene Name'].map(node_index['gene/protein'])   
    for fold in range(args.folds):
        n_fold = fold
        train_data,test_data,train_mask,test_mask,num_train_node,num_test_node=Downstream_data_preprocess(args,n_fold,gene_protein,cell_line)
        loss_sum = 0
        aucu_sum=0
        f1_sum=0
        bacc_sum=0
        kappa_sum=0
        aupr_sum=0
        edge_used=[]
            # map gene name(column name) to gene id
        training_logs = []
        testing_logs=[]
        prediction_result_log_fold=[]
        label_log_fold = []
    
        auc_sum_fold=[]
        aupr_sum_fold=[]
        f1_sum_fold=[]
        bacc_sum_fold=[]
        kappa_sum_fold=[]
        for epoch in tqdm(range(args.epoch)):
            HGT4Classification_model.train()
            gene_pair_loader = DataLoader(GenePairDataset(train_data), batch_size=args.train_batch_size, shuffle=True)
            # Train
            prediction_result_log_epoch=[]
            label_log_epoch = []
            loss_sum = 0
            for step,batch in enumerate(tqdm(gene_pair_loader)):
                optimizer_model.zero_grad()
                node_a = batch[:, 0]
                node_b = batch[:, 1]
                node = torch.cat([node_a, node_b], dim=0)
                label = batch[:, 2].to(args.device)
                node_set = set(node_a.numpy()) | set(node_b.numpy())
                unique_node = list(node_set)
                len_unique_node = len(unique_node)
                node_mask = torch.zeros((27671)) # The number of gene/protein nodes in kg
                node_mask[unique_node] = 1
                node_mask = node_mask.bool()
                
                kg_loader = HGTLoader(kgdata,
                    num_samples={key: [args.sample_nodes] * args.sample_layers for key in kgdata.node_types},
                    shuffle=False,
                    batch_size=len_unique_node,
                    input_nodes=(node_type,node_mask),
                    num_workers=args.num_workers) 
                for kg_batch in kg_loader:
                    break
                kg_batch.to(args.device)
                prediction_result = HGT4Classification_model(kg_batch,batch)
                criterion = nn.CrossEntropyLoss()
                loss = criterion(prediction_result, label)
                loss.backward()
                loss_sum += loss.item()
                torch.nn.utils.clip_grad_norm_(HGT4Classification_model.parameters(), max_norm=1.0)
                optimizer_model.step()
            # print(f"Epoch {s} loss: {loss_sum}")
            
            # valid
            HGT4Classification_model.eval()
            gene_pair_loader = DataLoader(GenePairDataset(test_data), batch_size=args.test_batch_size, shuffle=False)
            aucu_sum = 0
            f1_sum=0
            bacc_sum=0
            kappa_sum=0
            aupr_sum=0
            edge_used=[]
            with torch.no_grad():
                for step,batch in enumerate(tqdm(gene_pair_loader)):
                    node_a = batch[:, 0]
                    node_b = batch[:, 1]
                    node = torch.cat([node_a, node_b], dim=0)
                    label = batch[:, 2].to(args.device)
                    node_set = set(node_a.numpy()) | set(node_b.numpy())
                    unique_node = list(node_set)
                    len_unique_node = len(unique_node)
                    node_mask = torch.zeros((27671))
                    node_mask[unique_node] = 1
                    node_mask = node_mask.bool()
                    
                    kg_loader = HGTLoader(kgdata,
                        num_samples={key: [args.sample_nodes] * args.sample_layers for key in kgdata.node_types},
                        shuffle=False,
                        batch_size=len_unique_node,
                        input_nodes=(node_type,node_mask),
                        num_workers=args.num_workers)
                    
                    for kg_batch in kg_loader:
                        break
                    kg_batch.to(args.device)
                    prediction_result = HGT4Classification_model(kg_batch,batch)
                    prediction_result_log_epoch.append(prediction_result.detach().cpu().numpy())
                    label_log_epoch.append(label.tolist())
            prediction_result_log_epoch = np.concatenate(prediction_result_log_epoch)
            label_log_epoch_flat = np.array(list(chain.from_iterable(label_log_epoch)))
            

            aucu, aupr, f1, kappa, bacc = compute_accuracy(label_log_epoch_flat, np.array(prediction_result_log_epoch).argmax(axis=1), prediction_result_log_epoch)
            auc_sum_fold.append(aucu)
            aupr_sum_fold.append(aupr)
            f1_sum_fold.append(f1)
            bacc_sum_fold.append(bacc)
            kappa_sum_fold.append(kappa)
            print(f"Epoch {epoch}, Loss: {loss_sum}, AUC: {aucu}, AUPR: {aupr}, F1: {f1}, Kappa: {kappa}, BAcc: {bacc}")

In [None]:
import math
# 初始化参数
N = 12  # 总牌数
K = 3   # 有奖牌数
r = 3   # 目标是找到第3张有奖牌

# 计算期望
E = 0   # 初始期望
p_miss = 1 # 未中奖概率
for i in range(r, N+1):
    p = (math.comb(K,r) * math.comb(N-K,i-r))/(math.comb(N,i))
    q = 1 - p
    p_miss *= q
    print("i:",i,"P(K=r)累乘:",1-p_miss)
    







: 

In [None]:
# test model loading
