In [313]:
import readfof
from pyspark.sql import SparkSession
import numpy as np
import scipy.spatial as SS
from scipy.spatial import KDTree

In [314]:
spark = SparkSession.builder \
        .master("spark://master:7077")\
        .appName("CosmoSparkApplication")\
        .getOrCreate()

In [315]:
sc = spark.sparkContext

In [316]:
def read_cosmo_data(file_path):

    # Read Fof
    FoF = readfof.FoF_catalog(
        file_path,           # simulation directory
        2,                   # snapnum, indicating the redshift (z=1)
        long_ids = False,
        swap = False,
        SFR = False,
        read_IDs = False
        )

    pos = FoF.GroupPos/1e06             # Halo positions in Gpc/h 
    mass_raw = FoF.GroupMass * 1e10     # Halo masses in Msun/h

    dim = pos.shape[0]
    id = np.arange(dim, dtype=int).reshape(dim, 1)
    pos_mass_matrix = np.hstack([id, pos, mass_raw.reshape(dim, 1)])

    return pos_mass_matrix

def find_cut(pos_mass_array, percent):

    return np.quantile(pos_mass_array[:, 4], percent)

def mass_filter(pos_mass_cut_rdd):

    mass = pos_mass_cut_rdd[0][:, 4]
    cut = pos_mass_cut_rdd[1]
    
    #if mass >= cut:
        #return pos_mass_cut_rdd[0]
    return pos_mass_cut_rdd[0][mass >= cut, :]

def assign_box(point, boxes):
    id, x, y, z, m = point
    box_assign = []
    
    for box_name, ((x_min, x_max), (y_min, y_max), (z_min, z_max)) in boxes.items():
     if (x_min <= x <= x_max) and (y_min <= y <= y_max) and (z_min <= z <= z_max):
           box_assign.append((box_name, point))
    
    return box_assign

In [317]:
N_sims = 10

# path list (key, value)
path_list = [(i, "/mnt/cosmo_GNN/Data/"+str(i)) for i in range(N_sims)]

cosmo_rdd = sc.parallelize(path_list)\
            .mapValues(read_cosmo_data)


In [318]:
cut_rdd = cosmo_rdd.mapValues(lambda x: find_cut(x,0.997))

In [319]:
joined_rdd = cosmo_rdd.join(cut_rdd)

In [320]:
pos_mass_filtered = joined_rdd.mapValues(mass_filter)\
                                .filter(lambda x: x is not None)

In [321]:
cosmo_rdd = pos_mass_filtered.flatMapValues(lambda x: x)

In [322]:
cosmo_rdd.take(3)[0:2]

                                                                                

[(8,
  array([0.00000000e+00, 7.75257885e-01, 1.47849455e-01, 4.60891694e-01,
         7.21348146e+14])),
 (8,
  array([1.00000000e+00, 6.15689993e-01, 3.98693532e-01, 7.92618036e-01,
         5.53142664e+14]))]

In [323]:
min_x, min_y, min_z = 0, 0, 0 #min_coords
max_x, max_y, max_z = 1, 1, 1 #max_coords

r = 0.1  

# Compute the midpoint for every dimension
x_mid = np.mean([min_x, max_x])
y_mid = np.mean([min_y, max_y])
z_mid = np.mean([min_z, max_z])

boxes = {
    "box1": [(min_x    , x_mid + r ), (min_y    , y_mid + r), (min_z    , z_mid + r )],
    "box2": [(x_mid - r, max_x     ), (min_y    , y_mid + r), (min_z    , z_mid + r )],
    "box3": [(min_x    , x_mid + r ), (y_mid - r, max_y    ), (min_z    , z_mid + r )],
    "box4": [(x_mid - r, max_x     ), (y_mid - r, max_y    ), (min_z    , z_mid + r )],
    "box5": [(min_x    , x_mid + r ), (min_y    , y_mid + r), (z_mid - r, max_z    )],
    "box6": [(x_mid - r, max_x     ), (min_y    , y_mid + r), (z_mid - r, max_z    )],
    "box7": [(min_x    , x_mid + r ), (y_mid - r, max_y    ), (z_mid - r, max_z    )],
    "box8": [(x_mid - r, max_x     ), (y_mid - r, max_y    ), (z_mid - r, max_z    )],
}


In [324]:
point_box_rdd = cosmo_rdd.flatMapValues(lambda p: assign_box(p, boxes))

In [325]:
point_box_rdd.take(2)

[(0,
  ('box2',
   array([0.00000000e+00, 6.76912606e-01, 6.60486296e-02, 2.68535763e-01,
          6.35113457e+14]))),
 (0,
  ('box3',
   array([1.00000000e+00, 1.68456003e-01, 6.21953666e-01, 3.80254686e-01,
          4.60181352e+14])))]

In [346]:
point_box_rdd.values().take(3)
            #.keys()\
            #.take(3)

[('box2',
  array([0.00000000e+00, 7.75257885e-01, 1.47849455e-01, 4.60891694e-01,
         7.21348146e+14])),
 ('box6',
  array([0.00000000e+00, 7.75257885e-01, 1.47849455e-01, 4.60891694e-01,
         7.21348146e+14])),
 ('box6',
  array([1.00000000e+00, 6.15689993e-01, 3.98693532e-01, 7.92618036e-01,
         5.53142664e+14]))]

In [402]:
boxes_rdd_sim = point_box_rdd.groupByKey()\
                        .mapValues(lambda x: list(x))\
                        #.groupByKey()\
                        #.mapValues(lambda x: list(x))

In [430]:
boxes_rdd_sim.take(10)

[(0,
  [('box2',
    array([0.00000000e+00, 6.76912606e-01, 6.60486296e-02, 2.68535763e-01,
           6.35113457e+14])),
   ('box3',
    array([1.00000000e+00, 1.68456003e-01, 6.21953666e-01, 3.80254686e-01,
           4.60181352e+14])),
   ('box1',
    array([2.00000000e+00, 3.19514722e-01, 1.58893153e-01, 3.74134451e-01,
           4.34773265e+14])),
   ('box3',
    array([3.00000000e+00, 5.46004772e-01, 8.60320926e-01, 3.01911861e-01,
           4.13726381e+14])),
   ('box4',
    array([3.00000000e+00, 5.46004772e-01, 8.60320926e-01, 3.01911861e-01,
           4.13726381e+14])),
   ('box7',
    array([4.00000000e+00, 2.22168431e-01, 6.81388080e-01, 7.62438238e-01,
           4.13000498e+14])),
   ('box6',
    array([5.00000000e+00, 8.63869369e-01, 5.29710343e-03, 7.87207365e-01,
           4.05378978e+14])),
   ('box2',
    array([6.00000000e+00, 6.72135472e-01, 7.00396970e-02, 2.54645526e-01,
           3.66908453e+14])),
   ('box5',
    array([7.00000000e+00, 2.49624431e-01, 5.07

In [445]:
boxes_rdd = boxes_rdd_sim.values()\
                        .map(lambda x: x[0])\
                        .groupByKey()\
                        .mapValues(lambda x: list(x))

In [446]:
boxes_rdd.take(10)

[('box2',
  [array([0.00000000e+00, 6.76912606e-01, 6.60486296e-02, 2.68535763e-01,
          6.35113457e+14]),
   array([0.00000000e+00, 7.75257885e-01, 1.47849455e-01, 4.60891694e-01,
          7.21348146e+14])]),
 ('box3',
  [array([0.00000000e+00, 1.75386652e-01, 6.42486632e-01, 3.81892830e-01,
          8.81280917e+14]),
   array([0.00000000e+00, 8.89343396e-02, 9.59361136e-01, 4.08325493e-01,
          8.60586087e+14])]),
 ('box4',
  [array([0.00000000e+00, 6.09665453e-01, 6.43911779e-01, 5.03495038e-01,
          1.96273536e+15]),
   array([0.00000000e+00, 8.81940246e-01, 7.82104313e-01, 1.62193537e-01,
          5.66739188e+14])]),
 ('box7',
  [array([0.00000000e+00, 2.22270206e-01, 6.82273269e-01, 7.60720909e-01,
          1.13813613e+15]),
   array([0.00000000e+00, 1.47837684e-01, 9.53166068e-01, 7.76914358e-01,
          1.91964301e+15])]),
 ('box1',
  [array([0.00000000e+00, 1.57785699e-01, 3.36683959e-01, 3.64182174e-01,
          6.64899525e+14])]),
 ('box8',
  [array([0.

In [310]:
sc.stop()
spark.stop()