From 64ff1844ee998bbe300089ddf59408425d017e2b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 20:44:23 +0000 Subject: [PATCH 1/4] Initial plan From 408acef703004420f0226995e454fe393921cccc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 20:56:01 +0000 Subject: [PATCH 2/4] Upgrade Lightning to 2.6.1 with before_instantiate_classes hook Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com> --- chebai/cli.py | 103 +++++++++++++++++++++++++++++++------------------ pyproject.toml | 2 +- 2 files changed, 67 insertions(+), 38 deletions(-) diff --git a/chebai/cli.py b/chebai/cli.py index 1aaba53c..ad3b6d83 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -31,6 +31,68 @@ def __init__(self, *args, **kwargs): """ super().__init__(trainer_class=CustomTrainer, *args, **kwargs) + def before_instantiate_classes(self) -> None: + """ + Hook called before instantiating classes (Lightning 2.6+ compatible). + Instantiate the datamodule early to compute num_labels and feature_vector_size. + """ + # Get the current subcommand config (fit, test, validate, predict, etc.) + subcommand = self.config.get(self.config["subcommand"]) + + if subcommand and "data" in subcommand: + # Instantiate datamodule to get num_labels and feature_vector_size + data_config = subcommand["data"] + + if "class_path" in data_config: + # Import and instantiate the datamodule class + module_path, class_name = data_config["class_path"].rsplit(".", 1) + import importlib + module = importlib.import_module(module_path) + data_class = getattr(module, class_name) + + # Instantiate with init_args + init_args = data_config.get("init_args", {}) + data_instance = data_class(**init_args) + + # Call prepare_data and setup to initialize dynamic properties + if hasattr(data_instance, "_num_of_labels") and data_instance._num_of_labels is None: + data_instance.prepare_data() + data_instance.setup() + + num_labels = data_instance.num_of_labels + feature_vector_size = data_instance.feature_vector_size + + # Update config with the computed values if not already set + if "model" in subcommand and "init_args" in subcommand["model"]: + model_init_args = subcommand["model"]["init_args"] + if model_init_args.get("out_dim") is None: + model_init_args["out_dim"] = num_labels + if model_init_args.get("input_dim") is None: + model_init_args["input_dim"] = feature_vector_size + + # Update metrics num_labels in all metrics configurations + for kind in ("train", "val", "test"): + metrics_key = f"{kind}_metrics" + if metrics_key in model_init_args and model_init_args[metrics_key]: + metrics_config = model_init_args[metrics_key] + if "init_args" in metrics_config and "metrics" in metrics_config["init_args"]: + for metric_name, metric_config in metrics_config["init_args"]["metrics"].items(): + if "init_args" in metric_config and "num_labels" in metric_config["init_args"]: + if metric_config["init_args"]["num_labels"] is None: + metric_config["init_args"]["num_labels"] = num_labels + + # Update trainer callbacks num_labels + if "trainer" in subcommand and "callbacks" in subcommand["trainer"]: + callbacks = subcommand["trainer"]["callbacks"] + if isinstance(callbacks, list): + for callback in callbacks: + if "init_args" in callback and "num_labels" in callback["init_args"]: + if callback["init_args"]["num_labels"] is None: + callback["init_args"]["num_labels"] = num_labels + elif "init_args" in callbacks and "num_labels" in callbacks["init_args"]: + if callbacks["init_args"]["num_labels"] is None: + callbacks["init_args"]["num_labels"] = num_labels + def add_arguments_to_parser(self, parser: LightningArgumentParser): """ Link input parameters that are used by different classes (e.g. number of labels) @@ -40,25 +102,8 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): parser (LightningArgumentParser): Argument parser instance. """ - def call_data_methods(data: Type[XYBaseDataModule]): - if data._num_of_labels is None: - data.prepare_data() - data.setup() - return data.num_of_labels - - parser.link_arguments( - "data", - "model.init_args.out_dim", - apply_on="instantiate", - compute_fn=call_data_methods, - ) - - parser.link_arguments( - "data.feature_vector_size", - "model.init_args.input_dim", - apply_on="instantiate", - ) - + # Link num_labels to metrics configurations + # These links use the values set in before_instantiate_classes() for kind in ("train", "val", "test"): for average in ( "micro-f1", @@ -70,30 +115,14 @@ def call_data_methods(data: Type[XYBaseDataModule]): "rmse", "r2", ): - # When using lightning > 2.5.1 then need to uncomment all metrics that are not used - # for average in ("mse", "rmse","r2"): # for regression - # for average in ("f1", "roc-auc"): # for binary classification - # for average in ("micro-f1", "macro-f1", "roc-auc"): # for multilabel classification - # for average in ("micro-f1", "macro-f1", "balanced-accuracy", "roc-auc"): # for multilabel classification using balanced-accuracy parser.link_arguments( - "data.num_of_labels", + "model.init_args.out_dim", f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels", - apply_on="instantiate", ) parser.link_arguments( - "data.num_of_labels", "trainer.callbacks.init_args.num_labels" + "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels" ) - # parser.link_arguments( - # "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels" - # ) - # parser.link_arguments( - # "data", "model.init_args.criterion.init_args.data_extractor" - # ) - # parser.link_arguments( - # "data.init_args.chebi_version", - # "model.init_args.criterion.init_args.data_extractor.init_args.chebi_version", - # ) parser.link_arguments( "data", "model.init_args.criterion.init_args.data_extractor" diff --git a/pyproject.toml b/pyproject.toml index b3652b00..f407f340 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "transformers", "pysmiles==1.1.2", "rdkit==2024.3.6", - "lightning==2.5.1", + "lightning==2.6.1", ] [project.optional-dependencies] From cbed081b33743c8185a823506dde68441c0188e7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 20:58:29 +0000 Subject: [PATCH 3/4] Refactor before_instantiate_classes for better code quality Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com> --- chebai/cli.py | 131 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 79 insertions(+), 52 deletions(-) diff --git a/chebai/cli.py b/chebai/cli.py index ad3b6d83..c26d4feb 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -39,59 +39,86 @@ def before_instantiate_classes(self) -> None: # Get the current subcommand config (fit, test, validate, predict, etc.) subcommand = self.config.get(self.config["subcommand"]) - if subcommand and "data" in subcommand: - # Instantiate datamodule to get num_labels and feature_vector_size - data_config = subcommand["data"] + if not (subcommand and "data" in subcommand): + return - if "class_path" in data_config: - # Import and instantiate the datamodule class - module_path, class_name = data_config["class_path"].rsplit(".", 1) - import importlib - module = importlib.import_module(module_path) - data_class = getattr(module, class_name) - - # Instantiate with init_args - init_args = data_config.get("init_args", {}) - data_instance = data_class(**init_args) - - # Call prepare_data and setup to initialize dynamic properties - if hasattr(data_instance, "_num_of_labels") and data_instance._num_of_labels is None: - data_instance.prepare_data() - data_instance.setup() - - num_labels = data_instance.num_of_labels - feature_vector_size = data_instance.feature_vector_size - - # Update config with the computed values if not already set - if "model" in subcommand and "init_args" in subcommand["model"]: - model_init_args = subcommand["model"]["init_args"] - if model_init_args.get("out_dim") is None: - model_init_args["out_dim"] = num_labels - if model_init_args.get("input_dim") is None: - model_init_args["input_dim"] = feature_vector_size - - # Update metrics num_labels in all metrics configurations - for kind in ("train", "val", "test"): - metrics_key = f"{kind}_metrics" - if metrics_key in model_init_args and model_init_args[metrics_key]: - metrics_config = model_init_args[metrics_key] - if "init_args" in metrics_config and "metrics" in metrics_config["init_args"]: - for metric_name, metric_config in metrics_config["init_args"]["metrics"].items(): - if "init_args" in metric_config and "num_labels" in metric_config["init_args"]: - if metric_config["init_args"]["num_labels"] is None: - metric_config["init_args"]["num_labels"] = num_labels - - # Update trainer callbacks num_labels - if "trainer" in subcommand and "callbacks" in subcommand["trainer"]: - callbacks = subcommand["trainer"]["callbacks"] - if isinstance(callbacks, list): - for callback in callbacks: - if "init_args" in callback and "num_labels" in callback["init_args"]: - if callback["init_args"]["num_labels"] is None: - callback["init_args"]["num_labels"] = num_labels - elif "init_args" in callbacks and "num_labels" in callbacks["init_args"]: - if callbacks["init_args"]["num_labels"] is None: - callbacks["init_args"]["num_labels"] = num_labels + data_config = subcommand["data"] + if "class_path" not in data_config: + return + + # Import and instantiate the datamodule class + module_path, class_name = data_config["class_path"].rsplit(".", 1) + import importlib + module = importlib.import_module(module_path) + data_class = getattr(module, class_name) + + # Instantiate with init_args + init_args = data_config.get("init_args", {}) + data_instance = data_class(**init_args) + + # Call prepare_data and setup to initialize dynamic properties + # We need to check the private attribute to avoid calling the property which has an assert + if hasattr(data_instance, "_num_of_labels") and data_instance._num_of_labels is None: + data_instance.prepare_data() + data_instance.setup() + + num_labels = data_instance.num_of_labels + feature_vector_size = data_instance.feature_vector_size + + # Update model init args + self._update_model_args(subcommand, num_labels, feature_vector_size) + + # Update trainer callbacks + self._update_trainer_callbacks(subcommand, num_labels) + + def _update_model_args(self, subcommand: dict, num_labels: int, feature_vector_size: int) -> None: + """Helper method to update model initialization arguments.""" + if "model" not in subcommand or "init_args" not in subcommand["model"]: + return + + model_init_args = subcommand["model"]["init_args"] + + # Set out_dim and input_dim if not already set + if model_init_args.get("out_dim") is None: + model_init_args["out_dim"] = num_labels + if model_init_args.get("input_dim") is None: + model_init_args["input_dim"] = feature_vector_size + + # Update metrics num_labels in all metrics configurations + for kind in ("train", "val", "test"): + metrics_key = f"{kind}_metrics" + metrics_config = model_init_args.get(metrics_key) + if metrics_config: + self._update_metrics_num_labels(metrics_config, num_labels) + + def _update_metrics_num_labels(self, metrics_config: dict, num_labels: int) -> None: + """Helper method to update num_labels in metrics configuration.""" + init_args = metrics_config.get("init_args", {}) + metrics_dict = init_args.get("metrics", {}) + + for metric_name, metric_config in metrics_dict.items(): + metric_init_args = metric_config.get("init_args", {}) + if "num_labels" in metric_init_args and metric_init_args["num_labels"] is None: + metric_init_args["num_labels"] = num_labels + + def _update_trainer_callbacks(self, subcommand: dict, num_labels: int) -> None: + """Helper method to update num_labels in trainer callbacks.""" + if "trainer" not in subcommand or "callbacks" not in subcommand["trainer"]: + return + + callbacks = subcommand["trainer"]["callbacks"] + + if isinstance(callbacks, list): + for callback in callbacks: + self._set_callback_num_labels(callback, num_labels) + else: + self._set_callback_num_labels(callbacks, num_labels) + + def _set_callback_num_labels(self, callback: dict, num_labels: int) -> None: + """Helper method to set num_labels in a single callback configuration.""" + init_args = callback.get("init_args", {}) + if "num_labels" in init_args and init_args["num_labels"] is None: + init_args["num_labels"] = num_labels def add_arguments_to_parser(self, parser: LightningArgumentParser): """ From 4351bc5c8bf9367bf1b14c48bd69959371d32494 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 21:00:39 +0000 Subject: [PATCH 4/4] Add documentation explaining link_arguments pattern for Lightning 2.6+ Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com> --- chebai/cli.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/chebai/cli.py b/chebai/cli.py index c26d4feb..fbbc5d39 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -127,10 +127,16 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): Args: parser (LightningArgumentParser): Argument parser instance. + + Note: + In Lightning 2.6+, we use model.init_args.out_dim as the source for linking + because it's set during before_instantiate_classes() from the computed num_labels. + This avoids issues with linking from data.num_of_labels which is a property + that requires the datamodule to be instantiated. """ - # Link num_labels to metrics configurations - # These links use the values set in before_instantiate_classes() + # Link num_labels (via out_dim) to metrics configurations + # out_dim is set in before_instantiate_classes() from data.num_of_labels for kind in ("train", "val", "test"): for average in ( "micro-f1", @@ -147,10 +153,12 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels", ) + # Link out_dim to trainer callbacks parser.link_arguments( "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels" ) + # Link datamodule to criterion's data extractor parser.link_arguments( "data", "model.init_args.criterion.init_args.data_extractor" )