In [1]:
import pandas as pd

import torch
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader

from joblib import Parallel, delayed

import os
import sys
sys.path.append(os.path.join(os.path.abspath(''), '..'))
from utils.graph import grid_to_graph

In [3]:
def process_id(id):
    grid_df = pd.read_parquet(f'../data/{id}_grid.parquet')
    distribution_df = pd.read_parquet(f'../data/{id}_distribution.parquet')
    for j, (grid, distribution) in enumerate(zip(grid_df.to_numpy(), distribution_df.to_numpy())):
        node_positions, edges, mask, input_nodes, target_nodes = grid_to_graph(grid.reshape(150, 150), distribution.reshape(150, 150))
        node_positions = torch.tensor(node_positions, dtype=torch.long)
        edges = torch.tensor(edges, dtype=torch.long).t().contiguous()
        mask = torch.tensor(mask, dtype=torch.bool)
        input_nodes = torch.tensor(input_nodes, dtype=torch.float32)
        target_nodes = torch.tensor(target_nodes, dtype=torch.float32)
        
        data = Data(x=input_nodes, edge_index=edges, mask=mask, pos=node_positions, y=target_nodes)
        torch.save(data, f'../data/processed/{id}_{j}.pt')

filenames = os.listdir('../data')
ids = {f.split('_')[0] for f in filenames if f.endswith('.parquet')}

Parallel(n_jobs=4, verbose=10)(delayed(process_id)(id) for id in ids);

[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.


[Parallel(n_jobs=4)]: Done   5 tasks      | elapsed:  4.2min
[Parallel(n_jobs=4)]: Done  10 tasks      | elapsed:  7.1min
[Parallel(n_jobs=4)]: Done  17 tasks      | elapsed: 11.3min
[Parallel(n_jobs=4)]: Done  24 tasks      | elapsed: 14.9min
[Parallel(n_jobs=4)]: Done  33 tasks      | elapsed: 18.5min
[Parallel(n_jobs=4)]: Done  42 tasks      | elapsed: 21.8min
[Parallel(n_jobs=4)]: Done  53 tasks      | elapsed: 26.4min
[Parallel(n_jobs=4)]: Done  64 tasks      | elapsed: 32.0min
[Parallel(n_jobs=4)]: Done  77 tasks      | elapsed: 37.0min
[Parallel(n_jobs=4)]: Done  90 tasks      | elapsed: 44.6min
[Parallel(n_jobs=4)]: Done 100 out of 100 | elapsed: 49.2min finished


[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]

In [4]:
class GraphDataset(Dataset):
    def __init__(self, root):
        super(GraphDataset, self).__init__(root)
        self.root = root
        self.file_names = os.listdir(root)

    def len(self):
        return len(self.file_names)

    def get(self, idx):
        return torch.load(os.path.join(self.root, self.file_names[idx]))
    
dataset = GraphDataset('../data/processed')
print(len(dataset))

100000


In [5]:
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
    print(batch)
    break

DataBatch(x=[35443, 13], edge_index=[2, 119356], y=[35443, 1], pos=[35443, 2], mask=[35443, 1], batch=[35443], ptr=[33])
