In [364]:
import os
import time

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.nn import Parameter
from torch.nn.parallel import DistributedDataParallel
from torch_sparse import matmul, fill_diag
from torch_sparse import sum as sparsesum
from tqdm import tqdm

from torch_geometric.datasets import Reddit2
import torch_geometric.transforms as T
from torch_geometric.loader import NeighborSampler, ClusterLoader
from torch_geometric.nn import Linear, GCNConv
from torch_geometric.loader import ClusterData
from ogb.nodeproppred import PygNodePropPredDataset

to_sparse = T.ToSparseTensor()

In [366]:
path = os.environ.get('DATA_DIR')
path = os.path.join(path, 'ogb')
transform = T.Compose([T.ToSparseTensor()])
dataset = PygNodePropPredDataset(name='ogbn-products', root=path, transform=transform)

world_size = torch.cuda.device_count()
num_classes = dataset.num_classes

In [None]:
cluster_data = ClusterData(dataset[0], num_parts=world_size, recursive=False,
                           save_dir=dataset.processed_dir)
data_list = list(ClusterLoader(cluster_data, batch_size=1, shuffle=False))

In [368]:
from torch_geometric_autoscale import metis, permute, SubgraphLoader

data = dataset[0]
perm, ptr = metis(data.adj_t, world_size, log=True)
data = permute(data, perm, log=True)
data_list = list(SubgraphLoader(data, ptr, batch_size=1, shuffle=False))

Computing METIS partitioning with 4 parts... Done! [15.65s]
Permuting data... Done! [27.85s]
Pre-processing subgraphs... Done! [3.54s]


In [463]:
for sub_data in data_list:
    print(sub_data.batch_size)
    print(sub_data.data.num_nodes)
    print()

613761
814388

622604
814385

601931
882541

610733
903601



In [460]:
ptr[1:] - ptr[:-1]

tensor([613761, 622604, 601931, 610733])

In [371]:
sub_data = data_list[0]
batch_size = sub_data.batch_size
n_id = sub_data.n_id

# torch.equal(sub_data.data.x[batch_size:], data.x[n_id[batch_size:]])

out_of_batch = sub_data.n_id[sub_data.batch_size:]
out_of_batch

tensor([ 677613, 1052124, 1937103,  ..., 1973056, 1305320, 1309691])

In [445]:
from torch_geometric.utils import mask_to_index

halo_nodes = [[[] for j in range(world_size)] for i in range(world_size)]
reverse_idx = [[] for _ in range(world_size)]
halo_counts = [[] for _ in range(world_size)]

package_sizes = []
for batch_id, sub_data in enumerate(data_list):
    out_of_batch = sub_data.n_id[sub_data.batch_size:]

    package_size = 0
    for i in range(world_size):
        if batch_id == i:
            continue
        mask = (out_of_batch >= ptr[i]) & (out_of_batch < ptr[i+1])
        num_halos = mask.sum().item()
        if num_halos > package_size:
            package_size = num_halos

        halo_nodes[i][batch_id] = out_of_batch[mask] - ptr[i]
        reverse_idx[batch_id].append(mask_to_index(mask))
        halo_counts[batch_id].append(num_halos)
    package_sizes.append(package_size)
package_sizes

[83025, 75734, 160531, 167743]

In [443]:
package_size = 80000
temp = data_list[0].data.x[halo_nodes[0][1]]
temp_ = F.pad(temp, (0, 0, 0, package_size-temp.size(0)))
temp_[:temp.size(0)].equal(temp)

True

In [375]:
# gather, root node is 0
scattered_nodes = torch.cat([data_list[i].data.x[halo_nodes[i][0]] for i in range(1, world_size)])
scattered_nodes.shape

torch.Size([200627, 100])

In [376]:
torch.equal(scattered_nodes, data_list[0].data.x[batch_size:][torch.cat(reverse_idx[0])])

True

In [377]:
temp = torch.ones_like(data_list[0].data.x[batch_size:])
# 更新当前层halo-node表征：data.x[batch_size:][torch.cat(reverse_idx[0])] = scattered_nodes
temp[torch.cat(reverse_idx[0])] = scattered_nodes
temp.equal(data_list[0].data.x[batch_size:])

True