In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv

class MGATLayer(nn.Module):
    def __init__(self, in_channels, out_channels, num_heads=1, dropout=0.6):
        super(MGATLayer, self).__init__()
        self.attention_layers = nn.ModuleList()
        self.num_heads = num_heads

        for _ in range(num_heads):
            self.attention_layers.append(GATConv(in_channels, out_channels, heads=1, dropout=dropout))

    def forward(self, x, edge_index):
        out = torch.stack([attention(x, edge_index) for attention in self.attention_layers], dim=1)
        out = torch.mean(out, dim=1)  # Aggregate outputs from all heads (average pooling)
        return out

class MGATModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads=2):
        super(MGATModel, self).__init__()
        self.gat1 = MGATLayer(in_channels, hidden_channels, num_heads=num_heads)
        self.gat2 = MGATLayer(hidden_channels, out_channels, num_heads=num_heads)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.gat1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.gat2(x, edge_index)
        return F.log_softmax(x, dim=1)

# Sample usage with PyTorch Geometric data
import torch_geometric.datasets as datasets
from torch_geometric.data import DataLoader

# Load a graph dataset (e.g., Cora, Citeseer)
dataset = datasets.Planetoid(root='/tmp/Cora', name='Cora')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MGATModel(dataset.num_node_features, hidden_channels=8, out_channels=dataset.num_classes).to(device)
data = dataset[0].to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')

print("Training complete.")


Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


Epoch 0, Loss: 1.9830167293548584
Epoch 10, Loss: 1.7444876432418823
Epoch 20, Loss: 1.531516671180725
Epoch 30, Loss: 1.3431135416030884
Epoch 40, Loss: 1.2566040754318237
Epoch 50, Loss: 1.1759110689163208
Epoch 60, Loss: 0.9890807867050171
Epoch 70, Loss: 0.9573647379875183
Epoch 80, Loss: 0.8478729128837585
Epoch 90, Loss: 0.9188687205314636
Epoch 100, Loss: 0.7829055190086365
Epoch 110, Loss: 0.734627902507782
Epoch 120, Loss: 0.7490527033805847
Epoch 130, Loss: 0.6629327535629272
Epoch 140, Loss: 0.5397090315818787
Epoch 150, Loss: 0.6694390177726746
Epoch 160, Loss: 0.6100138425827026
Epoch 170, Loss: 0.5155973434448242
Epoch 180, Loss: 0.5343835353851318
Epoch 190, Loss: 0.4904060959815979
Training complete.
