In [1]:
import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Dataset, Data
import numpy as np 
import os
from tqdm import tqdm
import copy
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class pMHCDataset(Dataset):
    def __init__(self, root, filename, aaindex, transform=None, pre_transform=None):
        """
        root = Where the dataset should be stored. This folder is split
        into raw_dir (downloaded dataset) and processed_dir (processed data). 
        """
        self.filename = filename
        self.aaindex = aaindex
        super(pMHCDataset, self).__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        """ If this file exists in raw_dir, the download is not triggered.
            (The download func. is not implemented here)  
        """
        return self.filename

    @property
    def processed_file_names(self):
        """ If these files are found in processed_dir, processing is skipped"""
        self.data = pd.read_csv(self.raw_paths[0]).reset_index()
        return [f'data_{i}.pt' for i in list(self.data.index)]

    def download(self):
        pass##不需要下载
    
    def process(self):
        self.data = pd.read_csv(self.raw_paths[0])
        for index, sample in tqdm(self.data.iterrows(), total=self.data.shape[0]):#tqdm可以显示运行进程
            # Get node features
            node_feats = self._get_node_features(sample["pep"],sample["hla_seq"],self.aaindex)
            edge_index = self._get_edge_index(sample["pep"],sample["hla_seq"])
            label = self._get_labels(sample["type"])
            # Create data object
            data = Data(x=node_feats, edge_index=edge_index, y=label, index=0) 
            torch.save(data, os.path.join(self.processed_dir, f'data_{index}.pt'))

    def _get_node_features(self, pep, HLA, aaindex):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of Nodes, Node Feature size]
        """
        all_seq = pep + HLA
        all_node_feats = []
        for index, aa in enumerate(all_seq):
            node_feats = []
            ##aaindex
            node_feats.extend(aaindex[aa].to_list())
            anchar = [0,len(pep)]
            seq_onehot = [0,0]
            seq_onehot[sum([index >= i for i in anchar])-1] = 1
            node_feats.extend(seq_onehot)
            all_node_feats.append(node_feats)
        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats)
        
    
    def _get_labels(self, label):
        label = np.asarray([label])
        return torch.tensor(label)
    
    def _get_edge_index(self, pep, hla):
        ##生成边
        nodes = list(range(0,len(pep)+len(hla)))
        edge_index = [[],[]]  
        for i,_ in enumerate(pep):
            nodes_cp = copy.deepcopy(nodes)
            nodes_cp.remove(i)
            edge_index[0].extend([i]*(len(nodes)-1))
            edge_index[1].extend(nodes_cp)
        for i,_ in enumerate(hla):
            i = i + len(pep)
            nodes_cp = copy.deepcopy(nodes)
            nodes_cp.remove(i)
            edge_index[0].extend([i]*(len(nodes)-1))
            edge_index[1].extend(nodes_cp)  
        edge_index = torch.tensor(edge_index)
        return edge_index
    
    def len(self):
        return self.data.shape[0]

    def get(self, idx):
        """ - Equivalent to __getitem__ in pytorch
            - Is not needed for PyG's InMemoryDataset
        """
        data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt')) 
        return data

In [2]:
aaindex = pd.read_csv("../data/aaindex1_pca.csv")
train_dt = pMHCDataset(root="../data/models/train_data/",
                       filename="train_data_iedb_2.csv",
                       aaindex=aaindex)

In [3]:
##model
from torch_geometric.nn import TransformerConv, TopKPooling, GraphNorm
from torch.nn import Linear, BatchNorm1d, ModuleList, LeakyReLU
from gtrick import random_feature
from gtrick.pyg import VirtualNode
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.nn import MemPooling
from torch.nn import LeakyReLU
class GNN(torch.nn.Module):
    def __init__(self, feature_size, model_params):
        super().__init__()
        
        embedding_size = model_params["model_embedding_size"]
        dense_neurons = model_params["model_dense_neurons"]
        n_heads = model_params["model_heads"]
        n_layers = model_params["model_layers"]
        self.n_layers = n_layers
        self.top_k_every_n = 1
        self.conv_layers = ModuleList([])
        self.transf_layers = ModuleList([])
        self.pooling_layers = ModuleList([])
        self.bn_layers = ModuleList([])
        self.vns = ModuleList()
        self.relu = LeakyReLU()

        # Transformation layer
        self.conv1 = TransformerConv(feature_size, 
                                    embedding_size, 
                                    heads=n_heads,
                                    beta=True) 

        self.transf1 = Linear(embedding_size*n_heads, embedding_size)
        self.bn1 =  GraphNorm(embedding_size)
        # Other layers
        for i in range(n_layers):
            self.conv_layers.append(TransformerConv(embedding_size, 
                                                    embedding_size, 
                                                    heads=n_heads,
                                                    beta=True))

            self.transf_layers.append(Linear(embedding_size*n_heads, embedding_size))
            self.bn_layers.append(GraphNorm(embedding_size))
            self.vns.append(VirtualNode(embedding_size, embedding_size))
            

        # Linear layers
        self.linear1 = Linear(embedding_size, dense_neurons)
        self.linear2 = Linear(dense_neurons, 1)  

    def forward(self, x, edge_index, batch_index, node_index):
        # Initial transformation
        x = self.conv1(x, edge_index)
        x = self.relu(self.transf1(x))
        x = self.bn1(x, batch_index)

        for i in range(self.n_layers):
            x, vx = self.vns[i].update_node_emb(x, edge_index, batch_index)
            x = self.conv_layers[i](x, edge_index)
            x = self.relu(self.transf_layers[i](x))
            x = self.bn_layers[i](x, batch_index)
            vx = self.vns[i].update_vn_emb(x, batch_index, vx)
        
        #x = x[node_index,:]
        # Output block
        x = self.relu(self.linear1(vx))
        x = self.linear2(x)
        return x

In [4]:
def calculate_metrics(y_pred, y_true, epoch, type):
    print(f"{type}, Epoch: {epoch}: ")
    print(f"\n Confusion matrix: \n {confusion_matrix(y_true, y_pred)}")
    print(f"F1 Score: {f1_score(y_true, y_pred)}")
    print(f"Accuracy: {accuracy_score(y_true, y_pred)}")
    prec = precision_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred)
    print(f"Precision: {prec}")
    print(f"Recall: {rec}")
    try:
        roc = roc_auc_score(y_true, y_pred)
        print(f"ROC AUC: {roc}")
    except:
        print(f"ROC AUC: notdefined")
    return prec, rec, roc

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def train_one_epoch(epoch, model, train_loader, optimizer, loss_fn):
    # Enumerate over the data
    all_preds = []
    all_labels = []
    running_loss = 0.0
    step = 0
    for _, batch in enumerate(tqdm(train_loader)): 
        #remove = batch[1]
        target = batch.y
        input_x = batch.x
        edge_index = batch.edge_index
        target = target.to(device)
        input_x = input_x.to(device)
        edge_index = edge_index.to(device)
        batch_index = batch.batch.to(device)
        node_index = batch.index.to(device)
        optimizer.zero_grad() 
        pred = model(input_x.float(), edge_index, batch_index, node_index) 
        # Calculating the loss and gradients
        loss = loss_fn(pred, target.reshape(pred.shape[0],1).float())
        loss.backward()  
        optimizer.step()  
        # Update tracking
        running_loss += loss.item()
        step += 1
        all_preds.append(np.rint(torch.sigmoid(pred).cpu().detach().numpy()))
        all_labels.append(target.cpu().detach().numpy())
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    prec, rec, roc = calculate_metrics(all_preds, all_labels, epoch, "train")
    return running_loss/step, prec, rec, roc

def test(epoch, model, test_loader, loss_fn):
    all_preds = []
    all_labels = []
    running_loss = 0.0
    step = 0
    for batch in test_loader:
        target = batch.y
        input_x = batch.x
        edge_index = batch.edge_index
        target = target.to(device)
        input_x = input_x.to(device)
        edge_index = edge_index.to(device)
        batch_index = batch.batch.to(device)
        node_index = batch.index.to(device)
        pred = model(input_x.float(), edge_index, batch_index, node_index) 
        loss = loss_fn(pred, target.reshape(pred.shape[0],1).float())
         # Update tracking
        running_loss += loss.item()
        step += 1
        all_preds.append(np.rint(torch.sigmoid(pred).cpu().detach().numpy()))
        all_labels.append(target.cpu().detach().numpy())
        
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    prec, rec, roc = calculate_metrics(all_preds, all_labels, epoch, "test")
    return running_loss/step, prec, rec, roc

# %% Run the training
def run_one_training(params, train_dt, loss_fn, writer):

    # Prepare training
    train_loader = DataLoader(train_dt, batch_size=params["batch_size"], num_workers = 20, pin_memory=True, shuffle=True)
    #test_loader = DataLoader(test_dt, batch_size=params["batch_size"], num_workers = 20, pin_memory=True, shuffle=True)

    # Loading the model
    print("Loading model...")
    model_params = {k: v for k, v in params.items() if k.startswith("model_")}
    model = GNN(feature_size=22, model_params=model_params) 
    model = model.to(device)
    print(f"Number of parameters: {count_parameters(model)}")

    # < 1 increases precision, > 1 recall
    optimizer = torch.optim.Adam(model.parameters(), 
                                lr=params["learning_rate"],
                                weight_decay=params["weight_decay"])
    #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params["scheduler_gamma"])

    # Start training
    for epoch in range(36): 
        # Training
        model.train()
        loss, prec, rec, roc = train_one_epoch(epoch, model, train_loader, optimizer, loss_fn)
        print(f"Epoch {epoch} | Train Loss {loss}")
        writer.add_scalar("Train/Loss", loss, epoch)
        # Testing
        #if epoch % 5 == 0:
        #    model.eval()
        #    loss, prec, rec, roc = test(epoch, model, test_loader, loss_fn)
        #    print(f"Epoch {epoch} | Test Loss {loss}")
    return model

In [5]:
import numpy as np
import pandas as pd
import torch
from torch.utils import data
from torch_geometric.loader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score, roc_auc_score
from tqdm import tqdm
HYPERPARAMETERS = {
    "batch_size": 8,
    "weight_decay": 0.00001,
    "learning_rate": 0.0001,
    "model_embedding_size": 64, 
    "model_dense_neurons": 32,
    "model_heads":3,
    "model_layers":3
}
writer = SummaryWriter(f"/root/tf-logs/gnn_neo_loss")
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(1.4))
model_last = run_one_training(HYPERPARAMETERS,train_dt,loss_fn,writer)

Loading model...
Number of parameters: 273089


100%|██████████| 1052/1052 [00:57<00:00, 18.16it/s]


train, Epoch: 0: 

 Confusion matrix: 
 [[3040 1905]
 [1098 2369]]
F1 Score: 0.6120656245963054
Accuracy: 0.6430099857346647
Precision: 0.5542817033224146
Recall: 0.6832996827228152
ROC AUC: 0.6490310344857757
Epoch 0 | Train Loss 0.731081186308607


100%|██████████| 1052/1052 [00:54<00:00, 19.28it/s]


train, Epoch: 1: 

 Confusion matrix: 
 [[3232 1713]
 [ 751 2716]]
F1 Score: 0.6879432624113475
Accuracy: 0.7070851165002378
Precision: 0.6132309776473245
Recall: 0.7833862128641477
ROC AUC: 0.7184878485958758
Epoch 1 | Train Loss 0.6584465550435813


100%|██████████| 1052/1052 [00:56<00:00, 18.50it/s]


train, Epoch: 2: 

 Confusion matrix: 
 [[3359 1586]
 [ 753 2714]]
F1 Score: 0.6988541264323419
Accuracy: 0.7219448407037565
Precision: 0.6311627906976744
Recall: 0.7828093452552639
ROC AUC: 0.7310406685831425
Epoch 2 | Train Loss 0.6439041010893797


100%|██████████| 1052/1052 [00:56<00:00, 18.53it/s]


train, Epoch: 3: 

 Confusion matrix: 
 [[3318 1627]
 [ 698 2769]]
F1 Score: 0.7043113315528423
Accuracy: 0.7236091298145506
Precision: 0.6298908098271155
Recall: 0.7986732044995674
ROC AUC: 0.7348269965874985
Epoch 3 | Train Loss 0.6343430633950596


100%|██████████| 1052/1052 [00:55<00:00, 18.83it/s]


train, Epoch: 4: 

 Confusion matrix: 
 [[3284 1661]
 [ 652 2815]]
F1 Score: 0.7088002014352259
Accuracy: 0.7250356633380884
Precision: 0.6289097408400357
Recall: 0.8119411595038939
ROC AUC: 0.7380231581139287
Epoch 4 | Train Loss 0.6266339514475359


100%|██████████| 1052/1052 [00:56<00:00, 18.66it/s]


train, Epoch: 5: 

 Confusion matrix: 
 [[3423 1522]
 [ 740 2727]]
F1 Score: 0.7068429237947123
Accuracy: 0.7310984308131241
Precision: 0.6417980701341492
Recall: 0.7865589847130083
ROC AUC: 0.7393866713251593
Epoch 5 | Train Loss 0.620491536709519


100%|██████████| 1052/1052 [00:58<00:00, 18.06it/s]


train, Epoch: 6: 

 Confusion matrix: 
 [[3380 1565]
 [ 690 2777]]
F1 Score: 0.7112306313228326
Accuracy: 0.7319305753685211
Precision: 0.6395670198065407
Recall: 0.8009806749351024
ROC AUC: 0.7422496903492499
Epoch 6 | Train Loss 0.6100686796774429


100%|██████████| 1052/1052 [00:57<00:00, 18.40it/s]


train, Epoch: 7: 

 Confusion matrix: 
 [[3429 1516]
 [ 719 2748]]
F1 Score: 0.7109041521148622
Accuracy: 0.7343081312410842
Precision: 0.6444652908067542
Recall: 0.7926160946062879
ROC AUC: 0.7430218996792815
Epoch 7 | Train Loss 0.613002237399948


100%|██████████| 1052/1052 [00:56<00:00, 18.65it/s]


train, Epoch: 8: 

 Confusion matrix: 
 [[3447 1498]
 [ 703 2764]]
F1 Score: 0.7152283607193687
Accuracy: 0.7383499762244413
Precision: 0.6485218207414359
Recall: 0.797231035477358
ROC AUC: 0.7471493903372634
Epoch 8 | Train Loss 0.6049348601766865


100%|██████████| 1052/1052 [00:50<00:00, 20.79it/s]


train, Epoch: 9: 

 Confusion matrix: 
 [[3507 1438]
 [ 724 2743]]
F1 Score: 0.7173117154811716
Accuracy: 0.7429862101759391
Precision: 0.6560631427888065
Recall: 0.7911739255840785
ROC AUC: 0.7501875694654467
Epoch 9 | Train Loss 0.600368035553526


100%|██████████| 1052/1052 [00:57<00:00, 18.34it/s]


train, Epoch: 10: 

 Confusion matrix: 
 [[3494 1451]
 [ 680 2787]]
F1 Score: 0.7234263465282285
Accuracy: 0.7466714217784118
Precision: 0.6576215195847098
Recall: 0.8038650129795212
ROC AUC: 0.755218654113623
Epoch 10 | Train Loss 0.5946313865454478


100%|██████████| 1052/1052 [00:56<00:00, 18.68it/s]


train, Epoch: 11: 

 Confusion matrix: 
 [[3504 1441]
 [ 731 2736]]
F1 Score: 0.7158555729984301
Accuracy: 0.7417974322396577
Precision: 0.6550155614077089
Recall: 0.7891548889529852
ROC AUC: 0.7488747144461589
Epoch 11 | Train Loss 0.5995374696024697


100%|██████████| 1052/1052 [00:56<00:00, 18.51it/s]


train, Epoch: 12: 

 Confusion matrix: 
 [[3554 1391]
 [ 720 2747]]
F1 Score: 0.7224194608809994
Accuracy: 0.7490489776509748
Precision: 0.6638472692121798
Recall: 0.792327660801846
ROC AUC: 0.7555167120996085
Epoch 12 | Train Loss 0.5952884511236002


100%|██████████| 1052/1052 [00:53<00:00, 19.67it/s]


train, Epoch: 13: 

 Confusion matrix: 
 [[3577 1368]
 [ 741 2726]]
F1 Score: 0.721068641714059
Accuracy: 0.7492867332382311
Precision: 0.6658524670249145
Recall: 0.7862705509085665
ROC AUC: 0.7548137385483176
Epoch 13 | Train Loss 0.5893776307144546


100%|██████████| 1052/1052 [00:47<00:00, 22.37it/s]


train, Epoch: 14: 

 Confusion matrix: 
 [[3625 1320]
 [ 729 2738]]
F1 Score: 0.7277076411960133
Accuracy: 0.7564194008559201
Precision: 0.6747166091670774
Recall: 0.789731756561869
ROC AUC: 0.7613977286348272
Epoch 14 | Train Loss 0.5881547904343206


100%|██████████| 1052/1052 [00:46<00:00, 22.56it/s]


train, Epoch: 15: 

 Confusion matrix: 
 [[3559 1386]
 [ 726 2741]]
F1 Score: 0.7218856992362391
Accuracy: 0.7489300998573466
Precision: 0.664162830142961
Recall: 0.7905970579751946
ROC AUC: 0.7551569718591848
Epoch 15 | Train Loss 0.5775493247710707


100%|██████████| 1052/1052 [00:49<00:00, 21.39it/s]


train, Epoch: 16: 

 Confusion matrix: 
 [[3621 1324]
 [ 713 2754]]
F1 Score: 0.7300198807157058
Accuracy: 0.7578459343794579
Precision: 0.6753310446297205
Recall: 0.7943466974329392
ROC AUC: 0.7633007501320409
Epoch 16 | Train Loss 0.5764632602381615


100%|██████████| 1052/1052 [00:49<00:00, 21.31it/s]


train, Epoch: 17: 

 Confusion matrix: 
 [[3635 1310]
 [ 742 2725]]
F1 Score: 0.7264729405491868
Accuracy: 0.7560627674750356
Precision: 0.6753407682775713
Recall: 0.7859821171041246
ROC AUC: 0.760534031251759
Epoch 17 | Train Loss 0.5828078484778848


100%|██████████| 1052/1052 [00:46<00:00, 22.46it/s]


train, Epoch: 18: 

 Confusion matrix: 
 [[3641 1304]
 [ 722 2745]]
F1 Score: 0.7304417243214477
Accuracy: 0.7591535901093676
Precision: 0.6779451716473204
Recall: 0.7917507931929623
ROC AUC: 0.7640250427036601
Epoch 18 | Train Loss 0.5789848621297246


100%|██████████| 1052/1052 [00:49<00:00, 21.28it/s]


train, Epoch: 19: 

 Confusion matrix: 
 [[3681 1264]
 [ 710 2757]]
F1 Score: 0.7363782051282052
Accuracy: 0.7653352353780314
Precision: 0.6856503357373788
Recall: 0.7952119988462648
ROC AUC: 0.7698001349135268
Epoch 19 | Train Loss 0.5711950442608545


100%|██████████| 1052/1052 [00:53<00:00, 19.64it/s]


train, Epoch: 20: 

 Confusion matrix: 
 [[3628 1317]
 [ 737 2730]]
F1 Score: 0.726643598615917
Accuracy: 0.7558250118877794
Precision: 0.6745737583395107
Recall: 0.787424286126334
ROC AUC: 0.760547330120801
Epoch 20 | Train Loss 0.5814627334192225


100%|██████████| 1052/1052 [00:49<00:00, 21.29it/s]


train, Epoch: 21: 

 Confusion matrix: 
 [[3706 1239]
 [ 739 2728]]
F1 Score: 0.7339252085014797
Accuracy: 0.7648597242035188
Precision: 0.6876733047643055
Recall: 0.7868474185174502
ROC AUC: 0.768145650613629
Epoch 21 | Train Loss 0.5706989862336406


100%|██████████| 1052/1052 [00:49<00:00, 21.28it/s]


train, Epoch: 22: 

 Confusion matrix: 
 [[3692 1253]
 [ 698 2769]]
F1 Score: 0.7394845773801576
Accuracy: 0.7680694246314789
Precision: 0.6884634510193933
Recall: 0.7986732044995674
ROC AUC: 0.7726429723205623
Epoch 22 | Train Loss 0.5705427612273639


100%|██████████| 1052/1052 [00:48<00:00, 21.58it/s]


train, Epoch: 23: 

 Confusion matrix: 
 [[3684 1261]
 [ 729 2738]]
F1 Score: 0.7334583444950442
Accuracy: 0.7634331906799809
Precision: 0.684671167791948
Recall: 0.789731756561869
ROC AUC: 0.7673633504750699
Epoch 23 | Train Loss 0.5689839917047396


100%|██████████| 1052/1052 [00:49<00:00, 21.17it/s]


train, Epoch: 24: 

 Confusion matrix: 
 [[3700 1245]
 [ 720 2747]]
F1 Score: 0.7365598605711221
Accuracy: 0.7664051355206848
Precision: 0.68812625250501
Recall: 0.792327660801846
ROC AUC: 0.7702790983483446
Epoch 24 | Train Loss 0.5648963202513443


100%|██████████| 1052/1052 [00:49<00:00, 21.10it/s]


train, Epoch: 25: 

 Confusion matrix: 
 [[3709 1236]
 [ 719 2748]]
F1 Score: 0.7376191115286539
Accuracy: 0.7675939134569663
Precision: 0.6897590361445783
Recall: 0.7926160946062879
ROC AUC: 0.771333325361789
Epoch 25 | Train Loss 0.5628265371482408


100%|██████████| 1052/1052 [00:51<00:00, 20.51it/s]


train, Epoch: 26: 

 Confusion matrix: 
 [[3656 1289]
 [ 711 2756]]
F1 Score: 0.7337593184238551
Accuracy: 0.7622444127436995
Precision: 0.6813349814585908
Recall: 0.7949235650418229
ROC AUC: 0.7671281121467963
Epoch 26 | Train Loss 0.5631183696620364


100%|██████████| 1052/1052 [00:53<00:00, 19.63it/s]


train, Epoch: 27: 

 Confusion matrix: 
 [[3744 1201]
 [ 711 2756]]
F1 Score: 0.7424568965517242
Accuracy: 0.7727056585829767
Precision: 0.696487237806419
Recall: 0.7949235650418229
ROC AUC: 0.7760259887898701
Epoch 27 | Train Loss 0.5607368143934034


100%|██████████| 1052/1052 [00:52<00:00, 19.88it/s]


train, Epoch: 28: 

 Confusion matrix: 
 [[3728 1217]
 [ 672 2795]]
F1 Score: 0.7474261264874983
Accuracy: 0.7754398478364242
Precision: 0.6966600199401795
Recall: 0.8061724834150562
ROC AUC: 0.7800326522232005
Epoch 28 | Train Loss 0.5514995218343155


100%|██████████| 1052/1052 [00:56<00:00, 18.58it/s]


train, Epoch: 29: 

 Confusion matrix: 
 [[3743 1202]
 [ 738 2729]]
F1 Score: 0.7377669640443363
Accuracy: 0.7693770803613885
Precision: 0.6942253879419995
Recall: 0.7871358523218921
ROC AUC: 0.7720310201953242
Epoch 29 | Train Loss 0.5548050052737782


100%|██████████| 1052/1052 [00:56<00:00, 18.61it/s]


train, Epoch: 30: 

 Confusion matrix: 
 [[3739 1206]
 [ 688 2779]]
F1 Score: 0.7458400429414922
Accuracy: 0.7748454588682834
Precision: 0.6973651191969887
Recall: 0.8015575425439861
ROC AUC: 0.7788374163680497
Epoch 30 | Train Loss 0.5547089708735281


100%|██████████| 1052/1052 [00:56<00:00, 18.50it/s]


train, Epoch: 31: 

 Confusion matrix: 
 [[3712 1233]
 [ 691 2776]]
F1 Score: 0.7426431246655965
Accuracy: 0.7712791250594389
Precision: 0.6924420054876528
Recall: 0.8006922411306605
ROC AUC: 0.7756747353277165
Epoch 31 | Train Loss 0.5478256116886783


100%|██████████| 1052/1052 [00:51<00:00, 20.48it/s]


train, Epoch: 32: 

 Confusion matrix: 
 [[3742 1203]
 [ 741 2726]]
F1 Score: 0.7371552190373173
Accuracy: 0.7689015691868759
Precision: 0.693815220157801
Recall: 0.7862705509085665
ROC AUC: 0.7714972572540809
Epoch 32 | Train Loss 0.5535343329036191


100%|██████████| 1052/1052 [00:56<00:00, 18.67it/s]


train, Epoch: 33: 

 Confusion matrix: 
 [[3739 1206]
 [ 728 2739]]
F1 Score: 0.7390717754991905
Accuracy: 0.7700903471231574
Precision: 0.694296577946768
Recall: 0.7900201903663109
ROC AUC: 0.7730687402792121
Epoch 33 | Train Loss 0.559595361989374


100%|██████████| 1052/1052 [00:56<00:00, 18.71it/s]


train, Epoch: 34: 

 Confusion matrix: 
 [[3735 1210]
 [ 679 2788]]
F1 Score: 0.74695244474213
Accuracy: 0.7754398478364242
Precision: 0.6973486743371686
Recall: 0.804153446783963
ROC AUC: 0.7797309195497166
Epoch 34 | Train Loss 0.552113721791669


100%|██████████| 1052/1052 [00:56<00:00, 18.67it/s]

train, Epoch: 35: 

 Confusion matrix: 
 [[3747 1198]
 [ 675 2792]]
F1 Score: 0.7488266058736757
Accuracy: 0.7773418925344746
Precision: 0.699749373433584
Recall: 0.8053071820017306
ROC AUC: 0.7815211339735649
Epoch 35 | Train Loss 0.5449835354242715





In [7]:
torch.save(model_last.state_dict(), "../data/models/gnn_full_model_ext.pt")

In [14]:
###预测
dt = pd.read_csv("../data/tesla_val.csv")
val_dt = pMHCDataset(root="../data/models/val_data/",
                       filename="db_pre_dt.csv",
                       aaindex=aaindex)

Processing...
100%|██████████| 1438/1438 [00:05<00:00, 260.94it/s]
Done!


In [15]:
data_loader = DataLoader(val_dt, batch_size=64, shuffle=False)
all_preds = []
all_labels = []
for _, batch in enumerate(tqdm(data_loader)):
    target = batch.y
    input_x = batch.x
    edge_index = batch.edge_index
    target = target.to(device)
    input_x = input_x.to(device)
    edge_index = edge_index.to(device)
    batch_index = batch.batch.to(device)
    node_index = batch.index.to(device)
    model_last.eval()
    pred = model_last(input_x.float(), edge_index, batch_index, node_index) 
    all_preds.append(torch.sigmoid(pred).cpu().detach().numpy())
    all_labels.append(target.cpu().detach().numpy())
all_preds = np.concatenate(all_preds).ravel()
all_labels = np.concatenate(all_labels).ravel()
pred = pd.DataFrame({"all_preds":all_preds,"all_labels":all_labels})
pred.to_csv("../data/db_pred_gnn.csv")

100%|██████████| 23/23 [00:01<00:00, 17.01it/s]


In [16]:
torch.save(model_last, "../data/models/last_model.pt")