In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader , random_split,SubsetRandomSampler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import pickle
from tqdm import tqdm
import networkx as nx
from networkx.algorithms.link_analysis import pagerank
import pickle
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score, roc_auc_score
import warnings
from cell_net_omics import cellNetDataset
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')

exp = pd.read_csv("/home/data/sdb/wt/model_data/cell_gene_exp_vs_normal_filter.csv")
mut = pd.read_csv("/home/data/sdb/wt/model_data/mut_dt.csv")
cnv = pd.read_csv("/home/data/sdb/wt/model_data/cnv_dt.csv")
cell_net = cellNetDataset(root="/home/data/sdb/wt/model_data/omics_net",
                          filename = "train_cell_info_omics.csv",
                          exp = exp, mut = mut, cnv = cnv,
                          data_type = "train", 
                          net_path = "/home/data/sdb/wt/model_data/enzyme_train/", 
                          cores = 30)

from sklearn.model_selection import KFold
import random
splits = KFold(n_splits=10,shuffle=True,random_state=2023052701)

In [2]:
from torch_geometric.nn import GATv2Conv
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GATv2Conv(3247, 512, heads=3)
        self.conv2 = GATv2Conv(3 * 512, 512, heads=3)
        self.conv3 = GATv2Conv(3 * 512, 512, heads=3)
        self.lin1 = torch.nn.Linear(3 * 512 + 512 * 3, 1024)
        self.lin2 = torch.nn.Linear(1024, 512)
        self.lin3 = torch.nn.Linear(512, 1)
        
        self.encoder_exp = torch.nn.Sequential(
            torch.nn.Linear(7993, 4000),
            torch.nn.ReLU(),
            torch.nn.Linear(4000, 1500),
            torch.nn.ReLU(),
            torch.nn.Linear(1500, 512)
        )
        self.encoder_mut = torch.nn.Sequential(
            torch.nn.Linear(6806, 4000),
            torch.nn.ReLU(),
            torch.nn.Linear(4000, 1500),
            torch.nn.ReLU(),
            torch.nn.Linear(1500, 512)
        )
        self.encoder_cnv = torch.nn.Sequential(
            torch.nn.Linear(6336, 4000),
            torch.nn.ReLU(),
            torch.nn.Linear(4000, 1500),
            torch.nn.ReLU(),
            torch.nn.Linear(1500, 512)
        )

    def forward(self, x, edge_index, exp, mut, cnv):
        x = torch.relu(self.conv1(x, edge_index))
        x = torch.relu(self.conv2(x, edge_index))
        x, (idx, atten) = self.conv3(x, edge_index, return_attention_weights=True)
        exp_end = self.encoder_exp(exp)
        mut_end = self.encoder_mut(mut)
        cnv_end = self.encoder_cnv(cnv)
        
        x = torch.relu(x)
        cat_feat = torch.cat((x, exp_end, mut_end, cnv_end),1)
        out = torch.relu(self.lin1(cat_feat))
        out = torch.relu(self.lin2(out))
        out = self.lin3(out)
        return out

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_op = torch.nn.BCEWithLogitsLoss()

In [4]:
from tqdm import tqdm

def train():
    model.train()

    total_loss = 0
    for _, data in enumerate(tqdm(train_loader)):
        data = data.to(device)
        optimizer.zero_grad()
        pred = model(data.x.float(), data.edge_index, data.exp.float(), data.mut.float(), data.cnv.float())
        pred = pred[data.y_index]
        loss = loss_op(
            pred,
            data.y.float().reshape(-1, 1),
        )
        total_loss += loss.item() * data.num_graphs
        loss.backward()
        optimizer.step()
    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test():
    model.eval()

    ys, preds, preds_raw, genes = [], [], [],[]
    total_loss = 0
    for _, data in enumerate(tqdm(test_loader)):
        data = data.to(device)
        ys.append(data.y.cpu().detach().numpy())
        out = model(data.x.float(), data.edge_index, data.exp.float(), data.mut.float(), data.cnv.float())
        out = out[data.y_index]
        loss = loss_op(
            out,
            data.y.float().reshape(-1, 1),
        )
        total_loss += loss.item() * data.num_graphs
        preds.append(np.rint(torch.sigmoid(out).cpu().detach().numpy()))
        preds_raw.append(torch.sigmoid(out).cpu().detach().numpy())
        gene_name = np.array([i for batch in data.nodename for i in batch])
        genes.append(gene_name[data.y_index.cpu()])

    all_preds = np.concatenate(preds).ravel()
    all_labels = np.concatenate(ys).ravel()
    all_preds_raw = np.concatenate(preds_raw).ravel()
    all_genes = np.concatenate(genes).ravel()
    
    res = pd.DataFrame({"preds":all_preds,"preds_raw":all_preds_raw,"label":all_labels,"genes":all_genes})
    calculate_metrics(all_preds, all_labels, epoch, "test")
    return total_loss / len(test_loader.dataset), res


def calculate_metrics(y_pred, y_true, epoch, type):
    print(f"\n Confusion matrix: \n {confusion_matrix(y_true, y_pred)}")
    print(f"F1 Score: {f1_score(y_true, y_pred)}")
    print(f"Accuracy: {accuracy_score(y_true, y_pred)}")
    prec = precision_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred)
    print(f"Precision: {prec}")
    print(f"Recall: {rec}")
    try:
        roc = roc_auc_score(y_true, y_pred)
        print(f"ROC AUC: {roc}")
    except:
        print(f"ROC AUC: notdefined")

In [None]:
from torch_geometric.loader import DataLoader
for fold, (train_idx,val_idx) in enumerate(splits.split(np.arange(len(cell_net)))):

    print('Fold {}'.format(fold + 1))

    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(val_idx)
    train_loader = DataLoader(cell_net, batch_size=3, sampler=train_sampler, num_workers=10)
    test_loader = DataLoader(cell_net, batch_size=3, sampler=test_sampler, num_workers=10)
    
    model = Net().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(15):
        loss_train = train()
        print(f"Epoch: {epoch:03d}, Train Loss: {loss_train:.4f}")
        
    loss_test, test_pre = test()
    print(f"Epoch: {epoch:03d}, Test Loss: {loss_test:.4f}")
    test_pre.to_csv("cv/multi_omics/fold_"+str(fold)+".csv")