Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,26 @@ python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=co

### Predicting classes given SMILES strings
```
python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] ----smiles_file_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
```
The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the
one row for each SMILES string and one column for each class.
The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs.

* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`).

* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line.

* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class.

* **`--classes_path`** *(optional)*: Path to the dataset’s `raw/classes.txt` file, which maps model output indices to ChEBI IDs.

* If provided, the CSV columns will be named using the ChEBI IDs.
* If omitted, then script will located the file automatically. If unable to locate then the columns will be numbered sequentially.

## Evaluation

You can evaluate a model trained on the ontology extension task in one of two ways:

### 1. Using the Jupyter Notebook
An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
- Load your finetuned model and run the evaluation cells to compute metrics on the test set.

### 2. Using the Lightning CLI
Expand Down
1 change: 0 additions & 1 deletion chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def subcommands() -> Dict[str, Set[str]]:
"validate": {"model", "dataloaders", "datamodule"},
"test": {"model", "dataloaders", "datamodule"},
"predict": {"model", "dataloaders", "datamodule"},
"predict_from_file": {"model"},
}


Expand Down
10 changes: 9 additions & 1 deletion chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,15 @@ def predict_step(
Returns:
Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step.
"""
return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False)
assert isinstance(batch, XYData)
batch = batch.to(self.device)
data = self._process_batch(batch, batch_idx)
model_output = self(data, **data.get("model_kwargs", dict()))

# Dummy labels to avoid errors in _get_prediction_and_labels
labels = torch.zeros((len(batch), self.out_dim)).to(self.device)
pr, _ = self._get_prediction_and_labels(data, labels, model_output)
return pr

def _execute(
self,
Expand Down
63 changes: 57 additions & 6 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,14 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]:
for d in tqdm.tqdm(self._load_dict(path), total=lines)
if d["features"] is not None
]

return self._filter_to_token_limit(data)

def _filter_to_token_limit(
self, data: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
# filter for missing features in resulting data, keep features length below token limit
data = [
return [
val
for val in data
if val["features"] is not None
Expand All @@ -349,8 +355,6 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]:
)
]

return data

def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
"""
Returns the train DataLoader.
Expand Down Expand Up @@ -400,10 +404,14 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]
Returns:
Union[DataLoader, List[DataLoader]]: A DataLoader object for test data.
"""

return self.dataloader("test", shuffle=False, **kwargs)

def predict_dataloader(
self, *args, **kwargs
self,
smiles_list: List[str],
model_hparams: Optional[dict] = None,
**kwargs,
) -> Union[DataLoader, List[DataLoader]]:
"""
Returns the predict DataLoader.
Expand All @@ -415,7 +423,38 @@ def predict_dataloader(
Returns:
Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data.
"""
return self.dataloader(self.prediction_kind, shuffle=False, **kwargs)

data = self._process_input_for_prediction(smiles_list, model_hparams)
return DataLoader(
data,
collate_fn=self.reader.collator,
batch_size=self.batch_size,
**kwargs,
)

def _process_input_for_prediction(
self, smiles_list: list[str], model_hparams: Optional[dict] = None
) -> list:
"""
Process input data for prediction.

Args:
smiles_list (List[str]): List of SMILES strings.

Returns:
List[Dict[str, Any]]: Processed input data.
"""
# Add dummy labels because the collate function requires them.
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
# which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
data = [
self.reader.to_data(
{"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]}
)
for idx, smiles in enumerate(smiles_list)
]
data = self._filter_to_token_limit(data)
return data

def prepare_data(self, *args, **kwargs) -> None:
if self._prepare_data_flag != 1:
Expand Down Expand Up @@ -1190,7 +1229,8 @@ def _retrieve_splits_from_csv(self) -> None:
print(f"Applying label filter from {self.apply_label_filter}...")
with open(self.apply_label_filter, "r") as f:
label_filter = [line.strip() for line in f]
with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf:

with open(self.classes_txt_file_path, "r") as cf:
classes = [line.strip() for line in cf]
# reorder labels
old_labels = np.stack(df_data["labels"])
Expand Down Expand Up @@ -1315,3 +1355,14 @@ def processed_file_names_dict(self) -> dict:
if self.n_token_limit is not None:
return {"data": f"data_maxlen{self.n_token_limit}.pt"}
return {"data": "data.pt"}

@property
def classes_txt_file_path(self) -> str:
"""
Returns the filename for the classes text file.

Returns:
str: The filename for the classes text file.
"""
# This property also used in custom trainer `chebai/trainer/CustomTrainer.py`
return os.path.join(self.processed_dir_main, "classes.txt")
151 changes: 151 additions & 0 deletions chebai/result/prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import os
from typing import List, Optional

import pandas as pd
import torch
from jsonargparse import CLI
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.cli import instantiate_module
from torch.utils.data import DataLoader

from chebai.models.base import ChebaiBaseNet
from chebai.preprocessing.datasets.base import XYBaseDataModule


class Predictor:
def __init__(
self,
checkpoint_path: _PATH,
batch_size: Optional[int] = None,
compile_model: bool = True,
):
"""Initializes the Predictor with a model loaded from the checkpoint.

Args:
checkpoint_path: Path to the model checkpoint.
batch_size: Optional batch size for the DataLoader. If not provided,
the default from the datamodule will be used.
compile_model: Whether to compile the model using torch.compile. Default is True.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_file = torch.load(
checkpoint_path, map_location=self.device, weights_only=False
)
print("-" * 50)
print(f"For Loaded checkpoint from: {checkpoint_path}")
print("Below are the modules loaded from the checkpoint:")

self._dm_hparams = ckpt_file["datamodule_hyper_parameters"]
self._dm_hparams.pop("splits_file_path")
self._dm: XYBaseDataModule = instantiate_module(
XYBaseDataModule, self._dm_hparams
)
if batch_size is not None and int(batch_size) > 0:
self._dm.batch_size = int(batch_size)
print("*" * 10, f"Loaded datamodule class: {self._dm.__class__.__name__}")

self._model_hparams = ckpt_file["hyper_parameters"]
self._model: ChebaiBaseNet = instantiate_module(
ChebaiBaseNet, self._model_hparams
)
print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}")

if compile_model:
self._model = torch.compile(self._model)
self._model.eval()
print("-" * 50)

def predict_from_file(
self,
smiles_file_path: _PATH,
save_to: _PATH = "predictions.csv",
classes_path: Optional[_PATH] = None,
) -> None:
"""
Loads a model from a checkpoint and makes predictions on input data from a file.

Args:
smiles_file_path: Path to the input file containing SMILES strings.
save_to: Path to save the predictions CSV file.
classes_path: Optional path to a file containing class names:
if no class names are provided, code will try to get the class path
from the datamodule, else the columns will be numbered.
"""
with open(smiles_file_path, "r") as input:
smiles_strings = [inp.strip() for inp in input.readlines()]

preds: torch.Tensor = self.predict_smiles(smiles=smiles_strings)

predictions_df = pd.DataFrame(preds.detach().cpu().numpy())

def _add_class_columns(class_file_path: _PATH):
with open(class_file_path, "r") as f:
predictions_df.columns = [cls.strip() for cls in f.readlines()]

if classes_path is not None:
_add_class_columns(classes_path)
elif os.path.exists(self._dm.classes_txt_file_path):
_add_class_columns(self._dm.classes_txt_file_path)

predictions_df.index = smiles_strings
predictions_df.to_csv(save_to)

@torch.inference_mode()
def predict_smiles(
self,
smiles: List[str],
) -> torch.Tensor:
"""
Predicts the output for a list of SMILES strings using the model.

Args:
smiles: A list of SMILES strings.

Returns:
A tensor containing the predictions.
"""
# For certain data prediction piplines, we may need model hyperparameters
pred_dl: DataLoader = self._dm.predict_dataloader(
smiles_list=smiles, model_hparams=self._model_hparams
)

preds = []
for batch_idx, batch in enumerate(pred_dl):
# For certain model prediction pipelines, we may need data module hyperparameters
preds.append(
self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams)
)

return torch.cat(preds)


class MainPredictor:
@staticmethod
def predict_from_file(
checkpoint_path: _PATH,
smiles_file_path: _PATH,
save_to: _PATH = "predictions.csv",
classes_path: Optional[_PATH] = None,
batch_size: Optional[int] = None,
) -> None:
predictor = Predictor(checkpoint_path, batch_size)
predictor.predict_from_file(
smiles_file_path,
save_to,
classes_path,
)

@staticmethod
def predict_smiles(
checkpoint_path: _PATH,
smiles: List[str],
batch_size: Optional[int] = None,
) -> torch.Tensor:
predictor = Predictor(checkpoint_path, batch_size)
return predictor.predict_smiles(smiles)


if __name__ == "__main__":
# python chebai/result/prediction.py predict_from_file --help
# python chebai/result/prediction.py predict_smiles --help
CLI(MainPredictor, as_positional=False)
Loading