Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
3ab691f
add fingerprint dataset, logistic regression model
sfluegel05 Aug 20, 2025
bb73784
add batched pubchem dataset
sfluegel05 Aug 20, 2025
db5434f
update file name system for pubchem batched
sfluegel05 Aug 21, 2025
abffcae
fix k
sfluegel05 Aug 21, 2025
6d7ca43
fix k
sfluegel05 Aug 21, 2025
bf97527
add error handling for smiles tokenisation
sfluegel05 Aug 22, 2025
7f9e28d
add default model
sfluegel05 Aug 22, 2025
f76fe1c
fix lstm
sfluegel05 Aug 22, 2025
03d5e55
fix lstm
sfluegel05 Aug 22, 2025
e1256b0
fix lstm
sfluegel05 Aug 22, 2025
abc9b53
fix lstm
sfluegel05 Aug 22, 2025
3b233d6
streamline classic ml
sfluegel05 Sep 16, 2025
8c0454c
fix batched pubchem
sfluegel05 Sep 17, 2025
04abe66
fix pubchem batching
sfluegel05 Sep 17, 2025
97079c3
fix batch tokenisation
sfluegel05 Sep 17, 2025
5e6c508
fix batch tokenisation
sfluegel05 Sep 17, 2025
03cb212
fix batch tokenisation
sfluegel05 Sep 17, 2025
0f1e7c0
run n epochs with n different training files
sfluegel05 Sep 17, 2025
9eebad2
add logging
sfluegel05 Sep 18, 2025
faa3a72
lstm error logging
sfluegel05 Sep 18, 2025
7f92917
add more logging to find out if pubchemBatched actually works
sfluegel05 Sep 18, 2025
69908ba
fix print statement for fixing epoch issue
sfluegel05 Sep 19, 2025
6df484d
reformatting
sfluegel05 Sep 19, 2025
4288689
reformatting
sfluegel05 Sep 19, 2025
0e6afe2
add num_layers and dropout parameters, make lstm bidirectional
sfluegel05 Sep 22, 2025
940ce9d
multi-layer lstm
sfluegel05 Sep 22, 2025
1e68032
increase vocab_size for PubChem
sfluegel05 Sep 23, 2025
8ee5c4b
streamline batch size in PubchemBatched
sfluegel05 Sep 23, 2025
33de8f3
update tokens (full pubchem)
Sep 23, 2025
c73b0fb
fix number of expected pubchem batches
sfluegel05 Sep 24, 2025
03635ff
Merge branch 'feature/new-ensemble-models' of https://github.com/ChEB…
sfluegel05 Sep 24, 2025
905ffc2
more options for LR
sfluegel05 Sep 24, 2025
c1da092
reformat
sfluegel05 Sep 24, 2025
078bfb6
add subset parameter for chebi data
sfluegel05 Sep 24, 2025
2911f92
fix merge conflict
sfluegel05 Sep 24, 2025
85656da
add token (chebi_v243)
sfluegel05 Sep 24, 2025
5c84ec7
add custom fit loop for custom hook handling
sfluegel05 Sep 24, 2025
182a3b1
fix typo
sfluegel05 Sep 25, 2025
86044af
set subset before using it
sfluegel05 Sep 25, 2025
4c58dcb
add electra freeze option
sfluegel05 Sep 30, 2025
4f506b1
make processing label rows safe if input is numpy array
sfluegel05 Oct 1, 2025
eb86e3f
cast to model device
sfluegel05 Oct 8, 2025
4ab760e
add label filter
sfluegel05 Oct 14, 2025
dfc4db9
add id filter
sfluegel05 Nov 1, 2025
bcf96f6
add id filter
sfluegel05 Nov 1, 2025
abe9e2a
Merge branch 'dev' into feature/new-ensemble-models
sfluegel05 Nov 5, 2025
aba03da
fix term callback for clause without subset
sfluegel05 Nov 5, 2025
9015381
adapt reader test to fit bf97527477f84d6ac0752a196120ce424e6f9a9a
sfluegel05 Nov 5, 2025
08f6071
adapt test for subset
sfluegel05 Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion chebai/callbacks/epoch_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
labels (torch.Tensor): Ground truth labels.
"""
tps = torch.sum(
torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0
torch.logical_and(preds > self.threshold, labels.to(torch.bool)),
dim=0,
)
self.true_positives += tps
self.positive_predictions += torch.sum(preds > self.threshold, dim=0)
Expand Down
13 changes: 12 additions & 1 deletion chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.utilities.rank_zero import rank_zero_info

from chebai.preprocessing.structures import XYData

Expand Down Expand Up @@ -106,7 +107,8 @@ def _get_prediction_and_labels(
Returns:
Tuple[torch.Tensor, torch.Tensor]: Predictions and labels.
"""
return output, labels
# cast labels to int
return output, labels.to(torch.int) if labels is not None else labels

def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor:
"""
Expand Down Expand Up @@ -158,6 +160,13 @@ def _process_for_loss(
"""
return model_output, labels, loss_kwargs

def on_train_epoch_start(self) -> None:
# pass current epoch to datamodule if it has the attribute curr_epoch (for PubChemBatched dataset)
rank_zero_info(f"Starting epoch {self.current_epoch}")
if hasattr(self.trainer.datamodule, "curr_epoch"):
rank_zero_info(f"Setting datamodule.curr_epoch to {self.current_epoch}")
self.trainer.datamodule.curr_epoch = self.current_epoch

def training_step(
self, batch: XYData, batch_idx: int
) -> Dict[str, Union[torch.Tensor, Any]]:
Expand Down Expand Up @@ -310,6 +319,8 @@ def _execute(
for metric_name, metric in metrics.items():
metric.update(pr, tar)
self._log_metrics(prefix, metrics, len(batch))
if isinstance(d, dict) and "loss" not in d:
print(f"d has keys {d.keys()}, log={log}, criterion={self.criterion}")
return d

def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int):
Expand Down
97 changes: 97 additions & 0 deletions chebai/models/classic_ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import pickle as pkl
from typing import Any, Dict, List, Optional

import numpy as np
import torch
import tqdm
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression

from chebai.models.base import ChebaiBaseNet

LR_MODEL_PATH = os.path.join("models", "LR")


class LogisticRegression(ChebaiBaseNet):
"""
Logistic Regression model using scikit-learn, wrapped to fit the ChebaiBaseNet interface.
"""

def __init__(
self,
out_dim: int,
input_dim: int,
only_predict_classes: Optional[List] = None,
n_classes=1528,
**kwargs,
):
super().__init__(out_dim=out_dim, input_dim=input_dim, **kwargs)
self.models = [
SklearnLogisticRegression(solver="liblinear") for _ in range(n_classes)
]
# indices of classes (in the dataset used for training) where a model should be trained
self.only_predict_classes = only_predict_classes

def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor:
print(
f"forward called with x[features].shape {x['features'].shape}, self.training {self.training}"
)
if self.training:
self.fit_sklearn(x["features"], x["labels"])
preds = []
for model in self.models:
try:
p = torch.from_numpy(model.predict(x["features"])).float()
p = p.to(x["features"].device)
preds.append(p)
except NotFittedError:
preds.append(
torch.zeros((x["features"].shape[0]), device=(x["features"].device))
)
except AttributeError:
preds.append(
torch.zeros((x["features"].shape[0]), device=(x["features"].device))
)
preds = torch.stack(preds, dim=1)
print(f"preds shape {preds.shape}")
return preds.squeeze(-1)

def fit_sklearn(self, X, y):
"""
Fit the underlying sklearn model. X and y should be numpy arrays.
"""
for i, model in tqdm.tqdm(enumerate(self.models), desc="Fitting models"):
import os

if os.path.exists(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl")):
print(f"Loading model {i} from file")
self.models[i] = pkl.load(
open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "rb")
)
else:
if (
self.only_predict_classes and i not in self.only_predict_classes
): # only try these classes
continue
try:
model.fit(X, y[:, i])
except ValueError:
self.models[i] = PlaceholderModel()
# dump
pkl.dump(
model, open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "wb")
)

def configure_optimizers(self, **kwargs):
pass


class PlaceholderModel:
"""Acts like a trained model, but isn't. Use this if training fails and you need a placeholder."""

def __init__(self, default_prediction=1):
self.default_prediction = default_prediction

def predict(self, preds):
return np.ones(preds.shape[0]) * self.default_prediction
5 changes: 5 additions & 0 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def __init__(
config: Optional[Dict[str, Any]] = None,
pretrained_checkpoint: Optional[str] = None,
load_prefix: Optional[str] = None,
freeze_electra: bool = False,
**kwargs: Any,
):
# Remove this property in order to prevent it from being stored as a
Expand Down Expand Up @@ -262,6 +263,10 @@ def __init__(
else:
self.electra = ElectraModel(config=self.config)

if freeze_electra:
for param in self.electra.parameters():
param.requires_grad = False

def _process_for_loss(
self,
model_output: Dict[str, Tensor],
Expand Down
50 changes: 37 additions & 13 deletions chebai/models/lstm.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,55 @@
import logging

from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from chebai.models.base import ChebaiBaseNet

logging.getLogger("pysmiles").setLevel(logging.CRITICAL)


class ChemLSTM(ChebaiBaseNet):
def __init__(self, in_d, out_d, num_classes, **kwargs):
super().__init__(num_classes, **kwargs)
self.lstm = nn.LSTM(in_d, out_d, batch_first=True)
self.embedding = nn.Embedding(800, 100)
def __init__(
self,
out_d,
in_d,
num_classes,
criterion: nn.Module = None,
num_layers=6,
dropout=0.2,
**kwargs,
):
super().__init__(
out_dim=out_d,
input_dim=in_d,
criterion=criterion,
num_classes=num_classes,
**kwargs,
)
self.lstm = nn.LSTM(
in_d,
out_d,
batch_first=True,
dropout=dropout,
bidirectional=True,
num_layers=num_layers,
)
self.embedding = nn.Embedding(1400, in_d)
self.output = nn.Sequential(
nn.Linear(out_d, in_d),
nn.Linear(out_d * 2, out_d),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(in_d, num_classes),
nn.Linear(out_d, num_classes),
)

def forward(self, data):
x = data.x
x_lens = data.lens
def forward(self, data, *args, **kwargs):
x = data["features"]
x_lens = data["model_kwargs"]["lens"]
x = self.embedding(x)
x = pack_padded_sequence(x, x_lens, batch_first=True, enforce_sorted=False)
x = self.lstm(x)[1][0]
# = pad_packed_sequence(x, batch_first=True)[0]
x = self.lstm(x)[0]
x = pad_packed_sequence(x, batch_first=True)[0][
:, 0
] # reduce sequence dimension to first element
x = self.output(x)
return x.squeeze(0)
return x
Loading