In [21]:
import readfof
from pyspark.sql import SparkSession
import numpy as np
import scipy.spatial as SS
from scipy.spatial import KDTree

In [22]:
spark = SparkSession.builder \
        .master("spark://master:7077")\
        .appName("CosmoSparkApplication")\
        .getOrCreate()

In [23]:
sc = spark.sparkContext

### Useful functions

In [24]:
def read_cosmo_data(file_path):

    # Read Fof
    FoF = readfof.FoF_catalog(
        file_path,           # simulation directory
        2,                   # snapnum, indicating the redshift (z=1)
        long_ids = False,
        swap = False,
        SFR = False,
        read_IDs = False
        )

    return FoF

# Get masses and positions from FoF
def get_pos_mass(FoF):

    pos = FoF.GroupPos/1e06             # Halo positions in Gpc/h 
    mass_raw = FoF.GroupMass * 1e10     # Halo masses in Msun/h

    dim = pos.shape[0]
    id = np.arange(dim, dtype=int).reshape(dim, 1)
    pos_mass_matrix = np.hstack([id, pos, mass_raw.reshape(dim, 1)])

    return pos_mass_matrix


# Mass cut function
def mass_filter(pos_mass_rdd, cut):
    mass = pos_mass_rdd[4]
    if mass >= cut:
        return pos_mass_rdd
    

# Assign each point to a box
def assign_box(point, boxes):
    id, x, y, z, m = point
    box_assign = []
    
    for box_name, ((x_min, x_max), (y_min, y_max), (z_min, z_max)) in boxes.items():
     if (x_min <= x <= x_max) and (y_min <= y <= y_max) and (z_min <= z <= z_max):
           box_assign.append((box_name, point))
    
    return box_assign


# KD-tree to find edges in boxes
def get_edges(pos_mass_points):
    pos_mass_matrix = np.array(pos_mass_points)
    pos = pos_mass_matrix[:,1:4]
    id = pos_mass_matrix[:,0]

    kd_tree = SS.KDTree(pos, leafsize=16, boxsize=1.00001)
    edge_idx = kd_tree.query_pairs(r=0.2, output_type="ndarray")
    edge_idx = np.array([sorted((id[i], id[j])) for i, j in edge_idx])
    
    return edge_idx


# find unique pairs of edges between 2 boxes (pere uniche)
def unique_pears(mat1, mat2):
    mat = np.vstack((mat1, mat2))
    
    return np.unique(mat, axis=0)

In [25]:
sim_pars_file = np.loadtxt("/mnt/cosmo_GNN/latin_hypercube_params.txt", dtype=float)

file_path = "/mnt/cosmo_GNN/Data/" + str(0)
test_FoF = read_cosmo_data(file_path)
pos_mass_array = get_pos_mass(test_FoF)

# mass cut
cut = np.quantile(pos_mass_array[:, 4], 0.997)

# parallelize and filter
pos_mass_rdd = sc.parallelize(pos_mass_array)

pos_mass_filtered = pos_mass_rdd.map(lambda x: mass_filter(x, cut))\
                                .filter(lambda x: x is not None)

In [26]:
pos_mass_filtered.take(2)

                                                                                

[array([0.00000000e+00, 6.76912606e-01, 6.60486296e-02, 2.68535763e-01,
        6.35113457e+14]),
 array([1.00000000e+00, 1.68456003e-01, 6.21953666e-01, 3.80254686e-01,
        4.60181352e+14])]

In [27]:
#min_coords, max_coords = calculate_bounds(pos_mass_rdd)
min_x, min_y, min_z = 0, 0, 0 #min_coords
max_x, max_y, max_z = 1, 1, 1 #max_coords

r = 0.1  

# Compute the midpoint for every dimension
x_mid = np.mean([min_x, max_x])
y_mid = np.mean([min_y, max_y])
z_mid = np.mean([min_z, max_z])

boxes = {
    "box1": [(min_x    , x_mid + r ), (min_y    , y_mid + r), (min_z    , z_mid + r )],
    "box2": [(x_mid - r, max_x     ), (min_y    , y_mid + r), (min_z    , z_mid + r )],
    "box3": [(min_x    , x_mid + r ), (y_mid - r, max_y    ), (min_z    , z_mid + r )],
    "box4": [(x_mid - r, max_x     ), (y_mid - r, max_y    ), (min_z    , z_mid + r )],
    "box5": [(min_x    , x_mid + r ), (min_y    , y_mid + r), (z_mid - r, max_z    )],
    "box6": [(x_mid - r, max_x     ), (min_y    , y_mid + r), (z_mid - r, max_z    )],
    "box7": [(min_x    , x_mid + r ), (y_mid - r, max_y    ), (z_mid - r, max_z    )],
    "box8": [(x_mid - r, max_x     ), (y_mid - r, max_y    ), (z_mid - r, max_z    )],
}


In [28]:
point_box_rdd = pos_mass_filtered.flatMap(lambda p: assign_box(p, boxes))

In [38]:
point_box_rdd.take(3)

AttributeError: 'NoneType' object has no attribute 'sc'

In [30]:
boxes_rdd = point_box_rdd.groupByKey().mapValues(list)


In [31]:
boxes_rdd.take(3)

                                                                                

[('box1',
  [array([2.00000000e+00, 3.19514722e-01, 1.58893153e-01, 3.74134451e-01,
          4.34773265e+14]),
   array([1.70000000e+01, 3.14249396e-01, 1.65724963e-01, 3.81010741e-01,
          2.60210845e+14]),
   array([1.80000000e+01, 2.65082687e-01, 1.89858750e-01, 1.25514776e-01,
          2.56582219e+14]),
   array([2.30000000e+01, 2.94121563e-01, 2.84590393e-01, 3.06423098e-01,
          2.48234681e+14]),
   array([2.40000000e+01, 2.32297137e-01, 3.94813828e-02, 6.40881807e-02,
          2.44968493e+14]),
   array([3.00000000e+01, 2.31418371e-01, 3.64492759e-02, 6.89148679e-02,
          2.18838969e+14]),
   array([3.70000000e+01, 2.36052260e-01, 3.40070836e-02, 6.04751781e-02,
          2.06863828e+14]),
   array([3.90000000e+01, 4.26883042e-01, 5.51962435e-01, 1.11783393e-01,
          2.06500014e+14]),
   array([4.80000000e+01, 3.38469237e-01, 3.45230460e-01, 7.31916800e-02,
          1.91620704e+14]),
   array([5.80000000e+01, 6.53460622e-02, 2.38424584e-01, 1.08026162e-01

In [32]:
edges_rdd = boxes_rdd.mapValues(get_edges)

In [34]:
edges_rdd.take(1)

[('box1',
  array([[ 58., 128.],
         [ 24., 128.],
         [ 30., 128.],
         [ 37., 128.],
         [ 73., 105.],
         [ 58.,  65.],
         [ 24.,  30.],
         [ 24.,  37.],
         [ 30.,  37.],
         [105., 110.],
         [128., 176.],
         [ 18., 128.],
         [ 68., 128.],
         [ 24., 176.],
         [ 18.,  24.],
         [ 24.,  68.],
         [ 30., 176.],
         [ 18.,  30.],
         [ 30.,  68.],
         [ 37., 176.],
         [ 18.,  37.],
         [ 37.,  68.],
         [ 17., 105.],
         [  2., 105.],
         [ 89., 105.],
         [ 18., 176.],
         [ 68., 176.],
         [ 18.,  68.],
         [  2.,  17.],
         [ 17.,  89.],
         [ 17.,  88.],
         [ 17.,  92.],
         [  2.,  89.],
         [  2.,  88.],
         [  2.,  92.],
         [ 88.,  89.],
         [ 89.,  92.],
         [ 88.,  92.],
         [128., 170.],
         [128., 156.],
         [ 73.,  98.],
         [ 73., 158.],
         [ 58.,  93.],
 

In [35]:
rdd_no_key = edges_rdd.map(lambda x: x[1])\
                     .reduce(lambda a, b: unique_pears(a, b))
                            

In [36]:
rdd_no_key

array([[  0.,   6.],
       [  0.,  83.],
       [  0., 116.],
       ...,
       [169., 183.],
       [171., 172.],
       [184., 185.]])

In [37]:
sc.stop()
spark.stop()