In [21]:
import readfof
from pyspark.sql import SparkSession
import numpy as np
import scipy.spatial as SS
from scipy.spatial import KDTree

In [22]:
spark = SparkSession.builder \
        .master("spark://master:7077")\
        .appName("CosmoSparkApplication")\
        .getOrCreate()

In [23]:
sc = spark.sparkContext

In [24]:
def read_cosmo_data(file_path):

    # Read Fof
    FoF = readfof.FoF_catalog(
        file_path,           # simulation directory
        2,                   # snapnum, indicating the redshift (z=1)
        long_ids = False,
        swap = False,
        SFR = False,
        read_IDs = False
        )

    pos = FoF.GroupPos/1e06             # Halo positions in Gpc/h 
    mass_raw = FoF.GroupMass * 1e10     # Halo masses in Msun/h

    dim = pos.shape[0]
    id = np.arange(dim, dtype=int).reshape(dim, 1)
    pos_mass_matrix = np.hstack([id, pos, mass_raw.reshape(dim, 1)])

    return pos_mass_matrix

def find_cut(pos_mass_array, percent):

    return np.quantile(pos_mass_array[:, 4], percent)

def mass_filter(pos_mass_cut_rdd):

    mass = pos_mass_cut_rdd[0][:, 4]
    cut = pos_mass_cut_rdd[1]
    
    #if mass >= cut:
        #return pos_mass_cut_rdd[0]
    return pos_mass_cut_rdd[0][mass >= cut, :]

def assign_box(point, boxes):
    id, x, y, z, m = point
    box_assign = []
    
    for box_name, ((x_min, x_max), (y_min, y_max), (z_min, z_max)) in boxes.items():
     if (x_min <= x <= x_max) and (y_min <= y <= y_max) and (z_min <= z <= z_max):
           box_assign.append((box_name, point))
    
    return box_assign

# KD-tree to find edges in boxes
def get_edges(pos_mass_points):
    pos_mass_matrix = np.array(pos_mass_points)
    pos = pos_mass_matrix[:,1:4]
    id = pos_mass_matrix[:,0]

    kd_tree = SS.KDTree(pos, leafsize=16, boxsize=1.00001)
    edge_idx = kd_tree.query_pairs(r=0.2, output_type="ndarray")
    edge_idx = np.array([sorted((id[i], id[j])) for i, j in edge_idx])
    
    return edge_idx


# find unique pairs of edges between 2 boxes (pere uniche)
def unique_pears(mat1, mat2):
    mat = np.vstack((mat1, mat2))
    
    return np.unique(mat, axis=0)


In [25]:
N_sims = 1000

# path list (key, value)
path_list = [(i, "/mnt/cosmo_GNN/Data/"+str(i)) for i in range(N_sims)]

cosmo_rdd = sc.parallelize(path_list)\
            .mapValues(read_cosmo_data)


In [26]:
cut_rdd = cosmo_rdd.mapValues(lambda x: find_cut(x,0.997))

In [27]:
joined_rdd = cosmo_rdd.join(cut_rdd)

In [28]:
pos_mass_filtered = joined_rdd.mapValues(mass_filter)\
                                .filter(lambda x: x is not None)

In [29]:
cosmo_rdd = pos_mass_filtered.flatMapValues(lambda x: x)

In [30]:
min_x, min_y, min_z = 0, 0, 0 #min_coords
max_x, max_y, max_z = 1, 1, 1 #max_coords

r = 0.1  

# Compute the midpoint for every dimension
x_mid = np.mean([min_x, max_x])
y_mid = np.mean([min_y, max_y])
z_mid = np.mean([min_z, max_z])

boxes = {
    "box1": [(min_x    , x_mid + r ), (min_y    , y_mid + r), (min_z    , z_mid + r )],
    "box2": [(x_mid - r, max_x     ), (min_y    , y_mid + r), (min_z    , z_mid + r )],
    "box3": [(min_x    , x_mid + r ), (y_mid - r, max_y    ), (min_z    , z_mid + r )],
    "box4": [(x_mid - r, max_x     ), (y_mid - r, max_y    ), (min_z    , z_mid + r )],
    "box5": [(min_x    , x_mid + r ), (min_y    , y_mid + r), (z_mid - r, max_z    )],
    "box6": [(x_mid - r, max_x     ), (min_y    , y_mid + r), (z_mid - r, max_z    )],
    "box7": [(min_x    , x_mid + r ), (y_mid - r, max_y    ), (z_mid - r, max_z    )],
    "box8": [(x_mid - r, max_x     ), (y_mid - r, max_y    ), (z_mid - r, max_z    )],
}


In [31]:
point_box_rdd = cosmo_rdd.flatMapValues(lambda p: assign_box(p, boxes))

In [32]:
boxes_rdd_sim = point_box_rdd.groupByKey()\
                            .mapValues(lambda x: list(x))
                        #.map(lambda x: (str(x[0]) + list(x[1])[0][0], list(x[1])[0][1]))\
                        #.groupByKey()\
                        

In [33]:
boxes_rdd = point_box_rdd.map(lambda x: (str(x[0]) + '_' + list(x[1])[0], list(x[1])[1]))\
                        .groupByKey()\
                        .mapValues(list)

In [34]:
edges_rdd = boxes_rdd.mapValues(get_edges)

In [35]:
rdd_key_sim = edges_rdd.map(lambda x: (int(x[0][0]), x[1]))\
                     #.reduce(lambda a, b: unique_pears(a[1], b[1]))

In [36]:
final = rdd_key_sim.reduceByKey(lambda a, b: unique_pears(a, b))

In [37]:
final.collect()

                                                                                

[(0,
  array([[  0.,   6.],
         [  0.,  83.],
         [  0., 116.],
         ...,
         [169., 183.],
         [171., 172.],
         [184., 185.]])),
 (1,
  array([[0.000e+00, 1.000e+00],
         [0.000e+00, 2.000e+00],
         [0.000e+00, 3.000e+00],
         ...,
         [1.159e+03, 1.161e+03],
         [1.163e+03, 1.170e+03],
         [1.168e+03, 1.169e+03]])),
 (2,
  array([[0.000e+00, 1.000e+00],
         [0.000e+00, 2.000e+00],
         [0.000e+00, 3.000e+00],
         ...,
         [1.028e+03, 1.037e+03],
         [1.034e+03, 1.048e+03],
         [1.035e+03, 1.037e+03]])),
 (3,
  array([[0.000e+00, 1.000e+00],
         [0.000e+00, 2.000e+00],
         [0.000e+00, 3.000e+00],
         ...,
         [1.143e+03, 1.147e+03],
         [1.147e+03, 1.159e+03],
         [1.149e+03, 1.156e+03]])),
 (4,
  array([[0.000e+00, 1.000e+00],
         [0.000e+00, 2.000e+00],
         [0.000e+00, 3.000e+00],
         ...,
         [1.093e+03, 1.104e+03],
         [1.094e+03, 1.096e+0

In [143]:
#flattened_rdd = boxes_rdd_sim.flatMap(lambda x: [(x[0], box, arr) for box, arr in x[1]])
#
## Step 2: Group by 'box', keeping 'n' as the first element of the tuple
#grouped_by_box_rdd = flattened_rdd.map(lambda x: (x[1], (x[0], x[2])))\
#                                  .groupByKey()\
#                                  .mapValues(lambda x: list(x))
#
## Step 3: Regroup by 'n', so that the output is (n, [(box, [array, array, ...])])
#final_rdd = grouped_by_box_rdd.flatMap(lambda x: [(n, (x[0], arr)) for n, arr in x[1]])\
#                              .groupByKey()\
#                              .mapValues(lambda x: list(x))
#
## Collecting the results to see the output
#flattened_rdd.take(1)

In [20]:
sc.stop()
spark.stop()