# Convert raw dataset into pytorch geometric dataset and save them


## Graph creation

The generation, conversion, and storage of grid-based graph datasets for Graph Neural Networks (GNNs), specifically tailored for PyTorch Geometric (PyG). Its applications are envisioned in environmental simulations like terrain modeling or flood prediction, where the spatial arrangement and attributes are crucial.

### Function Descriptions

- **`center_grid_graph(dim1, dim2)`**:  
    Creates a directed graph from a rectangular grid of specified dimensions, where each grid cell represents a node. Edges connect adjacent nodes to model spatial continuity. Nodes are placed at the centers of grid cells, establishing the spatial structure for environmental modeling.

- **`get_coords(pos)`**:  
    Extracts x and y coordinates of each node from a position dictionary and returns them in a NumPy array. This utility is useful for numerical operations involving node positions, such as distance calculations.

- **`get_corners(pos)`** and **`get_contour(pos)`**:  
    These functions compute the coordinates of the grid's corners and its contour. Such utilities are helpful for tasks requiring awareness of spatial boundaries or the grid's external shape.

- **`reorder_dict(dictt)`**:  
    Reorganizes a dictionary by its values, assigning new sequential numeric keys. This simplifies data access and ensures consistency when working with graph attributes.

- **`convert_to_pyg(graph, pos, DEM, WD, VX, VY)`**:  
    Converts a graph into a PyTorch Geometric `Data` object, incorporating spatial and environmental attributes like elevation (DEM) and water depth (WD). This conversion is crucial for enabling GNNs to learn from physically relevant patterns.

- **`create_grid_dataset(dataset_folder, n_sim, start_sim=1, number_grids=64)`**:  
    Automates the loading of simulation data, generating a dataset of `Data` objects. This function streamlines the preparation of datasets for GNN model training or testing.

- **`save_database(dataset, name, out_path='datasets')`**:  
    Saves the generated dataset to disk using pickle serialization. This ensures the dataset can be easily retrieved for future use.

- **`create_dataset_folders(dataset_folder='datasets')`**:  
    Prepares the directory structure for dataset storage, with separate folders for training and testing data. This organization aids in dataset management.



In [1]:
import numpy as np
import networkx as nx
import os
import pickle
from tqdm import tqdm
import torch
from torch_geometric.data import Data

def center_grid_graph(dim1, dim2):
    '''
    Create graph from a rectangular grid of dimensions dim1 x dim2
    Returns networkx graph connecting the grid centers and corresponding 
    node positions
    ------
    dim1: int
        number of grids in the x direction
    dim2: int
        number of grids in the y direction
    '''
    G = nx.grid_2d_graph(dim1, dim2, create_using=nx.DiGraph)
    # for the position, it is assumed that they are located in the centre of each grid
    pos = {i:(x+0.5,y+0.5) for i, (x,y) in enumerate(G.nodes())}
    
    #change keys from (x,y) format to i format
    mapping = dict(zip(G, range(0, G.number_of_nodes())))
    G = nx.relabel_nodes(G, mapping)

    return G, pos

def get_coords(pos):
    '''
    Returns array of dimensions (n_nodes, 2) containing x and y coordinates of each node
    ------
    pos: dict
        keys: (x,y) index of every node
        values: spatial x and y positions of each node
    '''
    return np.array([xy for xy in pos.values()])
	

def get_corners(pos):
    '''
    Returns the coordinates of the corners of a grid
    ------
    pos: dict
        keys: (x,y) index of every node
        values: spatial x and y positions of each node
    '''    
    BL = min(pos.values()) #bottom-left
    TR = max(pos.values()) #top-right
    BR = (BL[0], TR[1]) #bottom-right
    TL = (TR[0], BL[1]) #top-left
    
    return BL, TR, BR, TL

def get_contour(pos):
    '''
    Returns a dictionary with the contours of a grid
    ------
    pos: dict
        keys: (x,y) index of every node
        values: spatial x and y positions of each node
    '''
    BL, TR, BR, TL = get_corners(pos)
    
    x_pos = np.arange(BL[0], TR[0]+1)
    y_pos = np.arange(BL[1], TR[1]+1)
    
    bottom = [(x, BL[1]) for x in x_pos]
    left = [(BL[0], y) for y in y_pos]
    right = [(TR[0], y) for y in y_pos]
    top = [(x, TR[1]) for x in x_pos]
    
    contour = {}

    for point in (bottom + left + right + top):
        key = list(pos.keys())[list(pos.values()).index(point)]
        contour[point] = pos[key]
    
    return contour

def reorder_dict(dictt):
    '''
    Change the key of a dictionary and sorts it by values order
    '''
    new_dict = {}
    
    #sort to exclude double values and order it
    dictt = dict(sorted(dictt.items()))

    #change keys from (x,y) format to i format
    for i, key in enumerate(dictt.keys()):
        new_dict[i] = dictt[key]
        
    return new_dict

def convert_to_pyg(graph, pos, DEM, WD, VX, VY):
    '''Converts a graph or mesh into a PyTorch Geometric Data type 
    Then, add position, DEM, and water variables to data object'''
    DEM = DEM.reshape(-1)

    edge_index = torch.LongTensor(list(graph.edges)).t().contiguous()
    row, col = edge_index

    data = Data()

    delta_DEM = torch.FloatTensor(DEM[col]-DEM[row])
    coords = torch.FloatTensor(get_coords(pos))
    edge_relative_distance = coords[col] - coords[row]
    edge_distance = torch.norm(edge_relative_distance, dim=1)
    edge_slope = delta_DEM/edge_distance

    data.edge_index = edge_index
    data.edge_distance = edge_distance
    data.edge_slope = edge_slope
    data.edge_relative_distance = edge_relative_distance

    data.num_nodes = graph.number_of_nodes()
    data.pos = torch.tensor(list(pos.values()))
    data.DEM = torch.FloatTensor(DEM)
    data.WD = torch.FloatTensor(WD.T)
    data.VX = torch.FloatTensor(VX.T)
    data.VY = torch.FloatTensor(VY.T)
        
    return data

def create_grid_dataset(dataset_folder, n_sim, start_sim=1, number_grids=64):
    assert os.path.exists(dataset_folder), "There is no raw dataset folder"
    grid_dataset = []

    graph, pos = center_grid_graph(number_grids,number_grids)
    
    for i in tqdm(range(start_sim, start_sim + n_sim)):
        DEM_path = os.path.join(dataset_folder, "DEM", f"DEM_{i}.txt")
        WD_path = os.path.join(dataset_folder, "WD", f"WD_{i}.txt")
        VX_path = os.path.join(dataset_folder, "VX", f"VX_{i}.txt")
        VY_path = os.path.join(dataset_folder, "VY", f"VY_{i}.txt")

        DEM = np.loadtxt(DEM_path)[:, 2]
        WD = np.loadtxt(WD_path)
        VX = np.loadtxt(VX_path)
        VY = np.loadtxt(VY_path)
        
        grid_i = convert_to_pyg(graph, pos, DEM, WD, VX, VY)
        grid_dataset.append(grid_i)
    
    return grid_dataset


def save_database(dataset, name, out_path='datasets'):
    '''
    This function saves the geometric database into a pickle file
    The name of the file is given by the type of graph and number of simulations
    ------
    dataset: list
        list of geometric datasets for grid and mesh
    names: str
        name of saved dataset
    out_path: str, path-like
        output file location
    '''
    n_sim = len(dataset)
    path = f"{out_path}/{name}.pkl"
    
    if os.path.exists(path):
        os.remove(path)
    elif not os.path.exists(out_path):
        os.mkdir(out_path)
    
    pickle.dump(dataset, open(path, "wb" ))
        
    return None

def create_dataset_folders(dataset_folder='datasets'):
    if not os.path.exists(dataset_folder):
        os.makedirs(dataset_folder)

    train_folder = os.path.join(dataset_folder, 'train')
    test_folder = os.path.join(dataset_folder, 'test')

    if not os.path.exists(train_folder):
        os.makedirs(train_folder)

    if not os.path.exists(test_folder):
        os.makedirs(test_folder)



### Workflow Overview

1. **Simulation IDs Definition**:  
    - A list named `simulation_ids` is initialized, defining various simulation parameters. Each entry specifies the dataset type (e.g., `grid`, `random_breach_grid`), the directory for saving the dataset (`datasets/train` or `datasets/test`), the starting simulation ID, the number of simulations to generate, and the grid dimensions.

2. **Dataset Folder Preparation**:  
    - The `create_dataset_folders` function is called to ensure the appropriate directory structure exists within a base `datasets` folder. This structure includes separate subdirectories for training and testing datasets.

3. **Path Debugging**:  
    - A debug print statement confirms the absolute path for a DEM file, helping to verify the correct directory structure and file naming convention.

4. **Dataset Generation and Saving**:  
    - The script iterates over each entry in `simulation_ids`, generating and saving datasets according to the specified parameters. The `create_grid_dataset` function is tasked with creating a PyG dataset for each simulation range, incorporating spatial and environmental data from text files. Following dataset creation, the `save_database` function serializes and saves the dataset to the specified directory.

### Key Functions

- **`create_grid_dataset`**: Automates the creation of grid-based datasets from specified simulation parameters, loading environmental data and structuring it into a format suitable for GNNs.
- **`save_database`**: Serializes and saves the generated PyG datasets, facilitating easy access for future model training and evaluation.



In [None]:
# Set the dataset folder path
dataset_folder = '/home/jupyter/SWE-GNN-paper-repository-/database/raw_datasets'


In [None]:
# Define simulation IDs
simulation_ids = [
    ['grid', 'datasets/train', 1, 80, 64],
    ['grid', 'datasets/test', 500, 20, 64],
    ['random_breach_grid', 'datasets/test', 10001, 20, 64],
    ['big_random_breach_grid', 'datasets/test', 15001, 10, 128],
]


In [None]:
# Create dataset folders
create_dataset_folders(dataset_folder='datasets')

# Debugging the path
print(os.path.abspath(os.path.join(dataset_folder, "DEM", "DEM_1.txt")))

for dataset_name, dataset_dir, start_sim_id, n_sim, n_grids in simulation_ids:
    pyg_dataset = create_grid_dataset(dataset_folder, n_sim=n_sim, start_sim=start_sim_id, number_grids=n_grids)
    save_database(pyg_dataset, name=dataset_name, out_path=dataset_dir)


In [None]:
#This is how a sample of the dataset will look like:
pyg_dataset[0]