# Data related dev

## Dataset

In [1]:
import sys
sys.path.append("../../")

In [2]:
from ptgnn.dataset import RSDataset, BindingAffinityDataset, BaceDataset, Tox21Dataset, OGBDataset

In [3]:
ogb = OGBDataset(split="test", ds_name="hiv")
ogb[0]

Data(x=[88, 118], edge_index=[2, 252], edge_attr=[252, 80], pos=[88, 6], parallel_node_index=[88], circle_index=[88], y=[1])

In [4]:
bac1 = BaceDataset(split='val')

Processing...
1060it [00:17, 60.26it/s] 
151it [00:07, 18.94it/s]
302it [00:09, 32.33it/s] 
Done!


In [5]:
bac1

BaceDataset(151)

In [6]:
bac1[0]

Data(x=[202, 118], edge_index=[2, 578], edge_attr=[578, 80], pos=[202, 6], parallel_node_index=[202], circle_index=[202], y=[1])

In [7]:
tox1 = Tox21Dataset(split='val')

In [8]:
tox1[0]

Data(x=[40, 118], edge_index=[2, 104], edge_attr=[104, 80], pos=[40, 6], parallel_node_index=[40], circle_index=[40], y=[12])

In [9]:
t1 = RSDataset(split='val')

Processing...
11748it [01:25, 138.09it/s]
Done!


In [10]:
t1[0]

Data(x=[72, 118], edge_index=[2, 192], edge_attr=[192, 80], pos=[72, 6], parallel_node_index=[72], circle_index=[72], y=[1])

In [11]:
t1

RSDataset(11740)

In [12]:
len(t1)

11740

In [13]:
t2 = RSDataset(split="test")

Processing...
11680it [01:24, 138.78it/s]
Done!


In [14]:
t2[0]

Data(x=[36, 118], edge_index=[2, 104], edge_attr=[104, 80], pos=[36, 6], parallel_node_index=[36], circle_index=[36], y=[1])

In [15]:
len(t2)

11676

In [16]:
t3 = RSDataset(split="train")

Processing...
55084it [06:16, 146.38it/s]
Done!


In [17]:
t3[0]

Data(x=[132, 118], edge_index=[2, 348], edge_attr=[348, 80], pos=[132, 6], parallel_node_index=[132], circle_index=[132], y=[1])

In [18]:
len(t3)

55068

In [19]:
ba1 = BindingAffinityDataset(split="val")

Processing...
10368it [00:38, 266.38it/s]
Done!


In [20]:
ba1

BindingAffinityDataset(10354)

In [21]:
ba1[0]

Data(x=[46, 118], edge_index=[2, 136], edge_attr=[136, 80], pos=[46, 6], parallel_node_index=[46], circle_index=[46], y=[1])

In [22]:
len(ba1)

10354

## Loader

In [23]:
import torch
from torch_geometric.data import Dataset
from torch_geometric.data.data import BaseData
from torch_geometric.loader.dataloader import Collater
from typing import Union, List, Optional, Any
from ptgnn.loading.chienn_collate import collate_with_circle_index

In [24]:
class CustomCollater:
    def __init__(self, follow_batch=None, exclude_keys=None, n_neighbors_in_circle=None):
        self.collator = Collater(follow_batch, exclude_keys)
        self.follow_batch = follow_batch
        exclude_keys = exclude_keys if exclude_keys else []
        self.exclude_keys = exclude_keys + ['circle_index']
        self.n_neighbors_in_circle = n_neighbors_in_circle

    def __call__(self, batch: List[Any]):
        elem = batch[0]
        if isinstance(elem, BaseData) and hasattr(elem, 'circle_index'):
            return collate_with_circle_index(batch, self.n_neighbors_in_circle)
        else:
            return self.collator(batch)

In [25]:
class CustomDataLoader(torch.utils.data.DataLoader):
    r"""A data loader which merges data objects from a
    :class:`torch_geometric.data.Dataset` to a mini-batch.
    Data objects can be either of type :class:`~torch_geometric.data.Data` or
    :class:`~torch_geometric.data.HeteroData`.

    Args:
        dataset (Dataset): The dataset from which to load the data.
        batch_size (int, optional): How many samples per batch to load.
            (default: :obj:`1`)
        shuffle (bool, optional): If set to :obj:`True`, the data will be
            reshuffled at every epoch. (default: :obj:`False`)
        follow_batch (List[str], optional): Creates assignment batch
            vectors for each key in the list. (default: :obj:`None`)
        exclude_keys (List[str], optional): Will exclude each key in the
            list. (default: :obj:`None`)
        **kwargs (optional): Additional arguments of
            :class:`torch.utils.data.DataLoader`.
    """
    def __init__(
            self,
            dataset: Union[Dataset, List[BaseData]],
            batch_size: int = 1,
            shuffle: bool = False,
            follow_batch: Optional[List[str]] = None,
            exclude_keys: Optional[List[str]] = None,
            n_neighbors_in_circle: Optional[int] = None,
            **kwargs,
    ):

        if 'collate_fn' in kwargs:
            del kwargs['collate_fn']

        # Save for PyTorch Lightning:
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        super().__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=CustomCollater(follow_batch, exclude_keys, n_neighbors_in_circle),
            **kwargs,
        )

In [43]:
ba1 = BindingAffinityDataset(split="val")
loader = CustomDataLoader(ba1, n_neighbors_in_circle=3, batch_size=3)

In [44]:
for batch in loader:
    print(batch)
    break

DataBatch(x=[130, 118], edge_index=[2, 378], edge_attr=[378, 80], pos=[130, 6], parallel_node_index=[130], y=[3], batch=[130], ptr=[4], circle_index=[130, 5])


In [45]:
batch.batch

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [46]:
batch.circle_index

tensor([[ -1,  -1,  -1,  -1,  -1],
        [  3,   7,   5,   3,   7],
        [  0,   5,   7,   0,   5],
        [  9,  11,  13,   9,  11],
        [  0,   7,   3,   0,   7],
        [ 38,  45,  43,  38,  45],
        [  0,   3,   5,   0,   3],
        [ -1,  -1,  -1,  -1,  -1],
        [  2,  13,  11,   2,  13],
        [ 15,  19,  17,  15,  19],
        [  2,   9,  13,   2,   9],
        [ -1,  -1,  -1,  -1,  -1],
        [  2,  11,   9,   2,  11],
        [ -1,  -1,  -1,  -1,  -1],
        [  8,  17,  19,   8,  17],
        [ 21,  25,  23,  21,  25],
        [  8,  19,  15,   8,  19],
        [ -1,  -1,  -1,  -1,  -1],
        [  8,  15,  17,   8,  15],
        [ -1,  -1,  -1,  -1,  -1],
        [ 14,  23,  25,  14,  23],
        [ 27,  29,  31,  27,  29],
        [ 14,  25,  21,  14,  25],
        [ -1,  -1,  -1,  -1,  -1],
        [ 14,  21,  23,  14,  21],
        [ -1,  -1,  -1,  -1,  -1],
        [ 20,  31,  29,  20,  31],
        [ 33,  37,  35,  33,  37],
        [ 20,  27,  

In [47]:
batch[0]

Data(x=[46, 118], edge_index=[2, 136], edge_attr=[136, 80], pos=[46, 6], parallel_node_index=[46], y=[1])

In [48]:
ba1[0].circle_index

[[],
 [3, 7, 5],
 [0, 5, 7],
 [9, 11, 13],
 [0, 7, 3],
 [38, 45, 43],
 [0, 3, 5],
 [],
 [2, 13, 11],
 [15, 19, 17],
 [2, 9, 13],
 [],
 [2, 11, 9],
 [],
 [8, 17, 19],
 [21, 25, 23],
 [8, 19, 15],
 [],
 [8, 15, 17],
 [],
 [14, 23, 25],
 [27, 29, 31],
 [14, 25, 21],
 [],
 [14, 21, 23],
 [],
 [20, 31, 29],
 [33, 37, 35],
 [20, 27, 31],
 [],
 [20, 29, 27],
 [],
 [26, 35, 37],
 [39, 41],
 [26, 37, 33],
 [],
 [26, 33, 35],
 [],
 [32, 41],
 [4, 43, 45],
 [32, 39],
 [],
 [4, 45, 38],
 [],
 [4, 38, 43],
 []]

In [49]:
type(ba1[0].circle_index)

list

In [8]:
# test for ptree transformation
bac1 = BaceDataset(split='val', graph_mode="edge", transformation_mode="chienn_tree_basic")

Processing...
1060it [00:26, 40.64it/s]
151it [00:08, 17.14it/s]
302it [00:11, 25.66it/s]
Done!


In [9]:
bac1[0]

Data(x=[202, 118], edge_index=[2, 578], edge_attr=[578, 80], pos=[202, 6], ptree=[202], y=[1])

In [11]:
bac1[0].ptree

['P[1, Z[3, 7, 5]]',
 'P[0, Z[9, 13, 11]]',
 'P[3, Z[1, 5, 7]]',
 'P[2]',
 'P[5, Z[1, 7, 3]]',
 'P[4]',
 'P[7, Z[1, 3, 5]]',
 'P[6]',
 'P[9, Z[0, 11, 13]]',
 'P[8, Z[15, 19, 17]]',
 'P[11, Z[0, 13, 9]]',
 'P[10]',
 'P[13, Z[0, 9, 11]]',
 'P[12]',
 'P[15, Z[8, 17, 19]]',
 'P[14, Z[21]]',
 'P[17, Z[8, 19, 15]]',
 'P[16]',
 'P[19, Z[8, 15, 17]]',
 'P[18]',
 'P[21, Z[14]]',
 'P[20, Z[23, 25, 27]]',
 'P[23, Z[20, 27, 25]]',
 'P[22, Z[29, 33, 31]]',
 'P[25, Z[20, 23, 27]]',
 'P[24, Z[42, 201, 199]]',
 'P[27, Z[20, 25, 23]]',
 'P[26]',
 'P[29, Z[22, 31, 33]]',
 'P[28, Z[35, 37, 39]]',
 'P[31, Z[22, 33, 29]]',
 'P[30]',
 'P[33, Z[22, 29, 31]]',
 'P[32]',
 'P[35, Z[28, 39, 37]]',
 'P[34, Z[41, 43, 45]]',
 'P[37, Z[28, 35, 39]]',
 'P[36]',
 'P[39, Z[28, 37, 35]]',
 'P[38]',
 'P[41, Z[34, 45, 43]]',
 'P[40, Z[47, 49, 51]]',
 'P[43, Z[34, 41, 45]]',
 'P[42, Z[24, 199, 201]]',
 'P[45, Z[34, 43, 41]]',
 'P[44]',
 'P[47, Z[40, 51, 49]]',
 'P[46, Z[53]]',
 'P[49, Z[40, 47, 51]]',
 'P[48, Z[55, 59, 57]