# Quick plotting for SFR79 from Observations and Simulations

#### <span style='color:blue'> Adapted to also include xCG individual data points in addition to binned-SDSS based grid of values! </span>

### Imports

In [None]:
import numpy as np
import pandas as pd

from pathlib import Path
from datetime import datetime
import warnings

import matplotlib as mpl
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm, trange

import SiGMo as sgm

### Defining Star-forming Galaxy Main Sequence

Follows the definiton of Saintonge+2016, Eq. 5. They use SDSS DR7 with 0.01 < z < 0.05.

def GMS_Saintonge2016(logMstar):
    '''computes the log10SFR for any given log10Mstar on Galaxy Main Seqence according to Saintonge+2016, Eq.5'''
    return ((-2.332) * logMstar + 0.4156 * logMstar**2 - 0.01828 * logMstar**3)

### Helper function: compute SFR79 from SFR

Generalised helper function, can be used for SDSS grid-style data and individual data points (in a 1-d grid). Computes the log SFR79 from the "actual" simulation SFR averages:

$$\mathrm{SFR79} = \frac{\mathrm{SFR}_\mathrm{5 Myr}}{\mathrm{SFR}_\mathrm{800 Myr}}$$

def compute_SFR79_from_SFR(gal_a):
    """
    Helper function to compute the star-formation change parameter SFR79 directly
    from the SFR over time data

    :param gal_a: numpy array of SiGMo.Snapshot objects over time. Can have arbitrary
    shape, as long as the different Snapshots over time for each object are in the
    last dimension
    :return: SFR79_a: numpy array with same shape as gal_a except for the last
    dimension; the time dimension is replaced by three summary values returned
    in the new last dimension:
     [0] average SFR over 5 Myr (in units of M_sol);
     [1] average SFR over 800 Myr (in units of M_sol);
     [2] log SFR 79 (being log10(avrgSFR_5Myr / avrgSFR_800Myr).
    Example: if gal_a.shape = (10, 16, 4, 1999), then SFR79.shape = (10, 16, 4, 3)
    """
    SFR79_a = np.empty(shape=(*gal_a.shape[:-1], 3), dtype=object)
    SFR79_a_fl = SFR79_a.reshape(-1, *SFR79_a.shape[-1:])
    gal_a_fl = gal_a.reshape(-1, *gal_a.shape[-1:])
    for i, gal_seq in enumerate(tqdm(gal_a_fl)):
        lookbacktime_a = np.array([gal.data["lookbacktime"] for gal in gal_seq])
        SFR_a = np.array([gal.data["SFR"] for gal in gal_seq])

        # avrg over 5 Myr
        SFR7_a = SFR_a[lookbacktime_a <= 0.005]
        SFR7_avg = np.sum(SFR7_a) / len(SFR7_a)
        SFR79_a_fl[i, 0] = SFR7_avg
        # print(SFR79_grid_fl[i, 0])

        # avrg over 800 Myr
        SFR9_a = SFR_a[lookbacktime_a <= 0.8]
        SFR9_avg = np.sum(SFR9_a) / len(SFR9_a)
        SFR79_a_fl[i, 1] = SFR9_avg
        # print(SFR79_grid_fl[i, 1])

        # log10(SFR_5Myr / SFR_800Myr)
        SFR79_a_fl[i, 2] = np.log10(SFR7_avg / SFR9_avg)
        # print(SFR79_grid_fl[i, 2])

    return SFR79_a

### Helper function: read in all data from directory

#### Version for *both*: huge numbers of single snapshots *or* one large multi snapshot per timestep
single_snapshot(s) = True when making/writing the snaps ➡️ set single_snapshots = True
single_snapshot(s) = False when making/writing the snaps ➡️ set single_snapshots = False

### Define directories

In [None]:
project_dir = Path.cwd().parent
sfr79_dir = project_dir / 'data' / "SFR79_grids"
# snp_dir = project_dir / 'outputs' / '_tmp' / "2022.03.08-17.52.38" / "0_forward_800Myr_1e-4"  # forwards at higher than standard time res
# snp_dir = project_dir / 'outputs' / '_tmp' / "2022.03.08-17.52.38" / "1_backward_800Myr_1e-4"  # backwards at higher than standard time res
# snp_dir_SDSSsim = project_dir / 'outputs' / '_tmp' / "2022.03.10-21.01.29" / "0_backward_2Gyr_1e-3"


# SDSS-based simulations
# 2 Gyr backwards runs (new, June 22)

# off-MS sMIR scaling basefactor: 0
snp_dir_SDSSsim = project_dir / 'outputs' / '_tmp' / "2022.06.22-11.41.56" / "0_backward_2Gyr_dt1e-3_sMIR_scaling_basefactor0.0"
# # off-MS sMIR scaling basefactor: 2
# snp_dir_SDSSsim = project_dir / 'outputs' / '_tmp' / "2022.06.22-19.32.36" / "0_backward_2Gyr_dt1e-3_sMIR_scaling_basefactor2.0"
# # off-MS sMIR scaling basefactor: 4
# snp_dir_SDSSsim = project_dir / 'outputs' / '_tmp' / "2022.06.22-20.52.43" / "0_backward_2Gyr_dt1e-3_sMIR_scaling_basefactor4.0"
# # off-MS sMIR scaling basefactor: 6
# snp_dir_SDSSsim = project_dir / 'outputs' / '_tmp' / "2022.06.22-21.32.20" / "0_backward_2Gyr_dt1e-3_sMIR_scaling_basefactor6.0"
# # off-MS sMIR scaling basefactor: 8
# snp_dir_SDSSsim = project_dir / 'outputs' / '_tmp' / "2022.06.22-22.08.31" / "0_backward_2Gyr_dt1e-3_sMIR_scaling_basefactor8.0"
# # off-MS sMIR scaling basefactor: 10
# snp_dir_SDSSsim = project_dir / 'outputs' / '_tmp' / "2022.06.22-22.40.42" / "0_backward_2Gyr_dt1e-3_sMIR_scaling_basefactor10.0"


# snp_dir_xCGsim = project_dir / 'outputs' / '_tmp' / "2022.03.22-20.00.02" / "0_backward_2Gyr_dt1e-4_wtd10"   # longer backwards at old standard time res
# snp_dir_xCGsim = project_dir / 'outputs' / '_tmp' / "2022.03.24-11.43.40" / "0_backward_2Gyr_dt1e-3"   # longer backwards at old standard time res


# xCOLD GASS-based simulations
# 2 Gyr backwards runs (new, June 22)

# off-MS sMIR scaling basefactor: 0
snp_dir_xCGsim = project_dir / 'outputs' / '_tmp' / "2022.06.22-14.13.10" / "0_backward_2Gyr_dt1e-3_sMIR_scaling_basefactor0.0"
# # off-MS sMIR scaling basefactor: 2
# snp_dir_xCGsim = project_dir / 'outputs' / '_tmp' / "2022.06.22-19.53.41" / "0_backward_2Gyr_dt1e-3_sMIR_scaling_basefactor2.0"
# # off-MS sMIR scaling basefactor: 4
# snp_dir_xCGsim = project_dir / 'outputs' / '_tmp' / "2022.06.22-21.02.00" / "0_backward_2Gyr_dt1e-3_sMIR_scaling_basefactor4.0"
# # off-MS sMIR scaling basefactor: 6
# snp_dir_xCGsim = project_dir / 'outputs' / '_tmp' / "2022.06.22-21.41.37" / "0_backward_2Gyr_dt1e-3_sMIR_scaling_basefactor6.0"
# # off-MS sMIR scaling basefactor: 8
# snp_dir_xCGsim = project_dir / 'outputs' / '_tmp' / "2022.06.22-22.19.06" / "0_backward_2Gyr_dt1e-3_sMIR_scaling_basefactor8.0"
# # off-MS sMIR scaling basefactor: 10
# snp_dir_xCGsim = project_dir / 'outputs' / '_tmp' / "2022.06.22-22.55.02" / "0_backward_2Gyr_dt1e-3_sMIR_scaling_basefactor10.0"


# ATTENTION! plotting directory is based on the xCG directory parent name/name here!
plot_dir = project_dir / 'plots' / '_tmp' / snp_dir_xCGsim.parent.name / snp_dir_xCGsim.name   # plot dir has now same datetime name as output dir
plot_dir.mkdir(exist_ok=True, parents=True)   # create plot dir only if necessary

### Read-in: Observational Data (SDSS and xCG)

In [None]:
# reading observational results

# SDSS
sfr79_medians = np.loadtxt(str(sfr79_dir / "SFR79_2dhist_medians.txt"))
mstar_mesh =  np.loadtxt(str(sfr79_dir / "SFR79_2dhist_binedges_mstar_mesh.txt"))
sfr_mesh = np.loadtxt(str(sfr79_dir / "SFR79_2dhist_binedges_sfr_mesh.txt"))
n_binned = np.loadtxt(str(sfr79_dir / "SFR79_2dhist_binnumbers.txt"))

# removing low-number bins (if wanted, as before when setting up the simulation)
# if not wanted: comment the next two lines out!
n_binned_min = 40
sfr79_medians = np.where(n_binned >= n_binned_min, sfr79_medians, np.nan)

# xCG
xCG_df = pd.read_csv(sfr79_dir / "xCOLD_GASS_with_SDSS_SFR79_df.csv")
xCG_minimal_selector = ((xCG_df.LOGMSTAR > -900) & (xCG_df.LOGSFR_BEST > -900) & (xCG_df.LOGMH2 > - 900))
xCG_gasdetect_selector = (xCG_df.FLAG_CO == 1)
xCG_sfr79 = np.squeeze(xCG_df.loc[xCG_minimal_selector & xCG_gasdetect_selector, "SFR79values"].to_numpy())

### Read in SDSS-based grid-like simulation data (that was created using binned SDSS data as ICs)

In [None]:
# setting the snapshot type flag! False -> from one big Env snap; True -> from many indiv. snaps
single_snapshots = False

In [None]:
# how many objects per timestep?
n_envs = 1
# assuming same mstar-sfr pairs in the sims as in the obs
n_halos = len(sfr79_medians.flat) - len(sfr79_medians[np.where(np.isnan(sfr79_medians))].flat)
n_gals = n_halos

# reading simulation results (SDSS)
env_grid_SDSSsim, halo_grid_SDSSsim, gal_grid_SDSSsim = sgm.read_all_snapshots_from_dir(snp_dir=snp_dir_SDSSsim,
                                                                                        n_envs=n_envs,
                                                                                        n_halos=n_halos,
                                                                                        n_gals=n_gals,
                                                                                        single_snapshots=single_snapshots)

In [None]:
# re-create/populate halo_grid and galaxy_grid (for FIRST Environment only) if necessary
if not single_snapshots:
    for t, env_snp in enumerate(tqdm(env_grid_SDSSsim[0])):  # HARDCODED to 1st Environment only!
        for i_halo, halo in enumerate(env_snp.data['halos']):
            halo_grid_SDSSsim[i_halo, t] = sgm.Snapshot(halo)
            for i_gal, gal in enumerate(halo['galaxies']):
                gal_grid_SDSSsim[i_halo, t] = sgm.Snapshot(gal)  # HARDCODED to 1 Galaxy per Halo!!

### SDSS-based sims: calculate SFR79 from SFRs

Compute the log SFR79 from the "actual" simulation SFR averages:

$$\mathrm{SFR79} = \frac{\mathrm{SFR}_\mathrm{5 Myr}}{\mathrm{SFR}_\mathrm{800 Myr}}$$

In [None]:
SFR79_grid_SDSSsim = sgm.compute_SFR79_from_SFR(gal_grid_SDSSsim)

print("SFR79_grid_SDSSsim:\n", SFR79_grid_SDSSsim)

### Read in xCOLD GASS-based simulation data (that was created using individual xCG detections as ICs)

In [None]:
# setting the snapshot type flag! False -> from one big Env snap; True -> from many indiv. snaps
single_snapshots = False

In [None]:
# how many objects per timestep?
n_envs = 1
# this is "unknown" unless looked up someplace else, but if 'None' will be determined from some Environment snapshot
n_halos = None
n_gals = None

# reading simulation results (SDSS)
env_grid_xCGsim, halo_grid_xCGsim, gal_grid_xCGsim = sgm.read_all_snapshots_from_dir(snp_dir=snp_dir_xCGsim,
                                                                                        n_envs=n_envs,
                                                                                        n_halos=n_halos,
                                                                                        n_gals=n_gals,
                                                                                        single_snapshots=single_snapshots)

In [None]:
# re-create/populate halo_grid and galaxy_grid (for FIRST Environment only) if necessary
if not single_snapshots:
    for t, env_snp in enumerate(tqdm(env_grid_xCGsim[0])):  # HARDCODED to 1st Environment only!
        for i_halo, halo in enumerate(env_snp.data['halos']):
            halo_grid_xCGsim[i_halo, t] = sgm.Snapshot(halo)
            for i_gal, gal in enumerate(halo['galaxies']):
                gal_grid_xCGsim[i_halo, t] = sgm.Snapshot(gal)  # HARDCODED to 1 Galaxy per Halo!!

### xCOLD GASS-based sims: calculate SFR79 from SFRs

Compute the log SFR79 from the "actual" simulation SFR averages:

$$\mathrm{SFR79} = \frac{\mathrm{SFR}_\mathrm{5 Myr}}{\mathrm{SFR}_\mathrm{800 Myr}}$$

In [None]:
SFR79_grid_xCGsim = sgm.compute_SFR79_from_SFR(gal_grid_xCGsim)

print("SFR79_grid_xCGsim:\n", SFR79_grid_xCGsim)

### **SFR79 from observations vs from simulations**

In [None]:
# sort simulation results into an array of same shape as observational results
mstar_lower = 10**mstar_mesh[:-1, :-1]
mstar_upper = 10**mstar_mesh[1:, 1:]
sfr_lower = 10**(sfr_mesh[:-1, :-1]) * 10**9  # multiply by 10**9 because of 'pre yr' to 'per Gyr' conversion
sfr_upper = 10**(sfr_mesh[1:, 1:]) * 10**9  # multiply by 10**9 because of 'pre yr' to 'per Gyr' conversion

sfr79_SDSSsim_grid = np.full_like(sfr79_medians, np.nan)
# for (sfr79, gal) in zip(SFR79_grid[:, -1], gal_grid[:, -1]):  # special setting if you want the shape at the end
for (sfr79, gal) in zip(SFR79_grid_SDSSsim[:, -1], gal_grid_SDSSsim[:, 0]):  # regular setting when you want the shape from the start
    w_mstar_lower = mstar_lower < gal.data['mstar']
    w_mstar_upper = mstar_upper > gal.data['mstar']
    w_sfr_lower = sfr_lower < gal.data['SFR']
    w_sfr_upper = sfr_upper > gal.data['SFR']

    w_mstar = w_mstar_lower & w_mstar_upper
    w_sfr = w_sfr_lower & w_sfr_upper

    w = w_mstar & w_sfr

    sfr79_SDSSsim_grid[w] = sfr79

In [None]:
# print some statistics (also displayed in plot)
print("min log SFR79 from sims", np.nanmin(sfr79_SDSSsim_grid))
print("max log SFR79 from sims", np.nanmax(sfr79_SDSSsim_grid))

In [None]:
# colormap/norm etc. for scatter plot itself
sfr79_range = (-2, 2)
cmap = mpl.cm.RdBu
norm = mpl.colors.Normalize(vmin=sfr79_range[0], vmax=sfr79_range[1])
mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)

# set up figure and axes
fig, (ax_obs, ax_sims, ax_cbar) = plt.subplots(1, 3,
                                               gridspec_kw={
                                                   'width_ratios': (9, 9, 1),
                                                   'hspace': 0.05
                                               },
                                               figsize=(21, 9))

# plot observational data
im_obs = ax_obs.pcolormesh(mstar_mesh, sfr_mesh,
                           sfr79_medians,
                           cmap=cmap, norm=norm)

# plot simulation data
im_sims = ax_sims.pcolormesh(mstar_mesh, sfr_mesh,
                             sfr79_SDSSsim_grid,
                             cmap=cmap, norm=norm)

# plot colorbar
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
             ax=ax_cbar,
             fraction=0.8,
             extend='both',
             anchor=(0.0, 0.0),
             label='log SFR79')

# adding the Galaxy Main Sequence on top (Saintonge+2016, Eq. 5)
GMS_x = np.linspace(start=np.min(mstar_mesh),
                    stop=np.max(mstar_mesh),
                    num=1000,
                    endpoint=True)
for ax in [ax_obs, ax_sims]:
    ax.plot(GMS_x, sgm.GMS_Saintonge2016(GMS_x),
            color='xkcd:magenta', ls=':')

# remove unnecessary axes
ax_cbar.remove()

# figure labelling etc
for ax in [ax_obs, ax_sims]:
    ax.set_xlabel(r'log $M_\star$ [$M_\odot$]')
    ax.set_ylabel(r'log SFR [$M_\odot \, yr^{-1}$]')
ax_sims.text(0.95, 0.05,
             f"min(log SFR79) = {np.nanmin(sfr79_SDSSsim_grid):.3f}\n"
             f"max(log SFR79) = {np.nanmax(sfr79_SDSSsim_grid):.3f}",
             transform=ax_sims.transAxes,
             va='bottom', ha='right')

In [None]:
# add xCG observational and simulation SFR79

# first: check that SFR79 arrays from sim and obs are SAME LENGTH!
assert len(xCG_sfr79) == len(SFR79_grid_xCGsim[:, -1])

# define which data to plot
gal_grid = gal_grid_xCGsim
halo_grid = halo_grid_xCGsim
env_grid = env_grid_xCGsim

# which snp is to be used for mstar-SFR values (aka position on the plot)?
which_snp = 0  # intial mstar, SFR

which_objects = range(0, len(gal_grid), 1)
x_data = []
y_data = []
c_data_obs = []
c_data_sim = []
c_colours_obs = []
c_colours_sim = []
for i in tqdm(which_objects):
    x_data.append([np.log10(gal_grid[i, which_snp].data['mstar'])])
    y_data.append([np.log10(gal_grid[i, which_snp].data['SFR']) - 9.])  # -9. is for conversion from /Gyr to /yr
    c_data_obs.append(xCG_sfr79[i])
    c_data_sim.append(SFR79_grid_xCGsim[i, -1])
    c_colours_obs.append(mapper.to_rgba(xCG_sfr79[i]))
    c_colours_sim.append(mapper.to_rgba(SFR79_grid_xCGsim[i, -1]))

# plotting the xCG obs data points
ax_obs.scatter(x=x_data, y=y_data, c=c_colours_obs, cmap=cmap, norm=norm, edgecolors='xkcd:light grey')

# plotting the xCG sim data points
ax_sims.scatter(x=x_data, y=y_data, c=c_colours_sim, cmap=cmap, norm=norm, edgecolors='xkcd:light grey')

In [None]:
# save heat-map plot to disk
fig.savefig(plot_dir / f'SFR79_obs_vs_sims_{datetime.now().strftime("%Y.%m.%d-%H.%M.%S")}.png')

### SDSS binned obs SFR79 with individual xCOLD GASS sims

In [None]:
# colormap/norm etc. for scatter plot itself
sfr79_range = (-2, 2)
cmap = mpl.cm.RdBu
norm = mpl.colors.Normalize(vmin=sfr79_range[0], vmax=sfr79_range[1])
mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)

# set up figure and axes
fig, (ax_obs, ax_cbar) = plt.subplots(
    1, 2,
    gridspec_kw={
       # 'width_ratios': (18, 1),
       'width_ratios': (18, 1.2),
       'hspace': 0.05
    },
    # figsize=(11, 9),
    # figsize=(7.5, 4.5),
    figsize=(7.2, 5.5),
    # tight_layout=True
    constrained_layout=True
)

# plot observational data
im_obs = ax_obs.pcolormesh(
    mstar_mesh, sfr_mesh,
    sfr79_medians,
    cmap=cmap,
    norm=norm,
    label="observational SFR79: SDSS",
)


# plot colorbar
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
             cax=ax_cbar,
             fraction=0.8,
             extend='both',
             # anchor=(0.0, 0.0),
             label='log SFR79')

# adding the Galaxy Main Sequence on top (Saintonge+2016, Eq. 5) <-- changed that to Saintonge+2022
GMS_x = np.linspace(start=np.min(mstar_mesh),
                    stop=np.max(mstar_mesh),
                    num=1000,
                    endpoint=True)
for ax in [ax_obs]:
    handle_GMS = ax.plot(GMS_x, sgm.GMS_Saintonge2022(GMS_x, log=True),
            color='xkcd:magenta', ls='--', label="GMS: Saintonge & Catinella (2022)")

# remove unnecessary axes
# ax_cbar.remove()

# figure labelling etc
for ax in [ax_obs]:
    ax.set_xlabel(r'log $M_\star$ [$M_\odot$]')
    ax.set_ylabel(r'log SFR [$M_\odot \, yr^{-1}$]')
    # ax.set(xlim=(6.5, 12), ylim=(-4.5, 3.5))
    # ax.set(xlim=(7, 12), ylim=(-2.5, 2))
    ax.set(xlim=(8.45, 11.6), ylim=(-2.3, 1.8))
    ax.tick_params(axis='both', which='both', direction='in', bottom=True, top=True, left=True, right=True)

# ax_obs.text(0.95, 0.05,
#              f"min(log SFR79) = {np.nanmin(sfr79_SDSSsim_grid):.3f}\n"
#              f"max(log SFR79) = {np.nanmax(sfr79_SDSSsim_grid):.3f}",
#              transform=ax_obs.transAxes,
#              va='bottom', ha='right')





# add xCG observational and simulation SFR79

# first: check that SFR79 arrays from sim and obs are SAME LENGTH!
assert len(xCG_sfr79) == len(SFR79_grid_xCGsim[:, -1])

# define which data to plot
gal_grid = gal_grid_xCGsim
halo_grid = halo_grid_xCGsim
env_grid = env_grid_xCGsim

# which snp is to be used for mstar-SFR values (aka position on the plot)?
which_snp = 0  # intial mstar, SFR

which_objects = range(0, len(gal_grid), 1)
x_data = []
y_data = []
c_data_obs = []
c_data_sim = []
c_colours_obs = []
c_colours_sim = []
for i in tqdm(which_objects):
    x_data.append([np.log10(gal_grid[i, which_snp].data['mstar'])])
    y_data.append([np.log10(gal_grid[i, which_snp].data['SFR']) - 9.])  # -9. is for conversion from /Gyr to /yr
    c_data_obs.append(xCG_sfr79[i])
    c_data_sim.append(SFR79_grid_xCGsim[i, -1])
    c_colours_obs.append(mapper.to_rgba(xCG_sfr79[i]))
    c_colours_sim.append(mapper.to_rgba(SFR79_grid_xCGsim[i, -1]))

# plotting the xCG sim data points
handle_scatter = ax_obs.scatter(
    x=x_data,
    y=y_data,
    s=(mpl.rcParams['lines.markersize']*2)**2,
    c=c_colours_sim,
    cmap=cmap,
    norm=norm,
    edgecolors='xkcd:light grey',
    label="simulated SFR79: xCOLD GASS",
)


# make the legend markers for scatter plot
handle_scatter_custom = (
    # mpl.lines.Line2D([0], [0], color=cmap(0.1), ls='', marker='o', markeredgecolor=handle_scatter._edgecolors[0], markersize=np.sqrt(handle_scatter._sizes[0])),
    mpl.lines.Line2D([0], [0], color=cmap(.75), ls='', marker='o', markeredgecolor=handle_scatter._edgecolors[0], markersize=np.sqrt(handle_scatter._sizes[0])),
    mpl.lines.Line2D([0], [0], color=cmap(.5), ls='', marker='o', markeredgecolor=handle_scatter._edgecolors[0], markersize=np.sqrt(handle_scatter._sizes[0])),
    mpl.lines.Line2D([0], [0], color=cmap(.25), ls='', marker='o', markeredgecolor=handle_scatter._edgecolors[0], markersize=np.sqrt(handle_scatter._sizes[0])),
    # mpl.lines.Line2D([0], [0], color=cmap(0.9), ls='', marker='o', markeredgecolor=handle_scatter._edgecolors[0], markersize=np.sqrt(handle_scatter._sizes[0])),
)

handle_pcolormesh_custom =(
    # mpl.patches.Patch(facecolor=cmap(0.25)),
    # mpl.lines.Line2D([0], [0], color=cmap(.75), ls='', marker='s', markersize=0.9*np.sqrt(handle_scatter._sizes[0])),
    mpl.lines.Line2D([0], [0], color=cmap(.65), ls='', marker='s', markersize=np.sqrt(handle_scatter._sizes[0])),
    # mpl.lines.Line2D([0], [0], color=cmap(.5), ls='', marker='s', markersize=0.9*np.sqrt(handle_scatter._sizes[0])),
    mpl.lines.Line2D([0], [0], color=cmap(.45), ls='', marker='s', markersize=np.sqrt(handle_scatter._sizes[0])),
    # mpl.lines.Line2D([0], [0], color=cmap(.25), ls='', marker='s', markersize=0.9*np.sqrt(handle_scatter._sizes[0])),
    # mpl.lines.Line2D([0], [0], color=cmap(.7), ls='', marker='s', markersize=1.25*np.sqrt(handle_scatter._sizes[0])),
)

handle_GMS_custom = (
    mpl.lines.Line2D([0], [0], color=handle_GMS[0]._color, linestyle=handle_GMS[0]._linestyle)
)


ax_obs.legend(
    [handle_pcolormesh_custom, handle_scatter_custom, handle_GMS_custom],
    [(im_obs._label), (handle_scatter._label), (handle_GMS[0]._label)],
    handler_map={tuple: mpl.legend_handler.HandlerTuple(ndivide=None)},
    # loc='upper left',
    loc='lower right',
    # handlelength=1,
    # handleheight=1,
    framealpha=0.87,
)

In [None]:
# save heat-map plot to disk
fig.savefig(plot_dir / f'SFR79_SDSSobs_vs_xCGsims_{datetime.now().strftime("%Y.%m.%d-%H.%M.%S")}.png', dpi=300)

### **select galaxies/halos: plot quantities and differences for individual param combinations over time**

In [None]:
# define which data to plot
gal_grid = gal_grid_xCGsim

# define what to plot
x_type = 'lookbacktime'
y_type = 'SFR'
# which_objects = [0, 100, 200, 300, 400, 500, 600, 700, 800]
# which_objects = [0, 50, 100, 150, 200, 250, 300]
which_objects = [0, 50, 100, 150, 200, 250, -1]

# grab plotting data
x_data = []
y_data = []
label_l = []
for i in tqdm(which_objects):
    x_data.append([gal.data[x_type] for gal in gal_grid[i][:]])
    y_data.append([gal.data[y_type] for gal in gal_grid[i][:]])
    # y_data.append([halo.data[y_type] for halo in halo_grid[i][:]])
    label_l.append(r"$\mathrm{M}_\star$" + f"= {gal_grid[i, 0].data['mstar']:.2e} \t SFR = {gal_grid[i, 0].data['SFR'] / 10 ** 9:.2e}")

# initialise plot
fig, ax = plt.subplots(nrows=2, figsize=(9, 9), constrained_layout=True)

# plot actual values
for i, (x, y, label) in enumerate(zip(x_data, y_data, label_l)):
    ax[0].plot(x, y, label=label)


# plot one quantity like SFR79
for i, label in zip(which_objects, label_l):
    ax[1].plot(i, SFR79_grid_xCGsim[i, -1], 'o', label=label)

# additional fig and axes config
fig.suptitle(f'{y_type}: comparing {len(which_objects)} simulated galaxies', fontsize=16)

ax[0].invert_xaxis()
ax[0].set_xlabel(x_type)
ax[0].set_yscale('log')
ax[0].set_ylabel(y_type)

ax[1].set_xlabel('simulation number (arb. index)')
ax[1].set_ylabel(f'log SFR79')

ax[1].legend()

# save to disk
fig.savefig(plot_dir / f'comparing_{y_type}_of_{len(which_objects)}_simulated_galaxies'
                       f'_{datetime.now().strftime("%Y.%m.%d-%H.%M.%S")}.png')

### **all galaxies/halos: plot quantities and differences for individual param combinations over time**

In [None]:
# define which data to plot
gal_grid = gal_grid_xCGsim
halo_grid = halo_grid_xCGsim
env_grid = env_grid_xCGsim

# # define which data to plot
# gal_grid = gal_grid_SDSSsim
# halo_grid = halo_grid_SDSSsim
# env_grid = env_grid_SDSSsim

# define what to plot
x_type = 'lookbacktime'
y_type = 'SFR'
# c_type = 'mtot'
# c_type = 'mtot_over_mstar'
c_type = 'mtot_over_mgas'
c_snp = 0
# c_snp = -1
c_alpha = 0.25
which_objects = range(0, len(gal_grid), 1)

# grab plotting data
x_data = []
y_data = []
c_data = []
label_l = []
for i in tqdm(which_objects):
    x_data.append([gal.data[x_type] for gal in gal_grid[i][:]])
    y_data.append([gal.data[y_type] for gal in gal_grid[i][:]])
    # y_data.append([halo.data[y_type] for halo in halo_grid[i][:]])
    if gal_grid[i][c_snp].data['mgas'] == 0 and halo_grid[i][c_snp].data['mtot'] != 0:
        _c_data_i = [np.nan]
    else:
        _c_data_i = [halo_grid[i][c_snp].data['mtot'] / gal_grid[i][c_snp].data['mgas']]
    # c_data.append([halo_grid[i][c_snp].data[c_type]])
    c_data.append(_c_data_i)
    label_l.append(r"$\mathrm{M}_\star$" + f"= {gal_grid[i, 0].data['mstar']:.2e} \t SFR = {gal_grid[i, 0].data['SFR'] / 10 ** 9:.2e}")


# normalise the colour data to make colour the data according to it (e.g. total halo mass mtot)
c_data_range = (np.nanmin(c_data), np.nanmax(c_data))
cmap = mpl.cm.viridis_r
norm = mpl.colors.LogNorm(vmin=c_data_range[0], vmax=c_data_range[1])
mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)


# initialise plot
fig, ax = plt.subplots(nrows=2, figsize=(9, 9), constrained_layout=True)

# plot actual values
for i, (x, y, c, label) in enumerate(zip(x_data, y_data, c_data, label_l)):
    c_mapped = mapper.to_rgba(c)
    c_mapped[:, -1] = c_alpha
    ax[0].plot(x, y, c=c_mapped, label=label)


# plot one quantity like SFR79
for i, c, label in zip(which_objects, c_data, label_l):
    c_mapped = mapper.to_rgba(c)
    c_mapped[:, -1] = c_alpha
    ax[1].plot(c_data[i], SFR79_grid_xCGsim[i, -1], c=c_mapped, marker='o', label=label)


# additional fig and axes config
fig.suptitle(f'{y_type}: comparing {len(which_objects)} simulated galaxies', fontsize=16)


ax[0].invert_xaxis()
ax[0].set_xlabel(x_type)
ax[0].set_yscale('log')
ax[0].set_ylabel(y_type)

ax[1].set_xscale('log')
ax[1].set_xlabel(c_type)
ax[1].set_ylabel(f'log SFR79')
ax[1].text(0.95, 0.05,
           f"{c_type} at lookbacktime = {env_grid[0, c_snp].data['lookbacktime']:.3f}",
           transform=ax[1].transAxes,
           va='bottom', ha='right')


In [None]:
# save to disk
fig.savefig(plot_dir / f'comparing_{y_type}_of_{len(which_objects)}_simulated_galaxies'
                       f'_{datetime.now().strftime("%Y.%m.%d-%H.%M.%S")}.png')

### Residuals between xCG SFR79 from observations and from simulations

In [None]:
xCG_df

In [None]:
x_data_a = np.array(x_data)
y_data_a = np.array(y_data)
c_data_obs_a = np.array(c_data_obs)
c_data_sim_a = np.array(c_data_sim)

xCG_gas_mass_fraction = np.squeeze(xCG_df.loc[xCG_minimal_selector & xCG_gasdetect_selector, "gas_mass_fraction"].to_numpy())
xCG_SFRexcessGMS = np.squeeze(xCG_df.loc[xCG_minimal_selector & xCG_gasdetect_selector, "SFRexcessGMS"].to_numpy())
xCG_median_SFR79_in_bin = np.squeeze(xCG_df.loc[xCG_minimal_selector & xCG_gasdetect_selector, "median_SFR79_in_bin"].to_numpy())


In [None]:
fig, ax = plt.subplots(2, 1, figsize=(12, 12), constrained_layout=True)

c_data_plot_a = c_data_obs_a - c_data_sim_a

cmap = mpl.cm.viridis
norm = mpl.colors.Normalize(vmin=min(xCG_SFRexcessGMS), vmax=max(xCG_SFRexcessGMS))
mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
ax[0].scatter(xCG_gas_mass_fraction, c_data_plot_a, c=xCG_SFRexcessGMS, cmap=cmap, norm=norm, marker='o', s=50)
ax[0].set_xlabel("Gas Mass Fraction $\log(M_\mathrm{H2} \; / \; M_{\star})$")
fig.colorbar(mapper, ax=ax[0], label=r"Δ SFR [dex] above/below GMS")
# poly fit
exp_fit = np.polyfit(x=xCG_gas_mass_fraction, y=c_data_plot_a, deg=1)
p = np.poly1d(exp_fit)
x_lin = np.linspace(min(xCG_gas_mass_fraction), max(xCG_gas_mass_fraction), 1000)
ax[0].plot(x_lin, p(x_lin), zorder=-10, c='xkcd:grey', ls='--')

cmap = mpl.cm.magma
norm = mpl.colors.Normalize(vmin=min(xCG_gas_mass_fraction), vmax=max(xCG_gas_mass_fraction))
mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
ax[1].scatter(xCG_SFRexcessGMS, c_data_plot_a, c=xCG_gas_mass_fraction, cmap=mpl.cm.magma, marker='o', s=50)
ax[1].set_xlabel(r"Δ SFR [dex] above/below GMS")
fig.colorbar(mapper, ax=ax[1], label="Gas Mass Fraction $\log(M_\mathrm{H2} \; / \; M_{\star})$")
# poly fit
exp_fit = np.polyfit(x=xCG_SFRexcessGMS, y=c_data_plot_a, deg=1)
p = np.poly1d(exp_fit)
x_lin = np.linspace(min(xCG_SFRexcessGMS), max(xCG_SFRexcessGMS), 1000)
ax[1].plot(x_lin, p(x_lin), zorder=-10, c='xkcd:grey', ls='--')

ylim_vals = (None, 2.6)
for ax_i in ax:
    ax_i.set_ylabel(r'log(SFR79$_\mathrm{obs}$) - log(SFR79$_\mathrm{sim}$)')
    ax_i.set_ylim(*ylim_vals)
    out_of_plot = 0
    out_of_plot_up = 0
    out_of_plot_down = 0
    if not ylim_vals[0] is None:
        if np.any(c_data_plot_a < ylim_vals[0]):
            out_of_plot_down = np.sum(c_data_plot_a < ylim_vals[0])
            out_of_plot += out_of_plot_down
    if not ylim_vals[1] is None:
        if np.any(c_data_plot_a > ylim_vals[1]):
            out_of_plot_up = np.sum(c_data_plot_a > ylim_vals[1])
            out_of_plot += out_of_plot_up
    ax_i.text(0.03, 0.94,
              (f"data points out of plot: {out_of_plot} ({out_of_plot_up}" +
               r'$\uparrow$ + ' + f"{out_of_plot_down}" + r'$\downarrow$)'),
              transform=ax_i.transAxes,
              va='top', ha='left')

In [None]:
# save residuals plot to disk
fig.savefig(plot_dir / f'residuals_SFR79_obs_vs_sims_{datetime.now().strftime("%Y.%m.%d-%H.%M.%S")}.png')

In [None]:
fig, ax = plt.subplots(figsize=(12, 9), constrained_layout=True)

sfr79_range = (-2, 2)
cmap = mpl.cm.RdBu
norm = mpl.colors.Normalize(vmin=sfr79_range[0], vmax=sfr79_range[1])
mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)

ax.scatter(xCG_gas_mass_fraction, xCG_SFRexcessGMS, c=c_data_obs_a, cmap=cmap, norm=norm, marker='o', edgecolor='xkcd:grey', s=50)
fig.colorbar(mapper, ax=ax, label=r"log SFR79 (obs)", extend='both')
# ax.set_facecolor('xkcd:grey')

# poly fit
exp_fit = np.polyfit(x=xCG_gas_mass_fraction, y=xCG_SFRexcessGMS, deg=1)
p = np.poly1d(exp_fit)
x_lin = np.linspace(min(xCG_gas_mass_fraction), max(xCG_gas_mass_fraction), 1000)
ax.plot(x_lin, p(x_lin), zorder=-10, c='xkcd:grey', ls='--')

# formatting to figure
ax.set_xlabel("Gas Mass Fraction $\log(M_\mathrm{H2} \; / \; M_{\star})$")
ax.set_ylabel(r"Δ SFR [dex] above/below GMS")

In [None]:
# save residuals plot to disk
fig.savefig(plot_dir / f'correlation_gasmassfrac_to_deltaSFR_cc_SFR79obs_{datetime.now().strftime("%Y.%m.%d-%H.%M.%S")}.png')