In [4]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.utils import add_self_loops

# Load the MUTAG dataset (Molecular graphs)
dataset = TUDataset(root="data/TUDataset", name="MUTAG")

# Split into train and test sets (MUTAG is small, so we use 80% train, 20% test)
train_size = int(0.8 * len(dataset))
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

class EdgePredictor(MessagePassing):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__(aggr='add')  # Message passing with sum aggregation
        self.node_mlp = torch.nn.Sequential(
            torch.nn.Linear(in_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, hidden_channels)
        )
        self.edge_mlp = torch.nn.Sequential(
            torch.nn.Linear(2 * hidden_channels + 1, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, x, edge_index, edge_attr):
        # Ensure dtype consistency
        x = x.float()
        edge_attr = edge_attr.float()

        # Apply node transformation
        x = self.node_mlp(x)

        # Message Passing
        edge_index, _ = add_self_loops(edge_index)  # Ignore edge_attr modification
        if edge_attr is not None:
            edge_attr = torch.cat([edge_attr, torch.zeros((edge_index.shape[1] - edge_attr.shape[0], 1), device=edge_attr.device)], dim=0)

        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        if edge_attr is None:
            edge_attr = torch.zeros((x_i.shape[0], 1), device=x_i.device)
        else:
            edge_attr = edge_attr.view(-1, 1)  # Ensure correct shape

        # Debugging prints
        print(f"x_i: {x_i.shape}, x_j: {x_j.shape}, edge_attr: {edge_attr.shape}")

        edge_features = torch.cat([x_i, x_j, edge_attr], dim=1)
        return self.edge_mlp(edge_features)


    def update(self, aggr_out):
        return aggr_out  # Directly use aggregated messages

# Define Model
class GNN(torch.nn.Module):
    def __init__(self, node_features, edge_features, hidden_dim):
        super().__init__()
        self.conv1 = EdgePredictor(node_features, hidden_dim, hidden_dim)
        self.conv2 = EdgePredictor(hidden_dim, hidden_dim, 1)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        edge_pred = self.conv1(x, edge_index, edge_attr)
        edge_pred = self.conv2(x, edge_index, edge_pred)
        return edge_pred

# Training function
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)

        # Ensure correct dtype
        data.x = data.x.float()
        data.edge_attr = data.edge_attr.float()

        optimizer.zero_grad()
        edge_pred = model(data).squeeze()
        loss = criterion(edge_pred, data.edge_attr)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# Evaluation function
def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)

            # Ensure correct dtype
            data.x = data.x.float()
            data.edge_attr = data.edge_attr.float()

            edge_pred = model(data).squeeze()
            loss = criterion(edge_pred, data.edge_attr)
            total_loss += loss.item()
    return total_loss / len(loader)

# Initialize training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GNN(node_features=dataset.num_node_features, edge_features=1, hidden_dim=32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()

# Train and evaluate
for epoch in range(5):
    train_loss = train(model, train_loader, optimizer, criterion)
    test_loss = evaluate(model, test_loader, criterion)
    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")

# Final test evaluation
final_test_loss = evaluate(model, test_loader, criterion)
print(f"Final Test Loss: {final_test_loss:.4f}")


RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 4 but got size 1 for tensor number 1 in the list.