In [1]:
!pip install dgl==1.0.1+cu117 -f https://data.dgl.ai/wheels/cu117/repo.html

Looking in links: https://data.dgl.ai/wheels/cu117/repo.html
Collecting dgl==1.0.1+cu117
  Downloading https://data.dgl.ai/wheels/cu117/dgl-1.0.1%2Bcu117-cp310-cp310-manylinux1_x86_64.whl (266.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.3/266.3 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dgl
Successfully installed dgl-1.0.1+cu117


In [2]:
import os, torch, time
os.environ['DGLBACKEND'] = 'pytorch'
import dgl
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

torch.cuda.is_available()

True

In [23]:
class GATLayer(nn.Module):
  def __init__(self, g, in_dim, out_dim):
    super().__init__()
    self.g = g
    self.fc = nn.Linear(in_dim, out_dim, bias=False)
    self.attn_fc = nn.Linear(out_dim * 2, 1, bias=False) #additive attention
    self.reset_parameters()

  def reset_parameters(self):
    """Reinitialize learnable parameters."""
    gain = nn.init.calculate_gain("relu")
    nn.init.xavier_normal_(self.fc.weight, gain=gain)
    nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)

  def edge_attention(self, edges):
    z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
    a = self.attn_fc(z2)
    return {'e': F.leaky_relu(a)}

  def message_func(self, edges):
    return {'z': edges.src['z'], 'e': edges.data['e']}

  def reduce_func(self, nodes):
    alpha = F.softmax(nodes.mailbox['e'], dim=1)
    h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
    return {'h': h}

  def forward(self, h):
    z = self.fc(h)
    self.g.ndata['z'] = z
    self.g.apply_edges(self.edge_attention)
    self.g.update_all(self.message_func, self.reduce_func)
    return self.g.ndata.pop('h')

In [24]:
class MultiHeadGATLayer(nn.Module):
  def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
    super().__init__()
    self.heads = nn.ModuleList()
    for i in range(num_heads):
      self.heads.append(GATLayer(g, in_dim, out_dim))
    self.merge = merge

  def forward(self, h):
    head_outs = [attn_head(h) for attn_head in self.heads]
    if self.merge == 'cat':
      return torch.cat(head_outs, dim=1)
    else:
      return torch.mean(torch.stack(head_outs))

In [25]:
class GAT(nn.Module):
  def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
    super().__init__()
    self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
    self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)

  def forward(self, h):
    h = F.elu(self.layer1(h))
    h = self.layer2(h)
    return h

In [26]:
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh

def load_cora_data():
    data = citegrh.load_cora()
    g = data[0]
    mask = torch.BoolTensor(g.ndata["train_mask"])
    return g, g.ndata["feat"], g.ndata["label"], mask

import time
import numpy as np

g, features, labels, mask = load_cora_data()
net = GAT(g, in_dim=features.size()[1], hidden_dim=8, out_dim=7, num_heads=2)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

dur = []
for epoch in range(30):
    if epoch >= 3:
        t0 = time.time()

    logits = net(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >= 3:
        dur.append(time.time() - t0)

    print(
        "Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
            epoch, loss.item(), np.mean(dur)
        )
    )

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Epoch 00000 | Loss 1.9461 | Time(s) nan


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 00001 | Loss 1.9443 | Time(s) nan
Epoch 00002 | Loss 1.9426 | Time(s) nan
Epoch 00003 | Loss 1.9408 | Time(s) 0.0899
Epoch 00004 | Loss 1.9390 | Time(s) 0.0956
Epoch 00005 | Loss 1.9372 | Time(s) 0.0940
Epoch 00006 | Loss 1.9354 | Time(s) 0.0969
Epoch 00007 | Loss 1.9336 | Time(s) 0.0964
Epoch 00008 | Loss 1.9319 | Time(s) 0.0970
Epoch 00009 | Loss 1.9301 | Time(s) 0.0965
Epoch 00010 | Loss 1.9283 | Time(s) 0.0965
Epoch 00011 | Loss 1.9264 | Time(s) 0.0957
Epoch 00012 | Loss 1.9246 | Time(s) 0.0947
Epoch 00013 | Loss 1.9228 | Time(s) 0.0943
Epoch 00014 | Loss 1.9210 | Time(s) 0.0939
Epoch 00015 | Loss 1.9191 | Time(s) 0.0933
Epoch 00016 | Loss 1.9173 | Time(s) 0.0931
Epoch 00017 | Loss 1.9154 | Time(s) 0.0928
Epoch 00018 | Loss 1.9135 | Time(s) 0.0922
Epoch 00019 | Loss 1.9117 | Time(s) 0.0932
Epoch 00020 | Loss 1.9098 | Time(s) 0.0930
Epoch 00021 | Loss 1.9078 | Time(s) 0.0924
Epoch 00022 | Loss 1.9059 | Time(s) 0.0922
Epoch 00023 | Loss 1.9040 | Time(s) 0.0918
Epoch 00024 | Los