Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
89ffbd3
Data handling restructure
aditya0by0 May 27, 2024
f07f312
Merge branch 'dev' into feature/testing_framework
aditya0by0 May 27, 2024
bd6382b
Update chebi tests for dynamic splits
aditya0by0 Jun 5, 2024
d8abee2
Dynamic split for chebi_version_train + changes
aditya0by0 Jun 8, 2024
91aa484
Update dynamic split tests
aditya0by0 Jun 8, 2024
22f882c
Update chebi + dynamic test
aditya0by0 Jun 10, 2024
dde4196
Update setup.py
aditya0by0 Jun 11, 2024
aecb7e6
Update Evaluation notebook + rel. code
aditya0by0 Jun 12, 2024
98342af
set split variables when required instead of during setup
Jun 13, 2024
89cbdb6
remove unnecessary class instantiation
Jun 13, 2024
8b22601
Merge branch 'refs/heads/dev' into feature/testing_framework
Jun 13, 2024
b2439f8
add isort to pre-commit, reformat with isort
Jun 13, 2024
ec6254d
Update .gitignore
aditya0by0 Jun 13, 2024
c1b6b0d
remove commented out cells - eval notebook
aditya0by0 Jun 13, 2024
667b079
add filename parameter to load_processed_data
aditya0by0 Jun 13, 2024
8c9dfe1
Updated chebi.py for train_version restructure
aditya0by0 Jun 18, 2024
cd03023
minor changes in data split code
aditya0by0 Jun 19, 2024
0584345
Merge branch 'dev' into feature/testing_framework
aditya0by0 Jun 23, 2024
f747257
fix: test for consistency across runs did validate the same run twice
Jun 27, 2024
a87dd35
migration script for chebi data for new data restructure
aditya0by0 Jul 1, 2024
d8e68cc
argparser + fixes
aditya0by0 Jul 1, 2024
ae61d10
transform data.pkl to data.pt instead of combining .pt split files
aditya0by0 Jul 1, 2024
9c25543
migration - raw data error fix + id col error
aditya0by0 Jul 3, 2024
0c2fca1
pd.to_pickle instead of pickle.dump for code consistency
aditya0by0 Jul 3, 2024
1c4acea
migration : added docstring + type hints
aditya0by0 Jul 3, 2024
9992a15
logic to generate splits csv + use csv if provided
aditya0by0 Jul 3, 2024
07340cb
read only first row to validate presence of relevant columns in csv
aditya0by0 Jul 3, 2024
bc19a21
add jsonargparse cli to migration, gentle file-not-found handling
Jul 5, 2024
8b0b505
add documentation for users
Jul 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/black.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: psf/black@stable
- uses: psf/black@stable
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,8 @@ cython_debug/
#.idea/

# configs/ # commented as new configs can be added as a part of a feature
/.idea
/data
/logs
/results_buffer
electra_pretrained.ckpt
26 changes: 21 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
repos:
#- repo: https://github.com/PyCQA/isort
# rev: "5.12.0"
# hooks:
# - id: isort
- repo: https://github.com/psf/black
rev: "24.2.0"
hooks:
- id: black
- id: black
- id: black-jupyter # for formatting jupyter-notebook

- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
args: ["--profile=black"]

- repo: https://github.com/asottile/seed-isort-config
rev: v2.2.0
hooks:
- id: seed-isort-config

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ChEBai

ChEBai is a deep learning library designed for the integration of deep learning methods with chemical ontologies, particularly ChEBI.
ChEBai is a deep learning library designed for the integration of deep learning methods with chemical ontologies, particularly ChEBI.
The library emphasizes the incorporation of the semantic qualities of the ontology into the learning process.

## Installation
Expand All @@ -21,7 +21,7 @@ pip install .

## Usage

The training and inference is abstracted using the Pytorch Lightning modules.
The training and inference is abstracted using the Pytorch Lightning modules.
Here are some CLI commands for the standard functionalities of pretraining, ontology extension, fine-tuning for toxicity and prediction.
For further details, see the [wiki](https://github.com/ChEB-AI/python-chebai/wiki).
If you face any problems, please open a new [issue](https://github.com/ChEB-AI/python-chebai/issues/new).
Expand Down Expand Up @@ -55,18 +55,18 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont

## Evaluation

An example for evaluating a model trained on the ontology extension task is given in `tutorials/eval_model_basic.ipynb`.
An example for evaluating a model trained on the ontology extension task is given in `tutorials/eval_model_basic.ipynb`.
It takes in the finetuned model as input for performing the evaluation.

## Cross-validation
You can do inner k-fold cross-validation, i.e., train models on k train-validation splits that all use the same test
You can do inner k-fold cross-validation, i.e., train models on k train-validation splits that all use the same test
set. For that, you need to specify the total_number of folds as
```
--data.init_args.inner_k_folds=K
```
and the fold to be used in the current optimisation run as
```
```
--data.init_args.fold_index=I
```
To train K models, you need to do K such calls, each with a different `fold_index`. On the first call with a given
To train K models, you need to do K such calls, each with a different `fold_index`. On the first call with a given
`inner_k_folds`, all folds will be created and stored in the data directory
2 changes: 1 addition & 1 deletion chebai/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import os

from lightning.pytorch.callbacks import BasePredictionWriter
import torch
from lightning.pytorch.callbacks import BasePredictionWriter


class ChebaiPredictionWriter(BasePredictionWriter):
Expand Down
5 changes: 3 additions & 2 deletions chebai/callbacks/prediction_callback.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from lightning.pytorch.callbacks import BasePredictionWriter
import torch
import os
import pickle

import torch
from lightning.pytorch.callbacks import BasePredictionWriter


class PredictionWriter(BasePredictionWriter):
def __init__(self, output_dir, write_interval):
Expand Down
6 changes: 3 additions & 3 deletions chebai/loggers/custom.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from datetime import datetime
from typing import Literal, Optional, Union, List
import os
from datetime import datetime
from typing import List, Literal, Optional, Union

import wandb
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
import wandb


class CustomLogger(WandbLogger):
Expand Down
8 changes: 5 additions & 3 deletions chebai/loss/bce_weighted.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import pickle

import pandas as pd
import torch

from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
import pandas as pd
import os
import pickle


class BCEWeighted(torch.nn.BCEWithLogitsLoss):
Expand Down
7 changes: 4 additions & 3 deletions chebai/loss/semantic.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import csv
import math
import os
import pickle

import math
import torch

from typing import Literal, Union

from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor, ChEBIOver100
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
from chebai.loss.bce_weighted import BCEWeighted
from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed


class ImplicationLoss(torch.nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions chebai/models/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Optional
import logging
import typing
from typing import Optional

from lightning.pytorch.core.module import LightningModule
import torch
from lightning.pytorch.core.module import LightningModule

from chebai.preprocessing.structures import XYData

Expand Down
4 changes: 2 additions & 2 deletions chebai/models/chemberta.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from tempfile import TemporaryDirectory
import logging
import random
from tempfile import TemporaryDirectory

import torch
from torch import nn
from torch.nn.functional import one_hot
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
Expand All @@ -11,7 +12,6 @@
RobertaModel,
RobertaTokenizer,
)
import torch

from chebai.models.base import ChebaiBaseNet

Expand Down
4 changes: 2 additions & 2 deletions chebai/models/chemyk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import pickle
import sys

import networkx as nx
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.functional import pad
import networkx as nx
import torch

from chebai.models.base import ChebaiBaseNet

Expand Down
4 changes: 2 additions & 2 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from math import pi
from tempfile import TemporaryDirectory
import logging

import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from transformers import (
Expand All @@ -10,7 +11,6 @@
ElectraForPreTraining,
ElectraModel,
)
import torch

from chebai.loss.pretraining import ElectraPreLoss # noqa
from chebai.models.base import ChebaiBaseNet
Expand Down
4 changes: 2 additions & 2 deletions chebai/models/lnn_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from lnn import Implies, Model, Not, Predicate, Variable, World
from owlready2 import get_ontology
import fastobo
import pyhornedowl
import tqdm
from lnn import Implies, Model, Not, Predicate, Variable, World
from owlready2 import get_ontology


def get_name(iri: str):
Expand Down
2 changes: 1 addition & 1 deletion chebai/models/recursive.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging

from torch import exp, nn, tensor
import networkx as nx
import torch
import torch.nn.functional as F
from torch import exp, nn, tensor

from chebai.models.base import ChebaiBaseNet

Expand Down
2 changes: 1 addition & 1 deletion chebai/preprocessing/bin/BPE_SWJ/vocab.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion chebai/preprocessing/collate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch.nn.utils.rnn import pad_sequence
import torch
from torch.nn.utils.rnn import pad_sequence

from chebai.preprocessing.structures import XYData

Expand Down
6 changes: 3 additions & 3 deletions chebai/preprocessing/collect_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
import os
import sys

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.metrics import F1
from sklearn.metrics import f1_score
from torch import nn
from torch_geometric import nn as tgnn
from torch_geometric.data import DataLoader
import pytorch_lightning as pl
import torch
import torch.nn.functional as F

from data import ClassificationData, JCIClassificationData

Expand Down
8 changes: 4 additions & 4 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import List, Union
import os
import random
import typing
from typing import List, Union

from lightning.pytorch.core.datamodule import LightningDataModule
from lightning_utilities.core.rank_zero import rank_zero_info
from torch.utils.data import DataLoader
import lightning as pl
import torch
import tqdm
from lightning.pytorch.core.datamodule import LightningDataModule
from lightning_utilities.core.rank_zero import rank_zero_info
from torch.utils.data import DataLoader

from chebai.preprocessing import reader as dr

Expand Down
Loading