# Generate Graphs data for training and testing the models

Uses class DataGeneration functions

--------------------

### Imports

In [None]:
import os
import joblib
import pickle
from torch.utils.data import ConcatDataset
from src.Data_generation import *

---------------

### Define the parameters for generating the graph datasets


In [None]:
set_type='train'  #Type of set to be generated: train or test. Note: train should only be used with simulation_1
simulation_name="simulation_1" #Simulation name to be used on inputs dir.
graph_name_list=["erdos","geometric","watts_strogatz","k_regular","barabasi"] #Graphs names to be generated
n_simulations=10 #Number of simulations.
n_graphs=40 #Number of graphs in each set for each simulation.
n_nodes=100 #Size of each graphs.

#### Generation loop for train or test data

Generates and save the graph datasets

In [None]:
if set_type == 'train':
    #Generate TrainDataset - only for Simulation 1 - correlation values in range (-1,1) for training the GNN-Siamese
    dataset_list1 = []
    dataset_list2 = []
    generator = DataGeneration(simulation_name,n_simulations)
    generator.generate_siamese_data(n_graphs)
    for graph_name in graph_name_list:
        dataset1 = generator.simulate_graph_and_nodes(generator.p_norm1,graph_name,n_nodes)
        dataset2 = generator.simulate_graph_and_nodes(generator.p_norm2,graph_name,n_nodes)
        dataset_list1.append(dataset1)
        dataset_list2.append(dataset2)
    train_dataset1 = ConcatDataset(dataset1)
    train_dataset2 = ConcatDataset(dataset2)
    torch.save(train_dataset1.datasets,f'/graph_correlation/data/train/train_set1.pth')
    torch.save(train_dataset2.datasets,f'/graph_correlation/data/train/train_set2.pth')

    scaler = generator.generate_scaler(ConcatDataset(dataset_list1+dataset_list2))
    joblib.dump(scaler, f'/graph_correlation/data/scaler.pkl')
else:
    #Generate Test dataset
    generator = DataGeneration(simulation_name,n_simulations)
    generator.generate_siamese_data(n_graphs)
    for graph_name in graph_name_list:
        test_dataset1 = generator.simulate_graph_and_nodes(generator.p_norm1,graph_name,n_nodes)
        test_dataset2 = generator.simulate_graph_and_nodes(generator.p_norm2,graph_name,n_nodes)
        torch.save(test_dataset1,f'/graph_correlation/data/{simulation_name}/test_set1_{simulation_name}_{graph_name}_{n_nodes}_nodes_{n_graphs}_graphs.pth')
        torch.save(test_dataset2,f'/graph_correlation/data/{simulation_name}/test_set2_{simulation_name}_{graph_name}_{n_nodes}_nodes_{n_graphs}_graphs.pth')