In [1]:
import torch
from torch_geometric.data import Data

In [2]:
edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index.t().contiguous())

In [11]:
import numpy as np
import scipy.sparse as sp
from tqdm import tqdm

def load_data(dir_name): 
	"""
	Function that loads graphs
	"""  
	print('load graph indicator')
	graph_indicator = np.loadtxt(dir_name+"graph_indicator.txt", dtype=np.int64)
	_,graph_size = np.unique(graph_indicator, return_counts=True)

	print('load edges')
	edges = np.loadtxt(dir_name+"edgelist.txt", dtype=np.int64, delimiter=",")
	A = sp.csr_matrix((np.ones(edges.shape[0]), (edges[:,0], edges[:,1])), shape=(graph_indicator.size, graph_indicator.size))
	A += A.T

	print('load nodes')
	x = np.loadtxt(dir_name+"node_attributes.txt", delimiter=",")
	print('load edge attributes')
	edge_attr = np.loadtxt(dir_name+"edge_attributes.txt", delimiter=",")

	# adj = []
	# features = []
	# edge_features = []
	idx_n = 0
	idx_m = 0
	datasets = []

	print('build graphs')
	for i in tqdm(range(graph_size.size)):
		node_feature = torch.tensor(x[idx_n:idx_n+graph_size[i],:], dtype=torch.float64)
		adj = A[idx_n:idx_n+graph_size[i],idx_n:idx_n+graph_size[i]]
		edge_index = torch.tensor(np.vstack((adj.tocoo().row,adj.tocoo().col)), dtype=torch.long)
		edge_feature = edge_attr[idx_m:idx_m+adj.nnz,:]
		data = Data(x=node_feature, edge_index=edge_index, edge_attr=edge_feature)
		if data.validate():
			datasets.append(data)
		else:
			print(i,data)
			break
		# adj.append(A[idx_n:idx_n+graph_size[i],idx_n:idx_n+graph_size[i]])
		# edge_features.append(edge_attr[idx_m:idx_m+adj[i].nnz,:])
		# features.append(x[idx_n:idx_n+graph_size[i],:])
		idx_n += graph_size[i]
		idx_m += adj.nnz

	return datasets

In [12]:
datasets = load_data('./data/raw/')

load graph indicator
load edges
load nodes
load edge attributes
build graphs


100%|██████████| 6111/6111 [00:08<00:00, 755.97it/s]


In [13]:
len(datasets)

6111

In [14]:
datasets[0]

Data(x=[327, 86], edge_index=[2, 6233], edge_attr=[6233, 5])

In [17]:
import os.path as osp
dir_name = './data/'

In [18]:
for idx, data in enumerate(datasets):
	torch.save(data, osp.join(dir_name, f'data_{idx}.pt'))