diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 73d53ea679d..035965cf550 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -174,7 +174,8 @@ jobs: nltk \ fvcore \ scikit-optimize \ - flair + flair \ + optuna kill $KA cd src/main/python python -m unittest discover -s tests/scuro -p 'test_*.py' -v diff --git a/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py b/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py index 0737f18d62d..0305f613b63 100644 --- a/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py +++ b/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py @@ -19,6 +19,7 @@ # # ------------------------------------------------------------- from typing import Dict, List, Tuple, Any, Optional +import inspect import os import numpy as np import logging @@ -31,7 +32,6 @@ import random from systemds.scuro.drsearch.representation_dag import ( RepresentationDAGBuilder, - RepresentationDag, RepresentationNode, ) from systemds.scuro.modality.modality import Modality @@ -40,12 +40,173 @@ from systemds.scuro.utils.checkpointing import CheckpointManager -def get_params_for_node(node_id, params): +def _get_params_for_node(node_id, params): return { k.split("-")[-1]: v for k, v in params.items() if k.startswith(node_id + "-") } +def _param_values_to_spec( + full_name: str, param_values: Any +) -> Optional[Dict[str, Any]]: + if isinstance(param_values, list): + return {"name": full_name, "type": "categorical", "domain": list(param_values)} + if isinstance(param_values, tuple) and len(param_values) == 2: + lo, hi = param_values + if isinstance(lo, int) and isinstance(hi, int): + return {"name": full_name, "type": "integer", "domain": (lo, hi)} + return {"name": full_name, "type": "real", "domain": (float(lo), float(hi))} + if isinstance(param_values, (str, int, float, bool)): + return {"name": full_name, "type": "categorical", "domain": [param_values]} + if hasattr(param_values, "__iter__") and not isinstance( + param_values, (str, bytes, dict) + ): + try: + domain = list(param_values) + except TypeError: + return None + if domain: + return {"name": full_name, "type": "categorical", "domain": domain} + return None + + +def _expand_aggregation_param_specs(op_id: str, agg_cls: Any) -> List[Dict[str, Any]]: + if not inspect.isclass(agg_cls): + return [] + + from systemds.scuro.representations.window_aggregation import ( + nested_aggregation_param_names, + ) + + nested_names = nested_aggregation_param_names(agg_cls) + if not nested_names: + return [] + + try: + instance = agg_cls() + except Exception: + return [] + + search_template = getattr(instance, "parameters", None) or {} + specs = [] + for nested_name in nested_names: + nested_values = search_template.get(nested_name) + if nested_values is None: + continue + full_name = f"{op_id}-aggregation_function_{nested_name}" + spec = _param_values_to_spec(full_name, nested_values) + if spec is not None: + specs.append(spec) + return specs + + +def _is_window_operation(op: Any) -> bool: + if not inspect.isclass(op): + return False + try: + from systemds.scuro.representations.window_aggregation import Window + + return issubclass(op, Window) + except ImportError: + return False + + +def _materialize_node_params( + node: RepresentationNode, flat_params: Dict[str, Any] +) -> Dict[str, Any]: + if not flat_params or node.operation is None: + return flat_params + if not _is_window_operation(node.operation): + return flat_params + + from systemds.scuro.representations.window_aggregation import ( + instantiate_nested_aggregation, + ) + + template = node.parameters or {} + agg_cls = template.get("aggregation_function") + + out: Dict[str, Any] = {} + agg_sub: Dict[str, Any] = {} + prefix = "aggregation_function_" + for key, value in flat_params.items(): + if key.startswith(prefix): + agg_sub[key[len(prefix) :]] = value + else: + out[key] = value + + if inspect.isclass(agg_cls): + out["aggregation_function"] = instantiate_nested_aggregation(agg_cls, agg_sub) + + return out + + +def _is_aggregated_representation_operation(op: Any) -> bool: + if not inspect.isclass(op): + return False + try: + from systemds.scuro.representations.aggregated_representation import ( + AggregatedRepresentation, + ) + + return issubclass(op, AggregatedRepresentation) + except ImportError: + return False + + +def _has_pushdown_aggregation(node_parameters: Optional[Dict[str, Any]]) -> bool: + return bool(node_parameters and "_pushdown_aggregation" in node_parameters) + + +def _apply_pushdown_trial_params( + base_params: Dict[str, Any], trial_params: Dict[str, Any] +) -> Dict[str, Any]: + """Merge trial values into node.parameters['_pushdown_aggregation'].""" + pushdown = copy.deepcopy(base_params.get("_pushdown_aggregation", {})) + prefix = "aggregation_function_" + top_level: Dict[str, Any] = {} + + for key, value in trial_params.items(): + if key.startswith(prefix): + pushdown[key] = value + elif key == "aggregation": + pushdown["aggregation_function_aggregation_function"] = value + elif key != "_pushdown_aggregation": + top_level[key] = value + + result = {**base_params, **top_level} + result["_pushdown_aggregation"] = pushdown + + for key in list(result.keys()): + if key.startswith(prefix): + result.pop(key, None) + return result + + +def _apply_trial_params_to_node( + node: RepresentationNode, global_params: Dict[str, Any] +) -> Dict[str, Any]: + base_params = copy.deepcopy(node.parameters) if node.parameters else {} + flat_params = _get_params_for_node(node.node_id, global_params) + if not flat_params: + return base_params + + trial_params = _materialize_node_params(node, flat_params) + + if _has_pushdown_aggregation(base_params): + return _apply_pushdown_trial_params(base_params, trial_params) + + if _is_aggregated_representation_operation(node.operation): + if "aggregation" in trial_params: + trial_params["aggregation_function_aggregation_function"] = trial_params[ + "aggregation" + ] + base_params.pop("aggregation_function_aggregation_function", None) + base_params.pop("aggregation_function_pad_modality", None) + + return {**base_params, **trial_params} + + @dataclass class HyperparamResult: representation_name: str @@ -91,7 +252,7 @@ def get_k_best_results(self, modality, task, performance_metric_name): prev_node_id = None for node in result.dag.nodes: if node.operation is not None and node.parameters: - params = get_params_for_node(node.node_id, result.best_params) + params = _apply_trial_params_to_node(node, result.best_params) prev_node_id = dag_with_best_params.create_operation_node( node.operation, [prev_node_id], params ) @@ -123,6 +284,12 @@ def __init__( random_state: int = 42, exhaustive_threshold: int = 256, local_search_patience: int = 3, + optuna_sampler: str = "tpe", # "tpe" | "random" | "bayes" + use_wandb: bool = False, + wandb_project: Optional[str] = None, + wandb_entity: Optional[str] = None, + wandb_group: Optional[str] = None, + wandb_tags: Optional[List[str]] = None, ): self.tasks = tasks self.unimodal_optimization_results = optimization_results @@ -137,7 +304,7 @@ def __init__( self.k_best_cache = None self.k_best_cache_by_modality = None self.k_best_representations = None - self.extract_k_best_modalities_per_task() + self.extract_k_best_modalities_per_task() # TODO: cache needed for multimodal optimization self.debug = debug self.logger = logging.getLogger(__name__) self.checkpoint_every = checkpoint_every @@ -152,10 +319,13 @@ def __init__( checkpoint_every=self.checkpoint_every, resume=self.resume, ) - if debug: - logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" - ) + self.optuna_sampler = optuna_sampler + self.use_wandb = use_wandb + self.wandb_project = wandb_project or "scuro-hyperparam" + self.wandb_entity = wandb_entity + self.wandb_group = wandb_group + self.wandb_tags = wandb_tags or [] + self._wandb_run = None def get_modalities_by_id(self, modality_ids: List[int]) -> Modality: modalities = [] @@ -187,7 +357,7 @@ def extract_k_best_modalities_per_task(self): for modality in self.modalities: k_best_results, cached_data = ( self.unimodal_optimization_results.get_k_best_results( - modality, task, self.scoring_metric + modality, task, self.scoring_metric, cache_needed=False ) ) representations[task.model.name][modality.modality_id] = k_best_results @@ -216,7 +386,7 @@ def resume_from_checkpoint(self): self.optimization_results.results = results def tune_unimodal_representations(self, max_eval_per_rep: Optional[int] = None): - self.resume_from_checkpoint() + # self.resume_from_checkpoint() for task in self.tasks: reps = self.k_best_representations[task.model.name] skip_remaining = 0 @@ -240,11 +410,11 @@ def tune_unimodal_representations(self, max_eval_per_rep: Optional[int] = None): self.optimization_results.add_result(results) self._checkpoint_manager.increment(task.model.name, len(results)) self._checkpoint_manager.checkpoint_if_due( - self.optimization_results.results, "eval_count_by_task" + self.optimization_results.results, ) except Exception: self._checkpoint_manager.save_checkpoint( - self.optimization_results.results, "eval_count_by_task", {} + self.optimization_results.results, {} ) raise @@ -269,7 +439,7 @@ def visit_node(node_id): visit_node(input_id) visited.add(node_id) if node.operation is not None: - params = self._get_params_for_node(node) + params = self.__get_params_for_node(node) if params: hyperparams[node_id] = params reps.append(node.operation) @@ -296,9 +466,9 @@ def visit_node(node_id): ) all_results = [baseline] else: - n_calls = max_evals if max_evals else 50 param_specs = self._build_param_specs(hyperparams) - default_config = {} + discrete_size = self._estimate_discrete_search_size(param_specs) + n_calls = min(discrete_size, max_evals) if max_evals else discrete_size all_results = self._search_best_configs( dag=dag, task=task, @@ -308,6 +478,7 @@ def visit_node(node_id): param_specs=param_specs, budget=n_calls, initial_config=None, + rep_name=rep_name, ) if not all_results: @@ -317,6 +488,8 @@ def get_score(result): score = result[1] if isinstance(score, PerformanceMeasure): return score.average_scores[self.scoring_metric] + elif isinstance(score, list): + return score[1] return score if self.maximize_metric: @@ -325,21 +498,7 @@ def get_score(result): best_params, best_score = min(all_results, key=get_score) tuning_time = time.time() - start_time - # results = self.unimodal_optimization_results.results[self.modalities[0].modality_id][task.model.name] - - # default_result = sorted( - # results, - # key=lambda r: r.val_score[self.scoring_metric], - # reverse=True, - # )[0] - # pm = PerformanceMeasure(name=self.scoring_metric, metrics=self.scoring_metric, higher_is_better=self.maximize_metric) - # pm.add_scores({self.scoring_metric: default_result.val_score[self.scoring_metric]}) - # default_params = self._get_default_params(dag) - # def_par ={} - # for k, v in default_params.items(): - # for k_v, v_v in v.items(): - # def_par[k+"-"+k_v] = v_v - # all_results.append((def_par, pm)) + best_result = HyperparamResult( representation_name=rep_name, best_params=best_params, @@ -354,11 +513,31 @@ def get_score(result): return best_result - def _get_params_for_node(self, node: RepresentationNode) -> Dict[str, Any]: - if not node.operation().parameters: + def __get_params_for_node(self, node: RepresentationNode) -> Dict[str, Any]: + try: + if node.parameters: + op = node.operation(params=node.parameters) + else: + op = node.operation() + except (TypeError, ValueError): + op = node.operation() + + if not op.parameters: return None - params = copy.deepcopy(node.operation().parameters) + params = copy.deepcopy(op.parameters) + if node.parameters: + if inspect.isclass(node.parameters.get("aggregation_function")): + params["aggregation_function"] = node.parameters["aggregation_function"] + for fixed_key in ("target_dimensions", "self_contained"): + if fixed_key in node.parameters: + params[fixed_key] = node.parameters[fixed_key] + + if _has_pushdown_aggregation(node.parameters): + from systemds.scuro.representations.aggregate import Aggregation + + params["aggregation_function"] = Aggregation + return params def _build_param_specs( @@ -367,23 +546,15 @@ def _build_param_specs( param_specs = [] for op_id, op_params in hyperparams.items(): for param_name, param_values in op_params.items(): + if param_name == "aggregation_function": + expanded = _expand_aggregation_param_specs(op_id, param_values) + if expanded: + param_specs.extend(expanded) + continue full_name = op_id + "-" + param_name - if isinstance(param_values, list): - param_type = "categorical" - domain = list(param_values) - elif isinstance(param_values, tuple) and len(param_values) == 2: - lo, hi = param_values - if isinstance(lo, int) and isinstance(hi, int): - param_type = "integer" - else: - param_type = "real" - domain = (lo, hi) - else: - param_type = "categorical" - domain = [param_values] - param_specs.append( - {"name": full_name, "type": param_type, "domain": domain} - ) + spec = _param_values_to_spec(full_name, param_values) + if spec is not None: + param_specs.append(spec) return param_specs def _config_key(self, params: Dict[str, Any]) -> Tuple[Tuple[str, Any], ...]: @@ -536,6 +707,24 @@ def _evaluate_configs( seen_configs[key] for key in unique_keys_in_order if key in seen_configs ] + def _suggest_config_from_specs( + self, trial, param_specs: List[Dict[str, Any]] + ) -> Dict[str, Any]: + + config = {} + for spec in param_specs: + name = spec["name"] + domain = spec["domain"] + if spec["type"] == "categorical": + config[name] = trial.suggest_categorical(name, list(domain)) + elif spec["type"] == "integer": + lo, hi = int(domain[0]), int(domain[1]) + config[name] = trial.suggest_int(name, lo, hi) + else: + lo, hi = float(domain[0]), float(domain[1]) + config[name] = trial.suggest_float(name, lo, hi) + return config + def _search_best_configs( self, dag, @@ -545,126 +734,137 @@ def _search_best_configs( modalities_override, param_specs: List[Dict[str, Any]], budget: int, - initial_config: Dict[str, Any], + initial_config: Optional[Dict[str, Any]], + rep_name: str = "", ) -> List[Tuple[Dict[str, Any], Any]]: + import optuna + from optuna.trial import TrialState + + optuna.logging.set_verbosity( + optuna.logging.INFO if self.debug else optuna.logging.WARNING + ) + budget = max(1, budget) - seen_configs: Dict[Tuple[Tuple[str, Any], ...], Tuple[Dict[str, Any], Any]] = {} all_results: List[Tuple[Dict[str, Any], Any]] = [] - best_score = np.nan - best_config = None - if initial_config is not None and budget > 0: - initial_results = self._evaluate_configs( + seen: Dict[Tuple[Tuple[str, Any], ...], Tuple[Dict[str, Any], Any]] = {} + + if initial_config is not None: + batch = self._evaluate_configs( dag, task, node_order, modality_ids, modalities_override, [initial_config], - seen_configs, + seen, ) - all_results.extend(initial_results) - if initial_results: - p, s = initial_results[0] - best_config = p - best_score = self._score_value(s) - budget -= 1 - - discrete_size = self._estimate_discrete_search_size(param_specs) - if discrete_size is not None and discrete_size <= min( - self.exhaustive_threshold, budget - ): - candidates = self._enumerate_configs(param_specs) - self._rng.shuffle(candidates) - candidates = candidates[:budget] - batch_results = self._evaluate_configs( - dag, - task, - node_order, - modality_ids, - modalities_override, - candidates, - seen_configs, - ) - all_results.extend(batch_results) + all_results.extend(batch) + budget = max(0, budget - len(batch)) + + if budget <= 0: return all_results - initial_budget = min(budget, max(8, len(param_specs) * 4)) - initial_candidates = [ - self._sample_random_config(param_specs) for _ in range(initial_budget) - ] - initial_results = self._evaluate_configs( - dag, - task, - node_order, - modality_ids, - modalities_override, - initial_candidates, - seen_configs, + direction = "maximize" if self.maximize_metric else "minimize" + sampler = ( + optuna.samplers.TPESampler(seed=self.random_state) + if self.optuna_sampler == "tpe" + else optuna.samplers.RandomSampler(seed=self.random_state) ) - all_results.extend(initial_results) - - for params, score in initial_results: - numeric_score = self._score_value(score) - if self._is_better(numeric_score, best_score): - best_score = numeric_score - best_config = params - eval_count = len(seen_configs) - no_improvement_rounds = 0 - step_scale = 0.5 + study = optuna.create_study( + direction=direction, + sampler=sampler, + study_name=f"{task.model.name}-{rep_name}"[:64], + ) - while eval_count < budget: - if best_config is None: - candidate_batch = [self._sample_random_config(param_specs)] - else: - candidate_batch = [] - batch_size = min( - max(2, abs(self.n_jobs) if self.n_jobs != 0 else 1), - budget - eval_count, + wandb_kwargs = {} + if self.use_wandb: + try: + import wandb + from optuna.integration.wandb import WeightsAndBiasesCallback + + wandb_kwargs["wandb_kwargs"] = { + "project": self.wandb_project, + "entity": self.wandb_entity, + "group": self.wandb_group or task.model.name, + "tags": self.wandb_tags + [rep_name, task.model.name], + "name": f"{task.model.name}-{rep_name}-{int(time.time())}", + "config": { + "task": task.model.name, + "representation": rep_name, + "scoring_metric": self.scoring_metric, + "budget": budget, + }, + } + wandb_cb = WeightsAndBiasesCallback( + metric_name=self.scoring_metric, + wandb_kwargs=wandb_kwargs["wandb_kwargs"], ) - for _ in range(batch_size): - candidate_batch.append( - self._generate_neighbor_config( - best_config, param_specs, step_scale - ) - ) + except ImportError: + self.logger.warning( + "wandb/optuna-integration not installed; disabling W&B" + ) + wandb_cb = None + else: + wandb_cb = None + + trial_results: List[Tuple[Dict[str, Any], Any]] = [] - if budget - eval_count > 3: - candidate_batch.append(self._sample_random_config(param_specs)) + def objective(trial: optuna.Trial) -> float: + config = self._suggest_config_from_specs(trial, param_specs) + key = self._config_key(config) + if key in seen: + trial.set_user_attr("duplicate", True) + raise optuna.TrialPruned() - batch_results = self._evaluate_configs( + params, scores = self.evaluate_dag_config( dag, - task, + config, node_order, modality_ids, - modalities_override, - candidate_batch, - seen_configs, + task, + modalities_override=modalities_override, ) - if not batch_results: - step_scale = max(0.05, step_scale * 0.5) - if step_scale <= 0.05: - break - continue + train_score = self._score_value(scores[0]) + val_score = self._score_value(scores[1]) + test_score = self._score_value(scores[2]) + if np.isnan(val_score): + raise optuna.TrialPruned() + + seen[self._config_key(params)] = ( + params, + [train_score, val_score, test_score], + ) + + trial_results.append((params, [train_score, val_score, test_score])) + return val_score + + callbacks = [c for c in [wandb_cb] if c is not None] + n_jobs = 1 if self.n_jobs == 0 else max(1, abs(self.n_jobs)) + try: + study.optimize( + objective, + n_trials=budget, + n_jobs=n_jobs, + callbacks=callbacks, + show_progress_bar=self.debug, + catch=(Exception,), + ) + finally: + if self.use_wandb and wandb.run is not None: + wandb.run.finish() - improved = False - for params, score in batch_results: - numeric_score = self._score_value(score) - if self._is_better(numeric_score, best_score): - best_score = numeric_score - best_config = params - improved = True - all_results.extend(batch_results) - eval_count = len(seen_configs) - - if improved: - no_improvement_rounds = 0 - step_scale = min(0.5, step_scale * 1.1) + all_results.extend(trial_results) + + for trial in study.trials: + if trial.state != TrialState.COMPLETE: + continue + config = trial.params + key = self._config_key(config) + if key in seen: + all_results.append(seen[key]) else: - no_improvement_rounds += 1 - step_scale = max(0.05, step_scale * 0.7) - if no_improvement_rounds >= self.local_search_patience: - break + pass return all_results @@ -684,11 +884,10 @@ def evaluate_dag_config( ): try: dag_copy = copy.deepcopy(dag) - for node_id in node_order: node = dag_copy.get_node_by_id(node_id) - if node.operation is not None and node.parameters: - node.parameters = get_params_for_node(node_id, params) + if node.operation is not None: + node.parameters = _apply_trial_params_to_node(node, params) modalities = ( modalities_override @@ -696,7 +895,7 @@ def evaluate_dag_config( else self.get_modalities_by_id(modality_ids) ) modified_modality = dag_copy.execute(modalities, task) - score = task.run(modified_modality.data)[1] + score = task.run(modified_modality.data) return params, score except Exception as e: @@ -704,7 +903,7 @@ def evaluate_dag_config( traceback.print_exc() self.logger.error(f"Error evaluating DAG with params {params}: {e}") - return params, np.nan + return params, [np.nan, np.nan, np.nan] def tune_multimodal_representations( self, diff --git a/src/main/python/systemds/scuro/representations/aggregate.py b/src/main/python/systemds/scuro/representations/aggregate.py index 8389fadbd0d..cf2c371676f 100644 --- a/src/main/python/systemds/scuro/representations/aggregate.py +++ b/src/main/python/systemds/scuro/representations/aggregate.py @@ -51,7 +51,7 @@ def _sum_agg(data, aggregate_dim=0): def __init__(self, aggregation_function="mean", pad_modality=True, params=None): if params is not None: aggregation_function = params["aggregation_function"] - pad_modality = params["pad_modality"] + pad_modality = params.get("pad_modality", True) if aggregation_function not in list(self._aggregation_function.keys()): raise ValueError("Invalid aggregation function") diff --git a/src/main/python/systemds/scuro/representations/bert.py b/src/main/python/systemds/scuro/representations/bert.py index fcaed8d4935..245466afb43 100644 --- a/src/main/python/systemds/scuro/representations/bert.py +++ b/src/main/python/systemds/scuro/representations/bert.py @@ -51,11 +51,13 @@ def __init__( aggregation=None, params=None, ): - parameters = {"batch_size": [1, 2, 4, 8, 16, 32, 64, 128]} + parameters = { + **(parameters or {}), + "batch_size": [1, 2, 4, 8, 16, 32, 64, 128], + } self.model_name = model_name super().__init__(representation_name, ModalityType.EMBEDDING, parameters) - - self.layer_name = layer + self.layer = layer self.output_file = output_file self.max_seq_length = max_seq_length self.needs_context = True @@ -67,6 +69,9 @@ def __init__( self.data_type = torch.float32 self.aggregation = aggregation self.params = params + if params is not None: + self.layer = params.get("layer", self.layer) + self.batch_size = int(params.get("batch_size", self.batch_size)) @property def gpu_id(self): @@ -83,12 +88,12 @@ def set_parameters( if params is not None: self.max_seq_length = int(params.get("max_seq_length", max_seq_length)) self.batch_size = int(params.get("batch_size", batch_size)) - self.layer_name = params.get("layer_name", layer) + self.layer = params.get("layer", layer) self.output_file = params.get("output_file", output_file) else: self.max_seq_length = max_seq_length self.batch_size = batch_size - self.layer_name = layer + self.layer = layer self.output_file = output_file def get_output_stats(self, input_stats) -> RepresentationStats: @@ -184,14 +189,19 @@ def transform(self, modality, aggregation=None): def get_activation(name): def hook(model, input, output): - self.bert_output = output.detach().cpu().numpy() + if isinstance(output, tuple): + self.bert_output = output[0] + elif hasattr(output, "last_hidden_state"): + self.bert_output = output.last_hidden_state + else: + self.bert_output = output return hook aggregate_dim = (0,) - if self.layer_name != "cls": + if self.layer != "cls": for name, layer in self.model.named_modules(): - if name == self.layer_name: + if name == self.layer: layer.register_forward_hook(get_activation(name)) break if ModalityType.TEXT.has_field(modality.metadata, "text_spans"): @@ -211,7 +221,6 @@ def hook(model, input, output): embeddings = self.create_embeddings( modality.data, self.model, tokenizer, aggregation ) - if self.output_file is not None: save_embeddings(embeddings, self.output_file) @@ -274,11 +283,15 @@ def create_embeddings(self, data, model, tokenizer, aggregation=None): with torch.no_grad(): outputs = model(**inputs) - if self.layer_name == "cls": + if self.layer == "cls": cls_embedding = outputs.last_hidden_state.detach().cpu().numpy() else: cls_embedding = self.bert_output.cpu().numpy() - if aggregation is not None: + if ( + aggregation is not None + and self.layer != "pooler" + and self.layer != "pooler.activation" + ): cls_embedding = aggregation.execute(cls_embedding) cls_embeddings.extend(cls_embedding) @@ -298,7 +311,7 @@ def __init__( self.set_parameters(params, max_seq_length, batch_size, layer, output_file) parameters = { - "layer_name": [ + "layer": [ "cls", "encoder.layer.0", "encoder.layer.1", @@ -342,7 +355,7 @@ def __init__( self.set_parameters(params, max_seq_length, batch_size, layer, output_file) parameters = { - "layer_name": [ + "layer": [ "cls", "encoder.layer.0", "encoder.layer.1", @@ -386,7 +399,7 @@ def __init__( self.set_parameters(params, max_seq_length, batch_size, layer, output_file) parameters = { - "layer_name": [ + "layer": [ "cls", "transformer.layer.0", "transformer.layer.1", @@ -420,7 +433,7 @@ def __init__( params=None, ): self.set_parameters(params, max_seq_length, batch_size, layer, output_file) - parameters = {"layer_name": ["cls", "encoder.albert_layer_groups.0", "pooler"]} + parameters = {"layer": ["cls", "encoder.albert_layer_groups.0", "pooler"]} super().__init__( "ALBERT", "albert-base-v2", @@ -445,7 +458,7 @@ def __init__( params=None, ): parameters = { - "layer_name": [ + "layer": [ "cls", "encoder.layer.0", "encoder.layer.1", diff --git a/src/main/python/systemds/scuro/representations/window_aggregation.py b/src/main/python/systemds/scuro/representations/window_aggregation.py index 541a7b68fb2..5d63e4790dd 100644 --- a/src/main/python/systemds/scuro/representations/window_aggregation.py +++ b/src/main/python/systemds/scuro/representations/window_aggregation.py @@ -19,6 +19,9 @@ # # ------------------------------------------------------------- +from concurrent.futures import ThreadPoolExecutor +import inspect +import os import numpy as np import math @@ -33,6 +36,37 @@ ) +def nested_aggregation_param_names(agg_cls): + if not inspect.isclass(agg_cls): + return set() + if agg_cls is Aggregation or agg_cls.__name__ == "Aggregation": + return {"aggregation_function", "pad_modality"} + try: + return set(agg_cls().parameters.keys()) + except Exception: + return set() + + +def instantiate_nested_aggregation(agg_cls, nested): + if not inspect.isclass(agg_cls): + return agg_cls + if not nested: + return agg_cls() + + if agg_cls is Aggregation or agg_cls.__name__ == "Aggregation": + return Aggregation(params=nested) + + allowed = nested_aggregation_param_names(agg_cls) + filtered = {key: value for key, value in nested.items() if key in allowed} + if not filtered: + return agg_cls() + + init_params = inspect.signature(agg_cls.__init__).parameters + if "params" in init_params: + return agg_cls(params=filtered) + return agg_cls(**filtered) + + class Window(Context): def __init__(self, name, aggregation_function): self.aggregation_function = aggregation_function @@ -117,23 +151,52 @@ def _rest_numel(shape): ) class WindowAggregation(Window): def __init__( - self, aggregation_function="mean", window_size=10, pad=True, params=None + self, + aggregation_function="mean", + window_size=10, + pad=True, + params=None, ): if params is not None: - aggregation_function = params["aggregation_function"] - try: - aggregation_function = aggregation_function() - except: - pass + if isinstance( + params.get("aggregation_function"), (Aggregation, Representation) + ): + aggregation_function = params["aggregation_function"] + else: + nested_agg = { + key[len("aggregation_function_") :]: value + for key, value in params.items() + if key.startswith("aggregation_function_") + } + agg_value = params.get("aggregation_function") + if nested_agg and inspect.isclass(agg_value): + aggregation_function = instantiate_nested_aggregation( + agg_value, nested_agg + ) + elif inspect.isclass(agg_value): + aggregation_function = agg_value() + else: + aggregation_function = params.get( + "aggregation_function", aggregation_function + ) window_size = params["window_size"] - pad = True + pad = params.get("pad", True) super().__init__("WindowAggregation", aggregation_function) - self.parameters["window_size"] = [5, 10, 15, 25, 50, 100] + self.parameters["window_size"] = (4, 128) self.window_size = int(window_size) self.pad = pad def get_output_stats(self, input_stats: RepresentationStats) -> tuple: - in_shape = tuple(int(s) for s in input_stats.output_shape) + if not isinstance(self.aggregation_function, Aggregation): + windowed_input_stats = RepresentationStats( + input_stats.num_instances, (self.window_size,) + ) + in_shape = self.aggregation_function.get_output_stats( + windowed_input_stats + ).output_shape + in_shape = (input_stats.output_shape[0], *in_shape) + else: + in_shape = tuple(int(s) for s in input_stats.output_shape) if len(in_shape) == 1: self.stats = RepresentationStats( input_stats.num_instances, @@ -163,6 +226,14 @@ def estimate_peak_memory_bytes(self, input_stats: RepresentationStats) -> dict: if len(in_shape) == 0: return {"cpu_peak_bytes": 0, "gpu_peak_bytes": 0} + out_stats = self.get_output_stats(input_stats) + out_shape = out_stats.output_shape + output_bytes = ( + input_stats.num_instances + * np.prod(out_shape) + * np.dtype(self.data_type).itemsize + ) + effective_seq_len = in_shape[0] in_numel = effective_seq_len * self._rest_numel(in_shape) output_bytes = self.estimate_output_memory_bytes(input_stats) @@ -188,7 +259,7 @@ def execute(self, modality): for instance in modality.data: new_length = math.ceil(len(instance) / self.window_size) if modality.get_data_layout() == DataLayout.SINGLE_LEVEL: - instance = np.array(instance) + instance = np.asarray(instance) instance.setflags(write=False) windowed_instance = self.window_aggregate_single_level( instance, new_length @@ -199,7 +270,7 @@ def execute(self, modality): windowed_instance = self.window_aggregate_nested_level( instance, new_length ) - original_lengths.append(new_length) + original_lengths.append(windowed_instance.shape[0]) windowed_data.append(windowed_instance) if self.pad and not isinstance(windowed_data, np.ndarray):