Skip to content

Commit

Permalink
Skip combine-preds for fields with existing file (#554)
Browse files Browse the repository at this point in the history
* Skip combine-preds for fields with existing file

* Bugfix
  • Loading branch information
bfhealy committed Mar 19, 2024
1 parent 6a74755 commit 46e9c0f
Showing 1 changed file with 99 additions and 76 deletions.
175 changes: 99 additions & 76 deletions tools/combine_preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,88 +98,111 @@ def combine_preds(

if save:
os.makedirs(path_to_preds / combined_preds_dirname, exist_ok=True)

done_fields = [
str(x).split("/")[-1].split(".")[0]
for x in (path_to_preds / combined_preds_dirname).glob("field_*.parquet")
]

preds_to_save = None
counter = 0
print(f"Processing {len(fields_dnn_dict)} fields/files...")

for field in fields_dnn_dict.keys():
if field in fields_xgb_dict.keys():
try:
dnn_preds = read_parquet(fields_dnn_dict[field] / f"{field}.parquet")
xgb_preds = read_parquet(fields_xgb_dict[field] / f"{field}.parquet")
except FileNotFoundError:
print(f'Parquet file not found for field {field}')
continue

counter += 1

dnn_columns = [x for x in dnn_preds.columns]
xgb_columns = [x for x in xgb_preds.columns]

if not merge_dnn_xgb:
id_col = '_id' if dateobs is None else 'obj_id'

dnn_columns.remove(id_col)

new_xgb_columns = [x for x in xgb_columns if (x not in dnn_columns)]
xgb_preds_new = xgb_preds[new_xgb_columns]

preds_to_save = pd.merge(dnn_preds, xgb_preds_new, on=id_col)
meta_dict = None
else:
field = f"merged_{field}"

merged_preds = pd.merge(dnn_preds, xgb_preds, on='obj_id')
shared_obj_ids = merged_preds['obj_id'].values

# Rename e.g. vnv_dnn and vnv_xgb both to vnv
dnn_rename_mapper = {
c: c.split('_')[0] for c in dnn_columns if '_dnn' in c
}
xgb_rename_mapper = {
c: c.split('_')[0] for c in xgb_columns if '_xgb' in c
}

dnn_preds = dnn_preds.rename(dnn_rename_mapper, axis=1)
xgb_preds = xgb_preds.rename(xgb_rename_mapper, axis=1)

combined_preds = pd.concat([dnn_preds, xgb_preds])
combined_columns = [x for x in combined_preds.columns if '_id' not in x]
pred_columns = np.array(
[x for x in combined_columns if x not in ['ra', 'dec', 'period']]
)

if agg_method not in ['mean', 'max']:
raise ValueError(
"Currently supported aggregation methods are 'mean', 'max'."
if field not in done_fields:
if field in fields_xgb_dict.keys():
try:
dnn_preds = read_parquet(
fields_dnn_dict[field] / f"{field}.parquet"
)
xgb_preds = read_parquet(
fields_xgb_dict[field] / f"{field}.parquet"
)
except FileNotFoundError:
print(f'Parquet file not found for field {field}')
continue

counter += 1

dnn_columns = [x for x in dnn_preds.columns]
xgb_columns = [x for x in xgb_preds.columns]

if not merge_dnn_xgb:
id_col = '_id' if dateobs is None else 'obj_id'

dnn_columns.remove(id_col)

new_xgb_columns = [x for x in xgb_columns if (x not in dnn_columns)]
xgb_preds_new = xgb_preds[new_xgb_columns]

preds_to_save = pd.merge(dnn_preds, xgb_preds_new, on=id_col)
meta_dict = None
else:
field = f"merged_{field}"

merged_preds = pd.merge(dnn_preds, xgb_preds, on='obj_id')
shared_obj_ids = merged_preds['obj_id'].values

# Rename e.g. vnv_dnn and vnv_xgb both to vnv
dnn_rename_mapper = {
c: c.split('_')[0] for c in dnn_columns if '_dnn' in c
}
xgb_rename_mapper = {
c: c.split('_')[0] for c in xgb_columns if '_xgb' in c
}

dnn_preds = dnn_preds.rename(dnn_rename_mapper, axis=1)
xgb_preds = xgb_preds.rename(xgb_rename_mapper, axis=1)

combined_preds = pd.concat([dnn_preds, xgb_preds])
combined_columns = [
x for x in combined_preds.columns if '_id' not in x
]
pred_columns = np.array(
[
x
for x in combined_columns
if x not in ['ra', 'dec', 'period']
]
)

agg_dct = {c: agg_method for c in combined_columns}
grouped_preds = combined_preds.groupby(['obj_id', 'survey_id'])
aggregated_preds = grouped_preds.agg(agg_dct)

preds_to_save = aggregated_preds.loc[shared_obj_ids].reset_index()

meta_dict = {}
for _, row in preds_to_save.iterrows():
gt_threshold = (row[pred_columns] > p_threshold).values
new_entry = {row['obj_id']: (pred_columns[gt_threshold]).tolist()}
meta_dict.update(new_entry)

if save:
write_parquet(
preds_to_save,
path_to_preds / combined_preds_dirname / f"{field}.parquet",
)
if write_csv:
preds_to_save.to_csv(
path_to_preds / combined_preds_dirname / f"{field}.csv",
index=False,
if agg_method not in ['mean', 'max']:
raise ValueError(
"Currently supported aggregation methods are 'mean', 'max'."
)

agg_dct = {c: agg_method for c in combined_columns}
grouped_preds = combined_preds.groupby(['obj_id', 'survey_id'])
aggregated_preds = grouped_preds.agg(agg_dct)

preds_to_save = aggregated_preds.loc[shared_obj_ids].reset_index()

meta_dict = {}
for _, row in preds_to_save.iterrows():
gt_threshold = (row[pred_columns] > p_threshold).values
new_entry = {
row['obj_id']: (pred_columns[gt_threshold]).tolist()
}
meta_dict.update(new_entry)

if save:
write_parquet(
preds_to_save,
path_to_preds / combined_preds_dirname / f"{field}.parquet",
)
if meta_dict is not None:
with open(
path_to_preds / combined_preds_dirname / f"{field}_meta.json",
'w',
) as f:
json.dump(meta_dict, f)
if write_csv:
preds_to_save.to_csv(
path_to_preds / combined_preds_dirname / f"{field}.csv",
index=False,
)
if meta_dict is not None:
with open(
path_to_preds
/ combined_preds_dirname
/ f"{field}_meta.json",
'w',
) as f:
json.dump(meta_dict, f)

return preds_to_save

Expand Down

0 comments on commit 46e9c0f

Please sign in to comment.