#### Graph Classification Models
This notebook serves as a tutorial for working with graph classifications in PyTorch Geometric (PyG), namely graph pair classification models (i.e. single label for a pair of graphs).

#### Enable GPU/MPS
If you have a GPU or a Silicon based Mac, then we can enable hardware acceleration for PyTorch with the following.


In [2]:
# Setup GPU
import sys
sys.path.append("../src")

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


For Macs, if you want to check whether MPS-acceleration is enabled, this is how (should print `True` for both).

In [2]:
print(torch.backends.mps.is_available()) #the MacOS is higher than 12.3+
print(torch.backends.mps.is_built()) #MPS is activated

True
True


#### Graph Structures
PyG uses a class called `Data` (namely [`torch_geometric.data.Data`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html)) to represent graph structures. This `Data` object consists of an `edge_index`, a torch tensor of shape `[2, num_edges]` where each column $[i \ \ j]^T$ represents a directed edge point from node $i$ to node $j$. The second component of `Data` are the node features notated as `x`, a torch tensor of shape `[num_nodes, num_node_features]`, and third component are the edge features are stored in the torch tensor `edge_attr` of shape `[num_edges, num_edge_features]` where its indices follow the columns of `edge_index`; that is, the 3rd column in `edge_index` is the 3rd entry of `edge_attr`, and so on.

In [3]:
from torch_geometric.data import Data

# Edges
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)

# Node features
x = torch.tensor([[-1], [0], [1]], dtype=torch.float32)

# Edge features
edge_attr = torch.tensor([[1, 1, 2, 2]], dtype=torch.float32)
edge_attr = edge_attr.T

# Graph Label
y = torch.tensor([0], dtype=torch.long)

# We store each graph in a Data object. This Data object is customizable in its attributes.
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

We can call certain attributes of our `Data` object such as the number of nodes and number of edge features, and much more.

In [4]:
# Descriptive data of the Data object
print((data.to_dict()).keys())
print("Number of nodes:", data.num_nodes)
print("Number of edges:", data.num_edges)
print("Number of node features:", data.num_node_features)
print("Number of edge features:", data.num_edge_features)

dict_keys(['x', 'edge_index', 'edge_attr', 'y'])
Number of nodes: 3
Number of edges: 4
Number of node features: 1
Number of edge features: 1


Once we've loaded all our graph representations into their respctive `Data` objects, we'll place them in a list, and use a special `DataLoader` from PyG to handle our data. This is important to use the PyG `DataLoader` and not the native torch `DataLoader` as it can handle batching of graphs much more efficiently.

In [5]:
from torch_geometric.loader import DataLoader

# Load Data objects into a list
data_list = [data, data, data, data]

# Load data list into a DataLoader object
loader = DataLoader(data_list, batch_size = 2)

# Print out objects for a single batch (2 graphs)
for batch in loader:
  print("Node features:")
  print(batch.x)
  print("Edge index:")
  print(batch.edge_index)
  print("Edge features:")
  print(batch.edge_attr)
  print("Batch objects which tracks which nodes belong to which graph:")
  print(batch.batch)
  print("Graph labels:")
  print(batch.y)
  break

Node features:
tensor([[-1.],
        [ 0.],
        [ 1.],
        [-1.],
        [ 0.],
        [ 1.]])
Edge index:
tensor([[0, 1, 1, 2, 3, 4, 4, 5],
        [1, 0, 2, 1, 4, 3, 5, 4]])
Edge features:
tensor([[1.],
        [1.],
        [2.],
        [2.],
        [1.],
        [1.],
        [2.],
        [2.]])
Batch objects which tracks which nodes belong to which graph:
tensor([0, 0, 0, 1, 1, 1])
Graph labels:
tensor([0, 0])


#### Binary Classification
Below is an example of a binary classifier using a Graph Convolutional Network (GCN), which includes two `GCNConv` layers from [`conv.GCNConv`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCNConv.html#torch_geometric.nn.conv.GCNConv).

In [6]:
# A simple GCN model which only takes in two graphs
import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GCNConv, global_mean_pool

class GraphClassifier(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features):
        super(GraphClassifier, self).__init__()

        # Node feature transformation layers
        self.conv1 = GCNConv(num_node_features, 64)
        self.conv2 = GCNConv(64, 128)
        
        # Readout layer
        self.readout = global_mean_pool

        # Classifier
        self.classifier = Linear(128, 1)

    def forward(self, x, edge_index, edge_attr, batch):
        # Update node features
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        # Readout layer to get graph-level representation
        x = self.readout(x, batch)  # <-- Use the batch vector here

        # Classifier to predict the graph label
        x = self.classifier(x)
        x = torch.sigmoid(x)

        return x.squeeze(-1)

#### Training
To train in PyG it's the essentially the same as training in regular PyTorch with a few caveats. We initialize the model, optimizer, and loss; then we run the training loop over a certain number of epochs and iterate of the batches in the `DataLoader`. The difference here is that we are using the `batch._` format for the inputs (you could also just modify the model to split up this batch object accordingly). What is *very* important is the `batch.batch` object which tracks which nodes correspond to which graphs in our batch, i.e. a value of $5$ would correspond to the 6th graph in our batch, and it tells the GNN how to perform batch operations properly, such as global mean pooling.

In [7]:
# Initialize model
model = GraphClassifier(num_node_features=1, num_edge_features=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCELoss().to(device)

# Training loop
for epoch in range(1, 20):
    for batch in loader:
        
        # Move batch to GPU
        batch.x = batch.x.to(device)
        batch.edge_index = batch.edge_index.to(device)
        batch.edge_attr = batch.edge_attr.to(device)
        batch.y = batch.y.to(device)
        batch.batch = batch.batch.to(device)
        
        optimizer.zero_grad()
        
        out = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        loss = criterion(out, batch.y.float())
        
        loss.backward()
        optimizer.step()
    print(f'Epoch: {epoch}, Loss: {loss.item()}')

Epoch: 1, Loss: 0.6097297072410583
Epoch: 2, Loss: 0.47246265411376953
Epoch: 3, Loss: 0.27932965755462646
Epoch: 4, Loss: 0.1038670614361763
Epoch: 5, Loss: 0.021476250141859055
Epoch: 6, Loss: 0.00292340200394392
Epoch: 7, Loss: 0.00034466429497115314
Epoch: 8, Loss: 4.1500883526168764e-05
Epoch: 9, Loss: 5.58156352781225e-06
Epoch: 10, Loss: 8.78013509009179e-07
Epoch: 11, Loss: 1.6544642278404353e-07
Epoch: 12, Loss: 3.757103073098733e-08
Epoch: 13, Loss: 1.0240342795952984e-08
Epoch: 14, Loss: 3.310429708136553e-09
Epoch: 15, Loss: 1.2492684664522358e-09
Epoch: 16, Loss: 5.426748495018785e-10
Epoch: 17, Loss: 2.670455245823433e-10
Epoch: 18, Loss: 1.465278592904795e-10
Epoch: 19, Loss: 8.834161180359956e-11


#### Multiclass Classification
Let's try out the main model we'll be using in BRAINGREG, which is given by `supervised_model` in `models.py`. We'll also test the data for the multiclass classification setting, so our label $y$ is no longer binary, it is a multiclass label encoded as a 3D one-hot vector $\vec{y}\in\R^3$.

In [8]:
import torch
from torch_geometric.data import Data

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Edges
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)

# Node features
x = torch.tensor([[-1], [0], [1]], dtype=torch.float32)

# Edge features
edge_attr = torch.tensor([[1, 1, 2, 2]], dtype=torch.float32)

edge_attr = edge_attr.T

# Graph Label
y = torch.tensor(2, dtype=torch.long)

# We store each graph in a Data object. This Data object is customizable in its attributes.
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

In [9]:
from torch_geometric.loader import DataLoader

# Load Data objects into a list
data_list = [data, data, data, data]

# Load data list into a DataLoader object
loader = DataLoader(data_list, batch_size = 2)

In [10]:
for batch in loader:
  print((batch.edge_attr).shape)
  print(batch.y)

torch.Size([8, 1])
tensor([2, 2])
torch.Size([8, 1])
tensor([2, 2])


In [18]:
from models import supervised_model

# Initialize model
config = {"num_node_features": 1,"num_edge_features": 1, "hidden_channels": 64, "out_channels": 32, "dropout": 0.05}

model = supervised_model(config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss().to(device)

# Training loop
for epoch in range(1, 10):
    for batch in loader:
        
        # Move batch to GPU
        batch.x = batch.x.to(device)
        batch.edge_index = batch.edge_index.to(device)
        batch.edge_attr = batch.edge_attr.to(device)
        batch.y = batch.y.to(device)
        batch.batch = batch.batch.to(device)
        
        optimizer.zero_grad()

        out = model(batch, classify="multiclass", head="linear", dropout=True)
        loss = criterion(out, batch.y)
        
        loss.backward()
        optimizer.step()
    print(f'Epoch: {epoch}, Loss: {loss.item()}')

Epoch: 1, Loss: 0.11627106368541718
Epoch: 2, Loss: 0.0007373987464234233
Epoch: 3, Loss: 5.918392344028689e-05
Epoch: 4, Loss: 0.0
Epoch: 5, Loss: 0.0
Epoch: 6, Loss: 0.0
Epoch: 7, Loss: 0.0
Epoch: 8, Loss: 0.0
Epoch: 9, Loss: 0.0


#### Graph Pair Classification

If you're interested in graph pair models, where each example in the batch is now a *pair* of graphs and we're trying to predict their graph pair label, then we have to create a custom class `PairData` that inherits from the `torch_geometric.data.Data` class.

Load the graph representations into a `PairData` class.

In [19]:
from preprocess import PairData

# Node features of shape (num_nodes, num_node_features) and type torch.float32
x1 = torch.tensor([[0, 0, 0],
                    [1, 1, 1],
                    [2, 2, 2],
                    [3, 3, 3]], dtype=torch.float32)
x2 = torch.tensor([[4, 4, 4],
                    [5, 5, 5],
                    [6, 6, 6]], dtype=torch.float32)

# Edge indices of shape (2, num_edges) and type torch.long
edge_index1 = torch.tensor([[0, 1, 1, 2, 2, 3],
                             [1, 0, 2, 1, 3, 2]], dtype=torch.long)
edge_index2 = torch.tensor([[0, 1, 1, 2],
                             [1, 0, 2, 1]], dtype=torch.long)

# Edge features of shape (num_edges, num_edge_features) and type torch.float32
edge_attr1 = torch.tensor([[0],
                            [1],
                            [2],
                            [3],
                            [4],
                            [5]], dtype=torch.float32)
edge_attr2 = torch.tensor([[6],
                            [7],
                            [8],
                            [9]], dtype=torch.float32)

# Pair label of shape (1,) and type torch.long
y = torch.tensor([1], dtype=torch.float32)

data = PairData(x1=x1, edge_index1=edge_index1, edge_attr1=edge_attr1,  # Graph 1.
                x2=x2, edge_index2=edge_index2, edge_attr2=edge_attr2,  # Graph 2.
                y=y) #Graph pair label. 

print(data.edge_attr1)

tensor([[0.],
        [1.],
        [2.],
        [3.],
        [4.],
        [5.]])


And for *continuous* labels, such as the temporal kernel coefficient found in VICRegT1 we can use `PairData` as well by changing `y` to be continuous.

In [20]:
from preprocess import PairData

# Node features of shape (num_nodes, num_node_features) and type torch.float32
x1 = torch.tensor([[0, 0, 0],
                    [1, 1, 1],
                    [2, 2, 2],
                    [3, 3, 3]], dtype=torch.float32)
x2 = torch.tensor([[4, 4, 4],
                    [5, 5, 5],
                    [6, 6, 6]], dtype=torch.float32)

# Edge indices of shape (2, num_edges) and type torch.long
edge_index1 = torch.tensor([[0, 1, 1, 2, 2, 3],
                             [1, 0, 2, 1, 3, 2]], dtype=torch.long)
edge_index2 = torch.tensor([[0, 1, 1, 2],
                             [1, 0, 2, 1]], dtype=torch.long)

# Edge features of shape (num_edges, num_edge_features) and type torch.float32
edge_attr1 = torch.tensor([[0],
                            [1],
                            [2],
                            [3],
                            [4],
                            [5]], dtype=torch.float32)
edge_attr2 = torch.tensor([[6],
                            [7],
                            [8],
                            [9]], dtype=torch.float32)

# Pair label of shape (1,) and type torch.long
y = torch.tensor([0.83], dtype=torch.float32)

data = PairData(x1=x1, edge_index1=edge_index1, edge_attr1=edge_attr1,  # Graph 1.
                x2=x2, edge_index2=edge_index2, edge_attr2=edge_attr2,  # Graph 2.
                y=y) #Graph pair label. 

print(data.y)

tensor([0.8300])


#### DataLoader for Graph Pairs
We can use the `torch_geometric.loader.Dataloader` to load our list of `PairData` objects, except we use an additional argument of `follow_batch=['x1', 'x2']`, which allows us to correctly identify the graph pairs in our batch.

In [13]:
# Dataloader for pairs of graphs
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch

# We will have our list of graphs in the form of Data objects
data_list = [data, data, data]

# Create the dataloader. The follow_batch tells the dataloader which nodes belong to which graph in this 
# giant disconnected graph that the batch creates. We can typically split the data_list into train, val, test and then 
# create individual loaders correspondingly.
pair_loader = DataLoader(data_list, batch_size=2, follow_batch=['x1', 'x2'])


# We can iterate through batches with the following. Each batch is a data.Batch() object
for batch in pair_loader:
    inputs = ((batch.x1, batch.edge_index1, batch.edge_attr1), (batch.x2, batch.edge_index2, batch.edge_attr2))
    graph_1, graph_2 = inputs
    labels = batch.y
    print(batch.x2)
    print(labels)
    print("Which nodes correspond to which graph:", batch.x1_batch)
    print("Which nodes correspond to which graph:", batch.x2_batch)

tensor([[4., 4., 4.],
        [5., 5., 5.],
        [6., 6., 6.],
        [4., 4., 4.],
        [5., 5., 5.],
        [6., 6., 6.]])
tensor([1., 1.])
Which nodes correspond to which graph: tensor([0, 0, 0, 0, 1, 1, 1, 1])
Which nodes correspond to which graph: tensor([0, 0, 0, 1, 1, 1])
tensor([[4., 4., 4.],
        [5., 5., 5.],
        [6., 6., 6.]])
tensor([1.])
Which nodes correspond to which graph: tensor([0, 0, 0, 0])
Which nodes correspond to which graph: tensor([0, 0, 0])


#### GCN Model for Graph Pairs
We use a simple GCN model for graph pairs, same as before.

In [13]:
# Simple graph pair classifier
class PairGraphClassifier(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features):
        super(PairGraphClassifier, self).__init__()

        # Node feature transformation layers
        self.conv1 = GCNConv(num_node_features, 64)
        self.conv2 = GCNConv(64, 128)

        # Edge feature transformation layers
        self.edge_mlp = Sequential(Linear(num_edge_features, 32),
                                   ReLU(),
                                   Linear(32, 64))

        # Classifier
        self.classifier = Linear(256, 1)  # 128 features from each graph

    def forward_one(self, x, edge_index, edge_attr, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        edge_attr = self.edge_mlp(edge_attr)
        x = global_mean_pool(x, batch)  # Use batch vector for separate pooling
        return x

    def forward(self, x_1, edge_index1, edge_attr1, batch1, x2, edge_index2, edge_attr2, batch2):
        
        x_1 = self.forward_one(x_1, edge_index1, edge_attr1, batch1)
        x2 = self.forward_one(x2, edge_index2, edge_attr2, batch2)

        x = torch.cat([x_1, x2], dim=1)
        x = self.classifier(x)
        x = torch.sigmoid(x)

        return x.squeeze(-1)

Train the graph pair model.

In [16]:
# Initialize device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model
model = PairGraphClassifier(num_node_features=3, num_edge_features=1)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCELoss()
criterion.to(device)

# Training loop
for epoch in range(1, 20):
    for batch in pair_loader:
        # Move batch data to the device
        batch.x1 = batch.x1.to(device)
        batch.edge_index1 = batch.edge_index1.to(device)
        batch.edge_attr1 = batch.edge_attr1.to(device)
        batch.x1_batch = batch.x1_batch.to(device)
        
        batch.x2 = batch.x2.to(device)
        batch.edge_index2 = batch.edge_index2.to(device)
        batch.edge_attr2 = batch.edge_attr2.to(device)
        batch.x2_batch = batch.x2_batch.to(device)
        
        batch.y = batch.y.to(device)

        optimizer.zero_grad()
        
        out = model(batch.x1, batch.edge_index1, batch.edge_attr1, batch.x1_batch,
                    batch.x2, batch.edge_index2, batch.edge_attr2, batch.x2_batch)
        loss = criterion(out, batch.y)
        
        loss.backward()
        optimizer.step()
        
    print(f'Epoch: {epoch}, Loss: {loss.item()}')
print("We have a graph pair model working!!!")


Epoch: 1, Loss: 0.18142934143543243
Epoch: 2, Loss: 0.008289927616715431
Epoch: 3, Loss: 0.00026128129684366286
Epoch: 4, Loss: 8.821526535029989e-06
Epoch: 5, Loss: 3.576279254957626e-07
Epoch: 6, Loss: 0.0
Epoch: 7, Loss: 0.0
Epoch: 8, Loss: 0.0
Epoch: 9, Loss: 0.0
Epoch: 10, Loss: 0.0
Epoch: 11, Loss: 0.0
Epoch: 12, Loss: 0.0
Epoch: 13, Loss: 0.0
Epoch: 14, Loss: 0.0
Epoch: 15, Loss: 0.0
Epoch: 16, Loss: 0.0
Epoch: 17, Loss: 0.0
Epoch: 18, Loss: 0.0
Epoch: 19, Loss: 0.0
We have a graph pair model working!!!
