## Tutorial 3 : Interfacing Graph `Data` with Pytorch Geometric `DataLoader`

In this tutorial, we will learn how to use Pytorch Geometric `DataLoader` to load graph data for mini-batch training. \
Q) Why we want to use Pytorch Geometric `DataLoader`? \
A) It is inherited from Pytorch `DataLoader`, which means it is easy to use + supported by Pytorch community.

Let's start the tutorial by implementing graph dataset.

In [1]:
import warnings
warnings.filterwarnings(action='ignore') 

In [2]:
import torch
from torch.utils.data import Dataset
from torch_geometric.utils import erdos_renyi_graph
from torch_geometric.data import Data

In [3]:
def generate_er_graph(num_nodes:int, 
                      edge_prob:float,
                      feat_dim:int=16):
    
    edge_idx = erdos_renyi_graph(num_nodes=num_nodes,edge_prob=edge_prob)
    x = torch.randn(num_nodes, feat_dim)
    y = (x.sum() / num_nodes).view(1,1)
    dummy = torch.randn(num_nodes, 32)
    g = Data(x=x, edge_index=edge_idx, y=y, dummy=dummy)
    return g

In [4]:
class ERDataset(Dataset):
    
    def __init__(self, 
                 num_graphs:int,
                 min_num_nodes: int = 32,
                 max_num_nodes: int = 64,
                 edge_prob: float = 0.3):
        
        num_nodes = torch.randint(min_num_nodes, max_num_nodes, (num_graphs,))
        self.gs = [
            generate_er_graph(num_nodes[i], edge_prob) for i in range(num_graphs)
        ]
        
    def __getitem__(self, index):
        return self.gs[index]
        
    def __len__(self):
        return len(self.gs)

In [5]:
dataset = ERDataset(128)
print(dataset[0])

Data(x=[60, 16], edge_index=[2, 1038], y=[1, 1], dummy=[60, 32])


## PyG `DataLoader`

As mentioned earlier, PyG `DataLoader` is inherited from Pytorch `DataLoader`. Meaning that we can
pass any arguments or keyword agruments that Pytorch `DataLoader` supports. Furthermore, PyG `DataLoader`
supports graph (mini) batching using custom `collate_fn` -- basically, an well engineering version of `Batch.from_data_list`.

In [6]:
from torch_geometric.data import DataLoader

dataloader = DataLoader(dataset, 
                        follow_batch=['batch'], # we can specify which attributes to follow to form "batch" attribute
                        exclude_keys=['dummy'], # we can specify which attributes can be excluded from the batch
                        batch_size=32, shuffle=True)
batched_g = next(iter(dataloader))

print(batched_g)
print(f'Number of graphs in batch: {batched_g.num_graphs}')

DataBatch(x=[1429, 16], edge_index=[2, 19360], y=[32, 1], batch=[1429], ptr=[33])
Number of graphs in batch: 32


## Feed Batched Data to a GNN model

In [7]:
from torch_geometric.nn.models import GCN

model = GCN(in_channels=16,
            hidden_channels=32, 
            out_channels=1,
            num_layers=3)

pred = model(batched_g.x, batched_g.edge_index)
print(pred.shape) # [#. total nodes, output dim]

torch.Size([1429, 1])
