## 1. Load Data

In [None]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data

class EdgePredictor(MessagePassing):
    def __init__(self, node_features, edge_features):
        super(EdgePredictor, self).__init__(aggr='add')  # "Add" aggregation.
        self.node_transform = torch.nn.Linear(node_features, edge_features)
        self.edge_transform = torch.nn.Linear(edge_features, edge_features)
        self.edge_predict = torch.nn.Linear(2 * edge_features, edge_features)

    def forward(self, x, edge_index, edge_attr):
        # x is the node features and has shape [N, node_features]
        # edge_index has shape [2, E]
        # edge_attr has shape [E, edge_features]

        # Step 1: Transform node and edge features
        x = self.node_transform(x)
        edge_attr = self.edge_transform(edge_attr)

        # Step 2: Start propagating messages.
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        # x_i, x_j have shape [E, edge_features]
        # edge_attr has shape [E, edge_features]
        
        # Step 3: Message function combines node and edge features.
        edge_message = torch.cat([x_i, x_j], dim=1) + edge_attr

        # Step 4: Edge attribute prediction
        return self.edge_predict(edge_message)

# Create a synthetic dataset
num_nodes = 10
node_features = 5
edge_features = 5
x = torch.randn(num_nodes, node_features)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
edge_attr = torch.randn((edge_index.shape[1], edge_features), dtype=torch.float32)

data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

# Create the model and optimizer
model = EdgePredictor(node_features, edge_features)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Synthetic target edge attributes for training
edge_attr_target = torch.randn((data.edge_index.shape[1], edge_features), dtype=torch.float32)

for epoch in range(10):
    optimizer.zero_grad()
    edge_attr_pred = model(data.x, data.edge_index, data.edge_attr)
    loss = ((edge_attr_pred - edge_attr_target)**2).mean()  # Mean Squared Error Loss
    loss.backward()
    optimizer.step()
    print('Epoch:', epoch+1, ', Loss:', loss.item())


: 