In [3]:
from torch_geometric.datasets import UPFD

In [4]:
# different feature types can be selected: content(profile + spacy; dim: 310), profile(dim: 10), spacy(dim: 300)
# splits: train, test, val
# name: politifact, gossipcop
dataset = UPFD('data/upfd', name="politifact", feature='bert', split="train")

In [15]:
print(f"Number of graphs: {len(dataset)}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Number of features: {dataset.num_features}")
print(f"Number of node features: {dataset.num_node_features}")
print(f"Number of edge features: {dataset.num_edge_features}")



Number of graphs: 62
Number of classes: 2
Number of features: 768
Number of node features: 768
Number of edge features: 0


In [33]:
graph = dataset[0]
print(f"Graph at index 0: {graph}")
print(f"Node features shape: {graph.x.shape}")
print(f"Node labels shape: {graph.y.shape}")

print(f"Edge index shape: {graph.edge_index.shape}")
print(f"Edge index: {graph.edge_index}")

Graph at index 0: Data(x=[72, 768], edge_index=[2, 71], y=[1])
Node features shape: torch.Size([72, 768])
Node labels shape: torch.Size([1])
Edge index shape: torch.Size([2, 71])
Edge index: tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  8,  8,  8, 16, 16, 16, 16, 16, 16,
         24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
         24, 24, 24, 24, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 60],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
         19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
         37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
         55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71]])


graph.x:  
row 0 of graph.x is the encoded article  
row n is the encoded past user tweets for user n  

label:  
0 = real
1 = fake



# train gat

In [29]:
import torch
import torch.nn.functional as F
from torch.nn import Linear, ModuleList, ReLU, Sequential
from typing import Callable, Dict, List, Optional, Tuple, Union, Final

from torch_geometric.nn import GATConv, GATv2Conv, global_add_pool, global_mean_pool, global_max_pool
from torch_geometric.nn.models.basic_gnn import BasicGNN
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import Adj, OptTensor

class GATforGraphClassification(BasicGNN):
    def __init__(
      self,
      in_channels: int,
      hidden_channels: int,
      out_channels: int,
      num_layers: int,
      heads: int,
      dropout: float = 0.0,
      pooling: str = 'mean',
      **kwargs,
    ):
      self.out_channels_final = out_channels
      
      super().__init__(
        in_channels=in_channels,
        out_channels=None,
        hidden_channels=hidden_channels,
        num_layers=num_layers,
        dropout=dropout,
        **kwargs,
      )
      
      self.pooling = pooling
      
      if pooling == 'add':
        self.pool = global_add_pool
      elif pooling == 'mean':
        self.pool = global_mean_pool
      elif pooling == 'max':
        self.pool = global_max_pool
      else:
        raise ValueError(f"Pooling type {pooling} not supported.")
      
      self.classifier = Linear(self.out_channels, self.out_channels_final)
      
    def init_conv(self, in_channels: Union[int, Tuple[int, int]],
                  out_channels: int, **kwargs) -> MessagePassing:

        v2 = kwargs.pop('v2', False)
        heads = kwargs.pop('heads', 1)
        concat = kwargs.pop('concat', True)

        # Do not use concatenation in case the layer `GATConv` layer maps to
        # the desired output channels (out_channels != None and jk != None):
        if getattr(self, '_is_conv_to_out', False):
            concat = False

        if concat and out_channels % heads != 0:
            raise ValueError(f"Ensure that the number of output channels of "
                             f"'GATConv' (got '{out_channels}') is divisible "
                             f"by the number of heads (got '{heads}')")

        if concat:
            out_channels = out_channels // heads

        Conv = GATConv if not v2 else GATv2Conv
        return Conv(in_channels, out_channels, heads=heads, concat=concat,
                    dropout=self.dropout.p, **kwargs)
      
    def forward(self, x, edge_index, batch=None, edge_attr=None):
        """
        Forward pass for graph classification.
        
        Args:
            x: Node features [num_nodes, in_channels]
            edge_index: Graph connectivity [2, num_edges]
            batch: Batch vector [num_nodes] mapping each node to its graph
            edge_attr: Edge features [num_edges, edge_dim] (optional)
        
        Returns:
            Graph classification predictions [batch_size, out_channels_final]
        """
        # Get node embeddings using the GNN layers from the parent class
        x = self.convs[0](x, edge_index, edge_attr=edge_attr)
        x = self.act(x)
        
        for i, conv in enumerate(self.convs[1:]):
            x = self.dropout(x)
            x = conv(x, edge_index, edge_attr=edge_attr)
            if i < len(self.convs) - 2:
                x = self.act(x)
        
            
        # Pool node features to graph-level representation
        if batch is None:
            # If no batch is provided, assume a single graph
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        # Apply pooling to get graph-level representation
        x = self.pool(x, batch)
        
        # Apply final classification layer
        x = self.dropout(x)
        x = self.classifier(x)
        
        return x

In [None]:
gat_model = GATforGraphClassification(
    in_channels=dataset.num_features,
    hidden_channels=64,
    out_channels=2,
    num_layers=3,
    heads=8,
    dropout=0.5,
    pooling='mean',
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gat_model = gat_model.to(device)

dummy_graph = dataset[0].to(device)
dummy_graph.x = dummy_graph.x.to(device)
dummy_graph.edge_index = dummy_graph.edge_index.to(device)
dummy_graph.edge_attr = dummy_graph.edge_attr.to(device) if dummy_graph.edge_attr is not None else None
gat_model.eval()
out = gat_model(
    x=dummy_graph.x,
    edge_index=dummy_graph.edge_index,
    batch=dummy_graph.batch,
    edge_attr=dummy_graph.edge_attr,
)


Output shape: torch.Size([1, 2])


In [34]:
from torch_geometric.loader import DataLoader

train_dataset = UPFD('data/upfd', name='politifact', feature='bert', split='train')
val_dataset = UPFD('data/upfd', name='politifact', feature='bert', split='val')
test_dataset = UPFD('data/upfd', name='politifact', feature='bert', split='test')

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

In [37]:
# Training loop
import torch.nn.functional as F

optimizer = torch.optim.AdamW(gat_model.parameters(), lr=0.005, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
epochs = 100

for epoch in range(epochs):
    gat_model.train()
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = gat_model(batch.x, batch.edge_index, batch.batch, batch.edge_attr)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
    
    # Validation loop
    gat_model.eval()
    correct = 0
    for batch in val_loader:
        batch = batch.to(device)
        out = gat_model(batch.x, batch.edge_index, batch.batch, batch.edge_attr)
        pred = out.argmax(dim=1)
        correct += (pred == batch.y).sum().item()
    accuracy = correct / len(val_loader.dataset)
    print(f"Validation Accuracy: {accuracy:.4f}")
    
# Test loop
gat_model.eval()
correct = 0
for batch in test_loader:
    batch = batch.to(device)
    out = gat_model(batch.x, batch.edge_index, batch.batch, batch.edge_attr)
    pred = out.argmax(dim=1)
    correct += (pred == batch.y).sum().item()
accuracy = correct / len(test_loader.dataset)
print(f"Test Accuracy: {accuracy:.4f}")




Epoch 1/100, Loss: 0.0856
Validation Accuracy: 0.8387
Epoch 2/100, Loss: 0.2887
Validation Accuracy: 0.8710
Epoch 3/100, Loss: 0.0280
Validation Accuracy: 0.7097
Epoch 4/100, Loss: 0.0616
Validation Accuracy: 0.7419
Epoch 5/100, Loss: 0.2679
Validation Accuracy: 0.8387
Epoch 6/100, Loss: 0.0335
Validation Accuracy: 0.8710
Epoch 7/100, Loss: 0.1323
Validation Accuracy: 0.8710
Epoch 8/100, Loss: 0.0783
Validation Accuracy: 0.8387
Epoch 9/100, Loss: 0.0534
Validation Accuracy: 0.8387
Epoch 10/100, Loss: 0.0710
Validation Accuracy: 0.8387
Epoch 11/100, Loss: 0.0587
Validation Accuracy: 0.8387
Epoch 12/100, Loss: 0.0547
Validation Accuracy: 0.8387
Epoch 13/100, Loss: 0.0947
Validation Accuracy: 0.8065
Epoch 14/100, Loss: 0.0421
Validation Accuracy: 0.7742
Epoch 15/100, Loss: 0.1004
Validation Accuracy: 0.7419
Epoch 16/100, Loss: 0.0136
Validation Accuracy: 0.7419
Epoch 17/100, Loss: 0.1595
Validation Accuracy: 0.7419
Epoch 18/100, Loss: 0.0355
Validation Accuracy: 0.7419
Epoch 19/100, Loss: