Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions chebai/loss/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import math
import torch
from typing import Literal
from typing import Literal, Union

from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor, ChEBIOver100
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
Expand All @@ -14,7 +14,7 @@
class ImplicationLoss(torch.nn.Module):
def __init__(
self,
data_extractor: _ChEBIDataExtractor | LabeledUnlabeledMixed,
data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed],
base_loss: torch.nn.Module = None,
tnorm: Literal["product", "lukasiewicz", "xu19"] = "product",
impl_loss_weight=0.1, # weight of implication loss in relation to base_loss
Expand Down Expand Up @@ -114,7 +114,7 @@ class DisjointLoss(ImplicationLoss):
def __init__(
self,
path_to_disjointness,
data_extractor: _ChEBIDataExtractor | LabeledUnlabeledMixed,
data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed],
base_loss: torch.nn.Module = None,
disjoint_loss_weight=100,
**kwargs,
Expand Down
3 changes: 2 additions & 1 deletion chebai/result/analyse_sem.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchmetrics.functional.classification import multilabel_f1_score
import wandb
import gc
from typing import List, Union
from utils import *

DEVICE = "cpu" # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -244,7 +245,7 @@ def analyse_run(
labeled_data_cls=ChEBIOver100, # use labels from this dataset for violations
chebi_version=231,
results_path=os.path.join("_semantic", "eval_results.csv"),
violation_metrics: [str | list[callable]] = "all",
violation_metrics: Union[str, List[callable]] = "all",
verbose_violation_output=False,
):
"""Calculates all semantic metrics for given predictions (and supervised metrics if labels are provided),
Expand Down
7 changes: 5 additions & 2 deletions chebai/trainer/CustomTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def predict_from_file(
smiles_strings = [inp.strip() for inp in input.readlines()]
loaded_model.eval()
predictions = self._predict_smiles(loaded_model, smiles_strings)
predictions_df = pd.DataFrame(predictions.detach().numpy())
predictions_df = pd.DataFrame(predictions.detach().cpu().numpy())
if classes_path is not None:
with open(classes_path, "r") as f:
predictions_df.columns = [cls.strip() for cls in f.readlines()]
Expand All @@ -74,7 +74,10 @@ def predict_from_file(
def _predict_smiles(self, model: LightningModule, smiles: List[str]):
reader = ChemDataReader()
parsed_smiles = [reader._read_data(s) for s in smiles]
x = pad_sequence([torch.tensor(a) for a in parsed_smiles], batch_first=True)
x = pad_sequence(
[torch.tensor(a, device=model.device) for a in parsed_smiles],
batch_first=True,
)
cls_tokens = (
torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1)
* CLS_TOKEN
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"iterative-stratification",
"wandb",
"chardet",
"yaml",
"pyyaml",
"torchmetrics",
],
extras_require={"dev": ["black", "isort", "pre-commit"]},
Expand Down
Loading