In [1]:
# Import torch & Check CUDA availability
import torch

print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())

True
1
0


In [2]:
# Get CUDA device name
print(torch.cuda.device(0))
print(torch.cuda.get_device_name(0))

<torch.cuda.device object at 0x7f9a9a7e0150>
NVIDIA A30


#### Import Reddit

In [3]:
from torch_geometric.datasets import Reddit
import torch_geometric.transforms as T

# Import dataset from PyTorch Geometric
dataset = Reddit(root="/dfs6/pub/seminl1/Reddit", transform=T.ToSparseTensor())
data = dataset[0]
data = data.pin_memory()

# Print information about the dataset
print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {data.x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

# Print information about the graph
print(f'\nGraph:')
print('------')
print(f'Edges are directed: {data.is_directed()}')
print(f'Graph has isolated nodes: {data.has_isolated_nodes()}')
print(f'Graph has loops: {data.has_self_loops()}')

Dataset: Reddit()
-------------------
Number of graphs: 1
Number of nodes: 232965
Number of features: 602
Number of classes: 41

Graph:
------
Edges are directed: False
Graph has isolated nodes: False
Graph has loops: False


#### NeighborLoader

In [4]:
from torch_geometric.loader import NeighborLoader

# NeighborLoader
train_loader = NeighborLoader(
    data,
    num_neighbors=[5],
    batch_size=8192,
    pin_memory=True,
)

In [5]:
# Print and save each subgraph
for i, subgraph in enumerate(train_loader):
    print(f'Subgraph {i}: {subgraph}')
    torch.save(subgraph, "/dfs6/pub/seminl1/Reddit/train_loader_neighbor_sampling_{0}.pt".format(i))

Subgraph 0: Data(x=[40024, 602], y=[40024], train_mask=[40024], val_mask=[40024], test_mask=[40024], adj_t=[40024, 40024, nnz=40440], n_id=[40024], e_id=[40440], num_sampled_nodes=[2], num_sampled_edges=[1], input_id=[8192], batch_size=8192)
Subgraph 1: Data(x=[40034, 602], y=[40034], train_mask=[40034], val_mask=[40034], test_mask=[40034], adj_t=[40034, 40034, nnz=40348], n_id=[40034], e_id=[40348], num_sampled_nodes=[2], num_sampled_edges=[1], input_id=[8192], batch_size=8192)
Subgraph 2: Data(x=[40031, 602], y=[40031], train_mask=[40031], val_mask=[40031], test_mask=[40031], adj_t=[40031, 40031, nnz=40332], n_id=[40031], e_id=[40332], num_sampled_nodes=[2], num_sampled_edges=[1], input_id=[8192], batch_size=8192)
Subgraph 3: Data(x=[40033, 602], y=[40033], train_mask=[40033], val_mask=[40033], test_mask=[40033], adj_t=[40033, 40033, nnz=40372], n_id=[40033], e_id=[40372], num_sampled_nodes=[2], num_sampled_edges=[1], input_id=[8192], batch_size=8192)
Subgraph 4: Data(x=[40103, 602],

In [6]:
train_loader = torch.load("/dfs6/pub/seminl1/Reddit/train_loader_neighbor_sampling_0.pt")

In [7]:
print(train_loader)

Data(x=[40024, 602], y=[40024], train_mask=[40024], val_mask=[40024], test_mask=[40024], adj_t=[40024, 40024, nnz=40440], n_id=[40024], e_id=[40440], num_sampled_nodes=[2], num_sampled_edges=[1], input_id=[8192], batch_size=8192)
