/
SNANA_FITS_to_pd.py
85 lines (65 loc) · 2.34 KB
/
SNANA_FITS_to_pd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np
import pandas as pd
from pathlib import Path
from astropy.table import Table
"""
SNANA simulation/data format to pandas
"""
def read_fits(fname, drop_separators=False):
"""Load SNANA formatted data and cast it to a PANDAS dataframe
Args:
fname (str): path + name to PHOT.FITS file
drop_separators (Boolean): if -777 are to be dropped
Returns:
(pandas.DataFrame) dataframe from PHOT.FITS file (with ID)
(pandas.DataFrame) dataframe from HEAD.FITS file
"""
# load photometry
dat = Table.read(fname, format="fits")
df_phot = dat.to_pandas()
# failsafe
if df_phot.MJD.values[-1] == -777.0:
df_phot = df_phot.drop(df_phot.index[-1])
if df_phot.MJD.values[0] == -777.0:
df_phot = df_phot.drop(df_phot.index[0])
# load header
header = Table.read(fname.replace("PHOT", "HEAD"), format="fits")
df_header = header.to_pandas()
df_header["SNID"] = df_header["SNID"].astype(np.int32)
# add SNID to phot for skimming
arr_ID = np.zeros(len(df_phot), dtype=np.int32)
# New light curves are identified by MJD == -777.0
arr_idx = np.where(df_phot["MJD"].values == -777.0)[0]
arr_idx = np.hstack((np.array([0]), arr_idx, np.array([len(df_phot)])))
# Fill in arr_ID
for counter in range(1, len(arr_idx)):
start, end = arr_idx[counter - 1], arr_idx[counter]
# index starts at zero
arr_ID[start:end] = df_header.SNID.iloc[counter - 1]
df_phot["SNID"] = arr_ID
if drop_separators:
df_phot = df_phot[df_phot.MJD != -777.000]
return df_header, df_phot
def save_fits(df, fname):
"""Save data frame in fits table
Arguments:
df {pandas.DataFrame} -- data to save
fname {str} -- outname, must end in .FITS
"""
keep_cols = df.keys()
df = df.reset_index()
df = df[keep_cols]
outtable = Table.from_pandas(df)
Path(fname).parent.mkdir(parents=True, exist_ok=True)
outtable.write(fname, format="fits", overwrite=True)
#
# Use examples
#
# SNANA.FITS to pd
df_header, df_phot = read_fits("./raw/DES_Ia-0001_PHOT.FITS", drop_separators=True)
# save one csv
df_out = pd.merge(df_header, df_phot, on="SNID", how="left")
df_out.to_csv("df_merged.csv")
# pd to FITS
# this saves the whole data frame as a 1-D FITS table
# save_fits(df_header, "DES_Ia-0001_HEAD.FITS")