# DISTRIBUTED GRAPHS

In [99]:
import readfof
from pyspark.sql import SparkSession
import numpy as np
import scipy.spatial as SS

In [2]:
spark = SparkSession.builder \
        .master("spark://master:7077")\
        .appName("CosmoSparkApplication")\
        .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/08/29 08:31:45 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
sc = spark.sparkContext

In [4]:
# Read data
def read_cosmo_data(file_path):

    # Read Fof
    FoF = readfof.FoF_catalog(
        file_path,           # simulation directory
        2,                  # snapnum, indicating the redshift (z=1)
        long_ids = False,
        swap = False,
        SFR = False,
        read_IDs = False
        )

    return FoF

In [18]:
# get masses and positions

def get_pos_mass(FoF):

    pos = FoF.GroupPos/1e06             # Halo positions in Gpc/h 
    mass_raw = FoF.GroupMass * 1e10     # Halo masses in Msun/h

    dim = pos.shape[0]
    pos_mass_matrix = np.hstack([pos, mass_raw.reshape(dim,1)])

    return pos_mass_matrix

In [57]:
# simulations parameter

sim_pars_file = np.loadtxt("/mnt/cosmo_GNN/latin_hypercube_params.txt", dtype=float)

sim_pars_file.shape

(2000, 2)

In [68]:
# number of simulation to use
N_sims = 50
# key list
key_list = np.arange(N_sims)
# path list (key, value)
path_list = [(i, "/mnt/cosmo_GNN/Data/"+str(i)) for i in range(N_sims)]

In [69]:
test_FoF = read_cosmo_data(path_list[0][1])

In [70]:
test_pos = test_FoF.GroupPos/1e06
test_masses = test_FoF.GroupMass*1e10

In [71]:
test_pos.shape

(62392, 3)

In [72]:
test_masses.shape

(62392,)

In [73]:
np.hstack([test_pos, test_masses.reshape(62392,1)]).shape

(62392, 4)

In [74]:
# FoF RDD
cosmo_rdd = sc.parallelize(path_list)\
            .mapValues(read_cosmo_data)

In [75]:
cosmo_rdd.count()

50

In [76]:
cosmo_rdd.getNumPartitions()

16

In [79]:
pos_mass_rdd = cosmo_rdd.mapValues(get_pos_mass)

In [80]:
cosa = pos_mass_rdd.take(2)

In [81]:
len(cosa)

2

In [82]:
type(cosa)

list

In [89]:
cosa[1][1].shape

(212944, 4)

In [91]:
cosa_pos = cosa[1][1][:,0:3]
cosa_mass = cosa[1][1][:,3]

cosa_cut = np.quantile(cosa_mass, 0.997)
cosa_mask = (cosa_mass >= cosa_cut)

In [92]:
cosa_pos[cosa_mask].shape

(641, 3)

In [93]:
cosa_mass[cosa_mask].shape

(641,)

In [94]:
cosa_dim = cosa_mass[cosa_mask].shape[0]
cosa_filtered = np.hstack([cosa_pos[cosa_mask], cosa_mass[cosa_mask].reshape(cosa_dim,1)])
cosa_filtered.shape

(641, 4)

In [95]:
def mass_filter(pos_mass_matrix):

    mass = pos_mass_matrix[:,3]
    pos = pos_mass_matrix[:,0:3]
    cut = np.quantile(mass, 0.997)
    mask = (mass >= cut)
    mass_filtered = mass[mask]
    pos_filtered = pos[mask]

    dim = mass_filtered.shape[0]

    pos_mass_matrix_filtered = np.hstack([pos_filtered, mass_filtered.reshape(dim,1)])

    return pos_mass_matrix_filtered


In [96]:
filtered_rdd = pos_mass_rdd.mapValues(mass_filter)

In [98]:
filtered_rdd.take(2)[1][1].shape

(641, 4)

In [100]:
# get KDTree
def get_tree(pos):

    kd_tree = SS.KDTree(pos, leafsize=16, boxsize=1.0001)

    return kd_tree

In [101]:
kdtree_rdd = filtered_rdd.mapValues(lambda el: get_tree(el[:,0:3]))

In [102]:
kdtree_rdd.take(2)

                                                                                

[(0, <scipy.spatial._kdtree.KDTree at 0x7fcb0511bed0>),
 (1, <scipy.spatial._kdtree.KDTree at 0x7fcb049d24d0>)]

In [103]:
# get edge indexes

def get_edges(tree):

    edge_idx = tree.query_pairs(r=0.2, output_type="ndarray")

    return edge_idx

In [104]:
edge_idx_rdd = kdtree_rdd.mapValues(get_edges)

In [105]:
edge_idx_rdd.take(2)

[(0,
  array([[ 30, 176],
         [ 37, 176],
         [ 24, 176],
         ...,
         [ 90, 161],
         [ 90, 145],
         [145, 161]])),
 (1,
  array([[138, 283],
         [138, 562],
         [138, 431],
         ...,
         [476, 563],
         [563, 639],
         [476, 639]]))]

In [49]:
sc.stop()
spark.stop()