In [13]:
import numpy as np
import argparse
import json
import logging
from time import time
import os
import torch_geometric.transforms as T
from MyLoader import HeteroDataset
from torch_geometric.loader import HGTLoader, NeighborLoader
# from dataloader import DataLoaderMasking 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from model import HGT
import pandas as pd
import pickle
import math
from torch_geometric.datasets import OGB_MAG
import torch.nn.init as init
from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score, roc_auc_score,auc,balanced_accuracy_score,cohen_kappa_score,precision_recall_curve
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, EsmModel
import joblib

print(torch.cuda.is_available())

False


In [14]:
class args :
    def __init__(self):
        self.Full_data_path=r'../data/download_data/kgdata.pkl'
        self.node_type='gene/protein'
        self.Task_data_path = '../data/original_data/Cell_line_specific'
        self.Save_model_path = '../logs_models/train_logs_models/'
        self.processed_data_path = '../data/processed_data/'
        self.cv = 'CV3'
        self.n_fold = 5
        self.device = 'cuda:1'
        self.do_low_data = False
        self.sample_nodes = 1024
        self.sample_layers = 4
        self.num_workers = 8
        self.specific = True # Cell line specific
        self.adapted = True # Cell line adapted
        self.cell_line_list = ['A375','Jurkat','A549']
        self.test_cell_line = 'A549'
        self.freeze_graph_encoder = True
        self.freeze_esm_encoder = True
        self.folds = 5
        self.do_train = True
        self.train_batch_size = 128
        self.test_batch_size = 128
        self.hgt_emb_dim = 128
        self.hgt_num_heads = 4
        self.hgt_dropout_ratio = 0.2
        self.hgt_num_layer = 3
        self.mlp_hidden_dim = 128
        self.lr = 1e-5
        self.device_0 = 'cuda:0'
        self.device_1 = 'cuda:1'
        self.device_2 = 'cuda:2'
        self.device_3 = 'cuda:3'
        self.esm_sequence_max_length = 256
        self.epoch = 100
        self.use_esm_embedding = True
        self.esm_embedding_file = '../data/download_data/gene_esm2emb.pkl'
        
args=args()

In [3]:
def set_logger(args):
    '''
    Write logs to checkpoint and console 
    '''

    if args.do_train:
        # train_log=str(linear_layer_count)+'_'+args.lr+'_'+'train.log'
        log_file = os.path.join(args.Save_model_path or args.init_checkpoint, 'train.log') 
    else:
        log_file = os.path.join(args.Save_model_path or args.init_checkpoint, 'test.log') 
    
    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s', 
        level=logging.INFO,  # 
        datefmt='%Y-%m-%d %H:%M:%S', 
        filename=log_file, 
        filemode='w'  
    )
    console = logging.StreamHandler() # 
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 
    console.setFormatter(formatter) 
    logging.getLogger('').addHandler(console) 
    
def compute_accuracy(target, pred, pred_edge):
    target = np.array(target)
    pred = np.array(pred)
    pred_edge = np.array(pred_edge)
    
    # 转换为 PyTorch 张量
    pred_edge_tensor = torch.tensor(pred_edge, dtype=torch.float32)
    scores = torch.softmax(pred_edge_tensor, dim=1).numpy()

    
    target = target.astype(int)
    

    
    # 计算各项指标
    aucu = roc_auc_score(target, scores[:, 1])
    precision_tmp, recall_tmp, _thresholds = precision_recall_curve(target, scores[:, 1])
    aupr = auc(recall_tmp, precision_tmp)
    f1 = f1_score(target, pred)
    kappa = cohen_kappa_score(target, pred)
    bacc = balanced_accuracy_score(target, pred)
    
    return aucu, aupr, f1, kappa, bacc




def load_cell_line_gene_data(args, cell_line):
    """
    load cell line specific gene data
    """
    cell_line_gene_data = pd.read_csv(f"{args.processed_data_path}/{cell_line}_all_data_gene.csv")
    return cell_line_gene_data

def load_esm_embedding_data(args, node_index_data):
    esm_embedding = joblib.load(args.esm_embedding_file )
    esm_embedding_geneid = {}
    for key, value in esm_embedding.items():
        if key not in node_index_data['gene/protein']:
            mapped_key = key  # Use original key or a placeholder if needed
            esm_embedding_geneid[mapped_key] = torch.zeros(1280)
        else:
            mapped_key = node_index_data['gene/protein'][key]
            esm_embedding_geneid[mapped_key] = value
    return esm_embedding_geneid


def Downstream_data_preprocess(args,n_fold,node_type_dict,cell_line): #FIXME
    """
    load SL data and preprocess before training 
    """
    task_data_path=args.Task_data_path
    train_data=pd.read_csv(f"{task_data_path}/{cell_line}/sl_train_{n_fold}.csv")
    test_data=pd.read_csv(f"{task_data_path}/{cell_line}/sl_test_{n_fold}.csv",)
    train_data.columns=[0,1,2,3]
    test_data.columns=[0,1,2,3]
    train_data[0]=train_data[0].astype(str).map(node_type_dict)
    train_data[1]=train_data[1].astype(str).map(node_type_dict)
    test_data[0]=test_data[0].astype(str).map(node_type_dict)
    test_data[1]=test_data[1].astype(str).map(node_type_dict)
    train_data=train_data.dropna()
    test_data=test_data.dropna()
    train_data[0]=train_data[0].astype(int)
    train_data[1]=train_data[1].astype(int)
    test_data[0]=test_data[0].astype(int)
    test_data[1]=test_data[1].astype(int)
    # low data scenario settings
    if args.do_low_data:
        num_sample=int(train_data.shape[0]*args.train_data_ratio)
        print(num_sample)
        train_data=train_data.sample(num_sample,replace=False,random_state=0)
        train_data.reset_index(inplace=True)
        print(f'train_data.size:{train_data.shape[0]}')

    train_node=list(set(train_data[0])|set(train_data[1]))
    print(f'train_node.size:{len(train_node)}')
    train_mask=torch.zeros((27671))
    test_mask=torch.zeros((27671))
    test_node=list(set(test_data[0])|set(test_data[1]))
    train_mask[train_node]=1
    test_mask[test_node]=1
    train_mask=train_mask.bool()
    test_mask=test_mask.bool()
    num_train_node=len(train_node)
    num_test_node=len(test_node)
    return train_data,test_data,train_mask,test_mask,num_train_node,num_test_node



class GenePairDataset(Dataset):
    def __init__(self, gene_pairs: pd.DataFrame):
        # drop column 2
        self.gene_pairs = gene_pairs.drop(columns=2).values
    
    def __len__(self):
        return len(self.gene_pairs)
    
    def __getitem__(self, idx):
        return self.gene_pairs[idx]


class sequence_dataset(Dataset):
    def __init__(self,sequence_data):
        self.sequence_data=sequence_data
    def __len__(self):
        return len(self.sequence_data)
    def __getitem__(self,idx):
        return self.sequence_data[idx]

def Construct_loader(args,kgdata,cell_line_gene_data,sequence_data,train_mask,test_mask,node_type,train_batch_size,test_batch_size):
    """
    construct loader for train/test data
    """
    
    train_loader = HGTLoader(kgdata,
    num_samples={key: [args.sample_nodes] * args.sample_layers for key in kgdata.node_types},shuffle=False,
    batch_size=train_batch_size,
    input_nodes=(node_type,train_mask),num_workers=args.num_workers)

    return train_loader



In [4]:
# main
set_logger(args)
with open (args.Full_data_path,'rb') as f:
    kgdata=pickle.load(f)
with open("../data/processed_data/gene_protein_2_id.json",'rb') as f:
    node_index=json.load(f)
sequence_data = pd.read_csv('../data/train_data/uniprot_results_filtered.csv')
sequence_data['Gene_id'] = sequence_data['Gene Name'].map(node_index['gene/protein'])
torch.manual_seed(0)
np.random.seed(0)
device = torch.device(args.device_0 )
device_0 = torch.device(args.device_0)
device_1 = torch.device(args.device_1)
device_2 = torch.device(args.device_2)
device_3 = torch.device(args.device_3)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

gene_protein=node_index[args.node_type] 
eval_metric_folds={'fold':[],'auc':[],'aupr':[],'f1':[],'bacc':[],'kappa':[]}
node_type=args.node_type
num_nodes_type=len(kgdata.node_types)
num_edge_type=len(kgdata.edge_types)
num_nodes=kgdata.num_nodes
input_node_embeddings = torch.nn.Embedding(num_nodes_type, 16)
torch.nn.init.xavier_uniform_(input_node_embeddings.weight.data)
for i in range(len(kgdata.node_types)):
    num_repeat=kgdata[kgdata.node_types[i]].x.shape[0]
    kgdata[kgdata.node_types[i]].x =input_node_embeddings(torch.tensor(i)).repeat([num_repeat,1]).detach()

HGT_model = HGT(kgdata,2*args.hgt_emb_dim,args.hgt_emb_dim,args.hgt_num_heads,args.hgt_num_layer).to(args.device_0)
# HGT_model = nn.DataParallel(HGT_model, device_ids=[args.device_0, args.device_1, args.device_2, args.device_3]) 
# freeze
# if args.freeze_graph_encoder:
#     for param in HGT_model.parameters():
#         param.requires_grad = False
        

if not args.use_esm_embedding:
    ESM_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
    ESM_model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D").to(args.device_3)
    ESM_model = nn.DataParallel(ESM_model, device_ids=[args.device_3, args.device_2, args.device_1])  # 使用两个设备
    # freeze
    if args.freeze_esm_encoder:
        for param in ESM_model.parameters():
            param.requires_grad = False
else:
    esm_embedding_geneid = load_esm_embedding_data(args, node_index)

mlp = nn.Sequential(
                nn.Linear(2*args.hgt_emb_dim+2*320+2*3, args.mlp_hidden_dim),
                nn.ReLU(),
                nn.Linear(args.mlp_hidden_dim, 32),
                nn.ReLU(),
                nn.Linear(32, 2),
                ).to(args.device_2)

optimizer = torch.optim.Adam(mlp.parameters(), lr=args.lr)

if args.specific:# Do cv within cell line
    logging.info(f"Cell line specific")
    for cell_line in args.cell_line_list:
        cell_line_gene_data = load_cell_line_gene_data(args, cell_line) 
        cell_line_gene_data['Gene_id'] = cell_line_gene_data['Gene Name'].map(node_index['gene/protein'])   
        for fold in range(args.folds):
            n_fold = fold
            train_data,test_data,train_mask,test_mask,num_train_node,num_test_node=Downstream_data_preprocess(args,n_fold,gene_protein,cell_line)
            loss_sum = 0
            aucu_sum=0
            f1_sum=0
            bacc_sum=0
            kappa_sum=0
            aupr_sum=0
            edge_used=[]
             # map gene name(column name) to gene id
            training_logs = []
            testing_logs=[]
            prediction_result_log_fold=[]
            label_log_fold = []
        
            auc_sum_fold=[]
            aupr_sum_fold=[]
            f1_sum_fold=[]
            bacc_sum_fold=[]
            kappa_sum_fold=[]
            for s in tqdm(range(args.epoch)):
                gene_pair_loader = DataLoader(GenePairDataset(train_data), batch_size=args.train_batch_size, shuffle=True)
                # Train
                prediction_result_log_epoch=[]
                label_log_epoch = []
                for step,batch in enumerate(tqdm(gene_pair_loader)):
                    node_a = batch[:, 0]
                    node_b = batch[:, 1]
                    node = torch.cat([node_a, node_b], dim=0)
                    label = batch[:, 2].to(args.device_2)
                    node_set = set(node_a.numpy()) | set(node_b.numpy())
                    unique_node = list(node_set)
                    len_unique_node = len(unique_node)
                    node_mask = torch.zeros((27671)) # The number of gene/protein nodes in kg
                    node_mask[unique_node] = 1
                    node_mask = node_mask.bool()
                    # find corresponding sequence data of node
                    sequence_batch = []
                    for one_node in node:
                        sequence = sequence_data[sequence_data['Gene_id'] == one_node.item()]['Sequence'].values[0]
                        # cut
                        sequence = sequence[:args.esm_sequence_max_length]
                        sequence_batch.append(sequence_data[sequence_data['Gene_id'] == one_node.item()]['Sequence'].values[0])
                    
                    kg_loader = HGTLoader(kgdata,
                        num_samples={key: [args.sample_nodes] * args.sample_layers for key in kgdata.node_types},
                        shuffle=False,
                        batch_size=len_unique_node,
                        input_nodes=(node_type,node_mask),
                        num_workers=args.num_workers) # FIXME 这里效率瓶颈，每次都要重新构建loader，可以考虑提前构建好...吗？
                    # get the whole batch of kg_loader
                    
                    
                    # HGT forward
                    for kg_batch in kg_loader:
                        break
                    kg_batch.to(args.device_0)
                    node_rep= HGT_model(kg_batch.x_dict, kg_batch.edge_index_dict)
                    node_rep=node_rep[node_type]
                    node_set=pd.DataFrame(list(kg_batch[node_type].n_id[:len_unique_node].squeeze().detach().cpu().numpy()))
                    node_set.drop_duplicates(inplace=True,keep='first')
                    node_set[1]=range(node_set.shape[0])
                    node_map=dict(zip(node_set[0],node_set[1]))
                    # batch to pandas
                    batch=pd.DataFrame(batch.numpy())
                    # column name 0,1,2
                    prediction_edge=batch[[0,1]]
                    prediction_label=batch[2]
                    edge_used.append(prediction_edge.shape[0])
                    edge_a,edge_b=prediction_edge[0],prediction_edge[1]
                    edge_a=edge_a.map(node_map)
                    edge_b=edge_b.map(node_map)
                    HGT_nodea_emb=node_rep[edge_a.values]
                    HGT_nodeb_emb=node_rep[edge_b.values]
                    
                    # ESM forward
                    # tokenize sequence
                    if not args.use_esm_embedding:
                        esm_input = ESM_tokenizer(sequence_batch,padding = True,truncation=True,return_tensors='pt')
                        esm_input.to(args.device_3)
                        sequence_batch_embedding = ESM_model(**esm_input).pooler_output
                        ESM_nodea_emb = sequence_batch_embedding[:len(node_a)]
                        ESM_nodeb_emb = sequence_batch_embedding[len(node_a):]
                    else:
                        ESM_nodea_emb = torch.stack([esm_embedding_geneid[one_node.item()] for one_node in node_a])
                        ESM_nodeb_emb = torch.stack([esm_embedding_geneid[one_node.item()] for one_node in node_b])
                        # do a linear forward to 320 dim
                        ESM_nodea_emb = nn.Linear(1280, 320)(ESM_nodea_emb)
                        ESM_nodeb_emb = nn.Linear(1280, 320)(ESM_nodeb_emb)
                        
                    
                    cell_line_gene_data_nodea = []
                    for one_node in node_a:
                        # 使用loc来选择特定列并进行条件筛选
                        selected_data = cell_line_gene_data.loc[cell_line_gene_data['Gene_id'] == one_node.item(), ['CN', 'Expression', 'HotspotMutation']]
                        
                        # 将选中的数据转换为列表并添加到结果列表中
                        cell_line_gene_data_nodea.append(selected_data.values.tolist())
                    cell_line_gene_data_nodea_embedding = torch.tensor(np.array(cell_line_gene_data_nodea).squeeze())
                    cell_line_gene_data_nodeb = []
                    for one_node in node_b:
                        selected_data = cell_line_gene_data.loc[cell_line_gene_data['Gene_id'] == one_node.item(), ['CN', 'Expression', 'HotspotMutation']]
                        cell_line_gene_data_nodeb.append(selected_data.values.tolist())
                    cell_line_gene_data_nodeb_embedding = torch.tensor(np.array(cell_line_gene_data_nodeb).squeeze())
                    HGT_nodea_emb = HGT_nodea_emb.to(args.device_2)
                    HGT_nodeb_emb = HGT_nodeb_emb.to(args.device_2)                
                    ESM_nodea_emb = ESM_nodea_emb.to(args.device_2)
                    ESM_nodeb_emb = ESM_nodeb_emb.to(args.device_2)
                    cell_line_gene_data_nodea_embedding = cell_line_gene_data_nodea_embedding.to(args.device_2)
                    cell_line_gene_data_nodeb_embedding = cell_line_gene_data_nodeb_embedding.to(args.device_2)

                    nodea_embedding = torch.cat([HGT_nodea_emb, ESM_nodea_emb, cell_line_gene_data_nodea_embedding], dim=1)
                    nodeb_embedding = torch.cat([HGT_nodeb_emb, ESM_nodeb_emb, cell_line_gene_data_nodeb_embedding], dim=1)
                    # embedding to float
                    edge_embedding = torch.cat([nodea_embedding, nodeb_embedding], dim=1).float()

                    embedding_dim = edge_embedding.shape[1]
                    prediction_result = mlp(edge_embedding)
                    prediction_result_log_epoch.append(prediction_result.detach().cpu().numpy())
                    label_log_epoch.append(label.tolist())
                    
                    
                    
                    criterion = nn.CrossEntropyLoss()
                    loss = criterion(prediction_result, label)
                    loss.backward()
                    optimizer.step()
                    
                # prediction_result = prediction_result.flatten() 
                # prediction_result_log: (Step, Batch, 2) -> (Step * Batch, 2)
                prediction_result_log_epoch = np.concatenate(prediction_result_log_epoch)
                from itertools import chain
                label_log_epoch_flat = np.array(list(chain.from_iterable(label_log_epoch)))
                
                aucu, aupr, f1, kappa, bacc = compute_accuracy(label_log_epoch_flat, np.array(prediction_result_log_epoch).argmax(axis=1), prediction_result_log_epoch)
                auc_sum_fold.append(aucu)
                aupr_sum_fold.append(aupr)
                f1_sum_fold.append(f1)
                bacc_sum_fold.append(bacc)
                kappa_sum_fold.append(kappa)
                logging.info(f"Step {step}, Loss: {loss.item()}, AUC: {aucu}, AUPR: {aupr}, F1: {f1}, Kappa: {kappa}, BAcc: {bacc}")
                print(f"Step {step}, Loss: {loss.item()}, AUC: {aucu}, AUPR: {aupr}, F1: {f1}, Kappa: {kappa}, BAcc: {bacc}")
            prediction_result_log_fold.append(prediction_result_log_epoch)
            label_log_fold.append(label_log_epoch)
        break



RuntimeError: The NVIDIA driver on your system is too old (found version 11000). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver.

In [None]:
label_log_epoch