# Draw a training sample from COSMOS2020 data
Using simple `pandas` functions.

In [None]:
import os
import pandas as pd
import numpy as np
import h5py
from sedpy import observate
import matplotlib as mpl
import matplotlib.pyplot as plt

## Load the data

In [None]:
infile = os.path.abspath(
    os.path.join('.', 'data', 'COSMOS2020_emu_hscOnly_CC_zinf3_noNaN.h5') # , 'COSMOS2020_emu_CC.h5') # 
)

In [None]:
indf = pd.read_hdf(infile)
indf

## Draw an appropriate sample

In [None]:
ntrain = int(max(10*indf.shape[0]/100, 5000)) if indf.shape[0] > 10000 else int(30*indf.shape[0]/100)
ntrain

In [None]:
traindf = indf.sample(n=ntrain, replace=False)
traindf

## Write it to a new file

In [None]:
with h5py.File(infile, 'r') as h5f:
    print(h5f.keys())
    dfkey = list(h5f.keys())[0]
dfkey

In [None]:
outdir, bn = os.path.split(infile)
in_n, ext = os.path.splitext(bn)
endstr = f"sample{ntrain//1000}k" if ntrain > 1000 else f"sample{ntrain}"
outfile = os.path.join(
    outdir,
    '_'.join((in_n, endstr))+ext
)
outfile

In [None]:
#traindf.to_hdf(outfile, key=dfkey, mode='w')

## Plot the filters
Because why not.

In [None]:
filt_list = [ '_'.join(_str.split('_')[1:]) for _str in traindf.columns if 'mag' in _str and not 'err' in _str ]
filt_dic = { _filt: '' if _filt in observate.list_available_filters() else os.path.join('.', 'data', 'FILTER', 'filt_cosmos') for _filt in filt_list }
filt_dic

In [None]:
sedpyfilts = [ observate.Filter(_filt) if _p=='' else observate.Filter(_filt, directory=_p) for _filt, _p in filt_dic.items() ]

In [None]:
filtcols = plt.cm.rainbow(np.linspace(0, 1, len(sedpyfilts)))

for filt, clr in zip(sedpyfilts, filtcols):
    maxtrans = np.max(filt.transmission)
    trans = filt.transmission / maxtrans if maxtrans>1 else filt.transmission
    plt.plot(filt.wavelength, trans, c=clr, label=filt.name)
    plt.fill_between(filt.wavelength, trans, color=clr, alpha=0.5)
#plt.xscale('log')
plt.grid()
plt.xlabel(r'Wavelength $\mathrm{[\AA]}$')
plt.ylabel('Transmission [arbitrary unit]')
plt.legend(loc='upper right', bbox_to_anchor=(1., 1.), ncol=2)
plt.show()