In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchdrug import data

from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error
from math import sqrt
import numpy as np

from GearNet import GearNet
from pepGraph_model import BiLSTM, MixBiLSTM_GearNet
import matplotlib.pyplot as plt

In [27]:
def test_model(model, test_loader, device):
    y_pred = []
    y_true = []
    range_list = []
    chain_list = []
    model.eval()
    for i, graph_batch in enumerate(test_loader):
        graph_batch = graph_batch.to(device)
        targets = graph_batch.y
        outputs = model(graph_batch, graph_batch.residue_feature.float())
        #outputs = model(graph_batch)
        #outputs = model(graph_batch.seq_embedding)
        #range_list.extend(graph_batch.range.cpu().detach().numpy())
        #chain_list.extend(graph_batch.chain)

        y_pred.append(outputs.cpu().detach().numpy())
        y_true.append(targets.cpu().detach().numpy())
    y_pred = np.concatenate(y_pred, axis=0)
    y_true = np.concatenate(y_true, axis=0)
    return y_true, y_pred, range_list, chain_list

def plot_results(y_true, y_pred):
    plt.scatter(y_true, y_pred, s=1)
    plt.plot([0, 1], [0, 1], color="red", linestyle="--")
    pcc = pearsonr(y_true, y_pred)[0]
    spR = spearmanr(y_true, y_pred)[0]
    plt.legend(["PCC: %.3f" % pcc, "SPR: %.3f" % spR], loc="upper left")
    plt.xlabel("True")
    plt.ylabel("Predicted")
    plt.show()

In [4]:
config = {
        'num_epochs':150,
        'batch_size': 16,
        'learning_rate': 0.001,
        'weight_decay': 5e-4,
        'GNN_type': 'GAT',
        'num_GNN_layers': 3,
        'cross_validation_num': 1,
        'num_workers': 4,
}

training_args = {'num_hidden_channels': 10, 'num_out_channels': 20, 

        'feat_in_dim': 56, 'topo_in_dim': 42, 'num_heads': 8, 'GNN_hidden_dim': 32,
        'GNN_out_dim': 64, 'LSTM_out_dim': 64,

        'final_hidden_dim': 16,

        'drop_out': 0.5, 'num_GNN_layers': config['num_GNN_layers'], 'GNN_type': config['GNN_type'],
        'graph_hop': 'hop1', 'batch_size': config['batch_size'],
        'result_dir': '/home/lwang/models/HDX_LSTM/results/240601_finalExp',
        'data_log': True,
}

In [31]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
cluster = 'cluster1'
model_name = 'GearNet'
model_fpath = f'/home/lwang/models/HDX_LSTM/results/240601_finalExp/model_{model_name}_epoch150_{cluster}.pth'

test_apo = torch.load(f'/home/lwang/models/HDX_LSTM/data/Latest_set/graph_ensemble_GearNetEdge/{cluster}/test_apo.pt')
test_complex = torch.load(f'/home/lwang/models/HDX_LSTM/data/Latest_set/graph_ensemble_GearNetEdge/{cluster}/test_complex.pt')

#GearNet
model = GearNet(input_dim = 56+42, hidden_dims = [512,512,512],
                num_relation=7, batch_norm=True, concat_hidden=True, readout='sum', activation = 'relu', short_cut=True)

#GearNet-Edge
#model = GearNet(input_dim=56+42, hidden_dims=[512, 512, 512], 
#                num_relation=7, edge_input_dim=59, num_angle_bin=8,
#                batch_norm=True, concat_hidden=True, short_cut=True, readout="sum", activation = 'relu').to(device)

#BiLSTM
#model = BiLSTM(training_args).to(device)

#MixBiLSTM_GearNet or GearNetEdge
#model = MixBiLSTM_GearNet(training_args).to(device)

#MixBiLSTM_GVP

model_state_dict = torch.load(model_fpath, map_location=device)
model.load_state_dict(model_state_dict)
model = model.to(device)

In [32]:
apo_dataloader = data.DataLoader(test_apo, batch_size=16, shuffle=False)
complex_dataloader = data.DataLoader(test_complex, batch_size=16, shuffle=False)
total_set = data.Protein.pack(list(test_apo) + list(test_complex))
total_dataloader = data.DataLoader(total_set, batch_size=16, shuffle=False)

In [35]:

y_true, y_pred, range_list, chain_list = test_model(model, total_dataloader, device)
print(y_true.shape, y_pred.shape)
pcc = pearsonr(y_true, y_pred)[0]
spR = spearmanr(y_true, y_pred)[0]
print(f'PCC: {pcc}, SPR: {spR}')

(1420,) (1420,)
PCC: 0.8050080474128625, SPR: 0.804118834347377
