# generate_features.ipynb

This notebook is designed to generate features used for machine learning models that characterize structural properties of galaxy clusters in the TNG-Cluster simulation.  
It processes group catalogs and merger trees to produce inputs for machine learning algorithms.

---

## Main Functions

- Load group catalogs and cluster merger event files from TNG-Cluster.
- Track primary halos and extract main progenitor information using merger trees.
- Fit galaxy distributions in phase space using Gaussian Mixture Models (GMM).
- Adopt 2-component GMM to model major/minor substructure; features are extracted to quantify asymmetry, kinematics, and morphology.
- Save feature and label dictionaries for ML training and analysis.

---

## Input Requirements

- **Local access to**:
  - TNG-Cluster merger event file (`cluster_mergers.hdf5`)
  - Sorted group catalog (`TargetHalo_MergerCat_099.hdf5`)
  - TNG-Cluster catalog (`TNG-Cluster_Catalog.hdf5`)
  - Sublink merger trees (`tng_cluster_mpbs/*.hdf5`)

- Required parameters such as snapshot number (`snapnum=99`) should match the filenames.

- All file paths must be modified from placeholder (`users_path/...`) to your actual local structure.

---

## Output

- A pickled `.pkl` file containing two dictionaries:
  - `feat_dict`: dictionary of feature vectors indexed by halo ID.
  - `label_dict`: corresponding classification labels for each halo (e.g., merger/non-merger).

- Output is saved to:  
  `feats_labels_TNGCluster_fullsample_final.pkl`


## Feature Vector Design 

The structure includes both raw GMM outputs and derived quantities.  
Derived features are placed **at the end** to allow easy removal for nonlinear models.
The detail of the features would be described in our paper.


### Base GMM Output Features (Raw Features)

| Feature Name      | Description                                      |
|-------------------|--------------------------------------------------|
| `n0`, `n1`        | Number of galaxies/subhalos in each component    |
| `mean_r_0`, `mean_r_1` | Mean position (projected) in each component  |
| `mean_v_0`, `mean_v_1` | Mean velocity (LOS) in each component        |
| `std_r_0`, `std_r_1`   | Std. dev. of projected positions              |
| `std_v_0`, `std_v_1`   | Std. dev. of line-of-sight velocities         |
| `bic_1`, `bic_2`       | Bayesian Information Criterion (BIC) values   |
| `elongation_ratio_xy` | $\lambda_2 / \lambda_1$ — Flattening ratio in XY-plane from PCA |




In [None]:
# Basic Packages
import numpy as np
import h5py
import logging
import os
import shutil
import gc
import matplotlib.pyplot as plt
import pandas as pd
from scipy.stats import linregress, pearsonr, spearmanr
import seaborn as sns
from scipy.optimize import curve_fit
import scipy.stats as stats
import math
# Physics-related Packages
from astropy.cosmology import Planck15
from sklearn.mixture import GaussianMixture
from sklearn.metrics import pairwise_distances
from astropy.cosmology import Planck15 as cosmo
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture

In [None]:
def get_subhalo_maxM(Halo_ID, Sub_GrNr, Sub_Mass):
    
    """
    Finds the indices of the top three most massive subhalos for a given halo.

    Parameters:
    - Halo_ID: Array of halo IDs (single value or array with specific halo ID to match).
    - Sub_GrNr: Array indicating the group number (halo) each subhalo belongs to.
    - Sub_Mass: Array of subhalo masses.

    Returns:
    - Numpy array (CenterSub_Index, SecondSub_Index, ThirdSub_Index):
      - CenterSub_Index: Index of the most massive subhalo in the halo.
      - SecondSub_Index: Index of the second most massive subhalo.
      - ThirdSub_Index: Index of the third most massive subhalo.
    """
        
    find_Sub = np.where(Sub_GrNr == Halo_ID)[0]
    find_Sub_Mass = Sub_Mass[find_Sub]
    find_Sub_Mass_Sorted = np.argsort(find_Sub_Mass)

    CenterSub_Index = np.where(Sub_GrNr == Halo_ID)[0][find_Sub_Mass_Sorted[-1]]

    CenterSub_Index = np.array(CenterSub_Index)
    
    return CenterSub_Index

def Get_AvgSFR(SubhaloGrNr, SubhaloSFR, FOF_Halo_IDs):
    AvgSFR = np.zeros(FOF_Halo_IDs.shape)

    for i in range(len(FOF_Halo_IDs)):
        current_fof_id = FOF_Halo_IDs[i]
        sub_in_fof = (SubhaloGrNr == current_fof_id)
        subSFR_in_fof = SubhaloSFR[sub_in_fof]
        AvgSFR[i] = np.mean(subSFR_in_fof)

    return AvgSFR
    
def Get_HaloIDs(TargetHalo_cat, SubhaloMassDef):
    """
    Extracts halo IDs and their properties from the HDF5 catalog.

    Parameters:
    - TargetHalo_cat: Path to the HDF5 file containing the target halo catalog.

    Returns:
    - Target_Halo_IDs: List of selected halo IDs.
    - Subhalo_MaxMasses: Maximum subhalo masses for each selected halo.
    - Target_Halo_Rs_Crit200: Halo critical radius R_Crit200.
    - Target_GroupPoses: Positions of the selected halos.
    - Galaxy_nums: Number of subhalos per halo.
    """

    with h5py.File(TargetHalo_cat, 'r') as Target_hdf:
        # Read FOF Halo Info
        FOF_Halo_IDs =  Target_hdf['Group/FOF_Halo_IDs'][:]
        GroupFirstSub = Target_hdf['Group/GroupFirstSub'][:]
        GroupPos = Target_hdf['Group/GroupPos'][:]
        Group_R_Crit200 = Target_hdf['Group/Group_R_Crit200'][:]
        Group_Nsubs = Target_hdf['Group/GroupNsubs'][:]

        # Read Subhalo Info
        SubhaloGrNr = Target_hdf['Subhalo/SubhaloGrNr'][:]
        SubhaloMass =  Target_hdf[f'Subhalo/{SubhaloMassDef}'][:]
        SubhaloSpin = Target_hdf['Subhalo/SubhaloSpin'][:]
        SubhaloVelDisp = Target_hdf['Subhalo/SubhaloVelDisp'][:]
        SubhaloSFR = Target_hdf['Subhalo/SubhaloSFR'][:]
        Subhalo_IDs = Target_hdf['Subhalo/Subhalo_IDs'][:]
        Subhalo_MassRad = 2*Target_hdf['Subhalo/SubhaloHalfmassRad'][:]

    # Find FOF halos with subhalos and Read Info
    Indices_HaloWithSub = np.where( GroupFirstSub != -1)[0]
    Target_Halo_IDs = FOF_Halo_IDs[Indices_HaloWithSub]
    Target_Halo_Rs_Crit200 = Group_R_Crit200[Indices_HaloWithSub]
    Target_GroupPoses = GroupPos[Indices_HaloWithSub]

    # Initialize array to get the central subhalo, second massive subhalo, and the third massive subhalo

    # Initialize galaxy numbers
    Galaxy_nums = Group_Nsubs[np.isin(FOF_Halo_IDs, Target_Halo_IDs)]
    AvgSFR = Get_AvgSFR(SubhaloGrNr, SubhaloSFR, FOF_Halo_IDs)

    # Initialize the Center subhalo info of each fof halo (consider it as the FOF halo itself using SubFind algorithms)
    Center_SubhaloIDs = np.zeros(Target_Halo_IDs.shape)
    Center_SubhaloMasses = np.zeros(Target_Halo_IDs.shape)
    Center_SubhaloVelDisp = np.zeros(Target_Halo_IDs.shape)
    Center_SubhaloSFR = np.zeros(Target_Halo_IDs.shape)
    Center_SubhaloSpin = np.zeros((Target_Halo_IDs.shape[0],3))
    Center_SubhaloMassRad = np.zeros(Target_Halo_IDs.shape)
      
    for i in range(len(Target_Halo_IDs)):
        # locate the current FOF Halo ID
        Halo_ID = Target_Halo_IDs[i]

        # Find the indices of center, second, third subhalos
        CenterSub_Index = get_subhalo_maxM(Halo_ID, SubhaloGrNr, SubhaloMass)

        Center_SubhaloIDs[i] = Subhalo_IDs[CenterSub_Index]
        Center_SubhaloMasses[i] = SubhaloMass[CenterSub_Index]
        Center_SubhaloVelDisp[i] = SubhaloVelDisp[CenterSub_Index]
        Center_SubhaloSFR[i] = SubhaloSFR[CenterSub_Index]
        Center_SubhaloSpin[i] = SubhaloSpin[CenterSub_Index]
        Center_SubhaloMassRad[i] = Subhalo_MassRad[CenterSub_Index]

    return (Target_Halo_IDs, Galaxy_nums, Target_Halo_Rs_Crit200, Target_GroupPoses, AvgSFR,
            Center_SubhaloIDs, Center_SubhaloMasses, Center_SubhaloVelDisp, Center_SubhaloSFR, Center_SubhaloSpin, Center_SubhaloMassRad
            )

In [None]:
# preparations for read box info from the url
import requests
import time

baseUrl = 'http://www.tng-project.org/api/'

def get(path, params=None, max_retries=5, backoff_factor=2):
    # make HTTP GET request to path
    headers = {"api-key":"API KEY"}

    attempt = 0
    while attempt < max_retries:
        try:
            r = requests.get(path, params=params, headers=headers)
            r.raise_for_status()

            if r.headers['content-type'] == 'application/json':
                return r.json()
            
            if 'content-disposition' in r.headers:
                filename = r.headers['content-disposition'].split("filename=")[1]
                with open(filename, 'wb') as f:
                    f.write(r.content)
                return filename
            return r  # fallback
        
        except Exception as e:
            attempt += 1
            wait = backoff_factor ** attempt
            print(f"[Retry {attempt}/{max_retries}] Request failed: {e}. Retrying in {wait}s...")
            time.sleep(wait)

    raise RuntimeError(f"Failed to GET {path} after {max_retries} retries.")

# Issue a request to the API root
r = get(baseUrl)

# Print out all the simulation names
names = [sim['name'] for sim in r['simulations']]
# Get the index of TNG300-1
i = names.index('TNG-Cluster')
# Get the info of simulation Illustris-3
sim = get( r['simulations'][i]['url'] )
sim.keys()

# get the snaps info this simulation
snaps = get(sim['snapshots'])

# Sim Box parameters
Snap_Index = 99 # the snapshots index in the total 100 snapshots taking at different z
BoxSize = sim['boxsize'] # unit: ckpc/h
Redshift = snaps[Snap_Index]['redshift'] # current redshift of our current snap


In [None]:
def get_enlongation_projection(relative_Pos, relative_Vel, plane='xy'):
    """
    Project subhalo positions onto elongation axis in given 2D plane.

    Parameters:
    - relative_Pos: (N, 3)
    - relative_Vel: (N, 3)
    - plane: 'xy', 'yz', or 'xz'

    Returns:
    - r_proj: projected position along elongation direction
    - v_los: velocity along remaining axis
    """
    if plane == 'xy':
        proj = relative_Pos[:, :2]
        v_los = relative_Vel[:, 2]
    elif plane == 'yz':
        proj = relative_Pos[:, 1:]
        v_los = relative_Vel[:, 0]
    elif plane == 'xz':
        proj = relative_Pos[:, [0, 2]]
        v_los = relative_Vel[:, 1]
    else:
        raise ValueError("Plane must be one of 'xy', 'yz', or 'xz'.")

    pca = PCA(n_components=1)
    axis = pca.fit(proj).components_[0]
    r_proj = proj @ axis
    return r_proj, v_los


def compute_flattening_ratio(relative_Pos, plane='xy'):
    """
    Compute the flattening ratio (λ2 / λ1) from 2D PCA in the specified projection plane.

    Parameters:
    -----------
    relative_Pos : ndarray of shape (N, 3)
        Relative positions of subhalos.
    plane : str
        Projection plane: 'xy', 'yz', or 'xz'

    Returns:
    --------
    flattening_ratio : float
        Ratio of the second to the first PCA eigenvalue (λ2 / λ1), indicating shape elongation.
    """

    if plane == 'xy':
        proj = relative_Pos[:, [0, 1]]
    elif plane == 'yz':
        proj = relative_Pos[:, [1, 2]]
    elif plane == 'xz':
        proj = relative_Pos[:, [0, 2]]
    else:
        raise ValueError("Plane must be one of 'xy', 'yz', or 'xz'.")

    pca = PCA(n_components=2)
    pca.fit(proj)
    λ1, λ2 = pca.explained_variance_  # Already sorted: λ1 ≥ λ2
    flattening_ratio = λ2 / (λ1 + 1e-8)  # epsilon to avoid divide-by-zero

    return flattening_ratio

def get_relative_pos_vel(cat_name, fofhalo_id):
    with h5py.File(cat_name, 'r') as f:
        SubhaloVel = f['Subhalo/SubhaloVel'][:]
        SubhaloGrNr = f['Subhalo/SubhaloGrNr'][:]
        SubhaloPos = f['Subhalo/SubhaloPos'][:]
        SubhaloMass = f['Subhalo/SubhaloMass'][:]

    if np.sum(SubhaloGrNr == fofhalo_id) == 0:
        return np.nan, np.nan

    center_index = np.argmax(SubhaloMass[SubhaloGrNr == fofhalo_id])
    center_sub_pos = SubhaloPos[SubhaloGrNr == fofhalo_id][center_index]
    center_sub_vel = SubhaloVel[SubhaloGrNr == fofhalo_id][center_index]

    relative_Pos = SubhaloPos[SubhaloGrNr == fofhalo_id] - center_sub_pos
    relative_Vel = SubhaloVel[SubhaloGrNr == fofhalo_id] - center_sub_vel

    return relative_Pos, relative_Vel


def plot_phase_space_projections(relative_Pos, relative_Vel, ax=None, if_plot=True):
    """
    Plot 3 standard projections and 1 elongation-based projection of phase space.

    Parameters:
    - cat_name: path to catalog file
    - fofhalo_id: FOF halo ID
    - snap_num: snapshot number
    - ax: optional matplotlib axes (4 subplots)
    - if_plot: whether to plot
    - bins: unused (placeholder for velocity histogram)
    """
    # Elongation axis projection
    # Elongation axis projection in 3 planes
    r_enlong_xy, v_los_xy = get_enlongation_projection(relative_Pos, relative_Vel, plane='xy')
    r_enlong_yz, v_los_yz = get_enlongation_projection(relative_Pos, relative_Vel, plane='yz')
    r_enlong_xz, v_los_xz = get_enlongation_projection(relative_Pos, relative_Vel, plane='xz')

    if ax is None and if_plot:
        fig, ax = plt.subplots(1, 3, figsize=(18, 10))  # 6 plots: 3 basic + 3 elongation
    
    ax = ax.flatten()

    if if_plot:
        # Elongation-based projections
        ax[0].scatter(r_enlong_xy, v_los_xy, s=8, alpha=0.6)
        ax[0].set_title("Elongation XY")
        ax[0].set_xlabel("Elongated R (xy)")
        ax[0].set_ylabel("Vz")

        ax[1].scatter(r_enlong_yz, v_los_yz, s=8, alpha=0.6)
        ax[1].set_title("Elongation YZ")
        ax[1].set_xlabel("Elongated R (yz)")
        ax[1].set_ylabel("Vx")

        ax[2].scatter(r_enlong_xz, v_los_xz, s=8, alpha=0.6)
        ax[2].set_title("Elongation XZ")
        ax[2].set_xlabel("Elongated R (xz)")
        ax[2].set_ylabel("Vy")

        plt.tight_layout()

In [None]:
def extract_gmm_features(R, V, random_state=42):
    """
    Fit GMM to (R, V) phase space and extract component features.

    Parameters:
    -----------
    R, V : array-like
        Projected position and LOS velocity.
    bic_thresh : float
        Threshold for BIC improvement to accept 2-component model.

    Returns:
    --------
    features : dict
        Dictionary of extracted features from best model.
        Returns NaNs if 2-component fit is not valid.
    """
    X = np.vstack([R, V]).T
    X = (X - np.mean(X, axis=0)) / (np.std(X, axis=0) + 1e-8)  # z-score normalization
    
    if len(R) <150:
        max_iterations = 150
    
    else:
        max_iterations = 300

    gmm1 = GaussianMixture(n_components=1, covariance_type='full',
                       max_iter=max_iterations, n_init=5, random_state=random_state).fit(X)
    
    if not gmm1.converged_:
        print("GMM1 did not converge, retrying with n_init=5...")
        gmm1 = GaussianMixture(n_components=1, covariance_type='full',
                            max_iter=1.5*max_iterations, n_init=10, random_state=random_state).fit(X)

    gmm2 = GaussianMixture(n_components=2, covariance_type='full',
                       max_iter=max_iterations, n_init=5, random_state=random_state).fit(X)
    
    if not gmm2.converged_:
        print("GMM2 did not converge, retrying with n_init=5...")
        gmm2 = GaussianMixture(n_components=2, covariance_type='full',
                           max_iter=1.5*max_iterations, n_init=10, random_state=random_state).fit(X)


    bic1 = gmm1.bic(X)
    bic2 = gmm2.bic(X)

    # Valid GMM
    labels = gmm2.predict(X)
    group0 = (labels == 0)
    group1 = (labels == 1)

    R0, V0 = R[group0], V[group0]
    R1, V1 = R[group1], V[group1]


    return {
        'bic_1': bic1, 'bic_2': bic2,
        'mean_r_0': np.mean(R0), 'mean_r_1': np.mean(R1),
        'std_r': np.std(R),
        'mean_v_0': np.mean(V0), 'mean_v_1': np.mean(V1),
        'std_r_0': np.std(R0),   'std_r_1': np.std(R1),
        'std_v_0': np.std(V0),   'std_v_1': np.std(V1),
        'std_v': np.std(V),
        'n0': len(R0), 'n1': len(R1),

    }


In [None]:
def calculate_all_features_for_snapshot(cat_name, fofhalo_id, random_state = 42):

    relative_Pos, relative_Vel = get_relative_pos_vel(cat_name, fofhalo_id)

    if isinstance(relative_Pos, float) and np.isnan(relative_Pos):
        return {}

    
    all_features = {}

    for plane in ['xy', 'yz', 'xz']:
        # PCA projection onto elongation axis
        R_proj, V_los = get_enlongation_projection(relative_Pos, relative_Vel, plane=plane)

        flat_ratio = compute_flattening_ratio(relative_Pos, plane=plane)

        # GMM feature extraction
        gmm_feats = extract_gmm_features(R_proj, V_los, random_state=random_state)

        # Append to results
        gmm_feats['elongation_ratio'] = flat_ratio
        
        all_features[plane] = gmm_feats

    return all_features


In [None]:
def get_snaptime(snap):
    snap_redshift = snaps[snap]['redshift'] 
    t_cosmic = cosmo.age(snap_redshift).value  # age of the Universe at that redshift
    return t_cosmic

In [None]:
def get_merger_score_all(
    target_halo: int,
    snapshot: int,
    HaloID: np.ndarray,
    Snap_coll: np.ndarray,
    N_window: int = 100,
    tau: float = 5.0,
) -> float:
    
    merger_snaps = Snap_coll[HaloID == target_halo]
    score = 0.0
    t_cosmic_current_snap = get_snaptime(snapshot)
    if len(merger_snaps)>=1:
        for merger_snap in merger_snaps:
            if merger_snap in range(snapshot - N_window, snapshot + N_window + 1):
                t_cosmic_merger = get_snaptime(merger_snap)
                delta_t = np.abs(t_cosmic_current_snap - t_cosmic_merger)
                score += np.exp(-delta_t / tau)

    return score

In [None]:
def get_merger_score_pre(
    target_halo: int,
    snapshot: int,
    HaloID: np.ndarray,
    Snap_coll: np.ndarray,
    N_window: int = 100,
    tau: float = 5.0,
) -> float:
    
    merger_snaps = Snap_coll[HaloID == target_halo]

    score = 0.0
    t_cosmic_current_snap = get_snaptime(snapshot)
    if len(merger_snaps)>=1:
        for merger_snap in merger_snaps:
            if merger_snap in range(snapshot - N_window, snapshot+ 1):
                t_cosmic_merger = get_snaptime(merger_snap)
                delta_t = np.abs(t_cosmic_current_snap - t_cosmic_merger)
                score += np.exp(-delta_t / tau)

    return score

In [None]:
def get_mass_ratio(cat_name, halo_id_snap):
    with h5py.File(cat_name, 'r') as f:
        SubhaloGrNr_snap = f['Subhalo/SubhaloGrNr'][:]
        SubhaloMass_snap = f['Subhalo/SubhaloMass'][:]
    
    SubhaloMass_in = SubhaloMass_snap[SubhaloGrNr_snap==halo_id_snap]
    
    if len(SubhaloMass_in) ==1:
        mass_ratio = 0
    elif len(SubhaloMass_in) ==0:
        print(f'check halos {halo_id_snap} at {cat_name}!')
        mass_ratio = np.nan
    else:
        sorted_mass = np.sort(SubhaloMass_in)[::-1]
        mass_ratio = sorted_mass[1] / sorted_mass[0]
    
    return mass_ratio

In [None]:
def get_fof_id(cat_name, sub_id):
    with h5py.File(cat_name, 'r') as f:
        SubhaloGrNr_snap = f['Subhalo/SubhaloGrNr'][:]
        SubhaloIDs_snap = f['Subhalo/Subhalo_IDs'][:]
    
    fof_snap = SubhaloGrNr_snap[SubhaloIDs_snap==sub_id]

    return int(fof_snap[0])

In [None]:
def get_feats_labels_dict_tv(fof_halo_id_at99, Center_SubhaloIDs, Target_Halo_IDs, HaloID, Snap_coll, feats_labels_dict, taus, random_state =42):
    sub1st_at99 = int(Center_SubhaloIDs[np.array(list(Target_Halo_IDs))==fof_halo_id_at99][0])
    feats_labels_dict[fof_halo_id_at99] = {}

    with h5py.File(f'/users_path/merger_trace/data/tng_cluster/tng_cluster_mpbs/sublink_mpb_{sub1st_at99}.hdf5', 'r') as mpb_f:
        SubfindID = mpb_f['SubfindID'][:]
        SnapNum = mpb_f['SnapNum'][:]

    for snap in range(72,100):
        feats_labels_dict[fof_halo_id_at99][snap] = {}
        # get the sorted halo catalog at target snap
        cat_name_temp = f'/users_path/merger_trace/data/tng_cluster/tng_cluster_targetcat/targethalo_cat_0{snap}/TargetHalo_MergerCat_0{snap}.hdf5'

        # get the progenitor id at target snap
        progenitor_id_atmergersnap = SubfindID[SnapNum==snap]

        current_redshift =  snaps[snap]['redshift']
        
        if len(progenitor_id_atmergersnap) >=1:
            # get the fof id at target snap
            #fof_atmergercat = get(f'https://www.tng-project.org/api/TNG-Cluster/snapshots/{snap}/subhalos/{progenitor_id_atmergersnap[0]}')['related']['parent_halo']
            #fofid_atmergercat = int(fof_atmergercat.strip('/').split('/')[-1])
            fofid_atmergercat = get_fof_id(cat_name_temp, progenitor_id_atmergersnap)
            # print(fof_atmergercat)
            
            #center_sub_url = get(fof_atmergercat)['child_subhalos']['results'][0]['url']
            #second_sub_url = get(fof_atmergercat)['child_subhalos']['results'][1]['url']

            #center_mass = get(center_sub_url)['mass']
            #second_mass = get(second_sub_url)['mass']
            mass_ratio = get_mass_ratio(cat_name_temp, fofid_atmergercat)

            features = calculate_all_features_for_snapshot(cat_name_temp, fofid_atmergercat, random_state=random_state)
            
            label_scores_all = np.zeros(len(taus))
            label_scores_pre = np.zeros(len(taus))
            for k,tau_temp in enumerate(taus):
                label_scores_all[k] = get_merger_score_all(fof_halo_id_at99, snap, HaloID, Snap_coll, N_window=100,tau=tau_temp)
                label_scores_pre[k] = get_merger_score_pre(fof_halo_id_at99, snap, HaloID, Snap_coll, N_window=100,tau=tau_temp)

            # store results in 3-level nested dict
            # every tau has a key
            feats_labels_dict[fof_halo_id_at99][snap] = {
                'features': features,
                'redshift': current_redshift,
                'mass_ratio': mass_ratio, #second_mass/center_mass,
                **{f'label_score_all_tau{tau_temp:.1f}': score_all for tau_temp, score_all in zip(taus, label_scores_all)},
                **{f'label_score_pre_tau{tau_temp:.1f}': score_pre for tau_temp, score_pre in zip(taus, label_scores_pre)}
            }

            print(f'calculate features and labels for {fof_halo_id_at99} at {snap} where it was {fofid_atmergercat}')
    return feats_labels_dict

In [None]:
with h5py.File('/users_path/merger_trace/data/tng_cluster/tng_cluster_catalog/TNG-Cluster_Catalog.hdf5', 'r') as f:
    origID_parentsim = f['origID'][:]
    haloID__TNGCluster = f['haloID'][:]

cat_name ='/users_path/merger_trace/data/tng_cluster/tng_cluster_targetcat/targethalo_cat_099/TargetHalo_MergerCat_099.hdf5'

results = Get_HaloIDs(cat_name, SubhaloMassDef='SubhaloMass')

(Target_Halo_IDs, Galaxy_nums, Target_Halo_Rs_Crit200, Target_GroupPoses, AvgSFR,
Center_SubhaloIDs, Center_SubhaloMasses, Center_SubhaloVelDisp, Center_SubhaloSFR, Center_SubhaloSpin, Center_SubhaloMassRad
) = results


with h5py.File('/users_path/merger_trace/data/tng_cluster/tng_cluster_cluster_mergers/cluster_mergers.hdf5', 'r') as f:
    HaloID = f['HaloID'][:]
    Snap_coll = f['Snap_coll'][:]
    print(f.keys())


HaloID_Mergercat = np.array(list(set(HaloID))).astype(int)


Merger_Target_Halo_ID = range(len(haloID__TNGCluster))
Merger_Target_Halo_ID = np.array(Merger_Target_Halo_ID)

Snap_coll_eachhalo = dict()

for i in range(len(Merger_Target_Halo_ID)):
    current_halo_id = Merger_Target_Halo_ID[i]
    rel_merger_snap = Snap_coll[HaloID==current_halo_id]
    halo_id_frommergercat2TNGCluster = haloID__TNGCluster[current_halo_id] 
    Snap_coll_eachhalo[halo_id_frommergercat2TNGCluster] = rel_merger_snap

In [None]:
MergerHaloID_TNGCluster = np.array([haloID__TNGCluster[i] for i in HaloID]).astype(int)

In [None]:
feats_labels_dict = {}
FOFHaloID = np.array(list(Snap_coll_eachhalo.keys())).astype(int)

for i in range(1,len(Snap_coll_eachhalo)):
    fofhalo_id = FOFHaloID[i]
    # caution!! HaloID is different for TNG-Cluster between merger catalog and fof halo id
    feats_labels_dict = get_feats_labels_dict_tv(fofhalo_id, Center_SubhaloIDs, Target_Halo_IDs, MergerHaloID_TNGCluster, Snap_coll, feats_labels_dict, taus=np.linspace(0.1, 4.0, 40), random_state=2025)