In [1]:
import os
print(os.getcwd())
print(os.path.exists("checkpoints/bioactivity/best_reference_chemberta_xai.pth"))


d:\DA_Final_Tox21
True


In [2]:
# imports + load models
from models.bioactivity.loader import load_bioactivity
from models.bioactivity.infer import predict_bioactivity

from models.tox21.hf_loader import load_tox_hf
from models.tox21.hf_infer import predict_tox_hf

from pipeline.screening import screen_end_to_end

TAU_BIO = 0.5
TAU_TOX = 0.5

bio_model, bio_tok = load_bioactivity(
    model_dir="checkpoints/bioactivity",
    weights_name="best_reference_chemberta_xai.pth",
    device="cpu"
)


tox_model, tox_tok = load_tox_hf(
    artifacts_dir="artifacts/admet_chemberta_tox21",
    device="cpu"
)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# define wrappers
def bio_fn(xs):
    return predict_bioactivity(xs, model=bio_model, tokenizer=bio_tok, tau_bio=TAU_BIO)

def tox_fn(xs):
    return predict_tox_hf(
        xs,
        model=tox_model,
        tokenizer=tox_tok,
        tau_tox=TAU_TOX
    )

In [4]:
# run test
smiles_list = [
    "CC(=O)Oc1ccccc1C(=O)O",  # aspirin example
    "CCN(CC)CCCC(C)NC1=C2C=CC(=CC2=NC=C1)Cl"
]

outs = screen_end_to_end(smiles_list, bio_fn=bio_fn, tox_fn=tox_fn)
outs

[ScreenOut(smiles='CC(=O)Oc1ccccc1C(=O)O', bio=BioOut(p_active=0.06879109144210815, active=False, xai=None), tox=ToxOut(p_toxic=0.21201245486736298, non_toxic=True, xai=None), keep=False, reason='Inactive'),
 ScreenOut(smiles='CCN(CC)CCCC(C)NC1=C2C=CC(=CC2=NC=C1)Cl', bio=BioOut(p_active=0.1400400847196579, active=False, xai=None), tox=ToxOut(p_toxic=0.3383820950984955, non_toxic=True, xai=None), keep=False, reason='Inactive')]

In [None]:
# quick sanity check
enc = tox_tok(["CC(=O)Oc1ccccc1C(=O)O"], return_tensors="pt", padding=True, truncation=True, max_length=128)
out = tox_model(**enc)
out.logits.shape

torch.Size([1, 2])