In [None]:
# Basic Packages
import numpy as np
import h5py
import logging
import os
import shutil
import gc
import matplotlib.pyplot as plt
import pickle
# Physics-related Packages
from astropy.cosmology import Planck15

In [None]:
# preparations for read box info from the url
import requests
import time

baseUrl = 'http://www.tng-project.org/api/'

def get(path, params=None, max_retries=5, backoff_factor=2):
    # make HTTP GET request to path
    headers = {"api-key":"API KEY"}

    attempt = 0
    while attempt < max_retries:
        try:
            r = requests.get(path, params=params, headers=headers)
            r.raise_for_status()

            if r.headers['content-type'] == 'application/json':
                return r.json()
            
            if 'content-disposition' in r.headers:
                filename = r.headers['content-disposition'].split("filename=")[1]
                with open(filename, 'wb') as f:
                    f.write(r.content)
                return filename
            return r  # fallback
        
        except Exception as e:
            attempt += 1
            wait = backoff_factor ** attempt
            print(f"[Retry {attempt}/{max_retries}] Request failed: {e}. Retrying in {wait}s...")
            time.sleep(wait)

    raise RuntimeError(f"Failed to GET {path} after {max_retries} retries.")

# Issue a request to the API root
r = get(baseUrl)

# Print out all the simulation names
names = [sim['name'] for sim in r['simulations']]
# Get the index of TNG300-1
i = names.index('TNG-Cluster')
# Get the info of simulation Illustris-3
sim = get( r['simulations'][i]['url'] )
sim.keys()

# get the snaps info this simulation
snaps = get(sim['snapshots'])

# Sim Box parameters
Snap_Index = 99 # the snapshots index in the total 100 snapshots taking at different z
BoxSize = sim['boxsize'] # unit: ckpc/h
Redshift = snaps[Snap_Index]['redshift'] # current redshift of our current snap


In [None]:
def get_snaptime(snap):
    snap_redshift = snaps[snap]['redshift'] 
    t_cosmic = Planck15.age(snap_redshift).value  # age of the Universe at that redshift
    return t_cosmic

snaptimes = np.array([get_snaptime(snap) for snap in range(72,100)])

In [None]:
def from_time2snap(t_cosmic):
    target = np.argmin(np.abs(snaptimes-t_cosmic))+ 72
    return target

In [None]:
def get_sfr_range(halo_id_99, center_snap,sfr_dict, tau=0.5):

    # calculate snap range based on tau
    t_center = snaptimes[center_snap-72]
    t_begin = t_center - tau
    t_end = t_center + tau
    snap_begin = max(from_time2snap(t_begin), 72)
    #print(snap_begin)
    snap_end = min(from_time2snap(t_end), 99)
    #print(snap_end)


    # extract sfr data within range
    snap_range = sfr_dict[halo_id_99]['snaps']
    sfrs = sfr_dict[halo_id_99]['avgsfr']
    galnums = sfr_dict[halo_id_99]['galnum']

    if np.sum(snap_range==snap_begin)==1 and np.sum(snap_range==snap_end)==1:
        begin_idx =  np.where(snap_range==snap_begin)[0][0]
        end_idx =  np.where(snap_range==snap_end)[0][0]
        avg_sfr_vec = sfrs[begin_idx:end_idx+1]
        galnum_vec = galnums[begin_idx:end_idx+1]
        selected_snaps = snap_range[begin_idx:end_idx+1]

        return avg_sfr_vec, galnum_vec, np.array(selected_snaps)


    else:
        return [], [], []

In [None]:
def get_scores_fromdict(halo_id_99, center_snap, features_dict):

    halo_data = features_dict[halo_id_99][center_snap]
    if 'label_score_all_tau2.0' not in halo_data  or 'label_score_pre_tau2.0' not in halo_data:
        return None, None  

    all_merger_score = halo_data['label_score_all_tau2.0']

    pre_merger_score = halo_data['label_score_pre_tau2.0']

    post_merger_score = all_merger_score-pre_merger_score

    return pre_merger_score, post_merger_score

In [None]:
with open('/users_path/merger_trace/data/tng_cluster/tng_cluster_products/sfr_tracking.pkl', 'rb') as f:
    sfr_dict = pickle.load(f)


In [None]:
with open('/users_path/merger_trace/data/tng_cluster/tng_cluster_products/feats_labels_dict_tngcluster.pkl', 'rb') as f:
    features_dict = pickle.load(f)

In [None]:
snap_cut = from_time2snap(snaptimes[-1]-2)

In [None]:
snap_cut

In [None]:
norm_curvature_vec = []  
curvature_vec = []
meta_info_vec = []
pre_merger_score_vec = []
post_merger_score_vec = []


for center_snap in range(72, 100):
#for center_snap in range(72, snap_cut+1):
    for halo_id_99 in sfr_dict.keys():

        avg_sfr_vec, galnum_vec, selected_snaps = get_sfr_range(halo_id_99, center_snap, sfr_dict, tau=0.5)
        
        pre_merger_score, post_merger_score = get_scores_fromdict(halo_id_99, center_snap, features_dict)

        snap_range = np.array(sfr_dict[halo_id_99]['snaps'])

        center_snap_index = np.where(selected_snaps==center_snap)[0]

        if (
            len(selected_snaps)>=3
            and pre_merger_score is not None 
            and post_merger_score is not None
            and np.sum(snap_range==center_snap)==1
            #and center_snap_index<=len(selected_snaps)*2/3
            #and center_snap_index>=len(selected_snaps)*1/3
        ):
            norm_sfr_vec = avg_sfr_vec/np.mean(avg_sfr_vec)

            center_sfr_norm = sfr_dict[halo_id_99]['avgsfr'][np.where((snap_range==center_snap))[0][0]]/np.mean(avg_sfr_vec)
            begin_sfr_norm = norm_sfr_vec[0]
            end_sfr_norm = norm_sfr_vec[-1]
            
            curvature_value_norm = 2*center_sfr_norm - begin_sfr_norm - end_sfr_norm
            print(curvature_value_norm)

            center_sfr = sfr_dict[halo_id_99]['avgsfr'][np.where((snap_range==center_snap))[0][0]]
            begin_sfr = avg_sfr_vec[0]
            end_sfr = avg_sfr_vec[-1]

            curvature_value = 2*center_sfr - begin_sfr - end_sfr

            norm_curvature_vec.append(curvature_value_norm)
            curvature_vec.append(curvature_value)
            meta_info_vec.append((halo_id_99, center_snap))

            pre_merger_score_vec.append(pre_merger_score)
            post_merger_score_vec.append(post_merger_score)

        else:
            print(f'not enough sfr values or do not find progenitor at {center_snap} for {halo_id_99}')


In [None]:
pre_merger_score_vec = np.array(pre_merger_score_vec)
norm_curvature_vec = np.array(norm_curvature_vec)
print(len(norm_curvature_vec))
percentiles = np.percentile(pre_merger_score_vec, [0, 20, 40, 60, 80, 100])
print(percentiles)
bin_centers = []
bin_means = []
bin_errors = []

for i in range(5):
    mask = (pre_merger_score_vec >= percentiles[i]) & (pre_merger_score_vec < percentiles[i+1])
    values = norm_curvature_vec[mask]

    if len(values) > 0:
        bin_center = np.mean(pre_merger_score_vec[mask])
        bin_centers.append(bin_center)
        bin_means.append(np.mean(values))
        bin_errors.append(np.std(values) / np.sqrt(len(values)))  # error bar: 标准误
    else:
        bin_centers.append(np.mean(pre_merger_score_vec[mask]))
        bin_means.append(np.nan)
        bin_errors.append(0)

# plot
plt.errorbar(bin_centers, bin_means, yerr=bin_errors, fmt='o', capsize=4)
plt.xlabel('Pre-merger score (binned by quantiles)')
plt.ylabel('Normalized curvature')
plt.title('Normalized SFR curvature vs. Pre-merger score')
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
'''
df = pd.DataFrame({
    'bin_center': bin_centers,
    'mean_curvature': bin_means,
    'error_bar': bin_errors,
    'sample_count': len(norm_curvature_vec)
})

df.to_csv('sfr_curvature_vs_pre_merger_score.csv', index=False)
'''


In [None]:
import pandas as pd
post_merger_score_vec = np.array(post_merger_score_vec)
norm_curvature_vec = np.array(norm_curvature_vec)
# print(len(norm_curvature_vec))

score_vec = np.array(post_merger_score_vec)
curv_vec = np.array(norm_curvature_vec)
valid_mask = ~np.isnan(score_vec) & ~np.isnan(curv_vec)
score_vec = score_vec[valid_mask]
curv_vec = curv_vec[valid_mask]


percentiles = np.percentile(score_vec, [0, 20, 40, 60, 80, 100])
print(percentiles)
# initialize
bin_centers, bin_means, bin_errors, bin_counts = [], [], [], []

# loop over all bins
for i in range(5):
    mask = (score_vec >= percentiles[i]) & (score_vec < percentiles[i+1])
    #print(mask)
    scores_in_bin = score_vec[mask]
    #print(scores_in_bin)
    curvs_in_bin = curv_vec[mask]

    if len(curvs_in_bin) > 0:
        bin_center = np.mean(scores_in_bin)
        mean_val = np.mean(curvs_in_bin)
        error_val = np.std(curvs_in_bin) / np.sqrt(len(curvs_in_bin))
        count_val = len(curvs_in_bin)
    else:
        bin_center = (percentiles[i] + percentiles[i+1]) / 2
        mean_val = np.nan
        error_val = 0
        count_val = 0

    bin_centers.append(bin_center)
    bin_means.append(mean_val)
    bin_errors.append(error_val)
    bin_counts.append(count_val)

# save to .csv
df = pd.DataFrame({
    'bin_center': bin_centers,
    'mean_curvature': bin_means,
    'error_bar': bin_errors,
    'sample_count': bin_counts
})
df.to_csv('post_merger_sfr_curvature_bins.csv', index=False)

# plot
plt.errorbar(bin_centers, bin_means, yerr=bin_errors, fmt='o', capsize=4)
plt.xlabel('Post-merger score (binned by quantiles)')
plt.ylabel('Normalized curvature')
plt.title('Normalized SFR curvature vs. post-merger score')
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
'''
df = pd.DataFrame({
    'bin_center': bin_centers,
    'mean_curvature': bin_means,
    'error_bar': bin_errors,
    'sample_count': len(norm_curvature_vec)
})

df.to_csv('post_merger_sfr_curvature_bins.csv', index=False)
'''