Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
An AI ensemble model for predicting chemical classes in the ChEBI ontology. It integrates deep learning models,
rule-based models and generative AI-based models.

A web application for the ensemble is available at https://chebifier.hastingslab.org/.
A web application for Chebifier is available at https://chebifier.hastingslab.org/.

## Installation

Expand Down Expand Up @@ -38,23 +38,27 @@ The package provides a command-line interface (CLI) for making predictions using
The ensemble configuration is given by a configuration file (by default, this is `chebifier/ensemble.yml`). If you
want to change which models are included in the ensemble or how they are weighted, you can create your own configuration file.

Model weights for deep learning models are automatically downloaded from [Hugging Face](https://huggingface.co/chebai).
To use specific model weights from Hugging face, add the `load_model` key in your configuration file. For example:
Trained deep learning models are automatically downloaded from [Hugging Face](https://huggingface.co/chebai).
To access a model from Hugging face, add the `load_model` key in your configuration file. For example:

```yaml
my_electra:
type: electra
load_model: "electra_chebi50_v241"
load_model: "electra_chebi50-3star_v244"
```

### Available model weights:

* `resgated-aug_chebi50-3star_v244`
* `gat-aug_chebi50_v244`
* `electra_chebi50-3star_v244`
* `gat_chebi50_v244`
* `electra_chebi50_v241`
* `resgated_chebi50_v241`
* `c3p_with_weights`


However, you can also supply your own model checkpoints (see `configs/example_config.yml` for an example).
You can also supply your own model checkpoints (see `configs/example_config.yml` for an example).

```bash
# Make predictions
Expand All @@ -72,12 +76,12 @@ python -m chebifier predict --help

### Python API

You can also use the package programmatically:
You can use the package programmatically as well:

```python
from chebifier import BaseEnsemble

# Instantiate ensemble model. If desired, can pass
# Instantiate ensemble model. Optionally, you can pass
# a path to a configuration, like 'configs/example_config.yml'
ensemble = BaseEnsemble()

Expand All @@ -100,11 +104,12 @@ Currently, the following models are supported:

| Model | Description | #Classes | Publication | Repository |
|-------|-------------|----------|-----------------------------------------------------------------------|----------------------------------------------------------------------------------------|
| `electra` | A transformer-based deep learning model trained on ChEBI SMILES strings. | 1522 | [Glauer, Martin, et al., 2024: Chebifier: Automating semantic classification in ChEBI to accelerate data-driven discovery, Digital Discovery 3 (2024) 896-907](https://pubs.rsc.org/en/content/articlehtml/2024/dd/d3dd00238a) | [python-chebai](https://github.com/ChEB-AI/python-chebai) |
| `resgated` | A Residual Gated Graph Convolutional Network trained on ChEBI molecules. | 1522 | | [python-chebai-graph](https://github.com/ChEB-AI/python-chebai-graph) |
| `electra` | A transformer-based deep learning model trained on ChEBI SMILES strings. | 1531* | [Glauer, Martin, et al., 2024: Chebifier: Automating semantic classification in ChEBI to accelerate data-driven discovery, Digital Discovery 3 (2024) 896-907](https://pubs.rsc.org/en/content/articlehtml/2024/dd/d3dd00238a) | [python-chebai](https://github.com/ChEB-AI/python-chebai) |
| `resgated` | A Residual Gated Graph Convolutional Network trained on ChEBI molecules. | 1531* | | [python-chebai-graph](https://github.com/ChEB-AI/python-chebai-graph) |
| `gat` | A Graph Attention Network trained on ChEBI molecules. | 1531* | | [python-chebai-graph](https://github.com/ChEB-AI/python-chebai-graph) |
| `chemlog_peptides` | A rule-based model specialised on peptide classes. | 18 | [Flügel, Simon, et al., 2025: ChemLog: Making MSOL Viable for Ontological Classification and Learning, arXiv](https://arxiv.org/abs/2507.13987) | [chemlog-peptides](https://github.com/sfluegel05/chemlog-peptides) |
| `chemlog_element`, `chemlog_organox` | Extensions of ChemLog for classes that are defined either by the presence of a specific element or by the presence of an organic bond. | 118 + 37 | | [chemlog-extra](https://github.com/ChEB-AI/chemlog-extra) |
| `c3p` | A collection _Chemical Classifier Programs_, generated by LLMs based on the natural language definitions of ChEBI classes. | 338 | [Mungall, Christopher J., et al., 2025: Chemical classification program synthesis using generative artificial intelligence, arXiv](https://arxiv.org/abs/2505.18470) | [c3p](https://github.com/chemkg/c3p) |
| `c3p` | A collection _Chemical Classifier Programs_, generated by LLMs based on the natural language definitions of ChEBI classes. | 338 | [Mungall, Christopher J., et al., 2025: Chemical classification program synthesis using generative artificial intelligence, Journal of Cheminsformatics](https://link.springer.com/article/10.1186/s13321-025-01092-3) | [c3p](https://github.com/chemkg/c3p) |

In addition, Chebifier also includes a ChEBI lookup that automatically retrieves the ChEBI superclasses for a class
matched by a SMILES string. This is not activated by default, but can be included by adding
Expand All @@ -116,6 +121,8 @@ chebi_lookup:
to your configuration file.

### The ensemble
For an extended description of the ensemble, see [Flügel, Simon, et al., 2025: Chebifier 2: An Ensemble for Chemistry](https://ceur-ws.org/Vol-4064/SymGenAI4Sci-paper4.pdf).

<img width="700" alt="ensemble_architecture" src="https://github.com/user-attachments/assets/9275d3cd-ac88-466f-a1e9-27d20d67543b" />

Given a sample (i.e., a SMILES string) and models $m_1, m_2, \ldots, m_n$, the ensemble works as follows:
Expand Down Expand Up @@ -146,20 +153,18 @@ Therefore, if in doubt, we are more confident in the negative prediction.

Confidence can be disabled by the `use_confidence` parameter of the predict method (default: True).

The model_weight can be set for each model in the configuration file (default: 1). This is used to favor a certain
The`model_weight` can be set for each model in the configuration file (default: 1). This is used to favor a certain
model independently of a given class.
Trust is based on the model's performance on a validation set. After training, we evaluate the Machine Learning models
on a validation set for each class. If the `ensemble_type` is set to `wmv-f1`, the trust is calculated as 1 + the F1 score.
`Trust` is based on the model's performance on a validation set. After training, we evaluate the Machine Learning models
on a validation set for each class. If the `ensemble_type` is set to `wmv-f1`, the trust is calculated as F1-score $^{6.25}$.
If the `ensemble_type` is set to `mv` (the default), the trust is set to 1 for all models.

### Inconsistency resolution
After a decision has been made for each class independently, the consistency of the predictions with regard to the ChEBI hierarchy
and disjointness axioms is checked. This is
done in 3 steps:
- (1) First, the hierarchy is corrected. For each pair of classes $A$ and $B$ where $A$ is a subclass of $B$ (following
the is-a relation in ChEBI), we set the ensemble prediction of $B$ to 1 if the prediction of $A$ is 1. Intuitively
speaking, if we have determined that a molecule belongs to a specific class (e.g., aromatic primary alcohol), it also
belongs to the direct and indirect superclasses (e.g., primary alcohol, aromatic alcohol, alcohol).
the is-a relation in ChEBI), we set the ensemble prediction of $A$ to $0$ if the _absolute value_ of $B$'s score is large than that of $A$. For example, if $A$ has a net score of $3$ and $B$ has a net score of $-4$, the ensemble will set $A$ to $0$ (i.e., predict neither $A$ nor $B$).
- (2) Next, we check for disjointness. This is not specified directly in ChEBI, but in an additional ChEBI module ([chebi-disjoints.owl](https://ftp.ebi.ac.uk/pub/databases/chebi/ontology/)).
We have extracted these disjointness axioms into a CSV file and added some more disjointness axioms ourselves (see
`data>disjoint_chebi.csv` and `data>disjoint_additional.csv`). If two classes $A$ and $B$ are disjoint and we predict
Expand Down
23 changes: 11 additions & 12 deletions chebifier/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,6 @@ def cli():
default="wmv-f1",
help="Type of ensemble to use (default: Weighted Majority Voting)",
)
@click.option(
"--chebi-version",
"-v",
type=int,
default=241,
help="ChEBI version to use for checking consistency (default: 241)",
)
@click.option(
"--use-confidence",
"-c",
Expand All @@ -58,23 +51,31 @@ def cli():
default=True,
help="Resolve inconsistencies in predictions automatically (default: True)",
)
@click.option(
"--verbose",
"-v",
is_flag=True,
default=False,
help="Enable verbose output",
)
def predict(
ensemble_config,
smiles,
smiles_file,
output,
ensemble_type,
chebi_version,
use_confidence,
resolve_inconsistencies=True,
verbose=False,
):
"""Predict ChEBI classes for SMILES strings using an ensemble model."""

# Instantiate ensemble model
ensemble = ENSEMBLES[ensemble_type](
ensemble_config,
chebi_version=chebi_version,
resolve_inconsistencies=resolve_inconsistencies,
verbose_output=verbose,
use_confidence=use_confidence,
)

# Collect SMILES strings from arguments and/or file
Expand All @@ -88,9 +89,7 @@ def predict(
return

# Make predictions
predictions = ensemble.predict_smiles_list(
smiles_list, use_confidence=use_confidence
)
predictions = ensemble.predict_smiles_list(smiles_list)

if output:
# save as json
Expand Down
22 changes: 10 additions & 12 deletions chebifier/ensemble.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
electra:
load_model: electra_chebi50_v241
resgated:
load_model: resgated_chebi50_v241
chemlog_peptides:
type: chemlog_peptides
model_weight: 100
chemlog_element:
type: chemlog_element
model_weight: 100
chemlog_organox:
type: chemlog_organox
electra_chebi50-3star_v244:
load_model: electra_chebi50-3star_v244
gat_chebi50_v244:
load_model: gat_chebi50_v244
gat-aug_chebi50_v244:
load_model: gat-aug_chebi50_v244
resgated-aug_chebi50-3star_v244:
load_model: resgated-aug_chebi50-3star_v244
chemlog:
type: chemlog
model_weight: 100
c3p:
load_model: c3p_with_weights
98 changes: 60 additions & 38 deletions chebifier/ensemble/base_ensemble.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import importlib
import os
import time
from pathlib import Path
from typing import Union
Expand All @@ -10,7 +9,7 @@

from chebifier.check_env import check_package_installed
from chebifier.hugging_face import download_model_files
from chebifier.inconsistency_resolution import PredictionSmoother
from chebifier.inconsistency_resolution import ScoreBasedPredictionSmoother
from chebifier.prediction_models.base_predictor import BasePredictor
from chebifier.utils import (
get_default_configs,
Expand All @@ -24,8 +23,9 @@ class BaseEnsemble:
def __init__(
self,
model_configs: Union[str, Path, dict, None] = None,
chebi_version: int = 241,
resolve_inconsistencies: bool = True,
verbose_output: bool = False,
use_confidence: bool = True,
):
# Deferred Import: To avoid circular import error
from chebifier.model_registry import MODEL_TYPES
Expand All @@ -48,6 +48,8 @@ def __init__(
model_registry = yaml.safe_load(f)

processed_configs = process_config(config, model_registry)
self.verbose_output = verbose_output
self.use_confidence = use_confidence

self.chebi_graph = load_chebi_graph()
self.disjoint_files = get_disjoint_files()
Expand All @@ -73,10 +75,11 @@ def __init__(
self.models.append(model_instance)

if resolve_inconsistencies:
self.smoother = PredictionSmoother(
self.smoother = ScoreBasedPredictionSmoother(
self.chebi_graph,
label_names=None,
disjoint_files=self.disjoint_files,
verbose=self.verbose_output,
)
else:
self.smoother = None
Expand All @@ -92,7 +95,8 @@ def gather_predictions(self, smiles_list):
if logits_for_smiles is not None:
for cls in logits_for_smiles:
predicted_classes.add(cls)
print(f"Sorting predictions from {len(model_predictions)} models...")
if self.verbose_output:
print(f"Sorting predictions from {len(model_predictions)} models...")
predicted_classes = sorted(list(predicted_classes))
predicted_classes_dict = {cls: i for i, cls in enumerate(predicted_classes)}
ordered_logits = (
Expand All @@ -114,7 +118,11 @@ def gather_predictions(self, smiles_list):
return ordered_logits, predicted_classes

def consolidate_predictions(
self, predictions, classwise_weights, predicted_classes, **kwargs
self,
predictions,
classwise_weights,
return_intermediate_results=False,
**kwargs,
):
"""
Aggregates predictions from multiple models using weighted majority voting.
Expand All @@ -137,7 +145,9 @@ def consolidate_predictions(
predictions < self.positive_prediction_threshold
) & valid_predictions

if "use_confidence" in kwargs and kwargs["use_confidence"]:
# if use_confidence is passed in kwargs, it overrides the ensemble setting
use_confidence = kwargs.get("use_confidence", self.use_confidence)
if use_confidence:
confidence = 2 * torch.abs(
predictions.nan_to_num() - self.positive_prediction_threshold
)
Expand All @@ -164,22 +174,39 @@ def consolidate_predictions(

# Determine which classes to include for each SMILES
net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
if return_intermediate_results:
return (
net_score,
has_valid_predictions,
{
"positive_mask": positive_mask,
"negative_mask": negative_mask,
"confidence": confidence,
"positive_sum": positive_sum,
"negative_sum": negative_sum,
},
)

return net_score, has_valid_predictions

def apply_inconsistency_resolution(
self, net_score, class_names, has_valid_predictions
):
# Smooth predictions
start_time = time.perf_counter()
class_names = list(predicted_classes.keys())
if self.smoother is not None:
self.smoother.set_label_names(class_names)
smooth_net_score = self.smoother(net_score)
class_decisions = (
smooth_net_score > 0.5
smooth_net_score > 0
) & has_valid_predictions # Shape: (num_smiles, num_classes)
else:
class_decisions = (
net_score > 0
) & has_valid_predictions # Shape: (num_smiles, num_classes)
end_time = time.perf_counter()
print(f"Prediction smoothing took {end_time - start_time:.2f} seconds")
if self.verbose_output:
print(f"Prediction smoothing took {end_time - start_time:.2f} seconds")

complete_failure = torch.all(~has_valid_predictions, dim=1)
return class_decisions, complete_failure
Expand All @@ -192,38 +219,28 @@ def calculate_classwise_weights(self, predicted_classes):
return positive_weights, negative_weights

def predict_smiles_list(
self, smiles_list, load_preds_if_possible=False, **kwargs
self, smiles_list, return_intermediate_results=False, **kwargs
) -> list:
preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt"
predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt"
if not load_preds_if_possible or not os.path.isfile(preds_file):
ordered_predictions, predicted_classes = self.gather_predictions(
smiles_list
)
if len(predicted_classes) == 0:
print(
"Warning: No classes have been predicted for the given SMILES list."
ordered_predictions, predicted_classes = self.gather_predictions(smiles_list)
if len(predicted_classes) == 0:
print("Warning: No classes have been predicted for the given SMILES list.")
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}

classwise_weights = self.calculate_classwise_weights(predicted_classes)
if return_intermediate_results:
net_score, has_valid_predictions, intermediate_results_dict = (
self.consolidate_predictions(
ordered_predictions,
classwise_weights,
return_intermediate_results=return_intermediate_results,
)
# save predictions
if load_preds_if_possible:
torch.save(ordered_predictions, preds_file)
with open(predicted_classes_file, "w") as f:
for cls in predicted_classes:
f.write(f"{cls}\n")
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
)
else:
print(
f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}"
net_score, has_valid_predictions = self.consolidate_predictions(
ordered_predictions, classwise_weights
)
ordered_predictions = torch.load(preds_file)
with open(predicted_classes_file, "r") as f:
predicted_classes = {
line.strip(): i for i, line in enumerate(f.readlines())
}

classwise_weights = self.calculate_classwise_weights(predicted_classes)
class_decisions, is_failure = self.consolidate_predictions(
ordered_predictions, classwise_weights, predicted_classes, **kwargs
class_decisions, is_failure = self.apply_inconsistency_resolution(
net_score, list(predicted_classes.keys()), has_valid_predictions
)

class_names = list(predicted_classes.keys())
Expand All @@ -239,6 +256,11 @@ def predict_smiles_list(
)
for i, failure in zip(class_decisions, is_failure)
]
if return_intermediate_results:
intermediate_results_dict["predicted_classes"] = predicted_classes
intermediate_results_dict["classwise_weights"] = classwise_weights
intermediate_results_dict["net_score"] = net_score
return result, intermediate_results_dict

return result

Expand Down
Loading