In [14]:
import os
import glob
import torch
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
import pandas as pd

In [2]:
%cd C:\Users\yuval\Projects\MolecularGraphs

C:\Users\yuval\Projects\MolecularGraphs


The `Dataset` class is designed to sample batches from storage without uploading all data into the RAM.
To create our own custom operation we need to create a class that inherent from `Dataset` class.

In the `init` method, the arguments that are pass to `Dataset` are:
* `root` (str, optional) - The root directory where the data should be saved.
This directory is going to have `raw` directory and `processed` directory.
The `raw` directory is where you have all files of the data, a file per instance.
The `processed` directory is where the class is going to saved all processed files.
The processing of files in our case is the convertion of the file into a `Data` object (including node features, edge index, label/s, and optional of edges features).
* `transform` (callable, optional) - not used - a function/transform that takes in an `Data` object and returns a transformed version. The `Data` object will be **transformed before every access**.
* `pre_transform` (callable, optional) - not used – a function/transform that takes in an `Data` object and returns a transformed version. The `Data` object will be **transformed before being saved to disk**. (default: None)
* `pre_filter` (callable, optional) - not used - a function that takes in an `Data` object and returns a boolean value, indicating whether the `Data` object should be included in the final dataset. 
* `log` (bool, optional) - whether to print any console output while downloading and processing the dataset.

Following the `init` method, we have two method decorated as property.
The decorator define the method as a "getter", i.e., getting an attribute of the class.
That means we can treat such method as an attribute and call it without parentheses.
Those two properties return all files names inside the previously mentioned directories - `raw` and `processed`.

The `process` method is called with calling the `Dataset`'s `init` method (I think).
In this method you iterate over all of the raw files and turn them into `Data` object of graph, including the `pre_transform` and `pre_filter` functions calls.

The `len` and `get` are self-explanable.


In [17]:
class MyDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return os.listdir(self.raw_dir)

    @property
    def processed_file_names(self):
        return glob.glob('Zinc\GraphData\processed\data_*.pt') # Avoiding processed filters and transformers


    def process(self):
        idx = 0
        for raw_path in self.raw_paths:

            # Load the two arrays and scaler from the saved file using read_pickle()
            with open(raw_path, 'rb') as f:
                x, edge_index, y = pd.read_pickle(f)
            
            data_i = Data(x=x, edge_index=edge_index, y=y)

            if self.pre_filter is not None and not self.pre_filter(data_i):
                continue

            if self.pre_transform is not None:
                data_i = self.pre_transform(data_i)

            torch.save(data_i, os.path.join(self.processed_dir, f'data_{idx}.pt'))
            idx += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data_i = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data_i

In [18]:
dataset = MyDataset('Zinc\GraphData')


Processing...
Done!


In [19]:
data = dataset.get(10)
data

Data(x=[19, 11], edge_index=[2, 40], y=0.0)

In [20]:
data.x

array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.

In [21]:
data.edge_index

array([[ 0,  1,  1,  1,  2,  3,  3,  4,  4,  5,  5,  5,  6,  6,  6,  7,
         7,  8,  8,  9,  9,  9, 10, 11, 11, 12, 12, 13, 13, 13, 14, 14,
        15, 15, 16, 16, 17, 17, 18, 18],
       [ 1,  0,  2,  3,  1,  1,  4,  3,  5,  4,  6, 13,  5,  7, 12,  6,
         8,  7,  9,  8, 10, 11,  9,  9, 12,  6, 11,  5, 14, 18, 13, 15,
        14, 16, 15, 17, 16, 18, 13, 17]], dtype=int64)

In [22]:
loader = DataLoader(dataset, batch_size=32, shuffle=True)

To split the data into training, validation and test sets, we use the `index_select` method which creates a subset of the dataset from specified indices idx.

In [23]:
N = dataset.len()
idx = torch.randperm(N) # Random permutation of integers from 0 to N - 1
idx_train, idx_val, idx_test = idx[:int(0.8 * N)], idx[int(0.8 * N): int(0.9 * N)], idx[int(0.9 * N):]

train_dataset = dataset.index_select(idx_train)
val_dataset = dataset.index_select(idx_val)
test_dataset = dataset.index_select(idx_test)

In [24]:
dataset.len()

2936

Now we define the `Dataloader`.
Note for thet `shuffle` parameter, if set to True, the data will be reshuffled at every epoch.
We do not want such thing for the validation and test sets.

In [25]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [26]:
for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[64], y=[64], batch=[1199], ptr=[65])

Step 2:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[64], y=[64], batch=[1307], ptr=[65])

Step 3:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[64], y=[64], batch=[1179], ptr=[65])

Step 4:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[64], y=[64], batch=[1130], ptr=[65])

Step 5:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[64], y=[64], batch=[1134], ptr=[65])

Step 6:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[64], y=[64], batch=[1101], ptr=[65])

Step 7:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[64], y=[64], batch=[1152], ptr=[65])

Step 8:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[64], y=[64], batch=[1101], ptr=[65])

Step 9:
Number of graphs in the current batch: 64
DataBa