From 3ab691fc016877b61341efa91ea3f591603b95d7 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 20 Aug 2025 16:21:01 +0200 Subject: [PATCH 01/46] add fingerprint dataset, logistic regression model --- chebai/models/base.py | 3 +- chebai/models/classic_ml.py | 62 ++++++++++++++++++++++++++ chebai/preprocessing/datasets/chebi.py | 16 +++++++ chebai/preprocessing/reader.py | 30 +++++++++++++ 4 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 chebai/models/classic_ml.py diff --git a/chebai/models/base.py b/chebai/models/base.py index e657963f..2da6002e 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -106,7 +106,8 @@ def _get_prediction_and_labels( Returns: Tuple[torch.Tensor, torch.Tensor]: Predictions and labels. """ - return output, labels + # cast labels to int + return output, labels.to(torch.int) if labels is not None else labels def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor: """ diff --git a/chebai/models/classic_ml.py b/chebai/models/classic_ml.py new file mode 100644 index 00000000..f0b10f78 --- /dev/null +++ b/chebai/models/classic_ml.py @@ -0,0 +1,62 @@ +from typing import Any, Dict + +import torch +import tqdm +from sklearn.exceptions import NotFittedError +from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression + +from chebai.models.base import ChebaiBaseNet + + +class LogisticRegression(ChebaiBaseNet): + """ + Logistic Regression model using scikit-learn, wrapped to fit the ChebaiBaseNet interface. + """ + + def __init__(self, out_dim: int, input_dim: int, **kwargs): + super().__init__(out_dim=out_dim, input_dim=input_dim, **kwargs) + self.models = [SklearnLogisticRegression(solver="liblinear") for _ in range(5)] + + def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor: + print( + f"forward called with x[features].shape {x['features'].shape}, self.training {self.training}" + ) + if self.training: + self.fit_sklearn(x["features"], x["labels"]) + try: + preds = [ + torch.from_numpy(model.predict(x["features"])) + .to( + x["features"].device + if isinstance(x["features"], torch.Tensor) + else "cpu" + ) + .float() + for model in self.models + ] + except NotFittedError: + # Not fitted yet, return zeros + print( + f"returning default 0s with shape {(x['features'].shape[0], self.out_dim)}" + ) + return torch.zeros( + (x["features"].shape[0], self.out_dim), + device=( + x["features"].device + if isinstance(x["features"], torch.Tensor) + else "cpu" + ), + ) + preds = torch.stack(preds, dim=1) + print(f"preds shape {preds.shape}") + return preds + + def fit_sklearn(self, X, y): + """ + Fit the underlying sklearn model. X and y should be numpy arrays. + """ + for i, model in tqdm.tqdm(enumerate(self.models), desc="Fitting models"): + model.fit(X, y[:, i]) + + def configure_optimizers(self, **kwargs): + pass diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 9fa1c1c7..437a23d3 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -809,6 +809,22 @@ class ChEBIOver50Partial(ChEBIOverXPartial, ChEBIOver50): pass +class ChEBIOverXFingerprints(ChEBIOverX): + """A class that uses Fingerprints for the processed data (used for fixed-length ML models).""" + + READER = dr.FingerprintReader + + +class ChEBIOver100Fingerprints(ChEBIOverXFingerprints, ChEBIOver100): + """ + A class for extracting data from the ChEBI dataset with Fingerprints reader and a threshold of 100. + + Inherits from ChEBIOverXFingerprints and ChEBIOver100. + """ + + pass + + class JCIExtendedBPEData(JCIExtendedBase): READER = dr.ChemBPEReader diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 4b1b0353..0cf81dbd 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -372,3 +372,33 @@ def name(cls) -> str: def _read_data(self, raw_data: str) -> List[int]: """Convert characters in raw data to their ordinal values.""" return [ord(s) for s in raw_data] + + +class FingerprintReader(DataReader): + """ + Data reader for chemical data using RDKit fingerprints. + + Args: + collator_kwargs: Optional dictionary of keyword arguments for the collator. + kwargs: Additional keyword arguments. + """ + + COLLATOR = DefaultCollator + + def __init__(self, fingerprint_size=1024, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fingerprint_size = fingerprint_size + + @classmethod + def name(cls) -> str: + """Returns the name of the data reader.""" + return "rdkit_fingerprint" + + def _read_data(self, raw_data: str) -> List[int]: + """Generate RDKit fingerprint from raw SMILES data.""" + mol = Chem.MolFromSmiles(raw_data.strip()) + if mol is None: + raise ValueError(f"Invalid SMILES: {raw_data}") + return list( + Chem.RDKFingerprint(mol, fpSize=self.fingerprint_size).ToBitString() + ) From bb73784ac6e463737a195dfc0749a97002ae4534 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 20 Aug 2025 16:22:42 +0200 Subject: [PATCH 02/46] add batched pubchem dataset --- chebai/preprocessing/datasets/pubchem.py | 72 ++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index a1879fe7..59b3247e 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -212,6 +212,78 @@ def _perform_data_preparation(self, *args, **kwargs): print("Done") +class PubChemBatched(PubChem): + """Store train data as batches of 10m, validation and test should each be 100k max""" + + def __init__(self, *args, **kwargs): + super(PubChemBatched, self).__init__(*args, **kwargs) + self.train_batch_size = 10_000_000 + if self.k != self.FULL: + self.val_batch_size = ( + 100_000 + if self.validation_split * self.k > 100_000 + else int(self.validation_split * self.k) + ) + self.test_batch_size = ( + 100_000 + if self.test_split * self.k > 100_000 + else int(self.test_split * self.k) + ) + else: + self.val_batch_size = 100_000 + self.test_batch_size = 100_000 + + @property + def processed_file_names(self) -> List[str]: + """ + Returns: + List[str]: List of processed data file names. + """ + train_samples = ( + self._k if self._k != self.FULL else 120_000_000 + ) # estimate size + train_samples -= self.val_batch_size + self.test_batch_size + train_batches = ( + ["train.pt"] + if train_samples <= self.train_batch_size + else [ + f"train_{i}.pt" + for i in range((train_samples // self.train_batch_size) + 1) + ] + ) + return train_batches + ["test.pt", "validation.pt"] + + def setup_processed(self): + """ + Prepares processed data and saves them as Torch tensors. + """ + filename = os.path.join(self.raw_dir, self.raw_file_names[0]) + print("Load data from file", filename) + data = self._load_data_from_file(filename) + print("Create splits") + train, test = train_test_split( + data, test_size=self.test_batch_size + self.val_batch_size + ) + del data + test, val = train_test_split(test, train_size=self.test_batch_size) + torch.save(test, os.path.join(self.processed_dir, "test.pt")) + torch.save(val, os.path.join(self.processed_dir, "validation.pt")) + + # batch training if necessary + if len(train) > self.train_batch_size: + train_batches = [ + train[i : i + self.train_batch_size] + for i in range(0, len(train), self.train_batch_size) + ] + train = [torch.tensor(batch) for batch in train_batches] + for i, batch in enumerate(train): + torch.save(batch, os.path.join(self.processed_dir, f"train_{i}.pt")) + else: + torch.save(train, os.path.join(self.processed_dir, "train.pt")) + + self.reader.on_finish() + + class PubChemDissimilar(PubChem): """ Subset of PubChem, but choosing the most dissimilar molecules (according to fingerprint) From db5434f4ebaa6152f13351068d79637d26d14326 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 21 Aug 2025 16:27:07 +0200 Subject: [PATCH 03/46] update file name system for pubchem batched --- chebai/preprocessing/datasets/pubchem.py | 33 +++++++++++++++--------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index 59b3247e..38b32d19 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -176,12 +176,16 @@ def raw_file_names(self) -> List[str]: return ["smiles.txt"] @property - def processed_file_names(self) -> List[str]: + def processed_file_names_dict(self) -> List[str]: """ Returns: List[str]: List of processed data file names. """ - return ["test.pt", "train.pt", "validation.pt"] + return { + "train": "train.pt", + "test": "test.pt", + "validation": "validation.pt" + } def _set_processed_data_props(self): """ @@ -215,9 +219,12 @@ def _perform_data_preparation(self, *args, **kwargs): class PubChemBatched(PubChem): """Store train data as batches of 10m, validation and test should each be 100k max""" - def __init__(self, *args, **kwargs): + READER: Type[dr.ChemDataReader] = dr.ChemDataReader + + def __init__(self, pc_train_batch_idx=0, train_batch_size=10_000_000, *args, **kwargs): super(PubChemBatched, self).__init__(*args, **kwargs) - self.train_batch_size = 10_000_000 + self.pc_train_batch_idx = pc_train_batch_idx + self.train_batch_size = train_batch_size if self.k != self.FULL: self.val_batch_size = ( 100_000 @@ -234,7 +241,7 @@ def __init__(self, *args, **kwargs): self.test_batch_size = 100_000 @property - def processed_file_names(self) -> List[str]: + def processed_file_names_dict(self) -> List[str]: """ Returns: List[str]: List of processed data file names. @@ -244,14 +251,16 @@ def processed_file_names(self) -> List[str]: ) # estimate size train_samples -= self.val_batch_size + self.test_batch_size train_batches = ( - ["train.pt"] + {"train": "train.pt"} if train_samples <= self.train_batch_size - else [ - f"train_{i}.pt" + else { + f"train" if i == self.pc_train_batch_idx else f"train_{i}": f"train_{i}.pt" for i in range((train_samples // self.train_batch_size) + 1) - ] + } ) - return train_batches + ["test.pt", "validation.pt"] + train_batches["test"] = "test.pt" + train_batches["validation"] = "validation.pt" + return train_batches def setup_processed(self): """ @@ -266,8 +275,8 @@ def setup_processed(self): ) del data test, val = train_test_split(test, train_size=self.test_batch_size) - torch.save(test, os.path.join(self.processed_dir, "test.pt")) - torch.save(val, os.path.join(self.processed_dir, "validation.pt")) + torch.save(test, os.path.join(self.processed_dir, self.processed_file_names_dict["test"])) + torch.save(val, os.path.join(self.processed_dir, self.processed_file_names_dict["validation"])) # batch training if necessary if len(train) > self.train_batch_size: From abffcaea46ade01426d0d6260a1691557c127639 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 21 Aug 2025 16:40:23 +0200 Subject: [PATCH 04/46] fix k --- chebai/preprocessing/datasets/pubchem.py | 26 ++++++++++++++++-------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index 38b32d19..79edff10 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -181,11 +181,7 @@ def processed_file_names_dict(self) -> List[str]: Returns: List[str]: List of processed data file names. """ - return { - "train": "train.pt", - "test": "test.pt", - "validation": "validation.pt" - } + return {"train": "train.pt", "test": "test.pt", "validation": "validation.pt"} def _set_processed_data_props(self): """ @@ -221,7 +217,9 @@ class PubChemBatched(PubChem): READER: Type[dr.ChemDataReader] = dr.ChemDataReader - def __init__(self, pc_train_batch_idx=0, train_batch_size=10_000_000, *args, **kwargs): + def __init__( + self, pc_train_batch_idx=0, train_batch_size=10_000_000, *args, **kwargs + ): super(PubChemBatched, self).__init__(*args, **kwargs) self.pc_train_batch_idx = pc_train_batch_idx self.train_batch_size = train_batch_size @@ -254,7 +252,9 @@ def processed_file_names_dict(self) -> List[str]: {"train": "train.pt"} if train_samples <= self.train_batch_size else { - f"train" if i == self.pc_train_batch_idx else f"train_{i}": f"train_{i}.pt" + ( + "train" if i == self.pc_train_batch_idx else f"train_{i}" + ): f"train_{i}.pt" for i in range((train_samples // self.train_batch_size) + 1) } ) @@ -275,8 +275,16 @@ def setup_processed(self): ) del data test, val = train_test_split(test, train_size=self.test_batch_size) - torch.save(test, os.path.join(self.processed_dir, self.processed_file_names_dict["test"])) - torch.save(val, os.path.join(self.processed_dir, self.processed_file_names_dict["validation"])) + torch.save( + test, + os.path.join(self.processed_dir, self.processed_file_names_dict["test"]), + ) + torch.save( + val, + os.path.join( + self.processed_dir, self.processed_file_names_dict["validation"] + ), + ) # batch training if necessary if len(train) > self.train_batch_size: From 6d7ca439d674e1e4383d8decfc54bae1c6769742 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 21 Aug 2025 16:41:20 +0200 Subject: [PATCH 05/46] fix k --- chebai/preprocessing/datasets/pubchem.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index 79edff10..28a341a2 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -223,16 +223,16 @@ def __init__( super(PubChemBatched, self).__init__(*args, **kwargs) self.pc_train_batch_idx = pc_train_batch_idx self.train_batch_size = train_batch_size - if self.k != self.FULL: + if self._k != self.FULL: self.val_batch_size = ( 100_000 - if self.validation_split * self.k > 100_000 - else int(self.validation_split * self.k) + if self.validation_split * self._k > 100_000 + else int(self.validation_split * self._k) ) self.test_batch_size = ( 100_000 - if self.test_split * self.k > 100_000 - else int(self.test_split * self.k) + if self.test_split * self._k > 100_000 + else int(self.test_split * self._k) ) else: self.val_batch_size = 100_000 From bf97527477f84d6ac0752a196120ce424e6f9a9a Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 22 Aug 2025 09:50:12 +0200 Subject: [PATCH 06/46] add error handling for smiles tokenisation --- chebai/preprocessing/reader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 0cf81dbd..af7b0bb1 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -205,8 +205,12 @@ def _read_data(self, raw_data: str) -> List[int]: except Exception as e: print(f"RDKit failed to process {raw_data}") print(f"\t{e}") - - return [self._get_token_index(v[1]) for v in _tokenize(raw_data)] + try: + return [self._get_token_index(v[1]) for v in _tokenize(raw_data)] + except ValueError as e: + print(f"could not process {raw_data}") + print(f"\t{e}") + return None class DeepChemDataReader(ChemDataReader): From 7f9e28d34f8633ffb42e442a42f1ad28adb258b8 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 22 Aug 2025 09:51:57 +0200 Subject: [PATCH 07/46] add default model --- chebai/models/classic_ml.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/chebai/models/classic_ml.py b/chebai/models/classic_ml.py index f0b10f78..a92e388f 100644 --- a/chebai/models/classic_ml.py +++ b/chebai/models/classic_ml.py @@ -1,5 +1,6 @@ from typing import Any, Dict - +import pickle as pkl +import numpy as np import torch import tqdm from sklearn.exceptions import NotFittedError @@ -15,7 +16,7 @@ class LogisticRegression(ChebaiBaseNet): def __init__(self, out_dim: int, input_dim: int, **kwargs): super().__init__(out_dim=out_dim, input_dim=input_dim, **kwargs) - self.models = [SklearnLogisticRegression(solver="liblinear") for _ in range(5)] + self.models = [SklearnLogisticRegression(solver="liblinear") for _ in range(out_dim)] def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor: print( @@ -56,7 +57,27 @@ def fit_sklearn(self, X, y): Fit the underlying sklearn model. X and y should be numpy arrays. """ for i, model in tqdm.tqdm(enumerate(self.models), desc="Fitting models"): - model.fit(X, y[:, i]) + import os + if os.path.exists(f"LR_CHEBI100_model_{i}.pkl"): + print(f"Loading model {i} from file") + self.models[i] = pkl.load(open(f"LR_CHEBI100_model_{i}.pkl", "rb")) + else: + try: + model.fit(X, y[:, i]) + except ValueError as e: + self.models[i] = PlaceholderModel() + # dump + pkl.dump(model, open(f"LR_CHEBI100_model_{i}.pkl", "wb")) def configure_optimizers(self, **kwargs): pass + + +class PlaceholderModel: + """Acts like a trained model, but isn't. Use this if training fails and you need a placeholder.""" + + def __init__(self, default_prediction=1): + self.default_prediction = default_prediction + + def predict(self, preds): + return np.ones(preds.shape[0]) * self.default_prediction \ No newline at end of file From f76fe1c7c858a7d1d22a46060637132c04b1c859 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 22 Aug 2025 13:47:45 +0200 Subject: [PATCH 08/46] fix lstm --- chebai/callbacks/epoch_metrics.py | 5 ++++- chebai/models/lstm.py | 22 ++++++++++++++-------- configs/model/lstm.yml | 7 +++++++ 3 files changed, 25 insertions(+), 9 deletions(-) create mode 100644 configs/model/lstm.yml diff --git a/chebai/callbacks/epoch_metrics.py b/chebai/callbacks/epoch_metrics.py index c1cf7bd3..f5b30662 100644 --- a/chebai/callbacks/epoch_metrics.py +++ b/chebai/callbacks/epoch_metrics.py @@ -62,7 +62,10 @@ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None: labels (torch.Tensor): Ground truth labels. """ tps = torch.sum( - torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0 + torch.logical_and( + preds > self.threshold, labels.to(torch.bool) + ), + dim=0, ) self.true_positives += tps self.positive_predictions += torch.sum(preds > self.threshold, dim=0) diff --git a/chebai/models/lstm.py b/chebai/models/lstm.py index 3a0949c4..caf7bb5c 100644 --- a/chebai/models/lstm.py +++ b/chebai/models/lstm.py @@ -9,20 +9,26 @@ class ChemLSTM(ChebaiBaseNet): - def __init__(self, in_d, out_d, num_classes, **kwargs): - super().__init__(num_classes, **kwargs) - self.lstm = nn.LSTM(in_d, out_d, batch_first=True) - self.embedding = nn.Embedding(800, 100) + def __init__(self, out_d, in_d, num_classes, criterion : nn.Module=None, **kwargs): + super().__init__( + out_dim=out_d, + input_dim=in_d, + criterion=criterion, + num_classes=num_classes, + **kwargs + ) + self.lstm = nn.LSTM(in_d, out_d, batch_first=True, dropout=0.2, bidirectional=True) + self.embedding = nn.Embedding(1400, in_d) self.output = nn.Sequential( - nn.Linear(out_d, in_d), + nn.Linear(out_d * 2, in_d), nn.ReLU(), nn.Dropout(0.2), nn.Linear(in_d, num_classes), ) - def forward(self, data): - x = data.x - x_lens = data.lens + def forward(self, data, *args, **kwargs): + x = data["features"] + x_lens = data["model_kwargs"]["lens"] x = self.embedding(x) x = pack_padded_sequence(x, x_lens, batch_first=True, enforce_sorted=False) x = self.lstm(x)[1][0] diff --git a/configs/model/lstm.yml b/configs/model/lstm.yml new file mode 100644 index 00000000..b3d6e3db --- /dev/null +++ b/configs/model/lstm.yml @@ -0,0 +1,7 @@ +class_path: chebai.models.lstm.ChemLSTM +init_args: + in_d: 100 + out_d: 100 + num_classes: 1528 + optimizer_kwargs: + lr: 1e-3 \ No newline at end of file From 03d5e55d127cf4d4a71274511c0c770bc8542ec6 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 22 Aug 2025 14:08:23 +0200 Subject: [PATCH 09/46] fix lstm --- chebai/models/lstm.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/chebai/models/lstm.py b/chebai/models/lstm.py index caf7bb5c..4450e21f 100644 --- a/chebai/models/lstm.py +++ b/chebai/models/lstm.py @@ -9,21 +9,23 @@ class ChemLSTM(ChebaiBaseNet): - def __init__(self, out_d, in_d, num_classes, criterion : nn.Module=None, **kwargs): + def __init__(self, out_d, in_d, num_classes, criterion: nn.Module = None, **kwargs): super().__init__( out_dim=out_d, input_dim=in_d, criterion=criterion, num_classes=num_classes, - **kwargs + **kwargs, + ) + self.lstm = nn.LSTM( + in_d, out_d, batch_first=True, dropout=0.2, bidirectional=True ) - self.lstm = nn.LSTM(in_d, out_d, batch_first=True, dropout=0.2, bidirectional=True) self.embedding = nn.Embedding(1400, in_d) self.output = nn.Sequential( - nn.Linear(out_d * 2, in_d), + nn.Linear(out_d, out_d), nn.ReLU(), nn.Dropout(0.2), - nn.Linear(in_d, num_classes), + nn.Linear(out_d, num_classes), ) def forward(self, data, *args, **kwargs): From e1256b02684281a084e93bd0e1fef70be4771973 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 22 Aug 2025 14:10:00 +0200 Subject: [PATCH 10/46] fix lstm --- configs/model/lstm.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/model/lstm.yml b/configs/model/lstm.yml index b3d6e3db..9ee3f183 100644 --- a/configs/model/lstm.yml +++ b/configs/model/lstm.yml @@ -4,4 +4,4 @@ init_args: out_d: 100 num_classes: 1528 optimizer_kwargs: - lr: 1e-3 \ No newline at end of file + lr: 1e-3 From abc9b53a7bca75a57c311433adc50425bf874135 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 22 Aug 2025 14:26:12 +0200 Subject: [PATCH 11/46] fix lstm --- chebai/models/lstm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/models/lstm.py b/chebai/models/lstm.py index 4450e21f..529158a1 100644 --- a/chebai/models/lstm.py +++ b/chebai/models/lstm.py @@ -18,7 +18,7 @@ def __init__(self, out_d, in_d, num_classes, criterion: nn.Module = None, **kwar **kwargs, ) self.lstm = nn.LSTM( - in_d, out_d, batch_first=True, dropout=0.2, bidirectional=True + in_d, out_d, batch_first=True, dropout=0.2 ) self.embedding = nn.Embedding(1400, in_d) self.output = nn.Sequential( From 3b233d6a88683ff08183ddda406cb28fee0de3b3 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 16 Sep 2025 17:18:33 +0200 Subject: [PATCH 12/46] streamline classic ml --- chebai/callbacks/epoch_metrics.py | 4 +-- chebai/models/classic_ml.py | 48 ++++++++++++------------------- 2 files changed, 20 insertions(+), 32 deletions(-) diff --git a/chebai/callbacks/epoch_metrics.py b/chebai/callbacks/epoch_metrics.py index f5b30662..76d6a8fd 100644 --- a/chebai/callbacks/epoch_metrics.py +++ b/chebai/callbacks/epoch_metrics.py @@ -62,9 +62,7 @@ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None: labels (torch.Tensor): Ground truth labels. """ tps = torch.sum( - torch.logical_and( - preds > self.threshold, labels.to(torch.bool) - ), + torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0, ) self.true_positives += tps diff --git a/chebai/models/classic_ml.py b/chebai/models/classic_ml.py index a92e388f..554f19df 100644 --- a/chebai/models/classic_ml.py +++ b/chebai/models/classic_ml.py @@ -1,5 +1,6 @@ -from typing import Any, Dict import pickle as pkl +from typing import Any, Dict + import numpy as np import torch import tqdm @@ -16,7 +17,9 @@ class LogisticRegression(ChebaiBaseNet): def __init__(self, out_dim: int, input_dim: int, **kwargs): super().__init__(out_dim=out_dim, input_dim=input_dim, **kwargs) - self.models = [SklearnLogisticRegression(solver="liblinear") for _ in range(out_dim)] + self.models = [ + SklearnLogisticRegression(solver="liblinear") for _ in range(300) + ] def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor: print( @@ -24,33 +27,19 @@ def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor: ) if self.training: self.fit_sklearn(x["features"], x["labels"]) - try: - preds = [ - torch.from_numpy(model.predict(x["features"])) - .to( - x["features"].device - if isinstance(x["features"], torch.Tensor) - else "cpu" - ) - .float() - for model in self.models - ] - except NotFittedError: - # Not fitted yet, return zeros - print( - f"returning default 0s with shape {(x['features'].shape[0], self.out_dim)}" - ) - return torch.zeros( - (x["features"].shape[0], self.out_dim), - device=( - x["features"].device - if isinstance(x["features"], torch.Tensor) - else "cpu" - ), - ) + preds = [] + for model in self.models: + try: + p = torch.from_numpy(model.predict(x["features"])).float() + p = p.to(x["features"].device) + preds.append(p) + except NotFittedError: + preds.append(torch.zeros((x["features"].shape[0], 1), device=(x["features"].device))) + except AttributeError: + preds.append(torch.zeros((x["features"].shape[0], 1), device=(x["features"].device))) preds = torch.stack(preds, dim=1) print(f"preds shape {preds.shape}") - return preds + return preds.squeeze(-1) def fit_sklearn(self, X, y): """ @@ -58,13 +47,14 @@ def fit_sklearn(self, X, y): """ for i, model in tqdm.tqdm(enumerate(self.models), desc="Fitting models"): import os + if os.path.exists(f"LR_CHEBI100_model_{i}.pkl"): print(f"Loading model {i} from file") self.models[i] = pkl.load(open(f"LR_CHEBI100_model_{i}.pkl", "rb")) else: try: model.fit(X, y[:, i]) - except ValueError as e: + except ValueError: self.models[i] = PlaceholderModel() # dump pkl.dump(model, open(f"LR_CHEBI100_model_{i}.pkl", "wb")) @@ -80,4 +70,4 @@ def __init__(self, default_prediction=1): self.default_prediction = default_prediction def predict(self, preds): - return np.ones(preds.shape[0]) * self.default_prediction \ No newline at end of file + return np.ones(preds.shape[0]) * self.default_prediction From 8c0454c47f94ec61176f1d35e55651ed2f5bfe72 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 17 Sep 2025 10:50:36 +0200 Subject: [PATCH 13/46] fix batched pubchem --- chebai/preprocessing/datasets/pubchem.py | 60 +++++++++++++++++++----- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index 28a341a2..d466dcb1 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -262,25 +262,63 @@ def processed_file_names_dict(self) -> List[str]: train_batches["validation"] = "validation.pt" return train_batches + def _tokenize_batched(self, data): + """ + Load data from a file and return a list of dictionaries, batched in 1,000,000 entries. + + Args: + path (str): The path to the input file. + batch_size (int): The size of each batch. + batch_idx (int): The index of the batch to load. + + Returns: + List: A list of dictionaries containing the features and labels. + """ + print(f"Processing {len(data)} lines...") + batch = [] + for i, d in enumerate(tqdm.tqdm(data, total=len(data))): + if d["features"] is not None: + batch.append(self.reader.to_data(d)) + if i % 1_000_000 == 0 and i > 0: + print(f"Saving batch {i // 1_000_000}") + batch = [ + b + for b in batch + if b["features"] is not None + and self.n_token_limit is None + or len(b["features"]) <= self.n_token_limit + ] + yield batch + batch = [] + print("Saving final batch") + batch = [ + b + for b in batch + if b["features"] is not None + and self.n_token_limit is None + or len(b["features"]) <= self.n_token_limit + ] + yield batch + def setup_processed(self): """ Prepares processed data and saves them as Torch tensors. """ filename = os.path.join(self.raw_dir, self.raw_file_names[0]) print("Load data from file", filename) - data = self._load_data_from_file(filename) + data_not_tokenized = [entry for entry in self._load_dict(filename)] print("Create splits") train, test = train_test_split( - data, test_size=self.test_batch_size + self.val_batch_size + data_not_tokenized, test_size=self.test_batch_size + self.val_batch_size ) - del data + del data_not_tokenized test, val = train_test_split(test, train_size=self.test_batch_size) torch.save( - test, + self._tokenize_batched(test), os.path.join(self.processed_dir, self.processed_file_names_dict["test"]), ) torch.save( - val, + self._tokenize_batched(val), os.path.join( self.processed_dir, self.processed_file_names_dict["validation"] ), @@ -288,15 +326,13 @@ def setup_processed(self): # batch training if necessary if len(train) > self.train_batch_size: - train_batches = [ - train[i : i + self.train_batch_size] - for i in range(0, len(train), self.train_batch_size) - ] - train = [torch.tensor(batch) for batch in train_batches] - for i, batch in enumerate(train): + for i, batch in enumerate(self._tokenize_batched(train)): torch.save(batch, os.path.join(self.processed_dir, f"train_{i}.pt")) else: - torch.save(train, os.path.join(self.processed_dir, "train.pt")) + torch.save( + self._tokenize_batched(train), + os.path.join(self.processed_dir, "train.pt"), + ) self.reader.on_finish() From 04abe6670a18d6bd3d9234fd5524d4f8d2bb0186 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 17 Sep 2025 12:55:24 +0200 Subject: [PATCH 14/46] fix pubchem batching --- chebai/preprocessing/datasets/pubchem.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index d466dcb1..dd12862d 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -313,12 +313,14 @@ def setup_processed(self): ) del data_not_tokenized test, val = train_test_split(test, train_size=self.test_batch_size) + # Save first (and only) test batch torch.save( - self._tokenize_batched(test), + next(self._tokenize_batched(test)), os.path.join(self.processed_dir, self.processed_file_names_dict["test"]), ) + # save first (and only) validation batch torch.save( - self._tokenize_batched(val), + next(self._tokenize_batched(val)), os.path.join( self.processed_dir, self.processed_file_names_dict["validation"] ), @@ -330,7 +332,7 @@ def setup_processed(self): torch.save(batch, os.path.join(self.processed_dir, f"train_{i}.pt")) else: torch.save( - self._tokenize_batched(train), + next(self._tokenize_batched(train)), os.path.join(self.processed_dir, "train.pt"), ) From 97079c3dc1671cbd4597883225e9e62f60bc9075 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 17 Sep 2025 13:14:06 +0200 Subject: [PATCH 15/46] fix batch tokenisation --- chebai/preprocessing/datasets/pubchem.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index dd12862d..7b9a61b0 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -295,8 +295,8 @@ def _tokenize_batched(self, data): b for b in batch if b["features"] is not None - and self.n_token_limit is None - or len(b["features"]) <= self.n_token_limit + and (self.n_token_limit is None + or len(b["features"]) <= self.n_token_limit) ] yield batch From 5e6c508f7780f7ebc52bc3e4d3024403745a72e2 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 17 Sep 2025 13:33:56 +0200 Subject: [PATCH 16/46] fix batch tokenisation --- chebai/preprocessing/datasets/pubchem.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index 7b9a61b0..cf06c116 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -285,8 +285,8 @@ def _tokenize_batched(self, data): b for b in batch if b["features"] is not None - and self.n_token_limit is None - or len(b["features"]) <= self.n_token_limit + and (self.n_token_limit is None + or len(b["features"]) <= self.n_token_limit) ] yield batch batch = [] @@ -295,8 +295,7 @@ def _tokenize_batched(self, data): b for b in batch if b["features"] is not None - and (self.n_token_limit is None - or len(b["features"]) <= self.n_token_limit) + and (self.n_token_limit is None or len(b["features"]) <= self.n_token_limit) ] yield batch From 03cb2126250acc53e9f9765c6533235caec9f95d Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 17 Sep 2025 14:23:40 +0200 Subject: [PATCH 17/46] fix batch tokenisation --- chebai/preprocessing/datasets/pubchem.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index cf06c116..62434db0 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -285,9 +285,9 @@ def _tokenize_batched(self, data): b for b in batch if b["features"] is not None - and (self.n_token_limit is None - or len(b["features"]) <= self.n_token_limit) ] + if self.n_token_limit is not None: + batch = [b for b in batch if len(b["features"]) <= self.n_token_limit] yield batch batch = [] print("Saving final batch") @@ -295,8 +295,9 @@ def _tokenize_batched(self, data): b for b in batch if b["features"] is not None - and (self.n_token_limit is None or len(b["features"]) <= self.n_token_limit) ] + if self.n_token_limit is not None: + batch = [b for b in batch if len(b["features"]) <= self.n_token_limit] yield batch def setup_processed(self): From 0f1e7c0fe607fbbecb517c6790996bc3803c9582 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 17 Sep 2025 14:33:28 +0200 Subject: [PATCH 18/46] run n epochs with n different training files --- chebai/models/base.py | 5 +++++ chebai/preprocessing/datasets/pubchem.py | 28 +++++++++++++++++++----- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 2da6002e..af08baa6 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -158,6 +158,11 @@ def _process_for_loss( Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: Model output, labels, and loss kwargs. """ return model_output, labels, loss_kwargs + + def on_train_epoch_start(self) -> None: + # pass current epoch to datamodule if it has the attribute curr_epoch (for PubChemBatched dataset) + if hasattr(self.trainer.datamodule, "curr_epoch"): + self.trainer.datamodule.curr_epoch = self.current_epoch def training_step( self, batch: XYData, batch_idx: int diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index 62434db0..ed6d60c9 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -14,7 +14,7 @@ import tempfile import time from datetime import datetime -from typing import Generator, List, Optional, Tuple, Type +from typing import Generator, List, Optional, Tuple, Type, Union import numpy as np import pandas as pd @@ -218,10 +218,10 @@ class PubChemBatched(PubChem): READER: Type[dr.ChemDataReader] = dr.ChemDataReader def __init__( - self, pc_train_batch_idx=0, train_batch_size=10_000_000, *args, **kwargs + self, train_batch_size=10_000_000, *args, **kwargs ): super(PubChemBatched, self).__init__(*args, **kwargs) - self.pc_train_batch_idx = pc_train_batch_idx + self.curr_epoch = 0 self.train_batch_size = train_batch_size if self._k != self.FULL: self.val_batch_size = ( @@ -252,9 +252,7 @@ def processed_file_names_dict(self) -> List[str]: {"train": "train.pt"} if train_samples <= self.train_batch_size else { - ( - "train" if i == self.pc_train_batch_idx else f"train_{i}" - ): f"train_{i}.pt" + f"train_{i}": f"train_{i}.pt" for i in range((train_samples // self.train_batch_size) + 1) } ) @@ -338,6 +336,24 @@ def setup_processed(self): self.reader.on_finish() + def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + """ + Returns the train DataLoader. This swaps the training batch for each epoch. + + Args: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + DataLoader: A DataLoader object for training data. + """ + return self.dataloader( + "train" if "train" in self.processed_file_names_dict else f"train_{self.curr_epoch}", + shuffle=True, + num_workers=self.num_workers, + persistent_workers=True, + **kwargs, + ) class PubChemDissimilar(PubChem): """ From 9eebad2b20691f65a143b35a1d4ec73d6af48bf4 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 18 Sep 2025 13:58:01 +0200 Subject: [PATCH 19/46] add logging --- chebai/preprocessing/datasets/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 4a1898bc..6aa6edcc 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -256,6 +256,7 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader: Returns: DataLoader: A DataLoader object. """ + rank_zero_info(f"Loading {kind} data...") dataset = self.load_processed_data(kind) if "ids" in kwargs: ids = kwargs.pop("ids") @@ -439,6 +440,7 @@ def setup(self, *args, **kwargs) -> None: rank_zero_info(f"Check for processed data in {self.processed_dir}") rank_zero_info(f"Cross-validation enabled: {self.use_inner_cross_validation}") + rank_zero_info(f"Looking for files: {self.processed_file_names}") if any( not os.path.isfile(os.path.join(self.processed_dir, f)) for f in self.processed_file_names From faa3a728c4193316bcb7434e89ef314778624bf8 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 18 Sep 2025 14:23:06 +0200 Subject: [PATCH 20/46] lstm error logging --- chebai/models/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chebai/models/base.py b/chebai/models/base.py index af08baa6..869d196a 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -316,6 +316,8 @@ def _execute( for metric_name, metric in metrics.items(): metric.update(pr, tar) self._log_metrics(prefix, metrics, len(batch)) + if isinstance(d, dict) and not "loss" in d: + print(f"d has keys {d.keys()}, log={log}, criterion={self.criterion}") return d def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int): From 7f92917283b775dd40400a2d5cce43ae882b9da6 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 18 Sep 2025 16:42:07 +0200 Subject: [PATCH 21/46] add more logging to find out if pubchemBatched actually works --- chebai/models/base.py | 7 +++++-- chebai/preprocessing/datasets/base.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 869d196a..c12bdf03 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -1,6 +1,7 @@ import logging from abc import ABC, abstractmethod from typing import Any, Dict, Iterable, Optional, Union +from lightning.pytorch.utilities.rank_zero import rank_zero_info import torch from lightning.pytorch.core.module import LightningModule @@ -158,10 +159,12 @@ def _process_for_loss( Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: Model output, labels, and loss kwargs. """ return model_output, labels, loss_kwargs - + def on_train_epoch_start(self) -> None: # pass current epoch to datamodule if it has the attribute curr_epoch (for PubChemBatched dataset) + rank_zero_info(f"Starting epoch {self.current_epoch}") if hasattr(self.trainer.datamodule, "curr_epoch"): + rank_zero_info(f"Setting datamodule.curr_epoch to {self.current_epoch}") self.trainer.datamodule.curr_epoch = self.current_epoch def training_step( @@ -316,7 +319,7 @@ def _execute( for metric_name, metric in metrics.items(): metric.update(pr, tar) self._log_metrics(prefix, metrics, len(batch)) - if isinstance(d, dict) and not "loss" in d: + if isinstance(d, dict) and "loss" not in d: print(f"d has keys {d.keys()}, log={log}, criterion={self.criterion}") return d diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 6aa6edcc..bc5cd9f7 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -256,7 +256,7 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader: Returns: DataLoader: A DataLoader object. """ - rank_zero_info(f"Loading {kind} data...") + rank_zero_info(f"Loading {kind} data... (datamodule.current_epoch={self.current_epoch})") dataset = self.load_processed_data(kind) if "ids" in kwargs: ids = kwargs.pop("ids") From 69908baccfaef00764a4c9abde0d583281e906c4 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 19 Sep 2025 10:40:23 +0200 Subject: [PATCH 22/46] fix print statement for fixing epoch issue --- chebai/preprocessing/datasets/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index bc5cd9f7..a8600383 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -256,7 +256,9 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader: Returns: DataLoader: A DataLoader object. """ - rank_zero_info(f"Loading {kind} data... (datamodule.current_epoch={self.current_epoch})") + rank_zero_info( + f"Loading {kind} data... (datamodule.current_epoch={self.curr_epoch if hasattr(self, 'curr_epoch') else 'N/A'})" + ) dataset = self.load_processed_data(kind) if "ids" in kwargs: ids = kwargs.pop("ids") From 6df484ddb57d42ee6a5222e12c2bf051784ee214 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 19 Sep 2025 11:25:48 +0200 Subject: [PATCH 23/46] reformatting --- chebai/models/base.py | 2 +- chebai/models/classic_ml.py | 12 ++++++++++-- chebai/models/lstm.py | 4 +--- chebai/preprocessing/datasets/pubchem.py | 2 +- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index c12bdf03..7653f13c 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -1,10 +1,10 @@ import logging from abc import ABC, abstractmethod from typing import Any, Dict, Iterable, Optional, Union -from lightning.pytorch.utilities.rank_zero import rank_zero_info import torch from lightning.pytorch.core.module import LightningModule +from lightning.pytorch.utilities.rank_zero import rank_zero_info from chebai.preprocessing.structures import XYData diff --git a/chebai/models/classic_ml.py b/chebai/models/classic_ml.py index 554f19df..0c201f91 100644 --- a/chebai/models/classic_ml.py +++ b/chebai/models/classic_ml.py @@ -34,9 +34,17 @@ def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor: p = p.to(x["features"].device) preds.append(p) except NotFittedError: - preds.append(torch.zeros((x["features"].shape[0], 1), device=(x["features"].device))) + preds.append( + torch.zeros( + (x["features"].shape[0], 1), device=(x["features"].device) + ) + ) except AttributeError: - preds.append(torch.zeros((x["features"].shape[0], 1), device=(x["features"].device))) + preds.append( + torch.zeros( + (x["features"].shape[0], 1), device=(x["features"].device) + ) + ) preds = torch.stack(preds, dim=1) print(f"preds shape {preds.shape}") return preds.squeeze(-1) diff --git a/chebai/models/lstm.py b/chebai/models/lstm.py index 529158a1..3613db68 100644 --- a/chebai/models/lstm.py +++ b/chebai/models/lstm.py @@ -17,9 +17,7 @@ def __init__(self, out_d, in_d, num_classes, criterion: nn.Module = None, **kwar num_classes=num_classes, **kwargs, ) - self.lstm = nn.LSTM( - in_d, out_d, batch_first=True, dropout=0.2 - ) + self.lstm = nn.LSTM(in_d, out_d, batch_first=True, dropout=0.2) self.embedding = nn.Embedding(1400, in_d) self.output = nn.Sequential( nn.Linear(out_d, out_d), diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index ed6d60c9..64d932f4 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -245,7 +245,7 @@ def processed_file_names_dict(self) -> List[str]: List[str]: List of processed data file names. """ train_samples = ( - self._k if self._k != self.FULL else 120_000_000 + self._k if self._k != self.FULL else 120_000_000 # estimated PubChem size ) # estimate size train_samples -= self.val_batch_size + self.test_batch_size train_batches = ( From 428868984114f20409b470f9623890e979537e5e Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 19 Sep 2025 11:26:09 +0200 Subject: [PATCH 24/46] reformatting --- chebai/preprocessing/datasets/pubchem.py | 29 +++++++++++------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index 64d932f4..ecde5fac 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -217,9 +217,7 @@ class PubChemBatched(PubChem): READER: Type[dr.ChemDataReader] = dr.ChemDataReader - def __init__( - self, train_batch_size=10_000_000, *args, **kwargs - ): + def __init__(self, train_batch_size=10_000_000, *args, **kwargs): super(PubChemBatched, self).__init__(*args, **kwargs) self.curr_epoch = 0 self.train_batch_size = train_batch_size @@ -245,7 +243,7 @@ def processed_file_names_dict(self) -> List[str]: List[str]: List of processed data file names. """ train_samples = ( - self._k if self._k != self.FULL else 120_000_000 # estimated PubChem size + self._k if self._k != self.FULL else 120_000_000 # estimated PubChem size ) # estimate size train_samples -= self.val_batch_size + self.test_batch_size train_batches = ( @@ -279,21 +277,15 @@ def _tokenize_batched(self, data): batch.append(self.reader.to_data(d)) if i % 1_000_000 == 0 and i > 0: print(f"Saving batch {i // 1_000_000}") - batch = [ - b - for b in batch - if b["features"] is not None - ] + batch = [b for b in batch if b["features"] is not None] if self.n_token_limit is not None: - batch = [b for b in batch if len(b["features"]) <= self.n_token_limit] + batch = [ + b for b in batch if len(b["features"]) <= self.n_token_limit + ] yield batch batch = [] print("Saving final batch") - batch = [ - b - for b in batch - if b["features"] is not None - ] + batch = [b for b in batch if b["features"] is not None] if self.n_token_limit is not None: batch = [b for b in batch if len(b["features"]) <= self.n_token_limit] yield batch @@ -348,13 +340,18 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader DataLoader: A DataLoader object for training data. """ return self.dataloader( - "train" if "train" in self.processed_file_names_dict else f"train_{self.curr_epoch}", + ( + "train" + if "train" in self.processed_file_names_dict + else f"train_{self.curr_epoch}" + ), shuffle=True, num_workers=self.num_workers, persistent_workers=True, **kwargs, ) + class PubChemDissimilar(PubChem): """ Subset of PubChem, but choosing the most dissimilar molecules (according to fingerprint) From 0e6afe2c9120cad513e68a62dc8bba4251e8dde5 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 22 Sep 2025 11:10:09 +0200 Subject: [PATCH 25/46] add num_layers and dropout parameters, make lstm bidirectional --- chebai/models/lstm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/models/lstm.py b/chebai/models/lstm.py index 3613db68..8aac3ed3 100644 --- a/chebai/models/lstm.py +++ b/chebai/models/lstm.py @@ -9,7 +9,7 @@ class ChemLSTM(ChebaiBaseNet): - def __init__(self, out_d, in_d, num_classes, criterion: nn.Module = None, **kwargs): + def __init__(self, out_d, in_d, num_classes, criterion: nn.Module = None, num_layers=6, dropout=0.2, **kwargs): super().__init__( out_dim=out_d, input_dim=in_d, @@ -17,7 +17,7 @@ def __init__(self, out_d, in_d, num_classes, criterion: nn.Module = None, **kwar num_classes=num_classes, **kwargs, ) - self.lstm = nn.LSTM(in_d, out_d, batch_first=True, dropout=0.2) + self.lstm = nn.LSTM(in_d, out_d, batch_first=True, dropout=dropout, bidirectional=True, num_layers=num_layers) self.embedding = nn.Embedding(1400, in_d) self.output = nn.Sequential( nn.Linear(out_d, out_d), From 940ce9d5caf6bea78ef0c0b1a424a6ae33902b7b Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 22 Sep 2025 12:02:04 +0200 Subject: [PATCH 26/46] multi-layer lstm --- chebai/models/lstm.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/chebai/models/lstm.py b/chebai/models/lstm.py index 8aac3ed3..96ecc944 100644 --- a/chebai/models/lstm.py +++ b/chebai/models/lstm.py @@ -1,7 +1,7 @@ import logging from torch import nn -from torch.nn.utils.rnn import pack_padded_sequence +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence from chebai.models.base import ChebaiBaseNet @@ -9,7 +9,16 @@ class ChemLSTM(ChebaiBaseNet): - def __init__(self, out_d, in_d, num_classes, criterion: nn.Module = None, num_layers=6, dropout=0.2, **kwargs): + def __init__( + self, + out_d, + in_d, + num_classes, + criterion: nn.Module = None, + num_layers=6, + dropout=0.2, + **kwargs, + ): super().__init__( out_dim=out_d, input_dim=in_d, @@ -17,10 +26,17 @@ def __init__(self, out_d, in_d, num_classes, criterion: nn.Module = None, num_la num_classes=num_classes, **kwargs, ) - self.lstm = nn.LSTM(in_d, out_d, batch_first=True, dropout=dropout, bidirectional=True, num_layers=num_layers) + self.lstm = nn.LSTM( + in_d, + out_d, + batch_first=True, + dropout=dropout, + bidirectional=True, + num_layers=num_layers, + ) self.embedding = nn.Embedding(1400, in_d) self.output = nn.Sequential( - nn.Linear(out_d, out_d), + nn.Linear(out_d * 2, out_d), nn.ReLU(), nn.Dropout(0.2), nn.Linear(out_d, num_classes), @@ -31,7 +47,9 @@ def forward(self, data, *args, **kwargs): x_lens = data["model_kwargs"]["lens"] x = self.embedding(x) x = pack_padded_sequence(x, x_lens, batch_first=True, enforce_sorted=False) - x = self.lstm(x)[1][0] - # = pad_packed_sequence(x, batch_first=True)[0] + x = self.lstm(x)[0] + x = pad_packed_sequence(x, batch_first=True)[0][ + :, 0 + ] # reduce sequence dimension to first element x = self.output(x) - return x.squeeze(0) + return x From 1e68032a2ccbe6cf080e260df9e820a36cc4e1f1 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 23 Sep 2025 14:02:55 +0200 Subject: [PATCH 27/46] increase vocab_size for PubChem --- configs/model/electra.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/model/electra.yml b/configs/model/electra.yml index c3cf2fdf..663a8fa1 100644 --- a/configs/model/electra.yml +++ b/configs/model/electra.yml @@ -3,7 +3,7 @@ init_args: optimizer_kwargs: lr: 1e-3 config: - vocab_size: 1400 + vocab_size: 4400 max_position_embeddings: 1800 num_attention_heads: 8 num_hidden_layers: 6 From 8ee5c4ba96d14cfdcd84583ce44a10667b5e82e2 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 23 Sep 2025 14:03:48 +0200 Subject: [PATCH 28/46] streamline batch size in PubchemBatched --- chebai/preprocessing/datasets/pubchem.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index ecde5fac..fb18c025 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -217,7 +217,7 @@ class PubChemBatched(PubChem): READER: Type[dr.ChemDataReader] = dr.ChemDataReader - def __init__(self, train_batch_size=10_000_000, *args, **kwargs): + def __init__(self, train_batch_size=1_000_000, *args, **kwargs): super(PubChemBatched, self).__init__(*args, **kwargs) self.curr_epoch = 0 self.train_batch_size = train_batch_size @@ -275,8 +275,8 @@ def _tokenize_batched(self, data): for i, d in enumerate(tqdm.tqdm(data, total=len(data))): if d["features"] is not None: batch.append(self.reader.to_data(d)) - if i % 1_000_000 == 0 and i > 0: - print(f"Saving batch {i // 1_000_000}") + if i % self.train_batch_size == 0 and i > 0: + print(f"Saving batch {i // self.train_batch_size}") batch = [b for b in batch if b["features"] is not None] if self.n_token_limit is not None: batch = [ From 33de8f317e543e92b144f9c894af707768d20bb3 Mon Sep 17 00:00:00 2001 From: sifluegel Date: Tue, 23 Sep 2025 14:19:29 +0200 Subject: [PATCH 29/46] update tokens (full pubchem) --- .../preprocessing/bin/smiles_token/tokens.txt | 3386 +++++++++++++++++ 1 file changed, 3386 insertions(+) diff --git a/chebai/preprocessing/bin/smiles_token/tokens.txt b/chebai/preprocessing/bin/smiles_token/tokens.txt index 9ce39f9d..79600dc5 100644 --- a/chebai/preprocessing/bin/smiles_token/tokens.txt +++ b/chebai/preprocessing/bin/smiles_token/tokens.txt @@ -984,3 +984,3389 @@ p [ClH2+] [BrH2+] [IH2+] +[RuH+2] +[RuH2+2] +[p-] +[15NH] +[Fe+4] +[11CH3] +[P@H] +[Ru+8] +[15n] +[15nH] +[Er+3] +[14CH2] +[Si+3] +[B@@-] +[76Br] +[IH+] +[128Ba] +[BiH] +[14cH] +[14c] +[13NH2] +[Nb+5] +[IH] +[14CH] +[ReH] +[18FH] +[c+] +[RuH2] +[Ru+6] +[IrH4] +[Pt+] +[Mo+2] +[20OH] +[Tc+3] +b +[Dy+3] +[195Pt] +[p+] +[si] +[18OH] +[36Ar] +[68Ga+3] +[RuH3] +[66Ga] +[Al+2] +[18C] +[Nb+3] +[siH] +[75As+3] +[Mn+] +[ClH4+3] +[68Ga] +[Ru+5] +[Mo+] +[Tc+4] +[11CH2] +[211At] +[77Br] +[99Tc+3] +[oH+] +[Nb-2] +[InH] +[P-2] +[184Hf] +[B@-] +[PoH] +[124I] +[14C@H] +[Si@@H] +[35Cl] +[W+] +[37Cl] +[Bi+2] +[13CH4] +[18F-] +[15NH3+] +[Si@H] +[Nb+2] +[98Tc+5] +[Ta+2] +[Rh-] +[151Eu+3] +[RuH] +[63Ni+2] +[NiH] +[PdH2] +[52Mn] +[16OH-] +[Fe+6] +[64Cu] +[194Os] +[Ir+] +[13C-] +[121I] +[Tm+3] +[19BH2] +[Sn+3] +[AlH2+] +[186Re] +[XeH] +[Os+6] +[15N+] +[122I] +[99Tc+6] +[GaH] +[12CH3] +[12C@H] +[AlH2-] +[16OH] +[GeH2-] +[49Ti] +[SiH-2] +[14C@@H] +[11CH4] +[197Hg+] +[Rh+] +[Th+2] +[Yb+2] +[145Eu] +[Cu-] +[RuH+3] +[20CH2] +[SnH2+2] +[136Ba] +[188Re] +[b-] +[se+] +[212Pb+2] +[Ga-] +[WH2] +[232Th] +[225Ac] +[89Zr] +[214Bi] +[pH+] +[TlH] +[99Tc+4] +[10CH2] +[AlH6-3] +[12CH2] +[123IH] +[14C@@] +[6Li+] +[SnH+] +[SnH+3] +[Tb+4] +[99Tc+5] +[125IH] +[144Pm] +[IrH2] +[10BH-] +[10BH2] +[60Co+3] +[14C-] +[NiH2] +[140Ce] +[125I-] +[177Lu+3] +[169Lu] +[85Sr] +[OsH6] +[7Li+] +[18o] +[InH4-] +[OsH2] +[In-] +[11CH] +[ClH] +[13CH2-] +[35P] +[15NH4+] +[RhH+2] +[86Rb+] +[166Ho+3] +[RuH+] +[75Br] +[SiH2-2] +[I@-] +[227Th] +[90Y] +[11c] +[11cH] +[PtH+] +[FeH] +[si-] +[213Bi+3] +[Os+5] +[Te@] +[64Cu+2] +[SbH+] +[14nH] +[14n] +[99Tc+7] +[12C@@H] +[192Bi] +[PtH] +[TaH2] +[32Cl] +[153Sm] +[255Fm] +[133IH] +[12C@] +[AlH-] +[61Cu] +[52Ti] +[117Sn+4] +[83Rb+] +[18O-] +[238Pu] +[165Dy] +[AlH+2] +[16N+] +[141Cs] +[67Cu+2] +[239Am] +[B@H-] +[201Hg] +[231Th] +[126Te] +[17OH] +[66Zn+2] +[Ge-2] +[98Tc+7] +[15n+] +[203Hg+] +[124I-] +[Ge@@] +[207At] +[Tc+5] +[177Lu] +[111In+3] +[CoH2] +[PdH+] +[12c] +[10CH3] +[YH] +[TaH3] +[TaH5] +[12CH] +[Tc+2] +[244Am] +[68Ge] +[35SH] +[RhH] +[MoH2] +[34SH] +[111In] +[RuH4] +[17C] +[Se@] +[65Zn+2] +[15N-] +[PtH2] +[135I] +[123Xe] +[62Zn] +[122Sb] +[si+] +[137La] +[ZrH2] +[53Mn] +[111In-] +[125Cs] +[Tc+6] +[106Pd] +[194Ir] +[159Gd] +[FeH4] +[141Sm] +[111InH3] +[Tc+] +[si+2] +[64Zn+2] +[te+] +[HgH] +[Pd-] +[Zr-2] +[10B-] +[10BH] +[8BH2] +[85Sr+2] +[IrH+2] +[PbH2+2] +[Re-2] +[12B] +[Zr+] +[10BH3] +[11BH3] +[91Y] +[218AtH] +[Ge@] +[CuH+] +[86Y] +[170Yb] +[63Cu+2] +[164Dy] +[173Ta] +[16C] +[ClH+2] +[153Gd+3] +[OsH] +[11C-] +[231Pa] +[TiH] +[229Th] +[72Zn] +[ZrH] +[67Cu] +[14O] +[156Eu] +[155Sm] +[138Ce] +[B@@H-] +[MnH2] +[16NH2] +[51Mn] +[42K] +[MoH5] +[128Sn] +[ClH2+2] +[17F] +[77BrH] +[16n+] +[ZnH+] +[153Sm+3] +[100Tc+4] +[94Ru] +[98Tc] +[IrH3] +[132La] +[242Am] +[14NH] +[162Er] +[208Bi] +[127Xe] +[11CH3-] +[Os+7] +[137Cs+] +[201Tl+] +[13CH+] +[ClH+3] +[129Cs] +[105Rh+3] +[127Sb+3] +[131Cs] +[168Yb] +[17NH] +[9C-] +[33SH2] +[13NH] +[Ge@@H] +[105Ru] +[PdH] +[82Br] +[12cH] +[41Ca] +[184Ir] +[82Rb+] +[14NH2] +[94Zr+4] +[74Se] +[80Br] +[123Te] +[70Zn] +[Tc+7] +[160Dy] +[P@@H] +[148Pm] +[64Zn] +[136Eu+3] +[SnH2+] +[232U] +[234U] +[246Cm] +[24Mg] +[Se@@] +[142Sm] +[68GaH] +[40K+] +[173Yb] +[45Ca+2] +[126IH] +[55Fe+3] +[Ta-2] +[151Nd] +[91Sr] +[Bi-2] +[130Te] +[GaH4-] +[BrH] +[SbH-] +[13CH3+] +[RhH2] +[38Cl-] +[75Ge] +[239Pu] +[ReH7] +[99Tc+2] +[RhH3] +[26Mg] +[Os+8] +[CuH2] +[122Xe] +[Pr+] +[74As] +[239Th] +[SeH2+] +[17OH2] +[136Cs+] +[13CH3-] +[IrH] +[11B-] +[Te@@] +[195Pt+2] +[134Cs+] +[TiH2] +[90Nb] +[146Eu] +[45Ca] +[15NH3] +[SnH-] +[176W] +[110Ru] +[237Pu] +[RuH6] +[217Bi] +[11C@@H] +[150Sm] +[179Lu] +[65Cu+] +[180W] +[132Te] +[90Sr+2] +[14c-] +[213BiH] +[145Pm] +[131SbH3] +[60Co] +[66Ga+3] +[225Ra] +[165Er] +[147Sm] +[129Sb] +[179Hf] +[129Cs+] +[AuH3] +[92Nb] +[GeH6-2] +[233Ra] +[FeH2] +[149Pm] +[ZnH2] +[99Ru] +[AgH] +[1HH] +[200Hg] +[16CH2] +[131I-] +[248Cf] +[CuH] +[232Pa] +[135I-] +[Ge@H] +[AuH] +[67Ga] +[193Pt+4] +[125Te+4] +[7Be] +[10c] +[WH] +[22CH3-] +[105Rh] +[OsH-] +[TaH] +[237Np] +[47V] +[191Pt+2] +[127Cs] +[13O] +[15NH+] +[135Ba] +[67GaH3] +[15OH] +[151Sm] +[18CH2] +[145Nd] +[97Zr] +[249Cf] +[100Tc+] +[I@@-] +[57Fe+2] +[102Pd] +[52Fe+3] +[181Ta+2] +[123I-] +[127I-] +[202Bi] +[106Ru] +[174Yb] +[81Rb+] +[150Pm] +[22C] +[143La] +[66Ni] +[126Sb] +[68GaH3] +[13c-] +[35S-] +[12C-] +[62Cu] +[183Hf] +[VH2] +[182Ta] +[15n-] +[230U] +[253Fm] +[90Y+3] +[237Am] +[173Lu] +[71Ge] +[204TlH] +[SbH2+] +[172Er] +[144Ce] +[107Ag] +[34s] +[CeH3] +[131I+2] +[59Fe+2] +[Mn-2] +[96Tc] +[68Cu] +[25Mg+2] +[105Ag] +[76Se] +[245Bk] +[111InH2] +[93Mo] +[154Gd] +[127Sn] +[Cl@-] +[76As] +[101Mo] +[152Gd] +[193Pt+2] +[12CH4] +[99Y+3] +[173Tm] +[9CH] +[113In+3] +[237U] +[88Sr+2] +[176Yb] +[75BrH] +[BiH4] +[15NH2+] +[242Cm] +[12BH2] +[59Fe] +[14NH3] +[79Kr] +[siH-] +[TcH4] +[69Zn] +[177Hf] +[89Zr+4] +[CrH2] +[125Sb] +[41Ar] +[70Ga] +[69Ga] +[78As] +[143Nd] +[51Cr+3] +[73AsH3] +[167Tm] +[13NH3] +[126SbH3] +[74AsH3] +[WH4] +[9c] +[100Mo] +[199PbH2] +[115Sb] +[176Lu] +[99Ru+2] +[100Pd] +[240Np] +[198Au] +[233Np] +[130I-] +[NbH3] +[95Y] +[16n] +[196Bi] +[181Os] +[CoH+] +[MnH+] +[10Be] +[44Ca+2] +[183Ta] +[155Gd] +[140Ba+2] +[77AsH3] +[235U] +[86Zr] +[131Te] +[17O-] +[17FH] +[250Bk] +[125Xe] +[AsH+] +[187Re] +[79BrH] +[192Ir] +[169Er+3] +[147Tb] +[AlH2-2] +[186Os] +[11CH3+] +[15nH+] +[152Sm+3] +[40PH] +[101Pd] +[47Ti] +[CoH+2] +[53Cr+6] +[227Ac] +[182Re] +[40Ar] +[191Pt+4] +[241Am] +[227Th+4] +[YH2] +[CoH3] +[149Gd] +[137Ba+2] +[39K+] +[Zr-3] +[161Er] +[Os-3] +[181Ta] +[49Ca] +[169Yb] +[45K] +[184W] +[196Au] +[179Ta] +[72Se] +[80Se] +[14CH4] +[210Tl] +[37SH2] +[FeH3] +[62Zn+2] +[15NH-] +[Re-] +[194Au] +[87Sr+2] +[131Ba] +[104Cd] +[131IH] +[124Xe] +[BiH2+2] +[88Nb] +[175Yb+3] +[240U] +[193Pt] +[62Cu+2] +[32P+] +[32PH] +[8B] +[132Cs+] +[LaH3] +[236Np] +[siH+] +[Zr-] +[18OH-] +[134Cs] +[ClH3+2] +[42K+] +[42Ca] +[94Tc+7] +[192Os] +[22Na+] +[38K] +[109Ag] +[136Eu] +[22Na] +[121Sn+2] +[173Hf] +[120I] +[149Tb] +[203Hg+2] +[139Pr] +[73Se] +[240Cm] +[162Dy] +[39Ar] +[89Nb] +[Cd-] +[115Cd] +[253Cf] +[235Pu] +[144Cs] +[18OH3+] +[186Ta] +[115Ag] +[169Yb+3] +[77Kr] +[TiH+] +[138Nd] +[18n] +[34SH2] +[39S] +[92Y] +[135Ce] +[236Pu] +[92Zr] +[50Ti] +[65Ga] +[189Os] +[184Os] +[15CH4] +[131Cs+] +[151Tb] +[38Ar] +[99Mo] +[161Gd] +[CrH+2] +[CoH] +[203PbH] +[81Rb] +[163Dy] +[166Tm] +[bH-] +[31SH] +[86Sr] +[189Ir] +[171Tm] +[194Pb] +[204Hg+] +[231U] +[ZnH] +[59Ni] +[19FH] +[13C+] +[118Sb] +[28Mg+2] +[22c] +[241Cm] +[144Ce+4] +[44Sc] +[38Cl] +[187Ir] +[148Eu] +[57Co+2] +[201TlH3] +[153Pm] +[203PbH2] +[36Cl] +[69Ga+3] +[Co-] +[81Br] +[95Tc+4] +[22CH2] +[170Tm] +[234Np] +[110Sn] +[SH2] +[36ClH] +[TiH4] +[218Pb] +[141Cs+] +[223Ac] +[104Tc] +[239Np] +[198Au+3] +[130SbH3] +[198Bi] +[134Xe] +[109Pd] +[153Gd] +[203Bi] +[253Es] +[XeH2] +[244Cm] +[79Rb+] +[141Pr+3] +[15NH2-] +[86Tc] +[103Pd+2] +[17c] +[82Br-] +[20CH] +[112Pd] +[165Tm] +[89Y+3] +[174Lu] +[23Na+] +[164Ho] +[201Au] +[115In] +[99Tc+] +[19B] +[238Am] +[127Te] +[133I-] +[130Xe] +[83Sr+2] +[184Ta] +[240Am] +[15C] +[197Hg+2] +[186Lu] +[155Eu] +[178Yb] +[35Cl-] +[166Ho] +[70AsH3] +[58Co+2] +[14CH2-] +[137Pr] +[135IH] +[99Y] +[85Rb+] +[13OH] +[90Tc] +[Sn@] +[113In] +[95Ru] +[ReH4] +[15C@@H] +[15CH2] +[109Pd+2] +[47Ca+2] +[17C-] +[17CH] +[58Co] +[38PH3] +[134Ce] +[71Zn] +[110Pd] +[148Nd] +[14N+] +[CrH3] +[58Fe+2] +[235U+2] +[167Er] +[178Ta] +[101Tc] +[130Cs] +[122I-] +[CuH2-] +[158Gd] +[238Th] +[238Np] +[160Tb] +[168Er] +[83BrH] +[246Am] +[199Pb] +[79SeH2] +[157Dy] +[9C] +[FeH6] +[76Kr] +[243Am] +[34S-] +[88Rb+] +[WH3] +[MoH] +[13CH-] +[40PH3] +[218Rn] +[59Co+3] +[172Tm] +[209Bi] +[199Tl+] +[66Ge] +[95Zr] +[71As] +[46Ti] +[232Np] +[48Sc] +[90Zr] +[123I+2] +[159Ho] +[40Ca] +[44K+] +[ZrH2+2] +[19C] +[195Tl] +[126Ba] +[159Gd+3] +[167Yb] +[12C@@] +[13OH2] +[195Ir] +[109Cd] +[109Cd+2] +[87Y] +[35s] +[148Tb] +[81BrH] +[ZrH3] +[162Tm] +[206Bi] +[72AsH3] +[146Nd] +[239U] +[246Bk] +[87Rb+] +[177W] +[176Hf] +[GaH-] +[156Ho] +[101Rh] +[212Bi] +[257Md] +[190Os] +[OsH4] +[46Ca] +[250Es] +[70As] +[57Co] +[55Fe+2] +[122SbH3] +[156Sm] +[ThH4] +[94Mo] +[181Re] +[105Pd] +[13N+] +[139Ba] +[30PH3] +[120I-] +[155Dy] +[84BrH] +[116In] +[PtH4] +[60Ni+2] +[186W] +[107Cd] +[46Sc] +[11C@H] +[95Tc] +[67Zn+2] +[13B] +[112Sn] +[128I] +[193Au] +[103Ru+2] +[136Ce] +[195Pb] +[89Sr+2] +[210PoH2] +[70Se] +[138Xe] +[35SH2] +[UH2] +[BH+] +[61Co] +[VH] +[178W] +[124IH] +[185Ir] +[99Rh] +[18O-2] +[209PbH2] +[120IH] +[91Zr] +[Hf+] +[15C-] +[OsH3] +[119SbH3] +[148Sm] +[149Sm] +[118Pd+2] +[BH4+] +[NiH+] +[29Al] +[58Co+3] +[142Pr] +[212PbH2] +[144Ce+3] +[47Sc] +[200Pb] +[224Rn] +[133Ba] +[53Cr] +[7Be+2] +[26AlH3] +[188Pt] +[12NH3] +[77As] +[182Hf] +[33PH] +[193Os] +[248Cm] +[113Sn] +[121SnH2] +[110Cd] +[43K+] +[NbH2] +[116Te] +[168Tm] +[165Dy+3] +[154Sm] +[162Yb] +[89Rb+] +[47Ca] +[18CH3] +[135Cs+] +[223Fr] +[61Ni] +[24Na+] +[174Hf+4] +[167Ho] +[84Rb+] +[50Cr] +[153Eu] +[38PH] +[194Bi] +[ReH3] +[60Co+2] +[110In] +[77Ge] +[177Re] +[211Bi] +[94Nb] +[222Ra] +[159Dy] +[136Cs] +[ReH6] +[170Lu] +[129I+2] +[61Cu+] +[134Te] +[HgH2] +[93Y] +[BiH2+] +[MnH] +[CeH] +[18o+] +[39ClH] +[EuH3] +[148Gd] +[133Xe] +[142Nd] +[36SH] +[Cl@@-] +[209BiH3] +[210BiH3] +[200Bi] +[SiH4-] +[11CH-] +[52V] +[58Ni] +[185W] +[249Bk] +[72BrH] +[185Ta] +[251Es] +[158Eu] +[243Pu] +[205Pb] +[84Sr] +[37Ar] +[82BrH] +[79Rb] +[208TlH] +[207Bi] +[172Lu] +[15OH2] +[157Tb] +[244Cf] +[15CH] +[95Nb] +[83Kr] +[110Ag+] +[77Br-] +[199TlH] +[17OH-] +[86Y+3] +[90Mo] +[65Cu+2] +[202Hg] +[171Lu] +[13NH2-] +[178Lu] +[212Ra] +[10CH4] +[9CH4] +[171Er] +[125Sn] +[P@@H+] +[142Ce] +[254Fm] +[67Ge] +[87Y+3] +[108Pd] +[104Rh] +[201Bi] +[18CH] +[64Ni] +[181Hf] +[156Dy] +[35S-2] +[151Pm] +[182Ir] +[71Se] +[88Kr] +[56Ni] +[60Fe] +[161Ho] +[NiH2+2] +[84Kr] +[234Pu] +[179W] +[217At] +[54Fe] +[37Cl-] +[MoH4] +[71Ga] +[238U] +[127Cs+] +[76BrH] +[157Ho] +[100Tc] +[234Pa] +[218PoH2] +[17O+] +[HgH+] +[230Th] +[77se] +[35ClH] +[18O+] +[Os-] +[34Cl-] +[228Ac] +[195Pt+4] +[132I-] +[189Re] +[142Ba+2] +[Ta+] +[45Ti] +[254Es] +[203TlH] +[122IH] +[142Pm] +[136Nd] +[80Kr] +[102Ag] +[32ClH] +[13cH-] +[124Sb] +[27Mg] +[113Ag] +[228Pa] +[144Nd] +[44Ca] +[P@H+] +[54Cr] +[246Cf] +[155Tb] +[124Sn] +[201TlH] +[155Ho] +[TiH+3] +[20Ne] +[201Pb] +[166Dy] +[138Cs] +[162Ho] +[211Rn] +[204Tl] +[186Pt] +[228Th] +[170Tm+3] +[100Rh] +[193Ir] +[213Bi] +[157Lu] +[142Ba] +[36SH2] +[15O+] +[129IH] +[230Pu] +[19OH2] +[154Eu+3] +[157Sm] +[195Hg] +[175Yb] +[121Xe] +[112Ag] +[15O-2] +[ClH3+3] +[37ClH] +[252Cf] +[158Dy] +[40K] +[78BrH] +[111Cd+2] +[103Pd] +[88Rb] +[132Xe] +[190Ir] +[22Ne] +[31P-3] +[57Co+3] +[72As] +[122Te] +[90Zr+4] +[57Mn] +[175Hf] +[198Pb] +[96Mo] +[152Dy] +[203Pb] +[34ClH] +[102Rh] +[194Hg] +[233U+4] +[187W] +[54Mn] +[117Sb] +[139Nd] +[117Cd] +[126Sb+3] +[54Fe+3] +[235Np] +[15CH3] +[16CH3] +[SeH5] +[128Te] +[194Tl] +[204Pb] +[200Tl] +[106Rh] +[87Sr] +[125I+2] +[56Co] +[172Hf] +[18C@@H] +[78AsH3] +[49V] +[112In] +[102Ru] +[178Hf] +[167Dy] +[104Pd] +[220Fr] +[14CH-] +[31PH3] +[210PbH2] +[147Eu] +[43Sc] +[31PH] +[191Ir] +[191Os] +[YbH2] +[164Er] +[9Li] +[22nH] +[68Zn] +[132Cs] +[81Se] +[69As] +[86Kr] +[245Am] +[131Sb] +[51Ti] +[58Fe+3] +[166Yb] +[208PbH2] +[InH-] +[157Gd+3] +[144Pr] +[218At] +[164Dy+3] +[117In] +[202Pb] +[94Zr] +[149Eu] +[238Cm] +[139Ce] +[AlH5-2] +[245Pu] +[75Br-] +[82Sr+2] +[94Tc] +[141Pm] +[28Mg] +[133Ba+2] +[114Sn] +[PtH2+2] +[172Yb] +[245Cm] +[103Ag] +[142La] +[169Er] +[32PH3] +[233U] +[74BrH] +[203Pb+2] +[133Te] +[52Cr] +[Zr-4] +[18C-] +[63Ni] +[135La] +[97Tc] +[208Tl] +[89Zr+3] +[16O+] +[97Ru] +[44K] +[48Cr] +[151Gd] +[130Cs+] +[141La] +[205Bi+3] +[103Ru] +[108Cd] +[131La] +[141Ce+3] +[38K+] +[94Y] +[66Cu] +[16OH2] +[14CH3-] +[204Hg] +[224Ac] +[205Bi] +[113I] +[36Cl-] +[170Hf] +[82Rb] +[31S] +[83Rb] +[65Ni] +[74Br-] +[139Cs] +[70Ge] +[106Cd] +[160Gd] +[75SeH] +[199Au] +[84Rb] +[107Rh] +[210Bi] +[121Te] +[188Ir] +[ThH2] +[GeH5-] +[116SbH3] +[21NH3] +[88Y] +[138Pr] +[117SnH2] +[156Gd] +[141Ce] +[19Ne] +[191Pt] +[55Fe] +[118Pd] +[14OH2] +[202PbH2] +[80Sr] +[82Se-2] +[240Pu] +[104Ag] +[114In+3] +[210At] +[196Pb] +[197Pb] +[209Pb] +[210Pb] +[211Pb] +[212Pb] +[213Pb] +[214Pb] +[147Pm] +[126I-] +[141Pr] +[203Tl+] +[SmH3] +[76AsH3] +[24Na] +[107Pd] +[121I-] +[258Md] +[103Rh] +[226Th] +[236U] +[174Ta] +[228Rn] +[138Ba] +[154Tb] +[136Pr] +[80BrH] +[146Ce] +[182W] +[188Os] +[131Xe] +[132Ba] +[252Fm] +[83Se] +[140Ba] +[51Fe] +[246Pu] +[106Ag] +[38SH2] +[48Ca] +[58Fe] +[16NH3] +[63Zn] +[111Sn] +[62Ga] +[44Ti] +[76Br-] +[181W] +[KrH] +[141Nd] +[60Cu] +[9cH] +[56Mn] +[209Tl] +[137Ba] +[248Am] +[216Bi] +[Ti-] +[128Sb] +[146Gd] +[82Kr] +[53Ni] +[108Ag] +[145Gd] +[229Rn] +[85Kr] +[211PbH2] +[180Os] +[166Er] +[81Br-] +[SeH4] +[242Pu] +[154Eu] +[ScH3] +[41Ca+2] +[129I-] +[72Br-] +[75As+5] +[43K] +[116Sb] +[120Te] +[150Nd] +[130Sb] +[195Au] +[175Tm] +[As@] +[ClH2+3] +[73Ga] +[254Cf] +[69Ge] +[247Cm] +[83Sr] +[RuH5] +[98Nb] +[147Nd] +[150Eu] +[MoH3] +[119In] +[144Pr+3] +[97Mo] +[129Te] +[188W] +[206Tl] +[149Nd] +[200Pt] +[82Se+6] +[97Nb] +[149Pr] +[198Hg] +[49Cr] +[135Xe] +[52Fe] +[177Yb] +[48V] +[62Ni] +[21Ne] +[185Os] +[178Re] +[62Co] +[120Sb] +[EuH2] +[182Os] +[127Sb] +[221Fr] +[244Pu] +[68Ge+4] +[197Tl] +[172Ta] +[80Br-] +[BiH+] +[170Er] +[123Sn] +[161Dy] +[202Tl] +[89Sr] +[147Gd] +[150Tb] +[43Ca+2] +[BiH3+2] +[96Zr] +[98Tc+4] +[110Te] +[89Kr] +[145Pr] +[49Sc] +[17NH4+] +[180Hf] +[44Sc+3] +[73As] +[140La] +[137Ce] +[119Sb] +[247Bk] +[76Ge] +[121Sn] +[220Ra] +[156Tb] +[208Tl+] +[153Tb] +[16O-] +[130IH] +[20CH3] +[187Os] +[14NH4+] +[50Cr+3] +[81Sr] +[222Fr] +[55Co] +[41K] +[72Ga] +[78Se] +[137Xe] +[103Cd] +[93Zr] +[126Xe] +[80Rb] +[176Ta] +[199Pt] +[205PbH2] +[197Pt] +[200Au] +[120Xe] +[136Xe] +[20C] +[100Tc+5] +[157Gd] +[17B] +[198Tl] +[SnH2-] +[127IH] +[65Cu] +[186Ir] +[193Hg] +[132IH] +[147Pr] +[145Sm] +[122Sn] +[161Tb] +[110Ag] +[250Cf] +[33PH3] +[241Pu] +[32SH2] +[185Re] +[78Ge] +[106Ru+3] +[146Sm] +[109In] +[17NH3] +[233Pa] +[134IH] +[92Sr+2] +[BH3+] +[64Ga] +[92Sr] +[82Se+4] +[62Cu+] +[226Ac] +[171Yb] +[34S-2] +[249Cm] +[56Fe] +[227Ra] +[143Ce] +[226Rn] +[64Cu+] +[152Tb] +[34S+] +[207Tl] +[111Ag] +[227Pa] +[157Eu] +[184Re] +[72Ge] +[SnH+2] +[117Sn+2] +[230Pa] +[78Kr] +[134Ba] +[199Hg] +[13CH2+] +[250Cm] +[183Re] +[121IH] +[251Cf] +[81Kr] +[125Cs+] +[208Pb] +[143Pm] +[114In] +[113Sn+4] +[82Sr] +[74Ge] +[UH3] +[52Mn+2] +[114Cd] +[33ClH] +[79Br-] +[22CH4] +[70Zn+2] +[144Sm] +[124Te] +[seH+] +[51Cr+6] +[152Sm] +[130Ba] +[Po@] +[174Hf] +[141Ba] +[128IH] +[27Al+3] +[234Th] +[88Zr] +[111IH] +[177Ta] +[191Os+4] +[152Eu] +[48Ti] +[87Kr] +[91Y+3] +[180Ta] +[128Xe] +[143Cs] +[86Rb] +[45K+] +[180Re] +[126Sn] +[146Pm] +[143Pr] +[116Cd] +[89Rb] +[230Ra] +[WH6] +[167Tm+3] +[96Nb] +[92Mo] +[57Ni] +[189Pt] +[134La] +[79Se] +[38ClH] +[125Sn+4] +[243Cm] +[257Fm] +[85Br] +[206Pb] +[138Cs+] +[175Ta] +[16nH] +[138La] +[112Cd] +[93Tc] +[28SiH3] +[166Tb] +[161Tb+3] +[158Tb] +[90Sr] +[32PH2] +[RuH+2] +[RuH2+2] +[p-] +[15NH] +[Fe+4] +[11CH3] +[P@H] +[Ru+8] +[15n] +[15nH] +[Er+3] +[14CH2] +[Si+3] +[B@@-] +[76Br] +[IH+] +[128Ba] +[BiH] +[14cH] +[14c] +[13NH2] +[Nb+5] +[IH] +[14CH] +[ReH] +[18FH] +[c+] +[RuH2] +[Ru+6] +[IrH4] +[Pt+] +[Mo+2] +[20OH] +[Tc+3] +b +[Dy+3] +[195Pt] +[p+] +[si] +[18OH] +[36Ar] +[68Ga+3] +[RuH3] +[66Ga] +[Al+2] +[18C] +[Nb+3] +[siH] +[75As+3] +[Mn+] +[ClH4+3] +[68Ga] +[Ru+5] +[Mo+] +[Tc+4] +[11CH2] +[211At] +[77Br] +[99Tc+3] +[oH+] +[Nb-2] +[InH] +[P-2] +[184Hf] +[B@-] +[PoH] +[124I] +[14C@H] +[Si@@H] +[35Cl] +[W+] +[37Cl] +[Bi+2] +[13CH4] +[18F-] +[15NH3+] +[Si@H] +[Nb+2] +[98Tc+5] +[Ta+2] +[Rh-] +[151Eu+3] +[RuH] +[63Ni+2] +[NiH] +[PdH2] +[52Mn] +[16OH-] +[Fe+6] +[64Cu] +[194Os] +[Ir+] +[13C-] +[121I] +[Tm+3] +[19BH2] +[Sn+3] +[AlH2+] +[186Re] +[XeH] +[Os+6] +[15N+] +[122I] +[99Tc+6] +[GaH] +[12CH3] +[12C@H] +[AlH2-] +[16OH] +[GeH2-] +[49Ti] +[SiH-2] +[14C@@H] +[11CH4] +[197Hg+] +[Rh+] +[Th+2] +[Yb+2] +[145Eu] +[Cu-] +[RuH+3] +[20CH2] +[SnH2+2] +[136Ba] +[188Re] +[b-] +[se+] +[212Pb+2] +[Ga-] +[WH2] +[232Th] +[225Ac] +[89Zr] +[214Bi] +[pH+] +[TlH] +[99Tc+4] +[10CH2] +[AlH6-3] +[12CH2] +[123IH] +[14C@@] +[6Li+] +[SnH+] +[SnH+3] +[Tb+4] +[99Tc+5] +[125IH] +[144Pm] +[IrH2] +[10BH-] +[10BH2] +[60Co+3] +[14C-] +[NiH2] +[140Ce] +[125I-] +[177Lu+3] +[169Lu] +[85Sr] +[OsH6] +[7Li+] +[18o] +[InH4-] +[OsH2] +[In-] +[11CH] +[ClH] +[13CH2-] +[35P] +[15NH4+] +[RhH+2] +[86Rb+] +[166Ho+3] +[RuH+] +[75Br] +[SiH2-2] +[I@-] +[227Th] +[90Y] +[11c] +[11cH] +[PtH+] +[FeH] +[si-] +[213Bi+3] +[Os+5] +[Te@] +[64Cu+2] +[SbH+] +[14nH] +[14n] +[99Tc+7] +[12C@@H] +[192Bi] +[PtH] +[TaH2] +[32Cl] +[153Sm] +[255Fm] +[133IH] +[12C@] +[AlH-] +[61Cu] +[52Ti] +[117Sn+4] +[83Rb+] +[18O-] +[238Pu] +[165Dy] +[AlH+2] +[16N+] +[141Cs] +[67Cu+2] +[239Am] +[B@H-] +[201Hg] +[231Th] +[126Te] +[17OH] +[66Zn+2] +[Ge-2] +[98Tc+7] +[15n+] +[203Hg+] +[124I-] +[Ge@@] +[207At] +[Tc+5] +[177Lu] +[111In+3] +[CoH2] +[PdH+] +[12c] +[10CH3] +[YH] +[TaH3] +[TaH5] +[12CH] +[Tc+2] +[244Am] +[68Ge] +[35SH] +[RhH] +[MoH2] +[34SH] +[111In] +[RuH4] +[17C] +[Se@] +[65Zn+2] +[15N-] +[PtH2] +[135I] +[123Xe] +[62Zn] +[122Sb] +[si+] +[137La] +[ZrH2] +[53Mn] +[111In-] +[125Cs] +[Tc+6] +[106Pd] +[194Ir] +[159Gd] +[FeH4] +[141Sm] +[111InH3] +[Tc+] +[si+2] +[64Zn+2] +[te+] +[HgH] +[Pd-] +[Zr-2] +[10B-] +[10BH] +[8BH2] +[85Sr+2] +[IrH+2] +[PbH2+2] +[Re-2] +[12B] +[Zr+] +[10BH3] +[11BH3] +[91Y] +[218AtH] +[Ge@] +[CuH+] +[86Y] +[170Yb] +[63Cu+2] +[164Dy] +[173Ta] +[16C] +[ClH+2] +[153Gd+3] +[OsH] +[11C-] +[231Pa] +[TiH] +[229Th] +[72Zn] +[ZrH] +[67Cu] +[14O] +[156Eu] +[155Sm] +[138Ce] +[B@@H-] +[MnH2] +[16NH2] +[51Mn] +[42K] +[MoH5] +[128Sn] +[ClH2+2] +[17F] +[77BrH] +[16n+] +[ZnH+] +[153Sm+3] +[100Tc+4] +[94Ru] +[98Tc] +[IrH3] +[132La] +[242Am] +[14NH] +[162Er] +[208Bi] +[127Xe] +[11CH3-] +[Os+7] +[137Cs+] +[201Tl+] +[13CH+] +[ClH+3] +[129Cs] +[105Rh+3] +[127Sb+3] +[131Cs] +[168Yb] +[17NH] +[9C-] +[33SH2] +[13NH] +[Ge@@H] +[105Ru] +[PdH] +[82Br] +[12cH] +[41Ca] +[184Ir] +[82Rb+] +[14NH2] +[94Zr+4] +[74Se] +[80Br] +[123Te] +[70Zn] +[Tc+7] +[160Dy] +[P@@H] +[148Pm] +[64Zn] +[136Eu+3] +[SnH2+] +[232U] +[234U] +[246Cm] +[24Mg] +[Se@@] +[142Sm] +[68GaH] +[40K+] +[173Yb] +[45Ca+2] +[126IH] +[55Fe+3] +[Ta-2] +[151Nd] +[91Sr] +[Bi-2] +[130Te] +[GaH4-] +[BrH] +[SbH-] +[13CH3+] +[RhH2] +[38Cl-] +[75Ge] +[239Pu] +[ReH7] +[99Tc+2] +[RhH3] +[26Mg] +[Os+8] +[CuH2] +[122Xe] +[Pr+] +[74As] +[239Th] +[SeH2+] +[17OH2] +[136Cs+] +[13CH3-] +[IrH] +[11B-] +[Te@@] +[195Pt+2] +[134Cs+] +[TiH2] +[90Nb] +[146Eu] +[45Ca] +[15NH3] +[SnH-] +[176W] +[110Ru] +[237Pu] +[RuH6] +[217Bi] +[11C@@H] +[150Sm] +[179Lu] +[65Cu+] +[180W] +[132Te] +[90Sr+2] +[14c-] +[213BiH] +[145Pm] +[131SbH3] +[60Co] +[66Ga+3] +[225Ra] +[165Er] +[147Sm] +[129Sb] +[179Hf] +[129Cs+] +[AuH3] +[92Nb] +[GeH6-2] +[233Ra] +[FeH2] +[149Pm] +[ZnH2] +[99Ru] +[AgH] +[1HH] +[200Hg] +[16CH2] +[131I-] +[248Cf] +[CuH] +[232Pa] +[135I-] +[Ge@H] +[AuH] +[67Ga] +[193Pt+4] +[125Te+4] +[7Be] +[10c] +[WH] +[22CH3-] +[105Rh] +[OsH-] +[TaH] +[237Np] +[47V] +[191Pt+2] +[127Cs] +[13O] +[15NH+] +[135Ba] +[67GaH3] +[15OH] +[151Sm] +[18CH2] +[145Nd] +[97Zr] +[249Cf] +[100Tc+] +[I@@-] +[57Fe+2] +[102Pd] +[52Fe+3] +[181Ta+2] +[123I-] +[127I-] +[202Bi] +[106Ru] +[174Yb] +[81Rb+] +[150Pm] +[22C] +[143La] +[66Ni] +[126Sb] +[68GaH3] +[13c-] +[35S-] +[12C-] +[62Cu] +[183Hf] +[VH2] +[182Ta] +[15n-] +[230U] +[253Fm] +[90Y+3] +[237Am] +[173Lu] +[71Ge] +[204TlH] +[SbH2+] +[172Er] +[144Ce] +[107Ag] +[34s] +[CeH3] +[131I+2] +[59Fe+2] +[Mn-2] +[96Tc] +[68Cu] +[25Mg+2] +[105Ag] +[76Se] +[245Bk] +[111InH2] +[93Mo] +[154Gd] +[127Sn] +[Cl@-] +[76As] +[101Mo] +[152Gd] +[193Pt+2] +[12CH4] +[99Y+3] +[173Tm] +[9CH] +[113In+3] +[237U] +[88Sr+2] +[176Yb] +[75BrH] +[BiH4] +[15NH2+] +[242Cm] +[12BH2] +[59Fe] +[14NH3] +[79Kr] +[siH-] +[TcH4] +[69Zn] +[177Hf] +[89Zr+4] +[CrH2] +[125Sb] +[41Ar] +[70Ga] +[69Ga] +[78As] +[143Nd] +[51Cr+3] +[73AsH3] +[167Tm] +[13NH3] +[126SbH3] +[74AsH3] +[WH4] +[9c] +[100Mo] +[199PbH2] +[115Sb] +[176Lu] +[99Ru+2] +[100Pd] +[240Np] +[198Au] +[233Np] +[130I-] +[NbH3] +[95Y] +[16n] +[196Bi] +[181Os] +[CoH+] +[MnH+] +[10Be] +[44Ca+2] +[183Ta] +[155Gd] +[140Ba+2] +[77AsH3] +[235U] +[86Zr] +[131Te] +[17O-] +[17FH] +[250Bk] +[125Xe] +[AsH+] +[187Re] +[79BrH] +[192Ir] +[169Er+3] +[147Tb] +[AlH2-2] +[186Os] +[11CH3+] +[15nH+] +[152Sm+3] +[40PH] +[101Pd] +[47Ti] +[CoH+2] +[53Cr+6] +[227Ac] +[182Re] +[40Ar] +[191Pt+4] +[241Am] +[227Th+4] +[YH2] +[CoH3] +[149Gd] +[137Ba+2] +[39K+] +[Zr-3] +[161Er] +[Os-3] +[181Ta] +[49Ca] +[169Yb] +[45K] +[184W] +[196Au] +[179Ta] +[72Se] +[80Se] +[14CH4] +[210Tl] +[37SH2] +[FeH3] +[62Zn+2] +[15NH-] +[Re-] +[194Au] +[87Sr+2] +[131Ba] +[104Cd] +[131IH] +[124Xe] +[BiH2+2] +[88Nb] +[175Yb+3] +[240U] +[193Pt] +[62Cu+2] +[32P+] +[32PH] +[8B] +[132Cs+] +[LaH3] +[236Np] +[siH+] +[Zr-] +[18OH-] +[134Cs] +[ClH3+2] +[42K+] +[42Ca] +[94Tc+7] +[192Os] +[22Na+] +[38K] +[109Ag] +[136Eu] +[22Na] +[121Sn+2] +[173Hf] +[120I] +[149Tb] +[203Hg+2] +[139Pr] +[73Se] +[240Cm] +[162Dy] +[39Ar] +[89Nb] +[Cd-] +[115Cd] +[253Cf] +[235Pu] +[144Cs] +[18OH3+] +[186Ta] +[115Ag] +[169Yb+3] +[77Kr] +[TiH+] +[138Nd] +[18n] +[34SH2] +[39S] +[92Y] +[135Ce] +[236Pu] +[92Zr] +[50Ti] +[65Ga] +[189Os] +[184Os] +[15CH4] +[131Cs+] +[151Tb] +[38Ar] +[99Mo] +[161Gd] +[CrH+2] +[CoH] +[203PbH] +[81Rb] +[163Dy] +[166Tm] +[bH-] +[31SH] +[86Sr] +[189Ir] +[171Tm] +[194Pb] +[204Hg+] +[231U] +[ZnH] +[59Ni] +[19FH] +[13C+] +[118Sb] +[28Mg+2] +[22c] +[241Cm] +[144Ce+4] +[44Sc] +[38Cl] +[187Ir] +[148Eu] +[57Co+2] +[201TlH3] +[153Pm] +[203PbH2] +[36Cl] +[69Ga+3] +[Co-] +[81Br] +[95Tc+4] +[22CH2] +[170Tm] +[234Np] +[110Sn] +[SH2] +[36ClH] +[TiH4] +[218Pb] +[141Cs+] +[223Ac] +[104Tc] +[239Np] +[198Au+3] +[130SbH3] +[198Bi] +[134Xe] +[109Pd] +[153Gd] +[203Bi] +[253Es] +[XeH2] +[244Cm] +[79Rb+] +[141Pr+3] +[15NH2-] +[86Tc] +[103Pd+2] +[17c] +[82Br-] +[20CH] +[112Pd] +[165Tm] +[89Y+3] +[174Lu] +[23Na+] +[164Ho] +[201Au] +[115In] +[99Tc+] +[19B] +[238Am] +[127Te] +[133I-] +[130Xe] +[83Sr+2] +[184Ta] +[240Am] +[15C] +[197Hg+2] +[186Lu] +[155Eu] +[178Yb] +[35Cl-] +[166Ho] +[70AsH3] +[58Co+2] +[14CH2-] +[137Pr] +[135IH] +[99Y] +[85Rb+] +[13OH] +[90Tc] +[Sn@] +[113In] +[95Ru] +[ReH4] +[15C@@H] +[15CH2] +[109Pd+2] +[47Ca+2] +[17C-] +[17CH] +[58Co] +[38PH3] +[134Ce] +[71Zn] +[110Pd] +[148Nd] +[14N+] +[CrH3] +[58Fe+2] +[235U+2] +[167Er] +[178Ta] +[101Tc] +[130Cs] +[122I-] +[CuH2-] +[158Gd] +[238Th] +[238Np] +[160Tb] +[168Er] +[83BrH] +[246Am] +[199Pb] +[79SeH2] +[157Dy] +[9C] +[FeH6] +[76Kr] +[243Am] +[34S-] +[88Rb+] +[WH3] +[MoH] +[13CH-] +[40PH3] +[218Rn] +[59Co+3] +[172Tm] +[209Bi] +[199Tl+] +[66Ge] +[95Zr] +[71As] +[46Ti] +[232Np] +[48Sc] +[90Zr] +[123I+2] +[159Ho] +[40Ca] +[44K+] +[ZrH2+2] +[19C] +[195Tl] +[126Ba] +[159Gd+3] +[167Yb] +[12C@@] +[13OH2] +[195Ir] +[109Cd] +[109Cd+2] +[87Y] +[35s] +[148Tb] +[81BrH] +[ZrH3] +[162Tm] +[206Bi] +[72AsH3] +[146Nd] +[239U] +[246Bk] +[87Rb+] +[177W] +[176Hf] +[GaH-] +[156Ho] +[101Rh] +[212Bi] +[257Md] +[190Os] +[OsH4] +[46Ca] +[250Es] +[70As] +[57Co] +[55Fe+2] +[122SbH3] +[156Sm] +[ThH4] +[94Mo] +[181Re] +[105Pd] +[13N+] +[139Ba] +[30PH3] +[120I-] +[155Dy] +[84BrH] +[116In] +[PtH4] +[60Ni+2] +[186W] +[107Cd] +[46Sc] +[11C@H] +[95Tc] +[67Zn+2] +[13B] +[112Sn] +[128I] +[193Au] +[103Ru+2] +[136Ce] +[195Pb] +[89Sr+2] +[210PoH2] +[70Se] +[138Xe] +[35SH2] +[UH2] +[BH+] +[61Co] +[VH] +[178W] +[124IH] +[185Ir] +[99Rh] +[18O-2] +[209PbH2] +[120IH] +[91Zr] +[Hf+] +[15C-] +[OsH3] +[119SbH3] +[148Sm] +[149Sm] +[118Pd+2] +[BH4+] +[NiH+] +[29Al] +[58Co+3] +[142Pr] +[212PbH2] +[144Ce+3] +[47Sc] +[200Pb] +[224Rn] +[133Ba] +[53Cr] +[7Be+2] +[26AlH3] +[188Pt] +[12NH3] +[77As] +[182Hf] +[33PH] +[193Os] +[248Cm] +[113Sn] +[121SnH2] +[110Cd] +[43K+] +[NbH2] +[116Te] +[168Tm] +[165Dy+3] +[154Sm] +[162Yb] +[89Rb+] +[47Ca] +[18CH3] +[135Cs+] +[223Fr] +[61Ni] +[24Na+] +[174Hf+4] +[167Ho] +[84Rb+] +[50Cr] +[153Eu] +[38PH] +[194Bi] +[ReH3] +[60Co+2] +[110In] +[77Ge] +[177Re] +[211Bi] +[94Nb] +[222Ra] +[159Dy] +[136Cs] +[ReH6] +[170Lu] +[129I+2] +[61Cu+] +[134Te] +[HgH2] +[93Y] +[BiH2+] +[MnH] +[CeH] +[18o+] +[39ClH] +[EuH3] +[148Gd] +[133Xe] +[142Nd] +[36SH] +[Cl@@-] +[209BiH3] +[210BiH3] +[200Bi] +[SiH4-] +[11CH-] +[52V] +[58Ni] +[185W] +[249Bk] +[72BrH] +[185Ta] +[251Es] +[158Eu] +[243Pu] +[205Pb] +[84Sr] +[37Ar] +[82BrH] +[79Rb] +[208TlH] +[207Bi] +[172Lu] +[15OH2] +[157Tb] +[244Cf] +[15CH] +[95Nb] +[83Kr] +[110Ag+] +[77Br-] +[199TlH] +[17OH-] +[86Y+3] +[90Mo] +[65Cu+2] +[202Hg] +[171Lu] +[13NH2-] +[178Lu] +[212Ra] +[10CH4] +[9CH4] +[171Er] +[125Sn] +[P@@H+] +[142Ce] +[254Fm] +[67Ge] +[87Y+3] +[108Pd] +[104Rh] +[201Bi] +[18CH] +[64Ni] +[181Hf] +[156Dy] +[35S-2] +[151Pm] +[182Ir] +[71Se] +[88Kr] +[56Ni] +[60Fe] +[161Ho] +[NiH2+2] +[84Kr] +[234Pu] +[179W] +[217At] +[54Fe] +[37Cl-] +[MoH4] +[71Ga] +[238U] +[127Cs+] +[76BrH] +[157Ho] +[100Tc] +[234Pa] +[218PoH2] +[17O+] +[HgH+] +[230Th] +[77se] +[35ClH] +[18O+] +[Os-] +[34Cl-] +[228Ac] +[195Pt+4] +[132I-] +[189Re] +[142Ba+2] +[Ta+] +[45Ti] +[254Es] +[203TlH] +[122IH] +[142Pm] +[136Nd] +[80Kr] +[102Ag] +[32ClH] +[13cH-] +[124Sb] +[27Mg] +[113Ag] +[228Pa] +[144Nd] +[44Ca] +[P@H+] +[54Cr] +[246Cf] +[155Tb] +[124Sn] +[201TlH] +[155Ho] +[TiH+3] +[20Ne] +[201Pb] +[166Dy] +[138Cs] +[162Ho] +[211Rn] +[204Tl] +[186Pt] +[228Th] +[170Tm+3] +[100Rh] +[193Ir] +[213Bi] +[157Lu] +[142Ba] +[36SH2] +[15O+] +[129IH] +[230Pu] +[19OH2] +[154Eu+3] +[157Sm] +[195Hg] +[175Yb] +[121Xe] +[112Ag] +[15O-2] +[ClH3+3] +[37ClH] +[252Cf] +[158Dy] +[40K] +[78BrH] +[111Cd+2] +[103Pd] +[88Rb] +[132Xe] +[190Ir] +[22Ne] +[31P-3] +[57Co+3] +[72As] +[122Te] +[90Zr+4] +[57Mn] +[175Hf] +[198Pb] +[96Mo] +[152Dy] +[203Pb] +[34ClH] +[102Rh] +[194Hg] +[233U+4] +[187W] +[54Mn] +[117Sb] +[139Nd] +[117Cd] +[126Sb+3] +[54Fe+3] +[235Np] +[15CH3] +[16CH3] +[SeH5] +[128Te] +[194Tl] +[204Pb] +[200Tl] +[106Rh] +[87Sr] +[125I+2] +[56Co] +[172Hf] +[18C@@H] +[78AsH3] +[49V] +[112In] +[102Ru] +[178Hf] +[167Dy] +[104Pd] +[220Fr] +[14CH-] +[31PH3] +[210PbH2] +[147Eu] +[43Sc] +[31PH] +[191Ir] +[191Os] +[YbH2] +[164Er] +[9Li] +[22nH] +[68Zn] +[132Cs] +[81Se] +[69As] +[86Kr] +[245Am] +[131Sb] +[51Ti] +[58Fe+3] +[166Yb] +[208PbH2] +[InH-] +[157Gd+3] +[144Pr] +[218At] +[164Dy+3] +[117In] +[202Pb] +[94Zr] +[149Eu] +[238Cm] +[139Ce] +[AlH5-2] +[245Pu] +[75Br-] +[82Sr+2] +[94Tc] +[141Pm] +[28Mg] +[133Ba+2] +[114Sn] +[PtH2+2] +[172Yb] +[245Cm] +[103Ag] +[142La] +[169Er] +[32PH3] +[233U] +[74BrH] +[203Pb+2] +[133Te] +[52Cr] +[Zr-4] +[18C-] +[63Ni] +[135La] +[97Tc] +[208Tl] +[89Zr+3] +[16O+] +[97Ru] +[44K] +[48Cr] +[151Gd] +[130Cs+] +[141La] +[205Bi+3] +[103Ru] +[108Cd] +[131La] +[141Ce+3] +[38K+] +[94Y] +[66Cu] +[16OH2] +[14CH3-] +[204Hg] +[224Ac] +[205Bi] +[113I] +[36Cl-] +[170Hf] +[82Rb] +[31S] +[83Rb] +[65Ni] +[74Br-] +[139Cs] +[70Ge] +[106Cd] +[160Gd] +[75SeH] +[199Au] +[84Rb] +[107Rh] +[210Bi] +[121Te] +[188Ir] +[ThH2] +[GeH5-] +[116SbH3] +[21NH3] +[88Y] +[138Pr] +[117SnH2] +[156Gd] +[141Ce] +[19Ne] +[191Pt] +[55Fe] +[118Pd] +[14OH2] +[202PbH2] +[80Sr] +[82Se-2] +[240Pu] +[104Ag] +[114In+3] +[210At] +[196Pb] +[197Pb] +[209Pb] +[210Pb] +[211Pb] +[212Pb] +[213Pb] +[214Pb] +[147Pm] +[126I-] +[141Pr] +[203Tl+] +[SmH3] +[76AsH3] +[24Na] +[107Pd] +[121I-] +[258Md] +[103Rh] +[226Th] +[236U] +[174Ta] +[228Rn] +[138Ba] +[154Tb] +[136Pr] +[80BrH] +[146Ce] +[182W] +[188Os] +[131Xe] +[132Ba] +[252Fm] +[83Se] +[140Ba] +[51Fe] +[246Pu] +[106Ag] +[38SH2] +[48Ca] +[58Fe] +[16NH3] +[63Zn] +[111Sn] +[62Ga] +[44Ti] +[76Br-] +[181W] +[KrH] +[141Nd] +[60Cu] +[9cH] +[56Mn] +[209Tl] +[137Ba] +[248Am] +[216Bi] +[Ti-] +[128Sb] +[146Gd] +[82Kr] +[53Ni] +[108Ag] +[145Gd] +[229Rn] +[85Kr] +[211PbH2] +[180Os] +[166Er] +[81Br-] +[SeH4] +[242Pu] +[154Eu] +[ScH3] +[41Ca+2] +[129I-] +[72Br-] +[75As+5] +[43K] +[116Sb] +[120Te] +[150Nd] +[130Sb] +[195Au] +[175Tm] +[As@] +[ClH2+3] +[73Ga] +[254Cf] +[69Ge] +[247Cm] +[83Sr] +[RuH5] +[98Nb] +[147Nd] +[150Eu] +[MoH3] +[119In] +[144Pr+3] +[97Mo] +[129Te] +[188W] +[206Tl] +[149Nd] +[200Pt] +[82Se+6] +[97Nb] +[149Pr] +[198Hg] +[49Cr] +[135Xe] +[52Fe] +[177Yb] +[48V] +[62Ni] +[21Ne] +[185Os] +[178Re] +[62Co] +[120Sb] +[EuH2] +[182Os] +[127Sb] +[221Fr] +[244Pu] +[68Ge+4] +[197Tl] +[172Ta] +[80Br-] +[BiH+] +[170Er] +[123Sn] +[161Dy] +[202Tl] +[89Sr] +[147Gd] +[150Tb] +[43Ca+2] +[BiH3+2] +[96Zr] +[98Tc+4] +[110Te] +[89Kr] +[145Pr] +[49Sc] +[17NH4+] +[180Hf] +[44Sc+3] +[73As] +[140La] +[137Ce] +[119Sb] +[247Bk] +[76Ge] +[121Sn] +[220Ra] +[156Tb] +[208Tl+] +[153Tb] +[16O-] +[130IH] +[20CH3] +[187Os] +[14NH4+] +[50Cr+3] +[81Sr] +[222Fr] +[55Co] +[41K] +[72Ga] +[78Se] +[137Xe] +[103Cd] +[93Zr] +[126Xe] +[80Rb] +[176Ta] +[199Pt] +[205PbH2] +[197Pt] +[200Au] +[120Xe] +[136Xe] +[20C] +[100Tc+5] +[157Gd] +[17B] +[198Tl] +[SnH2-] +[127IH] +[65Cu] +[186Ir] +[193Hg] +[132IH] +[147Pr] +[145Sm] +[122Sn] +[161Tb] +[110Ag] +[250Cf] +[33PH3] +[241Pu] +[32SH2] +[185Re] +[78Ge] +[106Ru+3] +[146Sm] +[109In] +[17NH3] +[233Pa] +[134IH] +[92Sr+2] +[BH3+] +[64Ga] +[92Sr] +[82Se+4] +[62Cu+] +[226Ac] +[171Yb] +[34S-2] +[249Cm] +[56Fe] +[227Ra] +[143Ce] +[226Rn] +[64Cu+] +[152Tb] +[34S+] +[207Tl] +[111Ag] +[227Pa] +[157Eu] +[184Re] +[72Ge] +[SnH+2] +[117Sn+2] +[230Pa] +[78Kr] +[134Ba] +[199Hg] +[13CH2+] +[250Cm] +[183Re] +[121IH] +[251Cf] +[81Kr] +[125Cs+] +[208Pb] +[143Pm] +[114In] +[113Sn+4] +[82Sr] +[74Ge] +[UH3] +[52Mn+2] +[114Cd] +[33ClH] +[79Br-] +[22CH4] +[70Zn+2] +[144Sm] +[124Te] +[seH+] +[51Cr+6] +[152Sm] +[130Ba] +[Po@] +[174Hf] +[141Ba] +[128IH] +[27Al+3] +[234Th] +[88Zr] +[111IH] +[177Ta] +[191Os+4] +[152Eu] +[48Ti] +[87Kr] +[91Y+3] +[180Ta] +[128Xe] +[143Cs] +[86Rb] +[45K+] +[180Re] +[126Sn] +[146Pm] +[143Pr] +[116Cd] +[89Rb] +[230Ra] +[WH6] +[167Tm+3] +[96Nb] +[92Mo] +[57Ni] +[189Pt] +[134La] +[79Se] +[38ClH] +[125Sn+4] +[243Cm] +[257Fm] +[85Br] +[206Pb] +[138Cs+] +[175Ta] +[16nH] +[138La] +[112Cd] +[93Tc] +[28SiH3] +[166Tb] +[161Tb+3] +[158Tb] +[90Sr] +[32PH2] From c73b0fb9ccb949ebc976a3d31adb3dbe7f2c349c Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 24 Sep 2025 10:06:10 +0200 Subject: [PATCH 30/46] fix number of expected pubchem batches --- chebai/preprocessing/datasets/pubchem.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index fb18c025..0625040e 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -251,7 +251,7 @@ def processed_file_names_dict(self) -> List[str]: if train_samples <= self.train_batch_size else { f"train_{i}": f"train_{i}.pt" - for i in range((train_samples // self.train_batch_size) + 1) + for i in range(train_samples // self.train_batch_size) } ) train_batches["test"] = "test.pt" @@ -276,7 +276,7 @@ def _tokenize_batched(self, data): if d["features"] is not None: batch.append(self.reader.to_data(d)) if i % self.train_batch_size == 0 and i > 0: - print(f"Saving batch {i // self.train_batch_size}") + print(f"Generating batch {i // self.train_batch_size - 1}") batch = [b for b in batch if b["features"] is not None] if self.n_token_limit is not None: batch = [ @@ -284,7 +284,7 @@ def _tokenize_batched(self, data): ] yield batch batch = [] - print("Saving final batch") + print("Generating final batch") batch = [b for b in batch if b["features"] is not None] if self.n_token_limit is not None: batch = [b for b in batch if len(b["features"]) <= self.n_token_limit] From 905ffc242aae270ad6994a9a072b449f74c41027 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 24 Sep 2025 16:07:54 +0200 Subject: [PATCH 31/46] more options for LR --- chebai/models/classic_ml.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/chebai/models/classic_ml.py b/chebai/models/classic_ml.py index 0c201f91..83479f97 100644 --- a/chebai/models/classic_ml.py +++ b/chebai/models/classic_ml.py @@ -1,5 +1,5 @@ import pickle as pkl -from typing import Any, Dict +from typing import Any, Dict, List, Optional import numpy as np import torch @@ -8,18 +8,22 @@ from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression from chebai.models.base import ChebaiBaseNet +import os +LR_MODEL_PATH = os.path.join("models", "LR") class LogisticRegression(ChebaiBaseNet): """ Logistic Regression model using scikit-learn, wrapped to fit the ChebaiBaseNet interface. """ - def __init__(self, out_dim: int, input_dim: int, **kwargs): + def __init__(self, out_dim: int, input_dim: int, only_predict_classes: Optional[List] = None, n_classes=1528, **kwargs): super().__init__(out_dim=out_dim, input_dim=input_dim, **kwargs) self.models = [ - SklearnLogisticRegression(solver="liblinear") for _ in range(300) + SklearnLogisticRegression(solver="liblinear") for _ in range(n_classes) ] + # indices of classes (in the dataset used for training) where a model should be trained + self.only_predict_classes = only_predict_classes def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor: print( @@ -36,13 +40,13 @@ def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor: except NotFittedError: preds.append( torch.zeros( - (x["features"].shape[0], 1), device=(x["features"].device) + (x["features"].shape[0]), device=(x["features"].device) ) ) except AttributeError: preds.append( torch.zeros( - (x["features"].shape[0], 1), device=(x["features"].device) + (x["features"].shape[0]), device=(x["features"].device) ) ) preds = torch.stack(preds, dim=1) @@ -56,16 +60,18 @@ def fit_sklearn(self, X, y): for i, model in tqdm.tqdm(enumerate(self.models), desc="Fitting models"): import os - if os.path.exists(f"LR_CHEBI100_model_{i}.pkl"): + if os.path.exists(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl")): print(f"Loading model {i} from file") - self.models[i] = pkl.load(open(f"LR_CHEBI100_model_{i}.pkl", "rb")) + self.models[i] = pkl.load(open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "rb")) else: + if self.only_predict_classes and i not in self.only_predict_classes: # only try these classes + continue try: model.fit(X, y[:, i]) except ValueError: self.models[i] = PlaceholderModel() # dump - pkl.dump(model, open(f"LR_CHEBI100_model_{i}.pkl", "wb")) + pkl.dump(model, open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "wb")) def configure_optimizers(self, **kwargs): pass From c1da0923641c613fae7a53f564255e1fb0828a0f Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 24 Sep 2025 16:25:18 +0200 Subject: [PATCH 32/46] reformat --- chebai/models/classic_ml.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/chebai/models/classic_ml.py b/chebai/models/classic_ml.py index 83479f97..c63d94a6 100644 --- a/chebai/models/classic_ml.py +++ b/chebai/models/classic_ml.py @@ -1,3 +1,4 @@ +import os import pickle as pkl from typing import Any, Dict, List, Optional @@ -8,16 +9,23 @@ from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression from chebai.models.base import ChebaiBaseNet -import os LR_MODEL_PATH = os.path.join("models", "LR") + class LogisticRegression(ChebaiBaseNet): """ Logistic Regression model using scikit-learn, wrapped to fit the ChebaiBaseNet interface. """ - def __init__(self, out_dim: int, input_dim: int, only_predict_classes: Optional[List] = None, n_classes=1528, **kwargs): + def __init__( + self, + out_dim: int, + input_dim: int, + only_predict_classes: Optional[List] = None, + n_classes=1528, + **kwargs, + ): super().__init__(out_dim=out_dim, input_dim=input_dim, **kwargs) self.models = [ SklearnLogisticRegression(solver="liblinear") for _ in range(n_classes) @@ -39,15 +47,11 @@ def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor: preds.append(p) except NotFittedError: preds.append( - torch.zeros( - (x["features"].shape[0]), device=(x["features"].device) - ) + torch.zeros((x["features"].shape[0]), device=(x["features"].device)) ) except AttributeError: preds.append( - torch.zeros( - (x["features"].shape[0]), device=(x["features"].device) - ) + torch.zeros((x["features"].shape[0]), device=(x["features"].device)) ) preds = torch.stack(preds, dim=1) print(f"preds shape {preds.shape}") @@ -62,16 +66,22 @@ def fit_sklearn(self, X, y): if os.path.exists(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl")): print(f"Loading model {i} from file") - self.models[i] = pkl.load(open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "rb")) + self.models[i] = pkl.load( + open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "rb") + ) else: - if self.only_predict_classes and i not in self.only_predict_classes: # only try these classes + if ( + self.only_predict_classes and i not in self.only_predict_classes + ): # only try these classes continue try: model.fit(X, y[:, i]) except ValueError: self.models[i] = PlaceholderModel() # dump - pkl.dump(model, open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "wb")) + pkl.dump( + model, open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "wb") + ) def configure_optimizers(self, **kwargs): pass From 078bfb6a4c6af587b8a8e6b610afdca63dfd5286 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 24 Sep 2025 18:27:36 +0200 Subject: [PATCH 33/46] add subset parameter for chebi data --- chebai/preprocessing/datasets/chebi.py | 33 +++++++++++++++++++++----- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 437a23d3..7f5cba70 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -13,7 +13,7 @@ import pickle from abc import ABC from collections import OrderedDict -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union import fastobo import networkx as nx @@ -110,15 +110,14 @@ class _ChEBIDataExtractor(_DynamicDataset, ABC): chebi_version will be used for training, validation and test. Defaults to None. single_class (int, optional): The ID of the single class to predict. If not set, all available labels will be predicted. Defaults to None. - dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. - splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. + subset (Literal["2_STAR", "3_STAR"], optional): If set, only use entities that are part of the given subset. **kwargs: Additional keyword arguments (passed to XYBaseDataModule). Attributes: single_class (Optional[int]): The ID of the single class to predict. chebi_version_train (Optional[int]): The version of ChEBI to use for training and validation. - dynamic_data_split_seed (int): The seed for random data splitting, default is 42. - splits_file_path (Optional[str]): Path to csv file containing split assignments. + subset (Optional[Literal["2_STAR", "3_STAR"]]): If set, only use entities that are part of the given subset. + """ # ---- Index for columns of processed `data.pkl` (derived from `_graph_to_raw_dataset` method) ------ @@ -134,6 +133,7 @@ def __init__( self, chebi_version_train: Optional[int] = None, single_class: Optional[int] = None, + subset: Optional[Literal["2_STAR", "3_STAR"]] = None, **kwargs, ): # predict only single class (given as id of one of the classes present in the raw data set) @@ -153,6 +153,8 @@ def __init__( **_init_kwargs, ) + self.subset = subset + # ------------------------------ Phase: Prepare data ----------------------------------- def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None: """ @@ -246,7 +248,9 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: and term_doc.id.prefix == "CHEBI" ): term_dict = term_callback(term_doc) - if term_dict: + if term_dict and ( + not self.subset or term_dict["subset"] == self.subset + ): elements.append(term_dict) g = nx.DiGraph() @@ -515,6 +519,20 @@ def base_dir(self) -> str: """ return os.path.join("data", f"chebi_v{self.chebi_version}") + @property + def processed_dir_main(self) -> str: + """ + Returns the main directory path where processed data is stored. + + Returns: + str: The path to the main processed data directory, based on the base directory and the instance's name. + """ + return os.path.join( + self.base_dir, + self._name if self.subset is None else f"{self._name}_{self.subset}", + "processed", + ) + @property def processed_dir(self) -> str: """ @@ -890,6 +908,8 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]: parents.append(chebi_to_int(str(clause.term))) elif isinstance(clause, fastobo.term.NameClause): name = str(clause.name) + elif isinstance(clause, fastobo.term.SubsetClause): + subset = str(clause.subset) if isinstance(clause, fastobo.term.IsObsoleteClause): if clause.obsolete: @@ -902,6 +922,7 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]: "has_part": parts, "name": name, "smiles": smiles, + "subset": subset, } From 85656dad6a680604a747bb2c2e1ffce51e2bfaa3 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 24 Sep 2025 18:48:43 +0200 Subject: [PATCH 34/46] add token (chebi_v243) --- chebai/preprocessing/bin/smiles_token/tokens.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/chebai/preprocessing/bin/smiles_token/tokens.txt b/chebai/preprocessing/bin/smiles_token/tokens.txt index 79600dc5..960173cd 100644 --- a/chebai/preprocessing/bin/smiles_token/tokens.txt +++ b/chebai/preprocessing/bin/smiles_token/tokens.txt @@ -4370,3 +4370,4 @@ b [158Tb] [90Sr] [32PH2] +[CaH2] From 5c84ec7d38f719ac07a26a3c7d661e750c610655 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 24 Sep 2025 20:10:57 +0200 Subject: [PATCH 35/46] add custom fit loop for custom hook handling --- chebai/trainer/CustomTrainer.py | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index 2ecee680..9cb3d2d2 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -4,8 +4,11 @@ import pandas as pd import torch from lightning import LightningModule, Trainer +from lightning.fabric.utilities.data import _set_sampler_epoch from lightning.fabric.utilities.types import _PATH from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.loops.fit_loop import _FitLoop +from lightning.pytorch.trainer import call from torch.nn.utils.rnn import pad_sequence from chebai.loggers.custom import CustomLogger @@ -39,6 +42,9 @@ def __init__(self, *args, **kwargs): log_kwargs[log_key] = log_value self.logger.log_hyperparams(log_kwargs) + # use custom fit loop (https://lightning.ai/docs/pytorch/LTS/extensions/loops.html#overriding-the-default-loops) + self.fit_loop = LoadDataLaterFitLoop(self, self.min_epochs, self.max_epochs) + def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: """ Resolves logging arguments, handling nested structures such as lists and complex objects. @@ -147,3 +153,35 @@ def log_dir(self) -> Optional[str]: dirpath = self.strategy.broadcast(dirpath) return dirpath + + +class LoadDataLaterFitLoop(_FitLoop): + + def on_advance_start(self) -> None: + """Calls the hook ``on_train_epoch_start`` **before** the dataloaders are setup. This is necessary + so that the dataloaders can get information from the model. For example: The on_train_epoch_start + hook sets the curr_epoch attribute of the PubChemBatched dataset. With the Lightning configuration, + the dataloaders would always load batch 0 first, run an epoch, then get the epoch number (usually 0, + unless resuming from a checkpoint), then load batch 0 again (or some other batch). With this + implementation, the dataloaders are setup after the epoch number is set, so that the correct + batch is loaded.""" + trainer = self.trainer + + # update the epoch value for all samplers + assert self._combined_loader is not None + for i, dl in enumerate(self._combined_loader.flattened): + _set_sampler_epoch(dl, self.epoch_progress.current.processed) + + self.restarted + if not self.restarted_mid_epoch and not self.restarted_on_epoch_end: + if not self.restarted_on_epoch_start: + self.epoch_progress.increment_ready() + + call._call_callback_hooks(trainer, "on_train_epoch_start") + call._call_lightning_module_hook(trainer, "on_train_epoch_start") + + self.epoch_progress.increment_started() + + # this is usually at the front of advance_start, but here we need it at the end + # might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs` + self.setup_data() From 182a3b133c737d4cfe363f7eb1b490c938c0bdb8 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 25 Sep 2025 09:56:02 +0200 Subject: [PATCH 36/46] fix typo --- chebai/trainer/CustomTrainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index 9cb3d2d2..f7c7c3e2 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -172,7 +172,6 @@ def on_advance_start(self) -> None: for i, dl in enumerate(self._combined_loader.flattened): _set_sampler_epoch(dl, self.epoch_progress.current.processed) - self.restarted if not self.restarted_mid_epoch and not self.restarted_on_epoch_end: if not self.restarted_on_epoch_start: self.epoch_progress.increment_ready() From 86044affa54c614018f023af17e12e410c96b1c5 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 25 Sep 2025 12:51:08 +0200 Subject: [PATCH 37/46] set subset before using it --- chebai/preprocessing/datasets/chebi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 7260e65e..9ab0ce05 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -149,6 +149,8 @@ def __init__( ): # predict only single class (given as id of one of the classes present in the raw data set) self.single_class = single_class + self.subset = subset + super(_ChEBIDataExtractor, self).__init__(**kwargs) # use different version of chebi for training and validation (if not None) # (still uses self.chebi_version for test set) @@ -164,8 +166,6 @@ def __init__( **_init_kwargs, ) - self.subset = subset - # ------------------------------ Phase: Prepare data ----------------------------------- def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None: """ From 4c58dcb4fd4a616758077cf6517f9d50113cac70 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 30 Sep 2025 10:26:18 +0200 Subject: [PATCH 38/46] add electra freeze option --- chebai/models/electra.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 589f0b02..c053db1c 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -224,6 +224,7 @@ def __init__( config: Optional[Dict[str, Any]] = None, pretrained_checkpoint: Optional[str] = None, load_prefix: Optional[str] = None, + freeze_electra: bool = False, **kwargs: Any, ): # Remove this property in order to prevent it from being stored as a @@ -262,6 +263,10 @@ def __init__( else: self.electra = ElectraModel(config=self.config) + if freeze_electra: + for param in self.electra.parameters(): + param.requires_grad = False + def _process_for_loss( self, model_output: Dict[str, Tensor], From 4f506b1763f2753c85253da7ebfeccc91cc6ab42 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 1 Oct 2025 13:51:07 +0200 Subject: [PATCH 39/46] make processing label rows safe if input is numpy array --- chebai/preprocessing/collate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/preprocessing/collate.py b/chebai/preprocessing/collate.py index ecbcb876..b420ef47 100644 --- a/chebai/preprocessing/collate.py +++ b/chebai/preprocessing/collate.py @@ -130,7 +130,7 @@ def process_label_rows(self, labels: Tuple) -> torch.Tensor: """ return pad_sequence( [ - torch.tensor([v if v is not None else False for v in row]) + torch.tensor([bool(v) if v is not None else False for v in row]) for row in labels ], batch_first=True, From eb86e3fb42436b67fbfcb5dfd8b7821584925d62 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 8 Oct 2025 15:28:00 +0200 Subject: [PATCH 40/46] cast to model device --- chebai/result/generate_class_properties.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/chebai/result/generate_class_properties.py b/chebai/result/generate_class_properties.py index 8c8f96bf..3cb51b0c 100644 --- a/chebai/result/generate_class_properties.py +++ b/chebai/result/generate_class_properties.py @@ -172,13 +172,18 @@ def generate_props( class_names = self.load_class_labels(classes_file) num_classes = len(class_names) metrics_obj_dict: dict[str, torchmetrics.Metric] = { - "cm": MultilabelConfusionMatrix(num_labels=num_classes), - "f1": MultilabelF1Score(num_labels=num_classes, average=None), + "cm": MultilabelConfusionMatrix(num_labels=num_classes).to( + device=model.device + ), + "f1": MultilabelF1Score(num_labels=num_classes, average=None).to( + device=model.device + ), } for batch_idx, batch in enumerate(data_loader): data = model._process_batch(batch, batch_idx=batch_idx) - labels = data["labels"] + labels = data["labels"].to(device=model.device) + data["features"][0].to(device=model.device) model_output = model(data, **data.get("model_kwargs", {})) preds, targets = model._get_prediction_and_labels( data, labels, model_output @@ -241,7 +246,8 @@ def generate( if __name__ == "__main__": - # _generate_classes_props_json.py generate \ + # Usage: + # generate_classes_properties.py generate \ # --data_partition "val" \ # --model_ckpt_path "model/ckpt/path" \ # --model_config_file_path "model/config/file/path" \ From 4ab760e801209c44b54ca9f919e131196a3b8d3c Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 14 Oct 2025 13:42:30 +0200 Subject: [PATCH 41/46] add label filter --- chebai/preprocessing/datasets/base.py | 18 ++++++++++++++++++ chebai/result/generate_class_properties.py | 6 +++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index b42b4ae9..0d29e7f7 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union import lightning as pl +import numpy as np import pandas as pd import torch import tqdm @@ -708,11 +709,14 @@ class _DynamicDataset(XYBaseDataModule, ABC): Args: dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. + apply_label_filter (Optional[str]): Path to a classes.txt file - only labels that are in the labels filter + file will be used (in that order). All labels in the label filter have to be present in the dataset. **kwargs: Additional keyword arguments passed to XYBaseDataModule. Attributes: dynamic_data_split_seed (int): The seed for random data splitting, default is 42. splits_file_path (Optional[str]): Path to the CSV file containing split assignments. + apply_label_filter (Optional[str]): Path to a classes.txt file for label filtering. """ # ---- Index for columns of processed `data.pkl` (should be derived from `_graph_to_raw_dataset` method) ------ @@ -722,6 +726,7 @@ class _DynamicDataset(XYBaseDataModule, ABC): def __init__( self, + apply_label_filter: Optional[str] = None, **kwargs, ): super(_DynamicDataset, self).__init__(**kwargs) @@ -735,6 +740,7 @@ def __init__( self.splits_file_path = self._validate_splits_file_path( kwargs.get("splits_file_path", None) ) + self.apply_label_filter = apply_label_filter @staticmethod def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]: @@ -1134,6 +1140,18 @@ def _retrieve_splits_from_csv(self) -> None: ) df_data = pd.DataFrame(data) + if self.apply_label_filter: + print(f"Applying label filter from {self.apply_label_filter}...") + with open(self.apply_label_filter, "r") as f: + label_filter = [line.strip() for line in f] + with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf: + classes = [line.strip() for line in cf] + # reorder labels + old_labels = np.stack(df_data["labels"]) + label_mapping = [classes.index(lbl) for lbl in label_filter] + new_labels = old_labels[:, label_mapping] + df_data["labels"] = list(new_labels) + train_ids = splits_df[splits_df["split"] == "train"]["id"] validation_ids = splits_df[splits_df["split"] == "validation"]["id"] test_ids = splits_df[splits_df["split"] == "test"]["id"] diff --git a/chebai/result/generate_class_properties.py b/chebai/result/generate_class_properties.py index 3cb51b0c..8d744b7b 100644 --- a/chebai/result/generate_class_properties.py +++ b/chebai/result/generate_class_properties.py @@ -168,7 +168,10 @@ def generate_props( raise ValueError(f"Unknown data partition: {data_partition}") print(f"Running inference on {data_partition} data...") - classes_file = Path(data_module.processed_dir_main) / "classes.txt" + if data_module.apply_label_filter is not None: + classes_file = data_module.apply_label_filter + else: + classes_file = Path(data_module.processed_dir_main) / "classes.txt" class_names = self.load_class_labels(classes_file) num_classes = len(class_names) metrics_obj_dict: dict[str, torchmetrics.Metric] = { @@ -181,6 +184,7 @@ def generate_props( } for batch_idx, batch in enumerate(data_loader): + batch = batch.to(device=model.device) data = model._process_batch(batch, batch_idx=batch_idx) labels = data["labels"].to(device=model.device) data["features"][0].to(device=model.device) From dfc4db9c5777878104795a9127a7a4f615c03741 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Sat, 1 Nov 2025 13:47:10 +0100 Subject: [PATCH 42/46] add id filter --- chebai/preprocessing/datasets/base.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 0d29e7f7..4d62193e 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -710,13 +710,17 @@ class _DynamicDataset(XYBaseDataModule, ABC): dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. apply_label_filter (Optional[str]): Path to a classes.txt file - only labels that are in the labels filter - file will be used (in that order). All labels in the label filter have to be present in the dataset. + file will be used (in that order). All labels in the label filter have to be present in the dataset. This filter + is only active when loading splits from a CSV file. Defaults to None. + apply_id_filter (Optional[str]): Path to a data.pt file from a different dataset - only IDs that are in the + id filter file will be used. Defaults to None. This filter is only active when loading splits from a CSV file. **kwargs: Additional keyword arguments passed to XYBaseDataModule. Attributes: dynamic_data_split_seed (int): The seed for random data splitting, default is 42. splits_file_path (Optional[str]): Path to the CSV file containing split assignments. apply_label_filter (Optional[str]): Path to a classes.txt file for label filtering. + apply_id_filter (Optional[str]): Path to a data.pt file for ID filtering. """ # ---- Index for columns of processed `data.pkl` (should be derived from `_graph_to_raw_dataset` method) ------ @@ -727,6 +731,7 @@ class _DynamicDataset(XYBaseDataModule, ABC): def __init__( self, apply_label_filter: Optional[str] = None, + apply_id_filter: Optional[str] = None, **kwargs, ): super(_DynamicDataset, self).__init__(**kwargs) @@ -741,6 +746,7 @@ def __init__( kwargs.get("splits_file_path", None) ) self.apply_label_filter = apply_label_filter + self.apply_id_filter = apply_id_filter @staticmethod def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]: @@ -1140,6 +1146,15 @@ def _retrieve_splits_from_csv(self) -> None: ) df_data = pd.DataFrame(data) + if self.apply_id_filter: + print(f"Applying ID filter from {self.apply_id_filter}...") + with open(self.apply_id_filter, "r") as f: + id_filter = [ + line["ident"] + for line in torch.load(self.apply_id_filter, weights_only=False) + ] + df_data = df_data[df_data["ident"].isin(id_filter)] + if self.apply_label_filter: print(f"Applying label filter from {self.apply_label_filter}...") with open(self.apply_label_filter, "r") as f: From bcf96f6be6b7a0c871183a32bb0f4a0193db41f0 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Sat, 1 Nov 2025 13:48:15 +0100 Subject: [PATCH 43/46] add id filter --- chebai/result/generate_class_properties.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chebai/result/generate_class_properties.py b/chebai/result/generate_class_properties.py index 8d744b7b..6a043e5a 100644 --- a/chebai/result/generate_class_properties.py +++ b/chebai/result/generate_class_properties.py @@ -121,6 +121,7 @@ def generate_props( model_config_file_path: str, data_config_file_path: str, output_path: str | None = None, + apply_id_filter: str | None = None, ) -> None: """ Run inference on validation set, compute TPV/NPV per class, and save to JSON. @@ -132,11 +133,13 @@ def generate_props( data_config_file_path: Path to yaml config file of the data. output_path: Optional path where to write the JSON metrics file. Defaults to '/classes.json'. + apply_id_filter: Optional path to a (data.pt) file containing IDs to filter the dataset. This is useful for comparing datasets with different ids. """ data_cls_path, data_cls_kwargs = parse_config_file(data_config_file_path) data_module: XYBaseDataModule = load_data_instance( data_cls_path, data_cls_kwargs ) + data_module.apply_id_filter = apply_id_filter splits_file_path = Path(data_module.processed_dir_main, "splits.csv") if data_module.splits_file_path is None: @@ -222,6 +225,7 @@ def generate( model_config_file_path: str, data_config_file_path: str, output_path: str | None = None, + apply_id_filter: str | None = None, ) -> None: """ CLI command to generate JSON with metrics on validation set. @@ -246,6 +250,7 @@ def generate( model_config_file_path, data_config_file_path, output_path, + apply_id_filter=apply_id_filter, ) From aba03da592238661de04e60488c6d4070cc9d2d6 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 5 Nov 2025 12:32:04 +0100 Subject: [PATCH 44/46] fix term callback for clause without subset --- chebai/preprocessing/datasets/chebi.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index b12d4b6e..edcc8c41 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -14,18 +14,8 @@ import random from abc import ABC from collections import OrderedDict -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Generator, - List, - Literal, - Optional, - Tuple, - Union, -) from itertools import cycle, permutations, product +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Optional, Union import numpy as np import pandas as pd @@ -1038,6 +1028,7 @@ def term_callback(doc: "fastobo.term.TermFrame") -> Union[Dict, bool]: parents = [] name = None smiles = None + subset = None for clause in doc: if isinstance(clause, fastobo.term.PropertyValueClause): t = clause.property_value From 901538100f6abc4b5c0be5527fa37d61aea5bcdb Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 5 Nov 2025 12:35:28 +0100 Subject: [PATCH 45/46] adapt reader test to fit bf97527477f84d6ac0752a196120ce424e6f9a9a --- tests/unit/readers/testChemDataReader.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/unit/readers/testChemDataReader.py b/tests/unit/readers/testChemDataReader.py index 5e6fb099..ec018f00 100644 --- a/tests/unit/readers/testChemDataReader.py +++ b/tests/unit/readers/testChemDataReader.py @@ -97,12 +97,15 @@ def test_read_data_with_new_token(self) -> None: def test_read_data_with_invalid_input(self) -> None: """ Test the _read_data method with an invalid input. - The invalid token should raise an error or be handled appropriately. + The invalid token should prompt a return value None """ raw_data = "%INVALID%" - with self.assertRaises(ValueError): - self.reader._read_data(raw_data) + result = self.reader._read_data(raw_data) + self.assertIsNone( + result, + "The output for invalid token '%INVALID%' should be None.", + ) @patch("builtins.open", new_callable=mock_open) def test_finish_method_for_new_tokens(self, mock_file: mock_open) -> None: From 08f60715649bc475fd70fc9ff95bfa0293d1d370 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 5 Nov 2025 12:39:13 +0100 Subject: [PATCH 46/46] adapt test for subset --- tests/unit/dataset_classes/testChebiTermCallback.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/dataset_classes/testChebiTermCallback.py b/tests/unit/dataset_classes/testChebiTermCallback.py index 8680760e..9ea77177 100644 --- a/tests/unit/dataset_classes/testChebiTermCallback.py +++ b/tests/unit/dataset_classes/testChebiTermCallback.py @@ -36,6 +36,7 @@ def test_process_valid_terms(self) -> None: "has_part": set(), "name": "Compound A", "smiles": "C1=CC=CC=C1", + "subset": "2_STAR", } actual_dict: Dict[str, Any] = term_callback(