In [4]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

class MGCN(torch.nn.Module):
    def __init__(self, in_channels_node, in_channels_edge, hidden_channels, out_channels):
        super(MGCN, self).__init__()
        # GCN for node features
        self.gcn_node = GCNConv(in_channels_node, hidden_channels)
        # GCN for edge features
        self.gcn_edge = GCNConv(in_channels_edge, hidden_channels)
        # Fully connected layer for combining modalities
        self.fc = torch.nn.Linear(2 * hidden_channels, out_channels)

    def forward(self, node_x, edge_index, edge_attr):
        # Apply GCN for node features
        node_out = self.gcn_node(node_x, edge_index)
        node_out = F.relu(node_out)

        # Apply GCN for edge features (reshape to match node dimensions)
        edge_out = self.gcn_edge(edge_attr, edge_index)
        edge_out = F.relu(edge_out)

        # Aggregate edge features per node (e.g., averaging edge features per node)
        node_count = node_out.size(0)
        edge_aggregated = torch.zeros(node_count, edge_out.size(1), device=edge_out.device)
        for idx, (src, dst) in enumerate(edge_index.t()):
            edge_aggregated[src] += edge_out[idx]
            edge_aggregated[dst] += edge_out[idx]
        edge_aggregated /= 2  # Average contribution per node

        # Concatenate node and aggregated edge features
        combined = torch.cat([node_out, edge_aggregated], dim=1)
        output = self.fc(combined)
        return F.log_softmax(output, dim=1)

# Example usage
# Define graph data
node_features = torch.rand(5, 4)  # 5 nodes with 4 features each
edge_index = torch.tensor([[0, 1, 2, 3, 4, 1], [1, 2, 3, 4, 0, 3]], dtype=torch.long)  # 6 edges
edge_features = torch.rand(edge_index.size(1), 4)  # Each edge has 4 features

data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)

# Initialize and train the MGCN model
model = MGCN(in_channels_node=4, in_channels_edge=4, hidden_channels=8, out_channels=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()

# Example training loop
for epoch in range(50):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index, data.edge_attr)
    # Dummy target for demonstration
    target = torch.randint(0, 3, (data.x.size(0),))
    loss = F.nll_loss(out, target)
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')


Epoch 1, Loss: 0.9426
Epoch 2, Loss: 1.0486
Epoch 3, Loss: 1.3446
Epoch 4, Loss: 1.1110
Epoch 5, Loss: 1.1140
Epoch 6, Loss: 1.0580
Epoch 7, Loss: 1.2117
Epoch 8, Loss: 1.0299
Epoch 9, Loss: 1.3309
Epoch 10, Loss: 1.0776
Epoch 11, Loss: 1.1281
Epoch 12, Loss: 1.0638
Epoch 13, Loss: 1.3151
Epoch 14, Loss: 1.1224
Epoch 15, Loss: 0.9981
Epoch 16, Loss: 1.1843
Epoch 17, Loss: 1.0982
Epoch 18, Loss: 1.0561
Epoch 19, Loss: 1.2300
Epoch 20, Loss: 1.1762
Epoch 21, Loss: 1.1720
Epoch 22, Loss: 1.0444
Epoch 23, Loss: 1.1433
Epoch 24, Loss: 1.1184
Epoch 25, Loss: 1.0847
Epoch 26, Loss: 1.1027
Epoch 27, Loss: 1.1135
Epoch 28, Loss: 1.1146
Epoch 29, Loss: 1.1313
Epoch 30, Loss: 1.0815
Epoch 31, Loss: 1.0722
Epoch 32, Loss: 1.1015
Epoch 33, Loss: 1.1313
Epoch 34, Loss: 1.1468
Epoch 35, Loss: 1.0922
Epoch 36, Loss: 1.0994
Epoch 37, Loss: 1.0940
Epoch 38, Loss: 1.0431
Epoch 39, Loss: 1.1204
Epoch 40, Loss: 1.1051
Epoch 41, Loss: 1.0841
Epoch 42, Loss: 1.1173
Epoch 43, Loss: 1.1167
Epoch 44, Loss: 1.14