This code is implemented using DGL with a PyTorch backend and is based on the official DGL implementation of Graph Attention Network (GAT) https://docs.dgl.ai/en/1.1.x/tutorials/models/1_gnn/9_gat.html. The tutorial at this link serves as the boilerplate code for this implementation, on top of which hyper-paramter tuning is added.

In [None]:
!pip install dgl

Collecting dgl
  Downloading dgl-2.1.0-cp310-cp310-manylinux1_x86_64.whl (8.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.5/8.5 MB[0m [31m45.6 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=2->torchdata>=0.5.0->dgl)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=2->torchdata>=0.5.0->dgl)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=2->torchdata>=0.5.0->dgl)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=2->torchdata>=0.5.0->dgl)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=2->torchdata>=0.5.0->dgl)
  Using cached nvidia_cublas_cu12-12.1.3.1

In [None]:
import dgl
import dgl.nn as dglnn
from dgl import AddSelfLoop
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset, CoauthorCSDataset, CoauthorPhysicsDataset
import torch.nn.functional as F
import torch.optim as optim
import time
import torch.nn as nn
import torch
import numpy as np

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


In [None]:
class GAT(nn.Module):
    def __init__(self, in_size, hid_size, out_size, heads):
        super().__init__()
        self.gat_layers = nn.ModuleList()
        # two-layer GAT
        self.gat_layers.append(
            dglnn.GATConv(
                in_size,
                hid_size,
                heads[0],
                feat_drop=0.6,
                attn_drop=0.6,
                activation=F.elu,
            )
        )
        self.gat_layers.append(
            dglnn.GATConv(
                hid_size * heads[0],
                out_size,
                heads[1],
                feat_drop=0.6,
                attn_drop=0.6,
                activation=None,
            )
        )

    def forward(self, g, inputs):
        h = inputs
        # print(g.adjacency_matrix())
        for i, layer in enumerate(self.gat_layers):
            h = layer(g, h)
            if i == 1:  # last layer
                h = h.mean(1)
            else:  # other layer(s)
                h = h.flatten(1)
        return h


In [None]:
epochs = 1000
min_delta = 0.001

def evaluate(g, features, labels, mask, model):
    model.eval()
    with torch.no_grad():
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        loss_fcn = nn.CrossEntropyLoss()
        loss = loss_fcn(logits, labels)
        return loss.item(), correct.item() * 1.0 / len(labels)

def train(g, features, labels, masks, model, learning_rate, weight_dec, patience, verbose):
    # define train/val samples, loss function and optimizer
    train_mask = masks[0]
    val_mask = masks[1]
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_dec)

    best_val_loss = float('inf')
    current_patience = 0
    # training loop
    for epoch in range(epochs):
        model.train()
        logits = model(g, features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_val = loss_fcn(logits[val_mask], labels[val_mask])

        if verbose:
          _, acc_train = evaluate(g, features, labels, train_mask, model)
          _, acc_val = evaluate(g, features, labels, val_mask, model)
          print("Epoch {:05d} | Train Loss {:.4f} | Train Acc {:.4f} | Val Loss {:.4f} | Val Accuracy {:.4f} ".format(
              epoch, loss.item(), acc_train, loss_val.item(), acc_val))

        curr_val_loss = loss_val.item()

        # Early stopping check
        if best_val_loss - curr_val_loss > min_delta:
          best_val_loss = curr_val_loss
          current_patience = 0
        else:
          current_patience += 1

        if current_patience > patience:
          if verbose:
            print(f'Early stopping at epoch {epoch}')
          break

In [None]:
from itertools import product
from tqdm import tqdm

# Training settings
epochs = 1000
hidden_list = [8]
head_list = [4, 8]
min_delta = 0.001
lr_list = [0.01, 0.001]
weight_decay_list = [0.001, 0.0001, 5e-4]

# Store results
results = []

# Define a dictionary to map learning rates to patience values
lr_patience_dict = {0.01: 25, 0.001: 50}

In [None]:
def split_train_val_test_ids(labels, train_samples_per_class=20, val_samples_per_class=30):
    unique_labels = np.unique(labels)

    train_ids = []
    val_ids = []
    test_ids = []

    for label in unique_labels:
        # Get indices of samples with the current label
        label_indices = np.where(labels == label)[0]

        # Shuffle the indices to randomize the samples
        np.random.shuffle(label_indices)

        # Split the indices into train, val, and test sets
        train_indices = label_indices[:train_samples_per_class]
        val_indices = label_indices[train_samples_per_class:(train_samples_per_class + val_samples_per_class)]
        test_indices = label_indices[(train_samples_per_class + val_samples_per_class):]

        train_ids.extend(train_indices)
        val_ids.extend(val_indices)
        test_ids.extend(test_indices)

    return train_ids, val_ids, test_ids

In [None]:
from dgl.data import DGLDataset

class Blogcatalog(DGLDataset):
    def __init__(self):
        super().__init__(name="Blogcatalog")

    def process(self):
          print("Loading Blogcatalog Graph data...")
          data = np.load('blogcatalog.npz', allow_pickle=True)
          labels = data['node_label']
          feat = data['node_attr']
          adj_matrix =  data['adj_matrix']
          feat = torch.tensor(feat.tolist().toarray()).float()
          labels = torch.tensor(labels)
          labels = labels.to(torch.int64)
          labels = labels - 1
          adj_matrix = adj_matrix.tolist().toarray()
          adj_matrix = adj_matrix + np.transpose(adj_matrix) + np.eye(adj_matrix.shape[0])
          print(adj_matrix)
          src, dst = np.nonzero(adj_matrix)

          self.graph = dgl.graph(
            (src, dst), num_nodes=adj_matrix.shape[0]
            )
          self.graph.ndata["feat"] = feat
          self.graph.ndata["label"] = labels
          self.num_classes = len(np.unique(labels))
          print("Data loaded.")

    def __getitem__(self, i):
        return self.graph

    def __len__(self):
        return 1

class DBLP(DGLDataset):
    def __init__(self):
        super().__init__(name="DBLP")

    def process(self):
          print("Loading DBLP Graph data...")
          data = np.load('DBLP_BERT_graph_data.npz', allow_pickle=True)
          labels = data['labels']
          feat = data['feature_matrix']
          adj_matrix =  data['adj_mat']
          feat = torch.tensor(feat).float()
          labels = torch.tensor(labels)
          labels = labels.to(torch.int64)
          adj_matrix = adj_matrix.tolist().toarray()
          adj_matrix = adj_matrix + np.transpose(adj_matrix) + np.eye(adj_matrix.shape[0])
          print(adj_matrix)
          src, dst = np.nonzero(adj_matrix)

          self.graph = dgl.graph(
            (src, dst), num_nodes=adj_matrix.shape[0]
            )
          self.graph.ndata["feat"] = feat
          self.graph.ndata["label"] = labels
          self.num_classes = len(np.unique(labels))
          print("Data loaded.")

    def __getitem__(self, i):
        return self.graph

    def __len__(self):
        return 1

Change dataset_name in the below cell for running experiment on specific dataset.

In [None]:
dataset_name = "DBLP" # change according

if dataset_name == "cora":
  dataset = CoraGraphDataset(transform= AddSelfLoop())

elif dataset_name == "citeseer":
  dataset = CiteseerGraphDataset(transform= AddSelfLoop())

elif dataset_name == "pubmed":
  dataset = PubmedGraphDataset(transform= AddSelfLoop())

elif dataset_name == "CS":
  dataset = CoauthorCSDataset(transform= AddSelfLoop())

elif dataset_name == "Physics":
  dataset = CoauthorPhysicsDataset(transform= AddSelfLoop())

elif dataset_name == "Blogcatalog":
  dataset = Blogcatalog()

elif dataset_name == "DBLP":
  dataset = DBLP()

else:
  raise NotImplementedError

num_classes = dataset.num_classes
g = dataset[0]
# get labels
labels = g.ndata['label']
features = g.ndata['feat']

print(labels.dtype)
print(features.dtype)

if dataset_name == "cora" or dataset_name == "citeseer" or dataset_name == "pubmed":
  masks = g.ndata["train_mask"], g.ndata["val_mask"], g.ndata["test_mask"]

elif dataset_name == "CS" or dataset_name == "Physics" or dataset_name == "Blogcatalog" or dataset_name == "DBLP":
  train_node_ids, val_node_ids, test_node_ids = split_train_val_test_ids(g.ndata['label'].numpy())
  train_mask = np.zeros(g.num_nodes(), dtype=bool)
  train_mask[train_node_ids] = True

  val_mask = np.zeros(g.num_nodes(), dtype=bool)
  val_mask[val_node_ids] = True

  test_mask = np.zeros(g.num_nodes(), dtype=bool)
  test_mask[test_node_ids] = True

  train_mask = torch.from_numpy(train_mask)
  val_mask = torch.from_numpy(val_mask)
  test_mask = torch.from_numpy(test_mask)

  masks = train_mask, val_mask, test_mask

# create GAT model
in_size = features.shape[1]
out_size = dataset.num_classes


Loading DBLP Graph data...
[[1. 2. 2. ... 0. 0. 0.]
 [2. 1. 0. ... 0. 0. 0.]
 [2. 0. 1. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 1. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]
 [0. 0. 0. ... 0. 0. 1.]]
Data loaded.
torch.int64
torch.float32


In [None]:

t_start = time.time()
# Perform grid search
# for hidden, lr, weight_decay in product(hidden_list, lr_list, weight_decay_list):
for hidden, lr, weight_decay, num_heads in tqdm(list(product(hidden_list, lr_list, weight_decay_list, head_list)), desc="Hyperparameter Grid Search"):
  model = GAT(in_size, hidden, out_size, heads=[num_heads, 1])
  train(g, features, labels, masks, model, lr, weight_decay, lr_patience_dict[lr], False)
  val_loss, val_acc = evaluate(g, features, labels, masks[1], model)

  results.append({
        'hidden': hidden,
        'heads': num_heads,
        'lr': lr,
        'weight_decay': weight_decay,
        'val_loss': val_loss,
    })

  print("Hidden:", hidden,
        "heads:", num_heads,
        "lr:", lr,
        "weight_decay:", weight_decay,
        "val_loss:", val_loss)


# Find the best set of hyperparameters
best_result = min(results, key=lambda x: x['val_loss'])


t_end = time.time()
print("------------------------")
print(f"Total Time Elapsed to Find Best Hyper-parameters: {t_end-t_start} seconds")
print("------------------------")

# Print the best hyperparameters and test the model
print("Best Hyperparameters:")
print(f"Hidden: {best_result['hidden']}")
print(f"Heads: {best_result['heads']}")
print(f"Learning Rate: {best_result['lr']}")
print(f"Weight Decay: {best_result['weight_decay']}")
print(f"Validation Loss: {best_result['val_loss']}")

print("------------------------")


print("Now training with best Hyper-paramater settings")

hidden = best_result['hidden']
lr = best_result['lr']
weight_decay = best_result['weight_decay']
num_heads = best_result['heads']

model = GAT(in_size, hidden, out_size, heads=[num_heads, 1])
train(g, features, labels, masks, model, lr, weight_decay, lr_patience_dict[lr], True)

# test the model
print("Testing...")
_, acc = evaluate(g, features, labels, masks[2], model)
print("Test accuracy {:.4f}".format(acc))
