In [None]:
import numpy as np
import argparse
import json
import logging
from time import time
import os
import torch_geometric.transforms as T
from MyLoader import HeteroDataset
from torch_geometric.loader import HGTLoader, NeighborLoader
# from dataloader import DataLoaderMasking 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from model import HGT, HGT4Classification, HGT_ESM_4Classification, HGT_ESM_Attention_4Classification, HGT_ESM_CLdata_4Classification
import pandas as pd
import pickle
import math
from torch_geometric.datasets import OGB_MAG
import torch.nn.init as init
from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score, roc_auc_score,auc,balanced_accuracy_score,cohen_kappa_score,precision_recall_curve, average_precision_score
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, EsmModel
import joblib
import torch_sparse
from itertools import chain
import datetime

from utils import generate_log_dir,set_logger, compute_accuracy, Downstream_data_preprocess, override_config, GenePairDataset, sequence_dataset , create_optimizer,FocalLoss , override_config, load_esm_embedding_data, load_cell_line_gene_data

print(torch.cuda.is_available())

In [None]:
class args :
    def __init__(self):
        self.Full_data_path=r'../data/download_data/kgdata.pkl'
        self.node_type='gene/protein'
        self.Task_data_path = '../data/train_data/Cell_line_specific'
        self.Save_model_path = '../logs_models/train_logs_models/'
        self.processed_data_path = '../data/processed_data/'
        self.init_checkpoint = '../logs_models/pretrained_models/Primekg_HGT_0.2_0.001'
        self.cv = 'CV3'
        self.do_low_data = False
        self.sample_nodes = 1024
        self.sample_layers = 4
        self.num_workers = 8
        self.specific = True # Cell line specific
        self.adapted = True # Cell line adapted
        self.cell_line_list = ['Jurkat']
        self.test_cell_line = 'A549'
        self.freeze_graph_encoder = True
        self.freeze_esm_encoder = True
        self.folds = 5
        self.do_train = True
        self.train_batch_size = 512
        self.test_batch_size = 512
        self.hgt_emb_dim = 128
        self.hgt_num_heads = 4
        self.hgt_dropout_ratio = 0.2
        self.hgt_num_layer = 4
        self.mlp_hidden_dim = 256
        self.lr = 1e-5
        self.device = 'cuda:2'
        self.device_0 = 'cuda:0'
        self.device_1 = 'cuda:1'
        self.device_2 = 'cuda:2'
        self.device_3 = 'cuda:3'
        self.esm_sequence_max_length = 256
        self.epoch = 20
        self.use_esm_embedding = True
        self.esm_embedding_file = '../data/download_data/gene_esm2emb.pkl'
        self.decay = 1e-6
        self.attention_classifier_num_heads = 4
        self.weight_decay = 1e-5
        self.esm_reduction_dim = 256
        self.hgt_lr = 1e-5
        self.fc_lr = 1e-4
        self.fc_weight_decay = 1e-5
        self.base_weight_decay = 0
        self.use_layer_lr = False
        self.loss_function = 'FocalLoss'
        
        
args=args()

In [None]:
log_dir = set_logger(args)
logger = logging.getLogger('')  # 获取默认日志记录器
with open (args.Full_data_path,'rb') as f:
    kgdata=pickle.load(f)
    
logger.info("Loaded kgdata from {}".format(args.Full_data_path))

with open("../data/processed_data/gene_protein_2_id.json",'rb') as f:
    node_index=json.load(f)
sequence_data = pd.read_csv('../data/train_data/uniprot_results_filtered.csv')
sequence_data['Gene_id'] = sequence_data['Gene Name'].map(node_index['gene/protein'])
if args.init_checkpoint:  
    override_config(args)

esm_embedding_geneid = load_esm_embedding_data(args, node_index)
logger.info("Loaded ESM embeddings from {}".format(args.esm_embedding_file))

    
torch.manual_seed(0)
np.random.seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)


gene_protein=node_index[args.node_type] 
eval_metric_folds={'fold':[],'auc':[],'aupr':[],'f1':[],'bacc':[],'kappa':[]}
node_type=args.node_type
num_nodes_type=len(kgdata.node_types)
num_edge_type=len(kgdata.edge_types)
num_nodes=kgdata.num_nodes
input_node_embeddings = torch.nn.Embedding(num_nodes_type, 16)
torch.nn.init.xavier_uniform_(input_node_embeddings.weight.data)
for i in range(len(kgdata.node_types)):
    num_repeat=kgdata[kgdata.node_types[i]].x.shape[0]
    kgdata[kgdata.node_types[i]].x =input_node_embeddings(torch.tensor(i)).repeat([num_repeat,1]).detach()



eval_metric_folds = {'fold':[], 'auc':[], 'aupr':[], 'f1':[], 'bacc':[], 'kappa':[]}

if args.loss_function == 'FocalLoss':
    criterion = FocalLoss(alpha = 0.9, gamma = 2.0, reduction = 'mean') 
    logging.info("Using criterion FocalLoss, alpha = 0.9, gamma = 2.0")
elif args.loss_function == 'CrossEntropyLoss':
    criterion = nn.CrossEntropyLoss()
    


logging.info(f"Cell line specific")
for cell_line in args.cell_line_list:
    
    cell_line_log_dir = os.path.join(log_dir, cell_line)
    if not os.path.exists(cell_line_log_dir):
        os.makedirs(cell_line_log_dir)
    cell_line_gene_data = load_cell_line_gene_data(args, cell_line) 
    cell_line_gene_data['Gene_id'] = cell_line_gene_data['Gene Name'].map(node_index['gene/protein'])   
    logger.info(f"Loaded {cell_line} gene data from {args.Task_data_path}")
    
    for fold in range(args.folds):

        fold_log_dir = os.path.join(cell_line_log_dir, f'fold_{fold}')
        if not os.path.exists(fold_log_dir):
            os.makedirs(fold_log_dir)
        best_auc = 0
        
        HGT_model = HGT(kgdata,2*args.hgt_emb_dim,args.hgt_emb_dim,args.hgt_num_heads,args.hgt_num_layer).to(args.device)

        if args.init_checkpoint: 
            # Restore model from checkpoint directory  
            logging.info('Loading checkpoint %s...' % args.init_checkpoint)
            checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
            init_step = checkpoint['step']
            HGT_model.load_state_dict(checkpoint['model_state_dict'])
            
        HGT_ESM_CLdata_4Classification_model = HGT_ESM_CLdata_4Classification(args, HGT_model).to(args.device)
        logging.info(f"Using HGT_ESM_CLdata_4Classification_model")
        
        
        if args.use_layer_lr:
            optimizer_model = create_optimizer(HGT_ESM_CLdata_4Classification_model,args.hgt_lr, args.fc_lr, args.base_weight_decay, args.fc_weight_decay)
        else:
            optimizer_model = optim.Adam(HGT_ESM_CLdata_4Classification_model.parameters(), lr=args.hgt_lr, weight_decay=args.fc_weight_decay)
        
        train_data,test_data,train_mask,test_mask,num_train_node,num_test_node=Downstream_data_preprocess(args,fold,gene_protein,cell_line)
        best_metrics = {'auc': 0, 'aupr': 0, 'f1': 0, 'bacc': 0, 'kappa': 0}
        loss_sum = 0
        edge_used=[]
            # map gene name(column name) to gene id
        training_logs = []
        testing_logs=[]
        prediction_result_log_fold=[]
        label_log_fold = []
    
        auc_sum_fold=[]
        aupr_sum_fold=[]
        f1_sum_fold=[]
        bacc_sum_fold=[]
        kappa_sum_fold=[]
        
        best_model_path = os.path.join(log_dir, 'best_model.pth')
        
        logger.info(f"Training {cell_line} fold {fold}")
        
        for epoch in tqdm(range(args.epoch)):
            HGT_ESM_CLdata_4Classification_model.train()
            gene_pair_loader = DataLoader(GenePairDataset(train_data), batch_size=args.train_batch_size, shuffle=True)
            # Train
            prediction_result_log_epoch=[]
            label_log_epoch = []
            loss_sum = 0
            for step,batch in enumerate(tqdm(gene_pair_loader)):
                optimizer_model.zero_grad()
                node_a = batch[:, 0]
                node_b = batch[:, 1]
                node = torch.cat([node_a, node_b], dim=0)
                label = batch[:, 2].to(args.device)
                node_set = set(node_a.numpy()) | set(node_b.numpy())
                unique_node = list(node_set)
                len_unique_node = len(unique_node)
                node_mask = torch.zeros((27671)) # The number of gene/protein nodes in kg
                node_mask[unique_node] = 1
                node_mask = node_mask.bool()
                
                kg_loader = HGTLoader(kgdata,
                    num_samples={key: [args.sample_nodes] * args.sample_layers for key in kgdata.node_types},
                    shuffle=False,
                    batch_size=len_unique_node,
                    input_nodes=(node_type,node_mask),
                    num_workers=args.num_workers) 
                for kg_batch in kg_loader:
                    break
                kg_batch.to(args.device)
                ESM_nodea_emb = torch.stack([esm_embedding_geneid[one_node.item()] for one_node in node_a]).to(args.device)
                ESM_nodeb_emb = torch.stack([esm_embedding_geneid[one_node.item()] for one_node in node_b]).to(args.device)
                
                cell_line_gene_data_nodea = []
                for one_node in node_a:
                    selected_data = cell_line_gene_data.loc[cell_line_gene_data['Gene_id'] == one_node.item(), ['CN', 'Expression', 'HotspotMutation']]
                    cell_line_gene_data_nodea.append(selected_data.values.tolist())
                cell_line_gene_data_nodea_embedding = torch.tensor(np.array(cell_line_gene_data_nodea).squeeze())
                cell_line_gene_data_nodeb = []
                for one_node in node_b:
                    selected_data = cell_line_gene_data.loc[cell_line_gene_data['Gene_id'] == one_node.item(), ['CN', 'Expression', 'HotspotMutation']]
                    cell_line_gene_data_nodeb.append(selected_data.values.tolist())
                cell_line_gene_data_nodeb_embedding = torch.tensor(np.array(cell_line_gene_data_nodeb).squeeze())
                
                cell_line_gene_data_nodea_embedding = cell_line_gene_data_nodea_embedding.float().to(args.device)
                cell_line_gene_data_nodeb_embedding = cell_line_gene_data_nodeb_embedding.float().to(args.device)
                
                prediction_result = HGT_ESM_CLdata_4Classification_model(args.node_type, args.train_batch_size , kg_batch,batch,ESM_nodea_emb,ESM_nodeb_emb, cell_line_gene_data_nodea_embedding, cell_line_gene_data_nodeb_embedding)
                


                label = label.long()
                loss = criterion(prediction_result, label)
                loss.backward()
                loss_sum += loss.item()
                torch.nn.utils.clip_grad_norm_(HGT_ESM_CLdata_4Classification_model.parameters(), max_norm=1.0)
                optimizer_model.step()
            
            # valid
            HGT_ESM_CLdata_4Classification_model.eval()
            gene_pair_loader = DataLoader(GenePairDataset(test_data), batch_size=args.test_batch_size, shuffle=False)
            edge_used=[]
            with torch.no_grad():
                for step,batch in enumerate(tqdm(gene_pair_loader)):
                    node_a = batch[:, 0]
                    node_b = batch[:, 1]
                    node = torch.cat([node_a, node_b], dim=0)
                    label = batch[:, 2].to(args.device)
                    node_set = set(node_a.numpy()) | set(node_b.numpy())
                    unique_node = list(node_set)
                    len_unique_node = len(unique_node)
                    node_mask = torch.zeros((27671))
                    node_mask[unique_node] = 1
                    node_mask = node_mask.bool()
                    
                    kg_loader = HGTLoader(kgdata,
                        num_samples={key: [args.sample_nodes] * args.sample_layers for key in kgdata.node_types},
                        shuffle=False,
                        batch_size=len_unique_node,
                        input_nodes=(node_type,node_mask),
                        num_workers=args.num_workers)
                    
                    for kg_batch in kg_loader:
                        break
                    kg_batch.to(args.device)
                    
                    ESM_nodea_emb = torch.stack([esm_embedding_geneid[one_node.item()] for one_node in node_a]).to(args.device)
                    ESM_nodeb_emb = torch.stack([esm_embedding_geneid[one_node.item()] for one_node in node_b]).to(args.device)
                    
                    cell_line_gene_data_nodea = []
                    for one_node in node_a:
                        selected_data = cell_line_gene_data.loc[cell_line_gene_data['Gene_id'] == one_node.item(), ['CN', 'Expression', 'HotspotMutation']]
                        cell_line_gene_data_nodea.append(selected_data.values.tolist())
                    cell_line_gene_data_nodea_embedding = torch.tensor(np.array(cell_line_gene_data_nodea).squeeze())
                    cell_line_gene_data_nodeb = []
                    for one_node in node_b:
                        selected_data = cell_line_gene_data.loc[cell_line_gene_data['Gene_id'] == one_node.item(), ['CN', 'Expression', 'HotspotMutation']]
                        cell_line_gene_data_nodeb.append(selected_data.values.tolist())
                    cell_line_gene_data_nodeb_embedding = torch.tensor(np.array(cell_line_gene_data_nodeb).squeeze())
                    
                    cell_line_gene_data_nodea_embedding = cell_line_gene_data_nodea_embedding.float().to(args.device)
                    cell_line_gene_data_nodeb_embedding = cell_line_gene_data_nodeb_embedding.float().to(args.device)
                    
                    prediction_result = HGT_ESM_CLdata_4Classification_model(args.node_type, args.train_batch_size , kg_batch,batch,ESM_nodea_emb,ESM_nodeb_emb, cell_line_gene_data_nodea_embedding, cell_line_gene_data_nodeb_embedding)
                    prediction_result_log_epoch.append(prediction_result.detach().cpu().numpy())
                    label_log_epoch.append(label.tolist())
            prediction_result_log_epoch = np.concatenate(prediction_result_log_epoch)
            label_log_epoch_flat = np.array(list(chain.from_iterable(label_log_epoch)))
            


            aucu, aupr, f1, kappa, bacc = compute_accuracy(label_log_epoch_flat, np.array(prediction_result_log_epoch).argmax(axis=1), prediction_result_log_epoch)
            auc_sum_fold.append(aucu)
            aupr_sum_fold.append(aupr)
            f1_sum_fold.append(f1)
            bacc_sum_fold.append(bacc)
            kappa_sum_fold.append(kappa)
            if aucu > best_metrics['auc']:
                best_metrics = {'auc': aucu, 'aupr': aupr, 'f1': f1, 'bacc': bacc, 'kappa': kappa}
                best_model_path = os.path.join(fold_log_dir, 'best_model.pth')
                torch.save(HGT_ESM_CLdata_4Classification_model.state_dict(), best_model_path)
            logger.info(f"Epoch {epoch}, Loss: {loss_sum}, AUC: {aucu}, AUPR: {aupr}, F1: {f1}, Kappa: {kappa}, BAcc: {bacc}")
            
        eval_metric_folds['fold'].append(fold)
        eval_metric_folds['auc'].append(best_metrics['auc'])
        eval_metric_folds['aupr'].append(best_metrics['aupr'])
        eval_metric_folds['f1'].append(best_metrics['f1'])
        eval_metric_folds['bacc'].append(best_metrics['bacc'])
        eval_metric_folds['kappa'].append(best_metrics['kappa'])
        
        
    avg_metrics = {key: np.mean(values) for key, values in eval_metric_folds.items() if key != 'fold'}
    # cell line
    logger.info(f"{cell_line} Average Metrics:{avg_metrics}")
    best_metrics = {key: max(values) for key, values in eval_metric_folds.items() if key != 'fold'}
    logger.info(f"{cell_line} Best Metrics: {best_metrics}")