# Graph Data Loading

In [10]:
import os
import torch
import numpy as np

# Define working and data directories
# This assumes the graph data is located in a relative '../Data' directory
work_dir = os.getcwd()  # Retrieve the current working directory
input_data_dir = os.path.join(work_dir, '../Data')  # Define input directory for graph data
output_dir = os.path.join(work_dir, '../Data')      # Define output directory (same as input in this context)

# Load the preprocessed graph dataset (PyTorch Geometric format)
# The file is expected to contain a list of graphs prepared for model inference or training
merged_file = os.path.join(input_data_dir, 'all_graphs_to_be_predicted.pt')
merged_graphs = torch.load(merged_file)

# Output the number of graphs successfully loaded
len(merged_graphs)


6

# GAT Model Definition

In [11]:
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv, global_mean_pool
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Function to fix random seed for reproducibility
def set_seed(seed):
    import random
    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # Ensures GPU-level reproducibility if applicable

set_seed(42)  # Set your preferred random seed value

# Definition of the Graph Attention Network (GAT) model
class GATModel(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_heads, dropout_rate=0.3, dosage_weight=1.0):
        """
        GAT-based neural network for graph-level prediction tasks.

        Parameters:
            in_dim (int): Input feature dimension for each node.
            hidden_dim (int): Hidden dimension for GAT layers.
            out_dim (int): Output dimension (e.g., number of classes or regression targets).
            num_heads (int): Number of attention heads in each GAT layer.
            dropout_rate (float): Dropout rate applied to GAT layers.
            dosage_weight (float): Scaling factor for the 91st feature (dosage ratio).
        """
        super(GATModel, self).__init__()
        self.dosage_weight = dosage_weight  # Weighting factor for dosage feature (feature 91)

        # Stack of 4 GATConv layers for deep graph feature extraction
        self.layer1 = GATConv(in_dim, hidden_dim, heads=num_heads, dropout=dropout_rate)
        self.layer2 = GATConv(hidden_dim * num_heads, hidden_dim, heads=num_heads, dropout=dropout_rate)
        self.layer3 = GATConv(hidden_dim * num_heads, hidden_dim, heads=num_heads, dropout=dropout_rate)
        self.layer4 = GATConv(hidden_dim * num_heads, hidden_dim, heads=1, dropout=dropout_rate)  # Last GAT layer

        # Fully connected layer to map pooled graph-level representation to final output
        self.fc = nn.Linear(hidden_dim, out_dim)

        # Initialize weights using Xavier uniform initialization
        self._initialize_weights()

    def _initialize_weights(self):
        """
        Apply Xavier initialization to linear components of GAT layers.
        """
        for layer in [self.layer1, self.layer2, self.layer3, self.layer4]:
            nn.init.xavier_uniform_(layer.lin.weight)
            if layer.lin.bias is not None:
                nn.init.zeros_(layer.lin.bias)

    def forward(self, data):
        """
        Forward pass of the GAT model.

        Parameters:
            data (torch_geometric.data.Data): A batch of graph data containing x, edge_index, and batch attributes.

        Returns:
            out (Tensor): Final prediction output.
            hg (Tensor): Graph-level pooled embeddings.
            attn_weights (Tuple): Attention weights from each GAT layer for interpretability.
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # Amplify dosage-related feature (91st feature) using the specified weight, capped between 0 and 10
        x[:, 90] = torch.clamp(x[:, 90] * self.dosage_weight, min=0, max=10)

        # Apply first GAT layer
        h, attn_weights_1 = self.layer1(x, edge_index, return_attention_weights=True)
        h = torch.relu(h)

        # Apply second GAT layer
        h, attn_weights_2 = self.layer2(h, edge_index, return_attention_weights=True)
        h = torch.relu(h)

        # Apply third GAT layer
        h, attn_weights_3 = self.layer3(h, edge_index, return_attention_weights=True)
        h = torch.relu(h)

        # Apply fourth GAT layer (single-head output)
        h, attn_weights_4 = self.layer4(h, edge_index, return_attention_weights=True)

        # Aggregate node features to obtain graph-level embedding
        hg = global_mean_pool(h, batch)

        # Fully connected prediction layer
        out = self.fc(hg)

        # Return output, graph embedding, and attention weights for analysis
        return out, hg, (attn_weights_1, attn_weights_2, attn_weights_3, attn_weights_4)

# ---------------- Model Instantiation ----------------

# Define model hyperparameters
in_dim = 91        # Input feature dimension (91 features per node)
hidden_dim = 64    # Hidden size for GAT layers
out_dim = 4        # Output dimension (e.g., 4 therapeutic effect categories)
num_heads = 2      # Number of attention heads
dropout_rate = 0.4 # Dropout rate during training
dosage_weight = 1  # Dosage amplification factor (for feature 91)

# Instantiate the GAT model
model = GATModel(in_dim, hidden_dim, out_dim, num_heads, dropout_rate, dosage_weight=dosage_weight)
print(model)


GATModel(
  (layer1): GATConv(91, 64, heads=2)
  (layer2): GATConv(128, 64, heads=2)
  (layer3): GATConv(128, 64, heads=2)
  (layer4): GATConv(128, 64, heads=1)
  (fc): Linear(in_features=64, out_features=4, bias=True)
)


# GAT Model Loading and Memory Management

In [12]:
# ---------------------- Model Loading and Memory Cleanup ----------------------

import os
import torch
import gc

# Explicitly remove existing model instance to release memory
del model  # Delete current model instance from memory
gc.collect()  # Manually trigger Python garbage collection to free GPU/CPU memory

# Reconstruct the GAT model architecture
# This step is required prior to loading state if the model was saved via state_dict
model = GATModel(in_dim, hidden_dim, out_dim, num_heads, dropout_rate, dosage_weight=dosage_weight)

# Define the working and data directories
work_dir = os.getcwd()  # Get the current working directory (usually project root)
input_data_dir = os.path.join(work_dir, '../Data')  # Directory containing saved model files
output_dir = os.path.join(work_dir, '../Data')      # Output directory (shared)

# Specify the path to the saved model file (.pth format)
model_save_path = os.path.join(input_data_dir, 'gat_model.pth')

# Load the pre-trained model from file
# Note: This assumes the full model (not just state_dict) was saved via torch.save(model)
model = torch.load(model_save_path)

# Set the model to evaluation mode (disable dropout, etc.)
model.eval()

# Confirm successful loading
print(f"Model loaded from {model_save_path}")


Model loaded from D:\博士文件\博士毕业课题材料\维吾尔医药配伍机制量化分析\GraphAI-for-Uyghur-Medicine\Python\../Data\gat_model.pth


In [13]:
len(merged_graphs)

6

# Model Inference and TSV Output for Predictions and Attention Weights

In [14]:
# ---------------------- GAT Model Inference and Result Export ----------------------

import os
import torch
import pandas as pd
import numpy as np

# Utility function to format floating point numbers to 4 decimal places
def format_value(val):
    return round(val, 4)

# Define the prediction routine for a given range of graph samples
def predict_samples(start_index, end_index):
    output_results = []  # Stores prediction scores and graph-level embeddings
    attn_results = []    # Stores attention weights for all graph edges

    # Iterate over the specified range of graphs
    for i in range(start_index, end_index + 1):
        sample = merged_graphs[i]
        cpm_id = sample.cpm_id  # Unique prescription identifier

        # Run the model in evaluation mode
        out, hg, attn_weights = model(sample)

        # Apply sigmoid to convert raw logits to probability values
        out_probs = torch.sigmoid(out).detach().cpu().numpy()

        # First output table: prediction scores and graph-level embeddings
        output_result = {
            "cpm_id": cpm_id,
            **{f"Class_{j+1}": format_value(val) for j, val in enumerate(out_probs.flatten())},
            **{f"hg_{j+1}": format_value(val) for j, val in enumerate(hg.detach().cpu().numpy().flatten())}
        }
        output_results.append(output_result)

        # Second output table: edge-wise attention weights
        node_names = sample.node_names  # Ensure node names are available

        edge_dict = {}  # Dictionary to collect attention data for each edge

        # Loop over each attention layer's output
        for layer_idx, (edge_index, attn_weight) in enumerate(attn_weights, start=1):
            edge_index_np = edge_index.detach().cpu().numpy()  # Shape: [2, E]
            attn_weight_np = attn_weight.detach().cpu().numpy()  # Shape: [E, num_heads]

            edges = edge_index_np.T  # Convert to shape [E, 2] for easier iteration

            # Iterate over all edges and collect attention scores
            for edge, attn in zip(edges, attn_weight_np):
                node_idx_1, node_idx_2 = edge
                node_name_1 = node_names[int(node_idx_1)]
                node_name_2 = node_names[int(node_idx_2)]

                edge_key = (node_name_1, node_name_2)

                # Initialize or update the dictionary for this edge
                if edge_key not in edge_dict:
                    edge_dict[edge_key] = {
                        "cpm_id": cpm_id,
                        "Source": node_name_1,
                        "Target": node_name_2
                    }

                # Record attention weights per head
                for head_idx, attn_value in enumerate(attn, start=1):
                    attn_col_name = f"attn_weights_{layer_idx}_head_{head_idx}"
                    edge_dict[edge_key][attn_col_name] = format_value(attn_value)

        # Append attention results to list
        attn_results.extend(edge_dict.values())

    # Convert predictions to DataFrame and export as TSV
    output_df = pd.DataFrame(output_results)
    output_path = os.path.join(output_dir, 'prediction_outputs.tsv')
    output_df.to_csv(output_path, sep='\t', index=False)
    print(f"Prediction outputs exported to {output_path} as TSV")

    # Convert attention data to DataFrame and export as TSV
    attn_df = pd.DataFrame(attn_results)
    attn_path = os.path.join(output_dir, 'attention_weights.tsv')
    attn_df.to_csv(attn_path, sep='\t', index=False)
    print(f"Attention weights exported to {attn_path} as TSV")

# ---------------------- Inference Execution ----------------------

# Specify the prediction range (inclusive)
start_index = 0  # Start from the first graph
end_index = 5    # Predict for first 6 graphs (0 through 5)

# Execute inference and export results
predict_samples(start_index, end_index)


Prediction outputs exported to D:\博士文件\博士毕业课题材料\维吾尔医药配伍机制量化分析\GraphAI-for-Uyghur-Medicine\Python\../Data\prediction_outputs.tsv as TSV
Attention weights exported to D:\博士文件\博士毕业课题材料\维吾尔医药配伍机制量化分析\GraphAI-for-Uyghur-Medicine\Python\../Data\attention_weights.tsv as TSV
