# OGB

In [1]:
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.loader import DataLoader
from torch_geometric.data import InMemoryDataset
from itertools import repeat
from copy import deepcopy, copy
from torch_geometric.data.separate import separate



ogb_list = ["hiv", "bbbp", "clintox", "tox21", "sider"]
i = 0
# Download and process data at './dataset/ogbg_molhiv/'
dataset = PygGraphPropPredDataset(name = "ogbg-mol" + ogb_list[i], root = '../data/')

 
split_idx = dataset.get_idx_split() 
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from aug import *
dataset[2]

Data(edge_index=[2, 48], edge_attr=[48, 3], x=[21, 9], y=[1, 1], num_nodes=21)

In [3]:
class OGBDataset_aug(PygGraphPropPredDataset):
    def __init__(self, name, root, aug=None, transform=None, pre_transform=None):
        super(OGBDataset_aug, self).__init__(name, root, transform, pre_transform)
        self.aug = aug

    def get(self, idx: int):
        data = super(PygGraphPropPredDataset, self).get(idx)

        node_num = data.edge_index.max()
        sl = torch.tensor([[n, n] for n in range(node_num)]).t()
        data.edge_index = torch.cat((data.edge_index, sl), dim=1)

        if self.aug == 'dnodes':
            data_aug = drop_nodes(deepcopy(data))
        elif self.aug == 'pedges':
            data_aug = permute_edges(deepcopy(data))
        elif self.aug == 'subgraph':
            data_aug = subgraph(deepcopy(data))
        elif self.aug == 'mask_nodes':
            data_aug = mask_nodes(deepcopy(data))
        elif self.aug == 'none':
            data_aug = deepcopy(data)
            data_aug.x = torch.ones((data.edge_index.max() + 1, 1))
        elif self.aug == 'random4':
            n = np.random.randint(4)
            if n == 0:
                data_aug = drop_nodes(deepcopy(data))
            elif n == 1:
                data_aug = permute_edges(deepcopy(data))
            elif n == 2:
                data_aug = subgraph(deepcopy(data))
            elif n == 3:
                data_aug = mask_nodes(deepcopy(data))
            else:
                print('sample error')
                assert False
        else:
            print('augmentation error')
            assert False

        return data, data_aug

In [4]:
from torch_geometric.datasets import TUDataset

class TUDataset_aug(TUDataset):
    def __init__(self, name, root, aug="none", transform=None, pre_transform=None):
        super(TUDataset_aug, self).__init__(name=name, root=root, transform=transform, pre_transform=pre_transform)
        self.aug = aug

    def get(self, idx: int):
        data = super(TUDataset_aug, self).get(idx)

        node_num = data.edge_index.max()
        sl = torch.tensor([[n, n] for n in range(node_num)]).t()
        data.edge_index = torch.cat((data.edge_index, sl), dim=1)

        if self.aug == 'dnodes':
            data_aug = drop_nodes(deepcopy(data))
        elif self.aug == 'pedges':
            data_aug = permute_edges(deepcopy(data))
        elif self.aug == 'subgraph':
            data_aug = subgraph(deepcopy(data))
        elif self.aug == 'mask_nodes':
            data_aug = mask_nodes(deepcopy(data))
        elif self.aug == 'none':
            data_aug = deepcopy(data)
            data_aug.x = torch.ones((data.edge_index.max() + 1, 1))
        elif self.aug == 'random4':
            n = np.random.randint(4)
            if n == 0:
                data_aug = drop_nodes(deepcopy(data))
            elif n == 1:
                data_aug = permute_edges(deepcopy(data))
            elif n == 2:
                data_aug = subgraph(deepcopy(data))
            elif n == 3:
                data_aug = mask_nodes(deepcopy(data))
            else:
                print('sample error')
                assert False
        else:
            print('augmentation error')
            assert False

        return data, data_aug