In [1]:
import readfof
from pyspark.sql import SparkSession
import numpy as np
import scipy.spatial as SS

### Spark cluster

In [2]:
spark = SparkSession.builder \
        .master("spark://master:7077")\
        .appName("CosmoSparkApplication")\
        .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/09/02 11:03:32 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
sc = spark.sparkContext

### Useful functions

In [4]:
# Read data
def read_cosmo_data(file_path):

    # Read Fof
    FoF = readfof.FoF_catalog(
        file_path,           # simulation directory
        2,                   # snapnum, indicating the redshift (z=1)
        long_ids = False,
        swap = False,
        SFR = False,
        read_IDs = False
        )

    return FoF

# Get masses and positions from FoF
def get_pos_mass(FoF):

    pos = FoF.GroupPos/1e06             # Halo positions in Gpc/h 
    mass_raw = FoF.GroupMass * 1e10     # Halo masses in Msun/h

    dim = pos.shape[0]
    pos_mass_matrix = np.hstack([pos, mass_raw.reshape(dim, 1)])

    return pos_mass_matrix

# Mass cut function
def mass_filter(pos_mass_element, cut):

    mass = pos_mass_element[3]
    if mass >= cut:
        return pos_mass_element

In [5]:
# simulations parameter

sim_pars_file = np.loadtxt("/mnt/cosmo_GNN/latin_hypercube_params.txt", dtype=float)

In [6]:
# read file and get positions and masses
file_path = "/mnt/cosmo_GNN/Data/" + str(13)
test_FoF = read_cosmo_data(file_path)
pos_mass_array = get_pos_mass(test_FoF)

# mass cut
cut = np.quantile(pos_mass_array[:, 3], 0.997)

# parallelize and filter
pos_mass_rdd = sc.parallelize(pos_mass_array)\
                .map(lambda el: mass_filter(el, cut))

In [7]:
def find_pairs(array):

    distances = []
    for i in range(array.shape[0]):
        dist = []
        for j in range(i, array.shape[0]):
            diff = np.abs(array[i]-array[j])
            dist.append(np.linalg.norm(diff))
        distances.append(dist)

    pairs = []
    for i in range(len(distances)):
        for j in range(1, len(distances[i])):
            if distances[i][j] <= 0.2:
                pairs.append([i,j])
    
    return pairs

In [10]:
def find_pairs_distr(rdd):

    # Step 1: Create all pairs of indices using the Cartesian product
    indexed_rdd = rdd.zipWithIndex()
    cartesian_rdd = indexed_rdd.cartesian(indexed_rdd)
    
    # Step 2: Compute distances for each pair
    distances_rdd = cartesian_rdd.map(lambda x: (x[0][1], x[1][1], np.linalg.norm(np.abs(x[0][0] - x[1][0]))))
    
    # Step 3: Filter pairs based on distance <= 0.2
    close_pairs_rdd = distances_rdd.filter(lambda x: x[0] < x[1] and x[2] <= 0.2)
    
    # Step 4: Collect and return the pairs
    pairs = close_pairs_rdd.map(lambda x: (x[0], x[1])).collect()
    
    return pairs

In [17]:
pos_mass_rdd.zipWithIndex().take(3)

[(array([7.4910957e-01, 3.9241824e-01, 7.2230619e-01, 1.3495033e+15],
        dtype=float32),
  0),
 (array([5.2160764e-01, 8.4915513e-01, 7.3328722e-01, 1.1827261e+15],
        dtype=float32),
  1),
 (array([3.10260355e-01, 2.15019077e-01, 3.10398489e-02, 1.08891796e+15],
        dtype=float32),
  2)]

In [20]:
pos_mass_rdd.take(3)

[array([7.4910957e-01, 3.9241824e-01, 7.2230619e-01, 1.3495033e+15],
       dtype=float32),
 array([5.2160764e-01, 8.4915513e-01, 7.3328722e-01, 1.1827261e+15],
       dtype=float32),
 array([3.10260355e-01, 2.15019077e-01, 3.10398489e-02, 1.08891796e+15],
       dtype=float32)]

In [22]:
pos_rdd = pos_mass_rdd.map(lambda el: el[:3])
pos_rdd.take(2)

[array([0.74910957, 0.39241824, 0.7223062 ], dtype=float32),
 array([0.52160764, 0.8491551 , 0.7332872 ], dtype=float32)]

In [None]:
def none_filter(el):
    

In [None]:
pos_rdd_non_none = pos_rdd.filter(lambda x: x if not np.any(x == None))
pos_rdd_non_none.zipWithIndex()

In [None]:
pairs_rdd = find_pairs_distr(pos_rdd)

In [None]:
pairs_rdd.take(3)

In [26]:
sc.stop()
spark.stop()