# DISTRIBUTED GRAPHS

In [1]:
import readfof
from pyspark.sql import SparkSession
import numpy as np
import scipy.spatial as SS
import math
import matplotlib.pyplot as plt

### Spark cluster

In [2]:
spark = SparkSession.builder \
        .master("spark://master:7077")\
        .appName("CosmoSparkApplication")\
        .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/09/02 09:18:34 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
sc = spark.sparkContext

### Useful functions

In [5]:
class graph:

    def __init__(self, node_f, pos, sim_pars, glob_f, edge_idx, edge_f):
        
        self.node_f = node_f
        self.pos = pos
        self.sim_pars = sim_pars
        self.glob_f = glob_f
        self.edge_idx = edge_idx
        self.edge_f = edge_f

        self.boxsize = 1.e6 # Box size in comoving kpc/h 

        def plot3D(self, num, pars_file):
            fig = plt.figure(figsize=(12, 12))
            fontsize = 12

            ax = fig.add_subplot(projection ="3d")

            self.pos *= self.boxsize/1.e3   # show in Mpc

            # Draw lines for each edge
            for (src, dst) in self.edge_idx: #.t().tolist():

                src = pos[int(src)].tolist()
                dst = pos[int(dst)].tolist()

                ax.plot([src[0], dst[0]], [src[1], dst[1]], zs=[src[2], dst[2]], linewidth=0.6, color='dimgrey')

            # Plot nodes
            mass_mean = np.mean(self.node_f)
            for i,m in enumerate(self.node_f):
                    ax.scatter(pos[i, 0], pos[i, 1], pos[i, 2], s=50*m*m/(mass_mean**2), zorder=1000, alpha=0.6, color = 'mediumpurple')

            ax.xaxis.set_tick_params(labelsize=fontsize)
            ax.yaxis.set_tick_params(labelsize=fontsize)
            ax.zaxis.set_tick_params(labelsize=fontsize)

            ax.set_xlabel('x (Mpc)', fontsize=16, labelpad=15)
            ax.set_ylabel('y (Mpc)', fontsize=16, labelpad=15)
            ax.set_zlabel('z (Mpc)', fontsize=16, labelpad=15)

            rl = '$R_{link} = 0.2$'

            ax.set_title(f'\tGraph n°{num}, Masses $\\geq 99.7$% percentile, {rl} Mpc \t \n \n $\\Omega_m = {float(pars_file[0]):.3f}$ \t $\\sigma_8 = {float(pars_file[1]):.3f}$', fontsize=20)

            # fig.savefig("Plots/Graphs/graph_"+num+"_997.png", bbox_inches='tight', pad_inches=0.6, dpi=400)
            # plt.close(fig)

            plt.show()

In [6]:
# Read data
def read_cosmo_data(file_path):

    # Read Fof
    FoF = readfof.FoF_catalog(
        file_path,           # simulation directory
        2,                   # snapnum, indicating the redshift (z=1)
        long_ids = False,
        swap = False,
        SFR = False,
        read_IDs = False
        )

    return FoF

# Get masses and positions
def get_pos_mass(FoF):

    pos = FoF.GroupPos/1e06             # Halo positions in Gpc/h 
    mass_raw = FoF.GroupMass * 1e10     # Halo masses in Msun/h

    dim = pos.shape[0]
    pos_mass_matrix = np.hstack([pos, mass_raw.reshape(dim,1)])

    return pos_mass_matrix

# Mass cut function
def mass_filter(pos_mass_matrix):

    mass = pos_mass_matrix[:,3]
    pos = pos_mass_matrix[:,0:3]
    cut = np.quantile(mass, 0.997)
    mask = (mass >= cut)
    mass_filtered = mass[mask]
    pos_filtered = pos[mask]

    dim = mass_filtered.shape[0]

    pos_mass_matrix_filtered = np.hstack([pos_filtered, mass_filtered.reshape(dim,1)])

    return pos_mass_matrix_filtered

# Get KDTree
def get_tree(pos_mass_matrix):

    pos = pos = pos_mass_matrix[:,0:3]
    kd_tree = SS.KDTree(pos, leafsize=16, boxsize=1.0001)

    return kd_tree

# Get edge indexes
def get_edges(tree):

    edge_idx = tree.query_pairs(r=0.2, output_type="ndarray")

    return edge_idx

# Add reverse pairs
def rev_pairs(edge_index_array):
    reversepairs = edge_index_array[:, [1,0]]
    edge_index_array_r = np.vstack([edge_index_array, reversepairs])
    # make sure indexes are integers
    return edge_index_array_r.astype(int)

# Get edge features 
def get_edge_feat(joined_tuple):
    # tuple: (positions_masses, indexes)
    pos_mass = joined_tuple[0]
    edg_idx = joined_tuple[1]
    edg_idx = edg_idx.T
    pos = pos_mass[:, 0:3]
    # distance vector
    row, col = edg_idx
    diff = pos[row]-pos[col]

    # boundary conditions
    diff_bc = np.where(diff < -0.01, diff + 1.0, diff)
    diff = np.where(diff > 0.01, diff - 1.0, diff_bc)

    # distance = sqrt(dx^2+dy^2+dz^2)
    dist = np.linalg.norm(diff, axis=1)

    # centroid of halo catalogue (3d position of the centroid)
    centroid = np.mean(pos,axis=0)

    # distance between each point wrt the centroid
    row = pos[row] - centroid
    col = pos[col] - centroid

    # boundary conditions
    row_bc = np.where(row < -0.5, row + 1, row)
    row = np.where(row > 0.5, row - 1, row_bc)

    col_bc = np.where(col < -0.5, col + 1, col)
    col = np.where(col > 0.5, col - 1, col_bc)

    # normalizing
    unitrow = row/(np.linalg.norm(row, axis = 1).reshape(-1, 1))  
    unitcol = col/(np.linalg.norm(col, axis = 1).reshape(-1, 1))
    unitdiff = diff/(dist.reshape(-1,1))    

    # number of pairs
    n_pairs = edg_idx.shape[1]

    # get cosines
    cos1 = np.array([np.dot(unitrow[i,:].T,unitcol[i,:]) for i in range(n_pairs)])
    cos2 = np.array([np.dot(unitrow[i,:].T,unitdiff[i,:]) for i in range(n_pairs)])

    # Normalize distance by linking radius
    dist /= 0.2

    # concatenate to get all edge attributes
    edge_attr = np.concatenate([dist.reshape(-1,1), cos1.reshape(-1,1), cos2.reshape(-1,1)], axis=1)

    return edge_attr

# Get global features
def get_glob_feat(pos_mass_matrix):
    return math.log10(pos_mass_matrix.shape[0])

# Build graph
def get_graph(joined_tuple):
    # tuple: (positions_masses, edge_idx, simulation_parameters, edge_attr, global_features)
    pos_mass = joined_tuple[1][0]
    edge_idx = joined_tuple[1][1]
    sim_pars = joined_tuple[1][2]
    edge_attr = joined_tuple[1][3]
    glob_feat = joined_tuple[1][4]

    masses = pos_mass[:,3]
    pos = pos_mass[:, :3]

    return graph(masses, pos, sim_pars, glob_feat, edge_idx, edge_attr)

In [6]:
# simulations parameters

sim_pars_file = np.loadtxt("/mnt/cosmo_GNN/latin_hypercube_params.txt", dtype=float)
sim_pars_list = [(i, j) for i, j in enumerate(sim_pars_file)]
sim_pars_file_rdd = sc.parallelize(sim_pars_list)
sim_pars_file.shape

(2000, 2)

In [7]:
# number of simulation to use
N_sims = 5

# path list (key, value)
path_list = [(i, "/mnt/cosmo_GNN/Data/"+str(i)) for i in range(N_sims)]

# FoF RDD
cosmo_rdd = sc.parallelize(path_list)\
            .mapValues(read_cosmo_data)

In [8]:
cosmo_rdd.getNumPartitions()

2

#### Getting positions and masses from files

In [42]:
# array RDD
pos_mass_rdd = cosmo_rdd.mapValues(get_pos_mass)

# filtering
filtered_rdd = pos_mass_rdd.mapValues(mass_filter)

#### Clustering phase

In [None]:
# KDTree RDD
kdtree_rdd = filtered_rdd.mapValues(get_tree)

# edge indexes RDD
edge_idx_rdd = kdtree_rdd.mapValues(get_edges)

#### Building graphs

In [None]:
# adding reverse pairs
edge_idx_rdd_r = edge_idx_rdd.mapValues(rev_pairs)

# join filtered arrays RDD and edge indexes RDD by key
joined_rdd = filtered_rdd.join(edge_idx_rdd_r)

# edge features (dist, cos1, cos2)
edge_feat_rdd = joined_rdd.mapValues(get_edge_feat)

# global features
glob_feat_rdd = filtered_rdd.mapValues(get_glob_feat)

# join position and masses, indexes, simulation parameters
pos_mass_idx_simpar = joined_rdd.join(sim_pars_file_rdd)\
                                .map(lambda x: (x[0], (x[1][0][0], x[1][0][1], x[1][1])))

# join edge attributes 
semicomplete_rdd = pos_mass_idx_simpar.join(edge_feat_rdd)\
                                      .map(lambda x: (x[0], (x[1][0][0], x[1][0][1], x[1][0][2], x[1][1])))

# join all
complete_rdd = semicomplete_rdd.join(glob_feat_rdd)\
                               .map(lambda x: (x[0], (x[1][0][0], x[1][0][1], x[1][0][2], x[1][0][3], x[1][1])))

# graphs RDD
graph_rdd = complete_rdd.map(get_graph)

In [43]:
graph_rdd.getNumPartitions()

12

In [30]:
sc.stop()
spark.stop()