# Data Handling of Graphs

In [None]:
# Just getting my feet wet with graphs
from torch_geometric.data import Data
from torch import tensor
import torch

edge_index = tensor([[0, 1, 1, 2],
                     [1, 0, 2, 1]], dtype=torch.long)
x = tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
print(data)


print(f'Keys: {data.keys}')
print(f"Data in graph: {data['x']}")

for key, item in data:
    print("{} found in data".format(key))
    
print(f"Edge attribute in data: {'edge_attr' in data}")

# Common Benchmark Datasets

In [None]:
from torch_geometric.datasets import TUDataset, Planetoid

enzymes_dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
cora_dataset = Planetoid(root='/tmp/Cora', name='Cora')

In [None]:
print(f'Dataset: {enzymes_dataset.name} | size : {len(enzymes_dataset)} | # of classes: {enzymes_dataset.num_classes} | # of node features {enzymes_dataset.num_node_features}')
print(f'Dataset: {cora_dataset.name} | size : {len(cora_dataset)} | # of classes: {cora_dataset.num_classes} | # of node features {cora_dataset.num_node_features}')

# To shuffle dataset ~ equivalent to do randperm
enzymes_dataset = enzymes_dataset.shuffle()
cora_dataset = cora_dataset.shuffle()

# Mini-batches

In [None]:
from torch_geometric.data import DataLoader

loader = DataLoader(enzymes_dataset, batch_size=32, shuffle=True)

for batch in loader:
    print(batch)
    print(batch.num_graphs)
    break 

 <h1 style="text-align: center;"> batch is a column vector which maps each node to its respective graph in the batch: </h1>
 <p style="text-align: center;"> $\mathrm{batch} = {\begin{bmatrix} 0 & \cdots & 0 & 1 & \cdots & n - 2 & n -1 & \cdots & n - 1 \end{bmatrix}}^{\top}$ </br>You can use it to, e.g., average node features in the node dimension for each graph individually:</p>

In [None]:
from torch_scatter import scatter_mean


for data in loader:
    print(f'Data: {data}')
    print(f'# of graphs: {data.num_graphs}')
    print(f'Size before average: {data.x.size()}')
    x = scatter_mean(data.x, data.batch, dim=0)
    print(f'Size after average: {x.size()}')
    print(f'{"="* 50}')

# Data Transforms 
#### Let’s look at an example, where we apply transforms on the ShapeNet dataset (containing 17,000 3D shape point clouds and per point labels from 16 shape categories).

In [32]:
from torch_geometric.datasets import ShapeNet
import open3d as o3d
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'])

dataset[0]

Data(pos=[2518, 3], y=[2518])

## Run this to visualize the first point cloud

In [15]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(dataset[0].pos.numpy())
o3d.visualization.draw_geometries([pcd])

## We can convert the point cloud dataset into a graph dataset by generating nearest neighbor graphs from the point clouds via transforms:

In [35]:
import torch_geometric.transforms as T
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6))

Downloading https://shapenet.cs.stanford.edu/iccv17/partseg/train_data.zip
Extracting /tmp/ShapeNet/raw/train_data.zip
Downloading https://shapenet.cs.stanford.edu/iccv17/partseg/train_label.zip
Extracting /tmp/ShapeNet/raw/train_label.zip
Downloading https://shapenet.cs.stanford.edu/iccv17/partseg/val_data.zip
Extracting /tmp/ShapeNet/raw/val_data.zip
Downloading https://shapenet.cs.stanford.edu/iccv17/partseg/val_label.zip
Extracting /tmp/ShapeNet/raw/val_label.zip
Downloading https://shapenet.cs.stanford.edu/iccv17/partseg/test_data.zip
Extracting /tmp/ShapeNet/raw/test_data.zip
Downloading https://shapenet.cs.stanford.edu/iccv17/partseg/test_label.zip
Extracting /tmp/ShapeNet/raw/test_label.zip
Processing...
Done!


## We use the pre_transform to convert the data before saving it to disk (leading to faster loading times). Note that the next time the dataset is initialized it will already contain graph edges, even if you do not pass any transform.

In [51]:
dataset, dataset[0]

(ShapeNet(2349, categories=['Airplane']),
 Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518]))

In [53]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(dataset[100].pos.numpy())
o3d.visualization.draw_geometries([pcd])