In [1]:
import torch
import dgl
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl.nn as dglnn
from dgl.data import AsNodePredDataset
from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler
from ogb.nodeproppred import DglNodePropPredDataset
import tqdm
import argparse
import ast
import sklearn.metrics
import numpy as np
import time

class SAGE(nn.Module):
    def __init__(self, in_size, hid_size, out_size,num_layers=2):
        super().__init__()
        self.layers = nn.ModuleList()
        # three-layer GraphSAGE-mean
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean'))
        for _ in range(num_layers - 2):
            self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, 'mean'))
        self.dropout = nn.Dropout(0.5)
        self.hid_size = hid_size
        self.out_size = out_size

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

    def inference(self, g,device, batch_size):
        """Conduct layer-wise inference to get all the node embeddings."""
        feat = g.ndata['feat']
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
        # sampler = NeighborSampler([15],  # fanout for [layer-0, layer-1, layer-2]
        #                     prefetch_node_feats=['feat'],
        #                     prefetch_labels=['label'])
        dataloader = DataLoader(
                g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
                batch_size=batch_size, shuffle=False, drop_last=False,
                num_workers=0)
        buffer_device = torch.device('cpu')
        pin_memory = (buffer_device != device)

        for l, layer in enumerate(self.layers):
            y = torch.empty(
                g.num_nodes(), self.hid_size if l != len(self.layers) - 1 else self.out_size,
                device=buffer_device, pin_memory=pin_memory)
            feat = feat.to(device)
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                x = feat[input_nodes]
                h = layer(blocks[0], x) # len(blocks) = 1
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
                # by design, our output nodes are contiguous
                y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device)
            feat = y
        return y

def evaluate(model, graph, dataloader):
    model.eval()
    ys = []
    y_hats = []
    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
        with torch.no_grad():
            x = blocks[0].srcdata['feat']
            ys.append(blocks[-1].dstdata['label'].cpu().numpy())
            y_hats.append(model(blocks, x).argmax(1).cpu().numpy())
        predictions = np.concatenate(y_hats)
        labels = np.concatenate(ys)
    return sklearn.metrics.accuracy_score(labels, predictions)

In [2]:
dataset = AsNodePredDataset(DglNodePropPredDataset('ogbn-papers100M',root="/home/bear/workspace/singleGNN/data/dataset"))
g = dataset[0]

# dataset = AsNodePredDataset(DglNodePropPredDataset('ogbn-products',root="/home/bear/workspace/singleGNN/data/dataset"))
# g = dataset[0]

In [3]:
g

Graph(num_nodes=111059956, num_edges=1615685872,
      ndata_schemes={'test_mask': Scheme(shape=(), dtype=torch.uint8), 'val_mask': Scheme(shape=(), dtype=torch.uint8), 'train_mask': Scheme(shape=(), dtype=torch.uint8), 'feat': Scheme(shape=(128,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.float32), 'year': Scheme(shape=(1,), dtype=torch.int64)}
      edata_schemes={})

In [21]:

src = g.edges()[0].numpy()
dst = g.edges()[1].numpy()
src.tofile("/raid/bear/products_bin/srcList.bin")
dst.tofile("/raid/bear/products_bin/dstList.bin")

In [22]:
feat = g.ndata['feat'].numpy().tofile("/raid/bear/products_bin/feat.bin")
label = g.ndata['label'].numpy().tofile("/raid/bear/products_bin/label.bin")
# feat.shape
# label.shape

In [13]:
torch.nonzero(g.ndata['train_mask']).squeeze().numpy().tofile("/raid/bear/products_bin/trainIDs.bin")
torch.nonzero(g.ndata['val_mask']).squeeze().numpy().tofile("/raid/bear/products_bin/valIDs.bin")
torch.nonzero(g.ndata['test_mask']).squeeze().numpy().tofile("/raid/bear/products_bin/testIDs.bin")



In [None]:
device = torch.device('cuda:0')
in_size = g.ndata['feat'].shape[1]
out_size = dataset.num_classes
model = SAGE(in_size, 256, out_size,2).to(device)

In [None]:
sampler = NeighborSampler([10, 10],
                              prefetch_node_feats=['feat'],
                              prefetch_labels=['label'])
train_idx = dataset.train_idx
val_idx = dataset.val_idx
test_idx = dataset.test_idx
device = torch.device('cpu')
train_dataloader = DataLoader(g, train_idx, sampler, device=device,
                                  batch_size=1024, shuffle=True,
                                  drop_last=False, num_workers=8)
val_dataloader = DataLoader(g, val_idx, sampler, device=device,
                                batch_size=1024, shuffle=True,
                                drop_last=False, num_workers=8,
                                )
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)

In [4]:
sampler = NeighborSampler([15, 50],
                              prefetch_node_feats=['feat'],
                              prefetch_labels=['label'])
train_idx = dataset.train_idx
val_idx = dataset.val_idx
test_idx = dataset.test_idx
device = torch.device('cuda:0')
train_dataloader = DataLoader(g, train_idx, sampler, device=device,
                                  batch_size=1024, shuffle=True,
                                  drop_last=False, num_workers=0,use_uva=True)
# opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)

t_list= []
out_list = []
edges = [[],[]]

for epoch in range(1):
    start = time.time()
    total_loss = 0
    startTime = time.time()
    count = 0
    for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
        t_list.extend(input_nodes)
        out_list.extend(output_nodes)
        edges[0].extend(input_nodes[blocks[0].edges()[0]])
        edges[0].extend(input_nodes[blocks[1].edges()[0]])
        edges[1].extend(input_nodes[blocks[0].edges()[1]])
        edges[1].extend(input_nodes[blocks[1].edges()[1]])
unique_input_elements = t_list.unique()
unique_seed_elements = out_list.unique()
      


In [None]:
subG_feats = g.ndata['feat'][unique_input_elements]
subG_labels = g.ndata['label'][unique_seed_elements]
src = torch.cat(edges[0])
dst = torch.cat(edges[1])
test_g = dgl.graph((src, dst))

In [None]:
# print(unique_elements.shape)
subG_feats = g.ndata['feat'][unique_input_elements]
print(subG_feats.shape)
subG_labels = g.ndata['label'][unique_seed_elements]
src = torch.cat(edges[0])
dst = torch.cat(edges[1])
print(subG_labels.shape)

In [None]:
test_g = dgl.graph((src, dst))
test_g.ndata['feat'] = g.ndata['feat'][:111059954]
test_g.ndata['label'] = g.ndata['label'][:111059954].to(torch.int64)
print(test_g)

In [None]:
seeds = unique_seed_elements
seed_sampler = NeighborSampler([5, 5],
                              prefetch_node_feats=['feat'],
                              prefetch_labels=['label'])
trains_dataloader = DataLoader(test_g, seeds, seed_sampler, device=device,
                                  batch_size=1024, shuffle=True,
                                  drop_last=False, num_workers=8)

with trains_dataloader.enable_cpu_affinity():
    for epoch in range(10):
        start = time.time()
        model.train()
        total_loss = 0
        startTime = time.time()
        count = 0
        for it, (input_nodes, output_nodes, blocks) in enumerate(trains_dataloader):
            x = blocks[0].srcdata['feat']
            y = blocks[-1].dstdata['label']
            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
            count = it
        print("count=",count)
        print("time=",time.time()-startTime)
        acc = evaluate(model, g, val_dataloader)
        print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} "
                .format(epoch, total_loss / (it+1), acc.item()))