In [161]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Data
import numpy as np
import networkx as nx
from torch_geometric.nn import GATConv
import torch.optim as optim
import torch.nn as nn 

In [162]:
class GATEdgePredictor(nn.Module):
    def __init__(self, in_node_features, in_edge_features, hidden_dim):
        super(GATEdgePredictor, self).__init__()

        self.gat1 = GATConv(in_node_features, hidden_dim, heads=1, concat=True)
        self.gat2 = GATConv(hidden_dim, hidden_dim, heads=1, concat=True)

        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * hidden_dim + in_edge_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, edge_index, edge_attr):
        # Node embeddings
        x = F.relu(self.gat1(x, edge_index))
        x = F.relu(self.gat2(x, edge_index))

        src, dst = edge_index

        # Concatenate node embeddings and edge features
        edge_features = torch.cat([x[src], x[dst], edge_attr], dim=1)

        # Predict edge probabilities
        edge_probs = torch.sigmoid(self.edge_mlp(edge_features)).squeeze()

        return edge_probs

In [267]:
# Set random seed for reproducibility
#torch.manual_seed(42)
#np.random.seed(42)

# Generate a random graph
num_nodes = 100
G = nx.erdos_renyi_graph(n=num_nodes, p=0.05, directed=False)
print(G)
# Assign random node features (each node has 3 features)
num_node_features = 3
X = torch.randn((num_nodes, num_node_features))

# Assign random edge features (each edge has 2 features)
num_edge_features = 2
edge_index = torch.tensor(list(G.edges), dtype=torch.long).t()
E = torch.randn((edge_index.shape[1], num_edge_features))

# Fixed coefficients for true edge probabilities
W_node = torch.tensor([0.5, -0.3, 0.2])  # Weights for node features
W_edge = torch.tensor([0.7, -0.5])        # Weights for edge features
bias = 0.01

# Compute true edge probabilities using a linear combination
X_src = X[edge_index[0]]  # Source node features
X_dst = X[edge_index[1]]  # Destination node features

# Linear combination of node and edge features
#W_true = torch.sigmoid(
#    (X_src * X_dst)**2 @ W_node + torch.relu(E @ W_edge)**2 + bias
#)

import torch.nn.functional as F
W_neighbor = torch.randn(num_node_features, 1)

def compute_w_true(X, edge_index, E, W_node, W_edge, W_neighbor, bias):
    src, dst = edge_index

    # Compute node interaction term x_u^T W x_v
    node_interaction = (X[src] @ W_node) * (X[dst] @ W_node)  # (num_edges, hidden_dim)

    # Compute edge contribution term e_{uv}^T W_e
    edge_contribution = E @ W_edge  # (num_edges, hidden_dim)

    # Initialize neighbor sum storage
    neighbor_sum = torch.zeros_like(X)  # (num_nodes, num_node_features)

    # Compute neighbor sum manually
    for i in range(X.shape[0]):  # Iterate over nodes
        neighbors1 = edge_index[1][edge_index[0] == i]  # Find neighbors of node i
        neighbors2 = edge_index[0][edge_index[1] == i]
        if len(neighbors1)+len(neighbors2) > 0:
            neighbor_sum[i] = X[neighbors1].sum(dim=0)+X[neighbors2].sum(dim=0)  # Sum over neighbors

    # Compute neighborhood effect for each edge
    neighbor_effect = neighbor_sum[src] + neighbor_sum[dst]  # (num_edges, num_node_features)

    # Apply learnable transformation
    neighbor_effect = neighbor_effect @ W_neighbor
    
    # Compute final edge probabilities
    #node_interaction + edge_contribution + neighbor_effect + bias
    W_true = torch.sigmoid(neighbor_effect)
    return W_true


W_true = compute_w_true(X, edge_index, E, W_node, W_edge, W_neighbor, bias)

# Ensure W_true is between 0 and 1
W_true = W_true.squeeze()

# Create PyTorch Geometric Data object
graph_data = Data(x=X, edge_index=edge_index, edge_attr=E, y=W_true)

Graph with 100 nodes and 249 edges


In [268]:
# Define hyperparameters
num_node_features = X.shape[1]
num_edge_features = E.shape[1]
hidden_dim = 16
epochs = 100
lr = 0.01

# Initialize model, loss function, and optimizer
model = GATEdgePredictor(num_node_features, num_edge_features, hidden_dim)
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.MSELoss()

# Training loop
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()

    # Forward pass
    W_pred = model(X, edge_index, E)

    # Compute loss
    loss = loss_fn(W_pred, W_true)
    
    # Backward pass
    loss.backward()
    optimizer.step()

    # Print progress
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

print("Training complete!")

Epoch 0, Loss: 0.1779
Epoch 10, Loss: 0.1471
Epoch 20, Loss: 0.1320
Epoch 30, Loss: 0.1136
Epoch 40, Loss: 0.0931
Epoch 50, Loss: 0.0771
Epoch 60, Loss: 0.0646
Epoch 70, Loss: 0.0549
Epoch 80, Loss: 0.0476
Epoch 90, Loss: 0.0414
Training complete!


In [269]:
from sklearn.linear_model import LinearRegression

# Prepare inputs (sum of node features + edge features)
X_train = torch.cat([X[edge_index[0]] + X[edge_index[1]], E], dim=1).detach().numpy()
y_train = W_true.detach().numpy()

# Train a linear model
lin_reg = LinearRegression().fit(X_train, y_train)

# Compute MSE
y_pred = lin_reg.predict(X_train)
mse_baseline = np.mean((y_pred - y_train) ** 2)
print(f"Linear Regression Baseline MSE: {mse_baseline:.4f}")


Linear Regression Baseline MSE: 0.1727


In [259]:
X_train[0]

array([ 0.4233604 ,  0.49810147,  2.888595  , -0.14879291, -0.7156355 ],
      dtype=float32)

In [173]:
E

tensor([[ 1.0723e+00,  1.2485e+00],
        [ 4.4386e-01, -2.2930e-01],
        [ 7.1540e-01, -4.8257e-02],
        [ 6.2648e-02,  1.2718e+00],
        [-2.4801e-01,  3.8076e-01],
        [ 7.1859e-01,  1.4812e+00],
        [-1.2312e+00,  1.0442e-03],
        [-1.0367e+00, -1.2849e+00],
        [-3.2516e-01, -2.5357e-02],
        [-2.2495e-01, -1.1798e+00],
        [ 2.3872e-01,  8.3738e-01],
        [ 9.2240e-01,  1.6296e+00],
        [ 1.4948e+00,  9.5620e-01],
        [ 5.6154e-01,  5.6149e-01],
        [-1.2400e+00,  6.8579e-01],
        [ 2.2173e+00, -3.3435e-01],
        [-1.3294e+00, -3.9427e-02],
        [ 4.6671e-01, -8.4786e-01],
        [-1.3032e+00, -2.1927e+00],
        [-2.0101e-01, -5.1045e-01],
        [ 5.1959e-01,  5.6179e-01],
        [ 3.7321e-01, -2.8861e-01],
        [ 8.8765e-01, -5.8171e-01],
        [ 9.7880e-02, -1.3898e-01],
        [ 5.4765e-01, -1.0651e+00],
        [ 1.0277e+00, -5.3899e-01],
        [-1.8798e+00,  6.6874e-01],
        [-2.5868e-01, -7.879