<a href="https://colab.research.google.com/github/13emilygriffith/FewProcessModel/blob/main/Notebooks/for_Griffith.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A two-process model for stellar abundances

## Authors:
- **Emily J. Griffith** (Colorado)
- **David W. Hogg** (NYU) (MPIA) (Flatiron)

## Major bugs / to-do items:
- We need to produce some kind of error estimates on everything.
- We should add a "jitter" term to the observational errors on the abundances.
- Does moving to a robust loss function help or change the results in any way? Some least-squares optimizers have these built in.

## Minor bugs:
- Ought to compute and track the full objective function.
- Haven't looked at dependences with regularization strengths.
- Shouldn't be taking a `sqrt` in the residual (chi) code.

## Comments
- `jax vmap()` completely changed what was possible in this project. We must acknowledge & cite them in the paper.

In [None]:
!pip install jaxopt

In [None]:
!pip install wget

In [None]:
pip install corner

In [None]:
import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import jaxopt
import wget
from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm
import os.path
from tqdm import tqdm
from jax import vmap, grad
import time
import corner

In [None]:
# Revise default plotting style
style_revisions = {
            'axes.linewidth': 1.5,
            'xtick.top' : True, 
            'ytick.right' : True, 
            'xtick.direction' : 'in',
            'ytick.direction' : 'in', 
            'xtick.major.size' : 11, 
            'ytick.major.size' : 11, 
            'xtick.minor.size' : 5.5, 
            'ytick.minor.size' : 5.5,
            'font.size' : 16,
            'figure.figsize' : [6, 6],
            'lines.linewidth' : 2.5,
        }
plt.rcParams.update(style_revisions)

In [None]:
# Set constants
ln10 = np.log(10.)

In [None]:
# Set hyper-parameters
processes = np.array(['CC', 'Ia']) # process names
K = len(processes) # number of processes!
Lambda_a = 1.e12 # regularization strength on Mg for CC and on Fe for Ia.
Lambda_b = 1.e12 # regularization strength on Mg for Ia

In [None]:
# Download files from Emily's website
url = 'https://www.emilyjgriffith.com/s/'
if(os.path.isfile('lnqs.npy')==False): wget.download(url+'lnqs.npy')
if(os.path.isfile('lnAs.npy')==False): wget.download(url+'lnAs.npy')
if(os.path.isfile('bins.npy')==False): wget.download(url+'bins.npy')
if(os.path.isfile('alldata.npy')==False): wget.download(url+'alldata.npy')
if(os.path.isfile('allivars.npy')==False): wget.download(url+'allivars.npy')

In [None]:
# Load numpy files
# N - number of stars = 34410
# M - number of elements = 16

elements  = np.array(['Mg','O','Si','S','Ca','CN','Na','Al','K','Cr','Fe','Ni','V','Mn','Co','Ce'])
metallicities = np.array(['-0.7', '-0.6', '-0.5', '-0.4', '-0.3', '-0.2', '-0.1', '0.0', '0.1', '0.2', '0.3', '0.4', ])

# lnqs: shape(2, 12, 16), 0 is qcc, 1 is qIa, replaced negative values with 0.05
w22_lnqs = np.load('lnqs.npy')
a, b, c = w22_lnqs.shape
while a < K:
    w22_lnqs = np.concatenate(w22_lnqs, something) # THIS WILL FAIL RN
    a, b, c = w22_lnqs.shape
assert a == len(processes)
assert b == len(metallicities)
assert c == len(elements)

# artificially raise the zeros above zero
w22_lnqs = np.clip(w22_lnqs, -7., None)

# lnAs: shape(2, 34410), 0 is Acc, 1 is AIa
w22_lnAs = np.load('lnAs.npy')
a, b = w22_lnAs.shape
while a < K:
    w22_lnqs = np.concatenate(w22_lnAs, something) # THIS WILL FAIL RN
    a, b = w22_lnAs.shape
assert a == len(processes)

# bins: shape(34410) index of metallicity bin 0 = -0.7, 11 = 0.4 (spaced by 0.1 dex)
bins = np.load('bins.npy').astype(int)

# alldata: shape(34410, 16), bad data = 0
alldata = np.load('alldata.npy')

# allivars: shape(34410,16), bad data = 0
allivars = np.load('allivars.npy')
assert allivars.shape == alldata.shape

In [None]:
def all_stars_K_process_model(lnAs, lnqs, bins):
    """
    ## inputs
    - `lnAs`: shape `(K, N)` natural-logarithmic amplitudes
    - `lnqs`: shape `(K, Nbin, M)` natural-logarithmic processes
    - `bins`: shape `(N, )` metallicity bin integers

    ## outputs
    shape `(M, )` log_10 abundances

    ## comments
    - Note the `ln10`.
    """
    return logsumexp(lnAs[:, :, None] + lnqs[:, bins, :], axis=0) / ln10

In [None]:
def plot_xmg_resid(elem, alldata, synthdata, xmax=0.4, xmin=-0.7, ymax=0.2, 
                   ymin=-0.5, resid_lim=0.05):
  """
  ## inputs
  - `elem`: sting of element name, capitalized
  - `alldata`: shape `(N, M)` log_10 abundance measurements
  - `synthdata`: shape `(N, M)` log_10 synthesized abundances
  - `xmax`: float plotting limit in [X/Mg] vs. [Mg/H] plots (default=0.4)
  - `xmin`: float plotting limit in [X/Mg] vs. [Mg/H] plots (default=-0.7)
  - `ymax`: float plotting limit in [X/Mg] vs. [Mg/H] plots (default=0.2)
  - `ymin`: float plotting limit in [X/Mg] vs. [Mg/H] plots (default=-0.5)
  - `resid_lim`: float plotting limit in Delta [X/H] vs. Delta [Mg/H] plots 
                  (default=0.05)

  ## outputs
  Plot with three pannels: 
    (1) [X/Mg] vs. [Mg/H] observed
    (2) [X/Mg] vs. [Mg/H] fit
    (3) Delta [X/H] vs. Delta [Mg/H] residuals
  """
  plt.figure(figsize=(12,4))

  if elem not in elements: 
    print('Element not valid')
    return None
  else: e_i = np.where(elements==elem)[0][0]

  plt.subplot(1,3,1)
  plt.hist2d(alldata[:,0], alldata[:,e_i]-alldata[:,0],cmap='magma', bins=100, 
          range=[[xmin,xmax],[ymin,ymax]], norm=LogNorm())
  plt.xlabel('[Mg/H]')
  plt.ylabel('['+elem+'/Mg]')
  plt.title('Observed')

  plt.subplot(1,3,2)
  plt.hist2d(synthdata[:,0], synthdata[:,e_i]-synthdata[:,0],cmap='magma', 
             bins=100, range=[[xmin,xmax],[ymin,ymax]], norm=LogNorm())
  plt.xlabel('[Mg/H]')
  plt.ylabel('['+elem+'/Mg]')
  plt.title('Simulated')

  plt.subplot(1,3,3)
  plt.hist2d(alldata[:,0]-synthdata[:,0],(alldata[:,e_i])-(synthdata[:,e_i]),
             cmap='magma', bins=100, norm=LogNorm(),
             range=[[-1*resid_lim,resid_lim],[-1*resid_lim,resid_lim]])
  plt.xlabel('$\Delta$ [Mg/H]')
  plt.ylabel('$\Delta$ ['+elem+'/H]')

  plt.tight_layout()
  plt.show()


def plot_resid_corner(alldata, synthdata, elements, resid_lim=0.15):
  """
  ## inputs
  - `alldata`: shape `(N, M)` log_10 abundance measurements
  - `synthdata`: shape `(N, M)` log_10 synthesized abundances
  - `elements': shape `(M)` string element names
  - `resid_lim`: float plotting limit in all 2D and 1D histograms
                  (default=0.15)

  ## outputs
  Corner plot of all possible combinations of elements
  """
  resid_data = np.array(alldata - synthdata)
  figure = corner.corner(resid_data, labels=elements, 
                         quantiles=[0.16, 0.5, 0.84], show_titles=True, 
                         title_kwargs={"fontsize": 12}, quiet=True,  
                         range=[(-1*resid_lim,resid_lim)]*len(elements))
  
def plot_star_abunds(star_idxs, alldata, allivars, synthdata, elements):
  """
  ## inputs
  - `star_idxs`: list of integer indexes of stars to plot (e.g. [0,1,2])
  - `alldata`: shape `(N, M)` log_10 abundance measurements
  - `allivars`: shape `(N, M)` inverse variances of observed abundances
  - `synthdata`: shape `(N, M)` log_10 synthesized abundances
  - `elements': shape `(M)` string element names

  ## outputs
  Plot of predicted and observed abundances with a row for each star
  """
  n_stars = len(star_idxs)
  err = np.sqrt(1/(allivars))
  plt.figure(figsize=(12,2*n_stars))

  for i in range(n_stars):
    ax = plt.subplot(n_stars, 1, i+1)
    plt.errorbar(range(len(elements)), alldata[star_idxs[i],:],
                 yerr=err[star_idxs[i],:], fmt='ko-', 
                  lw=1.5, label='Obs')  
    plt.plot(synthdata[star_idxs[i],:], 'ko--', markerfacecolor='None', lw=1.5,
             label='Synth')
    plt.xticks(range(len(elements)), elements)
    plt.ylabel('[X/H]')
    chi2 = np.sum(((alldata[star_idxs[i],:] - synthdata[star_idxs[i],:])**2) * 
                  allivars[star_idxs[i],:])
    plt.text(0.03,0.75, ' star %1.0f \n chi2= %2.3f' % (star_idxs[i], chi2), 
             transform=ax.transAxes, fontsize=10)
    avg_y = np.mean(synthdata[star_idxs[i],:])
    plt.ylim(avg_y-0.3, avg_y+0.3)
    plt.legend(ncol=2, loc='upper right', fontsize=10)

  plt.tight_layout()

def plot_chi2(alldata, allivars, synthdata, elements):
  """
  ## inputs
  - `alldata`: shape `(N, M)` log_10 abundance measurements
  - `allivars`: shape `(N, M)` inverse variances of observed abundances
  - `synthdata`: shape `(N, M)` log_10 synthesized abundances
  - `elements': shape `(M)` string element names

  ## outputs
  Plot with two panels (1) histograph of stellar chi^2 values and (2) plot of 
  chi^2 per element 
  """
  chi2_stars = np.sum(((alldata - synthdata)**2) * allivars, axis=1)
  chi2_elems = np.sum(((alldata - synthdata)**2) * allivars, axis=0)

  plt.figure(figsize=(8,8))

  plt.subplot(2,1,1)
  plt.hist(chi2_stars, bins=100, color='mediumslateblue', histtype='stepfilled')
  plt.yscale('log')
  plt.ylabel('N Stars')
  plt.xlabel(r'$\chi^2$ per star')

  plt.subplot(2,1,2)
  plt.plot(elements, chi2_elems, 'o--', color='mediumslateblue', lw=1.5)
  plt.ylabel(r'$\chi^2$ per elemnt')

  plt.tight_layout()


In [None]:
synthdata = all_stars_K_process_model(w22_lnAs, w22_lnqs, bins)

plot_xmg_resid('Fe', alldata, synthdata, resid_lim=0.1)
plot_resid_corner(alldata, synthdata, elements)
plot_star_abunds([0,1,2,674,10274], alldata, allivars, synthdata, elements)
plot_chi2(alldata, allivars, synthdata, elements)


In [None]:
def one_star_K_process_model(lnAs, lnqs):
    """
    ## inputs
    - `lnAs`: shape `(K,)` natural-logarithmic amplitudes
    - `lnqs`: shape `(K, M)` natural-logarithmic processes

    ## outputs
    shape `(M, )` log_10 abundances

    ## comments
    - Note the `ln10`.
    """
    return logsumexp(lnAs[:, None] + lnqs, axis=0) / ln10

def one_star_chi(lnAs, lnqs, data, ivars):
    """
    ## inputs
    - `lnAs`: shape `(2,)` natural-logarithmic amplitudes
    - `lnqs`: shape `(2, M)` natural-logarithmic processes
    - `data`: shape `(M, )` log_10 abundance measurements
    - `ivars`: shape `(M, )` inverse variances on the data

    ## outputs
    chi for this one star
    """
    return jnp.sqrt(ivars) * (data - one_star_K_process_model(lnAs, lnqs))

def one_star_A_step(lnqs, data, ivars, init=None):
    """
    ## inputs
    - `lnqs`: shape `(2, M)` natural-logarithmic processes
    - `data`: shape `(M, )` log_10 abundance measurements
    - `ivars`: shape `(M, )` inverse variances on the data

    ## outputs
    shape `(2,)` best-fit natural-logarithmic amplitudes

    ## bugs
    - Not at all tested.
    - Doesn't check the output of the optimizer AT ALL.
    """
    solver = jaxopt.GaussNewton(residual_fun=one_star_chi)
    if init is None:
        lnAs_init = jnp.log(jnp.array([0.5, 0.5]))
    else:
        lnAs_init = init.copy()

    chi2_init = np.sum(one_star_chi(lnAs_init, lnqs=lnqs, data=data, ivars=ivars) ** 2)
    res = solver.run(lnAs_init, lnqs=lnqs, data=data, ivars=ivars)
    chi2_res = np.sum(one_star_chi(res.params, lnqs=lnqs, data=data, ivars=ivars) ** 2)

    return res.params, chi2_init - chi2_res

def A_step(lnqs, alldata, allivars, bins, old_lnAs):
    """
    ## inputs
    - `lnqs`: shape `(2, Nbin, M)` natural-logarithmic processes
    - `alldata`: shape `(N, M)` log_10 abundance measurements
    - `allivars`: shape `(N, M)` inverse variances on alldata
    - `bins`: shape `(N, )` metallicity bin integers
    - `old_lnAs`: previous `lnAs`; used for initialization of the optimizer

    ## outputs
    shape `(2, N)` best-fit natural-logarithmic amplitudes

    ## bugs
    - Need to switch to a form of map that works well in jax??
    """
    N, M = alldata.shape
    foo, Nbin, bar = lnqs.shape
    assert lnqs.shape == (2, Nbin, M)
    assert allivars.shape == (N, M)
    assert bins.shape == (N, )
    assert old_lnAs.shape == (2, N)
    return vmap(one_star_A_step, in_axes=(1, 0, 0, 1), out_axes=(1, 0))(lnqs[:, bins], alldata, allivars, old_lnAs)

In [None]:
start = time.time()
new_lnAs, delta_chi2s = A_step(w22_lnqs, alldata, allivars, bins, w22_lnAs)
end = time.time()
print("A-step timing:", end - start)
print("A-step delta chi-squared range:", jnp.min(delta_chi2s), jnp.max(delta_chi2s))

In [None]:
def one_element_K_process_model(lnqs, lnAs, bins):
    """
    ## inputs
    - `lnqs`: shape `(K, Nbin)` natural-logarithmic process elements
    - `lnAs`: shape `(K, N)` natural-logarithmic amplitudes
    - `bins`: shape `(N, )` metallicity bin integers

    ## outputs
    shape `(N, )` log_10 abundances

    ## comments
    - Note the `ln10`.
    """
    return logsumexp(lnqs[:, bins] + lnAs, axis=0) / ln10

def q_step_regularization(lnqs):
    """
    Build arrays that are used for the regularization of the q step.

    ## inputs
    - `lnqs`: shape `(2, Nbin, M)` current values of the lnqs.

    ## outputs
    `Lambdas, q0s` regularization amplitudes and mean values; same shape as `lnqs`.

    ## bugs:
    - Depends on many global variables and choices.
    """
    Lambdas = np.zeros_like(lnqs)
    q0s = np.zeros_like(lnqs) + 0.5 # default values

    # First point: Strongly require that q_Mg = 1 for CC
    elem = elements == "Mg"
    proc = processes == "CC"
    Lambdas[proc, :, elem] = Lambda_a
    q0s[    proc, :, elem] = 1.0

    # Second point: Strongly require that q_Fe = 1 for Ia
    elem = elements == "Fe"
    proc = processes == "Ia"
    Lambdas[proc, :, elem] = Lambda_a
    q0s[    proc, :, elem] = 0.5

    # Third point: Weakly require that q_Mg = 0 for Ia
    elem = elements == "Mg"
    proc = processes == "Ia"
    Lambdas[proc, :, elem] = Lambda_b
    q0s[    proc, :, elem] = 0.0

    return Lambdas, q0s

def one_element_chi(lnqs, lnAs, data, ivars, bins, Lambdas, q0s):
    """
    ## inputs
    - `lnqs`: shape `(2, Nbin)` natural-logarithmic process vectors
    - `lnAs`: shape `(2, N)` natural-logarithmic amplitudes
    - `data`: shape `(N, )` log_10 abundance measurements
    - `ivars`: shape `(N, )` inverse variances on the data
    - `bins`: shape `(N, )` metallicity bin integers
    - `Lambdas`: shape `(2, Nbin)` list of regularization amplitudes
    - `q0s`: shape `(2, Nbin)` 

    ## outputs
    chi for this one star (weighted residual)
    """
    return jnp.concatenate([jnp.sqrt(ivars) * (data - one_element_K_process_model(lnqs, lnAs, bins)),
                            jnp.ravel(jnp.sqrt(Lambdas) * (jnp.exp(lnqs) - q0s))])

def one_element_q_step(lnAs, data, ivars, bins, Nbin, Lambdas, q0s, init=None):
    """
    ## inputs
    - `lnAs`: shape `(2, N)` natural-logarithmic amplitudes
    - `data`: shape `(N, )` log_10 abundance measurements
    - `ivars`: shape `(N, )` inverse variances on the data
    - `bins`: shape `(N, )` metallicity bin integers 

    ## outputs
    shape `(2, Nbin)` best-fit natural-logarithmic process elements

    ## bugs
    """
    solver = jaxopt.GaussNewton(residual_fun=one_element_chi)
    if init is None:
        lnqs_init = jnp.log(jnp.array([[0.5, ] * Nbin,[0.5, ] * Nbin])) 
    else:
        lnqs_init = init.copy()

    chi2_init = np.sum(one_element_chi(lnqs_init, lnAs=lnAs, data=data, ivars=ivars, 
                               bins=bins,Lambdas=Lambdas, q0s=q0s) ** 2)
    res = solver.run(lnqs_init, lnAs=lnAs, data=data, ivars=ivars, bins=bins,
                     Lambdas=Lambdas, q0s=q0s)
    chi2_res = np.sum(one_element_chi(res.params, lnAs=lnAs, data=data, ivars=ivars, 
                               bins=bins,Lambdas=Lambdas, q0s=q0s) ** 2)

    return res.params, chi2_init - chi2_res

def q_step(lnAs, alldata, allivars, bins, Lambdas, q0s, old_lnqs):
    """
    ## inputs
    - `lnAs`: shape `(2, N)` natural-logarithmic amplitudes
    - `alldata`: shape `(N, M)` log_10 abundance measurements
    - `allivars`: shape `(N, M)` inverse variances on alldata
    - `bins`: shape `(N, )` metallicity bin integers 
    - `Lambdas`: something
    - `q0s`: something
    - `old_lnqs`: shape `(2, Nbin, M)` initialization for optimizations

    ## outputs
    shape `(2, Nbin, M)` best-fit natural-logarithmic processes

    ## bugs
    - Should be a `vmap()` not a `for` loop. Maybe?
    """
    N, M = alldata.shape
    Nbin = np.max(bins) + 1
    assert lnAs.shape == (2, N)
    assert allivars.shape == (N, M)
    assert bins.shape == (N, )
    assert Lambdas.shape == (2, Nbin, M)
    assert q0s.shape == (2, Nbin, M)
    assert old_lnqs.shape == (2, Nbin, M)
    lnqs1 = np.zeros((2, Nbin, M))
    return vmap(one_element_q_step, in_axes=(None, 1, 1, None, None, 2, 2, 2),
               out_axes=(2, 0))(lnAs, alldata, allivars, bins, Nbin, Lambdas, q0s, old_lnqs)

In [None]:
start = time.time()
Lambdas, q0s = q_step_regularization(w22_lnqs)
new_lnqs, delta_chi2s = q_step(w22_lnAs, alldata, allivars, bins, Lambdas, q0s, w22_lnqs)
end = time.time()
print("q-step timing:", end - start)
print("q-step shapes", w22_lnqs.shape, new_lnqs.shape)
print("q-step delta-chi-squareds:", delta_chi2s)

In [None]:
new_qs = np.exp(new_lnqs)
qs = np.exp(w22_lnqs)

elems = elements
MgH = [float(m) for m in metallicities]

plt.figure(figsize=(10,10))
for i in range(16):
  plt.subplot(4,4,i+1)
  new_qcc = new_qs[0,:,i]
  new_qIa = new_qs[1,:,i]
  qcc = qs[0,:,i]
  qIa = qs[1,:,i]

  plt.plot(MgH, qcc, 'b-', label='qcc W22')
  plt.plot(MgH, qIa, 'r-', label='qIa W22')

  plt.plot(MgH, new_qcc, 'c-', label='qcc new')
  plt.plot(MgH, new_qIa, 'm-', label='qIa new')

  plt.xlabel('[Mg/H]')
  plt.ylabel('q '+elems[i])

  if i==0:
    plt.legend(ncol=1, fontsize=10)
  #plt.ylim(-0.1,1.1)
plt.tight_layout()

In [None]:
synthdata1 = all_stars_K_process_model(new_lnAs, new_lnqs, bins)
plt.figure(figsize=(12,4))

plt.subplot(1,3,1)
plt.hist2d(alldata[:,0], alldata[:,10]-alldata[:,0],cmap='magma', bins=100, 
         range=[[-0.7,0.4],[-0.5,0.2]], norm=LogNorm())
plt.xlabel('[Mg/H]')
plt.ylabel('[Fe/Mg]')
plt.ylim(-0.5,0.2)
plt.title('Observed')

plt.subplot(1,3,2)
plt.hist2d(synthdata1[:,0], synthdata1[:,10]-synthdata1[:,0],cmap='magma', bins=100, 
         range=[[-0.7,0.4],[-0.5,0.2]], norm=LogNorm())
plt.xlabel('[Mg/H]')
plt.ylabel('[Fe/Mg]')
plt.ylim(-0.5,0.2)
plt.title('Fit')

plt.subplot(1,3,3)
plt.hist2d(alldata[:,0]-synthdata1[:,0],(alldata[:,10])-(synthdata1[:,10]),
           cmap='magma', bins=100, range=[[-0.05,0.05],[-0.05,0.05]], norm=LogNorm())
plt.xlabel('[Mg/H] obs - fit')
plt.ylabel('[Fe/H] obs - fit')

plt.tight_layout()


In [None]:
# Okay now do some iterations
list_lnAs = []
list_lnqs = []
new_lnAs = w22_lnAs.copy()
new_lnqs = w22_lnqs.copy() 
list_lnAs.append(new_lnAs)
list_lnqs.append(new_lnqs)

for i in range(33):
    start = time.time()
    new_lnAs, dc2s = A_step(new_lnqs, alldata, allivars, bins, new_lnAs)
    print("A-step timing:", i, time.time() - start)
    print("A-step Chi2:", i,  np.nanmin(dc2s), np.nanmax(dc2s))
    print("A-step bad elements:", i,  np.sum(np.isnan(new_lnAs)))
    start = time.time()
    Lambdas, q0s = q_step_regularization(new_lnqs)
    new_lnqs, dc2s = q_step(new_lnAs, alldata, allivars, bins, Lambdas, q0s, new_lnqs)
    print("q-step timing:", i, time.time() - start)
    print("q-step Chi2:", i,  np.nanmin(dc2s), np.nanmax(dc2s))
    print("q-step bad elements:", i,  np.sum(np.isnan(new_lnqs)))
    if i < 2 or np.isclose(np.log2(i), np.round(np.log2(i))):
        print("appending")
        list_lnAs.append(new_lnAs)
        list_lnqs.append(new_lnqs)

In [None]:
NUM_COLORS = len(list_lnqs) - 1
cm = plt.get_cmap('jet')
colors =["grey"]
for i in range(NUM_COLORS):
    color = cm(1.*i/NUM_COLORS)
    colors.append(color)

In [None]:
# plot evolution of qcc
plt.figure(figsize=(15,10))
MgH = [float(m) for m in metallicities]
for j in range(16):
  plt.subplot(4,4,j+1)
  for i, lnqs in enumerate(list_lnqs):
    label = None
    if j == 0 and i == 0:
        label = "W22"
    if j == 0 and i + 1 == len(list_lnqs):
        label = "final iteration"
    qs = np.exp(lnqs)
    qcc = qs[0,:,j]
    plt.plot(MgH, qcc, '-', color=colors[i], lw=1.5, label=label)
    plt.xlabel('[Mg/H]')
    plt.ylabel('q_cc '+elements[j])
    plt.ylim(np.min(qcc)-0.101, np.max(qcc)+0.101)
  if j == 0:
      plt.legend()
plt.tight_layout()

In [None]:
# plot evolution of qIa
plt.figure(figsize=(15,10))
MgH = [float(m) for m in metallicities]
for j in range(16):
  plt.subplot(4,4,j+1)
  for i, lnqs in enumerate(list_lnqs):
    qs = np.exp(lnqs)
    qIa = qs[1,:,j]
    plt.plot(MgH, qIa, '-', color=colors[i], lw=1.5)
    plt.xlabel('[Mg/H]')
    plt.ylabel('q_Ia '+elements[j])
    plt.ylim(np.min(qIa)-0.101, np.max(qIa)+0.101)
plt.tight_layout()

In [None]:
synthdata1 = all_stars_K_process_model(new_lnAs, new_lnqs, bins)
plt.figure(figsize=(12,4))

plt.subplot(1,3,1)
plt.hist2d(alldata[:,0], alldata[:,10]-alldata[:,0],cmap='magma', bins=100, 
         range=[[-0.7,0.4],[-0.5,0.2]], norm=LogNorm())
plt.xlabel('[Mg/H]')
plt.ylabel('[Fe/Mg]')
plt.ylim(-0.5,0.2)
plt.title('Observed')

plt.subplot(1,3,2)
plt.hist2d(synthdata1[:,0], synthdata1[:,10]-synthdata1[:,0],cmap='magma', bins=100, 
         range=[[-0.7,0.4],[-0.5,0.2]], norm=LogNorm())
plt.xlabel('[Mg/H]')
plt.ylabel('[Fe/Mg]')
plt.ylim(-0.5,0.2)
plt.title('Fit')

plt.subplot(1,3,3)
plt.hist2d(alldata[:,0]-synthdata1[:,0],(alldata[:,10]-synthdata1[:,10]),
           cmap='magma', bins=100, range=[[-0.05,0.05],[-0.05,0.05]], norm=LogNorm())
plt.xlabel('[Mg/H] obs - fit')
plt.ylabel('[Fe/H] obs - fit')

plt.tight_layout()

In [None]:
# plot evolution of [X/H]_data - [X/H]_synth

plt.figure(figsize=(15,10))
sigma = 1/np.sqrt(allivars)


for j in range(16):
  plt.subplot(4,4,j+1)

  mask = np.where(np.isfinite(sigma[:,j]))[0]
  plt.text(-10,8E2, elements[j] + r' $\bar{\sigma}= $' + 
           str(np.round(sigma[:,j][mask].mean(), 4)))

  for i, (lnqs, lnAs) in enumerate(zip(list_lnqs, list_lnAs)):
    synthdata1 = all_stars_K_process_model(lnAs, lnqs, bins)
    plt.hist((alldata[:,j]-synthdata1[:,j])*np.sqrt(allivars[:,j]), color=colors[i], lw=1, 
             range=[-10,10], bins=100, histtype='step', stacked=True, 
             fill=False)
  plt.xlabel('(dat-mod)/$\sigma$')
  plt.ylabel('Counts')

plt.tight_layout()

The q vectors and A values can be used to derive the fractional CCSN contribution at a given metallicity

$f_{\rm CC}^X = \dfrac{A_{\rm CC} q_{\rm CC}^x(z)}{A_{\rm CC} q_{\rm CC}^x(z) + A_{\rm Ia} q_{\rm Ia}^x(z)} = [1 + (A_{\rm Ia}/A_{\rm CC})(q_{\rm Ia}^x/q_{\rm CC}^x)]^{-1}$

In [None]:
w22_AIaAcc_lowIa = np.array([0.055, 0.036, 0.051, 0.053, 0.058, 0.089, 0.128, 
                             0.189, 0.350, 0.548, 0.636, 0.632])
w22_AIaAcc_highIa = np.array([0.710, 0.753, 0.734, 0.719, 0.766, 0.875, 0.960, 
                              1.000, 1.028, 1.042, 1.042, 1.018])

In [None]:
np.shape(w22_lnqs)

In [None]:
f_metals = [float(m) for m in metallicities]

plt.figure(figsize=(15,10))
for i in range(len(elements)):
  plt.subplot(4,4,i+1)
  fccs = []
  for j in range(len(metallicities)):
    mask = np.where(bins==j)[0]
    bin_lnAs = new_lnAs[:,mask]
    bin_Acc = np.nanmedian(np.exp(bin_lnAs[0,:]))
    bin_AIa = np.nanmedian(np.exp(bin_lnAs[1,:]))
    bin_qcc = np.exp(new_lnqs[0,j,i])
    bin_qIa = np.exp(new_lnqs[1,j,i])
    fcc = (1 + ((bin_AIa/bin_Acc)*(bin_qIa/bin_qcc)))**-1
    fccs.append(fcc)
    if metallicities[j] == '0.0': print(elements[i], fcc)

  w22_fcc_lowIa = (1 + ((w22_AIaAcc_lowIa)*(np.exp(w22_lnqs[1,:,i])/
                                            np.exp(w22_lnqs[0,:,i]))))**-1
  w22_fcc_highIa = (1 + ((w22_AIaAcc_highIa)*(np.exp(w22_lnqs[1,:,i])/
                                            np.exp(w22_lnqs[0,:,i]))))**-1

  plt.plot(f_metals, w22_fcc_lowIa, 'r-', lw=1, label='W22 lowIa')
  plt.plot(f_metals, w22_fcc_highIa, 'b-', lw=1, label='W22 highIa')
  plt.plot(f_metals, fccs, 'ko-', label='Final Itteration')
  
  #plt.title(elements[i])
  plt.xlabel('[Mg/H]')
  plt.ylabel('f_cc '+ elements[i])
  plt.ylim(-0.05,1.05)
  if i ==0: plt.legend(fontsize=10)
plt.tight_layout()
plt.show()

In [None]:
def all_stars_fcc(lnAs, lnqs, bins):
    """
    ## inputs
    - `lnAs`: shape `(K, N)` natural-logarithmic amplitudes
    - `lnqs`: shape `(K, Nbin, M)` natural-logarithmic processes
    - `bins`: shape `(N, )` metallicity bin integers

    ## outputs
    shape `(M, )` fcc

    """
    Accqcc = np.exp(lnAs[0,:,None])*np.exp(lnqs[0,bins,:])
    AIaqIa = np.exp(lnAs[1,:,None])*np.exp(lnqs[1,bins,:])
    return Accqcc / (Accqcc + AIaqIa)

In [None]:
f_metals = [float(m) for m in metallicities]

plt.figure(figsize=(15,10))
fccs = all_stars_fcc(new_lnAs, new_lnqs, bins)
for i in range(len(elements)):
  plt.subplot(4,4,i+1)

  w22_fcc_lowIa = (1 + ((w22_AIaAcc_lowIa)*(np.exp(w22_lnqs[1,:,i])/
                                            np.exp(w22_lnqs[0,:,i]))))**-1
  w22_fcc_highIa = (1 + ((w22_AIaAcc_highIa)*(np.exp(w22_lnqs[1,:,i])/
                                            np.exp(w22_lnqs[0,:,i]))))**-1

  plt.plot(f_metals, w22_fcc_lowIa, 'r-', lw=1, label='W22 lowIa')
  plt.plot(f_metals, w22_fcc_highIa, 'b-', lw=1, label='W22 highIa')
  plt.hist2d(synthdata1[:,0], fccs[:,i], norm=LogNorm(), bins=100, range=[[-0.8,0.6],[-0.1,1.1]])
  
  #plt.title(elements[i])
  plt.xlabel('[Mg/H]')
  plt.ylabel('f_cc '+ elements[i])
  plt.ylim(-0.05,1.05)
  if i ==0: plt.legend(fontsize=10)
plt.tight_layout()
plt.show()

In [None]:
print(elements)

In [None]:
plt.plot(np.exp(new_lnAs[0,:]), np.exp(new_lnAs[1,:])/np.exp(new_lnAs[0,:]), ',')
plt.xlim(0,3)
plt.ylim(-0.5,2)
plt.xlabel('Acc')
plt.ylabel('AIa/Acc')

Thoughts 
- In W22 and G22, we plot AIa/Acc vs Acc, which acts similarly to the Mg/H vs Fe/H plot where you can see the bimodality. In our work, the two populations are exactly separated because we use the populations and the median trends to sovle for the qs and As. When doing the data driven model, do we also preserve the bimodality?
- Knowing that CCSN and SNIa dominate the production of most of these elements, fit the two process model and present the residuals. Then fix the process vectors and amplitudes and fit a third process to see if it can account for the remaining abundances?