# Graph Encoding for Uighur herbal formulas

In [None]:
import os
import pandas as pd
import torch
import networkx as nx
import logging
from torch_geometric.utils import from_networkx
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

# Configure logging
logging.basicConfig(filename='graph_processing.log', level=logging.INFO)

# Set working directory and define input/output paths
work_dir = os.getcwd()  # Use the current directory as work_dir
input_data_dir = os.path.join(work_dir, '../Data')  # Set ../Data as input data location
output_dir = os.path.join(work_dir, '../Data')  # Set ../Data as output data location


# Define file paths
file_path = os.path.join(input_data_dir, 'Test_input.xlsx')  # New input file
output_file = os.path.join(output_dir, 'all_graphs_to_be_predicted.pt')
chp_properties_path = os.path.join(input_data_dir, 'UHP_Medicinal_properties_encode.tsv')
chp_encoder_path = os.path.join(input_data_dir, 'UHP_Encoder.tsv')

# Load input data
data = pd.read_excel(file_path)
chp_properties_data = pd.read_csv(chp_properties_path, sep='\t')
chp_encoder_data = pd.read_csv(chp_encoder_path, sep='\t')

# Main function: construct each graph and convert it to PyTorch Geometric format
def process_graph_to_pyg(cpm_id, cpm_chp_data, chp_properties_data, chp_encoder_data):
    try:
        chp_encoder_data.iloc[:, 1:] = chp_encoder_data.iloc[:, 1:].apply(pd.to_numeric, errors='coerce')
        cpm_data = cpm_chp_data[cpm_chp_data['CPM_ID'] == cpm_id]

        G = nx.MultiGraph()
        chp_ids = cpm_data['CHP_ID'].unique()
        chp_encoder = chp_encoder_data[chp_encoder_data['CHP_ID'].isin(chp_ids)]

        # Add actual nodes with features
        for _, row in chp_encoder.iterrows():
            chp_id = row['CHP_ID']
            chp_attr = row[1:].tolist()
            dosage_ratio = cpm_data[cpm_data['CHP_ID'] == chp_id]['Dosage_ratio']
            dosage_ratio = pd.to_numeric(dosage_ratio, errors='coerce').fillna(0).iloc[0]
            chp_attr.append(dosage_ratio)
            G.add_node(chp_id, feature=chp_attr, type='Actual', name=chp_id)

        # Add virtual nodes (therapeutic properties)
        virtual_node_features = chp_attr.copy()
        virtual_nodes = ['Dry therapeutic', 'Moist therapeutic', 'Cold therapeutic', 'Hot therapeutic']
        for vn in virtual_nodes:
            G.add_node(vn, feature=virtual_node_features, type='Virtual', name=vn)

        # Connect actual nodes to virtual nodes based on properties
        chp_properties = chp_properties_data[chp_properties_data['CHP_ID'].isin(chp_ids)].copy()
        chp_properties['Level'] = chp_properties['Level'].astype(float)

        for chp_id in chp_ids:
            chp_rows = chp_properties[chp_properties['CHP_ID'] == chp_id]
            if not chp_rows.empty:
                for _, row in chp_rows.iterrows():
                    attribute = [row['Level']]
                    G.add_edge(chp_id, row['Medicinal_properties'], attr=attribute)
            else:
                # Default edge to 'Cold' and 'Hot' when property is missing
                default_level = [0]
                G.add_edge(chp_id, 'Cold therapeutic', attr=default_level)
                G.add_edge(chp_id, 'Hot therapeutic', attr=default_level)

        # Update virtual node features based on neighbors
        update_virtual_node_features(G, virtual_nodes, virtual_node_features)

        # Compute average edge attributes between virtual nodes
        initial_edge_attrs = calculate_initial_edge_attributes(G, virtual_nodes)

        # Connect virtual nodes (with some constraints)
        for i, vn1 in enumerate(virtual_nodes):
            for j, vn2 in enumerate(virtual_nodes):
                if i < j:
                    if (vn1 == 'Dry therapeutic' and vn2 == 'Moist therapeutic') or \
                       (vn1 == 'Moist therapeutic' and vn2 == 'Dry therapeutic') or \
                       (vn1 == 'Cold therapeutic' and vn2 == 'Hot therapeutic') or \
                       (vn1 == 'Hot therapeutic' and vn2 == 'Cold therapeutic'):
                        continue
                    G.add_edge(vn1, vn2, attr=initial_edge_attrs)

        # Convert to directed graph and ensure bidirectional edge attributes
        G = G.to_directed()
        for u, v, k, data in G.edges(keys=True, data=True):
            if 'attr' in data:
                G.edges[v, u, k]['attr'] = data['attr']

        # Convert to PyG format
        pyg_graph = convert_to_pyg_graph(G)
        pyg_graph.node_names = [G.nodes[node]['name'] for node in G.nodes]
        pyg_graph.cpm_id = cpm_id

        return pyg_graph
    except Exception as e:
        logging.error(f"Error processing CPM_ID {cpm_id}: {e}")
        print(f"Error processing CPM_ID {cpm_id}: {e}")
        return None

# Update virtual node features using weighted average from neighbors
def update_virtual_node_features(G, virtual_nodes, node_attr_names):
    for vn in virtual_nodes:
        connected_nodes = [node for node in G.neighbors(vn) if G.nodes[node]['type'] == 'Actual']
        if connected_nodes:
            initial_features = G.nodes[vn]['feature']
            weighted_features, total_weight = calculate_weighted_features(G, connected_nodes, vn, node_attr_names)
            if total_weight != 0:
                updated_features = [wf / total_weight for wf in weighted_features]
                G.nodes[vn]['feature'] = [(uf + if_) / 2 for uf, if_ in zip(updated_features, initial_features)]
            else:
                G.nodes[vn]['feature'] = initial_features

# Helper: calculate weighted feature contributions to virtual nodes
def calculate_weighted_features(G, connected_nodes, vn, node_attr_names):
    weighted_features = [0] * len(node_attr_names)
    total_weight = 0
    for node in connected_nodes:
        node_features = G.nodes[node]['feature']
        edge_data = G.get_edge_data(node, vn)
        for edge_key in edge_data:
            edge_attr = edge_data[edge_key]['attr']
            for ea in edge_attr:
                weighted_features = [wf + f * ea for wf, f in zip(weighted_features, node_features)]
                total_weight += ea
    return weighted_features, total_weight

# Compute average attribute value between virtual nodes
def calculate_initial_edge_attributes(G, virtual_nodes):
    initial_edge_attrs = []
    for vn in virtual_nodes:
        for node in G.neighbors(vn):
            if G.nodes[node]['type'] == 'Actual':
                edge_data = G.get_edge_data(node, vn)
                for edge_key in edge_data:
                    edge_attr = edge_data[edge_key]['attr']
                    initial_edge_attrs.append(edge_attr)
    avg_initial_edge_attr = [sum(x) / len(initial_edge_attrs) for x in zip(*initial_edge_attrs)]
    return avg_initial_edge_attr

# Convert NetworkX graph to PyTorch Geometric graph
def convert_to_pyg_graph(G):
    pyg_graph = from_networkx(G)
    pyg_graph.x = torch.tensor([G.nodes[node]['feature'] for node in G.nodes], dtype=torch.float)
    pyg_graph.edge_attr = torch.tensor([G.edges[edge]['attr'] for edge in G.edges], dtype=torch.float)
    pyg_graph.node_types = [G.nodes[node]['type'] for node in G.nodes]
    return pyg_graph

# Process all graphs in parallel and save the result
pyg_graphs = []
unique_cpm_ids = data['CPM_ID'].unique()

with ThreadPoolExecutor() as executor:
    futures = [
        executor.submit(
            process_graph_to_pyg,
            cpm_id,
            data[data['CPM_ID'] == cpm_id],
            chp_properties_data,
            chp_encoder_data
        ) for cpm_id in unique_cpm_ids
    ]
    for future in tqdm(as_completed(futures), total=len(futures), desc="Processing graphs"):
        result = future.result()
        if result is not None:
            pyg_graphs.append(result)

# Save the PyTorch Geometric graph list to file
torch.save(pyg_graphs, output_file)
print(f"Successfully saved {len(pyg_graphs)} graphs to '{output_file}'")

# Examine Graph Data

In [None]:
import os
import torch
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib import font_manager

# 设置中文字体（用于节点名含中文的情况）
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']  # For Windows systems
# For macOS or other systems, you might need to change the font (e.g., 'Arial Unicode MS')

# Set working directory and define input path
work_dir = os.getcwd()  # Use current working directory
input_data_dir = os.path.join(work_dir, '../Data')  # Input data path

# Load saved PyG graphs
output_file = os.path.join(input_data_dir, 'all_graphs_to_be_predicted.pt')
loaded_pyg_graphs = torch.load(output_file, weights_only=False)

# Display number of graphs
num_graphs = len(loaded_pyg_graphs)
print(f"图的数量 (Number of graphs): {num_graphs}")

# Display number of label columns for the first graph (if exists)
if len(loaded_pyg_graphs) > 0:
    first_graph = loaded_pyg_graphs[0]
    num_label_columns = first_graph.y.size(0) if hasattr(first_graph, 'y') and first_graph.y is not None else 0
    print(f"标签列的数量 (Number of label columns): {num_label_columns}")

# Select and inspect a specific graph (by index)
graph_index = 1  # Change this index to inspect a different graph
first_graph = loaded_pyg_graphs[graph_index]
print(f"\n第 {graph_index + 1} 张图的详细信息 (Details of Graph {graph_index + 1}):")

# Display node features (x)
print(f"节点特征 (Node Features):\n{first_graph.x}")

# Display edge index (edge connections)
print(f"边索引 (Edge Index):\n{first_graph.edge_index}")

# Display edge attributes if available
if hasattr(first_graph, 'edge_attr'):
    print(f"边属性 (Edge Attributes):\n{first_graph.edge_attr}")
else:
    print("没有边属性 (No Edge Attributes)")

# Display label vector (if exists)
if hasattr(first_graph, 'y'):
    print(f"标签 (Labels):\n{first_graph.y}")
else:
    print("没有标签 (No Labels)")

# Display associated CPM_ID (if exists)
if hasattr(first_graph, 'cpm_id'):
    print(f"CPM_ID: {first_graph.cpm_id}")
else:
    print("没有 CPM_ID")

# Display node names (for labeling in the graph)
if hasattr(first_graph, 'node_names'):
    print(f"节点名称 (Node Names):\n{first_graph.node_names}")
    node_names = first_graph.node_names
else:
    print("没有节点名称 (No Node Names)")
    node_names = None

# Display node types (actual vs virtual)
if hasattr(first_graph, 'node_types'):
    print(f"节点类型 (Node Types):\n{first_graph.node_types}")
else:
    print("没有节点类型 (No Node Types)")

# Construct a NetworkX graph for visualization
G = nx.Graph()

# Add edges from PyG graph to NetworkX graph
edge_index = first_graph.edge_index.numpy()
for i in range(edge_index.shape[1]):
    G.add_edge(edge_index[0][i], edge_index[1][i])

# If node names are available, use them for labels
if node_names is not None:
    labels = {i: node_names[i] for i in range(len(node_names))}
else:
    labels = None

# Visualize the graph
plt.figure(figsize=(10, 10))
pos = nx.spring_layout(G, seed=42)  # Use a fixed layout for reproducibility
nx.draw(
    G, pos, with_labels=True, node_size=500, node_color='skyblue',
    font_size=10, font_weight='bold', font_color='black'
)

# If node names exist, display them with Chinese font support
if labels is not None:
    nx.draw_networkx_labels(
        G, pos, labels=labels, font_size=12,
        font_family='Microsoft YaHei', font_weight='bold'
    )

plt.title(f"Graph {graph_index + 1} Visualization")
plt.show()
