In [4]:
#import spektral
import numpy as np
import tensorflow as tf
from ogb.graphproppred import GraphPropPredDataset
#from spektral.data import Dataset, Graph
#from spektral.datasets import TUDataset, QM9


In [46]:
config = {
    'seed': 1,
    'epochs': 10,
    'batch_size': 32,
    'learning_rate': 0.001,
    'dataset': 'ogbg-molesol', #JA: QM9, ogbg-molesol, ogbg-molfreesolv, ogbg-mollipo, ZINC| NEIN: aspirin
    'train_test_split': 0.8
}

np.random.seed(config['seed'])
tf.random.set_seed(config['seed'])

In [47]:
class OGBDataset(Dataset):
    '''
    (spektral) Dataset class wrapper for Open Graph Benchmark datasets.
    '''
    def __init__(self, name, **kwargs):
        self.name = name
        super().__init__(**kwargs)

    def read(self):
        dataset = GraphPropPredDataset(name=self.name)
        graphs = []
        for data in dataset:
            edge_index = data[0]['edge_index']
            edge_feat = data[0]['edge_feat']
            node_feat = data[0]['node_feat']
            label = data[1]

            # Create adjacency matrix
            num_nodes = node_feat.shape[0]
            adj = np.zeros((num_nodes, num_nodes))
            for edge in edge_index.T:
                adj[edge[0], edge[1]] = 1

            # Create spektral Graph object
            graphs.append(Graph(x=node_feat, a=adj, e=edge_feat, y=label))
            
        self.size = len(graphs)

        return graphs

def ogb_available_datasets():
    #These regression datasets have size % 2 == 0 number of graphs
    return ['ogbg-molesol', 'ogbg-molfreesolv', 'ogbg-mollipo']

In [48]:
def _load_data(name: str):
    '''
    Loads a dataset from [TUDataset, OGB]
    '''
    if name in ogb_available_datasets():
        dataset= OGBDataset(name)
    else:
        raise ValueError(f'Dataset {name} unknown')

    return dataset, dataset.n_labels

In [49]:
def _split_data(data, train_test_split, seed):
    '''
    Split the data into train and test sets
    '''
    np.random.seed(seed)
    idxs = np.random.permutation(len(data))
    split = int(train_test_split * len(data))
    idx_train, idx_test = np.split(idxs, [split])
    train, test = data[idx_train], data[idx_test]
    train.size = len(train)
    test.size = len(test)
    return train, test

In [50]:
def get_data(config):
    seed = config['seed']
    train_test_split = config['train_test_split']
    name = config['dataset']

    # Load data
    data, config['n_out'] = _load_data(name)
    # Split data
    train_data, test_data = _split_data(data, train_test_split, seed)

    return train_data, test_data

In [51]:
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [52]:
dataset_train, dataset_test = get_data(config)

In [53]:
len(dataset_train), len(dataset_test)

(902, 226)

In [54]:
dataset_train

OGBDataset(n_graphs=902)

In [55]:
us = [g.y for g in dataset_test]
us

[array([-2.55]),
 array([-0.03]),
 array([-0.96]),
 array([-0.8]),
 array([-1.28]),
 array([-1.37]),
 array([-3.493]),
 array([-1.34]),
 array([-3.77]),
 array([1.11]),
 array([-5.28]),
 array([-1.44]),
 array([-4.57]),
 array([-4.871]),
 array([-3.46]),
 array([-5.68]),
 array([-1.03]),
 array([-0.62]),
 array([-3.499]),
 array([-3.45]),
 array([-4.4]),
 array([-7.85]),
 array([-5.05]),
 array([-2.843]),
 array([-3.091]),
 array([-2.523]),
 array([-2.484]),
 array([-3.2]),
 array([-6.876]),
 array([-3.18]),
 array([-3.65]),
 array([-0.45]),
 array([-3.54]),
 array([-4.63]),
 array([-6.09]),
 array([-1.6]),
 array([-3.04]),
 array([-2.78]),
 array([-0.63]),
 array([-8.4]),
 array([0.]),
 array([-5.47]),
 array([-4.36]),
 array([-2.03]),
 array([-4.376]),
 array([-6.124]),
 array([-6.291]),
 array([-3.953]),
 array([-2.92]),
 array([0.]),
 array([-4.805]),
 array([-5.26]),
 array([-1.23]),
 array([-2.41]),
 array([-4.2]),
 array([-4.4]),
 array([-6.9]),
 array([-1.74]),
 array([-3.14]),

In [56]:
idx_sort = np.argsort(us, axis=0)

In [59]:
a = dataset_train.size

In [60]:
a

902

In [61]:
def iterate_train_random(elements):
    objects, sort_idx = elements#, _, sort_idx
    olen = objects.size
    seed = 42 + olen#self.seed
    pair_count = (olen * (olen - 1)) // 2
    sample_size = min(int(20 * pair_count), pair_count)#self.sample_ratio
    rng = np.random.default_rng(seed)

    sample = rng.choice(pair_count, sample_size, replace=False)
    sample_b = (np.sqrt(sample * 2 + 1/4) + 1/2).astype(np.int)
    sample_a = sample - (sample_b * (sample_b - 1)) // 2
    idx_a = sort_idx[sample_a]
    idx_b = sort_idx[sample_b]

    yield from zip(idx_a, idx_b)

In [78]:
dl = iterate_train_random((dataset_test, us))

In [81]:
while True:
    i,j=next(dl)
    print(i,j)

StopIteration: 

In [77]:
a

902