In [None]:
import illustris_python as il
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
basePath ='./sims.TNG/TNG50-1/output'

main_subhalo_id=329508
main_halo_id=23


missing_data_snaps=[31,32,35]
full_snaps=[33,40,50,59,67,72,78,84,91,99]

In [None]:
tree_fields=['SubfindID','SnapNum']
subfindid_snapnum_tree=il.sublink.loadTree(basePath, 99, 329508, fields=tree_fields, onlyMPB=True)
subfind_id=subfindid_snapnum_tree['SubfindID']
snapshot_num=subfindid_snapnum_tree['SnapNum']

In [None]:
def calculate_R90(Rmin, Rmax, MassRatio,Masses,Coordinates,CenterOfMass):
    # 定义精度和最大迭代次数
    precision = 1
    max_iter = 1000
    
    # 初始化左右边界
    left, right = Rmin, Rmax
    
    # 迭代计算R_90
    for i in range(max_iter):
        # 计算中间点
        mid = (left + right) / 2
        
        # 计算该点的质量占比
        ratio =  mass_ratio_within_radius(mid, Masses, Coordinates, CenterOfMass)
        
        # 如果质量占比小于90%，则将左边界更新为mid
        if ratio < 0.9:
            left = mid
        # 否则将右边界更新为mid
        else:
            right = mid
        
        # 如果当前的区间长度小于给定的精度，或者达到最大迭代次数，就停止迭代
        if abs(right - left) < precision or i == max_iter - 1:
            R90 = (left + right) / 2
            break
    
    return R90

def mass_ratio_within_radius(Radius, Masses, Coordinates, CenterOfMass):
    # 计算每个恒星到质心的距离
    dist_to_com = np.sqrt(np.sum((Coordinates - CenterOfMass)**2, axis=1))
    
    # 找到距离质心小于R_200的所有恒星的下标
    within_radius_idx = np.where(dist_to_com <= Radius)[0]
    
    # 计算在R_200范围内的恒星总质量
    total_mass_within_radius = np.sum(Masses[within_radius_idx])
    
    # 计算星系总质量
    total_mass = np.sum(Masses)
    
    # 计算质量比值
    mass_ratio = total_mass_within_radius / total_mass
    
    return mass_ratio

def Binary_DataIter(Data_Set1,Data_Set2):
    for i in range(0,min(len(Data_Set1),len(Data_Set2))):
        yield Data_Set1[i],Data_Set2[i]

In [None]:
R_90/R_Crit

In [None]:
stars_R_virial_ratio_list=[]
snaps_list=[]

In [None]:
for current_subfind_id,current_snapshot_num in Binary_DataIter(subfind_id,snapshot_num):
    if current_snapshot_num not in full_snaps: continue
    if current_snapshot_num < 30: break
    
    gas_subset_fields=['Coordinates','Masses','NeutralHydrogenAbundance']
    subhalos_fields=['SubhaloCM']
    
    gas_data=il.snapshot.loadSubhalo(basePath,current_snapshot_num,current_subfind_id,'gas',fields=gas_subset_fields)
    center_mass=il.groupcat.loadSubhalos(basePath,current_snapshot_num, fields=subhalos_fields)[current_subfind_id]
    halo_id=il.groupcat.loadSubhalos(basePath,current_snapshot_num, fields=['SubhaloGrNr'])[current_subfind_id]
    
    gas_mass=gas_data['Masses']
    gas_HI_abundance=gas_data['NeutralHydrogenAbundance']
    gas_HI_mass=gas_mass*gas_HI_abundance
    gas_coordiantes=gas_data['Coordinates']
    
    halos_fields=['Group_R_Crit200','Group_R_Mean200']
    radius=il.groupcat.loadHalos(basePath, current_snapshot_num, fields=halos_fields)
    R_Crit=radius['Group_R_Crit200'][halo_id]
    R_Mean=radius['Group_R_Mean200'][halo_id]
    
    R_90=calculate_R90(0,R_Mean,mass_ratio_within_radius,gas_HI_mass,gas_coordiantes,center_mass)
    Virial_ratio=R_90/R_Crit
    print(Virial_ratio)
    
    R_virial_ratio_list.append(Virial_ratio)
    snaps_list.append(current_snapshot_num)

In [None]:
R_virial_ratio_array=np.array(R_virial_ratio_list)
snaps_array=np.array(snaps_list)

In [None]:
plt.plot(snaps_array,R_virial_ratio_array)
new_x = [2,1.5,1,0.7,0.5,0.4,0.3,0.2,0.1,0]
plt.xticks(snaps_array, new_x)
plt.show()

In [None]:
stars_R_virial_ratio_list=[]
snaps_list=[]
for current_subfind_id,current_snapshot_num in Binary_DataIter(subfind_id,snapshot_num):
    #if current_snapshot_num in missing_data_snaps: continue
    if current_snapshot_num not in full_snaps: continue
    if current_snapshot_num < 30: break
    
    stars_subset_fields=['Coordinates','Masses']
    subhalos_fields=['SubhaloCM']
    
    stars_data=il.snapshot.loadSubhalo(basePath,current_snapshot_num,current_subfind_id,'stars',fields=stars_subset_fields)
    center_mass=il.groupcat.loadSubhalos(basePath,current_snapshot_num, fields=subhalos_fields)[current_subfind_id]
    halo_id=il.groupcat.loadSubhalos(basePath,current_snapshot_num, fields=['SubhaloGrNr'])[current_subfind_id]
    
    stars_mass=gas_data['Masses']
    stars_coordiantes=gas_data['Coordinates']
    
    halos_fields=['Group_R_Crit200','Group_R_Mean200']
    radius=il.groupcat.loadHalos(basePath, current_snapshot_num, fields=halos_fields)
    R_Crit=radius['Group_R_Crit200'][halo_id]
    R_Mean=radius['Group_R_Mean200'][halo_id]
    
    R_90=calculate_R90(0,R_Mean,mass_ratio_within_radius,stars_mass,stars_coordiantes,center_mass)
    Virial_ratio=R_90/R_Crit
    print(Virial_ratio)
    
    stars_R_virial_ratio_list.append(Virial_ratio)
    snaps_list.append(current_snapshot_num)

In [None]:
stars_R_virial_ratio_array=np.array(stars_R_virial_ratio_list)
snaps_array=np.array(snaps_list)
plt.plot(snaps_array,stars_R_virial_ratio_array)

In [None]:
R_virial_ratio_array=np.array([1.64081724107,
1.53052903296,
1.42584120527,
1.35851458861,
1.29856381891,
1.22872999012,
1.18612066942,
1.11899748304,
1.07291643638,
0.659726475015]).T
snaps_array=np.array(full_snaps)
snaps_array=np.flipud(snaps_array)

plt.plot(snaps_array,R_virial_ratio_array)
new_x = [2,1.5,1,0.7,0.5,0.4,0.3,0.2,0.1,0]
new_x=np.flipud(np.array(new_x))
plt.xticks(snaps_array, new_x)
plt.show()

In [None]:
R_virial_ratio_array

In [None]:
snaps_array