In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import corner
import xarray as xr
import io

In [None]:
ariel_grid_tb = pd.read_csv(grid_path, delimiter=',',  index_col=0)
ariel_grid_tb = ariel_grid_tb.rename(columns={' wavelength (micrometers)':"wavelength_µm", 
                                              ' bin_width (micrometers)':"bin_width_µm"})
ariel_grid_tb.index.name = None
print(ariel_grid_tb.shape)
ariel_grid_tb.head()

In [None]:
spectra_tb = pd.read_csv(f'{root_path}/data.csv', delimiter=',', skiprows=1, dtype=str)
spectra_tb.drop(index=[0,1], inplace=True)

col1 = spectra_tb.columns[0]
spectra_tb = spectra_tb.rename(columns={col1:col1[1:]})

spectra_tb.columns = spectra_tb.columns.astype(float)
spectra_tb = spectra_tb.astype(float)

spectra_tb = spectra_tb.iloc[:, ::-1]

print(spectra_tb.shape)
spectra_tb.head()

In [None]:
labels_tb = pd.read_csv(f'{root_path}/labels.csv', delimiter=',', header=0)
labels_tb = labels_tb.rename(columns={'# planet_temp': 'planet_temp_k'})
# labels_tb = labels_tb.rename(lambda x: x + '_l', axis='columns') #eh this might mix up columns
labels_tb = labels_tb.rename(columns={col: col + '_l' for col in labels_tb.columns})
print(labels_tb.shape)
labels_tb.head()

In [None]:
aux_full_tb = pd.read_csv(f'{root_path}/aux_full.csv', delimiter=',', header=0)
aux_full_tb = aux_full_tb.rename(columns={'# star_distance': 'star_distance',
                                          'star_temperature': 'star_temperature_k',})
aux_full_tb = aux_full_tb.rename(columns={col: col + '_a' for col in aux_full_tb.columns})
print(aux_full_tb.shape)
aux_full_tb.head()


In [None]:
file_path = f'{root_path}/contributions_CH4.csv'
a = pd.read_csv(file_path, delimiter=',', header=None)
a = a.iloc[:, ::-1]
a.columns = ariel_grid_tb['wavelength_µm'].values

a.head()

In [None]:
species = ['CH4', 'CO', 'CO2', 'H2O', 'NH3']
species_tb = {}
for s in species:
    file_path = f'{root_path}/contributions_{s}.csv'
    species_tb[s] = pd.read_csv(file_path, delimiter=',', header=None)
    species_tb[s] = species_tb[s].iloc[:, ::-1]
    species_tb[s].columns = ariel_grid_tb['wavelength_µm'].values
    print(species_tb[s].shape)

In [None]:
print(f'ariel_grid_tb shape is {ariel_grid_tb.shape}\n with headers {ariel_grid_tb.columns}\n')
print(f'spectra_tb shape is {spectra_tb.shape}\n with headers equal to the column wavelength_µm of ariel_grid_tb\n')
print(f'labels_tb shape is {labels_tb.shape}\n with headers {labels_tb.columns}\n')
print(f'aux_full_tb shape is {aux_full_tb.shape}\n with headers {aux_full_tb.columns}\n')
print(f'there are then {len(species_tb)} species_tables stored in a dictionary with keys {species_tb.keys()} which all look like this:')
print(f'species_tb shape is {species_tb["CH4"].shape}\n with headers equal to the column wavelength_µm of ariel_grid_tb\n')


In [None]:


# Create a sorted list of species from dictionary keys
species_order = sorted(species_tb.keys())  # Example: ['CH4', 'CO', 'CO2', 'H2O', 'NH3']


# Combine species tables into a 3D array (samples × wavelengths × species)
species_data = np.stack([species_tb[sp].values for sp in species_order], axis=-1)

# Add 'observed' as the first species
species_list = ['observed'] + species_order
species_dataset = ['observation'] + ['contribution'] * len(species_order)

# Create the Dataset with multi-dimensional species data
ds = xr.Dataset(
    coords={
        'wavelength': ariel_grid_tb['wavelength_µm'].values,
        'sample': np.arange(spectra_tb.shape[0]),
        'species': species_list  # Add species coordinate
    }
)

# Add spectral bin widths
ds['bin_width'] = xr.DataArray(
    ariel_grid_tb['bin_width_µm'].values,
    dims=['wavelength']
)

# Reshape observed spectra to include species dimension (for 'observed')
spectra_np = spectra_tb.values.reshape(
    spectra_tb.shape[0], len(ariel_grid_tb['wavelength_µm'].values), 1
)

# Combine observed spectra with species data along the species dimension
combined_data = np.concatenate([spectra_np, species_data], axis=-1)

# Add combined species contributions with the species dimension
ds['spectra'] = xr.DataArray(
    combined_data,
    dims=['sample', 'wavelength', 'species'],
        coords={
        'species': species_list,  # Same as before
    },
)

# Add planetary parameters
for label in labels_tb.columns:
    ds[label] = xr.DataArray(
        labels_tb[label].values,
        dims=['sample'],
        attrs={'dataset': 'label'} # Add attribute to distinguish labels from auxiliary parameters
    )

# Add auxiliary parameters
for aux_param in aux_full_tb.columns:
    ds[aux_param] = xr.DataArray(
        aux_full_tb[aux_param].values,
        dims=['sample'],
        attrs={'dataset': 'auxiliary'} # Add attribute to distinguish auxiliary parameters from labels
    )

ds


In [None]:
ds['wavelength']

In [None]:
plt.plot(ds['wavelength'], ds['spectra'].sel(sample=0, species='observed'), label='observed')

In [None]:
mean_h2o_spectrum = ds['spectra'].sel(species='H2O').mean(dim='sample')
plt.plot(ds['wavelength'], mean_h2o_spectrum, label='mean H2O contribution')

In [None]:
x = ds['spectra'].sel(species='observed') 
y = ds['spectra'].sel(species=['H2O', 'CO2', 'CH4', 'NH3']).sum(dim='species')



x = (x - x.min(dim='wavelength')) / (x.max(dim='wavelength') - x.min(dim='wavelength'))
x.sel(sample=range(2000)).plot(cmap='Spectral')
plt.figure()

y = (y - y.min(dim='wavelength')) / (y.max(dim='wavelength') - y.min(dim='wavelength'))
y.sel(sample=range(2000)).plot(cmap='Spectral')


In [None]:
# Plot data
i=0
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

# ax.plot(ariel_grid["wavelength_µm"] ,spectra.values[i],'k--', label='spectra')
ax.plot(ds['wavelength'], ds['spectra'].sel(sample=i, species='observed'), 'k--', label='observed')

# s = ds['species'].values
# s = s[s != 'observed']
s = (species_values := ds['species'].values)[species_values != 'observed'] # walrus operator!!? I literally never have a real excuse to use these!

ax.plot(ds['wavelength'] ,ds['spectra'].sel(sample=i, species=s), label=s)

ax.legend()
ax.set_xscale('log')

ax.set_xticks([0.5, 1, 2, 5, 8])
ax.set_xticklabels([0.5, 1, 2, 5, 8])

ax.set_xlabel('Wavelength (µm)')
ax.set_ylabel('Transit depth')


In [None]:
s = (species_values := ds['species'].values)[species_values != 'observed'] # walrus operator!!? I literally never have a real excuse to use these!
num_plots = 25
n_rows, n_cols = 5, 5

fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 10), sharex=True)
axes = axes.flatten()

for i, ax in enumerate(axes):
    ax.plot(ds['wavelength'], ds['spectra'].sel(sample=i, species='observed'), 'k--', label='observed')
    
    ax.plot(ds['wavelength'] ,ds['spectra'].sel(sample=i, species=s), label=s)

    ax.set_xscale('log')

axes[0].legend()

ax.set_xticks([0.5, 1, 2, 5, 8])
ax.set_xticklabels([0.5, 1, 2, 5, 8])

fig.supxlabel('Wavelength (µm)')
fig.supylabel('Transit depth')
fig.tight_layout()

# Adjust layout to prevent overlap
plt.tight_layout()
plt.show()


In [None]:
plt.plot(ds['wavelength'], ds['spectra'].sel(species='H2O').values[:1000].T)
plt.xscale('log')

In [None]:
for n in range(1,6):
    plt.plot(ds["wavelength"], np.power((ds['spectra'].sel(species='H2O').values[:1000].T),(1/n)))
    plt.xscale('log')
    plt.title(f'{n}th root, spead {np.power((ds['spectra'].sel(species='H2O').values[:1000].T),(1/n)).std()/np.power((ds['spectra'].sel(species='H2O').values[:1000].T),(1/n)).mean():.4f}')
    plt.show()

In [None]:
# Unlog the log values in ds by 10^val
ds_natural = ds.copy()
for var in ds_natural.data_vars:
    if 'log_' in var:
        new_var = var.replace('log_', '')
        ds_natural = ds_natural.rename_vars({var: new_var})
        ds_natural[new_var] = np.power(10, ds_natural[new_var])

ds_natural

In [None]:
s = (species_values := ds['species'].values)[species_values != 'observed'] # walrus operator!!? I literally never have a real excuse to use these!

# Create a boolean mask for samples where at least one value in `spectra` is non-zero
sample_mask_w_contribs = ds['spectra'].sel(species=s).sum(dim=['wavelength', 'species']) != 0
ds_c = ds.sel(sample = sample_mask_w_contribs)
ds_c


In [None]:
# List of parameters to plot (excluding `spectra`)
parameters = [var for var in ds.data_vars if var != "spectra" and var != "wavelength" and var != "bin_width"]
print(parameters)
# Create subplots
fig, axes = plt.subplots(lp:=len(parameters), 1, figsize=(8, lp * 2), constrained_layout=True)
if lp == 1:
    axes = [axes]  # Ensure axes is iterable when there's only one parameter

for ax, param in zip(axes, parameters):
    
    data = ds_c[param].values
    ax.hist(data.flatten(), bins=50, alpha=0.7, histtype="step", color="black", linewidth=1.5)

    ax.axvline(dm:=data.mean(), color="black", linestyle="-", label="Mean") 
    ax.axvline(dm + (dstd := data.std()), color="green", linestyle="--", label="1 Std. Dev.")
    ax.axvline(dm - dstd, color="green", linestyle="--")
    ax.axvline(dm + 2 * dstd, color="orange", linestyle="-.", label="2 Std. Dev.")
    ax.axvline(dm - 2 * dstd, color="orange", linestyle="-.") 
    ax.axvline(dm + 3 * dstd, color="red", linestyle=":", label="3 Std. Dev.") 
    ax.axvline(dm - 3 * dstd, color="red", linestyle=":")

    ax.set_xlabel(param)
    ax.set_ylabel("Frequency")

    ax.set_xlim(data.min(), data.max())

plt.show()


In [None]:
parameters = [var for var in ds.data_vars if var != "spectra" and var != "wavelength" and var != "bin_width"]
ds[parameters]

In [None]:
# Define a tolerance (e.g., within 10% of the mean)
abundance_tolerance = 1e-2
no_tolerance = 1e16

tolerance = {
    'planet_temp_k_l':    no_tolerance,
    'log_H2O_l':          no_tolerance,
    'log_CO2_l':          no_tolerance,
    'log_CH4_l':          abundance_tolerance,
    'log_CO_l':           no_tolerance,
    'log_NH3_l':          no_tolerance,
    'star_distance_a':    no_tolerance,
    'star_mass_kg_a':     no_tolerance,
    'star_radius_m_a':    no_tolerance,
    'star_temperature_k_a':no_tolerance,
    'planet_mass_kg_a':   no_tolerance,
    'planet_orbital_period_a':no_tolerance,
    'planet_distance_a':  no_tolerance,
    'planet_radius_m_a':  no_tolerance,
    'planet_surface_gravity_a':no_tolerance,
            }
tolerance_ds = xr.Dataset(tolerance)

targets = {
    'planet_temp_k_l': 1197.8374538093958,
    'log_H2O_l':-5.994919167786353,
    'log_CO2_l': -6.499649943854283,
    'log_CH4_l': -4,#-6.001007899837979,
    'log_CO_l': -4.496589021501109,
    'log_NH3_l': -6.491720080880544,
    'star_distance_a': 568.2065020332558,
    'star_mass_kg_a': 2.0357949035406833e+30,
    'star_radius_m_a': 855078079.4802036,
    'star_temperature_k_a': 5672.084205905214,
    'planet_mass_kg_a': 1.1245086149027514e+27,
    'planet_orbital_period_a': 24.335250716885003,
    'planet_distance_a': 0.11941006826573485,
    'planet_radius_m_a': 44984908.31604669,
    'planet_surface_gravity_a': 16.67067211510088
}
targets_ds = xr.Dataset(targets)

In [None]:
targets_ds

In [None]:

# Compute the mean for all parameters
# mean_planet = ds[parameters].mean(dim="sample")

# Compute the relative difference for all parameters and check if within tolerance
relative_diff = np.abs(np.abs(ds_c[parameters]) - np.abs(targets_ds)) / np.abs(targets_ds)
within_tolerance = (relative_diff <= tolerance_ds).to_array().all(dim="variable")

# Use the mask to subset the dataset
reduced_ds = ds_c.sel(sample=within_tolerance)

# Print or inspect the resulting subset dataset
reduced_ds


In [None]:
# List of parameters to plot (excluding `spectra`)
parameters = [var for var in ds_c.data_vars if var != "spectra" and var != "wavelength" and var != "bin_width"]
print(parameters)
# Create subplots
fig, axes = plt.subplots(len(parameters), 1, figsize=(8, len(parameters) * 2), constrained_layout=True, dpi=500)
if len(parameters) == 1:
    axes = [axes]  # Ensure axes is iterable when there's only one parameter

for ax, param in zip(axes, parameters):
    data = ds_c[param].values
    
    ax.hist(data.flatten(), bins=50, alpha=0.7, histtype="step", color="black", linewidth=1.5, label = "all data")

    ax.axvline(data.mean(), color="black", linestyle="-", label="Mean") 
    ax.axvline(data.mean() + data.std(), color="green", linestyle="--", label="1 Std. Dev.")
    ax.axvline(data.mean() - data.std(), color="green", linestyle="--")
    ax.axvline(data.mean() + 2 * data.std(), color="orange", linestyle="-.", label="2 Std. Dev.")
    ax.axvline(data.mean() - 2 * data.std(), color="orange", linestyle="-.") 
    ax.axvline(data.mean() + 3 * data.std(), color="red", linestyle=":", label="3 Std. Dev.") 
    ax.axvline(data.mean() - 3 * data.std(), color="red", linestyle=":")

    data_sub = reduced_ds[param].values
    ax2 = ax.twinx()
    ax2.hist(data_sub.flatten(), bins=30, alpha=1, color="blue", linewidth=1.5, label="mean subset")

    ax.set_xlabel(param)
    ax.set_ylabel("Frequency")

    ax.set_xlim(data.min(), data.max())
axes[0].legend()
plt.show()

In [None]:
s = (species_values := reduced_ds['species'].values)[species_values != 'observed'] # walrus operator!!? I literally never have a real excuse to use these!
num_plots = 5
n_rows, n_cols = 3, 2

fig, axes = plt.subplots(n_rows, n_cols, figsize=(10, 6), dpi=500)
axes = axes.flatten()

fig.suptitle(f'Fixed log_CH4 at {reduced_ds["log_CH4_l"].mean().values:.2f} {r'$\pm$'} {tolerance_ds["log_CH4_l"].values*100:.0f}%')

for i, ax in enumerate(axes[:num_plots]):
    # ax.plot(reduced_ds['wavelength'] ,reduced_ds['spectra'].sel(species=s[i]).mean(dim='sample'), label=s[i])
    # ax.fill_between(reduced_ds['wavelength'], 
    #                 (rdsm:= (rds := reduced_ds['spectra'].sel(species=s[i])).mean(dim='sample')) - (rdss := rds.std(dim='sample')),
    #                 rdsm + rdss,
    #                 alpha=0.2)
    ax.plot(reduced_ds['wavelength'], 
            reduced_ds['spectra'].sel(species=s[i]).prod(dim='sample'), 
            'k-',
            label=f'product {s[i]}')
    
    ax2 = ax.twinx()
    ax2.plot(reduced_ds['wavelength'], 
            reduced_ds['spectra'].sel(species=s[i]).mean(dim='sample'), 
            'r--',
            label=f'mean {s[i]}')
    ax2.tick_params(axis='y', colors='red')
    
    
    ax.set_xscale('log')
    # ax.set_ylabel(s[i])

    ax.set_xticks([0.5, 1, 2, 5, 8])
    ax.set_xticklabels([0.5, 1, 2, 5, 8])
    # create one legend for all twin axis
    lines, labels = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax.legend(lines + lines2, labels + labels2, loc='best')
    

axes[-1].axis('off')

fig.supxlabel('Wavelength (µm)')
fig.supylabel('Transit depth')
fig.tight_layout()

# Adjust layout to prevent overlap
plt.tight_layout()
plt.show()


In [None]:
# lets save the dataset we have created into a netcdf file

ds.to_netcdf('ariel_data.nc')


In [None]:
# and load it back into xarray and verify it is the same

ds2 = xr.open_dataset('ariel_data.nc')
ds2

In [None]:
ds2.equals(ds)

In [None]:
# now save the dataset into a hdf5 file

ds.to_netcdf('ariel_data.hdf5')

In [None]:
ds3 = xr.open_dataset('ariel_data.hdf5')
ds3

In [None]:
ds3.equals(ds)

In [None]:
import xarray as xr
import matplotlib.pyplot as plt

In [None]:

ds_c = xr.open_dataset('contribution_22_checkpoint_backup_10830.hdf5')

In [None]:
ds_c

In [None]:
n=1090

for s in ds_c['species'].values:
    plt.errorbar(x = ds_c['wavelength'],
                y = ds_c['contributions'].loc[dict(sample=n, species=s)],
                xerr=ds_c['bin_width']/2,
                yerr=ds_c['noise'].sel(sample = n),
                fmt=' ', color='lightgrey')
    plt.plot(ds_c['wavelength'],
         ds_c['contributions'].loc[dict(sample=n, species=s)],
           label=s)
    
plt.plot(ds_c['wavelength'],
         ds_c['clean_forward_model'].loc[dict(sample=n)], label='Full Model', color='black')

plt.errorbar(x = ds_c['wavelength'],
            y = ds_c['spectrum'].sel(sample=n),
            xerr=ds_c['bin_width']/2,
            yerr=ds_c['noise'].sel(sample=n),
            fmt=' ', color='lightgrey')
plt.plot(ds_c['wavelength'], 
         ds_c['spectrum'].sel(sample=n), 
        "--k", label='Data', )

plt.xlabel('Wavelength (µm)')
plt.ylabel('Transit Depth')
    
plt.legend()