<a href="https://colab.research.google.com/github/Kevoen/Google_Colab_Rep/blob/master/GeometricLibrary.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 安装数据库pytorch-geometric工具库

In [13]:
# Install required packages.
!pip install -q torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.6.0.html
!pip install -q torch-sparse==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.6.0.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git

  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone


![](https://miro.medium.com/max/700/1*WbZ6BvixcLI0t6Zz3E4EIw.png)
<br>So there are 4 nodes in the graph, v1 … v4, each of which is associated with a 2-dimensional feature vector, and a label y indicating its class. These two can be represented as FloatTensors:
```python
x = torch.tensor([2,1],[5,6],[3,7],[12,0])
y = torch.tensor([0,1,0,1])
```
The graph connectivity (edge index) should be confined with the COO format, i.e. the first list contains the index of the source nodes, while the index of target nodes is specified in the second list.<br>
```python
edge_index = torch.tensor([[0,1,2,0,3],
              [1,0,1,3,2]],dtype=torch.long)
```

In [20]:
import torch
from torch_geometric.data import Data

x = torch.tensor([[2,1],[5,6],[3,7],[12,0]],dtype=torch.float) #node feature, shapa=[node_num, node_feature_num]
y = torch.tensor([0,1,0,1],dtype=torch.float)   #Target of node,shape=[node_num,*]or[1,*]

edge_index = torch.tensor([[0.,1.,2.,3.,0.],
              [1.,0.,1.,2.,3.]],dtype=torch.long)
data = Data(x=x,y=y,edge_index=edge_index)

In [21]:
print(data['edge_index'])

tensor([[0, 1, 2, 3, 0],
        [1, 0, 1, 2, 3]])


## 数据加载Data与DataLoader

### 1、PyG 的 Dataset继承自torch.utils.data.Dataset，自带了很多图数据集，我们以TUDataset为例，通过以下代码就可以加载数据集，root参数设置数据下载的位置。通过索引可以访问每一个数据。

In [22]:
from torch_geometric.datasets import TUDataset   #加载pyG库自带数据
datasets = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')

Downloading https://www.chrsmrrs.com/graphkerneldatasets/ENZYMES.zip
Extracting /tmp/ENZYMES/ENZYMES/ENZYMES.zip
Processing...
Done!


In [28]:
data = datasets[0]
data

Data(edge_index=[2, 168], x=[37, 3], y=[1])

In [29]:
data.is_undirected()

True

`Data(edge_index=[2, 168], x=[37, 3], y=[1])`<br>由Data的内容可以看出，`data[0]`为一个由37个node，每个结点含有3个特征值，168/2=84条edge的无向图,存在分类

In [30]:
train_datasets = datasets[:540]
test_datasets = datasets[540:]   #切分数据集，训练集：测试集=9：1

In [32]:
print(len(train_datasets))
print(len(test_datasets))

540
60


如果在切分数据集之前不确定数据是否打乱，则使用`shuffle()`函数进行随机排列

In [33]:
datasets = datasets.shuffle()   #使用shuffle()函数对数据集进行随机排序
print(len(datasets))

600


In [35]:
#上述方法与该方法等价
perm = torch.randperm(len(datasets))    #使用randperm()函数对datasets序号进行随机排序
datasets = datasets[perm]
print(len(datasets))

600


### 2、重新下载一个数据库`Cora`，该数据是用于半监督图节点分类的基准数据集

In [36]:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora',name='Cora')

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [40]:
data = dataset[0]
data

Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])

In [42]:
data.is_undirected()

True

`Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])`
<br>其中增加了`test_mask`、`train_mask`、`val_mask`三个参数，其分别代表：
- `train_mask`代表需要进行训练的结点
- `test_mask`代表需要进行测试的结点
- `val_mask`代表那些结点需要进行验证

### 3、使用DataLoader对数据进行加载

In [46]:
import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch in loader:    #min-batch
  print(batch)


Batch(batch=[1046], edge_index=[2, 4036], x=[1046, 21], y=[32])
Batch(batch=[980], edge_index=[2, 3814], x=[980, 21], y=[32])
Batch(batch=[1005], edge_index=[2, 3904], x=[1005, 21], y=[32])
Batch(batch=[975], edge_index=[2, 3880], x=[975, 21], y=[32])
Batch(batch=[1096], edge_index=[2, 3884], x=[1096, 21], y=[32])
Batch(batch=[967], edge_index=[2, 3604], x=[967, 21], y=[32])
Batch(batch=[997], edge_index=[2, 3848], x=[997, 21], y=[32])
Batch(batch=[1085], edge_index=[2, 4194], x=[1085, 21], y=[32])
Batch(batch=[1000], edge_index=[2, 3880], x=[1000, 21], y=[32])
Batch(batch=[1011], edge_index=[2, 3880], x=[1011, 21], y=[32])
Batch(batch=[1084], edge_index=[2, 4236], x=[1084, 21], y=[32])
Batch(batch=[1043], edge_index=[2, 3878], x=[1043, 21], y=[32])
Batch(batch=[1082], edge_index=[2, 4062], x=[1082, 21], y=[32])
Batch(batch=[989], edge_index=[2, 3838], x=[989, 21], y=[32])
Batch(batch=[1050], edge_index=[2, 3766], x=[1050, 21], y=[32])
Batch(batch=[1126], edge_index=[2, 4230], x=[1126,

## Transforms

`transforms`在计算机视觉领域是一种很常见的数据增强。PyG有自己的`transforms`，输入是`Data`类型，输出也是`Data`类型。可以使用`torch_geometric.transforms.Compose`封装一系列的`transforms`。我们以`ShapeNet`数据集 (包含 17000 个 point clouds，每个 point 分类为 16 个类别的其中一个) 为例，我们可以使用transforms从 point clouds 生成最近邻图：

In [None]:
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'], pre_transform=T.KNNGraph(k=6))
print(dataset[0])

In [None]:
dataset = ShapeNet(root='/tmp/ShapeNet', 
                   categories=['Airplane'], 
                   pre_transform=T.KNNGraph(k=6), 
                   transform=T.RandomTranslate(0.01))   #数据增强，添加坐标上的扰动
print(dataset[0])