# Graph Encoding for Candidate Formula Generation

In [3]:
import os
import pandas as pd
import torch
import networkx as nx
import logging
from torch_geometric.utils import from_networkx
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

# Configure logging to log messages to a file
logging.basicConfig(filename='graph_processing.log', level=logging.INFO)

# Set working directory and define input/output paths
work_dir = os.getcwd()  # Use the current directory as work_dir
input_data_dir = os.path.join(work_dir, '../Data')  # Set ../Data as input data location
output_dir = os.path.join(work_dir, '../Data')  # Set ../Data as input data location

# Define specific file paths based on directory settings
file_path = os.path.join(input_data_dir, 'Test_input.xlsx')
output_file = os.path.join(output_dir, 'all_graphs_to_be_predicted.pt')
chp_properties_path = os.path.join(input_data_dir, 'CHP_Medicinal_properties.tsv')
chp_encoder_path = os.path.join(input_data_dir, 'CHP_Encoder.tsv')

# Load data from specified files
data = pd.read_excel(file_path)
chp_properties_data = pd.read_csv(chp_properties_path, sep='\t')
chp_encoder_data = pd.read_csv(chp_encoder_path, sep='\t')

# Main function to process individual graphs into PyG format
def process_graph_to_pyg(cpm_id, cpm_chp_data, chp_properties_data, chp_encoder_data):
    try:
        # Data preprocessing
        chp_properties_data[['x_rank', 'y_rank']] = chp_properties_data[['x_rank', 'y_rank']].apply(pd.to_numeric, errors='coerce')
        chp_encoder_data.iloc[:, 1:] = chp_encoder_data.iloc[:, 1:].apply(pd.to_numeric, errors='coerce')

        # Filter data to specific CPM_ID
        cpm_data = cpm_chp_data[cpm_chp_data['CPM_ID'] == cpm_id]

        # Initialize an empty graph for current formula
        G = nx.MultiGraph()

        # Add actual nodes with features, including Dosage_ratio attribute
        chp_ids = cpm_data['CHP_ID'].unique()
        chp_encoder = chp_encoder_data[chp_encoder_data['CHP_ID'].isin(chp_ids)]
        
        for _, row in chp_encoder.iterrows():
            chp_id = row['CHP_ID']
            chp_attr = row[1:].tolist()

            # Retrieve Dosage_ratio and handle NaN values by filling with 0
            dosage_ratio = cpm_data[cpm_data['CHP_ID'] == chp_id]['Dosage_ratio']
            dosage_ratio = pd.to_numeric(dosage_ratio, errors='coerce').fillna(0).iloc[0]
            chp_attr.append(dosage_ratio)  # Append Dosage_ratio as last feature

            # Add actual node with attributes
            G.add_node(chp_id, feature=chp_attr, type='Actual', name=chp_id)

        # Add virtual nodes based on actual node features as a template
        virtual_node_features = chp_attr.copy()
        virtual_nodes = ['Medicinal flavor', 'Meridian tropism', 'Therapeutic nature']
        for vn in virtual_nodes:
            G.add_node(vn, feature=virtual_node_features, type='Virtual', name=vn)

        # Connect actual nodes to virtual nodes
        chp_properties = chp_properties_data[chp_properties_data['CHP_ID'].isin(chp_ids)].copy()
        chp_properties[['x_rank', 'y_rank']] = chp_properties[['x_rank', 'y_rank']].astype(float) / 23

        for _, row in chp_properties.iterrows():
            chp_id = row['CHP_ID']
            attribute = row[['x_rank', 'y_rank']].tolist()
            G.add_edge(chp_id, row['Class'], attr=attribute)  # Connect actual and virtual nodes

        # Update virtual node attributes based on connections to actual nodes
        update_virtual_node_features(G, virtual_nodes, virtual_node_features)

        # Calculate initial attributes for edges between virtual nodes
        initial_edge_attrs = calculate_initial_edge_attributes(G, virtual_nodes)

        # Add edges between virtual nodes with initial attributes
        for i, vn1 in enumerate(virtual_nodes):
            for j, vn2 in enumerate(virtual_nodes):
                if i < j:
                    G.add_edge(vn1, vn2, attr=initial_edge_attrs)

        # Convert the undirected graph to a directed graph
        G = G.to_directed()

        # Ensure symmetry in edge attributes for both directions of each edge
        for u, v, k, data in G.edges(keys=True, data=True):
            if 'attr' in data:
                G.edges[v, u, k]['attr'] = data['attr']

        # Convert the graph to a PyG format
        pyg_graph = convert_to_pyg_graph(G)

        # Add node names and CPM_ID to PyG graph for reference
        pyg_graph.node_names = [G.nodes[node]['name'] for node in G.nodes]
        pyg_graph.cpm_id = cpm_id

        return pyg_graph
    except Exception as e:
        logging.error(f"Error processing CPM_ID {cpm_id}: {e}")
        print(f"Error processing CPM_ID {cpm_id}: {e}")  # Debug output
        return None

# Update attributes for virtual nodes based on actual node connections
def update_virtual_node_features(G, virtual_nodes, node_attr_names):
    for vn in virtual_nodes:
        connected_nodes = [node for node in G.neighbors(vn) if G.nodes[node]['type'] == 'Actual']
        if connected_nodes:
            initial_features = G.nodes[vn]['feature']
            weighted_features, total_weight = calculate_weighted_features(G, connected_nodes, vn, node_attr_names)
            if total_weight != 0:
                updated_features = [wf / total_weight for wf in weighted_features]
                G.nodes[vn]['feature'] = [(uf + if_) / 2 for uf, if_ in zip(updated_features, initial_features)]
            else:
                G.nodes[vn]['feature'] = initial_features

# Calculate weighted features for virtual nodes
def calculate_weighted_features(G, connected_nodes, vn, node_attr_names):
    weighted_features = [0] * len(node_attr_names)
    total_weight = 0
    for node in connected_nodes:
        node_features = G.nodes[node]['feature']
        edge_data = G.get_edge_data(node, vn)
        for edge_key in edge_data:
            edge_attr = edge_data[edge_key]['attr']
            for ea in edge_attr:
                weighted_features = [wf + f * ea for wf, f in zip(weighted_features, node_features)]
                total_weight += ea
    return weighted_features, total_weight

# Calculate average edge attributes for initial connections between virtual nodes
def calculate_initial_edge_attributes(G, virtual_nodes):
    initial_edge_attrs = []
    for vn in virtual_nodes:
        for node in G.neighbors(vn):
            if G.nodes[node]['type'] == 'Actual':
                edge_data = G.get_edge_data(node, vn)
                for edge_key in edge_data:
                    edge_attr = edge_data[edge_key]['attr']
                    initial_edge_attrs.append(edge_attr)
    avg_initial_edge_attr = [sum(x) / len(initial_edge_attrs) for x in zip(*initial_edge_attrs)]
    return avg_initial_edge_attr

# Convert networkx graph to PyTorch Geometric format with features and edge attributes
def convert_to_pyg_graph(G):
    pyg_graph = from_networkx(G)
    pyg_graph.x = torch.tensor([G.nodes[node]['feature'] for node in G.nodes], dtype=torch.float)
    pyg_graph.edge_attr = torch.tensor([G.edges[edge]['attr'] for edge in G.edges], dtype=torch.float)
    pyg_graph.node_types = [G.nodes[node]['type'] for node in G.nodes]
    return pyg_graph

# Process graphs concurrently and save results
pyg_graphs = []
unique_cpm_ids = data['CPM_ID'].unique()

with ThreadPoolExecutor() as executor:
    futures = [
        executor.submit(
            process_graph_to_pyg,
            cpm_id,
            data[data['CPM_ID'] == cpm_id],
            chp_properties_data,
            chp_encoder_data
        ) for cpm_id in unique_cpm_ids
    ]
    
    for future in tqdm(as_completed(futures), total=len(futures), desc="Processing graphs"):
        result = future.result()
        if result is not None:
            pyg_graphs.append(result)

# Save the processed PyG graphs to a file
torch.save(pyg_graphs, output_file)
print(f"Successfully saved {len(pyg_graphs)} graphs to '{output_file}'")


Processing graphs: 100%|█████████████████████████████████████████████████████████████████████████| 9/9 [00:00<?, ?it/s]

Successfully saved 9 graphs to 'D:\博士文件\TCMMKG\GraphAI-for-TCM\Python\../Data\all_graphs_to_be_predicted.pt'





# Examine Graph Data

In [4]:
import os
import torch

# Set working directory and define input/output paths
work_dir = os.getcwd()  # Use the current directory as work_dir
input_data_dir = os.path.join(work_dir, '../Data')  # Set ../Data as input data location
output_dir = os.path.join(work_dir, '../Data')  # Set ../Data as input data location

# Load PyG graphs from the specified file
output_file = os.path.join(input_data_dir, 'all_graphs_to_be_predicted.pt')
loaded_pyg_graphs = torch.load(output_file)

# Check and display the number of graphs loaded
num_graphs = len(loaded_pyg_graphs)
print(f"Number of graphs: {num_graphs}")

# Display the number of label columns in each graph (if labels are present)
if len(loaded_pyg_graphs) > 0:
    first_graph = loaded_pyg_graphs[0]
    num_label_columns = first_graph.y.size(0) if first_graph.y is not None else 0
    print(f"Number of label columns: {num_label_columns}")

# Display detailed information for the first graph (index can be changed as needed)
graph_index = 1  # Modify this index to view other graphs if needed
first_graph = loaded_pyg_graphs[graph_index]
print(f"\nDetails for Graph {graph_index + 1}:")

# Display node features
print(f"Node Features:\n{first_graph.x}")

# Display edge index (connections between nodes)
print(f"Edge Index:\n{first_graph.edge_index}")

# Display edge attributes if present
if hasattr(first_graph, 'edge_attr'):
    print(f"Edge Attributes:\n{first_graph.edge_attr}")
else:
    print("No Edge Attributes")

# Display labels if present
if hasattr(first_graph, 'y'):
    print(f"Labels:\n{first_graph.y}")
else:
    print("No Labels")

# Display CPM_ID if present
if hasattr(first_graph, 'cpm_id'):
    print(f"CPM_ID: {first_graph.cpm_id}")
else:
    print("No CPM_ID")

# Display node names if present
if hasattr(first_graph, 'node_names'):
    print(f"Node Names:\n{first_graph.node_names}")
else:
    print("No Node Names")

# Display node types if present
if hasattr(first_graph, 'node_types'):
    print(f"Node Types:\n{first_graph.node_types}")
else:
    print("No Node Types")


Number of graphs: 9
Number of label columns: 0

Details for Graph 2:
Node Features:
tensor([[ 1.0000,  0.0000,  0.0000,  ...,  0.1016, -0.2338,  0.0000],
        [ 1.0000,  0.0000,  0.0000,  ...,  0.1821, -0.1777,  0.0000],
        [ 1.0000,  0.0000,  0.0000,  ...,  0.0088, -0.0551,  0.0000],
        ...,
        [ 0.4009,  0.0495,  0.5495,  ...,  0.2321, -0.3584,  0.0000],
        [ 0.3612,  0.0543,  0.5845,  ...,  0.2225, -0.3496,  0.0000],
        [ 0.3765,  0.0618,  0.5618,  ...,  0.2204, -0.3452,  0.0000]])
Edge Index:
tensor([[ 0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,
          3,  3,  3,  3,  4,  4,  4,  4,  5,  5,  5,  5,  5,  6,  6,  6,  6,  7,
          7,  7,  7,  7,  8,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9,  9,
          9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
         10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11,
         11, 11, 11, 11, 11, 11, 11, 11],
        [ 9, 10, 10, 10