In [1]:
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, Dataset
from typing import Optional, Any, Callable, List, Tuple, Union
from torch import Tensor
from torch_geometric.typing import OptTensor
import torch
import os
import pandas as pd
from ogb.utils import smiles2graph
from ogb.utils.torch_util import replace_numpy_with_torchtensor
from ogb.utils.url import download_url, extract_zip
from rdkit import RDLogger
import shutil
from tqdm import tqdm
from torch_geometric.data.data import BaseData
from ogb.lsc import PCQM4MEvaluator
from gin_graph import GINGraphPooling
from torch.utils.tensorboard import SummaryWriter

### 小图合并组成大图
PyTorch Geometric中采用的将多个图封装成批的方式是，将小图作为连通组件（connected component）的形式合并，构建一个大图。于是小图的邻接矩阵存储在大图邻接矩阵的对角线上。
#### 小图的属性增值与拼接
将小图存储到大图中时需要对小图的属性做一些修改，一个最显著的例子就是要对节点序号增值。在最一般的形式中，`PyTorch Geometric`的`DataLoader`类会自动对`edge_index`张量增值，增加的值为当前被处理图的前面的图的累积节点数量。比方说，现在对第$k$个图的`edge_index`张量做增值，前面$k-1$个图的累积节点数量为$n$，那么对第$k$个图的`edge_index`张量的增值$n$。增值后，对所有图的`edge_index`张量（其形状为`[2, num_edges]`）在第二维中连接起来。
然而，有一些特殊的场景中，基于需求我们希望能修改这一行为。`PyTorch Geometric`允许我们通过覆盖`torch_geometric.data.__inc__()`和`torch_geometric.data.__cat_dim__()`函数来实现我们希望的行为。

案例：**图的匹配（Pairs of Graphs）**
如果你想在一个`Data`对象中存储多个图，例如用于图匹配等应用，我们需要确保所有这些图的正确封装成批行为。例如，考虑将两个图，一个源图$G_s$和一个目标图$G_t$，存储在一个Data类中。在这种情况中，`edge_index_s`应该根据源图$G_s$的节点数做增值，即`x_s.size(0)`，而`edge_index_t`应该根据目标图$G_t$的节点数做增值，即`x_t.size(0)`。

In [2]:
class PairData(Data):
    def __init__(self, edge_index_s, x_s, edge_index_t, x_t, **kwargs):
        super().__init__(**kwargs)
        self.edge_index_s = edge_index_s
        self.x_s = x_s
        self.edge_index_t = edge_index_t
        self.x_t = x_t
    
    def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        return super().__inc__(key, value, *args, **kwargs)

In [3]:
edge_index_s = torch.tensor([
    [0, 0, 0, 0],
    [1, 2, 3, 4],
])
x_s = torch.randn(5, 16)  # 5 nodes.
edge_index_t = torch.tensor([
    [0, 0, 0],
    [1, 2, 3],
])
x_t = torch.randn(4, 16)  # 4 nodes.

data = PairData(edge_index_s, x_s, edge_index_t, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)

print(batch.edge_index_s)
print(batch.x_s.shape)
print(batch.edge_index_t)

PairDataBatch(edge_index_s=[2, 8], x_s=[10, 16], edge_index_t=[2, 6], x_t=[8, 16])
tensor([[0, 0, 0, 0, 5, 5, 5, 5],
        [1, 2, 3, 4, 6, 7, 8, 9]])
torch.Size([10, 16])
tensor([[0, 0, 0, 4, 4, 4],
        [1, 2, 3, 5, 6, 7]])


由于`PyTorch Geometric`无法识别`PairData`对象中实际的图，所以`batch`属性（将大图每个节点映射到其各自对应的小图）没有正确工作。此时就需要`DataLoader`的`follow_batch`参数发挥作用。在这里，我们可以指定我们要为哪些属性维护批信息。
`follow_batch=['x_s', 'x_t']`现在成功地为节点特征`x_s`和`x_t`分别创建了名为`x_s_batch`和`x_t_batch`的赋值向量。这些信息现在可以用来在一个单一的Batch对象中对多个图进行聚合操作，例如，全局池化。

In [4]:
loader = DataLoader(data_list, batch_size=2, follow_batch=['x_s', 'x_t'])
batch = next(iter(loader))

print(batch)
print(batch.x_s_batch)
print(batch.x_t_batch)

PairDataBatch(edge_index_s=[2, 8], x_s=[10, 16], x_s_batch=[10], x_s_ptr=[3], edge_index_t=[2, 6], x_t=[8, 16], x_t_batch=[8], x_t_ptr=[3])
tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
tensor([0, 0, 0, 0, 1, 1, 1, 1])


#### 二部图
二部图的邻接矩阵定义两种类型的节点之间的连接关系。一般来说，不同类型的节点数量不需要一致，于是二部图的邻接矩阵$A \in {0,1}^{N \times M}$可能为平方矩阵，即可能有$N \neq M$。
为了对二部图实现正确的封装成批，我们需要告诉`PyTorch Geometric`，它应该在`edge_index`中独立地为边的源节点和目标节点做增值操作。
其中，`edge_index[0]`（边的源节点）根据`x_s.size(0)`做增值运算，而`edge_index[1]`（边的目标节点）根据`x_t.size(0)`做增值运算。

In [5]:
class BipartiteData(Data):
    def __init__(self, edge_index, x_s, x_t):
        super().__init__()
        self.edge_index = edge_index
        self.x_s = x_s
        self.x_t = x_t
    
    def __inc__(self, key, value, *args):
        if key == 'edge_index':
            return torch.tensor([[self.x_s.size(0)], [self.x_t.size(0)]])
        else:
            return super().__inc__(key, value)

In [6]:
edge_index = torch.tensor([
    [0, 0, 1, 1],
    [0, 1, 1, 2],
])
x_s = torch.randn(2, 16)  # 2 nodes.
x_t = torch.randn(3, 16)  # 3 nodes.

data = BipartiteData(edge_index, x_s, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)

print(batch.edge_index)

BipartiteDataBatch(edge_index=[2, 8], x_s=[4, 16], x_t=[6, 16], batch=[6], ptr=[3])
tensor([[0, 0, 1, 1, 2, 2, 3, 3],
        [0, 1, 1, 2, 3, 4, 4, 5]])




### 使用PCQM4M-LSC数据集来实践
PCQM4M-LSC是一个分子图的量子特性回归数据集，它包含了3,803,453个图。

In [7]:
RDLogger.DisableLog('rdApp.*')


class MyPCQM4MDataset(Dataset):
    def __init__(self, root: str | None = None, transform: Callable[..., Any] | None = None, pre_transform: Callable[..., Any] | None = None, pre_filter: Callable[..., Any] | None = None, log: bool = True):
        self.url = 'https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip'
        super().__init__(root)
        file_path = os.path.join(root, 'raw/data.csv.gz')
        data_df = pd.read_csv(file_path)
        self.smiles_list = data_df['smiles']
        self.homolumogap_list = data_df['homolumogap']
    
    @property
    def raw_file_names(self) -> str | List[str] | Tuple:
        return 'data.csv.gz'
    
    def download(self):
        path = download_url(self.url, self.root)
        extract_zip(path, self.root)
        os.unlink(path)
        shutil.move(os.path.join(self.root, 'pcqm4m_kddcup2021/raw/data.csv.gz'), os.path.join(self.root, 'raw/data.csv.gz'))
    
    def len(self) -> int:
        return len(self.smiles_list)
    
    def get(self, idx: int) -> BaseData:
        smiles, homolumogap = self.smiles_list[idx], self.homolumogap_list[idx]
        graph = smiles2graph(smiles)
        assert(len(graph['edge_feat']) == graph['edge_index'].shape[1])
        assert(len(graph['node_feat']) == graph['num_nodes'])
        
        x = torch.from_numpy(graph['node_feat']).to(torch.int64)
        edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
        edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
        y = torch.Tensor([homolumogap])
        num_nodes = int(graph['num_nodes'])
        data = Data(x, edge_index, edge_attr, y, num_nodes=num_nodes)
        return data

    def get_idx_split(self):
        split_dict = replace_numpy_with_torchtensor(torch.load(os.path.join(self.root, 'pcqm4m_kddcup2021/split_dict.pt')))
        return split_dict
    
    
def prepartion(task_name, device):
    save_dir = os.path.join('saves', task_name)
    if os.path.exists(save_dir):
        for idx in range(1000):
            if not os.path.exists(save_dir + '=' + str(idx)):
                save_dir = save_dir + '=' + str(idx)
                break

    os.makedirs(save_dir, exist_ok=True)
    device = torch.device("cuda:" + str(device)) if torch.cuda.is_available() else torch.device("cpu")
    output_file = open(os.path.join(save_dir, 'output'), 'a')
    return save_dir, device, output_file


def train(model, device, loader, optimizer, criterion_fn):
    model.train()
    loss_accum = 0

    for step, batch in enumerate(tqdm(loader)):
        batch = batch.to(device)
        pred = model(batch).view(-1,)
        optimizer.zero_grad()
        loss = criterion_fn(pred, batch.y)
        loss.backward()
        optimizer.step()
        loss_accum += loss.detach().cpu().item()

    return loss_accum / (step + 1)


def eval(model, device, loader, evaluator):
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for _, batch in enumerate(tqdm(loader)):
            batch = batch.to(device)
            pred = model(batch).view(-1,)
            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim=0)
    y_pred = torch.cat(y_pred, dim=0)
    input_dict = {"y_true": y_true, "y_pred": y_pred}
    return evaluator.eval(input_dict)["mae"]


def test(model, device, loader):
    model.eval()
    y_pred = []

    with torch.no_grad():
        for _, batch in enumerate(loader):
            batch = batch.to(device)
            pred = model(batch).view(-1,)
            y_pred.append(pred.detach().cpu())

    y_pred = torch.cat(y_pred, dim=0)
    return y_pred

In [8]:
task_name = 'GINGraphPooling'
num_layers = 5
graph_pooling = 'sum'
emb_dim = 256
drop_ratio = 0.
save_test = True
batch_size = 1024
epochs = 100
weight_decay = 5e-5
early_stop = 10
dataset_root = '../datasets/PCQM4M'

In [9]:
save_dir, device, output_file = prepartion(task_name, 0)
nn_params = {
    'num_layers': num_layers,
    'emb_dim': emb_dim,
    'drop_ratio': drop_ratio,
    'graph_pooling': graph_pooling
}
dataset = MyPCQM4MDataset(dataset_root)
split_idx = dataset.get_idx_split()
train_data = dataset[split_idx['train']]
valid_data = dataset[split_idx['valid']]
test_data = dataset[split_idx['test']]
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)


In [10]:

evaluator = PCQM4MEvaluator()
criterion_fn = torch.nn.MSELoss()
model = GINGraphPooling(**nn_params).to(device)
num_params = sum(p.numel() for p in model.parameters())
print(f'#Params: {num_params}', file=output_file, flush=True)
print(model, file=output_file, flush=True)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.25)

writer = SummaryWriter(save_dir)
not_improved = 0
best_valid_mae = 9999

for epoch in range(1, epochs + 1):
    print(f'======Epch {epoch}======', file=output_file, flush=True)
    print('Training....', file=output_file, flush=True)
    train_mae = train(model, device, train_loader, optimizer, criterion_fn)
    print('Evaluating....', file=output_file, flush=True)
    valid_mae = eval(model, device, valid_loader, evaluator)
    print(f'Train {train_mae}, Validation {valid_mae}', file=output_file, flush=True)
    writer.add_scalar('valid/mae', valid_mae, epoch)
    writer.add_scalar('train/mae', train_mae, epoch)
    if valid_mae < best_valid_mae:
        best_valid_mae = valid_mae
        if save_test:
            print('Saving checkpoint....', file=output_file, flush=True)
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler.state_dict': scheduler.state_dict(),
                'best_val_mae': best_valid_mae,
                'num_params': num_params
                }
            torch.save(checkpoint, os.path.join(save_dir, 'checkpoint.pt'))
            print('Predicting on test data....', file=output_file, flush=True)
            y_pred = test(model, device, test_loader)
            print('Saving test submission file....', file=output_file, flush=True)
            evaluator.save_test_submission({'y_pred': y_pred}, save_dir)
        
        not_improved = 0
    else:
        not_improved += 1
        if not_improved == early_stop:
            print(f'Have not improved for {not_improved} epochs.', file=output_file, flush=True)
            break
    scheduler.step()

writer.close()
output_file.close()


100%|██████████| 2974/2974 [16:11<00:00,  3.06it/s]
100%|██████████| 372/372 [01:51<00:00,  3.35it/s]
100%|██████████| 2974/2974 [16:04<00:00,  3.08it/s]
100%|██████████| 372/372 [01:51<00:00,  3.35it/s]
100%|██████████| 2974/2974 [16:02<00:00,  3.09it/s]
100%|██████████| 372/372 [01:50<00:00,  3.37it/s]
100%|██████████| 2974/2974 [16:03<00:00,  3.09it/s]
100%|██████████| 372/372 [01:50<00:00,  3.37it/s]
100%|██████████| 2974/2974 [15:56<00:00,  3.11it/s]
100%|██████████| 372/372 [01:50<00:00,  3.38it/s]
100%|██████████| 2974/2974 [16:02<00:00,  3.09it/s]
100%|██████████| 372/372 [01:50<00:00,  3.38it/s]
100%|██████████| 2974/2974 [15:59<00:00,  3.10it/s]
100%|██████████| 372/372 [01:49<00:00,  3.39it/s]
100%|██████████| 2974/2974 [16:00<00:00,  3.10it/s]
100%|██████████| 372/372 [01:49<00:00,  3.39it/s]
100%|██████████| 2974/2974 [15:59<00:00,  3.10it/s]
100%|██████████| 372/372 [01:50<00:00,  3.38it/s]
100%|██████████| 2974/2974 [15:58<00:00,  3.10it/s]
100%|██████████| 372/372 [01:4