In [1]:
import torch
import networkx as nx
from torch_geometric.utils import from_networkx

In [None]:
def load_gt_for_pytorch(gt_file_path, node_features=None, edge_features=None, target=None):
    """
    Load .gt file and convert to PyTorch Geometric format.
    
    Parameters:
    -----------
    gt_file_path : str
        Path to .gt file
    node_features : list
        List of vertex property names for node features
    edge_features : list  
        List of edge property names for edge features
    target : str
        Vertex property name for target/labels
        
    Returns:
    --------
    torch_geometric.data.Data
    """
    # Load graph-tool graph
    g_gt = gt.load_graph(gt_file_path)
    
    # Print available properties for debugging
    print(f"Available vertex properties: {list(g_gt.vertex_properties.keys())}")
    print(f"Available edge properties: {list(g_gt.edge_properties.keys())}")
    
    # Convert to PyTorch Geometric
    data = graphtool_to_pytorch_geometric(
        g_gt, 
        node_feature_props=node_features,
        edge_feature_props=edge_features,
        target_prop=target
    )
    
    return data

In [2]:
def torching_data(data, data_label, features, target):
    data_clone = data.clone()

    target_key = target

    features = []
    target = []

    for key in feature_keys:
        features.append(data_clone[key].unsqueeze(1))
    if target_key in data_clone:
        target.append(data_clone[target_key].unsqueeze(1))
    if features:
        data_clone.x = torch.cat(features, dim=1)  # shape: [num_nodes, num_features]
    if target:
        data_clone.y = torch.cat(target, dim=1)    # shape: [num_nodes, 1]

    torch.save(data_clone, f'data/data_{data_label}.pt')

    return data_clone

In [3]:
G_wiki_minmax = nx.read_graphml("../Test_WikiDataNet/data/G_wiki_minmax.graphml")
G_wiki_robust = nx.read_graphml("../Test_WikiDataNet/data/G_wiki_robust.graphml")
G_wiki_standard = nx.read_graphml("../Test_WikiDataNet/data/G_wiki_standard.graphml")

In [4]:
data_standard = from_networkx(G_wiki_standard)
data_robust = from_networkx(G_wiki_robust)
data_minmax = from_networkx(G_wiki_minmax)

In [6]:
feature_keys = ["num_categories", 
                "num_links", 
                "page_length", 
                "num_references", 
                "num_sections", 
                "num_templates", 
                "has_infobox_encoded",
                "protection_status_encoded",
                "assessment_source_umap_1",
                "assessment_source_umap_2",
                "categories_umap_1",
                "categories_umap_2",
                "categories_umap_3",
                "templates_umap_1",
                "templates_umap_2",
                "templates_umap_3",
                "degree_centrality",
                "pagerank",
                "reciprocity",
                "hub",
                "authority",
                "eigen"]

In [7]:
# Regressions
data_standard_reg = torching_data(data_standard, "standard_reg", feature_keys, "QC_num_log")
data_robust_reg = torching_data(data_robust, "robust_reg", feature_keys, "QC_num_log")
data_minmax_reg = torching_data(data_minmax, "minmax_reg", feature_keys, "QC_num_log")

In [8]:
data_standard_cat = torching_data(data_standard, "standard_cat", feature_keys, "QC_cat")
data_robust_cat = torching_data(data_robust, "robust_cat", feature_keys, "QC_cat")
data_minmax_cat = torching_data(data_minmax, "minmax_cat", feature_keys, "QC_cat")

In [9]:
data_standard_catagg = torching_data(data_standard, "standard_catagg", feature_keys, "QC_aggcat")
data_robust_catagg = torching_data(data_robust, "robust_catagg", feature_keys, "QC_aggcat")
data_minmax_catagg = torching_data(data_minmax, "minmax_catagg", feature_keys, "QC_aggcat")