# Exploring existing POSYDON MESA grids

**Main Goal:**

Become familiar with the information available in the grid:

- initial/final_values
- time series data of binaries

**Sub-learning goals:**

1. How to plot a grid slice.
2. How to plot specific information from a grid slice, i.e. a final_values[‘aaa’]
3. Plot a specific model(s): HR diagrams, mass transfer evolution
4. Show “downsampled” nature + what information is available
5. Extract a feature from the grid runs and plot them. For example, zeta parameter at the start of a mass transfer. Or the mass transfer efficiency of a specific mass transfer phase, peak Mdot, thermal timescale of the companion, tidal timescale.
6. Get familiar with the available profiles in the grid too and their potential downsides.



# 1. Grids

POSYDON comes with several pre-ran grids to use for your science cases in population synthesis. However, sometimes you want to understand a specific model or look at a region of the grid to understand the physics better.

POSYDON comes with "three" binary grids and two single star grids.
The single star grids contain hydrogen-rich non-rotating stars either at  zero-age main-sequence or non-rotating helium stars at the star of helium ignition.

The three binary grids come in two flavours: "standard" and "`RLO`"

The standard grids start their evolution at hydrogen or helium ignition depending on the type of grid.

![POSYDON_DATA file structure](./file_structure.png)

## Loading a POSYDON binary grid


In [None]:
import os
from posydon.config import PATH_TO_POSYDON_DATA
from posydon.grids.psygrid import PSyGrid
import matplotlib.pyplot as plt
import numpy as np

In [None]:
grid_file_path = os.path.join(PATH_TO_POSYDON_DATA, 'HMS-HMS', '1e-02_Zsun.h5')

In [None]:
# You can first create an instance of the PSyGrid class or load it in directly
# depending on the hardware this can take a while, when loading for the first time
grid = PSyGrid(grid_file_path)
# or
#grid = PSyGrid()
#grid.load(grid_file_path)

In [None]:
# Get some information about the grid:
print(grid)
print(len(grid))


![PSyGrid structure figure](./PsyGrid_structure.png)

The PSyGrid object contains a lot of information as the above cartoon depicts, 
but its main building blocks are the MESA simulations. These are stored as `PSyView` objects inside `PSyGrid`.
We will discuss the other components afterwards.

The `PSyView` load the data from the file when the specific information is requested.

You can easily select a MESA run using its index, like you would with an array/list:

In [None]:
index = 200
print(grid[index])
model = grid[index]

The original MESA simulations can be quite large. Each run in the `PSyGrid` comes with *downsampled* information from the MESA simulations.
The simplest two are the `initial_values` and `final_values` associated with the simulation.

As the names suggest, the `initial_values` contains the values of the first MESA history step, while `final_values` contains the final step.

Additionally, the `final_values` contains additional post-processing information, such as the `termination_flag_1/2/3/4` from the MESA simulation, information about the classification, the location of the core-envelope boundary given different hydrogen fraction definitions, and the standard SN_MODELS included in the POSYDON DR2.
Their details are described in the [POSYDON documentation](https://posydon.org/POSYDON/latest/tutorials-examples/generating-datasets/plot_2D.html#MESA-Termination-Flag-1:~:text=MESA%20Termination%20Flag%201)

In [None]:
print(model.initial_values)

However, as you might notice it doesn't show what each element is.
Luckily, the above array is a [structured numpy array](https://numpy.org/doc/stable/user/basics.rec.html) and has this information associated with each element.

In [None]:
print(model.initial_values.dtype)
# or if you only want the names:
print(model.initial_values.dtype.names)

In [None]:
model.final_values.dtype.names

In [None]:
# Some examples of accessing the final values:
print(model.final_values['termination_flag_1'])
print(model.final_values['interpolation_class'])
print(model.final_values['S1_SN_MODEL_v2_01_mass'])

<div class="alert alert-success">

## Exercise: Acccessing initial & final values

Questions about the 9.000th binary:
- What are the initial masses and period of the binary?
- Can you find how the binary with index 9000 exits the MESA simulation?
- What kind of compact object is produced with this system? Is there any difference in this between the different supernova prescriptions? Would you expect any?
- What is the state and spin of the companion at the end of the simulation?
- What kind of mass transfer took place during the evolution of the binary and was it stable or unstable?

You will need to look through the different parameters in the `initial_values` and `final_values` to get these answers.

<div class="alert alert-info">

**Done quickly?**

Can you draw a random binary between 0 and the maximum number of runs and determine the same questions for it?

</div>
</div>

In [None]:
index = 9000

<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution (click to reveal):</summary></b>

```python
# Solution
index = 9000

model = grid[index]
# initial properties
print(model.initial_values['star_1_mass'])
print(model.initial_values['star_2_mass'])
print(model.initial_values['period_days'])

# how does the simulation end?
print(model.final_values['termination_flag_1']) # termination flag

# Compact object information
print(model.final_values['S1_SN_MODEL_v2_01_mass']) # Mass of S1
print(model.final_values['S1_SN_MODEL_v2_01_CO_type']) # Type of S


# state of the companion at the end
print(model.final_values['termination_flag_3']) # State S1
print(model.final_values['termination_flag_4']) # State S2
print(model.final_values['S2_surf_avg_omega_div_omega_crit']) # Spin of S2

# Mass transfer details
print(model.final_values['termination_flag_2'])
print(model.final_values['interpolation_class'])



```


</details>

## Inspecting the timeseries

The 9.000th binary is possibly undergoing several mass transfer phases, which can be quite common in the binary grids. In the above example you examined only snapshots of the binary at the initial and final timesteps. 
It can be important from a science perspective to understand what the binary is **exactly** doing throughout its lifetime.
Luckily, a downsampled timeseries, also called *history*, of the MESA simulation is available for each binary in the `PSyGrid`.
These can be accessed using the following attributes of the `PSyView` instance.

- `.binary_history` contains histories from the MESA binary_history, which includes the masses, the orbit, and mass transfer parameters.
- `.history1` contains the history of properties of the primary star.
- `.history2` contains the history of properties of the secondary star.

In [None]:
model = grid[9000]
print(model.binary_history)
print(model.history1)
print(model.history2)

These are also structured arrays and we can use the same trick as before to get names of each column.

In [None]:
print(model.binary_history.dtype.names)
print(model.history1.dtype.names)

You can manually use these arrays to plot the evolution of the binary, which give 
a lot of flexibility in what you want to plot:

In [None]:
plt.plot(model.binary_history['age'], model.binary_history['star_1_mass'], label='star 1')
plt.plot(model.binary_history['age'], model.binary_history['star_2_mass'], label='star 2')
plt.xlabel('Age [yr]')
plt.ylabel('Mass [$M_\odot$]')
plt.legend()
plt.show()

There is also the possiblity to use the built-in `grid.plot` function, which simplifies some of
the plotting into a "single" line, including adding the correct labels and such.

<div class="alert alert-info">

**Note: Mixing histories**


One current restriction of this `grid.plot` is that you cannot plot two quantities where each resides within a different history, i.e. `history1`, `history2`, or `binary_history`.
For example, `core_he_mass` cannot be plotted against `age` using the plot function.
</div>

In [None]:
index = 9000
# note that the show_fig=True is needed to display the figure
grid.plot(index, 'age', 'star_1_mass', history='binary_history', show_fig=True)

You can also add a third value as a colour map on top of the evolutionary track, as in the example below.

You can clearly see the downsampling happening in this model.
During the main-sequence, the change in mass is limited and the number of datapoints is limited,
but later in the evolution and during the mass transfer, the changes are large and the number of datapoints is increasing.

In [None]:
index = 9000
# note that the show_fig=True is needed to display the figure
grid.plot(index, 'age', 'star_1_mass', 'period_days', history='binary_history', show_fig=True)

One of the main features of the built-in plotting is the ability to easily plot multiple figures
while setting the `PLOT_PROPERTIES` once.

The `PLOT_PROPERTIES` allows you to give the plotting function stylistic additions.

In [None]:
index = [9000, 1000]

PLOT_PROPERTIES = {
    'show_fig' : True,
    'close_fig' : True,
}

grid.plot(index, 'age', 'star_1_mass', history='binary_history', **PLOT_PROPERTIES)

In [None]:
PLOT_PROPERTIES = {
    'show_fig' : True,
    'close_fig' : True,
    'figsize' : (6,2),
    'legend1D': dict(loc='upper right', lines_legend=['9000', '1000']),
}

grid.plot(index, 'age', 'star_1_mass', history='binary_history', **PLOT_PROPERTIES)


In [None]:
PLOT_PROPERTIES = {
    'show_fig' : True,
    'close_fig' : True,
    'figsize' : (6,2),
    'legend1D': dict(loc='upper right', lines_legend=['9000', '1000']),
    'log10_x' : True,
    'log10_y' : True,
}

grid.plot(index, 'star_1_mass', 'star_2_mass', history='binary_history', **PLOT_PROPERTIES)


There are many different variables you can pass into the plot properties.
To see them all, you can import the default `PLOT_PROPERTIES`.

In [None]:
from posydon.visualization.plot_defaults import PLOT_PROPERTIES
print(PLOT_PROPERTIES)

Besides `grid.plot`, there's also `grid.HR` for specifically plotting Hertzsprung-Russel diagrams:

In [None]:
grid.HR(idx=[8000,9000], history='history1', show_fig=True)

<div class="alert alert-success">

## Exercise: Inspecting the mass transfer of binaries

With the knowledge of the previous section, we will start to extract specific mass transfer information from an individual run.
We start with our example binary: 9000.
It has two mass transfer phases: Case B and a Case C.

Let's see why this is by looking at the stellar radii and binary separation.
While also inspecting the stellar structure of the star initiating the mass transfer.

Let's plot the following:
1. The radii of the stars and separation of the binary.
2. The instantaneous mass transfer rate over age.
3. The change in primary and secondary mass over age.
4. Focus on the donor star in each mass transfer phase and show the core and envelope structure over stellar age. Can you figure out why a Case C mass transfer takes place?

</div>

In [None]:
# SOLUTION 1.
model = grid[9000]

In [None]:
# SOLUTION 2. 
model = grid[9000]

In [None]:
# SOLUTION 3 & 4.
model = grid[9000]

<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution 1 (click to reveal):</summary></b>


```python

model = grid[9000]

plt.plot(model.binary_history['age'], 10**model.history1['log_R'], label='Primary')
plt.plot(model.binary_history['age'], 10**model.history2['log_R'], label='Secondary')
plt.plot(model.binary_history['age'], model.binary_history['binary_separation'], label='Separation', ls='--')
plt.yscale('log')
plt.xlabel('Age [yr]')
plt.legend()
plt.ylabel('Radius or Separation [$R_\odot$]')

```

</details>
</div>


<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution 2 (click to reveal):</summary></b>


```python
model = grid[9000]

delta_t = np.diff(model.binary_history['age'])

plt.plot(model.binary_history['age'], 10**model.binary_history['lg_mtransfer_rate'], label='instantaneous MT rate')
plt.plot(model.binary_history['age'][:-1], -np.diff(model.binary_history['star_1_mass'])/delta_t, label='-$\Delta M_1$')
plt.plot(model.binary_history['age'][:-1], np.diff(model.binary_history['star_2_mass'])/delta_t, label='$\Delta M_2$')

plt.ylabel('Mass change rate [$M_\odot$/yr]')
plt.xlabel('Age [yr]')
plt.legend()
plt.xlim(5e6)
plt.show()

```

</details>
</div>


<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution 3 (click to reveal):</summary></b>


```python 
model = grid[9000]

plt.plot(model.binary_history['age'], model.binary_history['star_1_mass'], label='$M_1$')
plt.plot(model.binary_history['age'], model.history1['he_core_mass'], label='He core $M_1$')
plt.plot(model.binary_history['age'], model.history1['co_core_mass'], label='CO core $M_1$')
plt.plot(model.binary_history['age'], model.binary_history['star_2_mass'], label='$M_2$')
plt.xlabel('Age [yr]')
plt.ylabel('Mass [$M_\odot$]')
plt.xlim(5e6)
plt.legend()
plt.show()

```

</details>
</div>


<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution 4 (click to reveal):</summary></b>


Despite the Case B mass transfer, the mass ratios reverse and the primary is not fully stripped.
As such, after core-helium burning ends, the star expands again, leading to a second phase of mass transfer!

</details>
</div>

As you can see the mass transfer in this model is very short, which makes sense with it being a Case B mass transfer, which we saw earlier.
You have to zoom in quite a lot to see the actual changes.
Using the `model_number` can help you see this better, but it distorts the time evolution.
Instead, we can use the other information to identify the mass transfer phase.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

model = grid[9000]

delta_t = np.diff(model.binary_history['age'])

plt.plot(model.binary_history['model_number'], 10**model.binary_history['lg_mtransfer_rate'], label='instantaneous MT rate')
plt.plot(model.binary_history['model_number'][:-1], -np.diff(model.binary_history['star_1_mass'])/delta_t, label='-$\Delta M_1$')
plt.plot(model.binary_history['model_number'][:-1], np.diff(model.binary_history['star_2_mass'])/delta_t, label='$\Delta M_2$')

plt.ylabel('Mass change rate [$M_\odot$/yr]')
plt.xlabel('Model number')
plt.legend()
plt.show()

<div class="alert alert-success">

##  Exercise: extracting information from the timeseries

It can be useful to extract information straight from the timeseries of a binary.
For example, the efficiency of the mass transfer. In the previous exercise, you saw that the primary loses quite a lot of mass during the Case B mass transfer,
while the companion initially accretes quite a bit, but quickly stops accreting.
We want to know what the efficiency of both mass transfer phases are.

For this we need to know when the mass transfer starts and ends.

1. First find the boundaries of the mass transfer phases.
You can use the relative overflow of a star compared to its Roche lobe to determine the start and end of the mass transfer phase. (See hints at the end, if you get stuck here.)
2. Then, let's calculate the efficiency per phase. Since the stellar masses are used in the downsampling, their changes should be better tracked than the instantaneous mass transfer rate, so use them to calculate the efficiency.

</div>

In [None]:
# SOLUTION
RLO1_mask = model.binary_history['rl_relative_overflow_1'] >= 0

# Find indices where RLO1 changes from False to True (start of RLO phases)
rlo_start_indices =  # Fill in

# Find indices where RLO1 changes from True to False (end of RLO phases)  
rlo_end_indices = # Fill in

print(f"RLO starts at indices: {rlo_start_indices}")
print(f"RLO ends at indices: {rlo_end_indices}")

# If RLO1 starts with True, the first phase starts at index 0
if len(RLO1_mask) > 0 and RLO1_mask[0]:
    rlo_start_indices = np.concatenate([[0], rlo_start_indices])

# If RLO1 ends with True, the last phase ends at the final index
if len(RLO1_mask) > 0 and RLO1_mask[-1]:
    rlo_end_indices = np.concatenate([rlo_end_indices, [len(RLO1_mask) - 1]])

print(f"Complete RLO start indices: {rlo_start_indices}")
print(f"Complete RLO end indices: {rlo_end_indices}")

# Now we can extract the start and end indices of each RLO phase
RLO_phases = # Fill in
print(f"RLO phases (start_index, end_index): {RLO_phases}")

# Calculate the efficiency of each mass transfer phase
for start, end in RLO_phases:
    delta_M1 = # Fill in
    delta_M2 = # Fill in
    efficiency = delta_M2 / delta_M1 if delta_M1 != 0 else 0
    print(f"RLO phase from index {start} to {end}:")
    print(f"  $\Delta M_1$ = {delta_M1:.4f} M_sun")
    print(f"  $\Delta M_2$ = {delta_M2:.4f} M_sun")
    print(f"  Efficiency ($\Delta M_2 / \Delta M_1$) = {efficiency:.4f}")

<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution (click to reveal):</summary></b>

```python

# SOLUTION
RLO1_mask = model.binary_history['rl_relative_overflow_1'] >= 0

# Find indices where RLO1 changes from False to True (start of RLO phases)
rlo_start_indices = np.where(np.diff(RLO1_mask.astype(int)) == 1)[0] + 1

# Find indices where RLO1 changes from True to False (end of RLO phases)  
rlo_end_indices = np.where(np.diff(RLO1_mask.astype(int)) == -1)[0] + 1

print(f"RLO starts at indices: {rlo_start_indices}")
print(f"RLO ends at indices: {rlo_end_indices}")

# If RLO1 starts with True, the first phase starts at index 0
if len(RLO1_mask) > 0 and RLO1_mask[0]:
    rlo_start_indices = np.concatenate([[0], rlo_start_indices])

# If RLO1 ends with True, the last phase ends at the final index
if len(RLO1_mask) > 0 and RLO1_mask[-1]:
    rlo_end_indices = np.concatenate([rlo_end_indices, [len(RLO1_mask) - 1]])

print(f"Complete RLO start indices: {rlo_start_indices}")
print(f"Complete RLO end indices: {rlo_end_indices}")

# Now we can extract the start and end indices of each RLO phase
RLO_phases = list(zip(rlo_start_indices, rlo_end_indices))
print(f"RLO phases (start_index, end_index): {RLO_phases}")

# Calculate the efficiency of each mass transfer phase
for start, end in RLO_phases:
    delta_M1 = model.binary_history['star_1_mass'][start] - model.binary_history['star_1_mass'][end]
    delta_M2 = model.binary_history['star_2_mass'][end] - model.binary_history['star_2_mass'][start]
    efficiency = delta_M2 / delta_M1 if delta_M1 != 0 else 0
    print(f"RLO phase from index {start} to {end}:")
    print(f"  ΔM1 = {delta_M1:.4f} M_sun")
    print(f"  ΔM2 = {delta_M2:.4f} M_sun")
    print(f"  Efficiency (ΔM2/ΔM1) = {efficiency:.4f}")
    print()

```


</details>

# 2. Multiple systems

We're now shifting our attention to looking at multiple binaries at the same time. You're going to select binaries with the same initial masses, but with different initial period and make figures based on the evolutionary and the final values.


<div class="alert alert-success">

## Exercise: Plotting multiple binaries


1. Find all indices with the same initial masses as our 9000 example binary. There should be 34 binaries with the same initial masses.
2. Plot the mass evolution of each model in a single figure. **NOTE: sometimes not all histories are available! You will need to account for this.**
3. Add a colourmap to the lines based on initial period.
4. Based on the mass transfer stability, can you mark the end of each unstable mass transfer model with a diamond at the end of it's evolution? And all stable mass transfers with a star?

</div>

In [None]:
# SOLUTION 1. 
index = 9000
model = grid[index]
mass_primary = model.initial_values['star_1_mass']
mass_secondary = model.initial_values['star_2_mass']
mask_m1 = # ? 
mask_m2 = # ?
combined_mask = mask_m1 & mask_m2
indices = # ?
print(len(indices))

In [None]:
# SOLUTION 2.

# ?
 
plt.xlabel('Age [yr]')
plt.ylabel('Mass [$M_\odot$]')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')
plt.show()

In [None]:
# SOLUTION 3.
periods = #
sorted_indices = #

cmap = plt.get_cmap('viridis', len(sorted_indices))
norm = plt.Normalize(vmin=min(np.log10(periods)), vmax=max(np.log10(periods)))

for idx in sorted_indices:
    # history might be None. 
    #
    x = #
    y = #
    label = #
    plt.plot(x, y,
             color=cmap(norm(np.log10(grid[idx].initial_values['period_days']))),
             label=label)
    
plt.xlabel('Age [yr]')
plt.ylabel('Mass [$M_\odot$]')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')
plt.show()

In [None]:
# SOLUTION 4.
for idx in sorted_indices:
    # history might be None.

    if grid[idx].final_values['interpolation_class'] == 'unstable_MT':
        # Fill in
        
    elif grid[idx].final_values['interpolation_class'] == 'stable_MT':
        # Fill in
    else:
        # Fill in
    
plt.xlabel('Age [yr]')
plt.ylabel('Mass [$M_\odot$]')

plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')
plt.show()

<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution 1 (click to reveal):</summary></b>

```python
index = 9000
model = grid[index]
mass_primary = model.initial_values['star_1_mass']
mass_secondary = model.initial_values['star_2_mass']
mask_m1 = np.isclose(grid.initial_values['star_1_mass'], mass_primary, rtol=0.05)
mask_m2 = np.isclose(grid.initial_values['star_2_mass'], mass_secondary, rtol=0.05)
combined_mask = mask_m1 & mask_m2
indices = np.where(combined_mask)[0]
print(len(indices))

```

</details>
</div>

<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution 2 (click to reveal):</summary></b>

```python
for idx in indices:
    if grid[idx].binary_history is None:
        continue
    plt.plot(grid[idx].binary_history['age'],
             grid[idx].binary_history['star_1_mass'],
             label=f"$P={grid[idx].initial_values['period_days']:.2f}$ days")
 
plt.xlabel('Age [yr]')
plt.ylabel('Mass [$M_\odot$]')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')
plt.show()

```

</details>
</div>

<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution 3 (click to reveal):</summary></b>

```python
periods =  grid.initial_values['period_days'][indices]
sorted_indices = indices[np.argsort(periods)]

cmap = plt.get_cmap('viridis', len(sorted_indices))
norm = plt.Normalize(vmin=min(np.log10(periods)), vmax=max(np.log10(periods)))

for idx in sorted_indices:
    if grid[idx].binary_history is None:
        continue
    x = grid[idx].binary_history['age']
    y = grid[idx].binary_history['star_1_mass']
    plt.plot(x, y,
             color=cmap(norm(np.log10(grid[idx].initial_values['period_days']))),
             label=f"$P={grid[idx].initial_values['period_days']:.2f}$ days")
    
plt.xlabel('Age [yr]')
plt.ylabel('Mass [$M_\odot$]')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')
plt.show()
```

</details>
</div>


<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution 4 (click to reveal):</summary></b>

```python
for idx in sorted_indices:
    if grid[idx].binary_history is None:
        continue
    if grid[idx].final_values['interpolation_class'] == 'unstable_MT':
        plt.scatter(grid[idx].final_values['age'],
                    grid[idx].final_values['star_1_mass'],
                    marker='D', color=cmap(norm(np.log10(grid[idx].initial_values['period_days']))))
        plt.plot(grid[idx].binary_history['age'],
             grid[idx].binary_history['star_1_mass'],
             label=f"$P={grid[idx].initial_values['period_days']:.2f}$ days",
             ls='--',
             color=cmap(norm(np.log10(grid[idx].initial_values['period_days']))))
        
    elif grid[idx].final_values['interpolation_class'] == 'stable_MT':
        plt.plot(grid[idx].binary_history['age'],
             grid[idx].binary_history['star_1_mass'],
             label=f"$P={grid[idx].initial_values['period_days']:.2f}$ days",
             ls='-',
             color=cmap(norm(np.log10(grid[idx].initial_values['period_days']))))
        plt.scatter(grid[idx].final_values['age'],
                    grid[idx].final_values['star_1_mass'],
                    marker='*', color=cmap(norm(np.log10(grid[idx].initial_values['period_days']))))
    else:
        plt.plot(grid[idx].binary_history['age'],
             grid[idx].binary_history['star_1_mass'],
             label=f"$P={grid[idx].initial_values['period_days']:.2f}$ days",
             ls='-',
             color=cmap(norm(np.log10(grid[idx].initial_values['period_days']))))
    
plt.xlabel('Age [yr]')
plt.ylabel('Mass [$M_\odot$]')

plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')
plt.show()
```

</details>
</div>

## Extracting data from multiple systems

Using the systems you've found in the previous exercise, let's find the mass transfer efficiency of each stable mass transfer binary.

The function `efficiency_per_phase` returns the mass transfer efficiency and the amount of mass $M_1$ has lost during each mass transfer phase.

<div class='alert alert-success'>

## Exercise: multiple binaries data extraction

Using the `efficiency_per_phase` function, calculate the mass-averaged efficiency 
for each binary  you selected in the previous exercise and which undergoes stable mass transfer.

Does the mass-weighted efficiency change as a function of initial period?

</div>

In [None]:
def efficiency_per_phase(model):
    """
    Calculate the mass transfer efficiency for each mass transfer phase in a binary model.
    
    Parameters:
    model : PSyView
        Binary model to analyze.
    Returns:
    efficiencies : list of float
        List of mass transfer efficiencies for each phase.
    mass_loss : list of float
        List of mass lost by the donor star in each phase.
    """
    
    RLO1_mask = model.binary_history['rl_relative_overflow_1'] >= 0

    # Find indices where RLO1 changes from False to True (start of RLO phases)
    rlo_start_indices = np.where(np.diff(RLO1_mask.astype(int)) == 1)[0] + 1

    # Find indices where RLO1 changes from True to False (end of RLO phases)  
    rlo_end_indices = np.where(np.diff(RLO1_mask.astype(int)) == -1)[0] + 1

    # If RLO1 starts with True, the first phase starts at index 0
    if len(RLO1_mask) > 0 and RLO1_mask[0]:
        rlo_start_indices = np.concatenate([[0], rlo_start_indices])

    # If RLO1 ends with True, the last phase ends at the final index
    if len(RLO1_mask) > 0 and RLO1_mask[-1]:
        rlo_end_indices = np.concatenate([rlo_end_indices, [len(RLO1_mask) - 1]])

    # Now we can extract the start and end indices of each RLO phase
    RLO_phases = list(zip(rlo_start_indices, rlo_end_indices))

    efficiencies = []
    mass_loss = []
    # Calculate the efficiency of each mass transfer phase
    for start, end in RLO_phases:
        delta_M1 = model.binary_history['star_1_mass'][start] - model.binary_history['star_1_mass'][end]
        delta_M2 = model.binary_history['star_2_mass'][end] - model.binary_history['star_2_mass'][start]
        efficiency = delta_M2 / delta_M1 if delta_M1 != 0 else 0
        efficiencies.append(efficiency)
        mass_loss.append(delta_M1 - delta_M2)

    return efficiencies, mass_loss

In [None]:
# SOLUTION
mass_primary = model.initial_values['star_1_mass']
mass_secondary = model.initial_values['star_2_mass']

mask1 = #
mask2 = #
stable_mask = #

combined_mask = mask1 & mask2 & stable_mask
indices = np.where(combined_mask)[0]
print(indices)

# sort the period by initial period
periods = grid.initial_values['period_days'][indices]
sorted_indices = indices[np.argsort(periods)]

for idx in sorted_indices:
    print(grid[idx].final_values['termination_flag_2'])
    
    # get a mass weighted efficiency
    weighted_efficiency = #
    print(weighted_efficiency)

<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution (click to reveal):</summary></b>

```python
mass_primary = model.initial_values['star_1_mass']
mass_secondary = model.initial_values['star_2_mass']

mask1 = np.isclose(grid.initial_values['star_1_mass'], mass_primary, atol=0.1)
mask2 = np.isclose(grid.initial_values['star_2_mass'], mass_secondary, atol=0.1)

stable_mask = grid.final_values['interpolation_class'] == 'stable_MT'
combined_mask = mask1 & mask2 & stable_mask
indices = np.where(combined_mask)[0]
print(indices)

periods = grid.initial_values['period_days'][indices]
sorted_indices = indices[np.argsort(periods)]


for idx in sorted_indices:
    print(grid[idx].final_values['termination_flag_2'])
    efficiency, mass_loss = efficiency_per_phase(grid[idx])
    # get a mass weighted efficiency
    weighted_efficiency = np.average(efficiency, weights=mass_loss)
    print(weighted_efficiency)

```
</details>

# 3. Exploring a complete grid

So far we've focussed on the different building blocks of the `PSyGrid` object,
and this can get you very far in your explorations of the grids.

POSYDON provides a few additional features to make the exploration of grid slices easier.
The `initial_values` and `final_values` are available on a grid level, such that you do not need to loop over the individual binaries to get these values.
Hopefully, you've seen and used this in the previous section already.

The indices in `grid.initial_values[i]` match the indices in `grid[i]`.
This makes it easier to select specific systems and then extract information about the binary.

<div class='alert alert-success'>

## Exercise: plot 2D grid slice


Now you're going to create a 2D grid slice using the POSYDON grid with $M_1$ and $P$ as the x-axis and y-axis, respectively.
We select $q=0.8$ for the companion masses.

We will help you by setting up several unique arrays first and selecting only $q=0.8$ values.
The POSYDON HMS-HMS grids are spaced every $\Delta q=0.05$, so we have to make sure to only select systems in a specific grid slice.
</div>

In [None]:
import numpy as np
m1_unique = np.unique(grid.initial_values['star_1_mass'])
m2_unique = np.unique(grid.initial_values['star_2_mass'])
P_unique = np.unique(grid.initial_values['period_days'])

q = 0.8

# will in the np.isclose function check for values within q +/- 0.04
mask_q = np.isclose(grid.initial_values['star_2_mass']/grid.initial_values['star_1_mass'], q, atol=0.04)
indices = np.where(mask_q)[0]

In [None]:
x = #
y = #
plt.scatter(x,
            y,
            marker = '.', color='black')

plt.xscale('log')
plt.yscale('log')
plt.xlabel('$M_1$ [$M_\odot$]')
plt.ylabel('$P$ [days]')
plt.text(0.95, 0.95, f'$q={q}$', transform=plt.gca().transAxes,
         ha='right', va='top',
         bbox=dict(boxstyle='round', facecolor='white', edgecolor='black', lw=0.5, alpha=1))

<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution (click to reveal):</summary></b>

```python
x = grid.initial_values['star_1_mass'][indices]
y = grid.initial_values['period_days'][indices]
plt.scatter(x,
            y,
            marker = '.', color='black')

plt.xscale('log')
plt.yscale('log')
plt.xlabel('$M_1$ [$M_\odot$]')
plt.ylabel('$P$ [days]')
plt.text(0.95, 0.95, f'$q={q}$', transform=plt.gca().transAxes,
         ha='right', va='top',
         bbox=dict(boxstyle='round', facecolor='white', edgecolor='black', lw=0.5, alpha=1))

```
</details>

You should now have a plotted a single grid slice that is showing where all the MESA models lie in the $M_1$-$P$ ZAMS parameter space. It should also be very obvious that the density of the grid is higher below ${\sim}30M_\odot$, which makes plotting more complex but is beneficial for the synthesis of stellar populations.

Let's add some more information to the grid slice plotting.
Since one of the main POSYDON Features is the self-consistent tracking of rotation in the MESA models, we will see what the rotation is of the companion star at the end of the MESA model.
You can just add this as a colourmap to the scatter.


In [None]:

plt.scatter(grid.initial_values['star_1_mass'][indices],
            grid.initial_values['period_days'][indices],
            c=grid.final_values['S2_surf_avg_omega_div_omega_crit'][indices],
            marker = '.', cmap='viridis', vmin=0, vmax=1)
            
plt.colorbar(label='$S_2$ spin parameter')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('$M_1$ [$M_\odot$]')
plt.ylabel('$P$ [days]')
plt.text(0.95, 0.95, f'$q={q}$', transform=plt.gca().transAxes,
         ha='right', va='top',
         bbox=dict(boxstyle='round', facecolor='white', edgecolor='black', lw=0.5, alpha=1))

<div class="alert alert-success">
Do you notice any difference between the two grids plots?

</div>

At short periods, systems are missing! What is going on here?
The explanation is simple if we plot the type of interactions each model undergoes.
For this, we will use the `interpolation_class` and `termination_flag_X` of the models.

In [None]:

unique_flags = np.unique(grid.final_values['interpolation_class'])

for flag in unique_flags:
    mask_flag = grid.final_values['interpolation_class'][indices] == flag
    plt.scatter(grid.initial_values['star_1_mass'][indices][mask_flag],
            grid.initial_values['period_days'][indices][mask_flag],
            marker = '.', label=str(flag))

plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('$M_1$ [$M_\odot$]')
plt.ylabel('$P$ [days]')
plt.text(0.95, 0.95, f'$q={q}$', transform=plt.gca().transAxes,
         ha='right', va='top',
         bbox=dict(boxstyle='round', facecolor='white', edgecolor='black', lw=0.5, alpha=1))

In [None]:

unique_flags = np.unique(grid.final_values['termination_flag_1'])

for flag in unique_flags:
    mask_flag = grid.final_values['termination_flag_1'][indices] == flag
    plt.scatter(grid.initial_values['star_1_mass'][indices][mask_flag],
            grid.initial_values['period_days'][indices][mask_flag],
            marker = '.', label=str(flag))

plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('$M_1$ [$M_\odot$]')
plt.ylabel('$P$ [days]')
plt.text(0.95, 0.95, f'$q={q}$', transform=plt.gca().transAxes,
         ha='right', va='top',
         bbox=dict(boxstyle='round', facecolor='white', edgecolor='black', lw=0.5, alpha=1))

The figures tell us that the bottom MESA models are missing in the S2 spin figure, because the MESA models start with Roche lobe overflow and are thus not evolved!
These figures are okay for debugging and detailed inspections but can also be overwhelming.
There is a lot of information packed into a single figure.
So, the `PSyGrid` class  contains a 2D plotting system to simplify grid slice plotting.
You can find more information about this function here: [plot2D](https://posydon.org/POSYDON/latest/tutorials-examples/generating-datasets/plot_2D.html)
The cells below do the same as the figures above.

In [None]:
# similar to the 1D plotting method, we can pass a dictionary with additional plotting properties
PLOT_PROPERTIES = {
    'log10_x' : True,
    'log10_y' : True,
    'show_fig': True,
    'close_fig': True,
    'title': f'$q={q}$',
}

grid.plot2D(
    # which parameters to plot
    x_var_str='star_1_mass',
    y_var_str='period_days',
    # z_var parameter to color the points by (only compatible with termination_flag_1)
    z_var_str=None,

    # what termination flag to color the points by
    termination_flag='interpolation_class',
            
    # set that the 3D grid should be used and what parameter to slice on.
    grid_3D=True, slice_3D_var_str='mass_ratio',
    # select only the models with q +/- 0.03
    # similar to q_mask above
    slice_3D_var_range=(q-0.03, q+0.03),
    **PLOT_PROPERTIES)

You can also easily plot a `final_value` of the grid with the `z_var_str`, but this is currently only compatible when using `termination_flag_1`.

In [None]:
# similar to the 1D plotting method, we can pass a dictionary with additional plotting properties
PLOT_PROPERTIES = {
    'figsize': (7,6.),
    'log10_x' : True,
    'log10_y' : True,
    'show_fig': True,
    'close_fig': True,
    'title': f'$q={q}$',
}

grid.plot2D(
    # which parameters to plot
    x_var_str='star_1_mass',
    y_var_str='period_days',
    # z_var parameter to color the points by (only compatible with termination_flag_1)
    z_var_str='S2_surf_avg_omega_div_omega_crit',
    
    # what termination flag to color the points by
    termination_flag='termination_flag_1',
            
    # set that the 3D grid should be used and what parameter to slice on.
    grid_3D=True, slice_3D_var_str='mass_ratio',
    # select only the models with q +/- 0.03
    # similar to q_mask above
    slice_3D_var_range=(q-0.03, q+0.03),
    **PLOT_PROPERTIES)

An extremely useful addition is the `"combined_TF12"` termination flag, which combines the last interaction the system had and its stability (based on the `interpolation_class`) to show in detail what type of interactions the system has had. Moreover, it still shows the failed models and initial RLO models too. 

In [None]:
# similar to the 1D plotting method, we can pass a dictionary with additional plotting properties
PLOT_PROPERTIES = {
    'figsize': (6,5.),
    'log10_x' : True,
    'log10_y' : True,
    'show_fig': True,
    'close_fig': True,
    'title': f'$q={q}$',
}

grid.plot2D(
    # which parameters to plot
    x_var_str='star_1_mass',
    y_var_str='period_days',
    # z_var parameter to color the points by (only compatible with termination_flag_1)
    z_var_str=None,
    
    # what termination flag to color the points by
    termination_flag='combined_TF12',
            
    # set that the 3D grid should be used and what parameter to slice on.
    grid_3D=True, slice_3D_var_str='mass_ratio',
    # select only the models with q +/- 0.03
    # similar to q_mask above
    slice_3D_var_range=(q-0.03, q+0.03),
    **PLOT_PROPERTIES)

The plot2D also makes it easier to plot categorial data, such as the outcome of the `SN_MODEL`s.

POSYDON comes with a variety of pre-calcualted supernova remnant mass prescriptions, which are calculated using the detailed stellar profiles at carbon depletion. This adds additional accuracy in to the collapse calculation, which is not available from the downsampled profiles in the public grids. This is especially important for the BH spins.

Categorial data is put into the `termination_flag=` input instead of the `z_var_str` input parameter, which can be confusing at first. 

In [None]:
PLOT_PROPERTIES = {
    'figsize': (6,5.),
    'log10_x' : True,
    'log10_y' : True,
    'show_fig': True,
    'close_fig': True,
    'title': f'$q={q}$',
}

grid.plot2D(
    # which parameters to plot
    x_var_str='star_1_mass',
    y_var_str='period_days',
    # z_var parameter to color the points by (only compatible with termination_flag_1)
    z_var_str=None,
    
    # what termination flag to color the points by
    termination_flag='S1_SN_MODEL_v2_01_CO_type',
            
    # set that the 3D grid should be used and what parameter to slice on.
    grid_3D=True, slice_3D_var_str='mass_ratio',
    # select only the models with q +/- 0.03
    # similar to q_mask above
    slice_3D_var_range=(q-0.03, q+0.03),
    **PLOT_PROPERTIES)

<div class="alert alert-success">

## Exercise: plot different q values

Make 2 figures with the build-in `plot2D` function:
1. A mass transfer stability plot at q=0.4
2. The state of star 2 at the end of the MESA model at q=0.99
</div>

In [None]:
# SOLUTION
PLOT_PROPERTIES = {
    'figsize': (5,4.),
    'log10_x' : True,
    'log10_y' : True,
    'show_fig': True,
    'close_fig': True,
    'title': f'$q={q}$',
}

# ?

In [None]:
# SOLUTION
PLOT_PROPERTIES = {
    'figsize': (5,4.),
    'log10_x' : True,
    'log10_y' : True,
    'show_fig': True,
    'close_fig': True,
    'title': f'$q={q}$',
}

# ?

<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution 1 (click to reveal):</summary></b>

```python
# SOLUTION
q =0.4

PLOT_PROPERTIES = {
    'figsize': (5,4.),
    'log10_x' : True,
    'log10_y' : True,
    'show_fig': True,
    'close_fig': True,
    'title': f'$q={q}$',
}

grid.plot2D(
    # which parameters to plot
    x_var_str='star_1_mass',
    y_var_str='period_days',
    # z_var parameter to color the points by (incompatible with combined_TF12)
    z_var_str=None,
    
    # what termination flag to color the points by
    termination_flag='interpolation_class',
            
    # set that the 3D grid should be used and what parameter to slice on.
    grid_3D=True, slice_3D_var_str='mass_ratio',
    # select only the models with q +/- 0.03
    # similar to q_mask above
    slice_3D_var_range=(q-0.03, q+0.03),
    **PLOT_PROPERTIES)
```
</details>
</div>

<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution 2 (click to reveal):</summary></b>

```python
# SOLUTION
q =0.99

PLOT_PROPERTIES = {
    'figsize': (5,4.),
    'log10_x' : True,
    'log10_y' : True,
    'show_fig': True,
    'close_fig': True,
    'title': f'$q={q}$',
}

grid.plot2D(
    # which parameters to plot
    x_var_str='star_1_mass',
    y_var_str='period_days',
    # z_var parameter to color the points by (incompatible with combined_TF12)
    z_var_str=None,
    
    # what termination flag to color the points by
    termination_flag='termination_flag_4',
            
    # set that the 3D grid should be used and what parameter to slice on.
    grid_3D=True, slice_3D_var_str='mass_ratio',
    # select only the models with q +/- 0.03
    # similar to q_mask above
    slice_3D_var_range=(q-0.03, q+0.03),
    **PLOT_PROPERTIES)
```
</details>
</div>

<div class='alert alert-success'>

## Exercise: customising the grid plot

Let's change our Figure to make a plot with the mass ratio on the x-axis, period on the y-axis and 
a slice in the primary mass around 10 $M_\odot$. 
You will need to make a selection around this mass as part of the slice.
Look into the grid sizes and spacing in [Andrews et al. (2025)](https://ui.adsabs.harvard.edu/abs/2024arXiv241102376A/abstract) to figure out a good size.
</div>

In [None]:
M_1 = 10
delta_M = # Fill in the delta you want to use?

PLOT_PROPERTIES = {
    'figsize': (5,4.),
    'log10_x' : False,
    'log10_y' : True,
    'show_fig': True,
    'close_fig': True,
    'title': f'$M_1\approx{M_1} M_\odot$',
}

# ?

<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution (click to reveal):</summary></b>

```python
# SOLUTION
M_1 = 10
delta_M = 0.5

PLOT_PROPERTIES = {
    'figsize': (5,4.),
    'log10_x' : False,
    'log10_y' : True,
    'show_fig': True,
    'close_fig': True,
    'title': r'$M_1\approx{M_1} M_\odot$',
}

grid.plot2D(
    # which parameters to plot
    x_var_str='mass_ratio',
    y_var_str='period_days',
    # z_var parameter to color the points by (only compatible with termination_flag_1)
    z_var_str=None,
    
    # what termination flag to color the points by
    termination_flag='interpolation_class',
            
    # set that the 3D grid should be used and what parameter to slice on.
    grid_3D=True, slice_3D_var_str='star_1_mass',
    # select only the models with q +/- 0.03
    # similar to q_mask above
    slice_3D_var_range=(M_1-delta_M, M_1+delta_M),
    **PLOT_PROPERTIES)


```
</details>

Well done! 

There's one final skill to learn!
You're now going to combine the skills you learned from looking at single MESA tracks with the 2D plotting!
Using the function to extract the mass transfer efficiency, we're going to add these to our grid slice and grid.
There are two ways of adding information to a figure on the grid. You can either pass it as the `z_var_str` or add it to the grid.

In [None]:
# This code calculates the mass transfer efficiencies for all models in the grid
# and stores the average efficiency (weighted by mass accreted) in an array.
from tqdm import tqdm

efficiencies = np.zeros(grid.n_runs)
for i in tqdm(range(grid.n_runs)):
    if grid[i].binary_history is not None:
        eff, mass = efficiency_per_phase(grid[i])        
        if len(mass) != 0 and np.sum(mass) != 0.0 :
            average = np.average(eff, weights=mass)
        else:
            average = np.nan
    else:
        average = np.nan
    efficiencies[i] = average

In [None]:
# Plotting the average efficiencies on the grid
q=0.8
PLOT_PROPERTIES = {
    'figsize': (6,5.),
    'log10_x' : True,
    'log10_y' : True,
    'show_fig': True,
    'close_fig': True,
    'title': f'$q={q}$',
    'zmin':0.0,
    'zmax':1.0,
    'colorbar': {'label':'Average MT efficiency'}
}

grid.plot2D(
    # which parameters to plot
    x_var_str='star_1_mass',
    y_var_str='period_days',
    # z_var parameter to color the points by (only compatible with termination_flag_1)
    z_var_str=efficiencies,
    # what termination flag to color the points by
    termination_flag='termination_flag_1',
    # set that the 3D grid should be used and what parameter to slice on.
    grid_3D=True, slice_3D_var_str='mass_ratio',
    # select only the models with q +/- 0.03
    # similar to q_mask above
    slice_3D_var_range=(q-0.03, q+0.03),
    **PLOT_PROPERTIES)

Adding the values to the grid, requires a bit more work, but allows you to store long calculations
for later use in the same file.
Because multiple people are using the same files at `PATH_TO_POSYDON_DATA`, we cannot write to this file, while others are working on it.
So we will create a copy of the one grid file as an example:

In [None]:
# copy the grid; the copy can take a bit
# Because multiple people might have the same file open, we copy it to a new file!
import shutil
shutil.copyfile(grid_file_path, './1e-02_Zsun_with_beta_MT.h5')

In [None]:
grid = PSyGrid('./1e-02_Zsun_with_beta_MT.h5')
# Adding the values to the grid, requires a bit more work, but allows you to store long calculations
grid.add_column('beta_MT', efficiencies)

In [None]:
# Load the grid again to verify the values are there!
grid = PSyGrid('./1e-02_Zsun_with_beta_MT.h5')

<div class="alert alert-success">
    
## Exercise: create a 2D grid slice plot with beta_MT

Using the skills, you've learned can you use `grid.plot2D` to create a grid slice plot at 
q=0.7, which shows the average beta of the mass transfer phases?
</div>

In [None]:
# ? 

<div class="alert alert-warning" style="margin-top: 20px">
<details>
    
<b><summary>Solution (click to reveal):</summary></b>

```python
# SOLUTION
plot = grid.plot2D(x_var_str='star_1_mass', y_var_str='period_days', z_var_str='beta_MT',
            termination_flag='termination_flag_1',
            grid_3D=True, slice_3D_var_str='mass_ratio',
            slice_3D_var_range=(q-0.03, q+0.03),
            legend_pos = (3,3),
            verbose=True, **PLOT_PROPERTIES)
```
</details>