In [1]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedGroupKFold
from scipy.stats import rankdata
import torch
import dgl
import time
from tqdm import tqdm
import random
import gc
import dgl.function as fn
from dgl.nn import GATConv

In [2]:
class GraphDataset(torch.utils.data.Dataset):
    def __init__(self, dir_path, indexes=None, add_self_loop=False):
        super(GraphDataset, self).__init__()
        self.dir_path = dir_path
        self.graphs, label_dict = dgl.load_graphs(self.dir_path+'/dgl_graph.bin')
        self.df = pd.read_csv(self.dir_path+'/overview_df.csv', index_col=0)
        self.add_self_loop = add_self_loop
        if indexes is None:
            self.indexes = self.df.index
        else:
            self.indexes = indexes

    def __getitem__(self, i):
        idx = self.indexes[i]
        
        row = self.df.loc[idx]
        
        graph_index = row.graph_index
        graph = self.graphs[graph_index].clone()
        if self.add_self_loop:
            graph = dgl.add_self_loop(graph)
        
        seq_feature = np.load(self.dir_path+'/'+row.seq_feature_path)
        seq = seq_feature['seq']
        seq = torch.tensor(seq) 
        seq = torch.cat((seq[:, :1280], seq[:, -16:]), dim=1)
        
        surface_aa_feature = np.load(self.dir_path+'/'+row.surface_aa_feature_path)
        surface_aa_seq = surface_aa_feature['surface_aa_seq']
        surface_aa_seq = torch.tensor(surface_aa_seq) 
        surface_aa_seq = torch.cat((surface_aa_seq[:, :1280], surface_aa_seq[:, -16:]), dim=1)
        
        surface_pos = surface_aa_feature['surface_pos']
        
        graph.ndata['seq'] = seq
        graph.ndata['surface_aa_seq'] = surface_aa_seq
        graph.ndata['surface_pos'] = torch.from_numpy(surface_pos)
        
        label = row.get('pHmin', np.nan)
        label_valid = True
             
        return graph, label, label_valid, idx

    def __len__(self):
        return len(self.indexes)

In [3]:
import torch
import torch.nn as nn
import dgl
from dgl.nn.pytorch import GraphConv

class GCN(nn.Module):
    def __init__(self, hidden_dim, layer_num=3):
        super(GCN, self).__init__()
        self.convs = nn.ModuleList()
        self.activations = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        for i in range(layer_num):
            in_feats = hidden_dim if i == 0 else hidden_dim
            out_feats = hidden_dim
            self.convs.append(GraphConv(in_feats, out_feats))
            self.activations.append(nn.LeakyReLU())
            self.batch_norms.append(nn.BatchNorm1d(out_feats))

        self.layer_num = layer_num
        self.out_dim = hidden_dim * layer_num
        
    def forward(self, g, h):
        hs = [h]
        for conv, batch_norm, act in zip(self.convs, self.batch_norms, self.activations):
            h = conv(g, h)
            h = batch_norm(h)
            h = act(h)
            hs.append(h)
        return torch.cat(hs, dim=-1)
class GNNModel(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim=256, dropout_rate=0.5):
        super(GNNModel, self).__init__()
        self.comp = torch.nn.Sequential(
            torch.nn.Linear(in_dim, hidden_dim),
            torch.nn.LeakyReLU()
        )
        self.gcn = GCN(hidden_dim)
        self.head = torch.nn.Sequential(
            torch.nn.Dropout(0.5),
            torch.nn.Linear(2048, self.gcn.out_dim),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(self.gcn.out_dim, 1),
        )
    
    def forward(self, g, wildtype_seq, surface_aa_seq, surface_pos):
        wildtype_h = self.comp(wildtype_seq)
        surface_aa_h = self.comp(surface_aa_seq)
        wildtype_h = self.gcn(g, wildtype_h)
        surface_aa_h = self.gcn(g, surface_aa_h)
        with g.local_scope():
            g.ndata['h'] = wildtype_h
            wildtype_hg = dgl.readout_nodes(g, 'h', op='sum') 

        with g.local_scope():
            g.ndata['h'] = surface_aa_h
            surface_aa_hp = dgl.readout_nodes(g, 'h', op='sum')
        h_all = torch.cat([wildtype_hg, surface_aa_hp], dim=-1)
        #print(h_all.shape)
        pred = self.head(h_all).squeeze()
        return pred

In [5]:
model = GNNModel(1296, 256).cuda()
model.load_state_dict(torch.load('ACENet.pth'))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
test_dataset = GraphDataset("1fhe")
test_dataloader = dgl.dataloading.GraphDataLoader(test_dataset, batch_size=16, shuffle=False, drop_last=False, num_workers=26)
test_labels = []
test_predictions = []
for graph, label, label_valid, original_index in tqdm(test_dataloader, leave=True):
    graph = graph.to(device)
    label = torch.tensor(label).to(device).unsqueeze(1)
    label = label.squeeze()
    seq, surface_aa_seq, surface_pos = graph.ndata['seq'], graph.ndata['surface_aa_seq'], graph.ndata['surface_pos']
    with torch.no_grad():   
        pred = model(graph, seq, surface_aa_seq, surface_pos)

    # Collect predictions
    # test_labels.extend(label.cpu().numpy())
    # test_predictions.extend(pred.cpu().numpy())
    print(pred)

  model.load_state_dict(torch.load('ACENet.pth'))
  label = torch.tensor(label).to(device).unsqueeze(1)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.76s/it]

tensor(6.2696, device='cuda:0')


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:07<00:00,  7.04s/it]


In [6]:
import pandas as pd
df = pd.read_csv('1fhe.csv')  # 替换为你的文件路径
df['pHmin'] = test_predictions

# 打印结果以验证
print(df)
df.to_csv('1fhe.csv', index=False)

ValueError: Length of values (0) does not match length of index (1)