In [3]:
import os
import pickle
from torch_geometric.data import Data
import torch

def process_graphs(graph_dict):
    processed_data = []
    keys = sorted(graph_dict.keys())  
    for key in keys:
        graph = graph_dict[key]
        barcode = torch.tensor(graph['barcode'])
        PI = torch.tensor(graph['PI'])
        node_features = torch.tensor(graph['filtration_val']).view(-1, 1)
        source_edge_index = torch.tensor(graph['source_edge_index'])
        sink_edge_index = torch.tensor(graph['sink_edge_index'])
        graph_data = Data(
            x = node_features,
            source_edge_index = source_edge_index,
            sink_edge_index = sink_edge_index,
            barcode_ground = barcode,
            PI = PI,
            label = torch.tensor([graph['label']])
        )
        processed_data.append(graph_data)
    return processed_data

dataset_name = 'citation'
if dataset_name in ['citation', 'bitcoin', 'question', 'social']:
    data_path = os.path.join('dynamic_dataset', f'{dataset_name}.pkl')
    small_data_path = os.path.join('dynamic_dataset', f'{dataset_name}_small_graph.pkl')
    large_data_path = os.path.join('dynamic_dataset', f'{dataset_name}_large_graph.pkl')
elif dataset_name in ['RedditB', 'Reddit5K', 'Reddit12K']:
    data_path = os.path.join('static_dataset', f'{dataset_name}.pkl')
    small_data_path = os.path.join('dynamic_dataset', f'{dataset_name}_small_graph.pkl')
    large_data_path = os.path.join('dynamic_dataset', f'{dataset_name}_large_graph.pkl')

with open(data_path, 'rb') as f:
    data = pickle.load(f)

small_data = process_graphs(data['small_graph'])
large_data = process_graphs(data['big_graph'])

with open(small_data_path, 'wb') as f:
    pickle.dump(small_data, f, pickle.HIGHEST_PROTOCOL)

with open(large_data_path, 'wb') as f:
    pickle.dump(large_data, f, pickle.HIGHEST_PROTOCOL)


  source_edge_index = torch.tensor(graph['source_edge_index'])
  sink_edge_index = torch.tensor(graph['sink_edge_index'])
