In [12]:
import pandas as pd
import numpy as np
import polars as pl
from time import time
from torch_geometric.data import Data, Batch
from tqdm import tqdm


In [7]:
graphs = pd.read_pickle('data/all_buildingblock.pkl')
graph_dict = {row['id']: row['mol_graph'] for _, row in graphs.iterrows()}
print(f'ID 0: {graph_dict[0]}\nID 1: {graph_dict[1]}')

ID 0: Data(x=[13, 8], edge_index=[2, 11])
ID 1: Data(x=[33, 8], edge_index=[2, 36])


In [10]:
#shuffle the dataset
def shuffle_dataset(buildingblock_ids, targets):
    len_dataset = len(buildingblock_ids)
    indices = np.arange(len_dataset)
    np.random.shuffle(indices)
    return buildingblock_ids[indices], targets[indices]

# Verify by loading the saved .npz file
loaded_data = np.load('data/dataset.npz')
loaded_buildingblock_ids = loaded_data['buildingblock_ids']
loaded_targets = loaded_data['targets']

In [9]:
len_dataset = len(loaded_buildingblock_ids)
batch_size = 2048


loaded_buildingblock_ids, loaded_targets = shuffle_dataset(loaded_buildingblock_ids, loaded_targets)
print(f'Loaded buildingblock_ids: {loaded_buildingblock_ids[:10]}')

Loaded buildingblock_ids: [[ 614  901  938]
 [1131  415  211]
 [   7  699  647]
 [ 249  418 1012]
 [ 462 1023  369]
 [ 650   67  394]
 [ 258  769  978]
 [ 679  296  379]
 [ 419  725  834]
 [ 450  830 1005]]


# If only it was this fast
It's very quick just going through the bb - ids and grouping them into batch_size. 

In [16]:

start_time = time()
for epoch in range(4):
    for i in tqdm(range(0, len_dataset, batch_size)):
        bbs_array = loaded_buildingblock_ids[i:i+batch_size]
        targets_array = loaded_targets[i:i+batch_size]
        flat_bbs = bbs_array.flatten()
    print(f"Epoch done in time: {time() - start_time:.2f} seconds")
    
    loaded_buildingblock_ids, loaded_targets = shuffle_dataset(loaded_buildingblock_ids, loaded_targets)

# Process the batch here
end_time = time()
print(f'Time taken to iterate over the dataset ({len_dataset}) in batches: {end_time - start_time:.2f} seconds')


100%|██████████| 48055/48055 [00:00<00:00, 124588.11it/s]


Epoch done in time: 0.39 seconds


100%|██████████| 48055/48055 [00:00<00:00, 119474.41it/s]


Epoch done in time: 31.24 seconds


100%|██████████| 48055/48055 [00:00<00:00, 103978.85it/s]


Epoch done in time: 60.43 seconds


100%|██████████| 48055/48055 [00:00<00:00, 104013.19it/s]


Epoch done in time: 90.30 seconds
Time taken to iterate over the dataset (98415610) in batches: 118.63 seconds


# El Problemo 

The biggest bottleneck is Crating batches of objects from lists of graphs. I tried just indexing the buildingblock dictionary to check if that is the bottleneck, but it's definitely the Batch Creation.

In [24]:
def custom_batching(flat_bbs):
    graphs_array_1 = [graph_dict[bb] for bb in flat_bbs[0::3]]
    graphs_array_2 = [graph_dict[bb] for bb in flat_bbs[1::3]]
    graphs_array_3 = [graph_dict[bb] for bb in flat_bbs[2::3]]
    return Batch.from_data_list(graphs_array_1), Batch.from_data_list(graphs_array_2), Batch.from_data_list(graphs_array_3)

start_time = time()
for epoch in range(4):
    for i in tqdm(range(0, len_dataset, batch_size)):
        bbs_array = loaded_buildingblock_ids[i:i+batch_size]
        targets_array = loaded_targets[i:i+batch_size]
        flat_bbs = bbs_array.flatten()
        batch_graphs, bgs, bg3 = custom_batching(flat_bbs) 
        if i % 50 == 0 and i > 0:
            break


# Process the batch here
end_time = time()
print(f'Time taken to iterate over 100 batches in batches: {end_time - start_time:.2f} seconds')


  0%|          | 25/48055 [00:05<2:41:21,  4.96it/s]
  0%|          | 25/48055 [00:05<2:47:40,  4.77it/s]
  0%|          | 25/48055 [00:04<2:39:53,  5.01it/s]
  0%|          | 25/48055 [00:05<2:42:23,  4.93it/s]

Time taken to iterate over 100 batches in batches: 20.35 seconds



