In [1]:
from ogb.nodeproppred import DglNodePropPredDataset
from dgl.dataloading.neighbor import MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader
from dgl.nn import GATConv

Using backend: pytorch


In [6]:
d = DglNodePropPredDataset(name='ogbn-products')

In [7]:
d.num_classes

47

In [8]:
g, labels = d[0]
g, labels

(Graph(num_nodes=2449029, num_edges=123718280,
       ndata_schemes={'feat': Scheme(shape=(100,), dtype=torch.float32)}
       edata_schemes={}),
 tensor([[0],
         [1],
         [2],
         ...,
         [8],
         [2],
         [4]]))

In [5]:
type(g)

dgl.heterograph.DGLHeteroGraph

In [5]:
g.ndata['feat'].shape

torch.Size([169343, 128])

In [6]:
d.get_idx_split()

{'train': tensor([     0,      1,      2,  ..., 169145, 169148, 169251]),
 'valid': tensor([   349,    357,    366,  ..., 169185, 169261, 169296]),
 'test': tensor([   346,    398,    451,  ..., 169340, 169341, 169342])}

In [7]:
g.ndata['feat'].device

device(type='cpu')

In [8]:
labels.squeeze(1)

tensor([ 4,  5, 28,  ..., 10,  4,  1])

In [9]:
sampler = MultiLayerNeighborSampler([15, 15, 15])
split_idx = d.get_idx_split()
train_idx = split_idx["train"]#.to(device)
valid_idx = split_idx["valid"]#.to(device)
test_idx = split_idx["test"]#.to(device)
train_dataloader = NodeDataLoader(g,
                                  train_idx,
                                  sampler,
                                  batch_size=2048,
                                  shuffle=True,
                                  drop_last=False,
#                                       num_workers=4,
                                  )    

In [39]:
train_idx.shape

torch.Size([90941])

In [56]:
input_nodes, seeds, blocks = next(iter(train_dataloader))

In [15]:
input_nodes.shape

torch.Size([871532])

In [61]:
seeds

tensor([ 62152,  24565, 118068,  ..., 128546,  25811, 181088])

In [57]:
blocks

[Block(num_src_nodes=880459, num_dst_nodes=176257, num_edges=2582594),
 Block(num_src_nodes=176257, num_dst_nodes=15676, num_edges=230989),
 Block(num_src_nodes=15676, num_dst_nodes=1024, num_edges=15050)]

In [58]:
blocks[0].number_of_dst_nodes()

176257

In [59]:
blocks[1].ndata['feat']['_N'].shape

torch.Size([176257, 100])

In [54]:
blocks[0].dstdata['feat'][0], blocks[0].srcdata['feat'][0]

(tensor([ 0.5732, -0.3104,  0.3387,  1.4321, -0.3140,  1.0019, -0.4314,  0.0523,
         -0.9290,  0.2076, -0.2914,  1.1955,  0.5039,  0.6504,  0.0198,  0.3723,
         -0.4141,  0.6265, -0.2357,  0.0458,  0.0891, -0.2032, -0.0938, -0.5144,
          0.3038, -0.6644, -0.6155,  0.2167,  0.6638,  0.7903, -0.2806, -0.2437,
         -0.4158,  1.0209, -0.3888,  0.7717, -0.1969, -0.5606,  0.3234,  0.6431,
          0.6773, -0.3765, -0.5819,  0.3278, -1.2331,  0.5929,  0.4997,  0.5553,
          0.3544, -0.8994,  0.0871,  2.0412,  1.9411, -0.4537, -0.0190, -1.5940,
          1.0367,  0.1635, -1.1756, -0.3969,  0.9795, -0.3577,  1.0506,  0.9920,
         -1.2444, -0.1570,  0.8460, -0.1106,  0.5636, -1.5201,  1.7386,  1.2931,
         -0.8524,  0.8753, -0.6568,  0.8314,  0.2860,  2.6468, -0.6079, -0.4724,
          0.6979, -1.0838, -2.3402, -0.2772, -1.2829, -0.1060, -1.6026, -0.6662,
          0.9065, -0.0234, -1.1551,  0.8833, -0.2914, -0.0225,  0.3380,  2.1940,
          0.9189, -0.5383, -

In [60]:
blocks[0].dstnodes(), blocks[0].srcnodes()

(tensor([     0,      1,      2,  ..., 176254, 176255, 176256]),
 tensor([     0,      1,      2,  ..., 880456, 880457, 880458]))

In [12]:
gat = GATConv(128,
        64,
        num_heads=3,
#         residual=True,
        allow_zero_in_degree=True)
gat

GATConv(
  (fc): Linear(in_features=128, out_features=192, bias=False)
  (feat_drop): Dropout(p=0.0, inplace=False)
  (attn_drop): Dropout(p=0.0, inplace=False)
  (leaky_relu): LeakyReLU(negative_slope=0.2)
)

In [13]:
rst = gat(blocks[0], blocks[0].srcdata['feat'])

In [37]:
blocks

[Block(num_src_nodes=33991, num_dst_nodes=15218, num_edges=56051),
 Block(num_src_nodes=15218, num_dst_nodes=4416, num_edges=18135),
 Block(num_src_nodes=4416, num_dst_nodes=829, num_edges=3688)]

In [36]:
import gc
for step, (input_nodes, seeds, blocks) in enumerate(train_dataloader):
    print(f"{step}: {blocks}")
    h = blocks[0].srcdata['feat']
    for i in range(len(blocks)):
        rst = gat(blocks[0], blocks[0].srcdata['feat'])
#         h = rst
#         print(rst.shape)


0: [Block(num_src_nodes=55811, num_dst_nodes=31271, num_edges=112971), Block(num_src_nodes=31271, num_dst_nodes=10372, num_edges=42329), Block(num_src_nodes=10372, num_dst_nodes=2048, num_edges=8878)]
1: [Block(num_src_nodes=52465, num_dst_nodes=29198, num_edges=102647), Block(num_src_nodes=29198, num_dst_nodes=9987, num_edges=39279), Block(num_src_nodes=9987, num_dst_nodes=2048, num_edges=8458)]
2: [Block(num_src_nodes=54001, num_dst_nodes=29980, num_edges=105787), Block(num_src_nodes=29980, num_dst_nodes=10145, num_edges=39893), Block(num_src_nodes=10145, num_dst_nodes=2048, num_edges=8621)]
3: [Block(num_src_nodes=54969, num_dst_nodes=31035, num_edges=108180), Block(num_src_nodes=31035, num_dst_nodes=10408, num_edges=41269), Block(num_src_nodes=10408, num_dst_nodes=2048, num_edges=8832)]
4: [Block(num_src_nodes=53765, num_dst_nodes=30379, num_edges=108058), Block(num_src_nodes=30379, num_dst_nodes=10139, num_edges=40819), Block(num_src_nodes=10139, num_dst_nodes=2048, num_edges=8633

In [70]:
rst.shape

torch.Size([176257, 3, 64])

In [20]:
blocks[0].__dir__()

['_graph',
 '_canonical_etypes',
 '_batch_num_nodes',
 '_batch_num_edges',
 '_ntypes',
 '_srctypes_invmap',
 '_dsttypes_invmap',
 '_is_unibipartite',
 '_etypes',
 '_etype2canonical',
 '_etypes_invmap',
 '_node_frames',
 '_edge_frames',
 '__module__',
 '__doc__',
 'is_block',
 '__repr__',
 '__init__',
 '_init',
 '__setstate__',
 '__copy__',
 'add_nodes',
 'add_edge',
 'add_edges',
 'remove_edges',
 'remove_nodes',
 '_reset_cached_info',
 'is_unibipartite',
 'ntypes',
 'etypes',
 'canonical_etypes',
 'srctypes',
 'dsttypes',
 'metagraph',
 'to_canonical_etype',
 'get_ntype_id',
 'get_ntype_id_from_src',
 'get_ntype_id_from_dst',
 'get_etype_id',
 'batch_size',
 'batch_num_nodes',
 'set_batch_num_nodes',
 'batch_num_edges',
 'set_batch_num_edges',
 'nodes',
 'srcnodes',
 'dstnodes',
 'ndata',
 'srcdata',
 'dstdata',
 'edges',
 'edata',
 '_find_etypes',
 '__getitem__',
 'number_of_nodes',
 'num_nodes',
 'number_of_src_nodes',
 'num_src_nodes',
 'number_of_dst_nodes',
 'num_dst_nodes',
 'nu