In [1]:
import numpy as np
import argparse
import json
import logging
from time import time
import os
import torch_geometric.transforms as T
from MyLoader import HeteroDataset
from torch_geometric.loader import HGTLoader, NeighborLoader
# from dataloader import DataLoaderMasking 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from model import HGT
import pandas as pd
import pickle
import math
from torch_geometric.datasets import OGB_MAG
import torch.nn.init as init
from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score, roc_auc_score,auc,balanced_accuracy_score,cohen_kappa_score,precision_recall_curve

In [2]:
class args :
    def __init__(self):
        self.Full_data_path=r'../data/download_data/kgdata.pkl'
        self.node_type='gene/protein'
        self.Task_data_path = '../data/original_data/Cell_line_specific'
        self.cv = 'CV3'
        self.n_fold = 5
        self.device = 'cuda'
        self.do_low_data = False
        self.sample_nodes = 1024
        self.sample_layers = 4
        self.num_workers = 8
        self.specific = True # Cell line specific
        self.adapted = True # Cell line adapted
        self.cell_line_list = ['A375','Jurkat','A549']
        self.test_cell_line = 'A549'
        self.freeze_graph_encoder = True
        self.freeze_esm_encoder = True
        self.folds = 5
        self.do_train = True
        
args=args()

In [4]:
def set_logger(args):
    '''
    Write logs to checkpoint and console 
    '''

    if args.do_train:
        # train_log=str(linear_layer_count)+'_'+args.lr+'_'+'train.log'
        log_file = os.path.join(args.Save_model_path or args.init_checkpoint, 'train.log') 
    else:
        log_file = os.path.join(args.Save_model_path or args.init_checkpoint, 'test.log') 
    
    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s', 
        level=logging.INFO,  # 
        datefmt='%Y-%m-%d %H:%M:%S', 
        filename=log_file, 
        filemode='w'  
    )
    console = logging.StreamHandler() # 
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 
    console.setFormatter(formatter) 
    logging.getLogger('').addHandler(console) 
    
def compute_accuracy(target,pred, pred_edge):
    
    target=target.clone().detach().cpu().numpy()
    pred=pred.clone().detach().cpu().numpy()
    pred_edge=pred_edge.clone().detach().cpu()
    scores = torch.softmax(pred_edge, 1).numpy()
    target=target.astype(int)
   
    aucu=roc_auc_score(target,scores[:,1])
    precision_tmp, recall_tmp, _thresholds = precision_recall_curve(target, pred)
    aupr = auc(recall_tmp, precision_tmp)
    f1 = f1_score(target,pred)
    kappa=cohen_kappa_score(target,pred)
    bacc=balanced_accuracy_score(target,pred)
    
    return aucu,aupr,f1,kappa,bacc
def Downstream_data_preprocess(args,n_fold,node_type_dict,cell_line): #FIXME
    """
    load SL data and preprocess before training 
    """
    task_data_path=args.Task_data_path
    train_data=pd.read_csv(f"{task_data_path}/{cell_line}/sl_train_{n_fold}.csv")
    test_data=pd.read_csv(f"{task_data_path}/{cell_line}/sl_test_{n_fold}.csv",)
    train_data.columns=[0,1,2,3]
    test_data.columns=[0,1,2,3]
    train_data[0]=train_data[0].astype(str).map(node_type_dict)
    train_data[1]=train_data[1].astype(str).map(node_type_dict)
    test_data[0]=test_data[0].astype(str).map(node_type_dict)
    test_data[1]=test_data[1].astype(str).map(node_type_dict)
    train_data=train_data.dropna()
    test_data=test_data.dropna()
    train_data[0]=train_data[0].astype(int)
    train_data[1]=train_data[1].astype(int)
    test_data[0]=test_data[0].astype(int)
    test_data[1]=test_data[1].astype(int)
    # low data scenario settings
    if args.do_low_data:
        num_sample=int(train_data.shape[0]*args.train_data_ratio)
        print(num_sample)
        train_data=train_data.sample(num_sample,replace=False,random_state=0)
        train_data.reset_index(inplace=True)
        print(f'train_data.size:{train_data.shape[0]}')

    train_node=list(set(train_data[0])|set(train_data[1]))
    print(f'train_node.size:{len(train_node)}')
    train_mask=torch.zeros((27671))
    test_mask=torch.zeros((27671))
    test_node=list(set(test_data[0])|set(test_data[1]))
    train_mask[train_node]=1
    test_mask[test_node]=1
    train_mask=train_mask.bool()
    test_mask=test_mask.bool()
    num_train_node=len(train_node)
    num_test_node=len(test_node)
    return train_data,test_data,train_mask,test_mask,num_train_node,num_test_node

def Construct_loader(args,kgdata,train_mask,test_mask,node_type,num_train_node,num_test_node):
    """
    construct loader for train/test data
    """

    train_loader = HGTLoader(kgdata,
    num_samples={key: [args.sample_nodes] * args.sample_layers for key in kgdata.node_types},shuffle=False,
    batch_size=num_train_node,
    input_nodes=(node_type,train_mask),num_workers=args.num_workers)

    test_loader=HGTLoader(kgdata,
    num_samples={key: [args.sample_nodes] * args.sample_layers for key in kgdata.node_types},
    batch_size=num_test_node,
    input_nodes=(node_type,test_mask),num_workers=args.num_workers,shuffle=False)

    return train_loader,test_loader

class MyModel(nn.Module):
    def __init__(
        self,
        graph_embedding_dim = 128,
        transformer_depth = 3,
        num_transformer_heads = 4,
        num_classes = 2,
        graph_encoder = HGT,
        ):
        

In [6]:
# main
set_logger(args)
with open (args.Full_data_path,'rb') as f:
    kgdata=pickle.load(f)
with open("../data/processed_data/gene_protein_2_id.json",'rb') as f:
    node_index=json.load(f)
torch.manual_seed(0)
np.random.seed(0)
device = torch.device(args.device)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

gene_protein=node_index[args.node_type] 
eval_metric_folds={'fold':[],'auc':[],'aupr':[],'f1':[],'bacc':[],'kappa':[]}
node_type=args.node_type
num_nodes_type=len(kgdata.node_types)
num_edge_type=len(kgdata.edge_types)
num_nodes=kgdata.num_nodes
input_node_embeddings = torch.nn.Embedding(num_nodes_type, 16)
torch.nn.init.xavier_uniform_(input_node_embeddings.weight.data)
for i in range(len(kgdata.node_types)):
    num_repeat=kgdata[kgdata.node_types[i]].x.shape[0]
    kgdata[kgdata.node_types[i]].x =input_node_embeddings(torch.tensor(i)).repeat([num_repeat,1]).detach()

if args.specific:# Do cv within cell line
    logging.info(f"Cell line specific")
    for i in range(args.folds):
        n_fold = i
        train_data,test_data,train_mask,test_mask,num_train_node,num_test_node=Downstream_data_preprocess(args,args.cv,n_fold,gene_protein)
        train_loader,test_loader=Construct_loader(args,kgdata,train_mask,test_mask,node_type,num_train_node,num_test_node)


In [19]:
train_data

Unnamed: 0,0,1,2,3
0,27482,14366,A375,0
1,1592,16112,A375,0
2,13513,15543,A375,0
3,15020,25816,A375,0
4,27431,2881,A375,0
...,...,...,...,...
423,16390,26768,A375,0
424,19399,14930,A375,1
425,26053,26913,A375,0
426,7635,1035,A375,0
