In [1]:
import numpy as np

from Starfish.grid_tools import download_PHOENIX_models

ranges = [[5700, 8600], [4.0, 6.0], [-0.5, 0.5]]  # T, logg, Z

download_PHOENIX_models(path="PHOENIX", ranges=ranges)

lte08600-6.00+0.5.PHOENIX-ACES-AGSS-COND-2011-HiRes.fits: 100%|██████████| 330/330 [05:30<00:00,  1.00s/it]


In [2]:
from Starfish.grid_tools import PHOENIXGridInterfaceNoAlpha

grid = PHOENIXGridInterfaceNoAlpha(path="PHOENIX")

In [3]:
from Starfish.grid_tools.instruments import SPEX
from Starfish.grid_tools import HDF5Creator

import os
# check if grid_file exists, if so, load it and use it

grid_file = "F_SPEX_grid.hdf5"
if os.path.exists(grid_file):
    print("Loading existing grid file")
    # grid = HDF5Creator.load(grid_file)
    
else:
    creator = HDF5Creator(
        grid, "F_SPEX_grid.hdf5", instrument=SPEX(), wl_range=(0.9e4, np.inf), ranges=ranges
    )
    creator.process_grid()

Processing [8.6e+03 6.0e+00 5.0e-01]: 100%|██████████| 330/330 [06:13<00:00,  1.13s/it]   


In [5]:
from Starfish.emulator import Emulator
grid_file = "F_SPEX_grid.hdf5"

# can load from string or HDF5Interface
emu = Emulator.from_grid(grid_file)
emu

Emulator
--------
Trained: False
lambda_xi: 1.000
Variances:
	10000.00
	10000.00
	10000.00
	10000.00
Lengthscales:
	[ 600.00  1.50  1.50 ]
	[ 600.00  1.50  1.50 ]
	[ 600.00  1.50  1.50 ]
	[ 600.00  1.50  1.50 ]
Log Likelihood: -1272.34

In [6]:
%time emu.train(options=dict(maxiter=1e5))
emu

KeyboardInterrupt: 

Emulator
--------
Trained: False
lambda_xi: 0.985
Variances:
	23248.07
	2029.89
	663.12
	156.11
Lengthscales:
	[ 1486.69  2.82  2.26 ]
	[ 1464.05  1.20  2.36 ]
	[ 1871.30  1.54  1.80 ]
	[ 868.99  1.00  1.63 ]
Log Likelihood: -815.60

In [None]:
%matplotlib inline
from Starfish.emulator.plotting import plot_emulator

plot_emulator(emu)

In [None]:
from Starfish.spectrum import Spectrum

data = Spectrum.load("data/example_spec.hdf5")


In [None]:
from Starfish.models import SpectrumModel

model = SpectrumModel(
    "F_SPEX_emu.hdf5",
    data,
    grid_params=[6800, 4.2, 0],
    Av=0,
    global_cov=dict(log_amp=38, log_ls=2),
)
model

In [None]:
model.plot();


In [None]:
model.freeze("logg")
model.labels  # These are the fittable parameters

import scipy.stats as st

priors = {
    "T": st.norm(6800, 100),
    "Z": st.uniform(-0.5, 0.5),
    "Av": st.halfnorm(0, 0.2),
    "global_cov:log_amp": st.norm(38, 1),
    "global_cov:log_ls": st.uniform(0, 10),
}

In [None]:
%time model.train(priors)
model

In [None]:
model.plot();


In [None]:
import emcee

emcee.__version__


In [None]:
model.load("example_MAP.toml")
model.freeze("global_cov")
model.labels

In [None]:
import numpy as np

# Set our walkers and dimensionality
nwalkers = 50
ndim = len(model.labels)

# Initialize gaussian ball for starting point of walkers
scales = {"T": 1, "Av": 0.01, "Z": 0.01}

ball = np.random.randn(nwalkers, ndim)

for i, key in enumerate(model.labels):
    ball[:, i] *= scales[key]
    ball[:, i] += model[key]

In [None]:
# our objective to maximize
def log_prob(P, priors):
    model.set_param_vector(P)
    return model.log_likelihood(priors)


# Set up our backend and sampler
backend = emcee.backends.HDFBackend("example_chain.hdf5")
backend.reset(nwalkers, ndim)
sampler = emcee.EnsembleSampler(
    nwalkers, ndim, log_prob, args=(priors,), backend=backend
)

In [None]:
max_n = 1000

# We'll track how the average autocorrelation time estimate changes
index = 0
autocorr = np.empty(max_n)

# This will be useful to testing convergence
old_tau = np.inf


# Now we'll sample for up to max_n steps
for sample in sampler.sample(ball, iterations=max_n, progress=True):
    # Only check convergence every 10 steps
    if sampler.iteration % 10:
        continue

    # Compute the autocorrelation time so far
    # Using tol=0 means that we'll always get an estimate even
    # if it isn't trustworthy
    tau = sampler.get_autocorr_time(tol=0)
    autocorr[index] = np.mean(tau)
    index += 1
    # skip math if it's just going to yell at us
    if np.isnan(tau).any() or (tau == 0).any():
        continue
    # Check convergence
    converged = np.all(tau * 10 < sampler.iteration)
    converged &= np.all(np.abs(old_tau - tau) / tau < 0.01)
    if converged:
        print(f"Converged at sample {sampler.iteration}")
        break
    old_tau = tau

In [None]:
import arviz as az
import corner

print(az.__version__, corner.__version__)

In [None]:
reader = emcee.backends.HDFBackend("example_chain.hdf5")
full_data = az.from_emcee(reader, var_names=model.labels)

In [None]:
az.plot_trace(full_data);


In [None]:
tau = reader.get_autocorr_time(tol=0)
burnin = int(tau.max())
thin = int(0.3 * np.min(tau))
burn_samples = reader.get_chain(discard=burnin, thin=thin)
log_prob_samples = reader.get_log_prob(discard=burnin, thin=thin)
log_prior_samples = reader.get_blobs(discard=burnin, thin=thin)

dd = dict(zip(model.labels, burn_samples.T))
burn_data = az.from_dict(dd)

In [None]:
az.plot_trace(burn_data);


In [None]:
az.summary(burn_data)


In [None]:
az.plot_posterior(burn_data, ["T", "Z", "Av"]);


In [None]:
# See https://corner.readthedocs.io/en/latest/pages/sigmas.html#a-note-about-sigmas
sigmas = ((1 - np.exp(-0.5)), (1 - np.exp(-2)))
corner.corner(
    burn_samples.reshape((-1, 3)),
    labels=model.labels,
    quantiles=(0.05, 0.16, 0.84, 0.95),
    levels=sigmas,
    show_titles=True,
);

In [None]:
best_fit = dict(az.summary(burn_data)["mean"])
model.set_param_dict(best_fit)
model

In [None]:
model.plot();


In [None]:
model.save("example_sampled.toml")
