In [1]:
import torch
import os
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
import re
import random
import itertools
import networkx as nx
import numpy as np
from collections import defaultdict
from torch.utils.data.sampler import Sampler
from torchvision.datasets import ImageFolder, MNIST, CIFAR10

In [2]:
class custom_datalist(list):
    def __init__(self, *args, **kwargs):
        self.start_pos_data = []
        super().__init__(*args, **kwargs)

In [3]:
datalist = custom_datalist()
for num_node in [9, 11]:
    datalist.start_pos_data.append(len(datalist))   
    for _ in range(1024):
        edge_index = torch.from_numpy(
            np.array(nx.fast_gnp_random_graph(num_node, 0.5).edges())
        ).t().contiguous()
        datalist.append(
            Data(
                x=torch.rand(num_node, 5), 
                edge_index=edge_index, 
                edge_attr=torch.rand(edge_index.size(1))
            )
        )

In [4]:
from graphdataset import BucketSampler
dataloader = DataLoader(
    datalist, batch_sampler=BucketSampler(datalist, batch_size=128)
)
for data in dataloader:
    batch_size = data.batch.max() + 1
    print(f"number of batches: {batch_size}, number of nodes:{data.x.size(0)//batch_size}, test:{data.num_nodes // data.num_graphs}")

number of batches: 128, number of nodes:11, test:11
number of batches: 128, number of nodes:9, test:9
number of batches: 128, number of nodes:11, test:11
number of batches: 128, number of nodes:9, test:9
number of batches: 128, number of nodes:9, test:9
number of batches: 128, number of nodes:9, test:9
number of batches: 128, number of nodes:11, test:11
number of batches: 128, number of nodes:9, test:9
number of batches: 128, number of nodes:9, test:9
number of batches: 128, number of nodes:9, test:9
number of batches: 128, number of nodes:11, test:11
number of batches: 128, number of nodes:9, test:9
number of batches: 128, number of nodes:11, test:11
number of batches: 128, number of nodes:11, test:11
number of batches: 128, number of nodes:11, test:11
number of batches: 128, number of nodes:11, test:11


  print(f"number of batches: {batch_size}, number of nodes:{data.x.size(0)//batch_size}, test:{data.num_nodes // data.num_graphs}")


In [6]:
len(dataloader.dataset)

2048

In [6]:
bucketSampler = BucketSampler(datalist, batch_size=32)
# print(bucketSampler.samplers)
print(datalist.start_pos_data)
for x in bucketSampler:
    print(x.__len__())

[0, 1024]
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32


In [6]:
def get_combined_iterator(*iterables):
    nexts = [iter(iterable).__next__ for iterable in iterables]
    while nexts:
        next = random.choice(nexts)
        try:
            yield next()
        except StopIteration:
            nexts.remove(next)

In [7]:
num_nodes = set()
for data in datalist:
    num_nodes.add(data.num_nodes)
num_nodes

{9, 11}

In [14]:
datalists = defaultdict(list)
for data in datalist:
    datalists[data.num_nodes].append(data)
dataloaders = (
    DataLoader(data, batch_size=128, shuffle=True) for data in datalists.values()
)
batches = get_combined_iterator(*dataloaders)

In [15]:
i = 0
for batch in batches:
    i += 1
print(i)

16


In [None]:
epochs = 2 # just for illustration
for epoch_num in range(epochs):
    print(f"Epoch {epoch_num}")
    batches = get_combined_iterator(*dataloaders)
    for batch in batches:
        # now each graph in a batch will have the same number of nodes
        # do training
        pass

In [2]:
root_data_path = os.path.join("data_tgff", "multiple", "train", "raw")
root_data_path

'data_tgff/multiple/train/raw'

In [3]:
batch_size = 512
dataloaders = []
datas = []
raw_file_names = [f for f in os.listdir(root_data_path) if os.path.isfile(os.path.join(root_data_path, f))]
raw_file_names.sort(key=lambda f: int(re.split('_|[.]', f)[-2]))
raw_file_names = raw_file_names[:2]
for raw_file_name in raw_file_names:
    raw_file_path = os.path.join(root_data_path, raw_file_name)
    data = torch.load(raw_file_path)
    datas.append(data)
    dataloaders.append(DataLoader(data, batch_size=batch_size, shuffle=True))

In [4]:
datas_flattened = list(itertools.chain(*datas))

In [6]:
datas_flattened[0].num_nodes

9