Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changes/49.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add `--max_tel_per_type 10` argument to restrict the number of telescope parameters per telescope type.
Fix bug in indexing arrays with non-continuous telescope identifiers.
8 changes: 8 additions & 0 deletions src/eventdisplay_ml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def configure_training(analysis_type):
help="Observatory/site name for geomagnetic field (default: VERITAS).",
default="VERITAS",
)
parser.add_argument(
"--max_tel_per_type",
type=int,
help="Maximum number of telescopes to keep per mirror area type (for feature reduction).",
default=None,
)

model_configs = vars(parser.parse_args())

Expand All @@ -96,6 +102,8 @@ def configure_training(analysis_type):
_logger.info(f"Random state: {model_configs['random_state']}")
_logger.info(f"Max events: {model_configs['max_events']}")
_logger.info(f"Max CPU cores: {model_configs['max_cores']}")
if model_configs.get("max_tel_per_type") is not None:
_logger.info(f"Max telescopes per mirror area type: {model_configs['max_tel_per_type']}")

model_configs["models"] = hyper_parameters(
analysis_type, model_configs.get("hyperparameter_config")
Expand Down
135 changes: 131 additions & 4 deletions src/eventdisplay_ml/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ def _normalize_telescope_variable_to_tel_id_space(data, index_list, max_tel_id,
full_matrix = np.full((n_evt, max_tel_id + 1), DEFAULT_FILL_VALUE, dtype=np.float32)
row_indices, col_indices = np.where(~np.isnan(index_list))
tel_ids = index_list[row_indices, col_indices].astype(int)
valid_mask = tel_ids <= max_tel_id
# Filter for valid telescope IDs and valid column indices in data array
valid_mask = (tel_ids <= max_tel_id) & (col_indices < data.shape[1])
full_matrix[row_indices[valid_mask], tel_ids[valid_mask]] = data[
row_indices[valid_mask], col_indices[valid_mask]
]
Expand All @@ -349,8 +350,82 @@ def _clip_size_array(size_array):
return clipped


def _compute_telescope_indices_to_keep(tel_config, max_tel_id, max_tel_per_type):
"""
Determine which telescope position indices to keep based on mirror area grouping.

After sorting by (mirror_area desc, size desc), telescopes are grouped by mirror area.
This function returns position indices to keep, limiting each mirror area group to
max_tel_per_type telescopes.

Parameters
----------
tel_config : dict
Telescope configuration with 'tel_ids' and 'mirror_area'/'mirror_areas'.
max_tel_id : int
Maximum telescope ID.
max_tel_per_type : int or None
Maximum telescopes to keep per mirror area type. If None, keep all.

Returns
-------
list[int]
List of telescope position indices (0, 1, 2, ...) to keep after truncation.
"""
if max_tel_per_type is None:
# Keep all telescope positions
return list(range(max_tel_id + 1))

# Get mirror areas from tel_config
mirror_areas = tel_config.get("mirror_area")
if mirror_areas is None:
mirror_areas = tel_config.get("mirror_areas")
if mirror_areas is None:
_logger.warning("No mirror_area found in tel_config; keeping all telescopes")
return list(range(max_tel_id + 1))

# Group telescope IDs by mirror area and sort by area (descending)
area_to_tids = {}
for tid, area in zip(tel_config["tel_ids"], mirror_areas):
if tid <= max_tel_id:
area_key = round(float(area), 2) # Round to avoid floating point issues
if area_key not in area_to_tids:
area_to_tids[area_key] = []
area_to_tids[area_key].append(int(tid))

# Sort area types by size (descending)
sorted_area_keys = sorted(area_to_tids.keys(), reverse=True)

# Calculate positions to keep: for each area type, keep first max_tel_per_type positions
indices_to_keep = []
current_position = 0

for area_key in sorted_area_keys:
n_tels_this_type = len(area_to_tids[area_key])
n_to_keep = min(n_tels_this_type, max_tel_per_type)

# Keep positions current_position to current_position + n_to_keep - 1
indices_to_keep.extend(range(current_position, current_position + n_to_keep))
current_position += n_tels_this_type

_logger.info(
f"Feature reduction: {len(sorted_area_keys)} mirror area types, "
f"keeping {len(indices_to_keep)} out of {max_tel_id + 1} telescope positions "
f"(max {max_tel_per_type} per type)"
)

return sorted(indices_to_keep)


def flatten_telescope_data_vectorized(
df, n_tel, features, analysis_type, training=True, tel_config=None, observatory="veritas"
df,
n_tel,
features,
analysis_type,
training=True,
tel_config=None,
observatory="veritas",
max_tel_per_type=None,
):
"""
Vectorized flattening of telescope array columns.
Expand All @@ -372,6 +447,10 @@ def flatten_telescope_data_vectorized(
If True, indicates training mode. Default is True.
tel_config : dict, optional
Telescope configuration dictionary with 'max_tel_id' and 'tel_types'.
observatory : str, optional
Observatory name for indexing mode detection. Default is "veritas".
max_tel_per_type : int, optional
Maximum number of telescopes to keep per mirror area type. If None, keep all.

Returns
-------
Expand Down Expand Up @@ -408,6 +487,14 @@ def flatten_telescope_data_vectorized(
# Sorting by mirror area (desc; proxy for telescope type), then size (desc)
sort_indices = _compute_size_area_sort_indices(size_data, active_mask, tel_config, max_tel_id)

# Determine which telescope positions to keep (for feature reduction)
if max_tel_per_type is not None and tel_config is not None:
tel_indices_to_keep = _compute_telescope_indices_to_keep(
tel_config, max_tel_id, max_tel_per_type
)
else:
tel_indices_to_keep = list(range(max_tel_id + 1))

for var in features:
if var == "mirror_area" and tel_config:
flat_features.update(
Expand Down Expand Up @@ -456,9 +543,25 @@ def flatten_telescope_data_vectorized(
# All variables are now in telescope-ID space; apply sorting and flatten uniformly
data_normalized = data_normalized[np.arange(n_evt)[:, np.newaxis], sort_indices]

for tel_idx in range(max_tel_id + 1):
for tel_idx in tel_indices_to_keep:
flat_features[f"{var}_{tel_idx}"] = data_normalized[:, tel_idx]

# Also filter synthetic features to only keep selected indices
# This applies to features like mirror_area, tel_active, etc. that were added above
if max_tel_per_type is not None:
filtered_features = {}
for key, value in flat_features.items():
# Extract the telescope index from feature names like "size_5", "mirror_area_10", etc.
parts = key.rsplit("_", 1)
if len(parts) == 2 and parts[1].isdigit():
tel_idx = int(parts[1])
if tel_idx in tel_indices_to_keep:
filtered_features[key] = value
else:
# Keep features without telescope index suffix
filtered_features[key] = value
flat_features = filtered_features

index = _get_index(df, n_evt)
df_flat = flatten_telescope_variables(n_tel, flat_features, index, tel_config, analysis_type)
return pd.concat(
Expand Down Expand Up @@ -596,7 +699,13 @@ def _get_index(df_like, n):


def flatten_feature_data(
group_df, ntel, analysis_type, training, tel_config=None, observatory="veritas"
group_df,
ntel,
analysis_type,
training,
tel_config=None,
observatory="veritas",
max_tel_per_type=None,
):
"""
Get flattened features for events.
Expand All @@ -615,6 +724,10 @@ def flatten_feature_data(
Whether in training mode.
tel_config : dict, optional
Telescope configuration dictionary.
observatory : str, optional
Observatory name for indexing mode detection.
max_tel_per_type : int, optional
Maximum number of telescopes to keep per mirror area type. If None, keep all.
"""
df_flat = flatten_telescope_data_vectorized(
group_df,
Expand All @@ -624,6 +737,7 @@ def flatten_feature_data(
training=training,
tel_config=tel_config,
observatory=observatory,
max_tel_per_type=max_tel_per_type,
)
max_tel_id = tel_config["max_tel_id"] if tel_config else ntel - 1
excluded_columns = set(features_module.target_features(analysis_type)) | set(
Expand Down Expand Up @@ -691,6 +805,18 @@ def load_training_data(model_configs, file_list, analysis_type):
if tel_config is None:
tel_config = read_telescope_config(root_file)
model_configs["tel_config"] = tel_config
else:
# Check if current file has a larger max_tel_id and update if needed
current_tel_config = read_telescope_config(root_file)
if current_tel_config["max_tel_id"] > tel_config["max_tel_id"]:
_logger.info(
f"Updating telescope configuration: max_tel_id from "
f"{tel_config['max_tel_id']} to {current_tel_config['max_tel_id']} "
f"(file: {f})"
)
# Replace the full telescope configuration to keep all fields consistent
tel_config = current_tel_config
model_configs["tel_config"] = tel_config

_logger.info(f"Processing file: {f} (file {file_idx}/{total_files})")
tree = root_file["data"]
Expand Down Expand Up @@ -728,6 +854,7 @@ def load_training_data(model_configs, file_list, analysis_type):
training=True,
tel_config=tel_config,
observatory=model_configs.get("observatory", "veritas"),
max_tel_per_type=model_configs.get("max_tel_per_type", None),
)
if analysis_type == "stereo_analysis":
new_cols = {
Expand Down