In [None]:
import numpy as np
import matplotlib.pyplot as plt
import corner
from taurex.log.logger import root_logger
%matplotlib inline

In [None]:
from taurex.parameter import ParameterParser

pp = ParameterParser()

# Parse the input file
# input_file = "parfile.par"
input_file = "parfile-gpu.par"
pp.read(input_file)

# Setup global parameters
pp.setup_globals()

# Get the spectrum
observation = pp.generate_observation()

binning = pp.generate_binning()

# Generate a model from the input
model = pp.generate_appropriate_model(obs=observation)

# build the model
model.build()

In [None]:
wngrid = None

if binning == "observed" and observation is None:
    root_logger.critical(
        "Binning selected from Observation yet None provided"
    )
    quit()

if binning is None:
    if observation is None or observation == "self":
        binning = model.defaultBinner()
        wngrid = model.nativeWavenumberGrid
    else:
        binning = observation.create_binner()
        wngrid = observation.wavenumberGrid
else:
    if binning == "native":
        binning = model.defaultBinner()
        wngrid = model.nativeWavenumberGrid
    elif binning == "observed":
        binning = observation.create_binner()
        wngrid = observation.wavenumberGrid
    else:
        binning, wngrid = binning

In [None]:
instrument = pp.generate_instrument(binner=binning)

num_obs = 1
if instrument is not None:
    instrument, num_obs = instrument

if observation == "self" and instrument is None:
    root_logger.critical("Instrument nust be specified when using self option")
    raise ValueError("No instruemnt specified for self option")

inst_result = None
if instrument is not None:
    inst_result = instrument.model_noise(
        model, model_res=model.model(), num_observations=num_obs
    )

In [None]:
# Observation on self
if observation == "self":
    from taurex.data.spectrum import ArraySpectrum
    from taurex.util.util import wnwidth_to_wlwidth

    inst_wngrid, inst_spectrum, inst_noise, inst_width = inst_result

    inst_wlgrid = 10000 / inst_wngrid

    inst_wlwidth = wnwidth_to_wlwidth(inst_wngrid, inst_width)
    observation = ArraySpectrum(
        np.vstack([inst_wlgrid, inst_spectrum, inst_noise, inst_wlwidth]).T
    )
    binning = observation.create_binner()

In [None]:
optimizer = None
solution = None

import time

if observation is None:
    root_logger.critical("No spectrum is defined!!")
    quit()

optimizer = pp.generate_optimizer()
optimizer.set_model(model)
optimizer.set_observed(observation)
pp.setup_optimizer(optimizer)

In [None]:
start_time = time.time()
solution = optimizer.fit()

end_time = time.time()

root_logger.info("Total Retrieval finish in %s seconds", end_time - start_time)

for _, optimized, _, _ in optimizer.get_solution():
    optimizer.update_model(optimized)
    break

result = model.model()

# Let's plot

In [None]:
modelAxis = {
    "TransmissionModel": "$(R_p/R_*)^2$",
    "TransmissionCudaModel": "$(R_p/R_*)^2$",
    "EmissionModel": "$F_p/F_*$",
    "EmissionCudaModel": "$F_p/F_*$",
    "DirectImageModel": "$F_p$",
    "DirectImageCudaModel": "$F_p$",
}

In [None]:
fig = plt.figure(figsize=(10.6, 7.0))

obs_spectrum = optimizer._observed.spectrum
error = optimizer._observed.errorBar
wlgrid = optimizer._observed.wavelengthGrid

plt.errorbar(
    wlgrid,
    obs_spectrum,
    error,
    lw=1,
    color="black",
    alpha=0.4,
    ls="none",
    zorder=0,
    label="Observed",
)

for solution_idx, solution_val in solution.items():
    binned_grid = solution_val["Spectra"]["binned_wlgrid"][...]
    native_grid = solution_val["Spectra"]["native_wngrid"][...]

    plt.scatter(
        wlgrid,
        obs_spectrum,
        marker="d",
        zorder=1,
        **{"s": 10, "edgecolors": "grey", "color": "black"}
    )

    binned_spectrum = solution_val["Spectra"]["binned_spectrum"][...]
    binned_error = solution_val["Spectra"]["binned_std"][...]

    color = "C0"
    label = "Fitted spectrum"
    plt.plot(wlgrid, binned_spectrum, label=label, color=color, alpha=0.6)
    if binned_error is not None:
        # 1 sigma
        plt.fill_between(
            wlgrid,
            binned_spectrum - binned_error,
            binned_spectrum + binned_error,
            alpha=0.5,
            zorder=-2,
            color=color,
            edgecolor="none",
        )

        # 2 sigma
        plt.fill_between(
            wlgrid,
            binned_spectrum - 2 * binned_error,
            binned_spectrum + 2 * binned_error,
            alpha=0.2,
            zorder=-3,
            color=color,
            edgecolor="none",
        )


plt.xlim(np.min(wlgrid) - 0.05 * np.min(wlgrid), np.max(wlgrid) + 0.05 * np.max(wlgrid))
plt.xlabel(r"Wavelength ($\mu$m)")
plt.ylabel(modelAxis[model.__class__.__name__])

if np.max(wlgrid) - np.min(wlgrid) > 5:
    plt.xscale("log")
    plt.tick_params(axis="x", which="minor")
plt.legend(loc="best", ncol=2, frameon=False, prop={"size": 11})

plt.title("Fitted spectrum", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
def get_derived_parameters(solution):
    if 'derived_params' in solution:
        return [c for k, c in solution['derived_params'].items()]
    else:
        return [solution['fit_params']['mu_derived']]

In [None]:
fittingNames = [param[0] for param in optimizer.fitting_parameters]

figs = []
fig = plt.figure(figsize=(12, 12))

for solution_idx, solution_val in solution.items():
    tracedata = solution_val["tracedata"]
    weights = solution_val["weights"]
    indices = np.array([fittingNames.index(x) for x in fittingNames])

    mu_derived = get_derived_parameters(solution_val)
    _tracedata = np.column_stack((tracedata, mu_derived[0]["trace"]))
    fittingNames.append("mu (derived)")

    figure_past = fig

    plt.rc("xtick", labelsize=10)  # size of individual labels
    plt.rc("ytick", labelsize=10)
    plt.rc("axes.formatter", limits=(-4, 5))  # scientific notation..

    fig = corner.corner(
        _tracedata,
        weights=weights,
        labels=fittingNames,
        label_kwargs=dict(fontsize=14),
        smooth=1.5,
        scale_hist=True,
        quantiles=[0.16, 0.5, 0.84],
        show_titles=True,
        title_kwargs=dict(fontsize=12),
        # truths=truths,
        truth_color="black",
        ret=True,
        fill_contours=True,
        color=color,
        top_ticks=False,
        bins=100,
        fig=figure_past,
    )
    # fig.gca().annotate(
    #     "Posterior %s" % (solution_idx),
    #     xy=(0.5, 0.96),
    #     xycoords="figure fraction",
    #     xytext=(0, -5),
    #     textcoords="offset points",
    #     ha="center",
    #     va="top",
    #     fontsize=20,
    # )
plt.show()