In [1]:
# default imports
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import yaml
import os

from tqdm import tqdm as tqdm

# custom imports
from taurex_utils import get_mols, full_contribution_array



In [2]:
dataset = "../DATA/proccessed_22.hdf5"
checkpoint_path = "../DATA/checkpoint_22_contribution_proccessing.yaml"
load_path = "/Users/jools/Documents/Uni/UCL/ARIEL/data_preparation/contribution_22_checkpoint_backup_10830.hdf5"

In [3]:
def m_to_rJ(distance_m):
    """
    Convert meters to Jup radii.
    """
    return distance_m / 7.1492e7

def m_to_rS(distance_m):
    """
    Convert meters to Solar radii.
    """
    return distance_m / 6.957e8

def kg_to_Mj(mass_kg):
    """
    Convert kilograms to Jupiter masses.
    """

    return mass_kg / 1.898e27




In [4]:

def save_checkpoint_yaml(checkpoint_path, current_index):
    """
    Save the current progress to a YAML checkpoint file.
    """
    with open(checkpoint_path, 'w') as f:
        yaml.dump({"current_index": current_index}, f)

def load_checkpoint_yaml(checkpoint_path):
    """
    Load the last saved progress from a YAML checkpoint file.
    """
    if os.path.exists(checkpoint_path):
        with open(checkpoint_path, 'r') as f:
            return yaml.safe_load(f).get("current_index", 0)
    return 0


In [5]:
if not os.path.exists(checkpoint_path):
    save_checkpoint_yaml(checkpoint_path, 0)

start_index = load_checkpoint_yaml(checkpoint_path)
try:
    start_index = int(start_index)
except:
    raise ValueError(f"Could not convert {start_index} to an integer")

In [6]:
ds = xr.open_dataset(dataset)

In [7]:
ds

In [8]:
species = get_mols()
species


['H2O', 'CO2', 'CH4', 'CO', 'NH3']

In [9]:
sample_index = ds['sample'].values

print(sample_index[:5], '...', sample_index[-5:])
sample_index.shape

[0 1 2 3 4] ... [91387 91388 91389 91390 91391]


(91392,)

In [10]:

ds_c = ds.__deepcopy__()

In [11]:
ds_c.coords['species'] = species

In [12]:
empty_arr = np.empty((len(sample_index), len(species), len(ds_c['wavelength'])))
empty_arr.fill(np.nan)
ds_c['contributions'] = xr.DataArray(empty_arr, dims=['sample', 'species', 'wavelength'])
ds_c['contributions'].attrs = dict(units='transit depth', 
                                   dataset='taurex forward model',
                                   description='spectra per species if only that species was present in the atmosphere')

ds_c['clean_forward_model'] = xr.DataArray(empty_arr[:,0,:], dims=['sample', 'wavelength'])
ds_c['clean_forward_model'].attrs = dict(units='transit depth',
                                        dataset='taurex forward model',
                                        description='forward model spectra with full species compliment present in the atmosphere, but no instrument noise simulated')


In [13]:
ds_c

In [14]:
# this is how to ammend the contents of the dataset
ds_c['contributions'].loc[dict(sample=0, species='H2O')] = np.ones_like(ds_c['wavelength'])

In [15]:
ds[['planet_temp_k'] + [f'log_{s}' for s in species]].sel(sample=0).to_array().values

array([1108.72506695,   -6.48480938,   -7.01651169,   -3.29472856,
         -3.3519126 ,   -6.95582771])

In [16]:
planet = 1
ds['star_temperature_k'].sel(sample=planet).values

array(5071.)

In [17]:
if start_index != 0:
    print(f"Resuming at index {start_index}")
    ds_c = xr.open_dataset(load_path)
else:
    input("Staring from scratch. WARNING This will overwrite existing data - Press enter to continue")

Resuming at index 11060


In [18]:
print(f'Processing {sample_index[start_index]} to {sample_index[-1]}')
for planet in tqdm(range(start_index, sample_index.size)):

    abundancies = ds[[f'log_{s}' for s in species]].sel(sample=planet).to_array().values
    planet_temp = ds['planet_temp_k'].sel(sample=planet).values
    planet_radius = m_to_rJ(ds['planet_radius_m'].sel(sample=planet).values)

    planet_mass = kg_to_Mj(ds['planet_mass_kg'].sel(sample=planet).values)
    star_radius = m_to_rS(ds['star_radius_m'].sel(sample=planet).values)
    star_temp = ds['star_temperature_k'].sel(sample=planet).values

    # print(f'Planet {planet} - ')

    #for species in get_mols():
        # print(f'{species}: {abundancies[get_mols().index(species)]}')
    # print(f'Planet Temp: {planet_temp}\nPlanet Radius: {planet_radius}\nPlanet Mass: {planet_mass}')
    # print(f'Star Temp: {star_temp}\nStar Radius: {star_radius}')

    # generate the contribution functions for all of the elements present in the planet

    contribs = full_contribution_array(['H2O', 'CO2', 'CH4', 'CO', 'NH3'],
                                    abundancies,
                                    Rp=planet_radius,
                                    Tp=planet_temp,
                                    Mp=planet_mass,
                                    Rs= star_radius,
                                    Ts=star_temp)
    
    
    # save the contributions to the dataset
    for s in species:
        ds_c['contributions'].loc[dict(sample=planet, species=s)] = contribs[s][1][::-1] # these are also in reverse wavelength order for some reason!

    # save the clean forward model
    ds_c['clean_forward_model'].loc[dict(sample=planet)] = contribs['Full Model'][1][::-1]

    

    # save every 30 planets
    if planet % 10 == 0:
        try:
            ds_c.to_netcdf(f'../DATA/contribution_22_checkpoint.hdf5')
            save_checkpoint_yaml(checkpoint_path, planet)
        except Exception as e:
            print(f"CAUTION!! Failed to save checkpoint: {e}")
            ds_c.to_netcdf(f'../DATA/contribution_22_checkpoint_EMERGENCY_SAVE_{np.random.randint(10000,99999)}.hdf5')
            

Processing 11060 to 91391


  0%|          | 0/80332 [00:00<?, ?it/s]

In [None]:
ds_c.to_netcdf(f'../DATA/contribution_22_full.hdf5')
save_checkpoint_yaml(checkpoint_path, "Completed")

In [None]:
for s in species:
    plt.errorbar(x = ds_c['wavelength'],
                y = ds_c['contributions'].loc[dict(sample=0, species=s)],
                xerr=ds['bin_width']/2,
                yerr=ds['noise'].sel(sample = 0),
                fmt=' ', color='lightgrey')
    plt.plot(ds_c['wavelength'],
         ds_c['contributions'].loc[dict(sample=0, species=s)],
           label=s)
    
plt.plot(ds_c['wavelength'],
         ds_c['clean_forward_model'].loc[dict(sample=0)], label='Full Model', color='black')

plt.errorbar(x = ds['wavelength'],
            y = ds['spectrum'].sel(sample=0),
            xerr=ds['bin_width']/2,
            yerr=ds['noise'].sel(sample=0),
            fmt=' ', color='lightgrey')
plt.plot(ds['wavelength'], 
         ds['spectrum'].sel(sample=0), 
        "--k", label='Data', )

plt.xlabel('Wavelength (µm)')
plt.ylabel('Transit Depth')
    
plt.legend()