In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve, auc, accuracy_score, roc_curve, recall_score

import matplotlib.pyplot as plt
import warnings
from util import normalization
import random
import copy

warnings.filterwarnings('ignore')

In [None]:
# Load and normalize data
X_train_gd = normalization(np.load('./data/feature/X_train_gd_nl_new.npy'))
X_train_gmd = normalization(np.load('./data/feature/X_train_gmd_nl.npy'))
X_train_gld = normalization(np.load('./data/feature/X_train_gld_nl.npy')) 

X_test_gd = normalization(np.load('./data/feature/X_test_gd_nl_new.npy'))
X_test_gmd = normalization(np.load('./data/feature/X_test_gmd_nl.npy'))
X_test_gld = normalization(np.load('./data/feature/X_test_gld_nl.npy'))

X_gd = np.concatenate((X_train_gd,X_test_gd))
X_gmd = np.concatenate((X_train_gmd,X_test_gmd))
X_gld = np.concatenate((X_train_gld,X_test_gld))

y_train = np.loadtxt('./data/y_train.txt')
y_test = np.loadtxt('./data/y_test.txt')

X_gd_all = torch.from_numpy(normalization(np.copy(X_gd))).type(torch.FloatTensor)
X_gmd_all = torch.from_numpy(normalization(np.copy(X_gmd))).type(torch.FloatTensor)
X_gld_all = torch.from_numpy(normalization(np.copy(X_gld))).type(torch.FloatTensor)
y_all = torch.from_numpy(np.concatenate((y_train, y_test)).reshape(-1, 1)).type(torch.FloatTensor)

In [None]:
class GraphTransformer(nn.Module):
    def __init__(self, input_dim, n_heads=8, ff_hidden_dim=128, dropout=0.2):
        super(GraphTransformer, self).__init__()
        self.n_heads = n_heads
        self.attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=n_heads, dropout=dropout)
        self.layer_norm1 = nn.LayerNorm(input_dim)
        self.layer_norm2 = nn.LayerNorm(ff_hidden_dim)
        
        self.fc1 = nn.Linear(input_dim, ff_hidden_dim)
        self.fc2 = nn.Linear(ff_hidden_dim, input_dim)
        self.dropout = nn.Dropout(dropout)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, x_gmd, x_gld):
        # Combine input features
        x_combined = torch.cat((x, x_gmd, x_gld), dim=1)
        x_combined = x_combined.unsqueeze(1)  # Add sequence dimension for attention

        # Attention mechanism
        attn_output, _ = self.attention(x_combined, x_combined, x_combined)
        x = self.layer_norm1(x_combined + attn_output)

        # Feed-forward network
        ff_output = self.fc2(self.dropout(F.relu(self.fc1(x))))
        x = self.layer_norm2(x + ff_output)
        
        x = x.squeeze(1)  # Remove sequence dimension
        x = self.sigmoid(x)
        return x

    def predict(self, x, x_gmd, x_gld):
        with torch.no_grad():
            output = self.forward(x, x_gmd, x_gld)
            return (output > 0.5).float()

    def predict_probability(self, x, x_gmd, x_gld):
        with torch.no_grad():
            return self.forward(x, x_gmd, x_gld)

def train_cv(model, X, X_gmd, X_gld, y, opt, criterion, batch_size=64):
    model.train()
    losses = []
    for beg_i in range(0, X.size(0), batch_size):
        x_batch = X[beg_i:beg_i + batch_size, :]
        x_gmd_batch = X_gmd[beg_i:beg_i + batch_size, :]
        x_gld_batch = X_gld[beg_i:beg_i + batch_size, :]
        y_batch = y[beg_i:beg_i + batch_size, :]
        
        opt.zero_grad()
        y_hat = model(x_batch, x_gmd_batch, x_gld_batch)
        loss = criterion(y_hat, y_batch)
        loss.backward()
        opt.step()
        
        losses.append(loss.item())
    return losses, sum(losses) / len(losses)

In [None]:
cv = StratifiedKFold(n_splits=5)
num_epochs = 30
fold_loss = []
fpr_list, tpr_list, aucs, recal = [], [], [], []

for train_idx, test_idx in cv.split(X_gd_all, y_all):
    model = GraphTransformer(X_gd_all.size(1) * 3)
    optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-5)
    criterion = nn.BCELoss()

    for epoch in range(num_epochs):
        losses, avg_loss = train_cv(model, X_gd_all[train_idx], X_gmd_all[train_idx], X_gld_all[train_idx], y_all[train_idx], optimizer, criterion)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss}")

    fold_loss.append(losses)
    y_pred_prob = model.predict_probability(X_gd_all[test_idx], X_gmd_all[test_idx], X_gld_all[test_idx])
    y_pred = (y_pred_prob > 0.5).float()

    fpr, tpr, _ = roc_curve(y_all[test_idx], y_pred_prob.numpy())
    fpr_list.append(fpr)
    tpr_list.append(tpr)
    aucs.append(roc_auc_score(y_all[test_idx], y_pred))
    recal.append(recall_score(y_all[test_idx], y_pred))

In [None]:
for i in range(len(aucs)):
    plt.plot(fpr_list[i], tpr_list[i], label=f"Fold-{i+1} (AUC={aucs[i]:.3f})")
plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--')
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.title('ROC Curves for Graph Transformer')
plt.legend()
plt.savefig('./data/nl_ROC.svg', dpi=300)
plt.show()