In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GATv2Conv

# Create a simple graph data
# Let's assume a graph with 4 nodes, and the following edge connections:
# 0 - 1, 1 - 2, 2 - 3, 3 - 0
# Node features: 4 nodes, each with 3 features
# Edge indices: 4 edges (each connects two nodes)

# Edge index: shape [2, num_edges], representing edge connections
edge_index = torch.tensor([[0, 1, 2, 3],
                           [1, 2, 3, 0]], dtype=torch.long)

# Node features: 4 nodes, each with 3 features
x = torch.tensor([[1, 2, 3],  # Node 0
                  [4, 5, 6],  # Node 1
                  [7, 8, 9],  # Node 2
                  [10, 11, 12]], dtype=torch.float)

# Create a graph data object
data = Data(x=x, edge_index=edge_index)

# Define a GATv2 model with two layers and 8 attention heads
class GATv2(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads):
        super(GATv2, self).__init__()
        # First GATv2 layer with 8 attention heads
        self.gatv2_1 = GATv2Conv(in_channels, hidden_channels, heads=heads)  # 8 attention heads
        # Second GATv2 layer with 8 attention heads
        self.gatv2_2 = GATv2Conv(hidden_channels * heads, out_channels, heads=1)  # 1 attention head
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        # First layer of GATv2
        x = self.gatv2_1(x, edge_index)  # Apply first GATv2Conv layer
        x = F.relu(x)  # Apply ReLU activation
        # Second layer of GATv2
        x = self.gatv2_2(x, edge_index)  # Apply second GATv2Conv layer
        return F.relu(x)  # Apply ReLU activation after second layer

# Instantiate the GATv2 model with 3 input features, 4 hidden features, and 2 output features
model = GATv2(in_channels=3, hidden_channels=4, out_channels=2, heads=8)

# Forward pass
out = model(data)
print(out)


tensor([[0.8772, 4.8757],
        [0.6484, 1.6767],
        [0.6600, 1.7931],
        [0.8126, 3.6130]], grad_fn=<ReluBackward0>)
