In [1]:
import pandas as pd
import h5py
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import pymetis
import networkx as nx
import time
from networkx.algorithms import community
from random import shuffle
import math
import torch
import torch.nn as nn
import torch_geometric as tg
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import GCNConv
from torch_geometric.utils import add_self_loops, degree
from torch.nn import init
import pdb
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
from torch_geometric.data import Data
import torch.optim as optim
import pywt
from scipy.stats import norm
import scipy.interpolate as interp
import collections

In [2]:
def open_data(file_path):
    file = open(file_path,"rb")
    raw_data = pickle.load(file)
    
    return raw_data

In [3]:
def read_data(file_path):
    with h5py.File(file_path, 'r') as f:
    # 查看文件中所有的数据集名称
        dataset_names = list(f.keys())
        print("Datasets in the file:", dataset_names)
        data_dict = {}
        for name in dataset_names:
            if f[name].shape == (): 
                data = f[name][()] 
            else:
                data = f[name][:]
        
            data_dict[name] = data
            
        return data_dict

In [4]:
def renumber_subgraph(nodes, edge_index):
    unique_nodes = torch.unique(nodes, sorted = True)
    new_node_ids = torch.arange(len(unique_nodes))
    node_mapping = {old_id.item(): new_id.item() for old_id, new_id in zip(unique_nodes, new_node_ids)}
    
    new_edge_index = torch.tensor([
        [node_mapping[edge_index[0, i].item()], node_mapping[edge_index[1, i].item()]]
        for i in range(edge_index.size(1))
    ]).t()
    
    return new_edge_index

In [5]:
def subgraph_information(city_names, root_path):
    edge_pair_dictionary = {}
    max_ach_num = 0
    max_subgraph_node_num = 0 
    
    for name in city_names:
        edge_pairs = []
        dirs = os.listdir(root_path + name + "/edge_pair/")
        
        #-------------获取每个字图的edge pairs----------
        edge_pairs = [torch.tensor(open_data(root_path + name + "/edge_pair/" + each_file),dtype=torch.long) for each_file in dirs]
        
        #------------获取每个子图的节点-----------------
        comm_node_list = [torch.unique(edge_pair, sorted = True) for edge_pair in edge_pairs]
        edge_pair_dictionary[name + "_nodes"] = comm_node_list
        
        #------------获取每个子图的edge pairs并且新编号--------
        renumbered_edge_pairs = [renumber_subgraph(comm_node_list[i], edge_pairs[i]) for i in range(len(edge_pairs))]
        edge_pair_dictionary[name + "_edge_pair"] = renumbered_edge_pairs 
        
        #--------------计算子图的节点数-----------------
        graph_num = [sub_graph_num.shape[0] for sub_graph_num in comm_node_list]
        edge_pair_dictionary[name + "_subgraph_node_num"] = graph_num
        
        #---------Calculate the anchor set num of each comm----------------
        ach_set_nums = [int(0.5*int(np.log2(node_num.shape[0])))* int(np.log2(node_num.shape[0]))
                        for node_num in comm_node_list]
        edge_pair_dictionary[name + "_anchor_set_num"] = ach_set_nums
        
        #---------获取每个城市的节点数-----------------
        city_node_num = sum([comm_node.shape[0] for comm_node in comm_node_list])
        edge_pair_dictionary[name + "_city_node_num"] = city_node_num
        
        if max_ach_num <= max(ach_set_nums):
            max_ach_num = max(ach_set_nums)
        if max_subgraph_node_num <= max(graph_num):
            max_subgraph_node_num = max(graph_num)
        
    return edge_pair_dictionary, max_ach_num, max_subgraph_node_num

In [6]:
class CustomData(Data):
    def __init__(self, trend, period, target_volume, target_label, edge_pairs, subgraph_node_num, subgraph_nodes, city_node_num, dist_max, dist_argmax):
        super(CustomData, self).__init__()
        self.trend = trend
        self.period = period
        self.target_volume = target_volume
        self.target_label = target_label
        self.edge_pairs = edge_pairs
        self.subgraph_node_num = subgraph_node_num
        self.subgraph_nodes = subgraph_nodes
        self.city_node_num = city_node_num
        self.dist_max = dist_max
        self.dist_argmax = dist_argmax

In [7]:
def source_data_prepare(city_names, root_path, purpose):
    keys = ['period_template', 'target_template', 'trend_template']
    
    edge_pair_dictionary, _, max_subgraph_node_num = subgraph_information(city_names, root_path)
    
    for i in range(len(city_names)):
        data_list = []
        city_dict = read_data(root_path + city_names[i] + f"/input_target/{purpose}.h5")
        dist_dict = open_data(root_path + city_names[i] + "/input_target/dist_dictionary.h5")
        
        dist_max = dist_dict['dist_max']
        dist_argmax = dist_dict['dist_argmax'].to(torch.int)
        edge_pairs = edge_pair_dictionary[city_names[i] + "_edge_pair"]
        subgraph_node_num = edge_pair_dictionary[city_names[i] + "_subgraph_node_num"]
        subgraph_nodes = edge_pair_dictionary[city_names[i] + "_nodes"]
        city_node_num = edge_pair_dictionary[city_names[i] + "_city_node_num"]
        
        
        
        for j in range(city_dict["period_template"].shape[0]):
            trend = torch.tensor(city_dict["trend_template"][j, :, :, :6],dtype=torch.float)
            period = torch.tensor(city_dict["period_template"][j, :, :, :6],dtype=torch.float)
            targets = torch.tensor(city_dict["target_template"][j, :, : , :],dtype=torch.float)
            #target_label = torch.tensor(city_dict["target_template"][j, :, :, 5],dtype=torch.float).unsqueeze(-1)
            
            
            #-------------First we process trend and period------------------
            input_indices = [torch.isin(trend[:, :, 0], graph_node).unsqueeze(-1) for graph_node in subgraph_nodes]
            target_indices = [torch.isin(targets[:, :, 0], graph_node).unsqueeze(-1) for graph_node in subgraph_nodes]
            
            
            
            for k in range(len(input_indices)):
                #-------------trend-------------------------
                trend_k = torch.zeros((trend.shape[0], max_subgraph_node_num, 4))
                trend_k[:,:subgraph_node_num[k],:] = trend[:,:,1:5][input_indices[k].expand_as(trend[:,:,1:5])].view(trend.shape[0], -1, 4)
                #-------------period------------------------
                period_k = torch.zeros_like(trend_k)
                period_k[:,:subgraph_node_num[k],:] = period[:,:,1:5][input_indices[k].expand_as(period[:,:,1:5])].view(period.shape[0], -1, 4)
                
                #-------------targets---------------------
                target_vol_k = torch.zeros((targets.shape[0], max_subgraph_node_num, 4))
                target_vol_k[:,:subgraph_node_num[k],:] = targets[:,:,1:5][target_indices[k].expand_as(targets[:,:,1:5])].view(targets.shape[0], -1, 4)
                #------------------------------------------
                target_label_k = torch.zeros((trend.shape[0],max_subgraph_node_num, 1))
                #target_label_k = torch.zeros((targets.shape[0], max_subgraph_node_num, 1))
                temporal_tar = trend[:,:,5].unsqueeze(-1)[input_indices[k]].view(trend.shape[0], -1, 1)
                target_label_k[:,:subgraph_node_num[k],:] = temporal_tar
                #temporal_tar = targets[:,:,5].unsqueeze(-1)
                #target_label_k[:,:subgraph_node_num[k],:] = temporal_tar[target_indices[k]].view(targets.shape[0], -1, 1)
                #-------------------------------------------
   
            
                custom_data = CustomData(trend_k, period_k, target_vol_k, target_label_k, edge_pairs[k], subgraph_node_num[k],
                                     subgraph_nodes[k], city_node_num, dist_max[k,:,:,:], dist_argmax[k,:,:,:])
            
                data_list.append(custom_data)
                

        print(f"{city_names[i]} data is prepared")
            
        shuffle(data_list)
        torch.save(data_list, root_path + city_names[i] + f'/input_target/{purpose}_regional_level.pt')
    print('done')
    
    return data_list

In [8]:
city_names = ["Barcelona","Antwerp"]
root_path = "D:/ThesisData/processed data/"
purpose = 'train'

In [9]:
data_list = source_data_prepare(city_names, root_path, purpose)

Datasets in the file: ['period_template', 'target_template', 'trend_template']
Barcelona data is prepared
Datasets in the file: ['period_template', 'target_template', 'trend_template']
Antwerp data is prepared
done
