In [None]:
import os
import torch
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
from sklearn.model_selection import StratifiedKFold

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# **Lazy Loading: Load data on demand**
class GraphDataset(Dataset):
    def __init__(self, root, label, transform=None, pre_transform=None):
        super(GraphDataset, self).__init__(root, transform, pre_transform)
        self.file_paths = [os.path.join(root, f) for f in os.listdir(root)]  # Store file paths only to save memory
        self.label = label

    def len(self):
        return len(self.file_paths)

    def get(self, idx):
        """ Load a single sample on demand to avoid loading all data at once """
        file_path = self.file_paths[idx]
        data = torch.load(file_path, weights_only=False)  # Load only when needed
        data.y = torch.tensor([self.label], dtype=torch.float32)
        return data

# **Data loading**
def load_data(data_dir):
    datasets = {}
    for split in ['neg-test', 'neg-train', 'pos-test', 'pos-train']:
        split_dir = os.path.join(data_dir, split)
        if os.path.exists(split_dir):
            label = 0 if 'neg' in split else 1
            datasets[split] = GraphDataset(root=split_dir, label=label)
        else:
            print(f"Warning: {split_dir} does not exist.")
    return datasets

# **Get k-fold data loaders**
def get_kfold_data(data_dir, n_splits=5):
    datasets = load_data(data_dir)

    # Training dataset
    train_dataset = datasets['neg-train'] + datasets['pos-train']
    # Testing dataset
    test_dataset = datasets['neg-test'] + datasets['pos-test']

    labels = [data.y.item() for data in train_dataset]

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

    # Generator yields training and validation sets for each fold
    for train_idx, val_idx in skf.split(range(len(train_dataset)), labels):
        train_subset = torch.utils.data.Subset(train_dataset, train_idx)
        val_subset = torch.utils.data.Subset(train_dataset, val_idx)

        train_loader = DataLoader(train_subset, batch_size=2, shuffle=True, num_workers=0)
        val_loader = DataLoader(val_subset, batch_size=2, shuffle=False, num_workers=0)

        yield train_loader, val_loader

    # Finally return the full test set
    return test_dataset


In [None]:
# import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv, global_mean_pool, global_max_pool, GATConv
from torch_scatter import scatter_mean


##############  GNN Layer ##############
class GNNLayer(nn.Module):
    def __init__(self, num_hidden, dropout=0.2, num_heads=8):
        super(GNNLayer, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.ModuleList([nn.LayerNorm(num_hidden) for _ in range(2)])  # LayerNorm
        self.attention = TransformerConv(
            in_channels=num_hidden,
            out_channels=int(num_hidden / num_heads),  
            heads=num_heads,  
            dropout=dropout,
            edge_dim=num_hidden,  
            root_weight=False
        )
        self.PositionWiseFeedForward = nn.Sequential(
            nn.Linear(num_hidden, num_hidden * 4),
            nn.ReLU(),
            nn.Linear(num_hidden * 4, num_hidden)
        )
        self.edge_update = EdgeMLP(num_hidden, dropout)
        self.context = Context(num_hidden)

    def forward(self, h_V, edge_index, h_E, batch_id):

        dh = self.attention(h_V, edge_index, h_E)  
        h_V = self.norm[0](h_V + self.dropout(dh))  
        dh = self.PositionWiseFeedForward(h_V)  
        h_V = self.norm[1](h_V + self.dropout(dh))  
        h_E = self.edge_update(h_V, edge_index, h_E) 
        h_V = self.context(h_V, edge_index)  
        return h_V, h_E


class EdgeMLP(nn.Module):
    def __init__(self, num_hidden, dropout=0.2):
        super(EdgeMLP, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.BatchNorm1d(num_hidden)
        self.W11 = nn.Linear(3 * num_hidden, num_hidden, bias=True)
        self.W12 = nn.Linear(num_hidden, num_hidden, bias=True)
        self.act = torch.nn.GELU()

    def forward(self, h_V, edge_index, h_E):
        src_idx = edge_index[0] 
        dst_idx = edge_index[1]  
        h_EV = torch.cat([h_V[src_idx], h_E, h_V[dst_idx]], dim=-1)  
        h_message = self.W12(self.act(self.W11(h_EV)))  
        h_E = self.norm(h_E + self.dropout(h_message))  
        return h_E


##############  Context (with GAT)  ##############
class Context(nn.Module):
    def __init__(self, num_hidden):
        super(Context, self).__init__()
        self.gat_conv = GATConv(in_channels=num_hidden, out_channels=num_hidden, heads=1, dropout=0.2)

    def forward(self, h_V, edge_index):
        h_V_context = self.gat_conv(h_V, edge_index) 
   
        return h_V_context


##############  Graph Encoder ##############
class Graph_encoder(nn.Module):
    def __init__(self, node_in_dim, edge_in_dim, hidden_dim, num_layers=4, drop_rate=0.2):
        super(Graph_encoder, self).__init__()
   
        self.node_embedding = nn.Linear(node_in_dim, hidden_dim, bias=True)  
        self.edge_embedding = nn.Linear(edge_in_dim, hidden_dim, bias=True)  
        self.edge_transform = nn.Linear(edge_in_dim, hidden_dim)  
        self.norm_nodes = nn.BatchNorm1d(hidden_dim)
        self.norm_edges = nn.BatchNorm1d(hidden_dim)
        self.W_v = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.W_e = nn.Linear(hidden_dim, hidden_dim, bias=True)


        self.layers = nn.ModuleList(
            GNNLayer(num_hidden=hidden_dim, dropout=drop_rate, num_heads=8)
            for _ in range(num_layers)
        )

    def forward(self, h_V, edge_index, h_E, batch_id):

        h_V = self.W_v(self.norm_nodes(self.node_embedding(h_V)))  

        h_E = self.edge_transform(h_E)  
        h_E = self.W_e(self.norm_edges(h_E))  

        for i, layer in enumerate(self.layers):
            h_V, h_E = layer(h_V, edge_index, h_E, batch_id) 
        return h_V, h_E


class GPSol(nn.Module):
    def __init__(self, protein_node_input_dim, edge_input_dim, protein_hidden_dim, num_layers, dropout, device):
        super(GPSol, self).__init__()
        self.device = device
        

        self.Graph_encoder_protein = Graph_encoder(
            node_in_dim=protein_node_input_dim,  
            edge_in_dim=edge_input_dim,  
            hidden_dim=protein_hidden_dim,  
            num_layers=num_layers,
            drop_rate=dropout
        )
        
        

        self.transformer = nn.Transformer(
            d_model=2 * (protein_hidden_dim),  
            nhead=4,  
            num_encoder_layers=2,  
            num_decoder_layers=2,  
            dim_feedforward=1028, 
            dropout=dropout,
            batch_first=True  
        )

        self.dropout = nn.Dropout(dropout) 



        self.fc1 = nn.Linear(2 * protein_hidden_dim, 128)
        self.fc3 = nn.Linear(128, 1)  
        self.norm = nn.LayerNorm(2 * protein_hidden_dim)

    def forward(self, h_V_protein, edge_index_protein, h_E_protein, batch_id_protein):

        

        h_V_protein, _ = self.Graph_encoder_protein(h_V_protein, edge_index_protein, h_E_protein, batch_id_protein)


        h_V_protein_pooled = global_mean_pool(h_V_protein, batch_id_protein)  
        h_V_protein_maxpooled = global_max_pool(h_V_protein, batch_id_protein)  
        h_V_combined = torch.cat([h_V_protein_pooled, h_V_protein_maxpooled], dim=1)
        
        


        h_V_combined = h_V_combined.unsqueeze(0) 
        h_V_combined = self.transformer(h_V_combined, h_V_combined)  
        h_V_combined = h_V_combined.squeeze(0) 
        

        emb = F.relu(self.fc1(h_V_combined))  
        emb = self.dropout(emb)  
        output = torch.sigmoid(self.fc3(emb))  
        
        return output


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch_geometric.loader import DataLoader
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, f1_score, matthews_corrcoef, precision_score, recall_score
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# **Training function**
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    y_true, y_pred = [], []

    for batch in dataloader:
        optimizer.zero_grad()
        batch = batch.to(device)
        labels = batch.y.float().to(device) 

        h_V_protein = batch.x.to(device)
        edge_index_protein = batch.edge_index.to(device)
        h_E_protein = batch.edge_attr.to(device) if batch.edge_attr is not None else torch.zeros(batch.edge_index.size(1), 450).to(device)

        output = model(h_V_protein, edge_index_protein, h_E_protein, batch.batch)
        output = output.view(-1)

        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_correct += ((output >= 0.5).float() == labels).sum().item()
        total_samples += labels.size(0)
        y_true.extend(labels.cpu().numpy())  
        y_pred.extend(output.detach().cpu().numpy())

    # Convert y_true and y_pred to NumPy arrays
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    # Calculate AUC and AUPR from raw prediction scores
    auc = roc_auc_score(y_true, y_pred)
    aupr = average_precision_score(y_true, y_pred)

    # Convert predictions to binary class labels
    y_pred_labels = (y_pred >= 0.5).astype(int)

    # Calculate metrics
    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_samples
    f1 = f1_score(y_true, y_pred_labels)
    mcc = matthews_corrcoef(y_true, y_pred_labels)
    precision = precision_score(y_true, y_pred_labels, zero_division=0)
    recall = recall_score(y_true, y_pred_labels, zero_division=0)

    return avg_loss, accuracy, auc, aupr, f1, mcc, precision, recall


# **Validation function**
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    y_true, y_pred = [], []

    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            labels = batch.y.float().to(device)

            h_V_protein = batch.x.to(device)
            edge_index_protein = batch.edge_index.to(device)
            h_E_protein = batch.edge_attr.to(device) if batch.edge_attr is not None else torch.zeros(batch.edge_index.size(1), 450).to(device)

            output = model(h_V_protein, edge_index_protein, h_E_protein, batch.batch)
            output = output.view(-1)

            loss = criterion(output, labels)
            total_loss += loss.item()

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(output.detach().cpu().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    y_pred_labels = (y_pred >= 0.5).astype(int)

    avg_loss = total_loss / len(dataloader)
    accuracy = (y_pred_labels == y_true).mean()
    auc = roc_auc_score(y_true, y_pred)
    aupr = average_precision_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred_labels)
    mcc = matthews_corrcoef(y_true, y_pred_labels)
    precision = precision_score(y_true, y_pred_labels, zero_division=0)
    recall = recall_score(y_true, y_pred_labels, zero_division=0)

    return avg_loss, accuracy, auc, aupr, f1, mcc, precision, recall


# **Train the final model directly**
num_epochs = 100
data_dir = r'D:\UESTC\odorant\graph'  # Change to your data directory

# Load all data
train_data = load_data(data_dir)['neg-train'] + load_data(data_dir)['pos-train']
test_data = load_data(data_dir)['neg-test'] + load_data(data_dir)['pos-test']

train_loader = DataLoader(train_data, batch_size=2, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1, shuffle=False)

# Initialize the model
model = GPSol(
    protein_node_input_dim=1217, 
    edge_input_dim=450, 
    protein_hidden_dim=256,
    num_layers=2,
    dropout=0.5, 
    device=device
).to(device)

optimizer = optim.SGD(model.parameters(), lr=0.001, weight_decay=0.00001)
criterion = nn.BCELoss().to(device)

# Directory to save models
model_save_dir = r'D:\UESTC\odorant\model'
if not os.path.exists(model_save_dir):
    os.makedirs(model_save_dir)

# Train the model and save model for each epoch
for epoch in range(num_epochs):
    avg_loss, accuracy, auc, aupr, f1, mcc, precision, recall = train(model, train_loader, optimizer, criterion, device)
    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}, AUC: {auc:.4f}, AUPR: {aupr:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")

    # Validate after each epoch
    val_loss, val_accuracy, val_auc, val_aupr, val_f1, val_mcc, val_precision, val_recall = validate(model, test_loader, criterion, device)
    print(f"Test Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.4f}, AUC: {val_auc:.4f}, AUPR: {val_aupr:.4f}, F1: {val_f1:.4f}, MCC: {val_mcc:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}")

    # Save the model of current epoch
    torch.save(model.state_dict(), os.path.join(model_save_dir, f'model_epoch_{epoch+1}.pth'))
    print(f"Model saved for epoch {epoch+1}.")

# **Load and evaluate the final test model**
final_model = GPSol(
    protein_node_input_dim=1217, 
    edge_input_dim=450, 
    protein_hidden_dim=256,
    num_layers=2,
    dropout=0.5, 
    device=device
).to(device)

# Load the final epoch model (change if needed)
final_model.load_state_dict(torch.load(os.path.join(model_save_dir, f'model_epoch_{num_epochs}.pth')))
final_model.eval()

# **Final test evaluation**
final_loss, final_acc, final_auc, final_aupr, final_f1, final_mcc, final_precision, final_recall = validate(final_model, test_loader, criterion, device)
print(f"Final Test Results: Loss: {final_loss:.4f}, Accuracy: {final_acc:.4f}, AUC: {final_auc:.4f}, AUPR: {final_aupr:.4f}, F1: {final_f1:.4f}, MCC: {final_mcc:.4f}, Precision: {final_precision:.4f}, Recall: {final_recall:.4f}")
