# Data Loading

In [19]:
import os
import torch

# Load graph data and labels
# Set working directory and define input/output paths
work_dir = os.getcwd()  # Set current directory as work_dir
input_data_dir = os.path.join(work_dir, '../Data')  # Define input data path as ../Data
output_dir = os.path.join(work_dir, '../Data')  # Define output path as ../Data

# Load merged graph data from file
merged_file = os.path.join(input_data_dir, 'all_graphs_to_be_predicted.pt')
merged_graphs = torch.load(merged_file)
n = len(merged_graphs)  # Count number of graphs loaded
n


9

# GAT Model

In [20]:
import torch.nn as nn
from torch_geometric.nn import GATConv, global_mean_pool
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Set random seed for reproducibility
def set_seed(seed):
    import random
    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # For GPU

set_seed(42)  # Set random seed


# Define the GAT-based model
class GATModel(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_heads, dropout_rate=0.3, dosage_weight=1.0, initialize_weights=True):
        super(GATModel, self).__init__()
        self.dosage_weight = dosage_weight  # Scaling factor for the 91st feature
        self.layer1 = GATConv(in_dim, hidden_dim, heads=num_heads, dropout=dropout_rate)
        self.layer2 = GATConv(hidden_dim * num_heads, hidden_dim, heads=num_heads, dropout=dropout_rate)
        self.layer3 = GATConv(hidden_dim * num_heads, hidden_dim, heads=1, dropout=dropout_rate)
        self.fc = nn.Linear(hidden_dim, out_dim)

        # Initialize weights if specified
        if initialize_weights:
            self._initialize_weights()

    def _initialize_weights(self):
        for layer in [self.layer1, self.layer2, self.layer3]:
            nn.init.xavier_uniform_(layer.lin.weight)  # Initialize linear layer weights
            if layer.lin.bias is not None:
                nn.init.zeros_(layer.lin.bias)  # Initialize biases to 0

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # Apply GAT layers with attention weights
        h, attn_weights_1 = self.layer1(x, edge_index, return_attention_weights=True)
        h = torch.relu(h)

        h, attn_weights_2 = self.layer2(h, edge_index, return_attention_weights=True)
        h = torch.relu(h)

        h, attn_weights_3 = self.layer3(h, edge_index, return_attention_weights=True)
        
        # Aggregate node information using global mean pooling
        hg = global_mean_pool(h, batch)
        out = self.fc(hg)
        
        return out, hg, (attn_weights_1, attn_weights_2, attn_weights_3)  # Return output and attention weights


# Model parameter configuration
in_dim = 91       # Node feature dimension
hidden_dim = 64   # Hidden layer dimension
out_dim = 5       # Output dimension for 5 labels
num_heads = 4     # Number of attention heads in GAT layers
dropout_rate = 0.5
dosage_weight = 1 # Scaling factor for dosage feature

# Instantiate the GAT model
model = GATModel(in_dim, hidden_dim, out_dim, num_heads, dropout_rate, dosage_weight=dosage_weight)
print(model)



GATModel(
  (layer1): GATConv(91, 64, heads=4)
  (layer2): GATConv(256, 64, heads=4)
  (layer3): GATConv(256, 64, heads=1)
  (fc): Linear(in_features=64, out_features=5, bias=True)
)


# Model Loading

In [21]:
import gc

# Delete the previous model to free memory
del model
gc.collect()  # Force garbage collection




27

In [22]:
import os
import torch

# Set the working directory and define input/output paths
work_dir = os.getcwd()  # Set the current directory as work_dir
input_data_dir = os.path.join(work_dir, '../Model')  # Define the input data path
output_dir = os.path.join(work_dir, '../Data')  # Set ../Data as input data location

# Load graph data and labels from input_data_dir (if applicable, add specific file loading here)

# Define the path to the pre-trained GAT model
model_save_path = os.path.join(input_data_dir, 'gat_model.pth')  # Use input_data_dir for path consistency

# Load the pre-trained model
model = torch.load(model_save_path)
model.eval()  # Set the model to evaluation mode
print(f"Model loaded successfully from {model_save_path}")


Model loaded successfully from D:\博士文件\TCMMKG\GraphAI-for-TCM\Python\../Model\gat_model.pth


# Model Prediction

In [23]:
import os
import torch
import pandas as pd

# Define a helper function to format values to four significant figures
def format_value(val):
    return round(val, 4)

# Prediction function to process a range of samples
def predict_samples(start_index, end_index):
    output_results = []
    attn_results = []

    for i in range(start_index, end_index + 1):
        sample = merged_graphs[i]
        cpm_id = sample.cpm_id  # Unique identifier for each sample
        out, hg, attn_weights = model(sample)

        # First output table: cpm_id, output (out), and graph-level embeddings (hg)
        output_result = {
            "cpm_id": cpm_id,
            **{f"out_{j+1}": format_value(val) for j, val in enumerate(out.detach().cpu().numpy().flatten())},
            **{f"hg_{j+1}": format_value(val) for j, val in enumerate(hg.detach().cpu().numpy().flatten())}
        }
        output_results.append(output_result)

        # Second output table: cpm_id, node names, and attention weights
        node_names = sample.node_names  # Ensure that the graph has a `node_names` attribute

        # Dictionary to store edges and associated attention weights
        edge_dict = {}

        # Loop through each attention weight layer
        for layer_idx, (edge_index, attn_weight) in enumerate(attn_weights, start=1):
            edge_index_np = edge_index.detach().cpu().numpy()  # Shape [2, E]
            attn_weight_np = attn_weight.detach().cpu().numpy()  # Shape [E, heads]

            # Transpose edge_index_np to [E, 2]
            edges = edge_index_np.T

            # Extract node names and attention weights for each edge
            for edge, attn in zip(edges, attn_weight_np):
                node_idx_1, node_idx_2 = edge
                node_name_1 = node_names[int(node_idx_1)]
                node_name_2 = node_names[int(node_idx_2)]

                edge_key = (node_name_1, node_name_2)

                # Initialize dictionary entry for each edge
                if edge_key not in edge_dict:
                    edge_dict[edge_key] = {
                        "cpm_id": cpm_id,
                        "Source": node_name_1,
                        "Target": node_name_2
                    }

                # Unroll attention weights, storing each head in a separate column
                for head_idx, attn_value in enumerate(attn, start=1):
                    attn_col_name = f"attn_weights_{layer_idx}_head_{head_idx}"
                    edge_dict[edge_key][attn_col_name] = format_value(attn_value)

        # Append edge dictionary values to attn_results
        attn_results.extend(edge_dict.values())

    # Export the first table to a TSV file
    output_df = pd.DataFrame(output_results)
    output_path = os.path.join(output_dir, 'prediction_outputs.tsv')
    output_df.to_csv(output_path, sep='\t', index=False)
    print(f"Prediction outputs exported to {output_path} as TSV")

    # Export the second table to a TSV file
    attn_df = pd.DataFrame(attn_results)
    attn_path = os.path.join(output_dir, 'attention_weights.tsv')
    attn_df.to_csv(attn_path, sep='\t', index=False)
    print(f"Attention weights exported to {attn_path} as TSV")

# Define a custom range for prediction
start_index = 0  # Starting index
end_index = n - 1  # Ending index (can be changed as needed)
predict_samples(start_index, end_index)



Prediction outputs exported to D:\博士文件\TCMMKG\GraphAI-for-TCM\Python\../Data\prediction_outputs.tsv as TSV
Attention weights exported to D:\博士文件\TCMMKG\GraphAI-for-TCM\Python\../Data\attention_weights.tsv as TSV
