diff --git a/README.md b/README.md index 8d59280..8d75fcd 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ An AI ensemble model for predicting chemical classes in the ChEBI ontology. It integrates deep learning models, rule-based models and generative AI-based models. -A web application for the ensemble is available at https://chebifier.hastingslab.org/. +A web application for Chebifier is available at https://chebifier.hastingslab.org/. ## Installation @@ -38,23 +38,27 @@ The package provides a command-line interface (CLI) for making predictions using The ensemble configuration is given by a configuration file (by default, this is `chebifier/ensemble.yml`). If you want to change which models are included in the ensemble or how they are weighted, you can create your own configuration file. -Model weights for deep learning models are automatically downloaded from [Hugging Face](https://huggingface.co/chebai). -To use specific model weights from Hugging face, add the `load_model` key in your configuration file. For example: +Trained deep learning models are automatically downloaded from [Hugging Face](https://huggingface.co/chebai). +To access a model from Hugging face, add the `load_model` key in your configuration file. For example: ```yaml my_electra: type: electra - load_model: "electra_chebi50_v241" + load_model: "electra_chebi50-3star_v244" ``` ### Available model weights: +* `resgated-aug_chebi50-3star_v244` +* `gat-aug_chebi50_v244` +* `electra_chebi50-3star_v244` +* `gat_chebi50_v244` * `electra_chebi50_v241` * `resgated_chebi50_v241` * `c3p_with_weights` -However, you can also supply your own model checkpoints (see `configs/example_config.yml` for an example). +You can also supply your own model checkpoints (see `configs/example_config.yml` for an example). ```bash # Make predictions @@ -72,12 +76,12 @@ python -m chebifier predict --help ### Python API -You can also use the package programmatically: +You can use the package programmatically as well: ```python from chebifier import BaseEnsemble -# Instantiate ensemble model. If desired, can pass +# Instantiate ensemble model. Optionally, you can pass # a path to a configuration, like 'configs/example_config.yml' ensemble = BaseEnsemble() @@ -100,11 +104,12 @@ Currently, the following models are supported: | Model | Description | #Classes | Publication | Repository | |-------|-------------|----------|-----------------------------------------------------------------------|----------------------------------------------------------------------------------------| -| `electra` | A transformer-based deep learning model trained on ChEBI SMILES strings. | 1522 | [Glauer, Martin, et al., 2024: Chebifier: Automating semantic classification in ChEBI to accelerate data-driven discovery, Digital Discovery 3 (2024) 896-907](https://pubs.rsc.org/en/content/articlehtml/2024/dd/d3dd00238a) | [python-chebai](https://github.com/ChEB-AI/python-chebai) | -| `resgated` | A Residual Gated Graph Convolutional Network trained on ChEBI molecules. | 1522 | | [python-chebai-graph](https://github.com/ChEB-AI/python-chebai-graph) | +| `electra` | A transformer-based deep learning model trained on ChEBI SMILES strings. | 1531* | [Glauer, Martin, et al., 2024: Chebifier: Automating semantic classification in ChEBI to accelerate data-driven discovery, Digital Discovery 3 (2024) 896-907](https://pubs.rsc.org/en/content/articlehtml/2024/dd/d3dd00238a) | [python-chebai](https://github.com/ChEB-AI/python-chebai) | +| `resgated` | A Residual Gated Graph Convolutional Network trained on ChEBI molecules. | 1531* | | [python-chebai-graph](https://github.com/ChEB-AI/python-chebai-graph) | +| `gat` | A Graph Attention Network trained on ChEBI molecules. | 1531* | | [python-chebai-graph](https://github.com/ChEB-AI/python-chebai-graph) | | `chemlog_peptides` | A rule-based model specialised on peptide classes. | 18 | [Flügel, Simon, et al., 2025: ChemLog: Making MSOL Viable for Ontological Classification and Learning, arXiv](https://arxiv.org/abs/2507.13987) | [chemlog-peptides](https://github.com/sfluegel05/chemlog-peptides) | | `chemlog_element`, `chemlog_organox` | Extensions of ChemLog for classes that are defined either by the presence of a specific element or by the presence of an organic bond. | 118 + 37 | | [chemlog-extra](https://github.com/ChEB-AI/chemlog-extra) | -| `c3p` | A collection _Chemical Classifier Programs_, generated by LLMs based on the natural language definitions of ChEBI classes. | 338 | [Mungall, Christopher J., et al., 2025: Chemical classification program synthesis using generative artificial intelligence, arXiv](https://arxiv.org/abs/2505.18470) | [c3p](https://github.com/chemkg/c3p) | +| `c3p` | A collection _Chemical Classifier Programs_, generated by LLMs based on the natural language definitions of ChEBI classes. | 338 | [Mungall, Christopher J., et al., 2025: Chemical classification program synthesis using generative artificial intelligence, Journal of Cheminsformatics](https://link.springer.com/article/10.1186/s13321-025-01092-3) | [c3p](https://github.com/chemkg/c3p) | In addition, Chebifier also includes a ChEBI lookup that automatically retrieves the ChEBI superclasses for a class matched by a SMILES string. This is not activated by default, but can be included by adding @@ -116,6 +121,8 @@ chebi_lookup: to your configuration file. ### The ensemble +For an extended description of the ensemble, see [Flügel, Simon, et al., 2025: Chebifier 2: An Ensemble for Chemistry](https://ceur-ws.org/Vol-4064/SymGenAI4Sci-paper4.pdf). + ensemble_architecture Given a sample (i.e., a SMILES string) and models $m_1, m_2, \ldots, m_n$, the ensemble works as follows: @@ -146,10 +153,10 @@ Therefore, if in doubt, we are more confident in the negative prediction. Confidence can be disabled by the `use_confidence` parameter of the predict method (default: True). -The model_weight can be set for each model in the configuration file (default: 1). This is used to favor a certain +The`model_weight` can be set for each model in the configuration file (default: 1). This is used to favor a certain model independently of a given class. -Trust is based on the model's performance on a validation set. After training, we evaluate the Machine Learning models -on a validation set for each class. If the `ensemble_type` is set to `wmv-f1`, the trust is calculated as 1 + the F1 score. +`Trust` is based on the model's performance on a validation set. After training, we evaluate the Machine Learning models +on a validation set for each class. If the `ensemble_type` is set to `wmv-f1`, the trust is calculated as F1-score $^{6.25}$. If the `ensemble_type` is set to `mv` (the default), the trust is set to 1 for all models. ### Inconsistency resolution @@ -157,9 +164,7 @@ After a decision has been made for each class independently, the consistency of and disjointness axioms is checked. This is done in 3 steps: - (1) First, the hierarchy is corrected. For each pair of classes $A$ and $B$ where $A$ is a subclass of $B$ (following -the is-a relation in ChEBI), we set the ensemble prediction of $B$ to 1 if the prediction of $A$ is 1. Intuitively -speaking, if we have determined that a molecule belongs to a specific class (e.g., aromatic primary alcohol), it also -belongs to the direct and indirect superclasses (e.g., primary alcohol, aromatic alcohol, alcohol). +the is-a relation in ChEBI), we set the ensemble prediction of $A$ to $0$ if the _absolute value_ of $B$'s score is large than that of $A$. For example, if $A$ has a net score of $3$ and $B$ has a net score of $-4$, the ensemble will set $A$ to $0$ (i.e., predict neither $A$ nor $B$). - (2) Next, we check for disjointness. This is not specified directly in ChEBI, but in an additional ChEBI module ([chebi-disjoints.owl](https://ftp.ebi.ac.uk/pub/databases/chebi/ontology/)). We have extracted these disjointness axioms into a CSV file and added some more disjointness axioms ourselves (see `data>disjoint_chebi.csv` and `data>disjoint_additional.csv`). If two classes $A$ and $B$ are disjoint and we predict diff --git a/chebifier/cli.py b/chebifier/cli.py index a3db5d6..2f0468c 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -37,13 +37,6 @@ def cli(): default="wmv-f1", help="Type of ensemble to use (default: Weighted Majority Voting)", ) -@click.option( - "--chebi-version", - "-v", - type=int, - default=241, - help="ChEBI version to use for checking consistency (default: 241)", -) @click.option( "--use-confidence", "-c", @@ -58,23 +51,31 @@ def cli(): default=True, help="Resolve inconsistencies in predictions automatically (default: True)", ) +@click.option( + "--verbose", + "-v", + is_flag=True, + default=False, + help="Enable verbose output", +) def predict( ensemble_config, smiles, smiles_file, output, ensemble_type, - chebi_version, use_confidence, resolve_inconsistencies=True, + verbose=False, ): """Predict ChEBI classes for SMILES strings using an ensemble model.""" # Instantiate ensemble model ensemble = ENSEMBLES[ensemble_type]( ensemble_config, - chebi_version=chebi_version, resolve_inconsistencies=resolve_inconsistencies, + verbose_output=verbose, + use_confidence=use_confidence, ) # Collect SMILES strings from arguments and/or file @@ -88,9 +89,7 @@ def predict( return # Make predictions - predictions = ensemble.predict_smiles_list( - smiles_list, use_confidence=use_confidence - ) + predictions = ensemble.predict_smiles_list(smiles_list) if output: # save as json diff --git a/chebifier/ensemble.yml b/chebifier/ensemble.yml index 1744bad..3506005 100644 --- a/chebifier/ensemble.yml +++ b/chebifier/ensemble.yml @@ -1,15 +1,13 @@ -electra: - load_model: electra_chebi50_v241 -resgated: - load_model: resgated_chebi50_v241 -chemlog_peptides: - type: chemlog_peptides - model_weight: 100 -chemlog_element: - type: chemlog_element - model_weight: 100 -chemlog_organox: - type: chemlog_organox +electra_chebi50-3star_v244: + load_model: electra_chebi50-3star_v244 +gat_chebi50_v244: + load_model: gat_chebi50_v244 +gat-aug_chebi50_v244: + load_model: gat-aug_chebi50_v244 +resgated-aug_chebi50-3star_v244: + load_model: resgated-aug_chebi50-3star_v244 +chemlog: + type: chemlog model_weight: 100 c3p: load_model: c3p_with_weights diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 50ab7ee..6dbe8ee 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -1,5 +1,4 @@ import importlib -import os import time from pathlib import Path from typing import Union @@ -10,7 +9,7 @@ from chebifier.check_env import check_package_installed from chebifier.hugging_face import download_model_files -from chebifier.inconsistency_resolution import PredictionSmoother +from chebifier.inconsistency_resolution import ScoreBasedPredictionSmoother from chebifier.prediction_models.base_predictor import BasePredictor from chebifier.utils import ( get_default_configs, @@ -24,8 +23,9 @@ class BaseEnsemble: def __init__( self, model_configs: Union[str, Path, dict, None] = None, - chebi_version: int = 241, resolve_inconsistencies: bool = True, + verbose_output: bool = False, + use_confidence: bool = True, ): # Deferred Import: To avoid circular import error from chebifier.model_registry import MODEL_TYPES @@ -48,6 +48,8 @@ def __init__( model_registry = yaml.safe_load(f) processed_configs = process_config(config, model_registry) + self.verbose_output = verbose_output + self.use_confidence = use_confidence self.chebi_graph = load_chebi_graph() self.disjoint_files = get_disjoint_files() @@ -73,10 +75,11 @@ def __init__( self.models.append(model_instance) if resolve_inconsistencies: - self.smoother = PredictionSmoother( + self.smoother = ScoreBasedPredictionSmoother( self.chebi_graph, label_names=None, disjoint_files=self.disjoint_files, + verbose=self.verbose_output, ) else: self.smoother = None @@ -92,7 +95,8 @@ def gather_predictions(self, smiles_list): if logits_for_smiles is not None: for cls in logits_for_smiles: predicted_classes.add(cls) - print(f"Sorting predictions from {len(model_predictions)} models...") + if self.verbose_output: + print(f"Sorting predictions from {len(model_predictions)} models...") predicted_classes = sorted(list(predicted_classes)) predicted_classes_dict = {cls: i for i, cls in enumerate(predicted_classes)} ordered_logits = ( @@ -114,7 +118,11 @@ def gather_predictions(self, smiles_list): return ordered_logits, predicted_classes def consolidate_predictions( - self, predictions, classwise_weights, predicted_classes, **kwargs + self, + predictions, + classwise_weights, + return_intermediate_results=False, + **kwargs, ): """ Aggregates predictions from multiple models using weighted majority voting. @@ -137,7 +145,9 @@ def consolidate_predictions( predictions < self.positive_prediction_threshold ) & valid_predictions - if "use_confidence" in kwargs and kwargs["use_confidence"]: + # if use_confidence is passed in kwargs, it overrides the ensemble setting + use_confidence = kwargs.get("use_confidence", self.use_confidence) + if use_confidence: confidence = 2 * torch.abs( predictions.nan_to_num() - self.positive_prediction_threshold ) @@ -164,22 +174,39 @@ def consolidate_predictions( # Determine which classes to include for each SMILES net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes) + if return_intermediate_results: + return ( + net_score, + has_valid_predictions, + { + "positive_mask": positive_mask, + "negative_mask": negative_mask, + "confidence": confidence, + "positive_sum": positive_sum, + "negative_sum": negative_sum, + }, + ) + + return net_score, has_valid_predictions + def apply_inconsistency_resolution( + self, net_score, class_names, has_valid_predictions + ): # Smooth predictions start_time = time.perf_counter() - class_names = list(predicted_classes.keys()) if self.smoother is not None: self.smoother.set_label_names(class_names) smooth_net_score = self.smoother(net_score) class_decisions = ( - smooth_net_score > 0.5 + smooth_net_score > 0 ) & has_valid_predictions # Shape: (num_smiles, num_classes) else: class_decisions = ( net_score > 0 ) & has_valid_predictions # Shape: (num_smiles, num_classes) end_time = time.perf_counter() - print(f"Prediction smoothing took {end_time - start_time:.2f} seconds") + if self.verbose_output: + print(f"Prediction smoothing took {end_time - start_time:.2f} seconds") complete_failure = torch.all(~has_valid_predictions, dim=1) return class_decisions, complete_failure @@ -192,38 +219,28 @@ def calculate_classwise_weights(self, predicted_classes): return positive_weights, negative_weights def predict_smiles_list( - self, smiles_list, load_preds_if_possible=False, **kwargs + self, smiles_list, return_intermediate_results=False, **kwargs ) -> list: - preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt" - predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt" - if not load_preds_if_possible or not os.path.isfile(preds_file): - ordered_predictions, predicted_classes = self.gather_predictions( - smiles_list - ) - if len(predicted_classes) == 0: - print( - "Warning: No classes have been predicted for the given SMILES list." + ordered_predictions, predicted_classes = self.gather_predictions(smiles_list) + if len(predicted_classes) == 0: + print("Warning: No classes have been predicted for the given SMILES list.") + predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)} + + classwise_weights = self.calculate_classwise_weights(predicted_classes) + if return_intermediate_results: + net_score, has_valid_predictions, intermediate_results_dict = ( + self.consolidate_predictions( + ordered_predictions, + classwise_weights, + return_intermediate_results=return_intermediate_results, ) - # save predictions - if load_preds_if_possible: - torch.save(ordered_predictions, preds_file) - with open(predicted_classes_file, "w") as f: - for cls in predicted_classes: - f.write(f"{cls}\n") - predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)} + ) else: - print( - f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}" + net_score, has_valid_predictions = self.consolidate_predictions( + ordered_predictions, classwise_weights ) - ordered_predictions = torch.load(preds_file) - with open(predicted_classes_file, "r") as f: - predicted_classes = { - line.strip(): i for i, line in enumerate(f.readlines()) - } - - classwise_weights = self.calculate_classwise_weights(predicted_classes) - class_decisions, is_failure = self.consolidate_predictions( - ordered_predictions, classwise_weights, predicted_classes, **kwargs + class_decisions, is_failure = self.apply_inconsistency_resolution( + net_score, list(predicted_classes.keys()), has_valid_predictions ) class_names = list(predicted_classes.keys()) @@ -239,6 +256,11 @@ def predict_smiles_list( ) for i, failure in zip(class_decisions, is_failure) ] + if return_intermediate_results: + intermediate_results_dict["predicted_classes"] = predicted_classes + intermediate_results_dict["classwise_weights"] = classwise_weights + intermediate_results_dict["net_score"] = net_score + return result, intermediate_results_dict return result diff --git a/chebifier/ensemble/weighted_majority_ensemble.py b/chebifier/ensemble/weighted_majority_ensemble.py index 97338f5..8579e9f 100644 --- a/chebifier/ensemble/weighted_majority_ensemble.py +++ b/chebifier/ensemble/weighted_majority_ensemble.py @@ -4,6 +4,21 @@ class WMVwithPPVNPVEnsemble(BaseEnsemble): + + def __init__( + self, config_path=None, weighting_strength=1, weighting_exponent=1, **kwargs + ): + """WMV ensemble that weights models based on their class-wise positive / negative predictive values. For each class, the weight is calculated as: + weight = (weighting_strength * PPV + (1 - weighting_strength)) ** weighting_exponent + where PPV is the class-specific positive predictive value of the model on the validation set + or (if the prediction is negative): + weight = (weighting_strength * NPV + (1 - weighting_strength)) ** weighting_exponent + where NPV is the class-specific negative predictive value of the model on the validation set. + """ + super().__init__(config_path, **kwargs) + self.weighting_strength = weighting_strength + self.weighting_exponent = weighting_exponent + def calculate_classwise_weights(self, predicted_classes): """ Given the positions of predicted classes in the predictions tensor, assign weights to each class. The @@ -18,21 +33,40 @@ def calculate_classwise_weights(self, predicted_classes): if model.classwise_weights is None: continue for cls, weights in model.classwise_weights.items(): - positive_weights[predicted_classes[cls], j] *= weights["PPV"] - negative_weights[predicted_classes[cls], j] *= weights["NPV"] + positive_weights[predicted_classes[cls], j] *= ( + weights["PPV"] * self.weighting_strength + + (1 - self.weighting_strength) + ) ** self.weighting_exponent + negative_weights[predicted_classes[cls], j] *= ( + weights["NPV"] * self.weighting_strength + + (1 - self.weighting_strength) + ) ** self.weighting_exponent - print( - "Calculated model weightings. The averages for positive / negative weights are:" - ) - for i, model in enumerate(self.models): + if self.verbose_output: print( - f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}" + "Calculated model weightings. The averages for positive / negative weights are:" ) + for i, model in enumerate(self.models): + print( + f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}" + ) return positive_weights, negative_weights class WMVwithF1Ensemble(BaseEnsemble): + + def __init__( + self, config_path=None, weighting_strength=1, weighting_exponent=6.25, **kwargs + ): + """WMV ensemble that weights models based on their class-wise F1 scores. For each class, the weight is calculated as: + weight = model_weight * (weighting_strength * F1 + (1 - weighting_strength)) ** weighting_exponent + where F1 is the class-specific F1 score ("trust") of the model on the validation set. + """ + super().__init__(config_path, **kwargs) + self.weighting_strength = weighting_strength + self.weighting_exponent = weighting_exponent + def calculate_classwise_weights(self, predicted_classes): """ Given the positions of predicted classes in the predictions tensor, assign weights to each class. The @@ -52,10 +86,12 @@ def calculate_classwise_weights(self, predicted_classes): * weights["TP"] / (2 * weights["TP"] + weights["FP"] + weights["FN"]) ) - weights_by_cls[predicted_classes[cls], j] *= 1 + f1 - - print("Calculated model weightings. The average weights are:") - for i, model in enumerate(self.models): - print(f"{model.model_name}: {weights_by_cls[:, i].mean().item():.3f}") + weights_by_cls[predicted_classes[cls], j] *= ( + self.weighting_strength * f1 + 1 - self.weighting_strength + ) ** self.weighting_exponent + if self.verbose_output: + print("Calculated model weightings. The average weights are:") + for i, model in enumerate(self.models): + print(f"{model.model_name}: {weights_by_cls[:, i].mean().item():.3f}") return weights_by_cls, weights_by_cls diff --git a/chebifier/inconsistency_resolution.py b/chebifier/inconsistency_resolution.py index 6a9a45e..f442de8 100644 --- a/chebifier/inconsistency_resolution.py +++ b/chebifier/inconsistency_resolution.py @@ -56,10 +56,13 @@ def get_disjoint_groups(disjoint_files): class PredictionSmoother: """Removes implication and disjointness violations from predictions""" - def __init__(self, chebi_graph, label_names=None, disjoint_files=None): + def __init__( + self, chebi_graph, label_names=None, disjoint_files=None, verbose=False + ): self.chebi_graph = chebi_graph self.set_label_names(label_names) self.disjoint_groups = get_disjoint_groups(disjoint_files) + self.verbose = verbose def set_label_names(self, label_names): if label_names is not None: @@ -75,43 +78,26 @@ def set_label_names(self, label_names): self.label_successors[i, self.label_names.index(p)] = 1 self.label_successors = self.label_successors.unsqueeze(0) - def __call__(self, preds): - if preds.shape[1] == 0: - # no labels predicted - return preds - # preds shape: (n_samples, n_labels) - preds_sum_orig = torch.sum(preds) - # step 1: apply implications: for each class, set prediction to max of itself and all successors + def resolve_subsumption_violations(self, preds): preds = preds.unsqueeze(1) preds_masked_succ = torch.where(self.label_successors, preds, 0) # preds_masked_succ shape: (n_samples, n_labels, n_labels) + return preds_masked_succ.max(dim=2).values - preds = preds_masked_succ.max(dim=2).values - if torch.sum(preds) != preds_sum_orig: - print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}") + def resolve_disjointness_violations(self, preds): preds_sum_orig = torch.sum(preds) - # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower) - preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49) + for disj_group in self.disjoint_groups: disj_group = [ self.label_names.index(g) for g in disj_group if g in self.label_names ] if len(disj_group) > 1: - old_preds = preds[:, disj_group] disj_max = torch.max(preds[:, disj_group], dim=1) for i, row in enumerate(preds): for l_ in range(len(preds[i])): if l_ in disj_group and l_ != disj_group[disj_max.indices[i]]: - preds[i, l_] = preds_bounded[i, l_] - samples_changed = 0 - for i, row in enumerate(preds[:, disj_group]): - if any(r != o for r, o in zip(row, old_preds[i])): - samples_changed += 1 - if samples_changed != 0: - print( - f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples" - ) - if torch.sum(preds) != preds_sum_orig: + preds[i, l_] = 0 + if self.verbose and torch.sum(preds) != preds_sum_orig: print(f"Preds change (step 2): {torch.sum(preds) - preds_sum_orig}") preds_sum_orig = torch.sum(preds) # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors @@ -120,6 +106,51 @@ def __call__(self, preds): torch.transpose(self.label_successors, 1, 2), preds, 1 ) preds = preds_masked_predec.min(dim=2).values - if torch.sum(preds) != preds_sum_orig: + if self.verbose and torch.sum(preds) != preds_sum_orig: print(f"Preds change (step 3): {torch.sum(preds) - preds_sum_orig}") return preds + + def __call__(self, preds): + if preds.shape[1] == 0: + # no labels predicted + return preds + # preds shape: (n_samples, n_labels) + preds_sum_orig = torch.sum(preds) + # step 1: apply implications: for each class, set prediction to max of itself and all successors + preds = self.resolve_subsumption_violations(preds) + + if self.verbose and torch.sum(preds) != preds_sum_orig: + print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}") + # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower) + preds = self.resolve_disjointness_violations(preds) + return preds + + +class PessimisticPredictionSmoother(PredictionSmoother): + """Always assumes the positive prediction is wrong (in case of implication violations)""" + + def resolve_subsumption_violations(self, preds): + preds = preds.unsqueeze(1) + preds_masked_predec = torch.where( + torch.transpose(self.label_successors, 1, 2), preds, 1 + ) + preds = preds_masked_predec.min(dim=2).values + return preds + + +class ScoreBasedPredictionSmoother(PredictionSmoother): + """Removes implication violations from predictions based on net scores: for A subclassOf B where score(A) > score(B), either set score(B) = max(score(B), score(A)) + if abs(score(A)) > abs(score(B)) or set score(A) = min(score(A), score(B)) otherwise. + """ + + def resolve_subsumption_violations(self, preds): + preds = preds.unsqueeze(1) + preds_masked_succ = torch.where(self.label_successors, preds, 0) + preds_optimistic = preds_masked_succ.max(dim=2).values + preds_masked_predec = torch.where( + torch.transpose(self.label_successors, 1, 2), preds, 1 + ) + preds_pessimistic = preds_masked_predec.min(dim=2).values + # take the one with the higher absolute value + preds_direction = preds_optimistic - preds_pessimistic > 0 + return torch.where(preds_direction, preds_optimistic, preds_pessimistic) diff --git a/chebifier/model_registry.py b/chebifier/model_registry.py index 91958b1..3632662 100644 --- a/chebifier/model_registry.py +++ b/chebifier/model_registry.py @@ -11,6 +11,7 @@ ) from chebifier.prediction_models.c3p_predictor import C3PPredictor from chebifier.prediction_models.chemlog_predictor import ( + ChemlogAllPredictor, ChemlogOrganoXCompoundPredictor, ChemlogXMolecularEntityPredictor, ) @@ -27,6 +28,7 @@ "electra": ElectraPredictor, "resgated": ResGatedPredictor, "gat": GATPredictor, + "chemlog": ChemlogAllPredictor, "chemlog_peptides": ChemlogPeptidesPredictor, "chebi_lookup": ChEBILookupPredictor, "chemlog_element": ChemlogXMolecularEntityPredictor, diff --git a/chebifier/model_registry.yml b/chebifier/model_registry.yml index 0cef3af..4c3fedc 100644 --- a/chebifier/model_registry.yml +++ b/chebifier/model_registry.yml @@ -1,3 +1,91 @@ +electra_chebi50-3star_v244: + type: electra + hugging_face: + repo_id: chebai/electra_chebi50-3star_v244 + files: + ckpt_path: electra_chebi50-3star_v244_x2mngani_epoch=180.ckpt + target_labels_path: classes.txt + classwise_weights_path: electra_chebi50-3star_v244_x2mngani_epoch=180_trust_3star.json +gat_chebi50_v244: + type: gat + hugging_face: + repo_id: chebai/gat_chebi50_v244 + files: + ckpt_path: gat_chebi50_v244_0nfi19qt_epoch=198.ckpt + target_labels_path: classes.txt + classwise_weights_path: gat_chebi50_v244_0nfi19qt_epoch=198_trust_3star.json + dataset_cls: chebai_graph.preprocessing.datasets.ChEBI50GraphProperties + molecular_properties: + - chebai_graph.preprocessing.properties.AtomType + - chebai_graph.preprocessing.properties.NumAtomBonds + - chebai_graph.preprocessing.properties.AtomCharge + - chebai_graph.preprocessing.properties.AtomAromaticity + - chebai_graph.preprocessing.properties.AtomHybridization + - chebai_graph.preprocessing.properties.AtomNumHs + - chebai_graph.preprocessing.properties.BondType + - chebai_graph.preprocessing.properties.BondInRing + - chebai_graph.preprocessing.properties.BondAromaticity + - chebai_graph.preprocessing.properties.RDKit2DNormalized +gat-aug_chebi50_v244: + type: gat + hugging_face: + repo_id: chebai/gat-aug_chebi50_v244 + files: + ckpt_path: gat-aug_chebi50_v244_8fky8tru_epoch=192.ckpt + target_labels_path: classes.txt + classwise_weights_path: gat-aug_chebi50_v244_8fky8tru_epoch=192_trust_3star.json + dataset_cls: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_AsPerNodeType + molecular_properties: + - chebai_graph.preprocessing.properties.AtomNodeLevel + # Atom Node type properties + - chebai_graph.preprocessing.properties.AugAtomAromaticity + - chebai_graph.preprocessing.properties.AugAtomCharge + - chebai_graph.preprocessing.properties.AugAtomHybridization + - chebai_graph.preprocessing.properties.AugAtomNumHs + - chebai_graph.preprocessing.properties.AugAtomType + - chebai_graph.preprocessing.properties.AugNumAtomBonds + # FG Node type properties + - chebai_graph.preprocessing.properties.AtomFunctionalGroup + - chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG + - chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG + - chebai_graph.preprocessing.properties.IsFGAlkyl + # Graph Node type properties + - chebai_graph.preprocessing.properties.AugRDKit2DNormalized + # Bond properties + - chebai_graph.preprocessing.properties.BondLevel + - chebai_graph.preprocessing.properties.AugBondAromaticity + - chebai_graph.preprocessing.properties.AugBondInRing + - chebai_graph.preprocessing.properties.AugBondType +resgated-aug_chebi50-3star_v244: + type: resgated + hugging_face: + repo_id: chebai/resgated-aug_chebi50-3star_v244 + files: + ckpt_path: resgated-aug_chebi50-3star_v244_w0rhmajx_epoch=190.ckpt + target_labels_path: classes.txt + classwise_weights_path: resgated-aug_chebi50-3star_v244_w0rhmajx_epoch=190_trust_3star.json + dataset_cls: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_AsPerNodeType + molecular_properties: + - chebai_graph.preprocessing.properties.AtomNodeLevel + # Atom Node type properties + - chebai_graph.preprocessing.properties.AugAtomAromaticity + - chebai_graph.preprocessing.properties.AugAtomCharge + - chebai_graph.preprocessing.properties.AugAtomHybridization + - chebai_graph.preprocessing.properties.AugAtomNumHs + - chebai_graph.preprocessing.properties.AugAtomType + - chebai_graph.preprocessing.properties.AugNumAtomBonds + # FG Node type properties + - chebai_graph.preprocessing.properties.AtomFunctionalGroup + - chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG + - chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG + - chebai_graph.preprocessing.properties.IsFGAlkyl + # Graph Node type properties + - chebai_graph.preprocessing.properties.AugRDKit2DNormalized + # Bond properties + - chebai_graph.preprocessing.properties.BondLevel + - chebai_graph.preprocessing.properties.AugBondAromaticity + - chebai_graph.preprocessing.properties.AugBondInRing + - chebai_graph.preprocessing.properties.AugBondType electra_chebi50_v241: type: electra hugging_face: @@ -31,4 +119,4 @@ c3p_with_weights: repo_id: chebai/chebifier repo_type: dataset files: - classwise_weights_path: c3p_trust.json \ No newline at end of file + classwise_weights_path: c3p_trust.json diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 5581e6e..0a7f0d6 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -32,6 +32,32 @@ } +class ChemlogAllPredictor(BasePredictor): + def __init__(self, model_name: str, **kwargs): + super().__init__(model_name, **kwargs) + self.chebi_graph = kwargs.get("chebi_graph", None) + self.predictors = [ + ChemlogXMolecularEntityPredictor("chemlog_x_molecular_entity", **kwargs), + ChemlogOrganoXCompoundPredictor("chemlog_organo_x_compound", **kwargs), + ChemlogPeptidesPredictor("chemlog_peptides", **kwargs), + ] + + @modelwise_smiles_lru_cache.batch_decorator + def predict_smiles_list(self, smiles_list: list[str]) -> list: + results = [] + for predictor in self.predictors: + predictor_results = predictor._predict_smiles_list(smiles_list) + for i, res in enumerate(predictor_results): + if i >= len(results): + results.append(dict()) + if res is not None: + results[i].update(res) + return results + + def explain_smiles(self, smiles): + return self.predictors[2].explain_smiles(smiles) + + class ChemlogExtraPredictor(BasePredictor): def __init__(self, model_name: str, **kwargs): @@ -41,6 +67,9 @@ def __init__(self, model_name: str, **kwargs): @modelwise_smiles_lru_cache.batch_decorator def predict_smiles_list(self, smiles_list: list[str]) -> list: + return self._predict_smiles_list(smiles_list) + + def _predict_smiles_list(self, smiles_list: list[str]) -> list: from chemlog.cli import _smiles_to_mol mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list] @@ -65,7 +94,7 @@ def __init__(self, model_name: str, **kwargs): ) super().__init__(model_name, **kwargs) - self.classifier = XMolecularEntityClassifier() + self.classifier = XMolecularEntityClassifier(chebi_graph=self.chebi_graph) class ChemlogOrganoXCompoundPredictor(ChemlogExtraPredictor): @@ -75,7 +104,7 @@ def __init__(self, model_name: str, **kwargs): ) super().__init__(model_name, **kwargs) - self.classifier = OrganoXCompoundClassifier() + self.classifier = OrganoXCompoundClassifier(chebi_graph=self.chebi_graph) class ChemlogPeptidesPredictor(BasePredictor): @@ -124,6 +153,9 @@ def predict_smiles(self, smiles: str) -> Optional[dict]: @modelwise_smiles_lru_cache.batch_decorator def predict_smiles_list(self, smiles_list: list[str]) -> list: + return self._predict_smiles_list(smiles_list) + + def _predict_smiles_list(self, smiles_list: list[str]) -> list: results = [] for i, smiles in tqdm.tqdm(enumerate(smiles_list)): results.append(self.predict_smiles(smiles)) diff --git a/chebifier/utils.py b/chebifier/utils.py index c547054..fbdcde5 100644 --- a/chebifier/utils.py +++ b/chebifier/utils.py @@ -97,7 +97,8 @@ def build_chebi_graph(chebi_version=241): # Only take the edges which connect the existing nodes, to avoid internal creation of obsolete nodes # https://github.com/ChEB-AI/python-chebai/pull/55#issuecomment-2386654142 g.add_edges_from( - [(p, q["id"]) for q in elements for p in q["parents"] if g.has_node(p)] + [(p, q["id"]) for q in elements for p in q["parents"] if g.has_node(p)], + label="direct_child", ) return nx.transitive_closure_dag(g) @@ -153,3 +154,12 @@ def process_config(config, model_registry): else: new_config[model_name] = entry return new_config + + +if __name__ == "__main__": + chebi_graph = build_chebi_graph(chebi_version=244) + os.makedirs(os.path.join("data", "chebi_v244"), exist_ok=True) + pickle.dump( + chebi_graph, + open(os.path.join("data", "chebi_v244", "chebi_graph.pkl"), "wb"), + ) diff --git a/configs/example_config.yml b/configs/example_config.yml index bc8efbc..dab7744 100644 --- a/configs/example_config.yml +++ b/configs/example_config.yml @@ -5,7 +5,7 @@ chemlog_peptides: my_resgated: type: resgated ckpt_path: my_resgated.ckpt # checkpoint trained with chebai - target_labels_path: ../python-chebai/data/chebi_v241/ChEBI50/processed/classes.txt # from the chebai dataset + target_labels_path: ../python-chebai/data/chebi_v244/ChEBI50/processed/classes.txt # from the chebai dataset molecular_properties: # list of properties used during training - chebai_graph.preprocessing.properties.AtomType - chebai_graph.preprocessing.properties.NumAtomBonds @@ -22,5 +22,5 @@ my_resgated: my_electra: type: electra ckpt_path: my_electra.ckpt - target_labels_path: ../python-chebai/data/chebi_v241/ChEBI50/processed/classes.txt + target_labels_path: ../python-chebai/data/chebi_v244/ChEBI50/processed/classes.txt #classwise_weights_path: my_electra_metrics.json # can be calculated with chebai.results.generate_class_properties