In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn  
import pickle
from tqdm import tqdm
import networkx as nx
import pickle
import os.path
import os
import argparse
import shutil
import warnings
from cell_net_omics import cellNetDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATv2Conv

exp = pd.read_csv("/home/data/sdb/wt/model_data/cell_gene_exp_vs_normal_filter.csv")
mut = pd.read_csv("/home/data/sdb/wt/model_data/mut_dt.csv")
cnv = pd.read_csv("/home/data/sdb/wt/model_data/cnv_dt.csv")
cell_net = cellNetDataset(root="/home/data/sdb/wt/model_data/omics_net",
                          filename = "train_cell_info_omics.csv",
                          exp = exp, mut = mut, cnv = cnv,
                          data_type = "train", 
                          net_path = "/home/data/sdb/wt/model_data/enzyme_train/", 
                          cores = 30)

Processing...
Done!


In [3]:
cell_net[0]

Data(edge_index=[2, 14216], nodename=[824], x=[824, 3247], y=[378], y_index=[378], exp=[824, 7993], mut=[824, 6806], cnv=[824, 6336], cell='ACH-000001')

In [4]:
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GATv2Conv(3247, 512, heads=3)
        self.conv2 = GATv2Conv(3 * 512, 512, heads=3)
        self.conv3 = GATv2Conv(3 * 512, 512, heads=3)
        self.lin1 = torch.nn.Linear(3 * 512 + 512 * 3, 1024)
        self.lin2 = torch.nn.Linear(1024, 512)
        self.lin3 = torch.nn.Linear(512, 1)
        
        self.encoder_exp = torch.nn.Sequential(
            torch.nn.Linear(7993, 4000),
            torch.nn.ReLU(),
            torch.nn.Linear(4000, 1500),
            torch.nn.ReLU(),
            torch.nn.Linear(1500, 512)
        )
        self.encoder_mut = torch.nn.Sequential(
            torch.nn.Linear(6806, 4000),
            torch.nn.ReLU(),
            torch.nn.Linear(4000, 1500),
            torch.nn.ReLU(),
            torch.nn.Linear(1500, 512)
        )
        self.encoder_cnv = torch.nn.Sequential(
            torch.nn.Linear(6336, 4000),
            torch.nn.ReLU(),
            torch.nn.Linear(4000, 1500),
            torch.nn.ReLU(),
            torch.nn.Linear(1500, 512)
        )

    def forward(self, x, edge_index, exp, mut, cnv):
        x = torch.relu(self.conv1(x, edge_index))
        x = torch.relu(self.conv2(x, edge_index))
        x, (idx, atten) = self.conv3(x, edge_index, return_attention_weights=True)
        exp_end = self.encoder_exp(exp)
        mut_end = self.encoder_mut(mut)
        cnv_end = self.encoder_cnv(cnv)
        
        x = torch.relu(x)
        cat_feat = torch.cat((x, cell_end, mut_end, cnv_end),1)
        out = torch.relu(self.lin1(cat_feat))
        out = torch.relu(self.lin2(out))
        out = self.lin3(out)
        return out, (idx, atten)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
model.load_state_dict(torch.load("/root/cancer_target/enzyme_model_filterV2.pt", map_location=device))

<All keys matched successfully>

In [4]:
train_loader = DataLoader(cell_net, batch_size=1)

In [5]:
tt = next(iter(train_loader))

In [5]:
model.eval()
for _, data in enumerate(tqdm(train_loader)):
    data = data.to(device)
    res, atten = model(data.x.float(), data.edge_index, data.exp.float())
    idx = atten[0].cpu().detach().numpy()
    weight = atten[1].cpu().detach().numpy()
    res = np.rint(torch.sigmoid(res.cpu().detach()).numpy().flatten())
    ### 转化为 数据框
    idx_dt = pd.DataFrame(np.transpose(idx),columns=["source","target"])
    weight_dt = pd.DataFrame(weight,columns=["w1","w2","w3"])
    all_res = pd.concat([idx_dt,weight_dt],axis=1)
    # need_idx = data.y_index.cpu().detach().numpy()
    # all_res_filter = all_res[(all_res.source.isin(need_idx) | all_res.target.isin(need_idx))].reset_index().drop(columns=["index"])
    need_genes = np.array(data.nodename[0])
    gene_idx = pd.DataFrame({"source":range(len(need_genes)),
                             "target":range(len(need_genes)),
                             "genes":need_genes,
                             "preds":res})
    all_res = all_res.merge(gene_idx.drop(columns=["target"]),
                            on='source', how='left').merge(gene_idx.drop(columns=["source"]),
                                                           on='target', how='left')
    all_res["cell"] = data.cell*len(all_res)
    all_res.to_csv("/root/autodl-tmp/atten/"+data.cell[0]+".tsv",sep="\t")

100%|██████████| 684/684 [03:42<00:00,  3.07it/s]


In [6]:
data = tt.to(device)
res, atten = model(data.x.float(), data.edge_index, data.exp.float())

In [7]:
atten

(tensor([[  0,   0,   0,  ..., 821, 822, 823],
         [  1,  43,  46,  ..., 821, 822, 823]], device='cuda:0'),
 tensor([[1.4214e-15, 1.4970e-13, 7.2325e-15],
         [5.0834e-35, 6.9566e-20, 1.9366e-34],
         [1.5292e-20, 6.8977e-22, 1.8351e-21],
         ...,
         [1.0000e+00, 1.0000e+00, 1.0000e+00],
         [1.0000e+00, 1.0000e+00, 1.0000e+00],
         [3.8354e-05, 1.1435e-01, 1.6849e-05]], device='cuda:0',
        grad_fn=<DivBackward0>))