In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
from tabulate import tabulate
from os import path
from astropy.io import fits
from sklearn.metrics import mean_squared_error

from spectrum import Spectrum, FitsSpectrum
from common import list_files, list_directories, tqdm
import functions as func

In [15]:
def load_target(target_dir: str) -> list[FitsSpectrum]:
    return [FitsSpectrum(datafile) for band_dir in list_directories(target_dir) for datafile in list_files(band_dir)]

wavelength_dibs, spectrum_dibs=func.fileloader('data/dibs')
dib_centra_list=[5780,5797,6196,6379,6613,7224]
dib_wavelengths,dib_spectra,dib_locs=func.dib_finder(wavelength_dibs,spectrum_dibs,locs=dib_centra_list,wave_range=3,threshold=0.03)

hd185859 = load_target('data/fits/HD185859')
dib_spectra = []

def fit_dibs(dib_wvls, dib_spectra, model = func.skewed_gauss):
    assert len(dib_wvls) == len(dib_spectra), f'len(dib_wvls)={len(dib_wvls)} and len(dib_spectra)={len(dib_spectra)}, but they must have the same length'

    bounds_list = []
    p0_list = []
    centra_list = []

    for dib_wvl, dib_spectrum in zip(dib_wvls, dib_spectra):
        dib_center = func.min_finder(dib_wvl, dib_spectrum)

        centra_list.append(dib_center)
        p0_list.append([dib_center,0.1,0.10,2])
        bounds_list.append(np.array((
            [dib_center-0.5,-np.inf,-np.inf,-np.inf,-np.inf,-np.inf],
            [dib_center+0.5, np.inf, np.inf, np.inf, np.inf, np.inf]
        )))

    return func.fitter_plotter(dib_wvls, dib_spectra, model, centra_list, p0_list=p0_list, bounds_list=bounds_list)

table_headers = ['Observation date', 'Target', 'Expected center [Å]', 'Center [Å]', 'Width', 'Amplitude', 'Skew', 'Slope', 'Start', 'RMSE', 'FWHM [Å]', 'EW [Å]']
table_data = []

def fit_gaussian_for_target(target: str):
    subspectra = load_target(f'data/fits/{target}')
    
    for subspectrum in subspectra:
        subspectrum.remove_outliers()
        subspectrum.normalize(max_degree=10)
        subspectrum.correct_shift()

        wvl_max, wvl_min = np.max(subspectrum.wavelength), np.min(subspectrum.wavelength)
        dib_wavelengths, dib_fluxes, centra_list = func.dib_finder(
            subspectrum.wavelength, subspectrum.flux,
            locs=dib_locs, wave_range=1, index_search=True, ref_wavelengths=wavelength_dibs
        )

        params, predictions = fit_dibs(dib_wavelengths, dib_fluxes)
        fwhm_list = [func.FWHM(param[1]) for param in params]
        ew_list = [param[1] * param[2] * np.sqrt(2 * np.pi) for param in params]

        for center in dib_centra_list:
            # The DIB does not lie in this subspectrum
            if not wvl_min < center < wvl_max:
                continue

            result = subspectrum.select_dib(center)

            if result is None:
                continue

            _, params, rmse, fwhm, ew = result
            table_data.append([subspectrum.obs_date, subspectrum.target, center, *params, rmse, fwhm, ew])

for subspectrum in hd185859:
    subspectrum.remove_outliers()
    subspectrum.normalize(max_degree=10)
    subspectrum.correct_shift()

    dib_wavelengths, dib_fluxes, centra_list = func.dib_finder(
        subspectrum.wavelength, subspectrum.flux,
        locs=dib_locs, wave_range=1, index_search=True, ref_wavelengths=wavelength_dibs
    )

    params, predictions = fit_dibs(dib_wavelengths, dib_fluxes)
    fwhm_list = [func.FWHM(param[1]) for param in params]
    ew_list = [param[1] * param[2] * np.sqrt(2 * np.pi) for param in params]

    centra_total.append(centra_list)
    ew_total.append(ew_list)

    # if len(dib_wavelengths) > 0:
    #     ncols = 3
    #     nrows = int(np.ceil(len(dib_wavelengths) / ncols))
    #     fig, axes = plt.subplots(nrows, ncols, figsize=(15, nrows * 3))

    #     for wvl, flux, fit, ax, fwhm, ew in zip(dib_wavelengths, dib_fluxes, predictions, axes.flatten(), fwhm_list, ew_list):
    #         ax.plot(wvl, flux, '.', ms=2)
    #         ax.plot(wvl, fit)
    #         ax.set_title(f'Gaussian fit (FWHM={fwhm:.4g}, EW={ew:.4g})')
    #         ax.set_xticks(np.arange(np.min(wvl), np.max(wvl), 0.5))

    #     fig.tight_layout()

centra_total = [c for cs in centra_total for c in cs]
ew_total = [e for es in ew_total for e in es]

KeyboardInterrupt: 

In [14]:
np.savetxt('equivalent_widths.txt', np.column_stack([centra_total, ew_total]))