## Imports

In [None]:
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import from_networkx

## Graph Generation Functions

In [None]:
def find_connected_subgraph(G, size=4):
    """Find a connected subgraph of specified size"""
    for component in nx.connected_components(G):
        subgraph = G.subgraph(component)
        if len(subgraph) >= size:
            start_node = np.random.choice(list(subgraph.nodes()))
            nodes = list(nx.bfs_tree(subgraph, start_node))[:size]
            return nodes
    return None

def generate_graph(num_nodes=100, edge_prob=0.05):
    """Generate a random graph ensuring it has at least one connected component of size 4"""
    while True:
        G = nx.erdos_renyi_graph(n=num_nodes, p=edge_prob)
        connected_nodes = find_connected_subgraph(G, size=4)
        if connected_nodes is not None:
            return G, connected_nodes

## Feature Computation Functions

In [None]:
def compute_features(G, nodes):
    """Compute graph features including specific node features"""
    if nodes is None or len(nodes) != 4:
        raise ValueError("Must provide exactly 4 nodes for feature computation")
    
    num_nodes = G.number_of_nodes()
    if num_nodes == 0:
        return torch.zeros(10, dtype=torch.float32)
    
    features = []
    
    # Node-specific features for the provided nodes
    for node in nodes:
        # Basic metrics
        degree = G.degree[node]
        clustering = nx.clustering(G, node)
        avg_neighbor_degree = np.mean([G.degree[n] 
                               for n in G.neighbors(node)]) if list(G.neighbors(node)) else 0
            
        # Centrality metrics
        betweenness = nx.betweenness_centrality(G)[node]
        closeness = nx.closeness_centrality(G)[node]
        pagerank = nx.pagerank(G)[node]
            
        # Handle eigenvector centrality
        try:
            eigenvector = nx.eigenvector_centrality_numpy(G)[node]
        except (nx.NetworkXError, nx.AmbiguousSolution):
            eigenvector = 0
            
        # Structural metrics
        core_number = nx.core_number(G)[node]
        local_efficiency = nx.local_efficiency(G)
            
        node_features = [
            degree,
            clustering,
            avg_neighbor_degree,
            betweenness,
            closeness,
            pagerank,
            eigenvector,
            core_number,
            local_efficiency
        ]
        features.extend(node_features)
    
    # Global node-level features (averaged)
    degrees = [d for _, d in G.degree()]
    clustering_coeffs = [nx.clustering(G, node) for node in G.nodes()]
    neighbor_degrees = [np.mean([G.degree[n] for n in G.neighbors(node)]) if list(G.neighbors(node)) else 0 
                       for node in G.nodes()]
    betweenness = list(nx.betweenness_centrality(G).values())
    closeness = list(nx.closeness_centrality(G).values())
    pagerank = list(nx.pagerank(G).values())
    
    # Handle eigenvector centrality for all nodes
    try:
        eigenvector = list(nx.eigenvector_centrality_numpy(G).values())
    except (nx.NetworkXError, nx.AmbiguousSolution):
        eigenvector = [0] * num_nodes
    
    core_numbers = list(nx.core_number(G).values())
    
    # Calculate global averages
    avg_features = [
        np.mean(degrees),
        np.mean(clustering_coeffs),
        np.mean(neighbor_degrees),
        np.mean(betweenness),
        np.mean(closeness),
        np.mean(pagerank),
        np.mean(eigenvector),
        np.mean(core_numbers),
        nx.local_efficiency(G)
    ]
    
    # Global features
    global_features = [
        nx.density(G)
    ]
    
    # Combine all features
    features.extend(avg_features + global_features)
    return torch.tensor(features, dtype=torch.float32)

def prepare_node_features(G):
    """Prepare node features including removal flag"""
    num_nodes = G.number_of_nodes()
    # Basic features for each node (5 base features + 1 removal flag)
    features = torch.zeros(num_nodes, 6)
    
    for i in range(num_nodes):
        features[i] = torch.tensor([
            G.degree[i],
            nx.clustering(G, i),
            np.mean([G.degree[n] for n in G.neighbors(i)]) if list(G.neighbors(i)) else 0,
            list(nx.betweenness_centrality(G).values())[i],
            list(nx.closeness_centrality(G).values())[i],
            0  # Removal flag, will be set later
        ])
    return features

def prepare_edge_index(G):
    """Convert NetworkX graph edges to PyG edge index"""
    return torch.tensor([[e[0], e[1]] for e in G.edges()]).t().contiguous()

## Data Processing Functions

In [None]:
def process_graph_data(G, nodes_to_remove):
    """Process a graph to create training data"""
    # Original graph features
    original_features = compute_features(G, nodes_to_remove)
    
    # Create residual graph
    residual_G = G.copy()
    residual_G.remove_nodes_from(nodes_to_remove)
    
    # Get features for residual graph
    largest_component = max(nx.connected_components(residual_G), key=len)
    residual_G_main = residual_G.subgraph(largest_component)
    residual_features = compute_features(residual_G_main, 
                                       list(residual_G_main.nodes())[:4] if len(residual_G_main) >= 4 else None)
    
    # Create PyG data object
    data = Data(
        x=prepare_node_features(G),
        edge_index=prepare_edge_index(G),
        removed_nodes=torch.tensor(nodes_to_remove),
        original_features=original_features,
        residual_features=residual_features
    )
    return data

## Model Definition

In [None]:
class EnhancedGNNModel(nn.Module):
    def __init__(self, node_feature_dim, hidden_dim):
        super(EnhancedGNNModel, self).__init__()
        self.conv1 = GCNConv(node_feature_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # Separate MLPs for different feature types
        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 9)  # 9 features per node
        )
        
        self.global_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 10)  # Global graph features
        )
        
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        batch = data.batch if hasattr(data, 'batch') else torch.zeros(x.size(0), dtype=torch.long)
        
        # Graph convolutions
        x = self.conv1(x, edge_index).relu()
        x = self.dropout(x)
        x = self.conv2(x, edge_index).relu()
        x = self.dropout(x)
        x = self.conv3(x, edge_index).relu()
        
        # Process removed nodes
        removed_features = []
        for node in data.removed_nodes:
            node_feat = self.node_mlp(x[node])
            removed_features.append(node_feat)
        removed_features = torch.stack(removed_features)  # Shape: [4, 9]
        removed_features = removed_features.view(1, -1)  # Reshape to [1, 36]
        
        # Global graph features
        global_x = global_mean_pool(x, batch)  # Shape: [1, hidden_dim]
        global_features = self.global_mlp(global_x)  # Shape: [1, 10]
        
        return torch.cat([removed_features, global_features], dim=1)  # Shape: [1, 46]

## Training Functions

## Execution

In [None]:
# Generate graph and prepare data
G, selected_nodes = generate_graph(num_nodes=100, edge_prob=0.05)
data = process_graph_data(G, selected_nodes)

# Set removal flags for selected nodes
data.x[selected_nodes, -1] = 1

# Initialize and train model
model = EnhancedGNNModel(node_feature_dim=6, hidden_dim=64)
train_model(model, data, data.original_features.unsqueeze(0))

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 4 but got size 1 for tensor number 1 in the list.