<table align="center" style="text-align:center; border-collapse:collapse; border-spacing:0; width:100%;">
    <tr>
        <td style="width:25%; padding:0;">
            <img src="../docs/_static/astronuc-header.png" style="max-width:100%;" />
        </td>
        <td style="width:50%; padding:0;">
            <h1 style="font-size:50px">
                AstroNuc 2026<br><code>cogsworth</code> tutorial
            </h1>
            <h2 style="font-size:20px;">
                <i>led by <a href="https://www.tomwagg.com">Tom Wagg</a> (Postdoc at the Flatiron Institute)</i>
            </h2>
            <p style="font-size:15px;">
                This lab focuses on how to track the timing and location of SNe in galaxies with <code>cogsworth</code>
            </p>
        </td>
        <td style="width:25%; padding:0;">
            <img src="../docs/_static/astronuc-header.png" style="max-width:100%;" />
        </td>
    </tr>
</table>

In [1]:
import cogsworth
import gala.potential as gp
import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u
import pandas as pd

In [2]:
# this all just makes plots look nice
%config InlineBackend.figure_format = 'retina'

plt.rc('font', family='serif')
plt.rcParams['text.usetex'] = False
fs = 24

# update various fontsizes to match
params = {'figure.figsize': (12, 8),
          'legend.fontsize': fs,
          'axes.labelsize': fs,
          'xtick.labelsize': 0.9 * fs,
          'ytick.labelsize': 0.9 * fs,
          'axes.linewidth': 1.1,
          'xtick.major.size': 7,
          'xtick.minor.size': 4,
          'ytick.major.size': 7,
          'ytick.minor.size': 4}
plt.rcParams.update(params)

# this makes sure every column in Pandas dataframes is shown
pd.set_option('display.max_columns', None)

<hr>

# Part 1: Your first population

## Demo

More detailed explanations of all of this code can be found [on the lab page](https://teamcogsworth.github.io/cogsworth-school/pages/labs/astronuc/part-1.html#demo).

### Initialise a Population

In [3]:
p = cogsworth.pop.Population(
    n_binaries=1000,
    use_default_BSE_settings=True
)
p

### Initial sampling

In [4]:
p.sample_initial_binaries()

In [5]:
p.initial_binaries.head()

In [6]:
p.initial_galaxy

In [7]:
print(p.initial_galaxy.positions)
print(p.initial_galaxy.tau)

### Stellar evolution

In [8]:
p.perform_stellar_evolution()

In [9]:
p.bpp.loc[:5]

In [10]:
p.kick_info.loc[:5]

In [11]:
p.final_bpp.loc[:5]

### Galactic orbit integration

In [12]:
p.perform_galactic_evolution()

In [13]:
p.orbits[:5]

### Future shortcut

### Inspect the most massive binary

Let's find the binary with the most massive primary star at ZAMS and take a look at its evolution.

In [14]:
most_massive = p.bin_nums[p.initial_binaries["mass_1"].argmax()]

In [15]:
p.bpp.loc[most_massive]

In [16]:
p.kick_info.loc[most_massive]

In [17]:
fig, ax = p.plot_cartoon_binary(bin_num=most_massive)

In [18]:
fig, axes = p.plot_orbit(bin_num=most_massive)

In [19]:
fig, axes = p.plot_orbit(bin_num=most_massive, t_max=100 * u.Myr)

## Tasks

### Task 1.1: Your own population

To start, initialise a population with 1000 binaries, then sampling the binaries, evolve them, and integrate their orbits.

What are the initial properties of the first few binaries in the population?

In [20]:
# your code here

### Task 1.2: Distributions

Now let's make some plots. First, what does **the distribution of galactic birth times** look like for the binaries in the population?

In [21]:
# your code here

### Task 1.3: Your favourite binary

Now pick a binary of interest to you and inspect its evolution with a cartoon plot and look at its orbit through the galaxy.

Some inspiration for picking a binary:

- The most massive binary in the population
- A binary that ends by creating at least one neutron star
- A random binary!

In [22]:
# your code here

<hr>

# Part 2: Selecting subpopulations of interest

In [23]:
p = cogsworth.pop.Population(
    n_binaries=10000,
    use_default_BSE_settings=True
)
p.create_population()

### Inspect initial conditions

In [24]:
fig, ax = plt.subplots()
ax.scatter(p.initial_binaries["mass_1"], p.initial_binaries["mass_2"],
           s=1, rasterized=True)

ax.set(
    xscale="log",
    yscale="log",
    xlabel="Initial primary mass, $M_{1, i}$ [M$_\odot$]",
    ylabel="Initial secondary mass, $M_{2, i}$ [M$_\odot$]",
)
plt.show()

### Mask based on initial conditions

In [25]:
mass_ratio_mask = (p.initial_binaries["mass_2"] / p.initial_binaries["mass_1"]) < 0.5

In [26]:
selected_bin_nums = p.bin_nums[mass_ratio_mask]
print(selected_bin_nums)

In [27]:
# mask with bin_nums
new_pop = p[selected_bin_nums]

In [28]:
# mask with boolean array with same length as p.bin_nums
new_pop = p[mass_ratio_mask]

In [29]:
# how many binaries met this?
print(len(new_pop))

# range of mass ratios in this population
q = new_pop.initial_binaries["mass_2"] / new_pop.initial_binaries["mass_1"]
print(q.max())

### Mask based on final state

In [30]:
primary_ends_as_wd = p.final_bpp["kstar_1"].isin([10, 11, 12])
secondary_ends_as_wd = p.final_bpp["kstar_2"].isin([10, 11, 12])
has_a_wd = primary_ends_as_wd | secondary_ends_as_wd

In [31]:
wd_pop = p[has_a_wd]
print(wd_pop)
print(wd_pop.final_bpp.head())

## Tasks

### Task 2.1: Initial condition of mergers

#### Task 2.1.1: Initial scatter plot

First, make a plot of the initial orbital period vs the initial primary mass for all binaries in the population.

In [32]:
# your code here

#### Task 2.1.2: Mask mergers

The final separation is given by the ``sep`` column in the ``final_bpp`` table. You can access this table with ``p.final_bpp``.

In [33]:
# your code here

#### Task 2.1.3: Highlight mergers on plot

Now, update your plot to highlight the binaries that will eventually merge (however you like, outline the merger points, or just overplot them in a different color, etc).

In [34]:
# your code here

#### Task 2.1.4: Plot discussion

What trends do you notice in your plot? Which conditions seem to lead to mergers? Why?

### [Task 2.2: [BONUS] Final positions of compact objects](#toc0_) 

#### Task 2.2.1: Final positions plot

First, make a plot of the final positions of the primary star from each binary in the population. Plot the Galactocentric radius ($R = \sqrt{x^2 + y^2}$) on the x-axis and the absolute Galactocentric height ($|z|$) on the y-axis. I recommend using a log-scale for both axes.

In [35]:
# your code here

#### Task 2.2.2: Mask compact objects

Now, create a mask that selects only binaries where either star ends as a neutron star or black hole (i.e. that receive a natal kick).

In [36]:
# your code here

#### Task 2.2.3: Highlight compact objects

Now, update your plot to highlight the binaries where the primary star ends as a neutron star or black hole.

In [37]:
# your code here

#### Task 2.2.4: Discuss trends

What trends do you notice in your plot? Do the compact objects seem to have different final positions than the rest of the population? Is that true for all of them? Why/why not?

In [38]:
# your words here

<hr>

# Part 3: Finding timing and location of SNe

## Demo

In [39]:
p = cogsworth.pop.Population(
    n_binaries=10000,
    use_default_BSE_settings=True,
    final_kstar1=[13, 14],
    final_kstar2=[13, 14]
)
p.create_population()

In [73]:
primary_sn = p.bpp["evol_type"] == 15
secondary_sn = p.bpp["evol_type"] == 16
sn_mask = primary_sn | secondary_sn

In [74]:
bins = np.linspace(0, 200, 40)

fig, ax = plt.subplots()

ax.hist(p.bpp["tphys"][primary_sn], bins=bins, density=True, label="Primary SN")
ax.hist(p.bpp["tphys"][secondary_sn], bins=bins, alpha=0.7, density=True, label="Secondary SN")
ax.set(
    xlabel="Time in frame of binary, $t_b$ [Myr]",
    ylabel=r"${\rm}dN/{\rm d}t_b$",
)
ax.legend()
plt.show()

In [75]:
# get the bin_nums of the supernova events
primary_sn_bin_nums = p.bpp["bin_num"][primary_sn]
secondary_sn_bin_nums = p.bpp["bin_num"][secondary_sn]

# get the indices of these bin_nums in the p.bin_nums array
primary_sn_indices = np.searchsorted(p.bin_nums, primary_sn_bin_nums)
secondary_sn_indices = np.searchsorted(p.bin_nums, secondary_sn_bin_nums)

# use these indices to get tau
primary_sn_tau = p.initial_galaxy.tau[primary_sn_indices]
secondary_sn_tau = p.initial_galaxy.tau[secondary_sn_indices]

# compute the galactic times
primary_sn_t_gal = p.max_ev_time - primary_sn_tau + p.bpp["tphys"][primary_sn].values * u.Myr
secondary_sn_t_gal = p.max_ev_time - secondary_sn_tau + p.bpp["tphys"][secondary_sn].values * u.Myr

In [76]:
bins = np.linspace(0, 12, 20)
fig, ax = plt.subplots()
ax.hist(primary_sn_t_gal.to(u.Gyr).value, bins=bins, density=True,
        label="Primary SN")
ax.hist(secondary_sn_t_gal.to(u.Gyr).value, bins=bins, alpha=0.7,
        density=True, label="Secondary SN")
ax.set(
    xlabel=r"Age of Milky Way at SN, $t_{\rm gal}$ [Gyr]",
    ylabel=r"${\rm}dN/{\rm}dt_{\rm gal}$",
)
ax.legend()
plt.show()

In [44]:
# let's take the first orbit
orbit_example = p.orbits[0]

# it stores the time and position at each timestep
print(orbit_example.t)
print(orbit_example.pos.xyz)

In [77]:
primary_sn_positions = np.zeros((len(primary_sn_indices), 3)) * u.kpc
secondary_sn_positions = np.zeros((len(secondary_sn_indices), 3)) * u.kpc

for i in range(len(primary_sn_indices)):
    # find the corresponding orbit
    primary_sn_orbit = p.primary_orbits[primary_sn_indices[i]]

    # compute the last timestep where orbit.t is less than primary_sn_t_gal[i]
    closest_time_index = np.where(primary_sn_orbit.t < primary_sn_t_gal[i])[0][-1]

    # get the position of the binary at this time
    primary_sn_positions[i] = primary_sn_orbit.pos.xyz[:, closest_time_index]

# same for secondaries
for i in range(len(secondary_sn_indices)):
    secondary_sn_orbit = p.secondary_orbits[secondary_sn_indices[i]]
    closest_time_index = np.where(secondary_sn_orbit.t < secondary_sn_t_gal[i])[0][-1]
    secondary_sn_positions[i] = secondary_sn_orbit.pos.xyz[:, closest_time_index]

In [78]:
fig, axes = plt.subplots(2, 1, figsize=(8, 9), gridspec_kw={"height_ratios": [1, 4]})
for pos, times in zip(
    [primary_sn_positions, secondary_sn_positions],
    [primary_sn_t_gal, secondary_sn_t_gal],
):

    XMAX = 30
    ZMAX = 7.5

    axes[0].scatter(
        pos[:, 0], pos[:, 2],
        c=times.to(u.Gyr).value, s=5,
        cmap="magma", vmin=0, vmax=12
    )

    axes[1].scatter(
        pos[:, 0], pos[:, 1],
        c=times.to(u.Gyr).value, s=5,
        cmap="magma", vmin=0, vmax=12
    )
axes[0].set(
    ylabel="$z$ [kpc]",
    xlim=(-XMAX, XMAX),
    ylim=(-ZMAX, ZMAX),
    aspect="equal",
)
axes[1].set(
    xlabel="Galactocentric $x$ [kpc]",
    ylabel="Galactocentric $y$ [kpc]",
    xlim=(-XMAX, XMAX),
    ylim=(-XMAX, XMAX),
    aspect="equal",
)

fig.colorbar(axes[0].collections[0], ax=axes, label="Age of Milky Way at SN [Gyr]")

plt.show()

## Tasks

### Task 3.1: Find the CEs

Create a population like the one above (~10000 binaries, that preferentially samples higher mass binaries). Write a mask for the ``bpp`` table that selects only the rows corresponding to common-envelope events. It may be useful to know that common-envelope events are labelled as ``evol_type == 7``.

In [48]:
# your code here

### Task 3.2: CE time histogram (binary frame)

Now make a histogram that shows the distribution of common-envelope event times in the frame of the binary.

What drives the timing of these common-envelope events?

What would happen if you made a scatter plot of these times against the initial primary mass of the binary? Or the initial orbital period?

In [49]:
# your code here

### Task 3.3: CE time histogram (galaxy frame)

Now compute the timing of these common-envelope events in the frame of the galaxy. What drives the distribution of timing of these common-envelope events on Galactic timescale?

In [50]:
# your code here

### Task 3.4: CE positions

Last but not least, let's find the positions of these common-envelope events in the galaxy!

Follow the same method as above to find the positions of these common-envelope events in the galaxy and make a plot of these positions like the one above, though :math:`x` and :math:`y` limits of 20 kpc and :math:`z` limits of 5 kpc should work well for this. (Why do you think these limits are smaller than for the supernovae plot above?)

In [51]:
# your code here

<hr>

# Part 4: Vary your assumptions

## Demo

In [52]:
template = cogsworth.pop.Population(
    n_binaries=10000,
    use_default_BSE_settings=True,
    final_kstar1=[13, 14],
    final_kstar2=[13, 14],
)
template.sample_initial_binaries()

In [53]:
fiducial = template.copy()
fiducial.perform_stellar_evolution()
fiducial.perform_galactic_evolution()

### Initial population distributions

In [54]:
diff_porb = template.copy()
diff_porb.sampling_params["porb_model"] = {'min': 0.15, 'max': 5.5, 'slope': 0.5}
diff_porb.create_population()

In [55]:
fig, ax = plt.subplots()

bins = np.logspace(0.15, 5.5, 30)
ax.hist(fiducial.initial_binaries["porb"], bins=bins, label="Sana+2012")
ax.hist(diff_porb.initial_binaries["porb"], bins=bins, label="Custom power law", alpha=0.7)
ax.set(
    xscale="log",
    xlabel="Initial orbital period [days]",
    ylabel="Number of binaries",
)
ax.legend()
plt.show()

In [56]:
for pop, label in [(fiducial, "Sana+2012"), (diff_porb, "Custom power law")]:
    n_mergers = (pop.final_bpp["sep"] == 0.0).sum()
    print(f"Number of mergers with {label} porb distribution:", n_mergers)

### Binary physics settings

In [57]:
weak_kick = template.copy()
weak_kick.BSE_settings["sigma"] = 20  # km/s
weak_kick.perform_stellar_evolution()
weak_kick.perform_galactic_evolution()

In [58]:
for pop, label in [(fiducial, "Fiducial"), (weak_kick, "Weak kicks")]:
    n_disrupted = pop.disrupted.sum()
    print(f"Number of disrupted binaries with {label}:", n_disrupted)

In [59]:
fid_dis_nums = fiducial.bin_nums[fiducial.disrupted]
weak_dis_nums = weak_kick.bin_nums[weak_kick.disrupted]

# find one that is disrupted in fiducial but not in weak_kick
example = weak_kick.bin_nums[np.isin(weak_kick.bin_nums, fid_dis_nums) & ~np.isin(weak_kick.bin_nums, weak_dis_nums)][0]

In [60]:
for pop, label in [(fiducial, "Fiducial"), (weak_kick, "Weak kicks")]:
    pop.plot_cartoon_binary(example)

In [61]:
fig, axes = fiducial.plot_orbit(example, show=False, t_max=200 * u.Myr)
fig, axes = weak_kick.plot_orbit(example, fig=fig, axes=axes, show_legend=False, t_max=200 * u.Myr)

### Galactic potential

In [62]:
nfw = gp.NFWPotential(m=1e12, r_s=15.63, units="galactic")

nfw_pop = fiducial.copy()
nfw_pop.galactic_potential = nfw
nfw_pop.perform_galactic_evolution()

In [63]:
disrupted_num = fiducial.bin_nums[fiducial.disrupted][14]
for pop, label in [(fiducial, "fid"), (nfw_pop, "nfw")]:
    fig, axes = pop.plot_orbit(disrupted_num)

## Tasks

### Task 4.1: Vary your initial conditions

#### Task 4.1.1: Choose an initial condition

Choose an initial condition to vary! Your full range of options is given [here in the COSMIC docs](https://cosmic-popsynth.github.io/COSMIC/pages/inifile.html#sampling).

Some inspiration for you:

- You could try one of the other built-in initial orbital period distributions?
- How does making the initial population entirely circular change things?
- What if you set the minimum mass ratio to a larger value like 0.5?

In [64]:
# your code here

#### Task 4.1.2: Compare initial distributions

Create a template population and then make two copies of it. For one copy, your "fiducial" simulation, just call ``fiducial.create_population()`` to create the population and then evolve it.
            
For the other copy, change one of the sampling parameters like how we did above, and then re-run the sampling step and the evolution steps.

Make a plot of the initial distribution that you changed for both populations to check that it changed in the way you expected.

In [65]:
# your code here

#### Task 4.1.3: Effect on supernovae

Use your code from Part 3 to get the timing and location of all supernovae in both populations. How do the supernova properties change when you change the initial conditions?

In [66]:
# your code here

### Task 4.2: Vary your binary physics assumptions

#### Task 4.2.1: Choose a setting

Choose a binary physics assumption to vary! Your full range of options is given [here in the COSMIC documentation](https://cosmic-popsynth.github.io/COSMIC/pages/inifile.html#binary-physics)

Some inspiration for you:

- Perhaps you could make common-envelopes 10x more efficient (``alpha1 = 10``)?
- What if you make stable mass transfer always nonconservative (``acc_lim = 0``)?
- Or maybe change how angular momentum is lost during Roche-lobe overflow at super-Eddington mass transfer rates? (``gamma``)?

In [67]:
# your code here

#### Task 4.2.2: Compare evolution

Create a template population and then make two copies of it. For one copy, your "fiducial" simulation, just call ``fiducial.create_population()`` to create the population and then evolve it.
            
For the other copy, change one of the binary physics parameters like how we did above, and then run just the stellar evolution and galactic evolution steps (be careful not to do the sampling or you'll get a different initial population!).

Pick a random binary in both populations and plot a cartoon of its evolution in both cases. Does it change how you would expect?

In [68]:
# your code here

#### Task 4.2.3: Effect on supernovae

Use your code from Part 3 to get the timing and location of all supernovae in both populations. How do the supernova properties change when you change the binary physics assumptions?

In [69]:
# your code here

### Task 4.3: Vary galactic potential

#### Task 4.3.1: Choose a potential

Try creating a different galactic potential and evolving your population through it! You can use any of the potentials implemented in [gala](https://gala.adrian.pw/en/latest/potential/index.html), but I'd probably recommend an NFW potential or a Miyamoto-Nagai potential for this task, with masses similar to the Milky Way.

In [70]:
# your code here

#### Task 4.3.2: Compare orbits

Create a template population and then make two copies of it. For one copy, your "fiducial" simulation, just call ``fiducial.create_population()`` to create the population and then evolve it.
            
For the other copy, update the potential like how we did above, and then run just the galactic evolution steps (be careful not to do the sampling or stellar evolution or you'll get a different initial population!).

Pick a random binary in both populations and plot its galactic orbit. Does it change how you would expect?

In [71]:
# your code here

#### Task 4.3.3: Effect on supernovae

Use your code from Part 3 to get the timing and location of all supernovae in both populations. How do the supernova properties change when you change the galactic potential?

In [72]:
# your code here