In [3]:
import sys as sys
import unyt
import h5py
import logging
import numpy as np
logging.basicConfig(level = logging.INFO)
sys.path.append('/home/jovyan/Analysis/transfer')
sys.path.append('/home/jovyan/Analysis/Pylians3/library/build/lib.linux-x86_64-3.7')
import readgadget
import time
from unyt import unyt_array, unyt_quantity
from transfer import LOGGER
from scipy.spatial import cKDTree
from math import sqrt
import pandas as pd
from numba import njit, int32, int64, float64
from numba.experimental import jitclass
from numpy import zeros, array, full, int32, empty, arange
from typing import Union, Tuple, Dict, Optional

list_of_strings2 = ["1P_p1_n2", "1P_p1_n1", "1P_p1_0", "1P_p1_1", "1P_p1_2", 
                    "1P_p2_n2", "1P_p2_n1", "1P_p2_0", "1P_p2_1", "1P_p2_2", 
                    "1P_p3_n2", "1P_p3_n1", "1P_p3_0", "1P_p3_1", "1P_p3_2", 
                    "1P_p4_n2", "1P_p4_n1", "1P_p4_0", "1P_p4_1", "1P_p4_2", 
                    "1P_p5_n2", "1P_p5_n1", "1P_p5_0", "1P_p5_1", "1P_p5_2", 
                    "1P_p6_n2", "1P_p6_n1", "1P_p6_0", "1P_p6_1", "1P_p6_2", ]

#This function finds the initial particle neighbors
def find_closest_neighbours(
    dark_matter_coordinates,
    dark_matter_ids,
    boxsize,
    gas_coordinates,
    gas_ids,
):
    LOGGER.info("Building dark matter tree for spread metric")
    tree = cKDTree(dark_matter_coordinates, boxsize=boxsize)
    LOGGER.info("Finished tree build")

    # For dark matter, the cloest particle will be ourself.
    _, closest_indicies = tree.query(x=dark_matter_coordinates, k=2, workers=-1)

    dark_matter_neighbours = np.array(dark_matter_ids[closest_indicies[:, 1]])

    # For gas, we can just use closest neighbour
    if gas_coordinates is not None:
        _, closest_indicies = tree.query(x=gas_coordinates, k=1, workers=-1)

        gas_neighbours = np.array(dark_matter_ids[closest_indicies])
    else:
        gas_neighbours = None

    return dark_matter_neighbours, gas_neighbours

#This function loads data specifically from EAGLE's format
def eagle_data_loader(filename, particle_type, array_name):
    unit_mass = unyt_quantity(1e10, "Solar_Mass")
    unit_length = unyt_quantity(1.0, "kpc")
    unit_velocity = unyt_quantity(1.0,"km/s")
    
    full_path = f"/PartType{particle_type}/{array_name}"
    
    if array_name == "Coordinates":
        if particle_type == 0:
            handle = h5py.File(f"{filename}", "r")
            #LOGGER.info("Loading DM coordinates in place of gas")
            output = handle[f"/PartType1/{array_name}"][:] * h * 1000
        else:
            handle = h5py.File(f"{filename}", "r")
            output = handle[full_path][:] * h * 1000
            
        ind_check0 = np.where(output[:,0] >= boxsize)[0]
        if ind_check0.any() == True:
            output[ind_check0, 0] = boxsize - 0.001
            LOGGER.info(f"Fixed index {ind_check0} of the X coordinate of particle type {particle_type}")
        ind_check1 = np.where(output[:,1] >= boxsize)[0]
        if ind_check1.any() == True:
            output[ind_check1, 1] = boxsize - 0.001
            LOGGER.info(f"Fixed index {ind_check0} of the Y coordinate of particle type {particle_type}")
        ind_check2 = np.where(output[:,2] >= boxsize)[0]
        if ind_check2.any() == True:
            output[ind_check2, 2] = boxsize - 0.001
            LOGGER.info(f"Fixed index {ind_check0} of the Z coordinate of particle type {particle_type}") 
            
        neg_check_0 = np.where(output[:,0]<0)[0]
        if neg_check_0.any() == True:
            output[neg_check_0, 0] = output[neg_check_0, 0] + boxsize
            LOGGER.info(f"Fixed index {neg_check_0} of the X coordinate")

        neg_check_1 = np.where(output[:,1]<0)[0]
        if neg_check_1.any() == True:
            output[neg_check_1, 1] = output[neg_check_1, 1] + boxsize
            LOGGER.info(f"Fixed index {neg_check_1} of the Y coordinate")

        neg_check_2 = np.where(output[:,2]<0)[0]
        if neg_check_2.any() == True:
            output[neg_check_2, 2] = output[neg_check_2, 2] + boxsize
            LOGGER.info(f"Fixed index {neg_check_2} of the Z coordinate")
            
    if array_name == "ParticleIDs":
        handle = h5py.File(f"{filename}", "r")
        output = handle[full_path][:]
        
    
        
    return output

#This function loads data from normal CAMELS snapshots
def data_loader(filename, particle_type, array_name):
    unit_mass = unyt_quantity(1e10, "Solar_Mass")
    unit_length = unyt_quantity(1.0, "kpc")
    unit_velocity = unyt_quantity(1.0,"km/s")
    
    header = readgadget.header(filename)
    h = header.hubble
    units = unyt_quantity(1.0 / h, units=unit_length).to("Mpc")
    hubble_param = h
    boxsize  = unyt.unyt_array(header.boxsize, units=units)

    header = readgadget.header(filename)
    h = header.hubble

    if array_name == "Coordinates":
        array_name = "POS "
        output = unyt_array(readgadget.read_block(filename, array_name, [particle_type], verbose=True), units = units)
    
        ind_check0 = np.where(output[:,0].value >= boxsize.value)[0]
        if ind_check0.any() == True:
            output[ind_check0, 0] = boxsize.value - 0.1
            LOGGER.info(f"Fixed index {ind_check0} of the X coordinate of particle type {particle_type}")
        ind_check1 = np.where(output[:,1].value >= boxsize.value)[0]
        if ind_check1.any() == True:
            output[ind_check1, 1] = boxsize.value - 0.1
            LOGGER.info(f"Fixed index {ind_check0} of the Y coordinate of particle type {particle_type}")
        ind_check2 = np.where(output[:,2].value >= boxsize.value)[0]
        if ind_check2.any() == True:
            output[ind_check2, 2] = boxsize.value - 0.1
            LOGGER.info(f"Fixed index {ind_check0} of the Z coordinate of particle type {particle_type}")
    if array_name == "ParticleIDs":
        array_name = "ID  "
        output = unyt_array(readgadget.read_block(filename, array_name, [particle_type], verbose=True), units=None)
    
    return output

#Loads data in TNG50 format
#Currently not used
def tng50_data_loader(filename, particle_type, array_name):
    unit_mass = unyt_quantity(1e10, "Solar_Mass")
    unit_length = unyt_quantity(1.0, "kpc")
    unit_velocity = unyt_quantity(1.0,"km/s")
    
    header = readgadget.header(filename)
    h = header.hubble
    units = unyt_quantity(1.0 / h, units=unit_length).to("Mpc")
    hubble_param = h
    boxsize  = unyt.unyt_array(header.boxsize, units=units)

    header = readgadget.header(filename)
    h = header.hubble
    
    final_output = []
    
    for k in range(16):
        index = filename.index("snap")
        filename = filename[:index]+f"snapdir_090/snap_090.{k}.hdf5"
        if array_name == "Coordinates":
            array_name = "POS "
            output = unyt_array(readgadget.read_block(filename, array_name, [particle_type], verbose=True), units = units)

            ind_check0 = np.where(output[:,0].value >= boxsize.value)[0]
            if ind_check0.any() == True:
                output[ind_check0, 0] = boxsize.value - 0.1
                LOGGER.info(f"Fixed index {ind_check0} of the X coordinate of particle type {particle_type}")
            ind_check1 = np.where(output[:,1].value >= boxsize.value)[0]
            if ind_check1.any() == True:
                output[ind_check1, 1] = boxsize.value - 0.1
                LOGGER.info(f"Fixed index {ind_check0} of the Y coordinate of particle type {particle_type}")
            ind_check2 = np.where(output[:,2].value >= boxsize.value)[0]
            if ind_check2.any() == True:
                output[ind_check2, 2] = boxsize.value - 0.1
                LOGGER.info(f"Fixed index {ind_check0} of the Z coordinate of particle type {particle_type}")
        if array_name == "ParticleIDs":
            array_name = "ID  "
            output = unyt_array(readgadget.read_block(filename, array_name, [particle_type], verbose=True), units=None)
        
        final_output = np.append(final_output, output)
    
    return final_output

#This function gets the nearest initial neighbors and converts it to the "spread output" array order
#Currently not used
def nearest_neighbors(suite, sim):
    
    LOGGER.info("Loading simulation data")
    
    simdir = f'/home/jovyan/PUBLIC_RELEASE/Sims/{suite}/L25n256/{sim[:2]}/{sim}/'
    if suite == "Swift-EAGLE":
        #initial_filename = simdir + 'ICs/ics.hdf5'
        initial_filename = simdir + 'snapshot_000.hdf5'
        unit_length = unyt_quantity(1.0, "kpc")
        h = 0.6711
        units = unyt_quantity(1.0 / h, units=unit_length).to("Mpc")
        hubble_param = h
        boxsize  = unyt.unyt_array(25000, units=units)
        
    else:
        initial_filename = simdir + 'ICs/ics'
        unit_length = unyt_quantity(1.0, "kpc")
        header = readgadget.header(initial_filename)
        h = header.hubble
        units = unyt_quantity(1.0 / h, units=unit_length).to("Mpc")
        hubble_param = h
        boxsize  = unyt.unyt_array(header.boxsize, units=units)
    
    LOGGER.info("running data_loader")

    
    if suite == "Swift-EAGLE":
        gas_ids = unyt_array(np.arange((256**3))*2 + 1, units=None, dtype=np.uint64)
        dm_ids = unyt_array((np.arange(256**3))*2, units=None, dtype=np.uint64)
        gas_coords = eagle_data_loader(initial_filename, 0, "Coordinates")
        dm_coords = eagle_data_loader(initial_filename, 1, "Coordinates")
    
    if suite == "Astrid":
        dm_ids = unyt_array(np.arange((256**3)) + 1, units=None, dtype=np.uint64)
        gas_ids = unyt_array(np.arange(256**3, 2*(256**3)), units=None, dtype=np.uint64)
        gas_coords = data_loader(initial_filename, 0, "Coordinates")
        dm_coords = data_loader(initial_filename, 1, "Coordinates")
        
    if suite == "SIMBA":
        gas_ids = data_loader(initial_filename, 0, "ParticleIDs")
        dm_ids = data_loader(initial_filename, 1, "ParticleIDs")
        gas_coords = data_loader(initial_filename, 0, "Coordinates")
        dm_coords =  data_loader(initial_filename, 1, "Coordinates")
        
    if suite == "IllustrisTNG":
        gas_ids = data_loader(initial_filename, 0, "ParticleIDs")
        dm_ids = data_loader(initial_filename, 1, "ParticleIDs")
        gas_coords = data_loader(initial_filename, 0, "Coordinates")
        dm_coords =  data_loader(initial_filename, 1, "Coordinates")
    
    final_filename = simdir + "snapshot_090.hdf5"
    
    if suite == "IllustrisTNG":
        for k in range(16):
            index = final_filename.index("snap")
            filename = final_filename[:index]+f"snapdir_090/snap_090.{k}.hdf5"
            
            dark_matter_final_ids = []
            dark_matter_final_coordinates = []
            gas_final_ids = []
            gas_final_coordinates = []
            gas_final_masses = []
            
            with h5py.File(filename, "r") as handle:
                dark_matter_final_ids_temp = unyt.unyt_array(handle["PartType1/ParticleIDs"][:], units=None)
                dark_matter_final_coordinates_temp = unyt.unyt_array(handle["PartType1/Coordinates"][:], units=None)
                gas_final_ids_temp = unyt.unyt_array(handle["PartType0/ParticleIDs"][:], units=None)
                gas_final_coordinates_temp = unyt.unyt_array(handle["PartType0/Coordinates"][:], units=None)
                gas_final_masses_temp = unyt.unyt_array(handle["PartType0/Masses"][:], units=None)    
            
            dark_matter_final_ids = np.append(dark_matter_final_ids, dark_matter_final_ids_temp)
            dark_matter_final_coordinates = np.append(dark_matter_final_coordinates, dark_matter_final_coordinates_temp)
            gas_final_ids = np.append(gas_final_ids, gas_final_ids_temp)
            gas_final_coordinates = np.append(gas_final_coordinates, gas_final_coordinates_temp)
            gas_final_masses = np.append(gas_final_masses, gas_final_masses_temp)
            
    else:
        with h5py.File(final_filename, "r") as handle:
            dark_matter_final_ids = unyt.unyt_array(handle["PartType1/ParticleIDs"][:], units=None)
            dark_matter_final_coordinates = unyt.unyt_array(handle["PartType1/Coordinates"][:], units=None)
            gas_final_ids = unyt.unyt_array(handle["PartType0/ParticleIDs"][:], units=None)
            gas_final_coordinates = unyt.unyt_array(handle["PartType0/Coordinates"][:], units=None)
            gas_final_masses = unyt.unyt_array(handle["PartType0/Masses"][:], units=None)    
    
    LOGGER.info("Finding nearest neighbors")
    
    dm_neighbors, gas_neighbors = find_closest_neighbours(dark_matter_coordinates = dm_coords,
                                                      dark_matter_ids = dm_ids,
                                                      boxsize = boxsize,
                                                      gas_coordinates = gas_coords,
                                                      gas_ids = gas_ids
                                                     )
        
    LOGGER.info("Running a variety of masking and reordering operations")
    
    still_gas_mask = np.isin(gas_ids, gas_final_ids)
    still_gas_mask2 = np.isin(gas_final_ids, gas_ids)
    
    gas_final_ids = gas_final_ids[still_gas_mask2]
    gas_final_masses = gas_final_masses[still_gas_mask2]

    masked_initial_IDs = gas_ids[still_gas_mask]
    masked_gas_neighbors = gas_neighbors[still_gas_mask]
    
    gas_final_ids, indicies = np.unique(gas_final_ids, return_index=True)
    gas_final_coordinates = gas_final_coordinates[indicies]
    gas_final_masses = gas_final_masses[indicies]

    gasi_IDs = pd.DataFrame(masked_initial_IDs, columns = ["IDs"])
    gasf_IDs = pd.DataFrame(gas_final_ids, columns = ["IDs"])

    gasi_IDs['index'] =gasi_IDs.index 
    gasi_IDs = gasi_IDs.set_index('IDs')
    gasi_IDs = gasi_IDs.reindex(index=gasf_IDs['IDs'])
    gasi_IDs = gasi_IDs.reset_index()

    gas_indices = gasi_IDs['index'].tolist()
    
    dmi_IDs = pd.DataFrame(dm_ids, columns = ["IDs"])
    dmf_IDs = pd.DataFrame(dark_matter_final_ids, columns = ["IDs"])

    dmi_IDs['index'] =dmi_IDs.index 
    dmi_IDs = dmi_IDs.set_index('IDs')
    dmi_IDs = dmi_IDs.reindex(index=dmf_IDs['IDs'])
    dmi_IDs = dmi_IDs.reset_index()

    dm_indices = dmi_IDs['index'].tolist()
    
    final_gas_neighbors = masked_gas_neighbors[gas_indices]
    final_dm_neighbors = dm_neighbors[dm_indices]
    
    return gas_neighbors, final_gas_neighbors, dm_neighbors, final_dm_neighbors, dark_matter_final_ids, gas_final_masses, indicies, still_gas_mask2, dark_matter_final_coordinates

#This function creates the halo ID array for a given particle type
def halo_id_assigner(suite, sim, particle_type, coordinates, boxsize):
    if suite == "IllustrisTNG":
        halo_coordinates_x = []
        halo_coordinates_y = []
        halo_coordinates_z = []
        halo_radii = []
        for m in range(16):
            halo_filename = f'/home/jovyan/FOF_Subfind/IllustrisTNG_L50n512/{sim[:2]}/{sim}/groups_090/fof_subhalo_tab_090.{m}.hdf5'
            with h5py.File(halo_filename, "r") as handle:
                halo_coordinates_temp = handle['Group/GroupPos'][:]
                halo_coordinates_temp_x = halo_coordinates_temp[:,0]
                halo_coordinates_temp_y = halo_coordinates_temp[:,1]
                halo_coordinates_temp_z = halo_coordinates_temp[:,2]
                halo_radii_temp = handle['Group/Group_R_Crit200'][:]
            halo_coordinates_x = np.append(halo_coordinates_x, halo_coordinates_temp_x)
            halo_coordinates_y = np.append(halo_coordinates_y, halo_coordinates_temp_y)
            halo_coordinates_z = np.append(halo_coordinates_z, halo_coordinates_temp_z)
            halo_radii = np.append(halo_radii, halo_radii_temp)
            #print(len(halo_radii))
            
        halo_coordinates = np.stack((halo_coordinates_x, halo_coordinates_y, halo_coordinates_z), axis = 1)
        #print(len(halo_radii))
            
    else:
        halo_filename = f'/home/jovyan/PUBLIC_RELEASE/FOF_Subfind/{suite}/L25n256/{sim[:2]}/{sim}/' + "groups_090.hdf5"
        with h5py.File(halo_filename, "r") as handle:
            halo_coordinates = handle['Group/GroupPos'][:]
            halo_radii = handle['Group/Group_R_Crit200'][:]
    LOGGER.info("Building tree")
    tree = cKDTree(coordinates, boxsize=boxsize)
    haloes = full(coordinates[:,0].size, -1, dtype=int32)
    
    #print("size of halo ID array in halo ID assigner:", len(haloes))

    # Search the tree in blocks of haloes as this improves load balancing
    # by allowing the tree to parallelise.
    block_size = 1024
    number_of_haloes = halo_radii.size
    number_of_blocks = 1 + number_of_haloes // block_size

    LOGGER.info("Beginning tree search")

    for block in range(number_of_blocks):
        #LOGGER.debug(f"Running tree search on block {block}/{number_of_blocks}")

        starting_index = block * block_size
        ending_index = (block + 1) * (block_size)

        if ending_index > number_of_haloes:
            ending_index = number_of_haloes + 1

        if starting_index >= ending_index:
            break

        particle_indicies = tree.query_ball_point(
            x=halo_coordinates[starting_index:ending_index],
            r=halo_radii[starting_index:ending_index],
            workers=-1,
        )

        for halo, indicies in enumerate(particle_indicies):
            haloes[indicies] = int32(halo + starting_index)
            
    return haloes, number_of_haloes

#This function loads the necessary data to calculate halo IDs
def halo_id_getter(suite, sim, particle_type):
    if particle_type == 0:
        final_coordinates = gas_final_coordinates
    if particle_type == 1:
        final_coordinates = dark_matter_final_coordinates
        
    halo_id_array, number_of_haloes = halo_id_assigner(suite, sim, particle_type, final_coordinates, boxsize)
    
    if particle_type == 0:
        return halo_id_array, number_of_haloes
    if particle_type == 1:
        return halo_id_array, number_of_haloes

#This function matches the order of arrays from dm_ids to gas_neighbors to get the lagrangian ID array
def gas_L_regions_getter(dm_ids, gas_neighbors, dm_haloes):
    dm_id_mask = np.isin(dm_ids, gas_neighbors)

    dm_ids = dm_ids[dm_id_mask]
    dm_haloes = dm_haloes[dm_id_mask]

    dmi_IDs = pd.DataFrame(dm_ids, columns = ["IDs"])
    dmf_IDs = pd.DataFrame(gas_neighbors, columns = ["IDs"])

    dmi_IDs['index'] =dmi_IDs.index 
    dmi_IDs = dmi_IDs.set_index('IDs')
    dmi_IDs = dmi_IDs.reindex(index=dmf_IDs['IDs'])
    dmi_IDs = dmi_IDs.reset_index()

    dm_indices = dmi_IDs['index'].tolist()

    gas_L_regions = dm_haloes[dm_indices]
    
    return gas_L_regions, dm_indices, dm_id_mask

#This function calculates the masses for each component
@njit(parallel=True)
def transfer_masses(particle_masses, number_of_groups, gas_L_regions, haloes):
    number_of_particles = len(gas_L_regions)
    initial_particle_mass = 0.00126659

    # Output mass arrays
    in_halo = zeros(number_of_groups, dtype=float64)
    in_halo_from_own_lr = zeros(number_of_groups, dtype=float64)
    in_halo_from_other_lr = zeros(number_of_groups, dtype=float64)
    in_halo_from_outside_lr = zeros(number_of_groups, dtype=float64)

    in_lr = zeros(number_of_groups, dtype=float64)
    in_other_halo_from_lr = zeros(number_of_groups, dtype=float64)
    outside_haloes = zeros(number_of_groups, dtype=float64)

    for particle in range(number_of_particles):
        id = haloes[particle]
        lr = gas_L_regions[particle]
        mass = particle_masses[particle]

        # Logic all collected here
        particle_in_halo = id != -1
        particle_in_lr = lr != -1
        particle_in_same = id == lr and particle_in_halo
        particle_in_other = id != lr and particle_in_halo and particle_in_lr
        particle_from_outside = particle_in_halo and not particle_in_lr
        particle_now_outside = particle_in_lr and not particle_in_halo

        # Propagate masses back to arrays
        in_halo[id] += mass if particle_in_halo else 0.0
        in_halo_from_own_lr[id] += mass if particle_in_same else 0.0
        in_halo_from_other_lr[id] += mass if particle_in_other else 0.0
        in_halo_from_outside_lr[id] += mass if particle_from_outside else 0.0

        in_lr[lr] += initial_particle_mass if particle_in_lr else 0.0
        in_other_halo_from_lr[lr] += initial_particle_mass if particle_in_other else 0.0
        outside_haloes[lr] += initial_particle_mass if particle_now_outside else 0.0
        
    return in_halo, in_halo_from_own_lr, in_halo_from_other_lr, in_halo_from_outside_lr, in_lr, in_other_halo_from_lr, outside_haloes

def tracer_operations():
    test_mask2 = np.isin(gas_final_ids, parent_IDs)
    gas_cells_with_a_tracer_IDs = gas_final_ids[test_mask2]
    test_mask3 = np.isin(parent_IDs, gas_cells_with_a_tracer_IDs)
    gas_only_parents_IDs = parent_IDs[test_mask3]
    gas_only_tracer_IDs = tracer_IDs[test_mask3]
    gas_cells_with_a_tracer_coords = gas_final_coordinates[test_mask2]
    gas_cells_with_a_tracer_masses = gas_final_masses[test_mask2]
    gas_cells_with_a_tracer_metallicity = gas_final_metallicity[test_mask2]

    gas_cells = pd.DataFrame({
        "IDs":gas_cells_with_a_tracer_IDs,
        "XCoordinates":gas_cells_with_a_tracer_coords[:,0],
        "YCoordinates":gas_cells_with_a_tracer_coords[:,1],
        "ZCoordinates":gas_cells_with_a_tracer_coords[:,2],
        "Masses":gas_cells_with_a_tracer_masses,
        "Metals":gas_cells_with_a_tracer_metallicity,

    })

    tracers = pd.DataFrame({
        "IDs":gas_only_parents_IDs,
        "tracer_IDs":gas_only_tracer_IDs
    })

    final_df = gas_cells.join(tracers.set_index('IDs'), on="IDs")

    gas_Xcoords_df = final_df['XCoordinates'].tolist()
    gas_Ycoords_df = final_df['YCoordinates'].tolist()
    gas_Zcoords_df = final_df['ZCoordinates'].tolist()
    
    dataframe_masses = np.array(final_df['Masses'].tolist())

    dataframe_parent_IDs = np.array(final_df['IDs'].tolist())
    dataframe_tracer_IDs = final_df['tracer_IDs'].tolist()
    dataframe_coordinates = np.stack((gas_Xcoords_df, gas_Ycoords_df, gas_Zcoords_df),axis=1)
    dataframe_z = np.array(final_df['Metals'].tolist())

    #LOGGER.info(f"Tracer IDs: {dataframe_tracer_IDs}")
    #LOGGER.info(f"Sorted tracer IDs: {np.sort(dataframe_tracer_IDs)}")
    #LOGGER.info(f"Coordinates: {dataframe_coordinates}")

    LOGGER.info("dataframe coordinates and IDs returned")
    
    return dataframe_coordinates, dataframe_tracer_IDs, dataframe_masses, dataframe_z

def spread_metric(gas_final_coordinates, gas_neighbors, gas_final_ids, dark_matter_coordinates, dark_matter_ids, dm_indices, dm_id_mask):
    boxsize = 25000
    half_boxsize = boxsize/2
    
    dark_matter_coordinates = dark_matter_coordinates[dm_id_mask]
    
    gas_ordered_dm_coordinates = dark_matter_coordinates[dm_indices]
        
    gas_coordinate = gas_final_coordinates
    dm_coordinate = gas_ordered_dm_coordinates

    dx = gas_coordinate[:,0] - dm_coordinate[:,0]
    dy = gas_coordinate[:,1] - dm_coordinate[:,1]
    dz = gas_coordinate[:,2] - dm_coordinate[:,2]

    dx -= (dx > half_boxsize) * boxsize
    dx += (dx <= -half_boxsize) * boxsize
    dy -= (dy > half_boxsize) * boxsize
    dy += (dy <= -half_boxsize) * boxsize
    dz -= (dz > half_boxsize) * boxsize
    dz += (dz <= -half_boxsize) * boxsize

    spread = np.sqrt(dx * dx + dy * dy + dz * dz)
        
    return spread

In [5]:
## New method of loading in data where arrays are left as global variables ##

suite = "Swift-EAGLE"
h = 0.6711

num_sims = 30 # for CV set

for i in range(num_sims):
    sim = list_of_strings2[i]
    if suite == "Swift-EAGLE":
        simdir = f'/home/jovyan/PUBLIC_RELEASE/Sims/{suite}/L25n256/{sim[:2]}/{sim}/'
        initial_filename = simdir + 'ICs/ics.hdf5'
        final_filename = simdir + 'snapshot_090.hdf5'
        boxsize = 25000

        gas_initial_ids = np.arange(256**3)*2 + 1
        gas_initial_coordinates = eagle_data_loader(initial_filename, 0, "Coordinates")
        dark_matter_initial_ids = np.arange(256**3)*2
        dark_matter_initial_coordinates = eagle_data_loader(initial_filename, 1, "Coordinates")

        with h5py.File(final_filename, "r") as handle:
            gas_final_ids = handle["PartType0/ParticleIDs"][:]
            gas_final_coordinates = handle["PartType0/Coordinates"][:] * h * 1000
            gas_final_masses = handle["PartType0/Masses"][:] * h
            dark_matter_final_ids = handle["PartType1/ParticleIDs"][:] 
            dark_matter_final_coordinates = handle["PartType1/Coordinates"][:] * h * 1000

    if suite == "SIMBA":
        simdir = f'/home/jovyan/PUBLIC_RELEASE/Sims/{suite}/L25n256/{sim[:2]}/{sim}/'
        initial_filename = simdir + 'ICs/ics'
        final_filename = simdir + 'snapshot_090.hdf5'
        boxsize = 25000

        gas_initial_ids = data_loader(initial_filename, 0, "ParticleIDs")
        gas_initial_coordinates = data_loader(initial_filename, 0, "Coordinates")
        dark_matter_initial_ids = data_loader(initial_filename, 1, "ParticleIDs")
        dark_matter_initial_coordinates = data_loader(initial_filename, 1, "Coordinates")

        with h5py.File(final_filename, "r") as handle:
            gas_final_ids = handle["PartType0/ParticleIDs"][:]
            gas_final_coordinates = handle["PartType0/Coordinates"][:]
            gas_final_masses = handle["PartType0/Masses"][:]
            dark_matter_final_ids = handle["PartType1/ParticleIDs"][:] 
            dark_matter_final_coordinates = handle["PartType1/Coordinates"][:]

    if suite == "Astrid":
        simdir = f'/home/jovyan/PUBLIC_RELEASE/Sims/{suite}/L25n256/{sim[:2]}/{sim}/'
        initial_filename = simdir + 'ICs/ics'
        final_filename = simdir + 'snapshot_090.hdf5'
        boxsize = 25000

        gas_initial_ids = np.arange(256**3, 2*(256**3)) + 1
        gas_initial_coordinates = data_loader(initial_filename, 0, "Coordinates")
        dark_matter_initial_ids = np.arange((256**3)) + 1
        dark_matter_initial_coordinates = data_loader(initial_filename, 1, "Coordinates")

        with h5py.File(final_filename, "r") as handle:
            gas_final_ids = handle["PartType0/ParticleIDs"][:]
            gas_final_coordinates = handle["PartType0/Coordinates"][:]
            gas_final_masses = handle["PartType0/Masses"][:]
            dark_matter_final_ids = handle["PartType1/ParticleIDs"][:] 
            dark_matter_final_coordinates = handle["PartType1/Coordinates"][:]

    if suite == "IllustrisTNG":
        simdir = f'/home/jovyan/Sims/{suite}_L50n512/{sim[:2]}/{sim}/'
        initial_filename = simdir + 'ICs/ics'
        final_filename = simdir + 'snapshot_090.hdf5'
        boxsize = 50000

        gas_initial_ids = data_loader(initial_filename, 0, "ParticleIDs")
        gas_initial_coordinates = data_loader(initial_filename, 0, "Coordinates")
        dark_matter_initial_ids = data_loader(initial_filename, 1, "ParticleIDs")
        dark_matter_initial_coordinates = data_loader(initial_filename, 1, "Coordinates")

        gas_initial_ids = np.uint64(gas_initial_ids + 1)

        parent_IDs = np.array([])
        tracer_IDs = np.array([])
        gas_xcoords = np.array([])
        gas_ycoords = np.array([])
        gas_zcoords = np.array([])
        dm_xcoords = np.array([])
        dm_ycoords = np.array([])
        dm_zcoords = np.array([])
        gas_final_ids = np.array([])
        dark_matter_final_ids = np.array([])
        gas_final_masses = np.array([])
        gas_final_metallicity = np.array([])


        for k in range(16):
            index = final_filename.index("snap")
            filename = final_filename[:index]+f"snapdir_090/snap_090.{k}.hdf5"    

            with h5py.File(filename, "r") as handle:
                parent_IDs_temp = handle["PartType3/ParentID"][:]
                tracer_IDs_temp = handle["PartType3/TracerID"][:]
                gas_final_coordinates_temp = handle["PartType0/Coordinates"][:]
                gas_xcoords_temp = gas_final_coordinates_temp[:,0]
                gas_ycoords_temp = gas_final_coordinates_temp[:,1]
                gas_zcoords_temp = gas_final_coordinates_temp[:,2]
                
                gas_metals_temp = handle["PartType0/GFM_Metallicity"][:]

                gas_final_ids_temp = handle["PartType0/ParticleIDs"][:]
                gas_final_masses_temp = handle["PartType0/Masses"][:]

                dark_matter_final_ids_temp = handle["PartType1/ParticleIDs"][:]
                dark_matter_final_coordinates_temp = handle["PartType1/Coordinates"][:]

            parent_IDs = np.append(parent_IDs, parent_IDs_temp)
            tracer_IDs = np.append(tracer_IDs, tracer_IDs_temp)
            gas_xcoords = np.append(gas_xcoords, gas_xcoords_temp)
            gas_ycoords = np.append(gas_ycoords, gas_ycoords_temp)
            gas_zcoords = np.append(gas_zcoords, gas_zcoords_temp)
            gas_final_ids = np.append(gas_final_ids, gas_final_ids_temp)

            gas_final_coordinates = np.stack((gas_xcoords, gas_ycoords, gas_zcoords), axis = 1)

            dm_xcoords_temp = dark_matter_final_coordinates_temp[:,0]
            dm_ycoords_temp = dark_matter_final_coordinates_temp[:,1]
            dm_zcoords_temp = dark_matter_final_coordinates_temp[:,2]

            dm_xcoords = np.append(dm_xcoords, dm_xcoords_temp)
            dm_ycoords = np.append(dm_ycoords, dm_ycoords_temp)
            dm_zcoords = np.append(dm_zcoords, dm_zcoords_temp)

            dark_matter_final_ids = np.append(dark_matter_final_ids, dark_matter_final_ids_temp)
            gas_final_masses = np.append(gas_final_masses, gas_final_masses_temp)
            
            gas_final_metallicity = np.append(gas_final_metallicity, gas_metals_temp)


        LOGGER.info("Starting tracer operations")
        gas_final_coordinates, gas_final_ids, gas_final_masses, gas_final_metallicity = tracer_operations()
        dark_matter_final_coordinates = np.stack((dm_xcoords, dm_ycoords, dm_zcoords), axis = 1)
        #gas_final_coordinates = np.stack((gas_xcoords, gas_ycoords, gas_zcoords), axis = 1)

    for final_coordinates in [gas_initial_coordinates, dark_matter_initial_coordinates, 
                              gas_final_coordinates, dark_matter_final_coordinates]:
        ind_check0 = np.where(final_coordinates[:,0] >= boxsize)[0]
        if ind_check0.any() == True:
            final_coordinates[ind_check0, 0] = boxsize - 0.1
            LOGGER.info(f"Fixed index {ind_check0} of the X coordinate of {final_coordinates}")
        ind_check1 = np.where(final_coordinates[:,1] >= boxsize)[0]
        if ind_check1.any() == True:
            final_coordinates[ind_check1, 1] = boxsize - 0.1
            LOGGER.info(f"Fixed index {ind_check0} of the Y coordinate of {final_coordinates}")
        ind_check2 = np.where(final_coordinates[:,2] >= boxsize)[0]
        if ind_check2.any() == True:
            final_coordinates[ind_check2, 2] = boxsize - 0.1
            LOGGER.info(f"Fixed index {ind_check0} of the Z coordinate of {final_coordinates}")


    LOGGER.info("Finding nearest neighbors")

    dm_neighbors, gas_neighbors = find_closest_neighbours(dark_matter_coordinates = dark_matter_initial_coordinates,
                                                      dark_matter_ids = dark_matter_initial_ids,
                                                      boxsize = boxsize,
                                                      gas_coordinates = gas_initial_coordinates,
                                                      gas_ids = gas_initial_ids
                                                     )
    
    gas_final_ids = np.uint64(gas_final_ids)
    still_gas_mask = np.isin(gas_initial_ids, gas_final_ids)
    still_gas_mask2 = np.isin(gas_final_ids, gas_initial_ids)

    masked_initial_IDs = gas_initial_ids[still_gas_mask]
    masked_gas_neighbors = gas_neighbors[still_gas_mask]

    gas_final_ids = gas_final_ids[still_gas_mask2]
    gas_final_coordinates = gas_final_coordinates[still_gas_mask2]
    gas_final_masses = gas_final_masses[still_gas_mask2]
    #gas_final_metallicity = gas_final_metallicity[still_gas_mask2]

    gas_final_ids, indicies = np.unique(gas_final_ids, return_index=True)
    gas_final_coordinates = gas_final_coordinates[indicies]
    gas_final_masses = gas_final_masses[indicies]
    #gas_final_metallicity = gas_final_metallicity[indicies]

    gasi_IDs = pd.DataFrame(masked_initial_IDs, columns = ["IDs"])
    gasf_IDs = pd.DataFrame(gas_final_ids, columns = ["IDs"])

    gasi_IDs['index'] =gasi_IDs.index 
    gasi_IDs = gasi_IDs.set_index('IDs')
    gasi_IDs = gasi_IDs.reindex(index=gasf_IDs['IDs'])
    gasi_IDs = gasi_IDs.reset_index()

    gas_indices = gasi_IDs['index'].tolist()

    dmi_IDs = pd.DataFrame(dark_matter_initial_ids, columns = ["IDs"])
    dmf_IDs = pd.DataFrame(dark_matter_final_ids, columns = ["IDs"])

    dmi_IDs['index'] =dmi_IDs.index 
    dmi_IDs = dmi_IDs.set_index('IDs')
    dmi_IDs = dmi_IDs.reindex(index=dmf_IDs['IDs'])
    dmi_IDs = dmi_IDs.reset_index()

    dm_indices = dmi_IDs['index'].tolist()

    final_gas_neighbors = masked_gas_neighbors[gas_indices]
    final_dm_neighbors = dm_neighbors[dm_indices]

    #Calling the rest of the functions

    dm_haloes, number_of_groups = halo_id_getter(suite, sim, 1)

    gas_haloes, _, = halo_id_getter(suite, sim, 0)

    gas_L_regions, dm_indices, dm_id_mask = gas_L_regions_getter(dark_matter_final_ids, final_gas_neighbors, dm_haloes)

    (in_halo, in_halo_from_own_lr, in_halo_from_other_lr, 
     in_halo_from_outside_lr, in_lr, in_other_halo_from_lr,
     outside_haloes) = transfer_masses(gas_final_masses, number_of_groups, gas_L_regions, gas_haloes)
    
    LOGGER.info('Computing gas spread')

    gas_spread = spread_metric(gas_final_coordinates, final_gas_neighbors, gas_final_ids, 
                               dark_matter_final_coordinates, dark_matter_final_ids, dm_indices, dm_id_mask)
    
    #for particle_type in ["gas"]:
        #pickle_filepath = f"/home/jovyan/home/spread_metric/{suite}/{particle_type}_{sim}_snap090.hdf5"
        #pickle_file = h5py.File(pickle_filepath, 'w')
        
    hf = h5py.File(f'/home/jovyan/home/spread_metric/{suite}/Spread_Output_{sim}_snap090.hdf5', 'w')
    gas = hf.create_group('gas')
    gas.create_dataset('spread', data = gas_spread)

        #file_in_halo = pickle_file.create_dataset('in_halo', data = in_halo)
        #file_in_halo_from_own_lr = pickle_file.create_dataset('in_halo_from_own_lr', data = in_halo_from_own_lr)
        #file_in_halo_from_other_lr = pickle_file.create_dataset('in_halo_from_other_lr', data = in_halo_from_other_lr)
        #file_in_halo_from_outside_lr = pickle_file.create_dataset('in_halo_from_outside_lr', data = in_halo_from_outside_lr)
        #file_in_lr = pickle_file.create_dataset('in_lr', data = in_lr)
        #file_in_other_halo_from_lr = pickle_file.create_dataset('in_other_halo_from_lr', data = in_other_halo_from_lr)
        #file_outside_haloes = pickle_file.create_dataset('outside_haloes', data = outside_haloes)
        
        #file_gas_coordinates = pickle_file.create_dataset('gas_final_coordinates', data = gas_final_coordinates)
        #file_gas_masses = pickle_file.create_dataset('gas_final_masses', data = gas_final_masses)
        #file_gas_metallicities = pickle_file.create_dataset('gas_final_metallicities', data = gas_final_metallicity)

        #file_L_regions = pickle_file.create_dataset('lagrangian_IDs', data = gas_L_regions)
        #file_haloes = pickle_file.create_dataset('halo_IDs', data = gas_haloes)

        #file_mask1 = pickle_file.create_dataset('mask1', data = indicies)
        #file_mask2 = pickle_file.create_dataset('mask2', data = still_gas_mask2)

        #pickle_file.close()

    

INFO:TransferLogger:Fixed index [    0     1     2 ... 65533 65534 65535] of the X coordinate
INFO:TransferLogger:Fixed index [      61       62       63 ... 16711918 16711919 16711920] of the Y coordinate
INFO:TransferLogger:Fixed index [       0      256      512 ... 16776448 16776704 16776960] of the Z coordinate
INFO:TransferLogger:Fixed index [    0     1     2 ... 65533 65534 65535] of the X coordinate
INFO:TransferLogger:Fixed index [      61       62       63 ... 16711918 16711919 16711920] of the Y coordinate
INFO:TransferLogger:Fixed index [       0      256      512 ... 16776448 16776704 16776960] of the Z coordinate
INFO:TransferLogger:Fixed index [] of the Y coordinate of [[  268.59107176   295.34782876    94.17218176]
 [  160.93992076   321.54757276    97.16528776]
 [  206.38010176   242.80069876   369.28291576]
 ...
 [24713.14825195 24649.13873395 24958.19370595]
 [24838.40906695 24633.00548995 24979.69574995]
 [24686.06936695 24832.71813895 24864.81685195]]
INFO:Transfe

INFO:TransferLogger:Finished tree build
INFO:TransferLogger:Building tree
INFO:TransferLogger:Beginning tree search
INFO:TransferLogger:Building tree
INFO:TransferLogger:Beginning tree search
INFO:TransferLogger:Computing gas spread
INFO:TransferLogger:Fixed index [    0     1     2 ... 65533 65534 65535] of the X coordinate
INFO:TransferLogger:Fixed index [      53       61       62 ... 16711919 16711920 16711921] of the Y coordinate
INFO:TransferLogger:Fixed index [       0      256      512 ... 16776448 16776704 16776960] of the Z coordinate
INFO:TransferLogger:Fixed index [    0     1     2 ... 65533 65534 65535] of the X coordinate
INFO:TransferLogger:Fixed index [      53       61       62 ... 16711919 16711920 16711921] of the Y coordinate
INFO:TransferLogger:Fixed index [       0      256      512 ... 16776448 16776704 16776960] of the Z coordinate
INFO:TransferLogger:Fixed index [15845785] of the X coordinate of [[  119.15122001    60.65143301   104.74941401]
 [  131.15719901 

INFO:TransferLogger:Fixed index [16524699] of the X coordinate of [[  186.72902975   175.32704075   214.23070775]
 [   37.03346375    90.59395475    60.02534975]
 [   26.70523475    99.62696075    53.88478475]
 ...
 [24491.32770204 24306.03699204 24425.68070004]
 [24662.25687204 24284.44099404 24547.11624504]
 [24295.64836404 24576.94664004 24767.25717804]]
INFO:TransferLogger:Fixed index [16524699] of the Z coordinate of [[  186.72902975   175.32704075   214.23070775]
 [   37.03346375    90.59395475    60.02534975]
 [   26.70523475    99.62696075    53.88478475]
 ...
 [24491.32770204 24306.03699204 24425.68070004]
 [24662.25687204 24284.44099404 24547.11624504]
 [24295.64836404 24576.94664004 24767.25717804]]
INFO:TransferLogger:Finding nearest neighbors
INFO:TransferLogger:Building dark matter tree for spread metric
INFO:TransferLogger:Finished tree build
INFO:TransferLogger:Building tree
INFO:TransferLogger:Beginning tree search
INFO:TransferLogger:Building tree
INFO:TransferLogger:

INFO:TransferLogger:Fixed index [      53       61       62 ... 16711919 16711920 16711921] of the Y coordinate
INFO:TransferLogger:Fixed index [       0      256      512 ... 16776448 16776704 16776960] of the Z coordinate
INFO:TransferLogger:Fixed index [    0     1     2 ... 65533 65534 65535] of the X coordinate
INFO:TransferLogger:Fixed index [      53       61       62 ... 16711919 16711920 16711921] of the Y coordinate
INFO:TransferLogger:Fixed index [       0      256      512 ... 16776448 16776704 16776960] of the Z coordinate
INFO:TransferLogger:Fixed index [] of the Z coordinate of [[7.56417400e+01 2.45510572e+02 2.08128700e+01]
 [1.76246341e+02 1.52583355e+02 8.93791570e+01]
 [3.07325593e+02 6.31592800e+01 1.59797680e+02]
 ...
 [2.43631744e+04 2.47115089e+04 2.47408360e+04]
 [2.46697732e+04 2.47139853e+04 2.42990039e+04]
 [2.49868948e+04 2.44713491e+04 2.45554513e+04]]
INFO:TransferLogger:Finding nearest neighbors
INFO:TransferLogger:Building dark matter tree for spread met

INFO:TransferLogger:Fixed index [16308702 16470405 16601669 16649599] of the Y coordinate of [[   64.02646153   294.20033953   356.20997953]
 [   63.89224153   294.51575653   359.65943353]
 [   47.93348353   235.64686453   363.43101553]
 ...
 [24220.38323083 24717.93677083 24874.50440083]
 [24926.61531583 24279.92993383 24396.46644883]
 [24438.14846983 24870.60530983 24900.63032383]]
INFO:TransferLogger:Fixed index [16308702 16470405 16601669 16649599] of the Z coordinate of [[   64.02646153   294.20033953   356.20997953]
 [   63.89224153   294.51575653   359.65943353]
 [   47.93348353   235.64686453   363.43101553]
 ...
 [24220.38323083 24717.93677083 24874.50440083]
 [24926.61531583 24279.92993383 24396.46644883]
 [24438.14846983 24870.60530983 24900.63032383]]
INFO:TransferLogger:Finding nearest neighbors
INFO:TransferLogger:Building dark matter tree for spread metric
INFO:TransferLogger:Finished tree build
INFO:TransferLogger:Building tree
INFO:TransferLogger:Beginning tree search


INFO:TransferLogger:Finding nearest neighbors
INFO:TransferLogger:Building dark matter tree for spread metric
INFO:TransferLogger:Finished tree build
INFO:TransferLogger:Building tree
INFO:TransferLogger:Beginning tree search
INFO:TransferLogger:Building tree
INFO:TransferLogger:Beginning tree search
INFO:TransferLogger:Computing gas spread
INFO:TransferLogger:Fixed index [    0     1     2 ... 65533 65534 65535] of the X coordinate
INFO:TransferLogger:Fixed index [      53       61       62 ... 16711919 16711920 16711921] of the Y coordinate
INFO:TransferLogger:Fixed index [       0      256      512 ... 16776448 16776704 16776960] of the Z coordinate
INFO:TransferLogger:Fixed index [    0     1     2 ... 65533 65534 65535] of the X coordinate
INFO:TransferLogger:Fixed index [      53       61       62 ... 16711919 16711920 16711921] of the Y coordinate
INFO:TransferLogger:Fixed index [       0      256      512 ... 16776448 16776704 16776960] of the Z coordinate
INFO:TransferLogger:F

INFO:TransferLogger:Fixed index [16681949 16711246] of the Z coordinate of [[   63.77514659   298.05615659   366.76337459]
 [   47.90363159   243.32795159   357.54917159]
 [   80.11643159   229.14760859   352.97898059]
 ...
 [24582.12676123 24491.92421023 24267.06544423]
 [24554.88010123 24621.86259223 24439.67236423]
 [24543.84721723 24887.30948623 24935.80317223]]
INFO:TransferLogger:Finding nearest neighbors
INFO:TransferLogger:Building dark matter tree for spread metric
INFO:TransferLogger:Finished tree build
INFO:TransferLogger:Building tree
INFO:TransferLogger:Beginning tree search
INFO:TransferLogger:Building tree
INFO:TransferLogger:Beginning tree search
INFO:TransferLogger:Computing gas spread
INFO:TransferLogger:Fixed index [    0     1     2 ... 65533 65534 65535] of the X coordinate
INFO:TransferLogger:Fixed index [      53       61       62 ... 16711919 16711920 16711921] of the Y coordinate
INFO:TransferLogger:Fixed index [       0      256      512 ... 16776448 16776704 