In [1]:
import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Dataset, Data
import numpy as np 
import os
from tqdm import tqdm
import copy
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class pMHCDataset(Dataset):
    def __init__(self, root, filename, aaindex, transform=None, pre_transform=None):
        """
        root = Where the dataset should be stored. This folder is split
        into raw_dir (downloaded dataset) and processed_dir (processed data). 
        """
        self.filename = filename
        self.aaindex = aaindex
        super(pMHCDataset, self).__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        """ If this file exists in raw_dir, the download is not triggered.
            (The download func. is not implemented here)  
        """
        return self.filename

    @property
    def processed_file_names(self):
        """ If these files are found in processed_dir, processing is skipped"""
        self.data = pd.read_csv(self.raw_paths[0]).reset_index()
        return [f'data_{i}.pt' for i in list(self.data.index)]

    def download(self):
        pass##不需要下载
    
    def process(self):
        self.data = pd.read_csv(self.raw_paths[0])
        for index, sample in tqdm(self.data.iterrows(), total=self.data.shape[0]):#tqdm可以显示运行进程
            # Get node features
            node_feats = self._get_node_features(sample["pep"],sample["hla_seq"],self.aaindex)
            edge_index = self._get_edge_index(sample["pep"],sample["hla_seq"])
            label = self._get_labels(sample["type"])
            # Create data object
            data = Data(x=node_feats, edge_index=edge_index, y=label, index=0) 
            torch.save(data, os.path.join(self.processed_dir, f'data_{index}.pt'))

    def _get_node_features(self, pep, HLA, aaindex):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of Nodes, Node Feature size]
        """
        all_seq = pep + HLA
        all_node_feats = []
        for index, aa in enumerate(all_seq):
            node_feats = []
            ##aaindex
            node_feats.extend(aaindex[aa].to_list())
            anchar = [0,len(pep)]
            seq_onehot = [0,0]
            seq_onehot[sum([index >= i for i in anchar])-1] = 1
            node_feats.extend(seq_onehot)
            all_node_feats.append(node_feats)
        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats)
        
    
    def _get_labels(self, label):
        label = np.asarray([label])
        return torch.tensor(label)
    
    def _get_edge_index(self, pep, hla):
        ##生成边
        nodes = list(range(0,len(pep)+len(hla)))
        edge_index = [[],[]]  
        for i,_ in enumerate(pep):
            nodes_cp = copy.deepcopy(nodes)
            nodes_cp.remove(i)
            edge_index[0].extend([i]*(len(nodes)-1))
            edge_index[1].extend(nodes_cp)
        for i,_ in enumerate(hla):
            i = i + len(pep)
            nodes_cp = copy.deepcopy(nodes)
            nodes_cp.remove(i)
            edge_index[0].extend([i]*(len(nodes)-1))
            edge_index[1].extend(nodes_cp)  
        edge_index = torch.tensor(edge_index)
        return edge_index
    
    def len(self):
        return self.data.shape[0]

    def get(self, idx):
        """ - Equivalent to __getitem__ in pytorch
            - Is not needed for PyG's InMemoryDataset
        """
        data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt')) 
        return data

In [2]:
aaindex = pd.read_csv("../model/aaindex1_pca.csv")
train_dt = pMHCDataset(root="/home/data/sda/wt/Neodb_model/",
                       filename="train_data_iedb_2.csv",
                       aaindex=aaindex)

In [3]:
from torch_geometric.nn import TransformerConv,  GraphNorm
from torch.nn import Linear, ModuleList, LeakyReLU
from gtrick.pyg import VirtualNode
from torch.nn import LeakyReLU
class GNN(torch.nn.Module):
    def __init__(self, feature_size, model_params):
        super().__init__()
        
        embedding_size = model_params["model_embedding_size"]
        dense_neurons = model_params["model_dense_neurons"]
        n_heads = model_params["model_heads"]
        n_layers = model_params["model_layers"]
        self.n_layers = n_layers
        self.top_k_every_n = 1
        self.conv_layers = ModuleList([])
        self.transf_layers = ModuleList([])
        self.pooling_layers = ModuleList([])
        self.bn_layers = ModuleList([])
        self.vns = ModuleList()
        self.relu = LeakyReLU()

        # Transformation layer
        self.conv1 = TransformerConv(feature_size, 
                                    embedding_size, 
                                    heads=n_heads,
                                    beta=True) 

        self.transf1 = Linear(embedding_size*n_heads, embedding_size)
        self.bn1 =  GraphNorm(embedding_size)
        # Other layers
        for i in range(n_layers):
            self.conv_layers.append(TransformerConv(embedding_size, 
                                                    embedding_size, 
                                                    heads=n_heads,
                                                    beta=True))

            self.transf_layers.append(Linear(embedding_size*n_heads, embedding_size))
            self.bn_layers.append(GraphNorm(embedding_size))
            self.vns.append(VirtualNode(embedding_size, embedding_size))
            

        # Linear layers
        self.linear1 = Linear(embedding_size, dense_neurons)
        self.linear2 = Linear(dense_neurons, 1)  

    def forward(self, x, edge_index, batch_index):
        # Initial transformation
        x = self.conv1(x, edge_index)
        x = self.relu(self.transf1(x))
        x = self.bn1(x, batch_index)

        for i in range(self.n_layers):
            x, vx = self.vns[i].update_node_emb(x, edge_index, batch_index)
            if i == 2:
                x, (edge, attention_weights) = self.conv_layers[i](x, edge_index, return_attention_weights=True)
            else:
                x = self.conv_layers[i](x, edge_index)
            x = self.relu(self.transf_layers[i](x))
            x = self.bn_layers[i](x, batch_index)
            vx = self.vns[i].update_vn_emb(x, batch_index, vx)
        
        # Output block
        x = self.relu(self.linear1(vx))
        x = self.linear2(x)
        return x, (edge, attention_weights)

HYPERPARAMETERS = {
    "model_embedding_size": 64, 
    "model_dense_neurons": 32,
    "model_heads":3,
    "model_layers":3
}
model_params = {k: v for k, v in HYPERPARAMETERS.items() if k.startswith("model_")}
model = GNN(feature_size=22, model_params=model_params) 
model_file = "last_model.pt"
model.load_state_dict(torch.load(model_file,map_location=torch.device('cpu')))

<All keys matched successfully>

In [4]:
len(train_dt)

8412

In [5]:
from torch_geometric.loader import DataLoader
data_loader = DataLoader(train_dt, batch_size=8412, shuffle=False)

In [None]:
for _, batch in enumerate(tqdm(data_loader)):
    input_x = batch.x
    edge_index = batch.edge_index
    batch_index = batch.batch
    model.eval()
    pred, (edge_index, weight) = model(input_x.float(), edge_index, batch_index)
    #weight_mean = torch.max(weight,dim=1).detach().numpy()

  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
weight_mean = torch.max(weight,dim=1).values.detach().numpy()

In [None]:
edge_index = edge_index.cpu().detach().numpy()

In [None]:
edge_index.shape,weight_mean.shape

In [None]:
weight_mean

In [None]:
df = pd.DataFrame({'index1': edge_index[0, :], 'index2': edge_index[1, :]})

In [None]:
df.to_csv("/home/data/sda/wt/model_data/Neodb_all_edge_index.csv")

In [None]:
weight_mean = pd.DataFrame({"weight_mean":weight_mean})

In [None]:
weight_mean.to_csv("/home/data/sda/wt/model_data/Neodb_all_weight_mean.csv")