In [5]:
%matplotlib inline

import sys
import os

# Add the 'src' directory to the path so Python can find your modules
sys.path.append(os.path.abspath("src"))

import subprocess 
from PyAstronomy import pyasl
import numpy as np
import pandas as pd
import dask
import dask.dataframe as dd
import dask.array as da
import math
import matplotlib.pyplot as plt

# -----------------------------------------------------

import star_pop as st_pop
from star_pop import generate_star_population, plot_distributions,quantize_population
from star_pop import SamplerContainer,save_samplers,load_samplers
from star_processing import process_matched_population


# -----------------------------------------------------

from matplotlib.pyplot import cm
import matplotlib.colors as mcolors
import matplotlib.ticker as ticker
import scipy.optimize as optimize
import numpy as np
import scipy as sp
from astropy import constants as const
from pylab import rcParams
import math
import random

from scipy.stats import kde
from matplotlib.collections import PathCollection
import datashader as ds
import datashader.transfer_functions as tf

import mesa_reader as mr
import matplotlib as mpl
import glob
import inspect
from pathlib import Path  # Import the Path class


plt.rcParams.update({'font.size': 20})
plt.rcParams['figure.figsize'] = 10, 10

## Population

In [None]:
# --- Workflow Configuration ---
# Options: 'generate_on_the_fly', 'create_and_save_samplers', 'load_samplers_and_generate'
MODE = 'load_samplers_and_generate'  

SAVE_TO_CSV = True
SEED = random.randint(0,1e9)  # Set to None for a random run

# --- Simulation Configuration ---
NUM_STARS = 100_000_000
DISK_TYPE = "thin"
GALAXY_SHAPE = "spiral"    # "elliptical" or "spiral"
MZAMS_MIN, MZAMS_MAX = 0.11, 250.0
T_MAX_AGE = 13.6
FEH_MIN = -3.0
FEH_MAX = 0.5

SAMPLER_RESOLUTION = 1_000_000 # Controls accuracy vs. setup time for samplers


# If I save samplers the code doesn't have to sample probabilities everytime for a given random star


OUTPUT_DIR = f"{GALAXY_SHAPE}_{DISK_TYPE}_disk_run"

SAMPLER_FILE = os.path.join("samplers",f"res{SAMPLER_RESOLUTION}_{MZAMS_MIN}_{MZAMS_MAX}_samplers.pkl")
POP_DIR = os.path.join(OUTPUT_DIR,"population")

# --- Spiral Arm Configuration ---
ARM_MEMBERSHIP_PROBABILITY = [
        {'name': 'Scutum-Centaurus', 'probability': 0.30},
        {'name': 'Perseus', 'probability': 0.25},
        {'name': 'Sagittarius-Carina', 'probability': 0.20},
        {'name': 'Norma-Outer', 'probability': 0.15},
        {'name': 'Local (Orion Spur)', 'probability': 0.10}
]
# Defines the physical shape of the logarithmic spiral arms.
# a_kpc: Radius at theta=0. b: related to pitch angle. theta0_rad: rotation offset.
    # Parameters are now based on the synthesis table in Part IV of the research paper,
    # primarily derived from Reid et al. (2019) and VallÃ©e. This provides a much more
    # observationally-grounded model of the Milky Way's current structure.
ARM_PARAMS = {
        'Scutum-Centaurus': {
            'ref_radius_kpc': 3.14,
            'ref_angle_deg': 25,
            'pitch_angle_deg': 12.0,
            'radial_range_kpc': [3.0, 16.0]
        },
        'Sagittarius-Carina': {
            'ref_radius_kpc': 4.93,
            'ref_angle_deg': -45,
            'pitch_angle_deg': 13.1,
            'radial_range_kpc': [4.0, 16.0]
        },
        'Perseus': {
            'ref_radius_kpc': 9.94,
            'ref_angle_deg': 150,
            'pitch_angle_deg': 9.5,
            'radial_range_kpc': [6.0, 18.0]
        },
        'Norma-Outer': {
            'ref_radius_kpc': 4.00,
            'ref_angle_deg': -100,
            'pitch_angle_deg': 13.0,
            'radial_range_kpc': [3.5, 20.0]
        },
        'Local (Orion Spur)': {
            'ref_radius_kpc': 8.15,
            'ref_angle_deg': 0, # Sun is reference for Local arm
            'pitch_angle_deg': 10.1,
            'radial_range_kpc': [6.0, 9.0]
        }
}

# --- Seed the random number generators for reproducibility ---
if SEED is not None:
    np.random.seed(SEED)
    random.seed(SEED)

# --- Main Workflow Logic ---
samplers = None
if MODE == 'create_and_save_samplers':
    samplers = SamplerContainer(MZAMS_MIN, MZAMS_MAX, resolution=SAMPLER_RESOLUTION)
    save_samplers(samplers, SAMPLER_FILE)

else: # For 'generate_on_the_fly' or 'load_samplers_and_generate'
    if MODE == 'load_samplers_and_generate':
        if not os.path.exists(SAMPLER_FILE):
            raise FileNotFoundError(f"Sampler file not found: {SAMPLER_FILE}. Please run in 'create_and_save_samplers' mode first.")
        samplers = load_samplers(SAMPLER_FILE)
    else: # generate_on_the_fly
        samplers = SamplerContainer(DISK_TYPE, MZAMS_MIN, MZAMS_MAX, resolution=SAMPLER_RESOLUTION)


In [None]:
# 1. Generate the stellar population
population_df = generate_star_population(
    num_stars=NUM_STARS,
    samplers=samplers,
    disk_type=DISK_TYPE,
    galaxy_shape=GALAXY_SHAPE,
    arm_params_list=ARM_PARAMS,
    arm_membership_prob=ARM_MEMBERSHIP_PROBABILITY,
    seed=SEED,
    FeH_min=FEH_MIN,  # Pass the new limits
    FeH_max=FEH_MAX   # Pass the new limits
    )

# 2. Optionally save the original data
if SAVE_TO_CSV:
    if not os.path.exists(POP_DIR):
        os.makedirs(POP_DIR)
    filepath = os.path.join(POP_DIR, f"star_population_{DISK_TYPE}_{NUM_STARS}.csv")
    population_df.to_csv(filepath, index=False)
    print(f"Original data saved to {filepath}")

# You can now work with the data directly! For example:
print("\nHere are the first 5 stars from the simulation:")
print(population_df.head())

print("\nHere are some basic statistics about the population:")
print(population_df.describe())


# 3. Call the plotting function to visualize the results
plot_distributions(population_df, disk_type=DISK_TYPE)

### Take the datafile

In [8]:
population_df = pd.read_csv(f"{OUTPUT_DIR}/population/star_population_{DISK_TYPE}_{NUM_STARS}.csv",
                            low_memory=False)

### Galactic position

In [None]:
plt.style.use('dark_background')
fig, ax = plt.subplots(1,2,figsize=(30, 15))

# --- DYNAMIC MARKER SIZE CALCULATION ---
# To make the scatter points of specific sizes depending on how many stars are there
#  log-log interpolation to smooth the transition between these magnitudes.
total_stars = len(population_df)
    


if GALAXY_SHAPE == "spiral":
    log_s = np.interp(np.log10(total_stars), [6, 8], [np.log10(1.5), np.log10(0.001)])
    star_marker_size = 10**log_s
    
    arm_names = list(ARM_PARAMS.keys())
    colors = plt.cm.viridis(np.linspace(0.1, 1, len(arm_names)))
    color_map = {name: color for name, color in zip(arm_names, colors)}

    # 1. Initialize lists to hold ALL stars from ALL arms
    all_x = []
    all_y = []
    all_z = []
    all_c = []

    # 2. Loop to calculate coordinates
    for arm_name in arm_names:
        arm_stars = population_df['Arm Position'] == arm_name
        
        # Skip if no stars in this arm to prevent errors
        if not arm_stars.any():
            continue

        R = population_df['Radial Distance [kpc]'][arm_stars]
        alpha = population_df['alpha_angle_deg'][arm_stars]
        theta = random.choices([-1, 1], k=len(R))

        x = R * np.cos(np.deg2rad(alpha))
        y = R * np.sin(np.deg2rad(alpha))
        z = population_df['Vertical Distance [kpc]'][arm_stars] * theta
        
        # Determine color for this batch
        if arm_name == "Scutum-Centaurus":
            c_val = mcolors.to_rgba('grey') 
        else:
            c_val = color_map[arm_name]
        
        # Append to master lists
        all_x.append(x)
        all_y.append(y)
        all_z.append(z)
        
        colors_for_arm = np.tile(c_val, (len(x), 1))
        all_c.append(colors_for_arm)
        
        # Add dummy legend entry
        legend_color = 'grey' if arm_name == "Scutum-Centaurus" else c_val
        ax[0].scatter([], [], color=legend_color, s=2, label=arm_name)

    # 3. Concatenate everything into single large arrays
    final_x = np.concatenate(all_x)
    final_y = np.concatenate(all_y)
    final_z = np.concatenate(all_z)
    final_c = np.concatenate(all_c) 

    # --- PLOT 1: Top-Down (X vs Y) ---
    ax[0].scatter(final_x, final_y, c=final_c, s=star_marker_size, alpha=0.6)

    # --- PLOT 2: Edge-On (X vs Z) --- 
    # Sorting by Y (Depth)
    sort_indices = np.argsort(final_y) 
    
    x_sorted = final_x[sort_indices]
    z_sorted = final_z[sort_indices]
    c_sorted = final_c[sort_indices]
    
    ax[1].scatter(x_sorted, z_sorted, c=c_sorted, s=star_marker_size, alpha=0.6)

    # Add Legend and Center marker
    ax[0].legend(loc='upper right', title='Spiral Arms', markerscale=5, fontsize=20,title_fontsize=20)
    ax[0].plot(0, 0, 'X', color='yellow', markersize=12, label='Galactic Center')

else:
    # Colorscale based on radial distance
    
    log_s = np.interp(np.log10(total_stars), [6, 8], [np.log10(1.5), np.log10(0.05)])
    star_marker_size = 10**log_s
    
    R = population_df['Radial Distance [kpc]']
    alpha = np.random.uniform(0, 360, size=len(R))
    theta = random.choices([-1, 1], k=len(R))

    x_coords = R * np.cos(np.deg2rad(alpha))
    y_coords = R * np.sin(np.deg2rad(alpha))
    z_coords = population_df['Vertical Distance [kpc]'] * theta

    sort_indices = np.argsort(y_coords)

    x_sorted = x_coords.iloc[sort_indices]
    z_sorted = z_coords.iloc[sort_indices]
    R_sorted = R.iloc[sort_indices] 

    # Lowered alpha here as well
    ax[0].scatter(x_coords, y_coords, s=star_marker_size, alpha=0.05, c=R, cmap='viridis')
    ax[1].scatter(x_sorted, z_sorted, s=star_marker_size, alpha=0.05, c=R_sorted, cmap='viridis')

    ax[0].plot(0, 0, 'X', color='yellow', markersize=12)

# --- Formatting ---
ax[0].set_title('Simulated Stars Distribution (Top-down)', fontsize=25)
ax[1].set_title('Simulated Star Distribution (Edge-on)', fontsize=25)
ax[0].set_xlabel('Galactic X [kpc]', fontsize=25)
ax[1].set_xlabel('Galactic X [kpc]', fontsize=25)
ax[0].set_ylabel('Galactic Y [kpc]', fontsize=25)
ax[1].set_ylabel('Galactic Z [kpc]', fontsize=25)

for axs in ax:
    axs.tick_params(labelsize=20, axis='both', which='both', right=False, length=10)
    axs.grid(True) 

ax[0].set_xlim(-21, 21)
ax[1].set_xlim(-21, 21)
ax[0].set_ylim(-21, 21)

plt.show()

### Quantization

In [None]:
plt.style.use('default')


# --- Model Grids for Quantization ---

# This grid defines the available stellar masses for the model.
# The quantization function will snap generated masses to the nearest value in this list.
stellar_mass_grid = np.array([
    0.11, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 
    20, 30, 40, 50, 60, 70, 80, 90, 100
])

# This grid defines the available metallicities ([Fe/H]) for the model.
metallicity_grid = np.log10([0.1, 0.3, 1])/0.977     # Bertelli et al. 1994a

q_df = quantize_population(population_df, stellar_mass_grid, metallicity_grid,in_place=False)
# in_place = True/False  --- Faster, but overrides population_df (histograms cannot be used)/Slower, but histog.

# 1. Construct the full file path
filepath_q = os.path.join(OUTPUT_DIR, "quantized", f"star_population_{DISK_TYPE}_{NUM_STARS}_q.csv")

# 2. Get the directory part of the path
output_directory = os.path.dirname(filepath_q)

# 3. Create the directory if it doesn't exist
Path(output_directory).mkdir(parents=True, exist_ok=True)

# 4. Now, save the file with confidence
q_df.to_csv(filepath_q, index=False) 


# 5. Plot

fig, ax = plt.subplots(2,2,figsize=(20,20))

ax[0,0].hist(q_df["Mzams"],bins=30)
ax[0,1].hist(population_df["Mzams"],bins=30)

ax[1,0].hist(q_df["Fe/H"],bins=30)
ax[1,1].hist(population_df["Fe/H"],bins=30)

for axs in ax:
    for xs in axs:
        xs.set_yscale('log')
        xs.tick_params(labelsize=20,axis='both', which='both',top=True,right=True,length=10)
        if xs == ax[0,0] or xs == ax [0,1]:
            xs.set_ylim(5,len(q_df["Mzams"]))
            xs.set_xlabel("$M_\mathrm{ZAMS}/M_\odot$",fontsize=20)
        else:
            xs.set_ylim(5,len(q_df["Fe/H"]))
            xs.set_xlim(np.min([np.min(q_df["Fe/H"]),np.min(population_df["Fe/H"])]),
                        np.max([np.max(q_df["Fe/H"]),np.max(population_df["Fe/H"])]))
            xs.set_xlabel("[Fe/H]",fontsize=20)
    

plt.show()

### Combine with MESA tracks

In [42]:
q_df = pd.read_csv(f"{OUTPUT_DIR}/quantized/star_population_{DISK_TYPE}_{NUM_STARS}_q.csv",
                            low_memory=False)

This part splits the creation MESA population into chuncks of $10^7$ or fewer stars, so that the RAM is not overloaded. If the population is already below that number, no splitting is done

In [None]:
action = "just_extract"           # 'generate' ---- 'just_extract' (in case a population is already there)

MESA_or_SSE = "SSE"              # Here you choose which kind of stellar tracks to use
mesa_dir = "./MESA Simulations"    
sse_dir="../SSE/Simulations"

output_dir_mesa = "./mesa_matched_chunks"
output_dir_sse = "./SSE_matched_chunks" 

chunk_size = 10_000_000
wr_cond='any'

if action=='generate':
    df_stars = process_matched_population(
        q_df=q_df,
        mode=MESA_or_SSE, 
        what_to_do=action,      
        mesa_dir=mesa_dir,
        sse_dir=sse_dir,
        output_dir_mesa=output_dir_mesa,
        output_dir_sse=output_dir_sse,
        chunk_size=chunk_size,
        wr_cond=wr_cond
    )
    
else:
    df_stars = process_matched_population(    
        mode=MESA_or_SSE, 
        what_to_do=action,      
        mesa_dir=mesa_dir,
        sse_dir=sse_dir,
        output_dir_mesa=output_dir_mesa,
        output_dir_sse=output_dir_sse,
        chunk_size=chunk_size,
        wr_cond=wr_cond
    )

In [None]:
df_stars['Z_val'] = 10**(df_stars["Fe/H"]*0.977)

stellar_mass_grid = np.array([
    0.11, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 
    20, 30, 40, 50, 60, 70, 80, 90, 100
])
metallicity_grid = np.array([0.1, 0.3, 1])
mass_ticks = [0.1,0.2,0.3,0.4,0.5,1.0,2.5,5,10,25,50,100]
metal_ticks = metallicity_grid

# --- 2. Define Helper for Datashading ---
def datashade_image(df, x_col, y_col, color_col, cmap, norm_type='linear', width=1000, height=1000, 
                    span=None, spread_px=0):

    # 1. Calculate Plot Extents
    xmin, xmax = df[x_col].min(), df[x_col].max()
    ymin, ymax = df[y_col].min(), df[y_col].max()
    
    if hasattr(xmin, 'compute'):
        xmin, xmax, ymin, ymax = dask.compute(xmin, xmax, ymin, ymax)
        
    # 2. Setup Canvas
    cvs = ds.Canvas(plot_width=width, plot_height=height, 
                    x_range=(xmin, xmax), y_range=(ymin, ymax))
    
    # 3. Aggregate
    agg = cvs.points(df, x_col, y_col, ds.mean(color_col))
    
    # 4. Handle Color Span
    if span is None:
        vmin = float(agg.min().item()) if not np.isnan(agg.min().item()) else 0
        vmax = float(agg.max().item()) if not np.isnan(agg.max().item()) else 1
        span = [vmin, vmax]
    
    # 5. Colormap Formatting
    if hasattr(cmap, 'N'): 
        cmap_colors = [mcolors.to_hex(cmap(i)) for i in np.linspace(0, 1, 256)]
    else:
        cmap_colors = cmap 

    how_method = 'log' if norm_type == 'log' else 'linear'
    
    # 6. Render (Shade)
    img = tf.shade(agg, cmap=cmap_colors, how=how_method, span=span)
    
    # --- 7. INCREASE DOT SIZE (Spreading) ---
    # This dilates the pixels. px=1 makes a 1px dot become 3x3 pixels.
    if spread_px > 0:
        img = tf.spread(img, px=spread_px)
    
    return img, [xmin, xmax, ymin, ymax]

# --- 3. Plotting ---

plt.style.use('default')
fig, ax = plt.subplots(1, 2, figsize=(20, 10))

# Create the Colormaps
cmap_M = plt.get_cmap('cool')
cmap_Z = plt.get_cmap('copper')

# --- Generate Datashader Images ---
# Note: This is where the magic happens. It processes 63M points in seconds.

# Image 1: Mass (Log scale coloring)
img_mass, extent_mass = datashade_image(
    df_stars, "log_Teff", "log_L", "Mzams", cmap=cmap_M, norm_type='log',spread_px=4
)

# Image 2: Metallicity (Linear scale coloring)
img_metal, extent_metal = datashade_image(
    df_stars, "log_Teff", "log_L", "Z_val", cmap=cmap_Z, norm_type='linear',spread_px=4
)

# --- Display Images using imshow ---
# aspect='auto' is crucial so the pixels stretch to fit your figure size
ax[0].imshow(img_mass.to_pil(), origin='upper', extent=extent_mass, aspect='auto')
ax[1].imshow(img_metal.to_pil(), origin='upper', extent=extent_metal, aspect='auto')

# --- Colorbars ---
# Since we aren't using scatter, we need to create "Dummy" mappables 
# so the colorbar knows what range and colors to show.

# Dummy for Mass
norm_M = mcolors.LogNorm(0.1, stellar_mass_grid.max())
sm_M = plt.cm.ScalarMappable(cmap=cmap_M, norm=norm_M)
cbar0 = plt.colorbar(sm_M, ax=ax[0], ticks=mass_ticks, pad=0.0)
cbar0.set_label(r'$M_\mathrm{ZAMS} / M_\odot$', rotation=270, labelpad=25, fontsize=20)
cbar0.ax.tick_params(labelsize=18, size=12, width=2)
cbar0.ax.set_yticklabels([f'{tick:g}' for tick in mass_ticks])

# Dummy for Metallicity
norm_Z = plt.Normalize(metal_ticks.min(), metal_ticks.max())
sm_Z = plt.cm.ScalarMappable(cmap=cmap_Z, norm=norm_Z)
cbar1 = plt.colorbar(sm_Z, ax=ax[1], ticks=metal_ticks, pad=0.0)
cbar1.set_label(r'$Z / Z_\odot$', rotation=270, labelpad=25, fontsize=20)
cbar1.ax.tick_params(labelsize=18, which = 'major', size=12, width=2)
cbar1.ax.set_yticklabels([f'{tick:.2f}' for tick in metal_ticks])


# --- Customize Axis (Your original formatting) ---
for axs in ax:
    axs.set_xlabel('log($T_{eff}/$K)', fontsize=25)
    if axs == ax[0]:
        axs.set_ylabel('log($L/L_{\odot}$)', fontsize=25)
    axs.tick_params(labelsize=20, axis='both', which='both', right=False, length=10)
    axs.invert_xaxis()
    axs.grid()

    
plt.tight_layout(pad=1.5)
plt.show()